From c99c2ce1cdeeae63e7c6f4151cd885ab6d4a432b Mon Sep 17 00:00:00 2001 From: Pierce Freeman Date: Thu, 24 Aug 2023 15:17:12 -0700 Subject: [PATCH] Fix typehinting of functions --- gpt_json/gpt.py | 14 ++++++++++---- gpt_json/tests/shared.py | 3 ++- gpt_json/tests/test_fn_calling.py | 2 +- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/gpt_json/gpt.py b/gpt_json/gpt.py index dd77474..47f5072 100644 --- a/gpt_json/gpt.py +++ b/gpt_json/gpt.py @@ -89,7 +89,7 @@ class RunResponse(Generic[SchemaType], BaseModel): raw_response: GPTMessage | None response: SchemaType | None fix_transforms: FixTransforms | None - function_call: Callable[..., BaseModel] | None + function_call: Callable[[BaseModel], Any] | None function_arg: BaseModel | None @@ -108,7 +108,7 @@ def __init__( model: GPTModelVersion | str = GPTModelVersion.GPT_4, auto_trim: bool = False, auto_trim_response_overhead: int = 0, - functions: list[Callable[..., BaseModel]] | None = None, + functions: list[Callable[[Any], Any]] | None = None, # For messages that are relatively deterministic temperature=0.2, timeout: int | None = None, @@ -134,7 +134,13 @@ def __init__( self.openai_max_retries = openai_max_retries self.openai_arguments = kwargs self.schema_model = self._cls_schema_model - self.functions = {function_to_name(fn): fn for fn in (functions or [])} + self.functions = { + function_to_name(cast_fn): cast_fn + for fn in (functions or []) + # Use an explicit cast; Callable can't be typehinted with BaseModel directly + # because [BaseModel] is considered invariant with subclasses + for cast_fn in [cast(Callable[[BaseModel], Any], fn)] + } self.__class__._cls_schema_model = None if not self.schema_model: @@ -233,7 +239,7 @@ async def run( function_arg=None, ) - function_call: Callable[..., BaseModel] | None = None + function_call: Callable[[BaseModel], Any] | None = None function_parsed: BaseModel | None = None if response_message.get("function_call"): diff --git a/gpt_json/tests/shared.py b/gpt_json/tests/shared.py index 7ca85cc..7b60507 100644 --- a/gpt_json/tests/shared.py +++ b/gpt_json/tests/shared.py @@ -27,7 +27,8 @@ class GetCurrentWeatherRequest(BaseModel): def get_current_weather(request: GetCurrentWeatherRequest): """ - Get the current weather in a given location + Get the current weather in a given location. + Second line should also be included. The rest of the docstring should be omitted. """ diff --git a/gpt_json/tests/test_fn_calling.py b/gpt_json/tests/test_fn_calling.py index 2840704..782010d 100644 --- a/gpt_json/tests/test_fn_calling.py +++ b/gpt_json/tests/test_fn_calling.py @@ -35,7 +35,7 @@ def test_parse_function(): """ parse_function(get_current_weather) == { "name": "get_current_weather", - "description": "Get the current weather in a given location", + "description": "Get the current weather in a given location. Second line should also be included.", "parameters": { "type": "object", "properties": {