Skip to content

Commit

Permalink
Fix typehinting of functions
Browse files Browse the repository at this point in the history
  • Loading branch information
piercefreeman committed Aug 24, 2023
1 parent c393b34 commit 2f7c4f4
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions gpt_json/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class RunResponse(Generic[SchemaType], BaseModel):
raw_response: GPTMessage | None
response: SchemaType | None
fix_transforms: FixTransforms | None
function_call: Callable[..., BaseModel] | None
function_call: Callable[[BaseModel], Any] | None
function_arg: BaseModel | None


Expand All @@ -108,7 +108,7 @@ def __init__(
model: GPTModelVersion | str = GPTModelVersion.GPT_4,
auto_trim: bool = False,
auto_trim_response_overhead: int = 0,
functions: list[Callable[..., BaseModel]] | None = None,
functions: list[Callable[[Any], Any]] | None = None,
# For messages that are relatively deterministic
temperature=0.2,
timeout: int | None = None,
Expand All @@ -134,7 +134,13 @@ def __init__(
self.openai_max_retries = openai_max_retries
self.openai_arguments = kwargs
self.schema_model = self._cls_schema_model
self.functions = {function_to_name(fn): fn for fn in (functions or [])}
self.functions = {
function_to_name(cast_fn): cast_fn
for fn in (functions or [])
# Use an explicit cast; Callable can't be typehinted with BaseModel directly
# because [BaseModel] is considered invariant with subclasses
for cast_fn in [cast(Callable[[BaseModel], Any], fn)]
}
self.__class__._cls_schema_model = None

if not self.schema_model:
Expand Down Expand Up @@ -233,7 +239,7 @@ async def run(
function_arg=None,
)

function_call: Callable[..., BaseModel] | None = None
function_call: Callable[[BaseModel], Any] | None = None
function_parsed: BaseModel | None = None

if response_message.get("function_call"):
Expand Down

0 comments on commit 2f7c4f4

Please sign in to comment.