Skip to content

Commit

Permalink
Switch mocking to raw http requests
Browse files Browse the repository at this point in the history
  • Loading branch information
piercefreeman committed Apr 13, 2024
1 parent bca2f38 commit 9ecb934
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 221 deletions.
20 changes: 11 additions & 9 deletions gpt_json/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,9 @@ def __init__(
)

self.schema_prompt = generate_schema_prompt(self.schema_model)
self.client = AsyncOpenAI(api_key=api_key)

# We use separate retry logic, versus the one that's baked into OpenAPI
self.client = AsyncOpenAI(api_key=api_key, max_retries=1)

async def run(
self,
Expand Down Expand Up @@ -225,7 +227,7 @@ async def run(
)

logger.debug("------- RAW RESPONSE ----------")
logger.debug(response["choices"])
logger.debug(response.choices)
logger.debug("------- END RAW RESPONSE ----------")

# If the response requests a function call, prefer this over the main response
Expand All @@ -242,9 +244,9 @@ async def run(
function_call: Callable[[BaseModel], Any] | None = None
function_parsed: BaseModel | None = None

if response_message.get("function_call"):
function_name = response_message["function_call"]["name"]
function_args_string = response_message["function_call"]["arguments"]
if response_message.function_call:
function_name = response_message.function_call.name
function_args_string = response_message.function_call.arguments
if function_name not in self.functions:
raise InvalidFunctionResponse(function_name)

Expand All @@ -259,7 +261,7 @@ async def run(
except (ValueError, ValidationError):
raise InvalidFunctionParameters(function_name, function_args_string)

raw_response = GPTMessage.model_validate(response_message)
raw_response = GPTMessage.model_validate_json(response_message.model_dump_json())
raw_response.allow_templating = False

extracted_json, fixed_payload = self.extract_json(
Expand Down Expand Up @@ -375,7 +377,7 @@ def extract_json(self, response_message, extract_type: ResponseType):
Assumes one main block of results, either list of dictionary
"""

full_response = response_message["content"]
full_response = response_message.content
if not full_response:
return None, None

Expand All @@ -402,13 +404,13 @@ def extract_json(self, response_message, extract_type: ResponseType):
return None, fixed_payload

def extract_response_message(self, completion_response):
choices = completion_response["choices"]
choices = completion_response.choices

if not choices:
logger.warning("No choices available, should report error...")
return None

return choices[0]["message"]
return choices[0].message

async def submit_request(
self,
Expand Down
226 changes: 106 additions & 120 deletions gpt_json/tests/test_gpt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import asyncio
from json import dumps as json_dumps
from time import time
from typing import Any
from unittest.mock import AsyncMock, patch

import pytest
from openai._exceptions import APITimeoutError as OpenAITimeout
from pydantic import BaseModel, Field
from pytest_httpx import HTTPXMock

from gpt_json.fn_calling import parse_function
from gpt_json.gpt import GPTJSON, ListResponse
from gpt_json.models import (
FixTransforms,
Expand All @@ -26,6 +27,19 @@
from gpt_json.transformations import JsonFixEnum


def make_assistant_response(choices: list[Any]):
# https://platform.openai.com/docs/api-reference/chat/create
return {
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-3.5-turbo-0125",
"system_fingerprint": "fp_3478aj6f3a",
"choices": choices,
"usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21},
}


def test_throws_error_if_no_model_specified():
with pytest.raises(
ValueError, match="needs to be instantiated with a schema model"
Expand Down Expand Up @@ -145,6 +159,7 @@ def test_cast_message_to_gpt_format(role_type: GPTMessageRole, expected: str):
)
async def test_create(
schema_typehint,
httpx_mock: HTTPXMock,
response_raw: str,
parsed: BaseModel,
expected_transformations: FixTransforms,
Expand All @@ -158,50 +173,31 @@ async def test_create(
]

# Define mock response
mock_response = {
"choices": [
{
"message": {
"role": "assistant",
"content": response_raw,
},
"index": 0,
"finish_reason": "stop",
}
]
}

with patch("gpt_json.gpt.AsyncOpenAI") as mock_client:
mock_client.return_value.chat.completions.create = AsyncMock(
return_value=mock_response
)

model = GPTJSON[schema_typehint](
api_key="TEST",
model=model_version,
temperature=0.0,
timeout=60,
)

# Call the function and pass the expected parameters
response = await model.run(messages=messages)

# Assert that the mock function was called with the expected parameters
mock_client.return_value.chat.completions.create.assert_called_with(
model=model_version.value.api_name,
messages=[
httpx_mock.add_response(
url="https://api.openai.com/v1/chat/completions",
json=make_assistant_response(
[
{
"role": message.role.value,
"content": [
content.model_dump(by_alias=True, exclude_unset=True)
for content in message.get_content_payloads()
],
"message": {
"role": "assistant",
"content": response_raw,
},
"index": 0,
"finish_reason": "stop",
}
for message in messages
],
temperature=0.0,
stream=False,
)
]
),
)

model = GPTJSON[schema_typehint](
api_key="TEST",
model=model_version,
temperature=0.0,
timeout=60,
)

# Call the function and pass the expected parameters
response = await model.run(messages=messages)

assert response
assert response.response
Expand All @@ -210,7 +206,9 @@ async def test_create(


@pytest.mark.asyncio
async def test_create_with_function_calls():
async def test_create_with_function_calls(
httpx_mock: HTTPXMock,
):
model_version = GPTModelVersion.GPT_3_5
messages = [
GPTMessage(
Expand All @@ -219,60 +217,41 @@ async def test_create_with_function_calls():
)
]

mock_response = {
"choices": [
{
"message": {
"role": "assistant",
"content": "",
"function_call": {
"name": "get_current_weather",
"arguments": json_dumps(
{
"location": "Boston",
"unit": "fahrenheit",
}
),
# Define mock response
httpx_mock.add_response(
url="https://api.openai.com/v1/chat/completions",
json=make_assistant_response(
[
{
"message": {
"role": "assistant",
"content": "",
"function_call": {
"name": "get_current_weather",
"arguments": json_dumps(
{
"location": "Boston",
"unit": "fahrenheit",
}
),
},
},
},
"index": 0,
"finish_reason": "stop",
}
]
}

with patch("gpt_json.gpt.AsyncOpenAI") as mock_client:
mock_client.return_value.chat.completions.create = AsyncMock(
return_value=mock_response
)

model = GPTJSON[MySchema](
api_key="TEST",
model=model_version,
temperature=0.0,
timeout=60,
functions=[get_current_weather],
)
"index": 0,
"finish_reason": "stop",
}
]
),
)

response = await model.run(messages=messages)
model = GPTJSON[MySchema](
api_key="TEST",
model=model_version,
temperature=0.0,
timeout=60,
functions=[get_current_weather],
)

mock_client.return_value.chat.completions.create.assert_called_with(
model=model_version.value.api_name,
messages=[
{
"role": message.role.value,
"content": [
content.model_dump(by_alias=True, exclude_unset=True)
for content in message.get_content_payloads()
],
}
for message in messages
],
temperature=0.0,
stream=False,
functions=[parse_function(get_current_weather)],
function_call="auto",
)
response = await model.run(messages=messages)

assert response
assert response.response is None
Expand Down Expand Up @@ -395,16 +374,17 @@ class TestTemplateSchema(BaseModel):


@pytest.mark.asyncio
async def test_extracted_json_is_None():
async def test_extracted_json_is_none(httpx_mock: HTTPXMock):
gpt = GPTJSON[MySchema](api_key="TRUE")

httpx_mock.add_response(
url="https://api.openai.com/v1/chat/completions",
json=make_assistant_response(
[{"message": {"content": "some content", "role": "assistant"}}]
),
)

with patch.object(
gpt,
"submit_request",
return_value={
"choices": [{"message": {"content": "some content", "role": "assistant"}}]
},
), patch.object(
gpt, "extract_json", return_value=(None, FixTransforms(None, False))
):
result = await gpt.run(
Expand All @@ -419,32 +399,38 @@ async def test_extracted_json_is_None():


@pytest.mark.asyncio
async def test_no_valid_results_from_remote_request():
async def test_no_valid_results_from_remote_request(
httpx_mock: HTTPXMock,
):
gpt = GPTJSON[MySchema](api_key="TRUE")

with patch.object(gpt, "submit_request", return_value={"choices": []}):
result = await gpt.run(
[
GPTMessage(
role=GPTMessageRole.SYSTEM,
content=[TextContent(text="message content")],
)
]
)
assert result.response is None
httpx_mock.add_response(
url="https://api.openai.com/v1/chat/completions",
json=make_assistant_response([]),
)

result = await gpt.run(
[
GPTMessage(
role=GPTMessageRole.SYSTEM,
content=[TextContent(text="message content")],
)
]
)
assert result.response is None


@pytest.mark.asyncio
async def test_unable_to_find_valid_json_payload():
async def test_unable_to_find_valid_json_payload(httpx_mock: HTTPXMock):
httpx_mock.add_response(
url="https://api.openai.com/v1/chat/completions",
json=make_assistant_response(
[{"message": {"content": "some content", "role": "assistant"}}]
),
)
gpt = GPTJSON[MySchema](api_key="TRUE")

with patch.object(
gpt,
"submit_request",
return_value={
"choices": [{"message": {"content": "some content", "role": "assistant"}}]
},
), patch.object(
gpt, "extract_json", return_value=(None, FixTransforms(None, False))
):
result = await gpt.run(
Expand Down
Loading

0 comments on commit 9ecb934

Please sign in to comment.