Skip to content

Commit

Permalink
Fix multiline descriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
piercefreeman committed Aug 25, 2023
1 parent c99c2ce commit 93f75e4
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 14 deletions.
31 changes: 27 additions & 4 deletions gpt_json/fn_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@ def parse_function(fn: Callable) -> Dict[str, Any]:
API Reference: https://platform.openai.com/docs/api-reference/chat/create
"""
docstring = getdoc(fn) or ""
lines = docstring.strip().split("\n")
description = lines[0] if lines else None

parameter_type = get_argument_for_function(fn)
description = get_function_description(fn)

# Parse the parameter type into a JSON schema
parameter_schema = model_to_parameter_schema(parameter_type)
Expand All @@ -43,6 +40,32 @@ def function_to_name(fn: Callable) -> str:
return fn.__name__


def get_function_description(fn: Callable) -> str:
"""
The description of a function is everything before an empty linebreak.
For instance:
```
A
B
C
```
Would return "A B"
"""
docstring = getdoc(fn) or ""
lines = docstring.strip().split("\n")
description_lines = []
for line in lines:
if not line.strip():
break
description_lines.append(line.strip())
return " ".join(description_lines)


def get_argument_for_function(fn: Callable) -> Type[BaseModel]:
"""
Function definitions are expected to have one argument, which is a pydantic BaseModel that captures
Expand Down
9 changes: 4 additions & 5 deletions gpt_json/tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ class GetCurrentWeatherRequest(BaseModel):


def get_current_weather(request: GetCurrentWeatherRequest):
"""
Get the current weather in a given location.
Second line should also be included.
"""Test description"""

The rest of the docstring should be omitted.
"""

async def get_current_weather_async(request: GetCurrentWeatherRequest):
"""Test description"""


def get_weather_additional_args(request: GetCurrentWeatherRequest, other_args: str):
Expand Down
38 changes: 33 additions & 5 deletions gpt_json/tests/test_fn_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,30 @@

import pytest

from gpt_json.fn_calling import get_base_type, parse_function
from gpt_json.fn_calling import get_base_type, get_function_description, parse_function
from gpt_json.tests.shared import (
UnitType,
get_current_weather,
get_current_weather_async,
get_weather_additional_args,
get_weather_no_pydantic,
)


def multi_line_description_fn():
"""
Test
description
Hidden
description
"""


def single_line_description_fn():
"""Test description"""


@pytest.mark.parametrize(
"incorrect_fn",
[
Expand All @@ -29,18 +44,31 @@ def test_get_base_type():
assert get_base_type(Union[UnitType, None]) == UnitType


def test_parse_function():
def test_get_function_description():
assert get_function_description(multi_line_description_fn) == "Test description"
assert get_function_description(single_line_description_fn) == "Test description"


@pytest.mark.parametrize(
"function,expected_name",
[
(get_current_weather, "get_current_weather"),
(get_current_weather_async, "get_current_weather_async"),
],
)
def test_parse_function(function, expected_name: str):
"""
Assert the formatted schema conforms to the expected JSON-Schema / GPT format.
"""
parse_function(get_current_weather) == {
"name": "get_current_weather",
"description": "Get the current weather in a given location. Second line should also be included.",
assert parse_function(function) == {
"name": expected_name,
"description": "Test description",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"title": "Location",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
Expand Down

0 comments on commit 93f75e4

Please sign in to comment.