Skip to content

Commit

Permalink
making coco and non coco thresh configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
naokiyokoyamabd committed Sep 5, 2023
1 parent a096e80 commit 83d053a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
10 changes: 8 additions & 2 deletions zsos/policy/base_objectnav_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(
obstacle_map_area_threshold: float = 1.5,
use_vqa: bool = True,
vqa_prompt: str = "Is this ",
coco_threshold: float = 0.6,
non_coco_threshold: float = 0.4,
*args,
**kwargs,
):
Expand All @@ -76,6 +78,8 @@ def __init__(
self._pointnav_stop_radius = pointnav_stop_radius
self._visualize = visualize
self._vqa_prompt = vqa_prompt
self._coco_threshold = coco_threshold
self._non_coco_threshold = non_coco_threshold

self._num_steps = 0
self._did_reset = False
Expand Down Expand Up @@ -229,7 +233,7 @@ def _get_policy_info(self, detections: ObjectDetections) -> Dict[str, Any]:
def _get_object_detections(self, img: np.ndarray) -> ObjectDetections:
if self._target_object in COCO_CLASSES:
detections = self._coco_object_detector.predict(img)
self._det_conf_threshold = 0.6
self._det_conf_threshold = self._coco_threshold
else:
detections = self._object_detector.predict(img)
detections.phrases = [
Expand All @@ -240,7 +244,7 @@ def _get_object_detections(self, img: np.ndarray) -> ObjectDetections:
detections.phrases = [
p.replace("dining table", "table") for p in detections.phrases
]
self._det_conf_threshold = 0.4
self._det_conf_threshold = self._non_coco_threshold
if self._detect_target_only:
detections.filter_by_class([self._target_object])
detections.filter_by_conf(self._det_conf_threshold)
Expand Down Expand Up @@ -400,6 +404,8 @@ class ZSOSConfig:
min_obstacle_height: float = 0.61
max_obstacle_height: float = 0.88
vqa_prompt: str = "Is this "
coco_threshold: float = 0.6
non_coco_threshold: float = 0.4

@classmethod
@property
Expand Down
7 changes: 3 additions & 4 deletions zsos/vlm/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,10 @@ def ask(self, image, prompt=None) -> str:
"""
pil_img = Image.fromarray(image)
processed_image = (
self.vis_processors["eval"](pil_img).unsqueeze(0).to(self.device)
)

with torch.inference_mode():
processed_image = (
self.vis_processors["eval"](pil_img).unsqueeze(0).to(self.device)
)
if prompt is None or prompt == "":
out = self.model.generate({"image": processed_image})[0]
else:
Expand Down

0 comments on commit 83d053a

Please sign in to comment.