Skip to content

Commit

Permalink
Correct serialization of gpt messasgse
Browse files Browse the repository at this point in the history
  • Loading branch information
piercefreeman committed Apr 13, 2024
1 parent 3109055 commit 07e1359
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 112 deletions.
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
`gpt-json` is a wrapper around GPT that allows for declarative definition of expected output format. Set up a schema, write a prompt, and get results back as beautiful typehinted objects.

Specifically this library:

- Utilizes Pydantic schema definitions for type casting and validations
- Adds typehinting for both the API and the output schema
- Allows GPT to respond with both single-objects and lists of objects
Expand All @@ -11,6 +12,7 @@ Specifically this library:
- Formats the JSON schema as a flexible prompt that can be added into any message
- Supports templating of prompts to allow for dynamic content
- Validate typehinted function calls in the new GPT models, to better support agent creation
- Lightweight dependencies: only OpenAI, pydantic, and backoff

## Getting Started

Expand Down Expand Up @@ -197,19 +199,19 @@ GPT makes no guarantees about the validity of the returned functions. They could

The `GPTJSON` class supports other configuration parameters at initialization.

| Parameter | Type | Description |
|-----------------------------|------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| model | GPTModelVersion \| str | (default: GPTModelVersion.GPT_4) - For convenience we provide the currently supported GPT model versions in the `GPTModelVersion` enum. You can also pass a string value if you want to use another more specific architecture. |
| auto_trim | bool | (default: False) - If your input prompt is too long, perhaps because of dynamic injected content, will automatically truncate the text to create enough room for the model's response. |
| auto_trim_response_overhead | int | (default: 0) - If you're using auto_trim, configures the max amount of tokens to allow in the model's response. |
| **kwargs | Any | Any other parameters you want to pass to the underlying `GPT` class, will just be a passthrough. |
| Parameter | Type | Description |
| --------------------------- | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| model | GPTModelVersion \| str | (default: GPTModelVersion.GPT_4) - For convenience we provide the currently supported GPT model versions in the `GPTModelVersion` enum. You can also pass a string value if you want to use another more specific architecture. |
| auto_trim | bool | (default: False) - If your input prompt is too long, perhaps because of dynamic injected content, will automatically truncate the text to create enough room for the model's response. |
| auto_trim_response_overhead | int | (default: 0) - If you're using auto_trim, configures the max amount of tokens to allow in the model's response. |
| \*\*kwargs | Any | Any other parameters you want to pass to the underlying `GPT` class, will just be a passthrough. |

## Transformations

GPT (especially GPT-4) is relatively good at formatting responses at JSON, but it's not perfect. Some of the more common issues are:

- *Response truncation*: Since GPT is not internally aware of its response length limit, JSON payloads will sometimes exhaust the available token space. This results in a broken JSON payload where much of the data is valid but the JSON object is not closed, which is not valid syntax. There are many cases where this behavior is actually okay for production applications - for instance, if you list 100 generated strings, it's sometimes okay for you to take the 70 that actually rendered. In this case, `gpt-json` will attempt to fix the truncated payload by recreating the JSON object and closing it.
- *Boolean variables*: GPT will sometimes confuse valid JSON boolean values with the boolean tokens that are used in other languages. The most common is generating `True` instead of `true`. `gpt-json` will attempt to fix these values.
- _Response truncation_: Since GPT is not internally aware of its response length limit, JSON payloads will sometimes exhaust the available token space. This results in a broken JSON payload where much of the data is valid but the JSON object is not closed, which is not valid syntax. There are many cases where this behavior is actually okay for production applications - for instance, if you list 100 generated strings, it's sometimes okay for you to take the 70 that actually rendered. In this case, `gpt-json` will attempt to fix the truncated payload by recreating the JSON object and closing it.
- _Boolean variables_: GPT will sometimes confuse valid JSON boolean values with the boolean tokens that are used in other languages. The most common is generating `True` instead of `true`. `gpt-json` will attempt to fix these values.

When calling `gpt_json.run()`, we return a tuple of values:

Expand All @@ -225,7 +227,7 @@ FixTransforms(fixed_truncation=True, fixed_bools=False)

The first object is your generated Pydantic model. The second object is our correction storage object `FixTransforms`. This dataclass contains flags for each of the supported transformation cases that are sketched out above. This allows you to determine whether the response was explicitly parsed from the GPT JSON, or was passed through some middlelayers to get a correct output. From there you can accept or reject the response based on your own business logic.

