Skip to content

Commit

Permalink
filter out combined class names on server-side for GDINO
Browse files Browse the repository at this point in the history
  • Loading branch information
naokiyokoyamabd committed Jul 19, 2023
1 parent 8364129 commit a9471b2
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions zsos/vlm/grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,20 @@ def predict(self, image: np.ndarray, visualize: bool = False) -> ObjectDetection
box_threshold=self.box_threshold,
text_threshold=self.text_threshold,
)
det = ObjectDetections(
detections = ObjectDetections(
boxes, logits, phrases, image_source=image, visualize=visualize
)
return det

classes = self.classes.split(" . ")
keep = torch.tensor(
[p in classes for p in detections.phrases], dtype=torch.bool
)

detections.boxes = detections.boxes[keep]
detections.logits = detections.logits[keep]
detections.phrases = [p for i, p in enumerate(detections.phrases) if keep[i]]

return detections


class GroundingDINOClient:
Expand Down

0 comments on commit a9471b2

Please sign in to comment.