diff --git a/src/scrapeghost/apicall.py b/src/scrapeghost/apicall.py index 4ccc3a6..51e2297 100644 --- a/src/scrapeghost/apicall.py +++ b/src/scrapeghost/apicall.py @@ -1,13 +1,11 @@ """ Module for making OpenAI API calls. """ +import os import time -from dataclasses import dataclass import openai +from dataclasses import dataclass from openai import OpenAI - -client = OpenAI() -import openai.error from typing import Callable from .errors import ( @@ -27,11 +25,14 @@ RETRY_ERRORS = ( openai.RateLimitError, - openai.error.Timeout, - openai.error.APIConnectionError, + openai.APITimeoutError, + openai.APIConnectionError, ) +client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "")) + + @dataclass class RetryRule: max_retries: int = 0 @@ -78,7 +79,7 @@ def __init__( self.postprocessors = postprocessors def _raw_api_request( - self, model: str, messages: list[dict[str, str]], response: Response + self, model: str, messages: list[dict[str, str]], response: Response, ) -> Response: """ Make an OpenAPI request and return the raw response. @@ -94,10 +95,13 @@ def _raw_api_request( raise MaxCostExceeded( f"Total cost {self.total_cost:.2f} exceeds max cost {self.max_cost:.2f}" ) + json_mode = ( + {"response_format": "json_object"} if _model_dict[model].json_mode else {} + ) start_t = time.time() - completion = client.chat.completions.create(model=model, - messages=messages, - **self.model_params) + completion = client.chat.completions.create( + model=model, messages=messages, **self.model_params, **json_mode, + ) elapsed = time.time() - start_t p_tokens = completion.usage.prompt_tokens c_tokens = completion.usage.completion_tokens @@ -128,7 +132,7 @@ def _raw_api_request( f"(prompt_tokens={p_tokens}, " f"completion_tokens={c_tokens})" ) - response.data = choice.message.content + response.data = choice.text return response def _api_request(self, html: str) -> Response: @@ -168,7 +172,6 @@ def _api_request(self, html: str) -> Response: model=model, html_tokens=tokens, ) - json_mode = {"response_format": "json_object"} if model_data.json_mode else {} self._raw_api_request( model=model, messages=[ @@ -179,7 +182,6 @@ def _api_request(self, html: str) -> Response: {"role": "user", "content": html}, ], response=response, - **json_mode, ) return response except self.retry.retry_errors + ( diff --git a/tests/test_apicall.py b/tests/test_apicall.py index 30749ae..72c7909 100644 --- a/tests/test_apicall.py +++ b/tests/test_apicall.py @@ -26,7 +26,7 @@ def test_model_fallback(): with patch_create() as create: # fail first request create.side_effect = [ - _mock_response(finish_reason="timeout"), + _mock_response(finish_reason="length"), _mock_response(), ] api_call.request("") @@ -75,7 +75,7 @@ def _timeout_once(**kwargs): if _timeout_once.called: return _mock_response() _timeout_once.called = True - raise openai.Timeout() + raise openai.APITimeoutError(request=None) _timeout_once.called = False @@ -93,7 +93,7 @@ def test_retry_failure(): retry=RetryRule(2, 0), # disable wait ) - with pytest.raises(openai.Timeout): + with pytest.raises(openai.APITimeoutError): with patch_create() as create: # fail first request create.side_effect = _timeout diff --git a/tests/testutils.py b/tests/testutils.py index 4bccc16..775ec88 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -3,36 +3,37 @@ def _mock_response(**kwargs): - mr = openai.openai_object.OpenAIObject.construct_from( - dict( - model=kwargs.get("model"), - choices=[ - { - "finish_reason": kwargs.get("finish_reason", "stop"), - "message": { - "content": kwargs.get("content", "Hello world"), - }, - } - ], - usage={ - "prompt_tokens": kwargs.get("prompt_tokens", 1), - "completion_tokens": kwargs.get("completion_tokens", 1), - }, - created=1629200000, - id="cmpl-xxxxxxxxxxxxxxxxxxxx", - model_version=kwargs.get("model_version", "ada"), - prompt="Hello world", - status="complete", - ) + mr = openai.types.completion.Completion( + model=kwargs.get("model", ""), + object="text_completion", + choices=[ + openai.types.completion.CompletionChoice( + index=0, + text=kwargs.get("content", "hello world"), + finish_reason= kwargs.get("finish_reason", "stop"), + logprobs={}, + ) + ], + usage={ + "prompt_tokens": kwargs.get("prompt_tokens", 1), + "completion_tokens": kwargs.get("completion_tokens", 1), + "total_tokens": kwargs.get("prompt_tokens", 1) + kwargs.get("completion_tokens", 1), + }, + created=1629200000, + id="cmpl-xxxxxxxxxxxxxxxxxxxx", + model_version=kwargs.get("model_version", "ada"), + prompt="Hello world", + status="complete", + finish_reason="stop", ) return mr def _timeout(**kwargs): - raise openai.Timeout() + raise openai.APITimeoutError(request=None) def patch_create(): - p = patch("scrapeghost.apicall.openai.ChatCompletion.create") + p = patch("scrapeghost.apicall.client.chat.completions.create") return p