Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix most flaky Python tests in 5.0-dev branch #9550

Merged
merged 21 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/busy-lizards-heal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"gradio": minor
"gradio_client": minor
---

feat:Fix most flaky Python tests in `5.0-dev` branch
5 changes: 2 additions & 3 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(
)
api_prefix: str = self.config.get("api_prefix", "")
self.api_prefix = api_prefix.lstrip("/") + "/"
self.src_prefixed = urllib.parse.urljoin(self.src, api_prefix) + "/"
self.src_prefixed = urllib.parse.urljoin(self.src, api_prefix).rstrip("/") + "/"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we rstrip just to add back /?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

guarantees that the url ends in 1 and only 1 forward slash. previously was seeing some issues where sometimes there would be two forward slashes


self.api_url = urllib.parse.urljoin(self.src_prefixed, utils.API_URL)
self.sse_url = urllib.parse.urljoin(
Expand Down Expand Up @@ -1079,8 +1079,7 @@ def __init__(
self._get_component_type(id_) for id_ in dependency["outputs"]
]
self.parameters_info = self._get_parameters_info()

self.root_url = client.src.rstrip("/") + "/" + client.api_prefix
self.root_url = self.client.src_prefixed

# Disallow hitting endpoints that the Gradio app has disabled
self.is_valid = self.api_name is not False
Expand Down
10 changes: 6 additions & 4 deletions client/python/test/test_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import json
import os
import pathlib
import tempfile
import time
Expand All @@ -16,7 +15,6 @@
import httpx
import huggingface_hub
import pytest
from huggingface_hub import HfFolder
from huggingface_hub.utils import RepositoryNotFoundError

from gradio_client import Client, handle_file
Expand All @@ -30,7 +28,7 @@
StatusUpdate,
)

HF_TOKEN = os.getenv("HF_TOKEN") or HfFolder.get_token()
HF_TOKEN = huggingface_hub.get_token()


@contextmanager
Expand Down Expand Up @@ -130,7 +128,7 @@ def test_private_space(self):
space_id = "gradio-tests/not-actually-private-space"
api = huggingface_hub.HfApi()
assert api.space_info(space_id).private
client = Client(space_id)
client = Client(space_id, hf_token=HF_TOKEN)
output = client.predict("abc", api_name="/predict")
assert output == "abc"

