Skip to content

Commit

Permalink
Require pydantic2
Browse files Browse the repository at this point in the history
  • Loading branch information
piercefreeman committed Aug 24, 2023
1 parent 5fce4fc commit 469bd79
Show file tree
Hide file tree
Showing 10 changed files with 13 additions and 100 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
matrix:
python: ["3.11"]
pydantic: ["1.10.12", "2.1.1"]
pydantic: ["2.1.1"]

steps:
- uses: actions/checkout@v3
Expand Down
62 changes: 0 additions & 62 deletions gpt_json/common.py

This file was deleted.

7 changes: 0 additions & 7 deletions gpt_json/fn_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from pydantic import BaseModel

from gpt_json.common import get_pydantic_version


def parse_function(fn: Callable) -> Dict[str, Any]:
"""
Expand All @@ -17,11 +15,6 @@ def parse_function(fn: Callable) -> Dict[str, Any]:
API Reference: https://platform.openai.com/docs/api-reference/chat/create
"""
if get_pydantic_version() < 2:
raise ValueError(
f"Function calling is only supported with Pydantic > 2, found {get_pydantic_version()}"
)

docstring = getdoc(fn) or ""
lines = docstring.strip().split("\n")
description = lines[0] if lines else None
Expand Down
4 changes: 1 addition & 3 deletions gpt_json/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from pydantic import BaseModel, Field, create_model

from gpt_json.common import get_model_field_infos


def get_typevar_mapping(t: Any) -> dict[TypeVar, Any]:
origin = get_origin(t)
Expand Down Expand Up @@ -52,7 +50,7 @@ def resolve_generic_model(t: Any) -> Type[BaseModel]:
# Create a dict with all the fields from the original model
fields = {}
for name, type_ in base_model.__annotations__.items():
original_field = get_model_field_infos(base_model).get(name)
original_field = base_model.model_fields.get(name)
if original_field:
fields[name] = (type_, original_field)
else:
Expand Down
6 changes: 2 additions & 4 deletions gpt_json/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from pydantic import BaseModel, Field, ValidationError
from tiktoken import encoding_for_model

from gpt_json.common import obj_to_json, parse_obj_model
from gpt_json.exceptions import InvalidFunctionParameters, InvalidFunctionResponse
from gpt_json.fn_calling import (
function_to_name,
Expand Down Expand Up @@ -252,7 +251,7 @@ async def run(
except (ValueError, ValidationError):
raise InvalidFunctionParameters(function_name, function_args_string)

raw_response = parse_obj_model(GPTMessage, response_message)
raw_response = GPTMessage.model_validate(response_message)
raw_response.allow_templating = False

extracted_json, fixed_payload = self.extract_json(
Expand Down Expand Up @@ -559,8 +558,7 @@ def fill_message_template(
return new_message

def message_to_dict(self, message: GPTMessage):
# return {"role": message.role.value, "content": message.content}
obj = json_loads(obj_to_json(message, exclude_unset=True))
obj = json_loads(message.model_dump_json(exclude_unset=True))
obj.pop("allow_templating", None)
return obj

Expand Down
12 changes: 6 additions & 6 deletions gpt_json/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

from pydantic import BaseModel

from gpt_json.common import get_field_description, get_model_fields


def generate_schema_prompt(schema: Type[BaseModel]) -> str:
"""
Expand All @@ -15,12 +13,14 @@ def generate_schema_prompt(schema: Type[BaseModel]) -> str:

def generate_payload(model: Type[BaseModel]):
payload = []
for key, value in get_model_fields(model).items():
for key, value in model.model_fields.items():
field_annotation = value.annotation
annotation_origin = get_origin(field_annotation)
annotation_arguments = get_args(field_annotation)

if annotation_origin in {list, List}:
if field_annotation is None:
continue
elif annotation_origin in {list, List}:
if issubclass(annotation_arguments[0], BaseModel):
payload.append(
f'"{key}": {generate_payload(annotation_arguments[0])}[]'
Expand All @@ -35,8 +35,8 @@ def generate_payload(model: Type[BaseModel]):
payload.append(f'"{key}": {generate_payload(field_annotation)}')
else:
payload.append(f'"{key}": {field_annotation.__name__.lower()}')
if get_field_description(value):
payload[-1] += f" // {get_field_description(value)}"
if value.description:
payload[-1] += f" // {value.description}"
# All brackets are double defined so they will passthrough a call to `.format()` where we
# pass custom variables
return "{{\n" + ",\n".join(payload) + "\n}}"
Expand Down
7 changes: 2 additions & 5 deletions gpt_json/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from pydantic import BaseModel

from gpt_json.common import get_model_fields
from gpt_json.transformations import JsonFixEnum, fix_truncated_json

SchemaType = TypeVar("SchemaType", bound=BaseModel)
Expand Down Expand Up @@ -41,9 +40,7 @@ def _create_schema_from_partial(
TODO: this is hacky. ideally we want pydantic to implement Partial[SchemaType]
https://github.com/pydantic/pydantic/issues/1673
my fix is to create the schema object with all string values for now"""
cleaned_obj_data = {
field: "" for field, typ in get_model_fields(schema_model).items()
}
cleaned_obj_data = {field: "" for field, typ in schema_model.model_fields.items()}
cleaned_obj_data.update({k: v for k, v in partial.items() if v is not None})
return schema_model(**cleaned_obj_data)

Expand All @@ -69,7 +66,7 @@ def prepare_streaming_object(
list(current_partial_raw.keys())[-1] if current_partial_raw else None
)
updated_key = (
raw_recent_key if raw_recent_key in get_model_fields(schema_model) else None
raw_recent_key if raw_recent_key in schema_model.model_fields else None
)

event = proposed_event
Expand Down
7 changes: 0 additions & 7 deletions gpt_json/tests/test_fn_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest

from gpt_json.common import get_pydantic_version
from gpt_json.fn_calling import get_base_type, parse_function
from gpt_json.tests.shared import (
UnitType,
Expand All @@ -12,9 +11,6 @@
)


@pytest.mark.skipif(
get_pydantic_version() < 2, reason="Pydantic 2+ required for function calls"
)
@pytest.mark.parametrize(
"incorrect_fn",
[
Expand All @@ -33,9 +29,6 @@ def test_get_base_type():
assert get_base_type(Union[UnitType, None]) == UnitType


@pytest.mark.skipif(
get_pydantic_version() < 2, reason="Pydantic 2+ required for function calls"
)
def test_parse_function():
"""
Assert the formatted schema conforms to the expected JSON-Schema / GPT format.
Expand Down
4 changes: 0 additions & 4 deletions gpt_json/tests/test_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from openai.error import Timeout as OpenAITimeout
from pydantic import BaseModel, Field

from gpt_json.common import get_pydantic_version
from gpt_json.fn_calling import parse_function
from gpt_json.generics import resolve_generic_model
from gpt_json.gpt import GPTJSON, ListResponse
Expand Down Expand Up @@ -204,9 +203,6 @@ async def test_acreate(
assert response.fix_transforms == expected_transformations


@pytest.mark.skipif(
get_pydantic_version() < 2, reason="Pydantic 2+ required for function calls"
)
@pytest.mark.asyncio
async def test_acreate_with_function_calls():
model_version = GPTModelVersion.GPT_3_5
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ packages = [{include = "gpt_json"}]
python = "^3.11"
tiktoken = "^0.3.3"
openai = "^0.27.6"
pydantic = ">1.10.7, <3.0.0"
pydantic = ">2.0.0, <3.0.0"
backoff = "^2.2.1"


Expand Down

0 comments on commit 469bd79

Please sign in to comment.