diff --git a/plugins/huggingface/modelgauge/suts/huggingface_api.py b/plugins/huggingface/modelgauge/suts/huggingface_api.py new file mode 100644 index 00000000..782546bc --- /dev/null +++ b/plugins/huggingface/modelgauge/suts/huggingface_api.py @@ -0,0 +1,75 @@ +import requests +from typing import List, Optional + +from huggingface_hub import ( # type: ignore + ChatCompletionOutput, + get_inference_endpoint, + InferenceClient, + InferenceEndpointStatus, +) +from huggingface_hub.utils import HfHubHTTPError # type: ignore +from pydantic import BaseModel + +from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken +from modelgauge.prompt import TextPrompt +from modelgauge.secret_values import InjectSecret +from modelgauge.sut import PromptResponseSUT, SUTCompletion, SUTResponse +from modelgauge.sut_capabilities import AcceptsTextPrompt +from modelgauge.sut_decorator import modelgauge_sut +from modelgauge.sut_registry import SUTS + + +class HuggingFaceChatParams(BaseModel): + max_new_tokens: Optional[int] = None + temperature: Optional[float] = None + + +class HuggingFaceChatRequest(BaseModel): + inputs: str + parameters: HuggingFaceChatParams + + +class HuggingFaceResponse(BaseModel): + generated_text: str + + +@modelgauge_sut(capabilities=[AcceptsTextPrompt]) +class HuggingFaceSUT(PromptResponseSUT[HuggingFaceChatRequest, ChatCompletionOutput]): + """A Hugging Face SUT that is hosted on a dedicated inference endpoint.""" + + def __init__(self, uid: str, api_url: str, token: HuggingFaceInferenceToken): + super().__init__(uid) + self.token = token.value + self.api_url = api_url + + def translate_text_prompt(self, prompt: TextPrompt) -> HuggingFaceChatRequest: + return HuggingFaceChatRequest( + inputs=prompt.text, + parameters=HuggingFaceChatParams( + max_new_tokens=prompt.options.max_tokens, temperature=prompt.options.temperature + ), + ) + + def evaluate(self, request: HuggingFaceChatRequest) -> HuggingFaceResponse: + headers = { + "Accept": "application/json", + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json", + } + payload = request.model_dump(exclude_none=True) + response = requests.post(self.api_url, headers=headers, json=payload) + response_json = response.json()[0] + return HuggingFaceResponse(**response_json) + + def translate_response(self, request: HuggingFaceChatRequest, response: HuggingFaceResponse) -> SUTResponse: + return SUTResponse(completions=[SUTCompletion(text=response.generated_text)]) + + +HF_SECRET = InjectSecret(HuggingFaceInferenceToken) + +SUTS.register( + HuggingFaceSUT, + "olmo-7b-0724-instruct-hf", + "https://flakwttqzmq493dw.us-east-1.aws.endpoints.huggingface.cloud", + HF_SECRET, +) diff --git a/plugins/huggingface/modelgauge/suts/huggingface_inference.py b/plugins/huggingface/modelgauge/suts/huggingface_chat_completion.py similarity index 84% rename from plugins/huggingface/modelgauge/suts/huggingface_inference.py rename to plugins/huggingface/modelgauge/suts/huggingface_chat_completion.py index 4d41fb98..56cc5a9e 100644 --- a/plugins/huggingface/modelgauge/suts/huggingface_inference.py +++ b/plugins/huggingface/modelgauge/suts/huggingface_chat_completion.py @@ -25,7 +25,7 @@ class ChatMessage(BaseModel): role: str -class HuggingFaceInferenceChatRequest(BaseModel): +class HuggingFaceChatCompletionRequest(BaseModel): messages: List[ChatMessage] max_tokens: Optional[int] = None temperature: Optional[float] = None @@ -33,8 +33,8 @@ class HuggingFaceInferenceChatRequest(BaseModel): @modelgauge_sut(capabilities=[AcceptsTextPrompt]) -class HuggingFaceInferenceSUT(PromptResponseSUT[HuggingFaceInferenceChatRequest, ChatCompletionOutput]): - """A Hugging Face SUT that is hosted on a dedicated inference endpoint.""" +class HuggingFaceChatCompletionSUT(PromptResponseSUT[HuggingFaceChatCompletionRequest, ChatCompletionOutput]): + """A Hugging Face SUT that is hosted on a dedicated inference endpoint and uses the chat_completion API.""" def __init__(self, uid: str, inference_endpoint: str, token: HuggingFaceInferenceToken): super().__init__(uid) @@ -67,13 +67,13 @@ def _create_client(self): self.client = InferenceClient(base_url=endpoint.url, token=self.token.value) - def translate_text_prompt(self, prompt: TextPrompt) -> HuggingFaceInferenceChatRequest: - return HuggingFaceInferenceChatRequest( + def translate_text_prompt(self, prompt: TextPrompt) -> HuggingFaceChatCompletionRequest: + return HuggingFaceChatCompletionRequest( messages=[ChatMessage(role="user", content=prompt.text)], **prompt.options.model_dump(), ) - def evaluate(self, request: HuggingFaceInferenceChatRequest) -> ChatCompletionOutput: + def evaluate(self, request: HuggingFaceChatCompletionRequest) -> ChatCompletionOutput: if self.client is None: self._create_client() @@ -81,7 +81,7 @@ def evaluate(self, request: HuggingFaceInferenceChatRequest) -> ChatCompletionOu return self.client.chat_completion(**request_dict) # type: ignore def translate_response( - self, request: HuggingFaceInferenceChatRequest, response: ChatCompletionOutput + self, request: HuggingFaceChatCompletionRequest, response: ChatCompletionOutput ) -> SUTResponse: completions = [] for choice in response.choices: @@ -94,15 +94,22 @@ def translate_response( HF_SECRET = InjectSecret(HuggingFaceInferenceToken) SUTS.register( - HuggingFaceInferenceSUT, + HuggingFaceChatCompletionSUT, "gemma-9b-it-hf", "gemma-2-9b-it-qfa", HF_SECRET, ) SUTS.register( - HuggingFaceInferenceSUT, + HuggingFaceChatCompletionSUT, "mistral-nemo-instruct-2407-hf", "mistral-nemo-instruct-2407-mgt", HF_SECRET, ) + +SUTS.register( + HuggingFaceChatCompletionSUT, + "qwen2-5-7b-instruct-hf", + "qwen2-5-7b-instruct-hgy", + HF_SECRET, +) diff --git a/plugins/huggingface/tests/test_huggingface_inference.py b/plugins/huggingface/tests/test_huggingface_chat_completion.py similarity index 65% rename from plugins/huggingface/tests/test_huggingface_inference.py rename to plugins/huggingface/tests/test_huggingface_chat_completion.py index 2dec0040..7b2aeb05 100644 --- a/plugins/huggingface/tests/test_huggingface_inference.py +++ b/plugins/huggingface/tests/test_huggingface_chat_completion.py @@ -7,10 +7,10 @@ from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken from modelgauge.prompt import SUTOptions, TextPrompt from modelgauge.sut import SUTCompletion, SUTResponse -from modelgauge.suts.huggingface_inference import ( +from modelgauge.suts.huggingface_chat_completion import ( ChatMessage, - HuggingFaceInferenceChatRequest, - HuggingFaceInferenceSUT, + HuggingFaceChatCompletionRequest, + HuggingFaceChatCompletionSUT, ) from pydantic import BaseModel @@ -24,11 +24,11 @@ def mock_endpoint(): @pytest.fixture -@patch("modelgauge.suts.huggingface_inference.get_inference_endpoint") +@patch("modelgauge.suts.huggingface_chat_completion.get_inference_endpoint") def fake_sut(mock_get_inference_endpoint, mock_endpoint): mock_get_inference_endpoint.return_value = mock_endpoint - sut = HuggingFaceInferenceSUT("fake_uid", "fake_endpoint", HuggingFaceInferenceToken("fake_token")) + sut = HuggingFaceChatCompletionSUT("fake_uid", "fake_endpoint", HuggingFaceInferenceToken("fake_token")) return sut @@ -42,17 +42,17 @@ def prompt(): @pytest.fixture def sut_request(): - return HuggingFaceInferenceChatRequest( + return HuggingFaceChatCompletionRequest( messages=[ChatMessage(role="user", content="some text prompt")], max_tokens=5, temperature=1.0, ) -def test_huggingface_inference_translate_text_prompt_request(fake_sut, prompt, sut_request): +def test_huggingface_chat_completion_translate_text_prompt_request(fake_sut, prompt, sut_request): request = fake_sut.translate_text_prompt(prompt) - assert isinstance(request, HuggingFaceInferenceChatRequest) + assert isinstance(request, HuggingFaceChatCompletionRequest) assert request == sut_request @@ -64,8 +64,10 @@ def test_huggingface_inference_translate_text_prompt_request(fake_sut, prompt, s InferenceEndpointStatus.UPDATING, ], ) -@patch("modelgauge.suts.huggingface_inference.get_inference_endpoint") -def test_huggingface_inference_connect_endpoint(mock_get_inference_endpoint, fake_sut, mock_endpoint, endpoint_status): +@patch("modelgauge.suts.huggingface_chat_completion.get_inference_endpoint") +def test_huggingface_chat_completion_connect_endpoint( + mock_get_inference_endpoint, fake_sut, mock_endpoint, endpoint_status +): mock_get_inference_endpoint.return_value = mock_endpoint mock_endpoint.status = endpoint_status @@ -73,8 +75,10 @@ def test_huggingface_inference_connect_endpoint(mock_get_inference_endpoint, fak mock_endpoint.wait.assert_called_once() -@patch("modelgauge.suts.huggingface_inference.get_inference_endpoint") -def test_huggingface_inference_connect_endpoint_scaled_to_zero(mock_get_inference_endpoint, fake_sut, mock_endpoint): +@patch("modelgauge.suts.huggingface_chat_completion.get_inference_endpoint") +def test_huggingface_chat_completion_connect_endpoint_scaled_to_zero( + mock_get_inference_endpoint, fake_sut, mock_endpoint +): mock_get_inference_endpoint.return_value = mock_endpoint mock_endpoint.status = InferenceEndpointStatus.SCALED_TO_ZERO @@ -84,8 +88,10 @@ def test_huggingface_inference_connect_endpoint_scaled_to_zero(mock_get_inferenc mock_endpoint.wait.assert_called_once() -@patch("modelgauge.suts.huggingface_inference.get_inference_endpoint") -def test_huggingface_inference_connect_endpoint_fails_to_resume(mock_get_inference_endpoint, fake_sut, mock_endpoint): +@patch("modelgauge.suts.huggingface_chat_completion.get_inference_endpoint") +def test_huggingface_chat_completion_connect_endpoint_fails_to_resume( + mock_get_inference_endpoint, fake_sut, mock_endpoint +): mock_get_inference_endpoint.return_value = mock_endpoint mock_endpoint.status = InferenceEndpointStatus.SCALED_TO_ZERO mock_endpoint.resume.side_effect = HfHubHTTPError("Failure.") @@ -95,8 +101,8 @@ def test_huggingface_inference_connect_endpoint_fails_to_resume(mock_get_inferen mock_endpoint.wait.assert_not_called() -@patch("modelgauge.suts.huggingface_inference.get_inference_endpoint") -def test_huggingface_inference_connect_failed_endpoint(mock_get_inference_endpoint, fake_sut, mock_endpoint): +@patch("modelgauge.suts.huggingface_chat_completion.get_inference_endpoint") +def test_huggingface_chat_completion_connect_failed_endpoint(mock_get_inference_endpoint, fake_sut, mock_endpoint): mock_get_inference_endpoint.return_value = mock_endpoint mock_endpoint.status = InferenceEndpointStatus.FAILED @@ -104,9 +110,9 @@ def test_huggingface_inference_connect_failed_endpoint(mock_get_inference_endpoi fake_sut._create_client() -@patch("modelgauge.suts.huggingface_inference.get_inference_endpoint") -@patch("modelgauge.suts.huggingface_inference.InferenceClient") -def test_huggingface_inference_lazy_load_client( +@patch("modelgauge.suts.huggingface_chat_completion.get_inference_endpoint") +@patch("modelgauge.suts.huggingface_chat_completion.InferenceClient") +def test_huggingface_chat_completion_lazy_load_client( mock_client, mock_get_inference_endpoint, fake_sut, mock_endpoint, sut_request ): mock_get_inference_endpoint.return_value = mock_endpoint @@ -118,8 +124,8 @@ def test_huggingface_inference_lazy_load_client( assert fake_sut.client is not None -@patch("modelgauge.suts.huggingface_inference.InferenceClient") -def test_huggingface_inference_evaluate(mock_client, fake_sut, sut_request): +@patch("modelgauge.suts.huggingface_chat_completion.InferenceClient") +def test_huggingface_chat_completion_evaluate(mock_client, fake_sut, sut_request): fake_sut.client = mock_client fake_sut.evaluate(sut_request) @@ -141,7 +147,7 @@ class FakeResponse(BaseModel): choices: list[FakeChoice] -def test_huggingface_inference_translate_response(fake_sut, sut_request): +def test_huggingface_chat_completion_translate_response(fake_sut, sut_request): evaluate_output = FakeResponse(choices=[FakeChoice(message=ChatMessage(content="response", role="assistant"))]) response = fake_sut.translate_response(sut_request, evaluate_output) diff --git a/plugins/validation_tests/test_object_creation.py b/plugins/validation_tests/test_object_creation.py index 70e77c98..166c2b86 100644 --- a/plugins/validation_tests/test_object_creation.py +++ b/plugins/validation_tests/test_object_creation.py @@ -14,7 +14,7 @@ from modelgauge.sut import PromptResponseSUT, SUTResponse from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.sut_registry import SUTS -from modelgauge.suts.huggingface_inference import HUGGING_FACE_TIMEOUT +from modelgauge.suts.huggingface_chat_completion import HUGGING_FACE_TIMEOUT from modelgauge.test_registry import TESTS # Ensure all the plugins are available during testing.