From 2f7c4f434a2bde84b81c773fc64477d3a3244c6a Mon Sep 17 00:00:00 2001 From: Pierce Freeman Date: Thu, 24 Aug 2023 15:17:12 -0700 Subject: [PATCH] Fix typehinting of functions --- gpt_json/gpt.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/gpt_json/gpt.py b/gpt_json/gpt.py index dd77474..47f5072 100644 --- a/gpt_json/gpt.py +++ b/gpt_json/gpt.py @@ -89,7 +89,7 @@ class RunResponse(Generic[SchemaType], BaseModel): raw_response: GPTMessage | None response: SchemaType | None fix_transforms: FixTransforms | None - function_call: Callable[..., BaseModel] | None + function_call: Callable[[BaseModel], Any] | None function_arg: BaseModel | None @@ -108,7 +108,7 @@ def __init__( model: GPTModelVersion | str = GPTModelVersion.GPT_4, auto_trim: bool = False, auto_trim_response_overhead: int = 0, - functions: list[Callable[..., BaseModel]] | None = None, + functions: list[Callable[[Any], Any]] | None = None, # For messages that are relatively deterministic temperature=0.2, timeout: int | None = None, @@ -134,7 +134,13 @@ def __init__( self.openai_max_retries = openai_max_retries self.openai_arguments = kwargs self.schema_model = self._cls_schema_model - self.functions = {function_to_name(fn): fn for fn in (functions or [])} + self.functions = { + function_to_name(cast_fn): cast_fn + for fn in (functions or []) + # Use an explicit cast; Callable can't be typehinted with BaseModel directly + # because [BaseModel] is considered invariant with subclasses + for cast_fn in [cast(Callable[[BaseModel], Any], fn)] + } self.__class__._cls_schema_model = None if not self.schema_model: @@ -233,7 +239,7 @@ async def run( function_arg=None, ) - function_call: Callable[..., BaseModel] | None = None + function_call: Callable[[BaseModel], Any] | None = None function_parsed: BaseModel | None = None if response_message.get("function_call"):