Skip to content

Commit

Permalink
<feat>Support ultralytics classification (#44)
Browse files Browse the repository at this point in the history
* <feat>Support ultralytics classification

* <feat>render detect results

* Remove invalid package

---------

Co-authored-by: Dickson Neoh <dickson.neoh@gmail.com>
  • Loading branch information
315386775 and dnth authored Oct 31, 2024
1 parent adbc1d6 commit c5d1a45
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 17 deletions.
10 changes: 8 additions & 2 deletions nbs/yolo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -510,14 +510,20 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"model = xinfer.create_model(\"ultralytics/yolov11n-cls\", device=\"cuda\", dtype=\"float16\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"results = model.infer(\"https://ultralytics.com/images/bus.jpg\")\n",
"model.render('./images/')\n",
"results"
]
}
],
"metadata": {
Expand Down
55 changes: 40 additions & 15 deletions xinfer/ultralytics/ultralytics_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Dict, List

import torch
Expand All @@ -12,36 +13,60 @@ def __init__(
self, model_id: str, device: str = "cpu", dtype: str = "float32", **kwargs
):
super().__init__(model_id, device, dtype)
self.model_type = 'classification' if 'cls' in model_id else 'detection'
self.load_model(**kwargs)

def load_model(self, **kwargs):
self.model = YOLO(self.model_id.replace("ultralytics/", ""), **kwargs)
if self.model_type == 'classification':
self.model = YOLO(self.model_id.replace("ultralytics/", ""), task='classification', **kwargs)
else:
self.model = YOLO(self.model_id.replace("ultralytics/", ""), **kwargs)

@track_inference
def infer_batch(self, images: str | List[str], **kwargs) -> List[List[Dict]]:
half = self.dtype == torch.float16
results = self.model.predict(images, device=self.device, half=half, **kwargs)

self.results = self.model.predict(images, device=self.device, half=half, **kwargs)
batch_results = []
for result in results:
coco_format_results = []
boxes = result.boxes
for box in boxes:
x1, y1, x2, y2 = box.xyxy[0].tolist()
width = x2 - x1
height = y2 - y1
coco_format_results.append(
{
for result in self.results:

if self.model_type == 'classification':
classification_results = []
probs = result.probs
classification_results.append({
"class_id": int(probs.top1),
"score": float(probs.top1conf.cpu().numpy()),
"class_name": result.names[int(probs.top1)],
})
batch_results.append(classification_results)

else:
detection_results = []
boxes = result.boxes
for box in boxes:
x1, y1, x2, y2 = box.xyxy[0].tolist()
width = x2 - x1
height = y2 - y1
detection_results.append({
"bbox": [x1, y1, width, height],
"category_id": int(box.cls),
"score": float(box.conf),
"class_name": result.names[int(box.cls)],
}
)
batch_results.append(coco_format_results)
})
batch_results.append(detection_results)
return batch_results

@track_inference
def infer(self, image: str, **kwargs) -> List[List[Dict]]:
results = self.infer_batch([image], **kwargs)
return results[0]

def render(self, save_path: str = './', **kwargs):
for _, r in enumerate(self.results):
# im_bgr = r.plot()
# im_rgb = Image.fromarray(im_bgr[..., ::-1])

# plot results (such as bounding boxes, masks, keypoints, and probabilities)
file_name = os.path.basename(r.path)
file_name = os.path.join(save_path, file_name)
r.save(filename=f"{file_name}")
4 changes: 4 additions & 0 deletions xinfer/ultralytics/yolov11.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
@register_model("ultralytics/yolov11s", "ultralytics", ModelInputOutput.IMAGE_TO_BOXES)
@register_model("ultralytics/yolov11m", "ultralytics", ModelInputOutput.IMAGE_TO_BOXES)
@register_model("ultralytics/yolov11l", "ultralytics", ModelInputOutput.IMAGE_TO_BOXES)
@register_model("ultralytics/yolov11n-cls", "ultralytics", ModelInputOutput.IMAGE_TO_CATEGORIES)
@register_model("ultralytics/yolov11s-cls", "ultralytics", ModelInputOutput.IMAGE_TO_CATEGORIES)
@register_model("ultralytics/yolov11m-cls", "ultralytics", ModelInputOutput.IMAGE_TO_CATEGORIES)
@register_model("ultralytics/yolov11l-cls", "ultralytics", ModelInputOutput.IMAGE_TO_CATEGORIES)
class YOLOv11(UltralyticsModel):
def __init__(self, model_id: str, **kwargs):
model_id = model_id.replace("v", "")
Expand Down
5 changes: 5 additions & 0 deletions xinfer/ultralytics/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
@register_model("ultralytics/yolov8l", "ultralytics", ModelInputOutput.IMAGE_TO_BOXES)
@register_model("ultralytics/yolov8m", "ultralytics", ModelInputOutput.IMAGE_TO_BOXES)
@register_model("ultralytics/yolov8x", "ultralytics", ModelInputOutput.IMAGE_TO_BOXES)
@register_model("ultralytics/yolov8n-cls", "ultralytics", ModelInputOutput.IMAGE_TO_CATEGORIES)
@register_model("ultralytics/yolov8s-cls", "ultralytics", ModelInputOutput.IMAGE_TO_CATEGORIES)
@register_model("ultralytics/yolov8l-cls", "ultralytics", ModelInputOutput.IMAGE_TO_CATEGORIES)
@register_model("ultralytics/yolov8m-cls", "ultralytics", ModelInputOutput.IMAGE_TO_CATEGORIES)
@register_model("ultralytics/yolov8x-cls", "ultralytics", ModelInputOutput.IMAGE_TO_CATEGORIES)
class YOLOv8(UltralyticsModel):
def __init__(self, model_id: str, **kwargs):
super().__init__(model_id, **kwargs)

0 comments on commit c5d1a45

Please sign in to comment.