diff --git a/xinfer/timm/timm_model.py b/xinfer/timm/timm_model.py index 82d60da..d0a841b 100644 --- a/xinfer/timm/timm_model.py +++ b/xinfer/timm/timm_model.py @@ -21,7 +21,7 @@ def load_model(self, **kwargs): self.model = timm.create_model(self.model_id, pretrained=True, **kwargs).to( self.device, self.dtype ) - self.model = torch.compile(self.model, mode="max-autotune") + # self.model = torch.compile(self.model, mode="max-autotune") self.model.eval() def preprocess(self, images: str | list[str]): diff --git a/xinfer/viz.py b/xinfer/viz.py index 858373b..1c93e27 100644 --- a/xinfer/viz.py +++ b/xinfer/viz.py @@ -66,6 +66,13 @@ def launch_gradio_demo(): """ available_models = [model.id for model in model_registry.list_models()] + # Add example image URLs + example_images = [ + "https://raw.githubusercontent.com/dnth/x.infer/refs/heads/main/assets/demo/000b9c365c9e307a.jpg", + "https://raw.githubusercontent.com/dnth/x.infer/refs/heads/main/assets/demo/00aa2580828a9009.jpg", + "https://raw.githubusercontent.com/dnth/x.infer/refs/heads/main/assets/demo/0a6ee446579d2885.jpg", + ] + def load_model_and_infer(model_id, image, prompt, device, dtype): model = create_model(model_id, device=device, dtype=dtype) model_info = model_registry.get_model_info(model_id) @@ -92,19 +99,35 @@ def load_model_and_infer(model_id, image, prompt, device, dtype): with gr.Blocks() as demo: gr.Markdown("# x.infer Gradio Demo") - model_dropdown = gr.Dropdown(choices=available_models, label="Select a model") - image_input = gr.Image(type="filepath", label="Input Image") - prompt_input = gr.Textbox( - label="Prompt (for image-text to text models)", visible=False - ) - device_dropdown = gr.Dropdown( - choices=["cuda", "cpu"], label="Device", value="cuda" - ) - dtype_dropdown = gr.Dropdown( - choices=["float32", "float16", "bfloat16"], label="Dtype", value="float16" - ) - run_button = gr.Button("Run Inference") - output = gr.Textbox(label="Result", lines=10) + with gr.Row(): + # Left column: Input controls + with gr.Column(scale=1): + model_dropdown = gr.Dropdown( + choices=available_models, label="Select a model" + ) + with gr.Row(): + device_dropdown = gr.Dropdown( + choices=["cuda", "cpu"], label="Device", value="cuda" + ) + dtype_dropdown = gr.Dropdown( + choices=["float32", "float16", "bfloat16"], + label="Dtype", + value="float16", + ) + prompt_input = gr.Textbox( + label="Prompt (for image-text to text models)", visible=False + ) + run_button = gr.Button("Run Inference", variant="primary") + + # Right column: Image input + with gr.Column(scale=1): + image_input = gr.Image(type="filepath", label="Input Image", height=400) + + # Results section + output = gr.Textbox(label="Result", lines=5) + + # Add examples + gr.Examples(examples=example_images, inputs=image_input, label="Example Images") def update_prompt_visibility(model_id): model_info = model_registry.get_model_info(model_id)