Expand All @@ -141,6 +139,7 @@ def test_private_space_v4(self):
assert api.space_info(space_id).private
client = Client(
space_id,
hf_token=HF_TOKEN,
)
output = client.predict("abc", api_name="/predict")
assert output == "abc"
Expand All @@ -152,6 +151,7 @@ def test_private_space_v4_sse_v1(self):
assert api.space_info(space_id).private
client = Client(
space_id,
hf_token=HF_TOKEN,
)
output = client.predict("abc", api_name="/predict")
assert output == "abc"
Expand Down Expand Up @@ -1017,6 +1017,7 @@ def test_state_does_not_appear(self, state_demo):
def test_private_space(self):
client = Client(
"gradio-tests/not-actually-private-space",
hf_token=HF_TOKEN,
)
assert len(client.endpoints) == 3
assert len([e for e in client.endpoints.values() if e.is_valid]) == 2
Expand Down Expand Up @@ -1252,6 +1253,7 @@ class TestEndpoints:
def test_upload(self):
client = Client(
src="gradio-tests/not-actually-private-file-upload",
hf_token=HF_TOKEN,
)
response = MagicMock(status_code=200)
response.json.return_value = [
Expand Down
1 change: 1 addition & 0 deletions gradio/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def custom_post_binary(data):
elif p == "translation":
inputs = components.Textbox(label="Input")
outputs = components.Textbox(label="Translation")
postprocess = lambda x: x.translation_text # noqa: E731
examples = ["Hello, how are you?"]
fn = client.translation
# Example: facebook/bart-large-mnli
Expand Down
6 changes: 3 additions & 3 deletions gradio/external_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import httpx
import yaml
from huggingface_hub import HfApi, InferenceClient
from huggingface_hub import HfApi, ImageClassificationOutputElement, InferenceClient

from gradio import components

Expand Down Expand Up @@ -89,8 +89,8 @@ def rows_to_cols(incoming_data: dict) -> dict[str, dict[str, dict[str, list[str]
##################


def postprocess_label(scores: list[dict[str, str | float]]) -> dict:
return {c["label"]: c["score"] for c in scores}
def postprocess_label(scores: list[ImageClassificationOutputElement]) -> dict:
return {c.label: c.score for c in scores}


def postprocess_mask_tokens(scores: list[dict[str, str | float]]) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion gradio/themes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def from_hub(cls, repo_name: str, hf_token: str | None = None):

try:
space_info = api.space_info(name)
except huggingface_hub.utils._errors.RepositoryNotFoundError as e:
except huggingface_hub.utils.RepositoryNotFoundError as e:
raise ValueError(f"The space {name} does not exist") from e

assets = get_theme_assets(space_info)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ fastapi<1.0
ffmpy
gradio_client==1.4.0-beta.3
httpx>=0.24.1
huggingface_hub>=0.22.0
huggingface_hub>=0.25.1
Jinja2<4.0
markupsafe~=2.0
numpy>=1.0,<3.0
Expand Down
8 changes: 0 additions & 8 deletions test/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,6 @@ httpx==0.23.0
# gradio
# gradio-client
# respx
huggingface-hub==0.21.4
# via
# -r requirements.in
# diffusers
# gradio
# gradio-client
# tokenizers
# transformers
hypothesis==6.108.9
idna==3.3
# via
Expand Down
63 changes: 23 additions & 40 deletions test/test_external.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import os
import tempfile
import textwrap
import warnings
from pathlib import Path
from unittest.mock import MagicMock, patch

import huggingface_hub
import pytest
from fastapi.testclient import TestClient
from gradio_client import media_data
from huggingface_hub import HfFolder

import gradio as gr
from gradio.context import Context
Expand All @@ -31,7 +30,7 @@

os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"

HF_TOKEN = os.getenv("HF_TOKEN") or HfFolder.get_token()
HF_TOKEN = huggingface_hub.get_token()


class TestLoadInterface:
Expand Down Expand Up @@ -76,7 +75,7 @@ def test_text_generation(self):
def test_summarization(self):
model_type = "summarization"
interface = gr.load(
"models/facebook/bart-large-cnn", hf_token=None, alias=model_type
"models/facebook/bart-large-cnn", hf_token=HF_TOKEN, alias=model_type
)
assert interface.__name__ == model_type
assert interface.input_components and interface.output_components
Expand All @@ -86,7 +85,7 @@ def test_summarization(self):
def test_translation(self):
model_type = "translation"
interface = gr.load(
"models/facebook/bart-large-cnn", hf_token=None, alias=model_type
"models/facebook/bart-large-cnn", hf_token=HF_TOKEN, alias=model_type
)
assert interface.__name__ == model_type
assert interface.input_components and interface.output_components
Expand All @@ -96,7 +95,7 @@ def test_translation(self):
def test_text2text_generation(self):
model_type = "text2text-generation"
interface = gr.load(
"models/sshleifer/tiny-mbart", hf_token=None, alias=model_type
"models/sshleifer/tiny-mbart", hf_token=HF_TOKEN, alias=model_type
)
assert interface.__name__ == model_type
assert interface.input_components and interface.output_components
Expand All @@ -107,7 +106,7 @@ def test_text_classification(self):
model_type = "text-classification"
interface = gr.load(
"models/distilbert-base-uncased-finetuned-sst-2-english",
hf_token=None,
hf_token=HF_TOKEN,
alias=model_type,
)
assert interface.__name__ == model_type
Expand All @@ -117,7 +116,9 @@ def test_text_classification(self):

def test_fill_mask(self):
model_type = "fill-mask"
interface = gr.load("models/bert-base-uncased", hf_token=None, alias=model_type)
interface = gr.load(
"models/bert-base-uncased", hf_token=HF_TOKEN, alias=model_type
)
assert interface.__name__ == model_type
assert interface.input_components and interface.output_components
assert isinstance(interface.input_components[0], gr.Textbox)
Expand All @@ -126,7 +127,7 @@ def test_fill_mask(self):
def test_zero_shot_classification(self):
model_type = "zero-shot-classification"
interface = gr.load(
"models/facebook/bart-large-mnli", hf_token=None, alias=model_type
"models/facebook/bart-large-mnli", hf_token=HF_TOKEN, alias=model_type
)
assert interface.__name__ == model_type
assert interface.input_components and interface.output_components
Expand All @@ -138,7 +139,7 @@ def test_zero_shot_classification(self):
def test_automatic_speech_recognition(self):
model_type = "automatic-speech-recognition"
interface = gr.load(
"models/facebook/wav2vec2-base-960h", hf_token=None, alias=model_type
"models/facebook/wav2vec2-base-960h", hf_token=HF_TOKEN, alias=model_type
)
assert interface.__name__ == model_type
assert interface.input_components and interface.output_components
Expand All @@ -148,7 +149,7 @@ def test_automatic_speech_recognition(self):
def test_image_classification(self):
model_type = "image-classification"
interface = gr.load(
"models/google/vit-base-patch16-224", hf_token=None, alias=model_type
"models/google/vit-base-patch16-224", hf_token=HF_TOKEN, alias=model_type
)
assert interface.__name__ == model_type
assert interface.input_components and interface.output_components
Expand All @@ -159,7 +160,7 @@ def test_feature_extraction(self):
model_type = "feature-extraction"
interface = gr.load(
"models/sentence-transformers/distilbert-base-nli-mean-tokens",
hf_token=None,
hf_token=HF_TOKEN,
alias=model_type,
)
assert interface.__name__ == model_type
Expand All @@ -171,7 +172,7 @@ def test_sentence_similarity(self):
model_type = "text-to-speech"
interface = gr.load(
"models/julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train",
hf_token=None,
hf_token=HF_TOKEN,
alias=model_type,
)
assert interface.__name__ == model_type
Expand All @@ -183,7 +184,7 @@ def test_text_to_speech(self):
model_type = "text-to-speech"
interface = gr.load(
"models/julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train",
hf_token=None,
hf_token=HF_TOKEN,
alias=model_type,
)
assert interface.__name__ == model_type
Expand All @@ -204,22 +205,22 @@ def test_english_to_spanish_v4(self):

def test_sentiment_model(self):
io = gr.load(
"models/distilbert-base-uncased-finetuned-sst-2-english", hf_token=None
"models/distilbert-base-uncased-finetuned-sst-2-english", hf_token=HF_TOKEN
)
try:
assert io("I am happy, I love you")["label"] == "POSITIVE"
except TooManyRequestsError:
pass

def test_image_classification_model(self):
io = gr.load(name="models/google/vit-base-patch16-224", hf_token=None)
io = gr.load(name="models/google/vit-base-patch16-224", hf_token=HF_TOKEN)
try:
assert io("gradio/test_data/lion.jpg")["label"].startswith("lion")
except TooManyRequestsError:
pass

def test_translation_model(self):
io = gr.load(name="models/t5-base", hf_token=None)
io = gr.load(name="models/t5-base", hf_token=HF_TOKEN)
try:
output = io("My name is Sarah and I live in London")
assert output == "Mein Name ist Sarah und ich lebe in London"
Expand All @@ -239,47 +240,29 @@ def test_numerical_to_label_space(self):
pass

def test_visual_question_answering(self):
io = gr.load("models/dandelin/vilt-b32-finetuned-vqa", hf_token=None)
io = gr.load("models/dandelin/vilt-b32-finetuned-vqa", hf_token=HF_TOKEN)
try:
output = io("gradio/test_data/lion.jpg", "What is in the image?")
assert isinstance(output, dict) and "label" in output
except TooManyRequestsError:
pass

def test_image_to_text(self):
io = gr.load("models/nlpconnect/vit-gpt2-image-captioning", hf_token=None)
io = gr.load("models/nlpconnect/vit-gpt2-image-captioning", hf_token=HF_TOKEN)
try:
output = io("gradio/test_data/lion.jpg")
assert isinstance(output, str)
except TooManyRequestsError:
pass

def test_speech_recognition_model(self):
io = gr.load("models/facebook/wav2vec2-base-960h", hf_token=None)
io = gr.load("models/facebook/wav2vec2-base-960h", hf_token=HF_TOKEN)
try:
output = io("gradio/test_data/test_audio.wav")
assert output is not None
except TooManyRequestsError:
pass

app, _, _ = io.launch(prevent_thread_lock=True, show_error=True)
client = TestClient(app)
resp = client.post(
"api/predict",
json={"fn_index": 0, "data": [media_data.BASE64_AUDIO], "name": "sample"},
)
try:
if resp.status_code != 200:
warnings.warn("Request for speech recognition model failed!")
assert (
"Could not complete request to HuggingFace API"
not in resp.json()["error"]
)
else:
assert resp.json()["data"] is not None
finally:
io.close()

def test_private_space(self):
io = gr.load(
"spaces/gradio-tests/not-actually-private-spacev4-sse", hf_token=HF_TOKEN
Expand Down Expand Up @@ -329,7 +312,7 @@ def test_loading_files_via_proxy_works(self):
def test_private_space_v4_sse_v1(self):
io = gr.load(
"spaces/gradio-tests/not-actually-private-spacev4-sse-v1",
hf_token=HfFolder.get_token(),
hf_token=HF_TOKEN,
)
try:
output = io("abc")
Expand Down Expand Up @@ -359,7 +342,7 @@ def test_interface_load_cache_examples(self, tmp_path):
name="models/google/vit-base-patch16-224",
examples=[Path(test_file_dir, "cheetah1.jpg")],
cache_examples=True,
hf_token=None,
hf_token=HF_TOKEN,
)
except TooManyRequestsError:
pass
Expand Down
Loading