*Where you can help*: There are certainly more areas of common (and not-so-common failures). If you see these, please add a test case to the unit tests. If you can write a handler to help solve the general case, please do so. Otherwise flag it as a `pytest.xfail` and we'll add it to the backlog.
_Where you can help_: There are certainly more areas of common (and not-so-common failures). If you see these, please add a test case to the unit tests. If you can write a handler to help solve the general case, please do so. Otherwise flag it as a `pytest.xfail` and we'll add it to the backlog.

## Testing

Expand Down
26 changes: 12 additions & 14 deletions gpt_json/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
prepare_streaming_object,
)
from gpt_json.transformations import fix_bools, fix_truncated_json
from gpt_json.truncation import num_tokens_from_messages, truncate_tokens
from gpt_json.truncation import truncate_tokens
from gpt_json.types_oai import ChatCompletionChunk

logger = logging.getLogger("gptjson_logger")
Expand Down Expand Up @@ -511,19 +511,17 @@ def fill_messages(
)

# fill the messages without the target variable to calculate the "space" we have left
enc = encoding_for_model("gpt-4")
format_variables_no_target = format_variables.copy()
format_variables_no_target[truncation_options.target_variable] = ""
target_variable_max_tokens = (
truncation_options.max_prompt_tokens
- num_tokens_from_messages(
[
self.message_to_dict(
self.fill_message_template(message, format_variables_no_target)
)
for message in messages
],
self.model.api_name,
)
target_variable_max_tokens = truncation_options.max_prompt_tokens - sum(
[
len(enc.encode(self.get_content_text(new_message)))
for message in messages
for new_message in [
self.fill_message_template(message, format_variables_no_target)
]
],
)

if target_variable_max_tokens < 0:
Expand Down Expand Up @@ -588,7 +586,7 @@ def fill_message_template(
return new_message

def message_to_dict(self, message: GPTMessage):
obj = json_loads(message.model_dump_json(exclude_unset=True))
obj = json_loads(message.model_dump_json(by_alias=True, exclude_unset=True))
obj.pop("allow_templating", None)
return obj

Expand Down Expand Up @@ -677,7 +675,7 @@ def get_model_metadata(
# Determine if the deprecation date is specified, should re-raise a deprecation warning
if model.value.deprecated_date is not None:
deprecation_date = model.value.deprecated_date
if deprecation_date < datetime.now():
if deprecation_date < datetime.now().date():
raise ValueError(
f"Model {model.value.api_name} is deprecated as of {deprecation_date}."
)
Expand Down
9 changes: 5 additions & 4 deletions gpt_json/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys
from base64 import b64encode
from dataclasses import dataclass
from dataclasses import dataclass, replace
from datetime import date
from enum import Enum, unique
from typing import Callable, Literal
Expand Down Expand Up @@ -42,6 +42,7 @@ class ModelVersionParams:
api_name: str
max_length: int
deprecated_date: date | None = None
archived: bool = False


@unique
Expand All @@ -66,8 +67,8 @@ class GPTModelVersion(Enum):

# Deprecated internally - switch to explicit model revisions
# Kept for reverse compatibility
GPT_3_5 = GPT_3_5_0613
GPT_4 = GPT_4_0613
GPT_3_5 = replace(GPT_3_5_0613, archived=True) # type: ignore
GPT_4 = replace(GPT_4_0613, archived=True) # type: ignore


@unique
Expand Down Expand Up @@ -103,7 +104,7 @@ class ImageUrl(BaseModel):
url: HttpUrl | str

@field_validator("url", mode="after")
def validate_url(self, value):
def validate_url(cls, value):
if isinstance(value, HttpUrl):
return value
elif isinstance(value, str):
Expand Down
14 changes: 10 additions & 4 deletions gpt_json/tests/test_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,14 @@ async def test_create(

# Assert that the mock function was called with the expected parameters
mock_client.return_value.chat.completions.create.assert_called_with(
model=model_version.value,
model=model_version.value.api_name,
messages=[
{
"role": message.role.value,
"content": message.content,
"content": [
content.model_dump(by_alias=True, exclude_unset=True)
for content in message.get_content_payloads()
],
}
for message in messages
],
Expand Down Expand Up @@ -254,11 +257,14 @@ async def test_create_with_function_calls():
response = await model.run(messages=messages)

mock_client.return_value.chat.completions.create.assert_called_with(
model=model_version.value,
model=model_version.value.api_name,
messages=[
{
"role": message.role.value,
"content": message.content,
"content": [
content.model_dump(by_alias=True, exclude_unset=True)
for content in message.get_content_payloads()
],
}
for message in messages
],
Expand Down
26 changes: 0 additions & 26 deletions gpt_json/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,8 @@
import sys
from enum import Enum

import pytest

import gpt_json.models as models_file
from gpt_json.models import GPTMessage, GPTMessageRole


@pytest.mark.parametrize("model_file", [models_file])
def test_string_enums(model_file):
if sys.version_info < (3, 11):
pytest.skip("Only Python 3.11+ has native support for string-based enums")
return
else:
from enum import StrEnum

found_enums = 0
for obj in model_file.__dict__.values():
if (
isinstance(obj, type)
and issubclass(obj, Enum)
and not obj in {Enum, StrEnum}
):
found_enums += 1
assert issubclass(obj, StrEnum), f"{obj} is not a StrEnum"

# Every file listed in pytest should have at least one enum
assert found_enums > 0, f"No enums found in {model_file}"


def test_gpt_message_validates_function():
with pytest.raises(ValueError):
GPTMessage(
Expand Down
7 changes: 5 additions & 2 deletions gpt_json/tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,14 @@ async def async_list_to_generator(my_list):

# Assert that the mock function was called with the expected parameters, including streaming
mock_client.return_value.chat.completions.create.assert_called_with(
model=model_version.value,
model=model_version.value.api_name,
messages=[
{
"role": message.role.value,
"content": message.content,
"content": [
content.model_dump(by_alias=True, exclude_unset=True)
for content in message.get_content_payloads()
],
}
for message in messages
],
Expand Down
9 changes: 1 addition & 8 deletions gpt_json/tests/test_truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,7 @@
TruncationOptions,
VariableTruncationMode,
)
from gpt_json.truncation import num_tokens_from_messages, truncate_tokens


@pytest.mark.parametrize("model", [model.value for model in GPTModelVersion])
def test_num_tokens_implemented(model):
# no need to assert anything specific, just that its implemented for all models
# i.e. doesn't throw an error
num_tokens_from_messages([], model)
from gpt_json.truncation import truncate_tokens


def test_fill_messages_truncated():
Expand Down
46 changes: 1 addition & 45 deletions gpt_json/truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import tiktoken

from gpt_json.models import GPTModelVersion, VariableTruncationMode
from gpt_json.models import VariableTruncationMode


def tokenize(text: str, model: str) -> list[int]:
Expand All @@ -16,50 +16,6 @@ def decode(tokens: list[int], model: str) -> str:
return enc.decode(tokens)


def gpt_message_markup_v1(messages: list[dict[str, str]], model: str) -> int:
"""
Converts a list of messages into the number of tokens used by the model, following the
markup rules for GPT-3.5 and GPT-4 defined here: https://platform.openai.com/docs/guides/chat/managing-tokens.
"""
encoding = tiktoken.encoding_for_model(model)

num_tokens = 0
for message in messages:
num_tokens += (
4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
)
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name": # if there's a name, the role is omitted
num_tokens += -1 # role is always required and always 1 token
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens


MODEL_MESSAGE_MARKUP = {
# For now all the models have the same tokenization strategy
enum_value.value.api_name: gpt_message_markup_v1
for enum_value in [
GPTModelVersion.GPT_4_0613,
GPTModelVersion.GPT_4_32K_0613,
GPTModelVersion.GPT_3_5_0613,
GPTModelVersion.GPT_3_5_1106,
GPTModelVersion.GPT_3_5_0125,
]
}


def num_tokens_from_messages(messages: list[dict[str, str]], model: str) -> int:
"""Returns the number of tokens used by a list of messages.
NOTE: for future models, there may be structural changes to how messages are converted into content
that affect the number of tokens. Future models should be added to MODEL_MESSAGE_MARKUP.
"""
if model not in MODEL_MESSAGE_MARKUP:
raise NotImplementedError(f"Model {model} message markup not implemented")
return MODEL_MESSAGE_MARKUP[model](messages, model)


def truncate_tokens(
text: str,
model: str,
Expand Down

0 comments on commit 07e1359

Please sign in to comment.