From a45ab7029012da80547dc864fc7f6509c93afb67 Mon Sep 17 00:00:00 2001 From: Riza Semih Koca Date: Thu, 24 Oct 2024 20:30:43 +0300 Subject: [PATCH] Added centroid calculation and plot feature --- ultralytics/engine/results.py | 65 +++++++++++++++++++++++++++++++++++ ultralytics/utils/plotting.py | 17 +++++++++ 2 files changed, 82 insertions(+) diff --git a/ultralytics/engine/results.py b/ultralytics/engine/results.py index 85849c34d..c6459e87b 100644 --- a/ultralytics/engine/results.py +++ b/ultralytics/engine/results.py @@ -115,6 +115,7 @@ def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, ke self.probs = Probs(probs) if probs is not None else None self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None self.obb = OBB(obb, self.orig_shape) if obb is not None else None + self.centroids = Centroids(self.boxes.xyxy, self.orig_shape) if self.boxes is not None else None self.speed = {"preprocess": None, "inference": None, "postprocess": None} # milliseconds per image self.names = names self.path = path @@ -198,6 +199,7 @@ def plot( boxes=True, masks=True, probs=True, + centroids=True, show=False, save=False, filename=None, @@ -219,6 +221,7 @@ def plot( boxes (bool): Whether to plot the bounding boxes. masks (bool): Whether to plot the masks. probs (bool): Whether to plot classification probability + centroids (bool): Whether to plot the centroids. show (bool): Whether to display the annotated image directly. save (bool): Whether to save the annotated image to `filename`. filename (str): Filename to save image to if save is True. @@ -248,6 +251,7 @@ def plot( pred_boxes, show_boxes = self.obb if is_obb else self.boxes, boxes pred_masks, show_masks = self.masks, masks pred_probs, show_probs = self.probs, probs + pred_centroids, show_centroids = self.centroids, centroids annotator = Annotator( deepcopy(self.orig_img if img is None else img), line_width, @@ -291,6 +295,16 @@ def plot( for k in reversed(self.keypoints.data): annotator.kpts(k, self.orig_shape, radius=kpt_radius, kpt_line=kpt_line) + # Plot centroids + if pred_centroids is not None and show_centroids: + for i, centroid in enumerate(pred_centroids.xy): + if pred_boxes is not None: + c = int(pred_boxes.cls[i]) + color = colors(c, True) + else: + color = (0, 255, 0) # Default to green if no class information + annotator.dot(centroid, color=color, radius=5) + # Show results if show: annotator.show(self.path) @@ -419,6 +433,13 @@ def summary(self, normalize=False, decimals=5): "y": (y / h).numpy().round(decimals).tolist(), "visible": visible.numpy().round(decimals).tolist(), } + + if self.centroids is not None: + result["centroid"] = { + "x": round(self.centroids.xy[i][0].item() / w, decimals), + "y": round(self.centroids.xy[i][1].item() / h, decimals) + } + results.append(result) return results @@ -741,3 +762,47 @@ def xyxy(self): y2 = self.xyxyxyxy[..., 1].max(1).values xyxy = [x1, y1, x2, y2] return np.stack(xyxy, axis=-1) if isinstance(self.data, np.ndarray) else torch.stack(xyxy, dim=-1) + + +class Centroids(BaseTensor): + """ + A class for storing and manipulating detection centroids. + + Attributes: + xy (torch.Tensor): A tensor containing x, y coordinates of centroids for each detection. + xyn (torch.Tensor): A normalized version of xy with coordinates in the range [0, 1]. + + Methods: + cpu(): Returns a copy of the centroids tensor on CPU memory. + numpy(): Returns a copy of the centroids tensor as a numpy array. + cuda(): Returns a copy of the centroids tensor on GPU memory. + to(device, dtype): Returns a copy of the centroids tensor with the specified device and dtype. + """ + + def __init__(self, boxes, orig_shape) -> None: + """ + Initialize the Centroids object with bounding boxes and original image size. + + Args: + boxes (torch.Tensor): A tensor of bounding boxes in xyxy format. + orig_shape (tuple): Original image size in (height, width) format. + """ + centroids = torch.zeros(boxes.shape[0], 2, device=boxes.device) + centroids[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2 # x-coordinate + centroids[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2 # y-coordinate + super().__init__(centroids, orig_shape) + + @property + @lru_cache(maxsize=1) + def xy(self): + """Returns x, y coordinates of centroids.""" + return self.data + + @property + @lru_cache(maxsize=1) + def xyn(self): + """Returns normalized x, y coordinates of centroids.""" + xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy) + xy[:, 0] /= self.orig_shape[1] + xy[:, 1] /= self.orig_shape[0] + return xy \ No newline at end of file diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py index d0215ba5e..d3aa74e17 100644 --- a/ultralytics/utils/plotting.py +++ b/ultralytics/utils/plotting.py @@ -593,6 +593,23 @@ def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, cv2.circle(self.im, center_bbox, pins_radius, color, -1) cv2.line(self.im, center_point, center_bbox, color, thickness) + def dot(self, xy, color=(255, 0, 0), radius=5): + """ + Draw a dot on the image. + + Args: + xy (tuple): The (x, y) coordinates of the dot's center. + color (tuple, optional): The color of the dot in BGR format. Defaults to (255, 0, 0) (blue). + radius (int, optional): The radius of the dot. Defaults to 5. + """ + if self.pil: + # Convert to PIL ImageDraw + draw = ImageDraw.Draw(self.im) + draw.ellipse([xy[0]-radius, xy[1]-radius, xy[0]+radius, xy[1]+radius], fill=color, outline=color) + else: + cv2.circle(self.im, (int(xy[0]), int(xy[1])), radius, color, -1, lineType=cv2.LINE_AA) + + @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395 @plt_settings()