Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function calling support #38

Merged
merged 19 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
matrix:
python: ["3.11"]
pydantic: ["1.10.12", "2.1.1"]
pydantic: ["2.1.1"]

steps:
- uses: actions/checkout@v3
Expand Down
71 changes: 64 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Specifically this library:
- Includes retry logic for the most common API failures
- Formats the JSON schema as a flexible prompt that can be added into any message
- Supports templating of prompts to allow for dynamic content
- Validate typehinted function calls in the new GPT models, to better support agent creation

## Getting Started

Expand Down Expand Up @@ -38,7 +39,7 @@ Respond with the following JSON schema:

async def runner():
gpt_json = GPTJSON[SentimentSchema](API_KEY)
response, _ = await gpt_json.run(
payload = await gpt_json.run(
messages=[
GPTMessage(
role=GPTMessageRole.SYSTEM,
Expand All @@ -50,8 +51,8 @@ async def runner():
)
]
)
print(response)
print(f"Detected sentiment: {response.sentiment}")
print(payload.response)
print(f"Detected sentiment: {payload.response.sentiment}")

asyncio.run(runner())
```
Expand Down Expand Up @@ -101,7 +102,7 @@ Generate fictitious quotes that are {sentiment}.
"""

gpt_json = GPTJSON[QuoteSchema](API_KEY)
response, _ = await gpt_json.run(
response = await gpt_json.run(
messages=[
GPTMessage(
role=GPTMessageRole.SYSTEM,
Expand All @@ -125,7 +126,7 @@ Generate fictitious quotes that are {sentiment}.
"""

gpt_json = GPTJSON[QuoteSchema](API_KEY)
response, _ = await gpt_json.run(
response = await gpt_json.run(
messages=[
GPTMessage(
role=GPTMessageRole.SYSTEM,
Expand All @@ -136,6 +137,62 @@ response, _ = await gpt_json.run(
)
```

## Function Calls

`gpt-3.5-turbo-0613` and `gpt-4-0613` were fine-tuned to support a specific syntax for function calls. We support this syntax in `gpt-json` as well. Here's an example of how to use it:

```python
class UnitType(Enum):
CELSIUS = "celsius"
FAHRENHEIT = "fahrenheit"


class GetCurrentWeatherRequest(BaseModel):
location: str = Field(description="The city and state, e.g. San Francisco, CA")
unit: UnitType | None = None


class DataPayload(BaseModel):
data: str


def get_current_weather(request: GetCurrentWeatherRequest):
"""
Get the current weather in a given location
"""
weather_info = {
"location": request.location,
"temperature": "72",
"unit": request.unit,
"forecast": ["sunny", "windy"],
}
return json_dumps(weather_info)


async def runner():
gpt_json = GPTJSON[DataPayload](API_KEY, functions=[get_current_weather])
response = await gpt_json.run(
messages=[
GPTMessage(
role=GPTMessageRole.USER,
content="What's the weather like in Boston, in F?",
),
],
)

assert response.function_call == get_current_weather
assert response.function_arg == GetCurrentWeatherRequest(
location="Boston", unit=UnitType.FAHRENHEIT
)
```

The response provides the original function alongside a formatted Pydantic object. If users want to execute the function, they can run response.function_call(response.function_arg). We will parse the get_current_weather function and the GetCurrentWeatherRequest parameter into the format that GPT expects, so it is more likely to return you a correct function execution.

GPT makes no guarantees about the validity of the returned functions. They could hallucinate a function name or the function signature. To address these cases, the run() function may now throw two new exceptions:

`InvalidFunctionResponse` - The function name is incorrect.
`InvalidFunctionParameters` - The function name is correct, but doesn't match the input schema that was provided.

## Other Configurations

The `GPTJSON` class supports other configuration parameters at initialization.
Expand All @@ -157,9 +214,9 @@ GPT (especially GPT-4) is relatively good at formatting responses at JSON, but i
When calling `gpt_json.run()`, we return a tuple of values:

```python
response, transformations = await gpt_json.run(...)
payload = await gpt_json.run(...)

print(transformations)
print(transformations.fix_transforms)
```

```bash
Expand Down
62 changes: 62 additions & 0 deletions examples/function_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import asyncio
from enum import Enum
from json import dumps as json_dumps
from os import getenv

from dotenv import load_dotenv
from pydantic import BaseModel, Field

from gpt_json import GPTJSON, GPTMessage, GPTMessageRole

load_dotenv()
API_KEY = getenv("OPENAI_API_KEY")


class UnitType(Enum):
CELSIUS = "celsius"
FAHRENHEIT = "fahrenheit"


class GetCurrentWeatherRequest(BaseModel):
location: str = Field(description="The city and state, e.g. San Francisco, CA")
unit: UnitType | None = None


class DataPayload(BaseModel):
data: str


def get_current_weather(request: GetCurrentWeatherRequest):
"""
Get the current weather in a given location

The rest of the docstring should be omitted.
"""
weather_info = {
"location": request.location,
"temperature": "72",
"unit": request.unit,
"forecast": ["sunny", "windy"],
}
return json_dumps(weather_info)


async def runner():
gpt_json = GPTJSON[DataPayload](API_KEY, functions=[get_current_weather])
response = await gpt_json.run(
messages=[
GPTMessage(
role=GPTMessageRole.USER,
content="What's the weather like in Boston, in F?",
),
],
)

print(response)
assert response.function_call == get_current_weather
assert response.function_arg == GetCurrentWeatherRequest(
location="Boston", unit=UnitType.FAHRENHEIT
)


asyncio.run(runner())
4 changes: 2 additions & 2 deletions examples/hint_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ async def runner():
),
]
)
print(response)
print(f"Detected sentiment: {response.sentiment}")
print(response.response)
print(f"Detected sentiment: {response.response.sentiment}")


asyncio.run(runner())
Loading