Skip to content

Commit

Permalink
Fix missing role in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
piercefreeman committed Aug 24, 2023
1 parent 0694734 commit 5fce4fc
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
9 changes: 9 additions & 0 deletions gpt_json/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,12 @@ def parse_obj_model(model: Type[T], obj: dict[str, Any]) -> T:
return model.model_validate(obj)
else:
raise ValueError(f"Unknown pydantic field class structure: {model}")


def obj_to_json(model: T, **kwargs) -> str:
if hasattr(model, "json"):
return model.json(**kwargs)
elif hasattr(model, "model_dump_json"):
return model.model_dump_json(**kwargs)
else:
raise ValueError(f"Unknown pydantic field class structure: {model}")
7 changes: 5 additions & 2 deletions gpt_json/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pydantic import BaseModel, Field, ValidationError
from tiktoken import encoding_for_model

from gpt_json.common import parse_obj_model
from gpt_json.common import obj_to_json, parse_obj_model
from gpt_json.exceptions import InvalidFunctionParameters, InvalidFunctionResponse
from gpt_json.fn_calling import (
function_to_name,
Expand Down Expand Up @@ -559,7 +559,10 @@ def fill_message_template(
return new_message

def message_to_dict(self, message: GPTMessage):
return {"role": message.role.value, "content": message.content}
# return {"role": message.role.value, "content": message.content}
obj = json_loads(obj_to_json(message, exclude_unset=True))
obj.pop("allow_templating", None)
return obj

def trim_messages(self, messages: list[GPTMessage], n: int):
"""
Expand Down
8 changes: 6 additions & 2 deletions gpt_json/tests/test_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,9 @@ async def test_extracted_json_is_None():
with patch.object(
gpt,
"submit_request",
return_value={"choices": [{"message": {"content": "some content"}}]},
return_value={
"choices": [{"message": {"content": "some content", "role": "assistant"}}]
},
), patch.object(
gpt, "extract_json", return_value=(None, FixTransforms(None, False))
):
Expand Down Expand Up @@ -392,7 +394,9 @@ async def test_unable_to_find_valid_json_payload():
with patch.object(
gpt,
"submit_request",
return_value={"choices": [{"message": {"content": "some content"}}]},
return_value={
"choices": [{"message": {"content": "some content", "role": "assistant"}}]
},
), patch.object(
gpt, "extract_json", return_value=(None, FixTransforms(None, False))
):
Expand Down

0 comments on commit 5fce4fc

Please sign in to comment.