Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add olmo and qwen suts #601

Merged
merged 5 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions plugins/huggingface/modelgauge/suts/huggingface_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import requests
from typing import List, Optional

from huggingface_hub import ( # type: ignore
ChatCompletionOutput,
get_inference_endpoint,
InferenceClient,
InferenceEndpointStatus,
)
from huggingface_hub.utils import HfHubHTTPError # type: ignore
from pydantic import BaseModel

from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.prompt import TextPrompt
from modelgauge.secret_values import InjectSecret
from modelgauge.sut import PromptResponseSUT, SUTCompletion, SUTResponse
from modelgauge.sut_capabilities import AcceptsTextPrompt
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS


class HuggingFaceChatParams(BaseModel):
max_new_tokens: Optional[int] = None
temperature: Optional[float] = None


class HuggingFaceChatRequest(BaseModel):
inputs: str
parameters: HuggingFaceChatParams


class HuggingFaceResponse(BaseModel):
generated_text: str


@modelgauge_sut(capabilities=[AcceptsTextPrompt])
class HuggingFaceSUT(PromptResponseSUT[HuggingFaceChatRequest, ChatCompletionOutput]):
"""A Hugging Face SUT that is hosted on a dedicated inference endpoint."""

def __init__(self, uid: str, api_url: str, token: HuggingFaceInferenceToken):
super().__init__(uid)
self.token = token.value
self.api_url = api_url

def translate_text_prompt(self, prompt: TextPrompt) -> HuggingFaceChatRequest:
return HuggingFaceChatRequest(
inputs=prompt.text,
parameters=HuggingFaceChatParams(
max_new_tokens=prompt.options.max_tokens, temperature=prompt.options.temperature
),
)

def evaluate(self, request: HuggingFaceChatRequest) -> HuggingFaceResponse:
headers = {
"Accept": "application/json",
"Authorization": f"Bearer {self.token}",
"Content-Type": "application/json",
}
payload = request.model_dump(exclude_none=True)
response = requests.post(self.api_url, headers=headers, json=payload)
response_json = response.json()[0]
return HuggingFaceResponse(**response_json)

def translate_response(self, request: HuggingFaceChatRequest, response: HuggingFaceResponse) -> SUTResponse:
return SUTResponse(completions=[SUTCompletion(text=response.generated_text)])


HF_SECRET = InjectSecret(HuggingFaceInferenceToken)

SUTS.register(
HuggingFaceSUT,
"olmo-7b-0724-instruct-hf",
"https://flakwttqzmq493dw.us-east-1.aws.endpoints.huggingface.cloud",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bkorycki to add to @wpietri 's comment, this worries me a little. Maybe we could make a mlcommons.org cname to put in the code, and update the cname if the HF endpoint changes? Or in secrets.toml?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long term, I'd like to create a layer that lets a more user-y person say they want olmo-7b-0724-instruct for something, and then lets an ops person say, "we prefer to use huggingface for that and here's how". And presumably that latter bit will be via a configuration file. That would help us, and it would also help other users of our code, because they may want to use a different llamaguard provider, including an internal one, rather than the one we've hardcoded.

But for now updating the source isn't much harder than updating a cname. And a cname only solves the problem for us, not for other people using this code. So I'm fine with just leaving this like this and then fixing it when we have more information about where we need things to flex.

HF_SECRET,
)
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ class ChatMessage(BaseModel):
role: str


class HuggingFaceInferenceChatRequest(BaseModel):
class HuggingFaceChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None


@modelgauge_sut(capabilities=[AcceptsTextPrompt])
class HuggingFaceInferenceSUT(PromptResponseSUT[HuggingFaceInferenceChatRequest, ChatCompletionOutput]):
"""A Hugging Face SUT that is hosted on a dedicated inference endpoint."""
class HuggingFaceChatCompletionSUT(PromptResponseSUT[HuggingFaceChatCompletionRequest, ChatCompletionOutput]):
"""A Hugging Face SUT that is hosted on a dedicated inference endpoint and uses the chat_completion API."""

def __init__(self, uid: str, inference_endpoint: str, token: HuggingFaceInferenceToken):
super().__init__(uid)
Expand Down Expand Up @@ -67,21 +67,21 @@ def _create_client(self):

self.client = InferenceClient(base_url=endpoint.url, token=self.token.value)

