diff --git a/gpt_json/models.py b/gpt_json/models.py index f1b19c1..b7c361b 100644 --- a/gpt_json/models.py +++ b/gpt_json/models.py @@ -48,6 +48,11 @@ class JsonFixEnum(EnumSuper): UNCLOSED_VALUE = "unclosed_value" MISSING_VALUE = "missing_value" + # Drop any additional JSON tags that occur after the main payload + # has been processed; this most often happens when the models spit back + # double close brackets like ]] or }} + DROP_TRAILING_JSON = "drop_trailing_json" + @dataclass class FixTransforms: diff --git a/gpt_json/tests/test_parsers.py b/gpt_json/tests/test_parsers.py index bd05abe..e1d10d0 100644 --- a/gpt_json/tests/test_parsers.py +++ b/gpt_json/tests/test_parsers.py @@ -32,6 +32,11 @@ '{"text": "Test", "numerical": 123, "reason": true, "sub_element": { "name": "Test" }, "items": ["Item 1", "Item 2', ResponseType.DICTIONARY, ), + ( + 'This message has an additional closing bracket: {"text": "Test"}}', + '{"text": "Test"}}', + ResponseType.DICTIONARY, + ), ], ) def test_find_json_response(input_string, expected, extract_type): diff --git a/gpt_json/tests/test_transformations.py b/gpt_json/tests/test_transformations.py index a67e1ed..bcae43e 100644 --- a/gpt_json/tests/test_transformations.py +++ b/gpt_json/tests/test_transformations.py @@ -114,6 +114,13 @@ def test_is_truncated(input_string: str, expected: bool): }, JsonFixEnum.UNCLOSED_VALUE, ), + ( + '{"text": "Test"}}', + { + "text": "Test", + }, + JsonFixEnum.DROP_TRAILING_JSON, + ), ], ) def test_fix_truncated_json(broken_string, expected, expected_fix_reason): diff --git a/gpt_json/transformations.py b/gpt_json/transformations.py index 3c509e1..65fcbb2 100644 --- a/gpt_json/transformations.py +++ b/gpt_json/transformations.py @@ -4,6 +4,7 @@ def build_stack(json_str): stack = [] fixed_str = "" + last_i = -1 open_quotes = False # a flag indicating whether we've seen a comma or colon most recently @@ -18,6 +19,9 @@ def build_stack(json_str): last_seen_comma_or_colon = None # closing a nested elif char in "}]": + if len(stack) == 0: + print("will break") + break stack.pop() last_seen_comma_or_colon = None if char in ",:": @@ -27,8 +31,10 @@ def build_stack(json_str): open_quotes = not open_quotes fixed_str += char + last_i = i + 1 - return (stack, fixed_str, open_quotes, last_seen_comma_or_colon) + unparsed_str = json_str[last_i:] + return (stack, fixed_str, open_quotes, last_seen_comma_or_colon, unparsed_str) def _is_missing_dict_value(stack, fixed_str, open_quotes, last_seen_comma_or_colon): @@ -59,7 +65,7 @@ def is_truncated(json_str): brackets is greater than the number of closing brackets. """ - stack, _, _, _ = build_stack(json_str) + stack, _, _, _, _ = build_stack(json_str) return len(stack) > 0 @@ -70,13 +76,18 @@ def fix_truncated_json(json_str) -> tuple[str, JsonFixEnum | None]: Returns a tuple of (fixed_json_string, fix_type) """ - stack, fixed_str, open_quotes, last_seen_colon_or_comma = build_stack(json_str) + stack, fixed_str, open_quotes, last_seen_colon_or_comma, unparsed_str = build_stack( + json_str + ) missing_value = _is_missing_dict_value( stack, fixed_str, open_quotes, last_seen_colon_or_comma ) is_truncated = len(stack) > 0 if not is_truncated: - return json_str, None + if not unparsed_str.strip(): + return json_str, None + else: + return fixed_str, JsonFixEnum.DROP_TRAILING_JSON fixed_str = fixed_str.strip() @@ -97,6 +108,7 @@ def fix_truncated_json(json_str) -> tuple[str, JsonFixEnum | None]: close_stack = ["]" if char == "[" else "}" for char in stack] fixed_str += "".join(close_stack[::-1]) + print("FIXED", fixed_str) # if the fixed string is valid JSON, return it fix = JsonFixEnum.UNCLOSED_OBJECT if open_quotes: