Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade OpenAI to APIV1 #43

Merged
merged 1 commit into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion gpt_json/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from gpt_json.gpt import GPTJSON, ListResponse # noqa
from gpt_json.gpt import GPTJSON as GPTJSON
from gpt_json.gpt import ListResponse as ListResponse
from gpt_json.models import * # noqa
26 changes: 18 additions & 8 deletions gpt_json/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
)

import backoff
import openai
from openai.error import APIConnectionError, RateLimitError
from openai.error import Timeout as OpenAITimeout
from openai import AsyncOpenAI
from openai._exceptions import APIConnectionError
from openai._exceptions import APITimeoutError as OpenAITimeout
from openai._exceptions import RateLimitError
from pydantic import BaseModel, Field, ValidationError
from tiktoken import encoding_for_model

Expand Down Expand Up @@ -103,7 +104,7 @@ class GPTJSON(Generic[SchemaType]):

def __init__(
self,
api_key: str | None = None,
api_key: str,
model: GPTModelVersion | str = GPTModelVersion.GPT_4,
auto_trim: bool = False,
auto_trim_response_overhead: int = 0,
Expand Down Expand Up @@ -180,7 +181,7 @@ def __init__(
)

self.schema_prompt = generate_schema_prompt(self.schema_model)
self.api_key = api_key
self.client = AsyncOpenAI(api_key=api_key)

async def run(
self,
Expand Down Expand Up @@ -435,16 +436,23 @@ async def submit_request(
]
optional_parameters["function_call"] = "auto"

execute_prediction = openai.ChatCompletion.acreate(
execute_prediction = self.client.chat.completions.create(
model=self.model,
messages=[self.message_to_dict(message) for message in messages],
temperature=self.temperature,
api_key=self.api_key,
stream=stream,
**optional_parameters,
**self.openai_arguments,
)

# We can't combine streams (which return async generators) with timeouts, since our client
# functions expect us to return immediately
if self.timeout is not None and stream:
raise ValueError("Timeouts are not supported with streaming completions")

if stream:
return execute_prediction

# The 'timeout' parameter supported by the OpenAI API is only used to cycle
# the model while it's warming up
# https://github.com/openai/openai-python/blob/fe3abd16b582ae784d8a73fd249bcdfebd5752c9/openai/api_resources/chat_completion.py#L41
Expand All @@ -456,7 +464,9 @@ async def submit_request(
try:
return await wait_for(execute_prediction, timeout=self.timeout)
except AsyncTimeoutError:
raise OpenAITimeout
# We don't have access to the underlying httpx.Request, so we just return None in place
# of the request object
raise OpenAITimeout(None) # type: ignore

def fill_messages(
self,
Expand Down
87 changes: 45 additions & 42 deletions gpt_json/tests/test_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from time import time
from unittest.mock import AsyncMock, patch

import openai
import pytest
from openai.error import Timeout as OpenAITimeout
from openai._exceptions import APITimeoutError as OpenAITimeout
from pydantic import BaseModel, Field

from gpt_json.fn_calling import parse_function
Expand All @@ -25,7 +24,7 @@ def test_throws_error_if_no_model_specified():
with pytest.raises(
ValueError, match="needs to be instantiated with a schema model"
):
GPTJSON(None)
GPTJSON(api_key="TEST")


@pytest.mark.parametrize(
Expand All @@ -37,7 +36,7 @@ def test_throws_error_if_no_model_specified():
],
)
def test_cast_message_to_gpt_format(role_type: GPTMessageRole, expected: str):
parser = GPTJSON[MySchema](None)
parser = GPTJSON[MySchema](api_key="TEST")
assert (
parser.message_to_dict(
GPTMessage(
Expand Down Expand Up @@ -138,7 +137,7 @@ def test_cast_message_to_gpt_format(role_type: GPTMessageRole, expected: str):
),
],
)
async def test_acreate(
async def test_create(
schema_typehint,
response_raw: str,
parsed: BaseModel,
Expand All @@ -152,13 +151,6 @@ async def test_acreate(
)
]

model = GPTJSON[schema_typehint](
None,
model=model_version,
temperature=0.0,
timeout=60,
)

# Define mock response
mock_response = {
"choices": [
Expand All @@ -173,16 +165,23 @@ async def test_acreate(
]
}

# Create the mock
with patch.object(openai.ChatCompletion, "acreate") as mock_acreate:
# Make the mock function asynchronous
mock_acreate.return_value = mock_response
with patch("gpt_json.gpt.AsyncOpenAI") as mock_client:
mock_client.return_value.chat.completions.create = AsyncMock(
return_value=mock_response
)

model = GPTJSON[schema_typehint](
api_key="TEST",
model=model_version,
temperature=0.0,
timeout=60,
)

# Call the function and pass the expected parameters
response = await model.run(messages=messages)

# Assert that the mock function was called with the expected parameters
mock_acreate.assert_called_with(
mock_client.return_value.chat.completions.create.assert_called_with(
model=model_version.value,
messages=[
{
Expand All @@ -192,7 +191,6 @@ async def test_acreate(
for message in messages
],
temperature=0.0,
api_key=None,
stream=False,
)

Expand All @@ -203,7 +201,7 @@ async def test_acreate(


@pytest.mark.asyncio
async def test_acreate_with_function_calls():
async def test_create_with_function_calls():
model_version = GPTModelVersion.GPT_3_5
messages = [
GPTMessage(
Expand All @@ -212,14 +210,6 @@ async def test_acreate_with_function_calls():
)
]

model = GPTJSON[MySchema](
None,
model=model_version,
temperature=0.0,
timeout=60,
functions=[get_current_weather],
)

mock_response = {
"choices": [
{
Expand All @@ -242,12 +232,22 @@ async def test_acreate_with_function_calls():
]
}

with patch.object(openai.ChatCompletion, "acreate") as mock_acreate:
mock_acreate.return_value = mock_response
with patch("gpt_json.gpt.AsyncOpenAI") as mock_client:
mock_client.return_value.chat.completions.create = AsyncMock(
return_value=mock_response
)

model = GPTJSON[MySchema](
api_key="TEST",
model=model_version,
temperature=0.0,
timeout=60,
functions=[get_current_weather],
)

response = await model.run(messages=messages)

mock_acreate.assert_called_with(
mock_client.return_value.chat.completions.create.assert_called_with(
model=model_version.value,
messages=[
{
Expand All @@ -257,7 +257,6 @@ async def test_acreate_with_function_calls():
for message in messages
],
temperature=0.0,
api_key=None,
stream=False,
functions=[parse_function(get_current_weather)],
function_call="auto",
Expand Down Expand Up @@ -299,7 +298,9 @@ async def test_acreate_with_function_calls():
],
)
def test_trim_messages(input_messages, expected_output_messages):
gpt = GPTJSON[MySchema](None, auto_trim=True, auto_trim_response_overhead=0)
gpt = GPTJSON[MySchema](
api_key="TEST", auto_trim=True, auto_trim_response_overhead=0
)

output_messages = gpt.trim_messages(input_messages, n=8192)

Expand All @@ -319,15 +320,15 @@ class TestSchema1(BaseModel):
class TestSchema2(BaseModel):
field2: str

gptjson1 = GPTJSON[TestSchema1](None)
gptjson1 = GPTJSON[TestSchema1](api_key="TRUE")

# Shouldn't allow instantion without a schema
# We already expect a mypy error here, which is why we need a `type ignore`
# butr we also want to make sure that the error is raised at runtime
with pytest.raises(ValueError):
gptjson2 = GPTJSON(None) # type: ignore

gptjson2 = GPTJSON[TestSchema2](None)
gptjson2 = GPTJSON[TestSchema2](api_key="TRUE")

assert gptjson1.schema_model == TestSchema1
assert gptjson2.schema_model == TestSchema2
Expand All @@ -337,7 +338,7 @@ def test_fill_message_schema_template():
class TestTemplateSchema(BaseModel):
template_field: str = Field(description="Max length {max_length}")

gpt = GPTJSON[TestTemplateSchema](None)
gpt = GPTJSON[TestTemplateSchema](api_key="TRUE")
assert gpt.fill_message_template(
GPTMessage(
role=GPTMessageRole.USER,
Expand All @@ -356,7 +357,7 @@ def test_fill_message_functions_template():
class TestTemplateSchema(BaseModel):
template_field: str = Field(description="Max length {max_length}")

gpt = GPTJSON[TestTemplateSchema](None, functions=[get_current_weather])
gpt = GPTJSON[TestTemplateSchema](api_key="TRUE", functions=[get_current_weather])
assert gpt.fill_message_template(
GPTMessage(
role=GPTMessageRole.USER,
Expand All @@ -371,7 +372,7 @@ class TestTemplateSchema(BaseModel):

@pytest.mark.asyncio
async def test_extracted_json_is_None():
gpt = GPTJSON[MySchema](None)
gpt = GPTJSON[MySchema](api_key="TRUE")

with patch.object(
gpt,
Expand All @@ -390,7 +391,7 @@ async def test_extracted_json_is_None():

@pytest.mark.asyncio
async def test_no_valid_results_from_remote_request():
gpt = GPTJSON[MySchema](None)
gpt = GPTJSON[MySchema](api_key="TRUE")

with patch.object(gpt, "submit_request", return_value={"choices": []}):
result = await gpt.run(
Expand All @@ -401,7 +402,7 @@ async def test_no_valid_results_from_remote_request():

@pytest.mark.asyncio
async def test_unable_to_find_valid_json_payload():
gpt = GPTJSON[MySchema](None)
gpt = GPTJSON[MySchema](api_key="TRUE")

with patch.object(
gpt,
Expand All @@ -421,7 +422,7 @@ async def test_unable_to_find_valid_json_payload():
@pytest.mark.asyncio
async def test_unknown_model_to_infer_max_tokens():
with pytest.raises(ValueError):
GPTJSON[MySchema](model="UnknownModel", auto_trim=True)
GPTJSON[MySchema](api_key="TRUE", model="UnknownModel", auto_trim=True)


@pytest.mark.asyncio
Expand Down Expand Up @@ -453,13 +454,15 @@ async def read(self):
}
return json_dumps(mock_response).encode()

with patch("aiohttp.ClientSession.request", new_callable=AsyncMock) as mock_request:
with patch("gpt_json.gpt.AsyncOpenAI") as mock_client:
# Mock a stalling request
async def side_effect(*args, **kwargs):
await asyncio.sleep(4)
return MockResponse("TEST_RESPONSE")

mock_request.side_effect = side_effect
mock_client.return_value.chat.completions.create = AsyncMock(
side_effect=side_effect
)

gpt = GPTJSON[MySchema](api_key="ABC", timeout=2)

Expand Down
22 changes: 9 additions & 13 deletions gpt_json/tests/test_streaming.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
from unittest.mock import patch

import openai
import pytest

from gpt_json.gpt import GPTJSON
Expand Down Expand Up @@ -90,13 +89,6 @@ async def test_gpt_stream(
)
]

model = GPTJSON[schema_typehint]( # type: ignore
None,
model=model_version,
temperature=0.0,
timeout=60,
)

# Define mock response
mocked_oai_raw_responses = _mock_oai_streaming_chunks(full_object)

Expand All @@ -107,9 +99,14 @@ async def async_list_to_generator(my_list):
mock_response = async_list_to_generator(mocked_oai_raw_responses)

# Create the mock
with patch.object(openai.ChatCompletion, "acreate") as mock_acreate:
# Make the mock function asynchronous
mock_acreate.return_value = mock_response
with patch("gpt_json.gpt.AsyncOpenAI") as mock_client:
mock_client.return_value.chat.completions.create.return_value = mock_response

model = GPTJSON[schema_typehint]( # type: ignore
api_key="TEST",
model=model_version,
temperature=0.0,
)

if not should_support:
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -137,7 +134,7 @@ async def async_list_to_generator(my_list):
idx += 1

# Assert that the mock function was called with the expected parameters, including streaming
mock_acreate.assert_called_with(
mock_client.return_value.chat.completions.create.assert_called_with(
model=model_version.value,
messages=[
{
Expand All @@ -148,5 +145,4 @@ async def async_list_to_generator(my_list):
],
temperature=0.0,
stream=True,
api_key=None,
)
4 changes: 2 additions & 2 deletions gpt_json/tests/test_truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_fill_messages_truncated():
class TestSchema(BaseModel):
summary: str

gpt = GPTJSON[TestSchema](None)
gpt = GPTJSON[TestSchema](api_key="TEST")
assert gpt.fill_messages(
[
GPTMessage(role=GPTMessageRole.SYSTEM, content="system"),
Expand Down Expand Up @@ -53,7 +53,7 @@ def test_fill_messages_truncated_failure_case():
class TestSchema(BaseModel):
summary: str

gpt = GPTJSON[TestSchema](None)
gpt = GPTJSON[TestSchema](api_key="TEST")

# this should fail because the max_prompt_tokens is too small
with pytest.raises(ValueError, match=".* max_prompt_tokens .* too small .*"):
Expand Down
Loading
Loading