Skip to content

Commit

Permalink
update gradio demo layout
Browse files Browse the repository at this point in the history
  • Loading branch information
dnth committed Oct 25, 2024
1 parent 987e19c commit 0317a26
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 14 deletions.
2 changes: 1 addition & 1 deletion xinfer/timm/timm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def load_model(self, **kwargs):
self.model = timm.create_model(self.model_id, pretrained=True, **kwargs).to(
self.device, self.dtype
)
self.model = torch.compile(self.model, mode="max-autotune")
# self.model = torch.compile(self.model, mode="max-autotune")
self.model.eval()

def preprocess(self, images: str | list[str]):
Expand Down
49 changes: 36 additions & 13 deletions xinfer/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def launch_gradio_demo():
"""
available_models = [model.id for model in model_registry.list_models()]

# Add example image URLs
example_images = [
"https://raw.githubusercontent.com/dnth/x.infer/refs/heads/main/assets/demo/000b9c365c9e307a.jpg",
"https://raw.githubusercontent.com/dnth/x.infer/refs/heads/main/assets/demo/00aa2580828a9009.jpg",
"https://raw.githubusercontent.com/dnth/x.infer/refs/heads/main/assets/demo/0a6ee446579d2885.jpg",
]

def load_model_and_infer(model_id, image, prompt, device, dtype):
model = create_model(model_id, device=device, dtype=dtype)
model_info = model_registry.get_model_info(model_id)
Expand All @@ -92,19 +99,35 @@ def load_model_and_infer(model_id, image, prompt, device, dtype):
with gr.Blocks() as demo:
gr.Markdown("# x.infer Gradio Demo")

model_dropdown = gr.Dropdown(choices=available_models, label="Select a model")
image_input = gr.Image(type="filepath", label="Input Image")
prompt_input = gr.Textbox(
label="Prompt (for image-text to text models)", visible=False
)
device_dropdown = gr.Dropdown(
choices=["cuda", "cpu"], label="Device", value="cuda"
)
dtype_dropdown = gr.Dropdown(
choices=["float32", "float16", "bfloat16"], label="Dtype", value="float16"
)
run_button = gr.Button("Run Inference")
output = gr.Textbox(label="Result", lines=10)
with gr.Row():
# Left column: Input controls
with gr.Column(scale=1):
model_dropdown = gr.Dropdown(
choices=available_models, label="Select a model"
)
with gr.Row():
device_dropdown = gr.Dropdown(
choices=["cuda", "cpu"], label="Device", value="cuda"
)
dtype_dropdown = gr.Dropdown(
choices=["float32", "float16", "bfloat16"],
label="Dtype",
value="float16",
)
prompt_input = gr.Textbox(
label="Prompt (for image-text to text models)", visible=False
)
run_button = gr.Button("Run Inference", variant="primary")

# Right column: Image input
with gr.Column(scale=1):
image_input = gr.Image(type="filepath", label="Input Image", height=400)

# Results section
output = gr.Textbox(label="Result", lines=5)

# Add examples
gr.Examples(examples=example_images, inputs=image_input, label="Example Images")

def update_prompt_visibility(model_id):
model_info = model_registry.get_model_info(model_id)
Expand Down

0 comments on commit 0317a26

Please sign in to comment.