Skip to content

Commit

Permalink
fix tests for client v1
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesturk committed Nov 25, 2023
1 parent 83499ef commit de820c2
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 39 deletions.
28 changes: 15 additions & 13 deletions src/scrapeghost/apicall.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""
Module for making OpenAI API calls.
"""
import os
import time
from dataclasses import dataclass
import openai
from dataclasses import dataclass
from openai import OpenAI

client = OpenAI()
import openai.error
from typing import Callable

from .errors import (
Expand All @@ -27,11 +25,14 @@

RETRY_ERRORS = (
openai.RateLimitError,
openai.error.Timeout,
openai.error.APIConnectionError,
openai.APITimeoutError,
openai.APIConnectionError,
)


client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", ""))


@dataclass
class RetryRule:
max_retries: int = 0
Expand Down Expand Up @@ -78,7 +79,7 @@ def __init__(
self.postprocessors = postprocessors

def _raw_api_request(
self, model: str, messages: list[dict[str, str]], response: Response
self, model: str, messages: list[dict[str, str]], response: Response,
) -> Response:
"""
Make an OpenAPI request and return the raw response.
Expand All @@ -94,10 +95,13 @@ def _raw_api_request(
raise MaxCostExceeded(
f"Total cost {self.total_cost:.2f} exceeds max cost {self.max_cost:.2f}"
)
json_mode = (
{"response_format": "json_object"} if _model_dict[model].json_mode else {}
)
start_t = time.time()
completion = client.chat.completions.create(model=model,
messages=messages,
**self.model_params)
completion = client.chat.completions.create(
model=model, messages=messages, **self.model_params, **json_mode,
)
elapsed = time.time() - start_t
p_tokens = completion.usage.prompt_tokens
c_tokens = completion.usage.completion_tokens
Expand Down Expand Up @@ -128,7 +132,7 @@ def _raw_api_request(
f"(prompt_tokens={p_tokens}, "
f"completion_tokens={c_tokens})"
)
response.data = choice.message.content
response.data = choice.text
return response

def _api_request(self, html: str) -> Response:
Expand Down Expand Up @@ -168,7 +172,6 @@ def _api_request(self, html: str) -> Response:
model=model,
html_tokens=tokens,
)
json_mode = {"response_format": "json_object"} if model_data.json_mode else {}
self._raw_api_request(
model=model,
messages=[
Expand All @@ -179,7 +182,6 @@ def _api_request(self, html: str) -> Response:
{"role": "user", "content": html},
],
response=response,
**json_mode,
)
return response
except self.retry.retry_errors + (
Expand Down
6 changes: 3 additions & 3 deletions tests/test_apicall.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_model_fallback():
with patch_create() as create:
# fail first request
create.side_effect = [
_mock_response(finish_reason="timeout"),
_mock_response(finish_reason="length"),
_mock_response(),
]
api_call.request("<html>")
Expand Down Expand Up @@ -75,7 +75,7 @@ def _timeout_once(**kwargs):
if _timeout_once.called:
return _mock_response()
_timeout_once.called = True
raise openai.Timeout()
raise openai.APITimeoutError(request=None)

_timeout_once.called = False

Expand All @@ -93,7 +93,7 @@ def test_retry_failure():
retry=RetryRule(2, 0), # disable wait
)

with pytest.raises(openai.Timeout):
with pytest.raises(openai.APITimeoutError):
with patch_create() as create:
# fail first request
create.side_effect = _timeout
Expand Down
47 changes: 24 additions & 23 deletions tests/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,37 @@


def _mock_response(**kwargs):
mr = openai.openai_object.OpenAIObject.construct_from(
dict(
model=kwargs.get("model"),
choices=[
{
"finish_reason": kwargs.get("finish_reason", "stop"),
"message": {
"content": kwargs.get("content", "Hello world"),
},
}
],
usage={
"prompt_tokens": kwargs.get("prompt_tokens", 1),
"completion_tokens": kwargs.get("completion_tokens", 1),
},
created=1629200000,
id="cmpl-xxxxxxxxxxxxxxxxxxxx",
model_version=kwargs.get("model_version", "ada"),
prompt="Hello world",
status="complete",
)
mr = openai.types.completion.Completion(
model=kwargs.get("model", ""),
object="text_completion",
choices=[
openai.types.completion.CompletionChoice(
index=0,
text=kwargs.get("content", "hello world"),
finish_reason= kwargs.get("finish_reason", "stop"),
logprobs={},
)
],
usage={
"prompt_tokens": kwargs.get("prompt_tokens", 1),
"completion_tokens": kwargs.get("completion_tokens", 1),
"total_tokens": kwargs.get("prompt_tokens", 1) + kwargs.get("completion_tokens", 1),
},
created=1629200000,
id="cmpl-xxxxxxxxxxxxxxxxxxxx",
model_version=kwargs.get("model_version", "ada"),
prompt="Hello world",
status="complete",
finish_reason="stop",
)

return mr


def _timeout(**kwargs):
raise openai.Timeout()
raise openai.APITimeoutError(request=None)


def patch_create():
p = patch("scrapeghost.apicall.openai.ChatCompletion.create")
p = patch("scrapeghost.apicall.client.chat.completions.create")
return p

0 comments on commit de820c2

Please sign in to comment.