Skip to content

Commit

Permalink
Huggingface Chat Completion SUT produces logprobs (#630)
Browse files Browse the repository at this point in the history
* chat completions sut produces logprobs

* update uid gemma-9b -> gemma-2-9b

* unit tests for logprobs
  • Loading branch information
bkorycki authored Oct 23, 2024
1 parent 3c59dff commit fc8ecf8
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 23 deletions.
5 changes: 2 additions & 3 deletions plugins/huggingface/modelgauge/suts/huggingface_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import Optional

import requests
import requests # type: ignore
from huggingface_hub import ( # type: ignore
ChatCompletionOutput,
get_inference_endpoint,
Expand All @@ -9,6 +7,7 @@
)
from huggingface_hub.utils import HfHubHTTPError # type: ignore
from pydantic import BaseModel
from typing import Optional

from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.prompt import TextPrompt
Expand Down
27 changes: 22 additions & 5 deletions plugins/huggingface/modelgauge/suts/huggingface_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.prompt import TextPrompt
from modelgauge.secret_values import InjectSecret
from modelgauge.sut import PromptResponseSUT, SUTCompletion, SUTResponse
from modelgauge.sut_capabilities import AcceptsTextPrompt
from modelgauge.sut import PromptResponseSUT, SUTCompletion, SUTResponse, TokenProbability, TopTokens
from modelgauge.sut_capabilities import AcceptsTextPrompt, ProducesPerTokenLogProbabilities
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS

Expand All @@ -27,12 +27,14 @@ class ChatMessage(BaseModel):

class HuggingFaceChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
logprobs: bool
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None


@modelgauge_sut(capabilities=[AcceptsTextPrompt])
@modelgauge_sut(capabilities=[AcceptsTextPrompt, ProducesPerTokenLogProbabilities])
class HuggingFaceChatCompletionSUT(PromptResponseSUT[HuggingFaceChatCompletionRequest, ChatCompletionOutput]):
"""A Hugging Face SUT that is hosted on a dedicated inference endpoint and uses the chat_completion API."""

Expand Down Expand Up @@ -68,8 +70,12 @@ def _create_client(self):
self.client = InferenceClient(base_url=endpoint.url, token=self.token.value)

def translate_text_prompt(self, prompt: TextPrompt) -> HuggingFaceChatCompletionRequest:
logprobs = False
if prompt.options.top_logprobs is not None:
logprobs = True
return HuggingFaceChatCompletionRequest(
messages=[ChatMessage(role="user", content=prompt.text)],
logprobs=logprobs,
**prompt.options.model_dump(),
)

Expand All @@ -87,15 +93,26 @@ def translate_response(
for choice in response.choices:
text = choice.message.content
assert text is not None
completions.append(SUTCompletion(text=text))
logprobs: Optional[List[TopTokens]] = None
if request.logprobs:
logprobs = []
assert choice.logprobs is not None, "Expected logprobs, but not returned."
lobprobs_sequence = choice.logprobs.content
for token in lobprobs_sequence:
top_tokens = []
for top_logprob in token.top_logprobs:
top_tokens.append(TokenProbability(token=top_logprob.token, logprob=top_logprob.logprob))
logprobs.append(TopTokens(top_tokens=top_tokens))

completions.append(SUTCompletion(text=text, top_logprobs=logprobs))
return SUTResponse(completions=completions)


HF_SECRET = InjectSecret(HuggingFaceInferenceToken)

SUTS.register(
HuggingFaceChatCompletionSUT,
"gemma-9b-it-hf",
"gemma-2-9b-it-hf",
"gemma-2-9b-it-qfa",
HF_SECRET,
)
Expand Down
114 changes: 99 additions & 15 deletions plugins/huggingface/tests/test_huggingface_chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from unittest.mock import Mock, patch

import pytest
from huggingface_hub import InferenceEndpointStatus # type: ignore
from huggingface_hub import ChatCompletionOutputLogprob, ChatCompletionOutputLogprobs, ChatCompletionOutputTopLogprob, InferenceEndpointStatus # type: ignore
from huggingface_hub.utils import HfHubHTTPError # type: ignore
from pydantic import BaseModel
from typing import Optional
from unittest.mock import Mock, patch

from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.prompt import SUTOptions, TextPrompt
from modelgauge.sut import SUTCompletion, SUTResponse
from modelgauge.sut import SUTCompletion, SUTResponse, TokenProbability, TopTokens
from modelgauge.suts.huggingface_chat_completion import (
ChatMessage,
HuggingFaceChatCompletionRequest,
HuggingFaceChatCompletionSUT,
)
from pydantic import BaseModel


@pytest.fixture
Expand All @@ -32,28 +32,37 @@ def fake_sut(mock_get_inference_endpoint, mock_endpoint):
return sut


@pytest.fixture
def prompt():
def _make_prompt(top_logprobs=None):
extra_options = {}
if top_logprobs is not None:
extra_options["top_logprobs"] = top_logprobs
return TextPrompt(
text="some text prompt",
options=SUTOptions(max_tokens=5, temperature=1.0, random="random"),
options=SUTOptions(max_tokens=5, temperature=1.0, random="random", **extra_options),
)


@pytest.fixture
def sut_request():
def _make_sut_request(top_logprobs: Optional[int] = None):
extra_options = {}
if top_logprobs is not None:
extra_options["top_logprobs"] = top_logprobs
return HuggingFaceChatCompletionRequest(
messages=[ChatMessage(role="user", content="some text prompt")],
logprobs=top_logprobs is not None,
max_tokens=5,
temperature=1.0,
**extra_options,
)


def test_huggingface_chat_completion_translate_text_prompt_request(fake_sut, prompt, sut_request):
@pytest.mark.parametrize("top_logprobs", [None, 2])
def test_huggingface_chat_completion_translate_text_prompt_request(fake_sut, top_logprobs):
prompt = _make_prompt(top_logprobs)

request = fake_sut.translate_text_prompt(prompt)

assert isinstance(request, HuggingFaceChatCompletionRequest)
assert request == sut_request
assert request == _make_sut_request(top_logprobs)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -113,8 +122,9 @@ def test_huggingface_chat_completion_connect_failed_endpoint(mock_get_inference_
@patch("modelgauge.suts.huggingface_chat_completion.get_inference_endpoint")
@patch("modelgauge.suts.huggingface_chat_completion.InferenceClient")
def test_huggingface_chat_completion_lazy_load_client(
mock_client, mock_get_inference_endpoint, fake_sut, mock_endpoint, sut_request
mock_client, mock_get_inference_endpoint, fake_sut, mock_endpoint
):
sut_request = _make_sut_request()
mock_get_inference_endpoint.return_value = mock_endpoint
assert fake_sut.client is None

Expand All @@ -125,14 +135,34 @@ def test_huggingface_chat_completion_lazy_load_client(


@patch("modelgauge.suts.huggingface_chat_completion.InferenceClient")
def test_huggingface_chat_completion_evaluate(mock_client, fake_sut, sut_request):
def test_huggingface_chat_completion_evaluate(mock_client, fake_sut):
sut_request = _make_sut_request()
fake_sut.client = mock_client

fake_sut.evaluate(sut_request)

mock_client.chat_completion.assert_called_with(
**{
"messages": [{"content": "some text prompt", "role": "user"}],
"logprobs": False,
"max_tokens": 5,
"temperature": 1.0,
}
)


@patch("modelgauge.suts.huggingface_chat_completion.InferenceClient")
def test_huggingface_chat_completion_evaluate_with_logprobs(mock_client, fake_sut):
sut_request = _make_sut_request(top_logprobs=2)
fake_sut.client = mock_client

fake_sut.evaluate(sut_request)

mock_client.chat_completion.assert_called_with(
**{
"messages": [{"content": "some text prompt", "role": "user"}],
"logprobs": True,
"top_logprobs": 2,
"max_tokens": 5,
"temperature": 1.0,
}
Expand All @@ -141,15 +171,69 @@ def test_huggingface_chat_completion_evaluate(mock_client, fake_sut, sut_request

class FakeChoice(BaseModel):
message: ChatMessage
logprobs: Optional[ChatCompletionOutputLogprobs] = None


class FakeResponse(BaseModel):
choices: list[FakeChoice]


def test_huggingface_chat_completion_translate_response(fake_sut, sut_request):
def test_huggingface_chat_completion_translate_response(fake_sut):
sut_request = _make_sut_request()
evaluate_output = FakeResponse(choices=[FakeChoice(message=ChatMessage(content="response", role="assistant"))])

response = fake_sut.translate_response(sut_request, evaluate_output)

assert response == SUTResponse(completions=[SUTCompletion(text="response")])


def test_huggingface_chat_completion_translate_response_with_logprobs(fake_sut):
sut_request = _make_sut_request(top_logprobs=2)
logprobs_output = ChatCompletionOutputLogprobs(
content=[
ChatCompletionOutputLogprob(
token="hello",
logprob=-0.1,
top_logprobs=[
ChatCompletionOutputTopLogprob(token="hello", logprob=-0.1),
ChatCompletionOutputTopLogprob(token="hola", logprob=-0.2),
],
),
ChatCompletionOutputLogprob(
token="world",
logprob=-0.3,
top_logprobs=[
ChatCompletionOutputTopLogprob(token="world", logprob=-0.3),
ChatCompletionOutputTopLogprob(token="peeps", logprob=-0.4),
],
),
]
)

evaluate_output = FakeResponse(
choices=[FakeChoice(message=ChatMessage(content="hello world", role="assistant"), logprobs=logprobs_output)]
)

response = fake_sut.translate_response(sut_request, evaluate_output)

assert response == SUTResponse(
completions=[
SUTCompletion(
text="hello world",
top_logprobs=[
TopTokens(
top_tokens=[
TokenProbability(token="hello", logprob=-0.1),
TokenProbability(token="hola", logprob=-0.2),
]
),
TopTokens(
top_tokens=[
TokenProbability(token="world", logprob=-0.3),
TokenProbability(token="peeps", logprob=-0.4),
]
),
],
)
]
)

0 comments on commit fc8ecf8

Please sign in to comment.