Skip to content

Commit

Permalink
Fix typehinting of functions
Browse files Browse the repository at this point in the history
  • Loading branch information
piercefreeman committed Aug 25, 2023
1 parent c393b34 commit c99c2ce
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
14 changes: 10 additions & 4 deletions gpt_json/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class RunResponse(Generic[SchemaType], BaseModel):
raw_response: GPTMessage | None
response: SchemaType | None
fix_transforms: FixTransforms | None
function_call: Callable[..., BaseModel] | None
function_call: Callable[[BaseModel], Any] | None
function_arg: BaseModel | None


Expand All @@ -108,7 +108,7 @@ def __init__(
model: GPTModelVersion | str = GPTModelVersion.GPT_4,
auto_trim: bool = False,
auto_trim_response_overhead: int = 0,
functions: list[Callable[..., BaseModel]] | None = None,
functions: list[Callable[[Any], Any]] | None = None,
# For messages that are relatively deterministic
temperature=0.2,
timeout: int | None = None,
Expand All @@ -134,7 +134,13 @@ def __init__(
self.openai_max_retries = openai_max_retries
self.openai_arguments = kwargs
self.schema_model = self._cls_schema_model
self.functions = {function_to_name(fn): fn for fn in (functions or [])}
self.functions = {
function_to_name(cast_fn): cast_fn
for fn in (functions or [])
# Use an explicit cast; Callable can't be typehinted with BaseModel directly
# because [BaseModel] is considered invariant with subclasses
for cast_fn in [cast(Callable[[BaseModel], Any], fn)]
}
self.__class__._cls_schema_model = None

if not self.schema_model:
Expand Down Expand Up @@ -233,7 +239,7 @@ async def run(
function_arg=None,
)

function_call: Callable[..., BaseModel] | None = None
function_call: Callable[[BaseModel], Any] | None = None
function_parsed: BaseModel | None = None

if response_message.get("function_call"):
Expand Down
3 changes: 2 additions & 1 deletion gpt_json/tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class GetCurrentWeatherRequest(BaseModel):

def get_current_weather(request: GetCurrentWeatherRequest):
"""
Get the current weather in a given location
Get the current weather in a given location.
Second line should also be included.
The rest of the docstring should be omitted.
"""
Expand Down
2 changes: 1 addition & 1 deletion gpt_json/tests/test_fn_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_parse_function():
"""
parse_function(get_current_weather) == {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"description": "Get the current weather in a given location. Second line should also be included.",
"parameters": {
"type": "object",
"properties": {
Expand Down

0 comments on commit c99c2ce

Please sign in to comment.