Skip to content

Commit

Permalink
Initial Pydantic2 Support (#37)
Browse files Browse the repository at this point in the history
* Pydantic2 test runner

* Add pydantic2 support

* Fix poetry requirements syntax
  • Loading branch information
piercefreeman authored Aug 21, 2023
1 parent a1e68a3 commit 11efb9a
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 90 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ jobs:
strategy:
matrix:
python: ["3.11"]
pydantic: ["1.10.12", "2.1.1"]

steps:
- uses: actions/checkout@v3
Expand All @@ -28,6 +29,10 @@ jobs:
export PATH="/Users/runner/.local/bin:$PATH"
poetry install
- name: Install pydantic version
run: |
poetry add pydantic==${{ matrix.pydantic }}
- name: Run tests
run: |
poetry run pytest
Expand Down
36 changes: 36 additions & 0 deletions gpt_json/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Pydantic V1 and V2 compatibility layer. Pydantic V2 has a better API for the type inspection
that we do in GPT-JSON, but we can easily bridge some of the concepts in V1.
"""

from typing import Type

from pydantic import BaseModel


def get_field_description(field):
if hasattr(field, "description"):
return field.description
elif hasattr(field, "field_info"):
return field.field_info.description
else:
raise ValueError(f"Unknown pydantic field class structure: {field}")


def get_model_fields(model: Type[BaseModel]):
if hasattr(model, "model_fields"):
return model.model_fields
elif hasattr(model, "__fields__"):
return model.__fields__
else:
raise ValueError(f"Unknown pydantic field class structure: {model}")


def get_model_field_infos(model: Type[BaseModel]):
if hasattr(model, "model_fields"):
return model.model_fields
elif hasattr(model, "__fields__"):
return {key: value.field_info for key, value in model.__fields__.items()} # type: ignore
else:
raise ValueError(f"Unknown pydantic field class structure: {model}")
11 changes: 8 additions & 3 deletions gpt_json/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from pydantic import BaseModel, Field, create_model

from gpt_json.common import get_model_field_infos


def get_typevar_mapping(t: Any) -> dict[TypeVar, Any]:
origin = get_origin(t)
Expand Down Expand Up @@ -50,15 +52,18 @@ def resolve_generic_model(t: Any) -> Type[BaseModel]:
# Create a dict with all the fields from the original model
fields = {}
for name, type_ in base_model.__annotations__.items():
original_field = base_model.__fields__.get(name)
original_field = get_model_field_infos(base_model).get(name)
if original_field:
fields[name] = (type_, original_field.field_info)
fields[name] = (type_, original_field)
else:
fields[name] = (type_, Field())

# Replace the fields that have a TypeVar with their resolved types
for name, (type_, field) in fields.items():
fields[name] = (resolve_type(type_, typevar_mapping), field)
resolved_annotation = resolve_type(type_, typevar_mapping)
fields[name] = (resolved_annotation, field)
if hasattr(field, "annotation"):
field.annotation = resolved_annotation

# Use the Pydantic's create_model function to create a new model with the resolved fields
return create_model(base_model.__name__, **fields) # type: ignore
16 changes: 9 additions & 7 deletions gpt_json/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from pydantic import BaseModel

from gpt_json.common import get_field_description, get_model_fields


def generate_schema_prompt(schema: Type[BaseModel]) -> str:
"""
Expand All @@ -13,8 +15,8 @@ def generate_schema_prompt(schema: Type[BaseModel]) -> str:

def generate_payload(model: Type[BaseModel]):
payload = []
for key, value in model.__fields__.items():
field_annotation = model.__annotations__[key]
for key, value in get_model_fields(model).items():
field_annotation = value.annotation
annotation_origin = get_origin(field_annotation)
annotation_arguments = get_args(field_annotation)

Expand All @@ -29,12 +31,12 @@ def generate_payload(model: Type[BaseModel]):
payload.append(
f'"{key}": {" | ".join([arg.__name__.lower() for arg in annotation_arguments])}'
)
elif issubclass(value.type_, BaseModel):
payload.append(f'"{key}": {generate_payload(value.type_)}')
elif issubclass(field_annotation, BaseModel):
payload.append(f'"{key}": {generate_payload(field_annotation)}')
else:
payload.append(f'"{key}": {value.type_.__name__.lower()}')
if value.field_info.description:
payload[-1] += f" // {value.field_info.description}"
payload.append(f'"{key}": {field_annotation.__name__.lower()}')
if get_field_description(value):
payload[-1] += f" // {get_field_description(value)}"
# All brackets are double defined so they will passthrough a call to `.format()` where we
# pass custom variables
return "{{\n" + ",\n".join(payload) + "\n}}"
Expand Down
9 changes: 7 additions & 2 deletions gpt_json/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pydantic import BaseModel

from gpt_json.common import get_model_fields
from gpt_json.transformations import JsonFixEnum, fix_truncated_json

SchemaType = TypeVar("SchemaType", bound=BaseModel)
Expand Down Expand Up @@ -40,7 +41,9 @@ def _create_schema_from_partial(
TODO: this is hacky. ideally we want pydantic to implement Partial[SchemaType]
https://github.com/pydantic/pydantic/issues/1673
my fix is to create the schema object with all string values for now"""
cleaned_obj_data = {field: "" for field, typ in schema_model.__fields__.items()}
cleaned_obj_data = {
field: "" for field, typ in get_model_fields(schema_model).items()
}
cleaned_obj_data.update({k: v for k, v in partial.items() if v is not None})
return schema_model(**cleaned_obj_data)

Expand All @@ -65,7 +68,9 @@ def prepare_streaming_object(
raw_recent_key = (
list(current_partial_raw.keys())[-1] if current_partial_raw else None
)
updated_key = raw_recent_key if raw_recent_key in schema_model.__fields__ else None
updated_key = (
raw_recent_key if raw_recent_key in get_model_fields(schema_model) else None
)

event = proposed_event
if proposed_event == StreamEventEnum.KEY_UPDATED and updated_key is None:
Expand Down
13 changes: 10 additions & 3 deletions gpt_json/tests/test_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from openai.error import Timeout as OpenAITimeout
from pydantic import BaseModel, Field

from gpt_json.generics import resolve_generic_model
from gpt_json.gpt import GPTJSON, ListResponse
from gpt_json.models import FixTransforms, GPTMessage, GPTMessageRole, GPTModelVersion
from gpt_json.tests.shared import MySchema, MySubSchema
Expand Down Expand Up @@ -89,7 +90,9 @@ def test_cast_message_to_gpt_format(role_type: GPTMessageRole, expected: str):
}
Your response is above.
""",
ListResponse(
# Slight hack to work around ListResponse being a generic base that Pydantic can't
# otherwise validate / output to a dictionary
resolve_generic_model(ListResponse[MySchema])(
items=[
MySchema(
text="Test",
Expand Down Expand Up @@ -130,7 +133,10 @@ def test_cast_message_to_gpt_format(role_type: GPTMessageRole, expected: str):
],
)
async def test_acreate(
schema_typehint, response_raw, parsed, expected_transformations: FixTransforms
schema_typehint,
response_raw: str,
parsed: BaseModel,
expected_transformations: FixTransforms,
):
model_version = GPTModelVersion.GPT_3_5
messages = [
Expand Down Expand Up @@ -184,7 +190,8 @@ async def test_acreate(
stream=False,
)

assert response == parsed
assert response
assert response.dict() == parsed.dict()
assert transformations == expected_transformations


Expand Down
6 changes: 3 additions & 3 deletions gpt_json/types_oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@


class ChatCompletionDelta(BaseModel):
content: str | None
role: str | None
content: str | None = None
role: str | None = None


class ChatCompletionChunkChoice(BaseModel):
delta: ChatCompletionDelta
finish_reason: str | None
finish_reason: str | None = None
index: int


Expand Down
Loading

0 comments on commit 11efb9a

Please sign in to comment.