def translate_text_prompt(self, prompt: TextPrompt) -> HuggingFaceInferenceChatRequest:
return HuggingFaceInferenceChatRequest(
def translate_text_prompt(self, prompt: TextPrompt) -> HuggingFaceChatCompletionRequest:
return HuggingFaceChatCompletionRequest(
messages=[ChatMessage(role="user", content=prompt.text)],
**prompt.options.model_dump(),
)

def evaluate(self, request: HuggingFaceInferenceChatRequest) -> ChatCompletionOutput:
def evaluate(self, request: HuggingFaceChatCompletionRequest) -> ChatCompletionOutput:
if self.client is None:
self._create_client()

request_dict = request.model_dump(exclude_none=True)
return self.client.chat_completion(**request_dict) # type: ignore

def translate_response(
self, request: HuggingFaceInferenceChatRequest, response: ChatCompletionOutput
self, request: HuggingFaceChatCompletionRequest, response: ChatCompletionOutput
) -> SUTResponse:
completions = []
for choice in response.choices:
Expand All @@ -94,15 +94,22 @@ def translate_response(
HF_SECRET = InjectSecret(HuggingFaceInferenceToken)

SUTS.register(
HuggingFaceInferenceSUT,
HuggingFaceChatCompletionSUT,
"gemma-9b-it-hf",
"gemma-2-9b-it-qfa",
HF_SECRET,
)

SUTS.register(
HuggingFaceInferenceSUT,
HuggingFaceChatCompletionSUT,
"mistral-nemo-instruct-2407-hf",
"mistral-nemo-instruct-2407-mgt",
HF_SECRET,
)

SUTS.register(
HuggingFaceChatCompletionSUT,
"qwen2-5-7b-instruct-hf",
"qwen2-5-7b-instruct-hgy",
HF_SECRET,
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.prompt import SUTOptions, TextPrompt
from modelgauge.sut import SUTCompletion, SUTResponse
from modelgauge.suts.huggingface_inference import (
from modelgauge.suts.huggingface_chat_completion import (
ChatMessage,
HuggingFaceInferenceChatRequest,
HuggingFaceInferenceSUT,
HuggingFaceChatCompletionRequest,
HuggingFaceChatCompletionSUT,
)
from pydantic import BaseModel

Expand All @@ -24,11 +24,11 @@ def mock_endpoint():


@pytest.fixture
@patch("modelgauge.suts.huggingface_inference.get_inference_endpoint")
@patch("modelgauge.suts.huggingface_chat_completion.get_inference_endpoint")
def fake_sut(mock_get_inference_endpoint, mock_endpoint):
mock_get_inference_endpoint.return_value = mock_endpoint

sut = HuggingFaceInferenceSUT("fake_uid", "fake_endpoint", HuggingFaceInferenceToken("fake_token"))
sut = HuggingFaceChatCompletionSUT("fake_uid", "fake_endpoint", HuggingFaceInferenceToken("fake_token"))
return sut


Expand All @@ -42,17 +42,17 @@ def prompt():

@pytest.fixture
def sut_request():
return HuggingFaceInferenceChatRequest(
return HuggingFaceChatCompletionRequest(
messages=[ChatMessage(role="user", content="some text prompt")],
max_tokens=5,
temperature=1.0,
)


def test_huggingface_inference_translate_text_prompt_request(fake_sut, prompt, sut_request):
def test_huggingface_chat_completion_translate_text_prompt_request(fake_sut, prompt, sut_request):
request = fake_sut.translate_text_prompt(prompt)

assert isinstance(request, HuggingFaceInferenceChatRequest)
assert isinstance(request, HuggingFaceChatCompletionRequest)
assert request == sut_request


Expand All @@ -64,17 +64,21 @@ def test_huggingface_inference_translate_text_prompt_request(fake_sut, prompt, s
InferenceEndpointStatus.UPDATING,
],
)
@patch("modelgauge.suts.huggingface_inference.get_inference_endpoint")
def test_huggingface_inference_connect_endpoint(mock_get_inference_endpoint, fake_sut, mock_endpoint, endpoint_status):
@patch("modelgauge.suts.huggingface_chat_completion.get_inference_endpoint")
def test_huggingface_chat_completion_connect_endpoint(
mock_get_inference_endpoint, fake_sut, mock_endpoint, endpoint_status
):
mock_get_inference_endpoint.return_value = mock_endpoint
mock_endpoint.status = endpoint_status

fake_sut._create_client()
mock_endpoint.wait.assert_called_once()


@patch("modelgauge.suts.huggingface_inference.get_inference_endpoint")
def test_huggingface_inference_connect_endpoint_scaled_to_zero(mock_get_inference_endpoint, fake_sut, mock_endpoint):
@patch("modelgauge.suts.huggingface_chat_completion.get_inference_endpoint")
def test_huggingface_chat_completion_connect_endpoint_scaled_to_zero(
mock_get_inference_endpoint, fake_sut, mock_endpoint
):
mock_get_inference_endpoint.return_value = mock_endpoint
mock_endpoint.status = InferenceEndpointStatus.SCALED_TO_ZERO

Expand All @@ -84,8 +88,10 @@ def test_huggingface_inference_connect_endpoint_scaled_to_zero(mock_get_inferenc
mock_endpoint.wait.assert_called_once()


@patch("modelgauge.suts.huggingface_inference.get_inference_endpoint")
def test_huggingface_inference_connect_endpoint_fails_to_resume(mock_get_inference_endpoint, fake_sut, mock_endpoint):
@patch("modelgauge.suts.huggingface_chat_completion.get_inference_endpoint")
def test_huggingface_chat_completion_connect_endpoint_fails_to_resume(
mock_get_inference_endpoint, fake_sut, mock_endpoint
):
mock_get_inference_endpoint.return_value = mock_endpoint
mock_endpoint.status = InferenceEndpointStatus.SCALED_TO_ZERO
mock_endpoint.resume.side_effect = HfHubHTTPError("Failure.")
Expand All @@ -95,18 +101,18 @@ def test_huggingface_inference_connect_endpoint_fails_to_resume(mock_get_inferen
mock_endpoint.wait.assert_not_called()


@patch("modelgauge.suts.huggingface_inference.get_inference_endpoint")
def test_huggingface_inference_connect_failed_endpoint(mock_get_inference_endpoint, fake_sut, mock_endpoint):
@patch("modelgauge.suts.huggingface_chat_completion.get_inference_endpoint")
def test_huggingface_chat_completion_connect_failed_endpoint(mock_get_inference_endpoint, fake_sut, mock_endpoint):
mock_get_inference_endpoint.return_value = mock_endpoint
mock_endpoint.status = InferenceEndpointStatus.FAILED

with pytest.raises(ConnectionError):
fake_sut._create_client()


@patch("modelgauge.suts.huggingface_inference.get_inference_endpoint")
@patch("modelgauge.suts.huggingface_inference.InferenceClient")
def test_huggingface_inference_lazy_load_client(
@patch("modelgauge.suts.huggingface_chat_completion.get_inference_endpoint")
@patch("modelgauge.suts.huggingface_chat_completion.InferenceClient")
def test_huggingface_chat_completion_lazy_load_client(
mock_client, mock_get_inference_endpoint, fake_sut, mock_endpoint, sut_request
):
mock_get_inference_endpoint.return_value = mock_endpoint
Expand All @@ -118,8 +124,8 @@ def test_huggingface_inference_lazy_load_client(
assert fake_sut.client is not None


@patch("modelgauge.suts.huggingface_inference.InferenceClient")
def test_huggingface_inference_evaluate(mock_client, fake_sut, sut_request):
@patch("modelgauge.suts.huggingface_chat_completion.InferenceClient")
def test_huggingface_chat_completion_evaluate(mock_client, fake_sut, sut_request):
fake_sut.client = mock_client

fake_sut.evaluate(sut_request)
Expand All @@ -141,7 +147,7 @@ class FakeResponse(BaseModel):
choices: list[FakeChoice]


def test_huggingface_inference_translate_response(fake_sut, sut_request):
def test_huggingface_chat_completion_translate_response(fake_sut, sut_request):
evaluate_output = FakeResponse(choices=[FakeChoice(message=ChatMessage(content="response", role="assistant"))])

response = fake_sut.translate_response(sut_request, evaluate_output)
Expand Down
2 changes: 1 addition & 1 deletion plugins/validation_tests/test_object_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from modelgauge.sut import PromptResponseSUT, SUTResponse
from modelgauge.sut_capabilities import AcceptsTextPrompt
from modelgauge.sut_registry import SUTS
from modelgauge.suts.huggingface_inference import HUGGING_FACE_TIMEOUT
from modelgauge.suts.huggingface_chat_completion import HUGGING_FACE_TIMEOUT
from modelgauge.test_registry import TESTS

# Ensure all the plugins are available during testing.
Expand Down
Loading