Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance image_predictor_example.ipynb with Interactive Point Addition Using Matplotlib #421

Open
future-158 opened this issue Oct 28, 2024 · 0 comments

Comments

@future-158
Copy link

Issue:
The current image_predictor_example.ipynb provides a good example of using the image predictor. However, it lacks interactivity, which can enhance user experience and facilitate experimentation.

Proposed Enhancement:
Introduce an interactive feature that allows users to add positive and negative points by clicking on the image:

  • Left Click: Add a positive point.
  • Right Click: Add a negative point.

This can be achieved with a simple Matplotlib-based script consisting of less than 110 lines of code, eliminating the need for complex third-party tools for simple testing.

Benefits:

  • Simplicity: Easier to understand and modify due to fewer lines of code.
  • User-Friendly: Enhances interactivity without adding external dependencies.
  • Educational Value: Helps users learn by directly interacting with the model.

Implementation:
I have made a concise script that demonstrates this functionality. You can view the complete code in the following Gist.

Example Code Snippet:

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from sam2.sam2_image_predictor import SAM2ImagePredictor
import requests

def load_image(url: str) -> Image.Image:
    headers = {
        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/85.0.4183.83 Safari/537.36"
    }    
    image = Image.open(requests.get(url, stream=True, headers=headers).raw)
    return image


%matplotlib widget

predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny")
# predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")


# load example image
url = "https://images.pexels.com/photos/529782/pexels-photo-529782.jpeg?auto=compress&cs=tinysrgb&w=800"
base_img = load_image(url)

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    predictor.set_image(base_img)




img = np.array(base_img)
fig, ax = plt.subplots()
im = ax.imshow(img)

# Remove ticks
ax.set_xticks([])
ax.set_yticks([])

# Remove tick labels
ax.set_xticklabels([])
ax.set_yticklabels([])

# Remove axis labels
ax.set_xlabel("")
ax.set_ylabel("")

plt.tight_layout()


positive_points = []
negative_points = []
mask = None


def inference() -> np.ndarray:
    global mask
    point_coords = [*positive_points, *negative_points]
    point_labels = [1] * len(positive_points) + [0] * len(negative_points)

    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
        masks, _, _ = predictor.predict(
            point_coords=point_coords,
            point_labels=point_labels,
            multimask_output=False,
        )

    mask = masks[0] > 0
    blended = Image.blend(
        Image.new("RGB", base_img.size, (0, 0, 255)), base_img, alpha=0.5
    )
    composited = Image.composite(blended, base_img, Image.fromarray(mask > 0)).convert(
        "RGB"
    )

    return np.array(composited)


def on_click(event):
    """
    Event handler for mouse click events on the plot.
    Parameters:
    - event: The mouse event.
    """
    if event.inaxes:
        x, y = event.xdata, event.ydata

        if event.button == 1:  # Left mouse button
            positive_points.append((x, y))
        elif event.button == 3:  # Right mouse button
            negative_points.append((x, y))

        new_rgb = inference()
        for p in positive_points:
            x, y = p
            x, y = int(x), int(y)
            new_rgb[y - 5 : y + 5, x - 5 : x + 5] = [0, 255, 0]

        for p in negative_points:
            x, y = p
            x, y = int(x), int(y)
            new_rgb[y - 5 : y + 5, x - 5 : x + 5] = [255, 0, 0]
        im.set_data(new_rgb)
        fig.canvas.draw_idle()


cid = fig.canvas.mpl_connect("button_press_event", on_click)
plt.show()

Conclusion:
when i first ran demo with my image, i draw image with plotly first (cause it show mouse point coordinates) and manually update postive points and negative points one by one.
i think adding this interactive example can make the image_predictor_example.ipynb more engaging.
I'm happy to contribute this example to the repository or provide further assistance if needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant