diff --git a/zsos/policy/base_objectnav_policy.py b/zsos/policy/base_objectnav_policy.py index 6fd3d86..3f4d5c0 100644 --- a/zsos/policy/base_objectnav_policy.py +++ b/zsos/policy/base_objectnav_policy.py @@ -52,6 +52,8 @@ def __init__( obstacle_map_area_threshold: float = 1.5, use_vqa: bool = True, vqa_prompt: str = "Is this ", + coco_threshold: float = 0.6, + non_coco_threshold: float = 0.4, *args, **kwargs, ): @@ -76,6 +78,8 @@ def __init__( self._pointnav_stop_radius = pointnav_stop_radius self._visualize = visualize self._vqa_prompt = vqa_prompt + self._coco_threshold = coco_threshold + self._non_coco_threshold = non_coco_threshold self._num_steps = 0 self._did_reset = False @@ -229,7 +233,7 @@ def _get_policy_info(self, detections: ObjectDetections) -> Dict[str, Any]: def _get_object_detections(self, img: np.ndarray) -> ObjectDetections: if self._target_object in COCO_CLASSES: detections = self._coco_object_detector.predict(img) - self._det_conf_threshold = 0.6 + self._det_conf_threshold = self._coco_threshold else: detections = self._object_detector.predict(img) detections.phrases = [ @@ -240,7 +244,7 @@ def _get_object_detections(self, img: np.ndarray) -> ObjectDetections: detections.phrases = [ p.replace("dining table", "table") for p in detections.phrases ] - self._det_conf_threshold = 0.4 + self._det_conf_threshold = self._non_coco_threshold if self._detect_target_only: detections.filter_by_class([self._target_object]) detections.filter_by_conf(self._det_conf_threshold) @@ -400,6 +404,8 @@ class ZSOSConfig: min_obstacle_height: float = 0.61 max_obstacle_height: float = 0.88 vqa_prompt: str = "Is this " + coco_threshold: float = 0.6 + non_coco_threshold: float = 0.4 @classmethod @property diff --git a/zsos/vlm/blip2.py b/zsos/vlm/blip2.py index 1bd6847..f2dccd3 100644 --- a/zsos/vlm/blip2.py +++ b/zsos/vlm/blip2.py @@ -39,11 +39,10 @@ def ask(self, image, prompt=None) -> str: """ pil_img = Image.fromarray(image) - processed_image = ( - self.vis_processors["eval"](pil_img).unsqueeze(0).to(self.device) - ) - with torch.inference_mode(): + processed_image = ( + self.vis_processors["eval"](pil_img).unsqueeze(0).to(self.device) + ) if prompt is None or prompt == "": out = self.model.generate({"image": processed_image})[0] else: