Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added centroid calculation and plot feature #472

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions ultralytics/engine/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, ke
self.probs = Probs(probs) if probs is not None else None
self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
self.obb = OBB(obb, self.orig_shape) if obb is not None else None
self.centroids = Centroids(self.boxes.xyxy, self.orig_shape) if self.boxes is not None else None
self.speed = {"preprocess": None, "inference": None, "postprocess": None} # milliseconds per image
self.names = names
self.path = path
Expand Down Expand Up @@ -198,6 +199,7 @@ def plot(
boxes=True,
masks=True,
probs=True,
centroids=True,
show=False,
save=False,
filename=None,
Expand All @@ -219,6 +221,7 @@ def plot(
boxes (bool): Whether to plot the bounding boxes.
masks (bool): Whether to plot the masks.
probs (bool): Whether to plot classification probability
centroids (bool): Whether to plot the centroids.
show (bool): Whether to display the annotated image directly.
save (bool): Whether to save the annotated image to `filename`.
filename (str): Filename to save image to if save is True.
Expand Down Expand Up @@ -248,6 +251,7 @@ def plot(
pred_boxes, show_boxes = self.obb if is_obb else self.boxes, boxes
pred_masks, show_masks = self.masks, masks
pred_probs, show_probs = self.probs, probs
pred_centroids, show_centroids = self.centroids, centroids
annotator = Annotator(
deepcopy(self.orig_img if img is None else img),
line_width,
Expand Down Expand Up @@ -291,6 +295,16 @@ def plot(
for k in reversed(self.keypoints.data):
annotator.kpts(k, self.orig_shape, radius=kpt_radius, kpt_line=kpt_line)

# Plot centroids
if pred_centroids is not None and show_centroids:
for i, centroid in enumerate(pred_centroids.xy):
if pred_boxes is not None:
c = int(pred_boxes.cls[i])
color = colors(c, True)
else:
color = (0, 255, 0) # Default to green if no class information
annotator.dot(centroid, color=color, radius=5)

# Show results
if show:
annotator.show(self.path)
Expand Down Expand Up @@ -419,6 +433,13 @@ def summary(self, normalize=False, decimals=5):
"y": (y / h).numpy().round(decimals).tolist(),
"visible": visible.numpy().round(decimals).tolist(),
}

if self.centroids is not None:
result["centroid"] = {
"x": round(self.centroids.xy[i][0].item() / w, decimals),
"y": round(self.centroids.xy[i][1].item() / h, decimals)
}

results.append(result)

return results
Expand Down Expand Up @@ -741,3 +762,47 @@ def xyxy(self):
y2 = self.xyxyxyxy[..., 1].max(1).values
xyxy = [x1, y1, x2, y2]
return np.stack(xyxy, axis=-1) if isinstance(self.data, np.ndarray) else torch.stack(xyxy, dim=-1)


class Centroids(BaseTensor):
"""
A class for storing and manipulating detection centroids.

Attributes:
xy (torch.Tensor): A tensor containing x, y coordinates of centroids for each detection.
xyn (torch.Tensor): A normalized version of xy with coordinates in the range [0, 1].

Methods:
cpu(): Returns a copy of the centroids tensor on CPU memory.
numpy(): Returns a copy of the centroids tensor as a numpy array.
cuda(): Returns a copy of the centroids tensor on GPU memory.
to(device, dtype): Returns a copy of the centroids tensor with the specified device and dtype.
"""

def __init__(self, boxes, orig_shape) -> None:
"""
Initialize the Centroids object with bounding boxes and original image size.

Args:
boxes (torch.Tensor): A tensor of bounding boxes in xyxy format.
orig_shape (tuple): Original image size in (height, width) format.
"""
centroids = torch.zeros(boxes.shape[0], 2, device=boxes.device)
centroids[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2 # x-coordinate
centroids[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2 # y-coordinate
super().__init__(centroids, orig_shape)

@property
@lru_cache(maxsize=1)
def xy(self):
"""Returns x, y coordinates of centroids."""
return self.data

@property
@lru_cache(maxsize=1)
def xyn(self):
"""Returns normalized x, y coordinates of centroids."""
xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy)
xy[:, 0] /= self.orig_shape[1]
xy[:, 1] /= self.orig_shape[0]
return xy
17 changes: 17 additions & 0 deletions ultralytics/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,23 @@ def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0,
cv2.circle(self.im, center_bbox, pins_radius, color, -1)
cv2.line(self.im, center_point, center_bbox, color, thickness)

def dot(self, xy, color=(255, 0, 0), radius=5):
"""
Draw a dot on the image.

Args:
xy (tuple): The (x, y) coordinates of the dot's center.
color (tuple, optional): The color of the dot in BGR format. Defaults to (255, 0, 0) (blue).
radius (int, optional): The radius of the dot. Defaults to 5.
"""
if self.pil:
# Convert to PIL ImageDraw
draw = ImageDraw.Draw(self.im)
draw.ellipse([xy[0]-radius, xy[1]-radius, xy[0]+radius, xy[1]+radius], fill=color, outline=color)
else:
cv2.circle(self.im, (int(xy[0]), int(xy[1])), radius, color, -1, lineType=cv2.LINE_AA)



@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
@plt_settings()
Expand Down