Skip to content

Commit

Permalink
Allow literal values in pydantic model
Browse files Browse the repository at this point in the history
  • Loading branch information
piercefreeman committed Nov 4, 2023
1 parent d75ca4e commit 7dba4b6
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
6 changes: 5 additions & 1 deletion gpt_json/prompts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from types import UnionType
from typing import List, Type, get_args, get_origin
from typing import List, Literal, Type, get_args, get_origin

from pydantic import BaseModel

Expand All @@ -17,6 +17,7 @@ def generate_payload(model: Type[BaseModel]):
field_annotation = value.annotation
annotation_origin = get_origin(field_annotation)
annotation_arguments = get_args(field_annotation)
print(annotation_origin)

if field_annotation is None:
continue
Expand All @@ -31,6 +32,9 @@ def generate_payload(model: Type[BaseModel]):
payload.append(
f'"{key}": {" | ".join([arg.__name__.lower() for arg in annotation_arguments])}'
)
elif annotation_origin == Literal:
allowed_values = [f'"{arg}"' for arg in annotation_arguments]
payload.append(f'"{key}": {" | ".join(allowed_values)}')
elif issubclass(field_annotation, BaseModel):
payload.append(f'"{key}": {generate_payload(field_annotation)}')
else:
Expand Down
7 changes: 7 additions & 0 deletions gpt_json/tests/shared.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum
from typing import Literal

from pydantic import BaseModel, Field

Expand All @@ -15,6 +16,12 @@ class MySchema(BaseModel):
reason: bool = Field(description="Explanation")


class LiteralSchema(BaseModel):
work_format: Literal["REMOTE", "OFFICE", "ANY"] = Field(
default="ANY", description="One of the given values"
)


class UnitType(Enum):
CELSIUS = "celsius"
FAHRENHEIT = "fahrenheit"
Expand Down
10 changes: 9 additions & 1 deletion gpt_json/tests/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from gpt_json.generics import resolve_generic_model
from gpt_json.gpt import ListResponse
from gpt_json.prompts import generate_schema_prompt
from gpt_json.tests.shared import MySchema
from gpt_json.tests.shared import LiteralSchema, MySchema


def strip_whitespace(input_string: str):
Expand Down Expand Up @@ -45,6 +45,14 @@ def strip_whitespace(input_string: str):
}}
""",
),
(
LiteralSchema,
"""
{{
"work_format": "REMOTE" | "OFFICE" | "ANY" // One of the given values
}}
""",
),
],
)
def test_generate_schema_prompt(schema_definition, expected: str):
Expand Down

0 comments on commit 7dba4b6

Please sign in to comment.