diff --git a/config/experiments/llm_objectnav_hm3d.yaml b/config/experiments/llm_objectnav_hm3d.yaml index c398783..c7beea7 100644 --- a/config/experiments/llm_objectnav_hm3d.yaml +++ b/config/experiments/llm_objectnav_hm3d.yaml @@ -11,6 +11,7 @@ defaults: - frontier_sensor - /habitat/task/measurements: - frontier_exploration_map + - /habitat_baselines/rl/policy: zsos_policy - _self_ habitat: diff --git a/zsos/policy/base_objectnav_policy.py b/zsos/policy/base_objectnav_policy.py index f7b6704..a2a6aa0 100644 --- a/zsos/policy/base_objectnav_policy.py +++ b/zsos/policy/base_objectnav_policy.py @@ -11,7 +11,6 @@ WrappedPointNavResNetPolicy, rho_theta_from_gps_compass_goal, ) -from zsos.utils.geometry_utils import xyz_yaw_to_tf_matrix from zsos.vlm.grounding_dino import GroundingDINOClient, ObjectDetections try: @@ -25,45 +24,50 @@ class BasePolicy: class BaseObjectNavPolicy(BasePolicy): - target_object: str = "" - camera_height: float = 0.88 - depth_image_shape: Tuple[int, int] = (244, 224) - det_conf_threshold: float = 0.50 - pointnav_stop_radius: float = 0.85 - visualize: bool = True - policy_info: Dict[str, Any] = {} - id_to_padding: Dict[str, float] = {} - _stop_action: Tensor = None # must be set by subclass - # ObjectMap parameters; these must be set by subclass - min_depth: float = None - max_depth: float = None - hfov: float = None - proximity_threshold: float = None - - def __init__(self, *args, **kwargs): + _target_object: str = "" + _policy_info: Dict[str, Any] = {} + _id_to_padding: Dict[str, float] = {} + _stop_action: Tensor = None # MUST BE SET BY SUBCLASS + + def __init__( + self, + pointnav_policy_path: str, + depth_image_shape: Tuple[int, int], + det_conf_threshold: float, + pointnav_stop_radius: float, + object_map_min_depth: float, + object_map_max_depth: float, + object_map_hfov: float, + object_map_proximity_threshold: float, + visualize: bool = True, + *args, + **kwargs, + ): super().__init__() - self.object_detector = GroundingDINOClient() - self.pointnav_policy = WrappedPointNavResNetPolicy( - os.environ["POINTNAV_POLICY_PATH"] - ) - self.object_map: ObjectMap = ObjectMap( - min_depth=self.min_depth, - max_depth=self.max_depth, - hfov=self.hfov, - proximity_threshold=self.proximity_threshold, + self._object_detector = GroundingDINOClient() + self._pointnav_policy = WrappedPointNavResNetPolicy(pointnav_policy_path) + self._object_map: ObjectMap = ObjectMap( + min_depth=object_map_min_depth, + max_depth=object_map_max_depth, + hfov=object_map_hfov, + proximity_threshold=object_map_proximity_threshold, ) + self._depth_image_shape = tuple(depth_image_shape) + self._det_conf_threshold = det_conf_threshold + self._pointnav_stop_radius = pointnav_stop_radius + self._visualize = visualize - self.num_steps = 0 - self.last_goal = np.zeros(2) - self.done_initializing = False + self._num_steps = 0 + self._last_goal = np.zeros(2) + self._done_initializing = False def _reset(self): - self.target_object = "" - self.pointnav_policy.reset() - self.object_map.reset() - self.last_goal = np.zeros(2) - self.num_steps = 0 - self.done_initializing = False + self._target_object = "" + self._pointnav_policy.reset() + self._object_map.reset() + self._last_goal = np.zeros(2) + self._num_steps = 0 + self._done_initializing = False def act( self, observations, rnn_hidden_states, prev_actions, masks, deterministic=False @@ -78,17 +82,15 @@ def act( assert masks.shape[1] == 1, "Currently only supporting one env at a time" if masks[0] == 0: self._reset() - self.target_object = observations["objectgoal"] + self._target_object = observations["objectgoal"] - self.policy_info = {} + self._policy_info = {} - rgb, depth, tf_camera_to_episodic = self._get_detection_camera_info( - observations - ) + rgb, depth, tf_camera_to_episodic = self._get_object_camera_info(observations) detections = self._update_object_map(rgb, depth, tf_camera_to_episodic) goal = self._get_target_object_location() - if not self.done_initializing: # Initialize + if not self._done_initializing: # Initialize pointnav_action = self._initialize() elif goal is None: # Haven't found target object yet pointnav_action = self._explore(observations) @@ -97,8 +99,8 @@ def act( observations, goal[:2], deterministic=deterministic, stop=True ) - self.policy_info = self._get_policy_info(observations, detections) - self.num_steps += 1 + self._policy_info = self._get_policy_info(observations, detections) + self._num_steps += 1 return pointnav_action, rnn_hidden_states @@ -110,7 +112,7 @@ def _explore(self, observations: "TensorDict") -> Tensor: def _get_target_object_location(self) -> Union[None, np.ndarray]: try: - return self.object_map.get_best_object(self.target_object) + return self._object_map.get_best_object(self._target_object) except ValueError: # Target object has not been spotted return None @@ -118,10 +120,10 @@ def _get_target_object_location(self) -> Union[None, np.ndarray]: def _get_policy_info( self, observations: "TensorDict", detections: ObjectDetections ) -> Dict[str, Any]: - seen_objects = set(i.class_name for i in self.object_map.map) + seen_objects = set(i.class_name for i in self._object_map.map) seen_objects_str = ", ".join(seen_objects) policy_info = { - "target_object": "target: " + self.target_object, + "target_object": "target: " + self._target_object, "visualized_detections": detections.annotated_frame, "seen_objects": seen_objects_str, "gps": str(observations["gps"][0].cpu().numpy()), @@ -140,8 +142,8 @@ def _get_policy_info( return policy_info def _get_object_detections(self, img: np.ndarray) -> ObjectDetections: - detections = self.object_detector.predict(img, visualize=self.visualize) - detections.filter_by_conf(self.det_conf_threshold) + detections = self._object_detector.predict(img, visualize=self._visualize) + detections.filter_by_conf(self._det_conf_threshold) return detections @@ -160,48 +162,38 @@ def _pointnav( Args: observations ("TensorDict"): The observations from the current timestep. """ - masks = torch.tensor([self.num_steps != 0], dtype=torch.bool, device="cuda") - if not np.array_equal(goal, self.last_goal): - self.last_goal = goal - self.pointnav_policy.reset() + masks = torch.tensor([self._num_steps != 0], dtype=torch.bool, device="cuda") + if not np.array_equal(goal, self._last_goal): + self._last_goal = goal + self._pointnav_policy.reset() masks = torch.zeros_like(masks) rho_theta = rho_theta_from_gps_compass_goal(observations, goal) obs_pointnav = { "depth": image_resize( observations["depth"], - self.depth_image_shape, + self._depth_image_shape, channels_last=True, interpolation_mode="area", ), "pointgoal_with_gps_compass": rho_theta.unsqueeze(0), } - stop_dist = self.pointnav_stop_radius + self.id_to_padding.get( - self.target_object, 0.0 + stop_dist = self._pointnav_stop_radius + self._id_to_padding.get( + self._target_object, 0.0 ) if rho_theta[0] < stop_dist and stop: return self._stop_action - action = self.pointnav_policy.act( + action = self._pointnav_policy.act( obs_pointnav, masks, deterministic=deterministic ) return action - def _get_detection_camera_info(self, observations: "TensorDict") -> Tuple: - rgb = observations["rgb"][0].cpu().numpy() - depth = observations["depth"][0].cpu().numpy() - x, y = observations["gps"][0].cpu().numpy() - camera_yaw = observations["compass"][0].cpu().item() - # Habitat GPS makes west negative, so flip y - camera_position = np.array([x, -y, self.camera_height]) - tf_camera_to_episodic = xyz_yaw_to_tf_matrix(camera_position, camera_yaw) - return rgb, depth, tf_camera_to_episodic - def _update_object_map( self, rgb: np.ndarray, depth: np.ndarray, tf_camera_to_episodic: np.ndarray ) -> ObjectDetections: detections = self._get_object_detections(rgb) for idx, confidence in enumerate(detections.logits): - self.object_map.update_map( + self._object_map.update_map( detections.phrases[idx], detections.boxes[idx], depth, @@ -209,6 +201,22 @@ def _update_object_map( confidence, ) - self.object_map.update_explored(tf_camera_to_episodic) + self._object_map.update_explored(tf_camera_to_episodic) return detections + + def _get_object_camera_info( + self, observations: "TensorDict" + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Extracts the rgb, depth, and camera transform from the observations. + + Args: + observations ("TensorDict"): The observations from the current timestep. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray]: The rgb image, depth image, and + camera transform. The depth image is normalized to be between 0 and 1. + The camera transform is the transform from the camera to the episodic + frame, a 4x4 transformation matrix. + """ + raise NotImplementedError diff --git a/zsos/policy/habitat_policies.py b/zsos/policy/habitat_policies.py index ea1ad2e..a5d5bb2 100644 --- a/zsos/policy/habitat_policies.py +++ b/zsos/policy/habitat_policies.py @@ -1,16 +1,24 @@ -from typing import Any, Dict, Union +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple, Union +import numpy as np import torch from habitat.tasks.nav.nav import HeadingSensor from habitat.tasks.nav.object_nav_task import ObjectGoalSensor from habitat_baselines.common.baseline_registry import baseline_registry from habitat_baselines.common.tensor_dict import TensorDict +from habitat_baselines.config.default_structured_configs import ( + PolicyConfig, +) from habitat_baselines.rl.ppo.policy import PolicyActionData +from hydra.core.config_store import ConfigStore +from omegaconf import DictConfig from torch import Tensor from frontier_exploration.base_explorer import BaseExplorer +from zsos.utils.geometry_utils import xyz_yaw_to_tf_matrix +from zsos.vlm.grounding_dino import ObjectDetections -from ..vlm.detections import ObjectDetections from .base_objectnav_policy import BaseObjectNavPolicy from .itm_policy import ITMPolicy @@ -25,17 +33,35 @@ class TorchActionIDs: class HabitatMixin: - id_to_padding: Dict[str, float] = { + """This Python mixin only contains code relevant for running a BaseObjectNavPolicy + explicitly within Habitat (vs. the real world, etc.) and will endow any parent class + (that is a subclass of BaseObjectNavPolicy) with the necessary methods to run in + Habitat. + """ + + _id_to_padding: Dict[str, float] = { "bed": 0.3, "couch": 0.15, } _stop_action: Tensor = TorchActionIDs.STOP _start_yaw: Union[float, None] = None # must be set by _reset() method - # ObjectMap parameters - min_depth: float = 0.5 - max_depth: float = 5.0 - hfov: float = 79.0 - proximity_threshold: float = 1.5 + + def __init__(self, camera_height: float, *args: Any, **kwargs: Any) -> None: + self._camera_height = camera_height + super().__init__(*args, **kwargs) + + @classmethod + def from_config(cls, config: DictConfig, *args_unused, **kwargs_unused): + policy_config: ZSOSPolicyConfig = config.habitat_baselines.rl.policy + kwargs = { + k: policy_config[k] for k in ZSOSPolicyConfig.arg_names() # type: ignore + } + + # In habitat, we need the height of the camera to generate the camera transform + agent_config = config.habitat.simulator.agents.main_agent + kwargs["camera_height"] = agent_config.sim_sensors.rgb_sensor.position[1] + + return cls(**kwargs) def act( self: BaseObjectNavPolicy, @@ -56,12 +82,12 @@ def act( return PolicyActionData( actions=action, rnn_hidden_states=rnn_hidden_states, - policy_info=[self.policy_info], + policy_info=[self._policy_info], ) def _initialize(self) -> Tensor: """Turn left 30 degrees 12 times to get a 360 view at the beginning""" - self.done_initializing = not self.num_steps < 11 # type: ignore + self._done_initializing = not self._num_steps < 11 # type: ignore return TorchActionIDs.TURN_LEFT def _reset(self) -> None: @@ -70,7 +96,7 @@ def _reset(self) -> None: self._start_yaw = None def _get_policy_info( - self, observations: "TensorDict", detections: ObjectDetections + self, observations: TensorDict, detections: ObjectDetections ) -> Dict[str, Any]: """Get policy info for logging""" parent_cls: BaseObjectNavPolicy = super() # type: ignore @@ -80,6 +106,29 @@ def _get_policy_info( info["start_yaw"] = self._start_yaw return info + def _get_object_camera_info( + self, observations: TensorDict + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Extracts the rgb, depth, and camera transform from the observations. + + Args: + observations (TensorDict): The observations from the current timestep. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray]: The rgb image, depth image, and + camera transform. The depth image is normalized to be between 0 and 1. + The camera transform is the transform from the camera to the episodic + frame, a 4x4 transformation matrix. + """ + rgb = observations["rgb"][0].cpu().numpy() + depth = observations["depth"][0].cpu().numpy() + x, y = observations["gps"][0].cpu().numpy() + camera_yaw = observations["compass"][0].cpu().item() + # Habitat GPS makes west negative, so flip y + camera_position = np.array([x, -y, self._camera_height]) + tf_camera_to_episodic = xyz_yaw_to_tf_matrix(camera_position, camera_yaw) + return rgb, depth, tf_camera_to_episodic + @baseline_registry.register_policy class OracleFBEPolicy(HabitatMixin, BaseObjectNavPolicy): @@ -97,10 +146,47 @@ def act( return PolicyActionData( actions=observations[BaseExplorer.cls_uuid], rnn_hidden_states=rnn_hidden_states, - policy_info=[self.policy_info], + policy_info=[self._policy_info], ) @baseline_registry.register_policy class HabitatITMPolicy(HabitatMixin, ITMPolicy): pass + + +@dataclass +class ZSOSPolicyConfig(PolicyConfig): + name: str = "HabitatITMPolicy" + pointnav_policy_path: str = "data/pointnav_weights.pth" + depth_image_shape: Tuple[int, int] = (244, 224) + det_conf_threshold: float = 0.6 + pointnav_stop_radius: float = 0.85 + object_map_min_depth: float = 0.5 + object_map_max_depth: float = 5.0 + object_map_hfov: float = 79.0 + value_map_max_depth: float = 5.0 + value_map_hfov: float = 79.0 + object_map_proximity_threshold: float = 1.5 + visualize: bool = True + + @classmethod + def arg_names(cls) -> List[str]: + # All the above except "name". Also excludes all attributes from parent classes. + return [ + "pointnav_policy_path", + "depth_image_shape", + "det_conf_threshold", + "pointnav_stop_radius", + "object_map_min_depth", + "object_map_max_depth", + "object_map_hfov", + "object_map_proximity_threshold", + "value_map_max_depth", + "value_map_hfov", + "visualize", + ] + + +cs = ConfigStore.instance() +cs.store(group="habitat_baselines/rl/policy", name="zsos_policy", node=ZSOSPolicyConfig) diff --git a/zsos/policy/itm_policy.py b/zsos/policy/itm_policy.py index af0bc05..c403e46 100644 --- a/zsos/policy/itm_policy.py +++ b/zsos/policy/itm_policy.py @@ -18,23 +18,25 @@ class ITMPolicy(BaseObjectNavPolicy): - target_object_color: Tuple[int, int, int] = (0, 255, 0) - selected_frontier_color: Tuple[int, int, int] = (0, 255, 255) - frontier_color: Tuple[int, int, int] = (0, 0, 255) - circle_marker_thickness: int = 2 - circle_marker_radius: int = 5 + _target_object_color: Tuple[int, int, int] = (0, 255, 0) + _selected__frontier_color: Tuple[int, int, int] = (0, 255, 255) + _frontier_color: Tuple[int, int, int] = (0, 0, 255) + _circle_marker_thickness: int = 2 + _circle_marker_radius: int = 5 - def __init__(self, *args, **kwargs): - super().__init__() + def __init__( + self, value_map_max_depth: float, value_map_hfov: float, *args, **kwargs + ): + super().__init__(*args, **kwargs) self.itm = BLIP2ITMClient() self.frontier_map: FrontierMap = FrontierMap() - self.value_map: ValueMap = ValueMap(fov=self.hfov, max_depth=self.max_depth) + self.value_map: ValueMap = ValueMap( + fov=value_map_hfov, max_depth=value_map_max_depth + ) def act(self, observations: TensorDict, *args, **kwargs) -> Tuple[Tensor, Tensor]: - rgb, depth, tf_camera_to_episodic = self._get_detection_camera_info( - observations - ) - text = f"Seems like there is a {self.target_object} ahead." + rgb, depth, tf_camera_to_episodic = self._get_object_camera_info(observations) + text = f"Seems like there is a {self._target_object} ahead." curr_cosine = self.frontier_map._encode(rgb, text) self.value_map.update_map(depth, tf_camera_to_episodic, curr_cosine) @@ -55,22 +57,22 @@ def _get_policy_info( # Draw frontiers on to the cost map base_kwargs = { - "radius": self.circle_marker_radius, - "thickness": self.circle_marker_thickness, + "radius": self._circle_marker_radius, + "thickness": self._circle_marker_thickness, } frontiers = observations["frontier_sensor"][0].cpu().numpy() for frontier in frontiers: - marker_kwargs = {"color": self.frontier_color, **base_kwargs} + marker_kwargs = {"color": self._frontier_color, **base_kwargs} markers.append((frontier[:2], marker_kwargs)) - if not np.array_equal(self.last_goal, np.zeros(2)): + if not np.array_equal(self._last_goal, np.zeros(2)): # Draw the pointnav goal on to the cost map - if any(np.array_equal(self.last_goal, frontier) for frontier in frontiers): - color = self.selected_frontier_color + if any(np.array_equal(self._last_goal, frontier) for frontier in frontiers): + color = self._selected__frontier_color else: - color = self.target_object_color + color = self._target_object_color marker_kwargs = {"color": color, **base_kwargs} - markers.append((self.last_goal, marker_kwargs)) + markers.append((self._last_goal, marker_kwargs)) policy_info["cost_map"] = cv2.cvtColor( self.value_map.visualize(markers), cv2.COLOR_BGR2RGB ) @@ -80,11 +82,11 @@ def _get_policy_info( def _explore(self, observations: Union[Dict[str, Tensor], "TensorDict"]) -> Tensor: frontiers = observations["frontier_sensor"][0].cpu().numpy() rgb = observations["rgb"][0].cpu().numpy() - text = f"Seems like there is a {self.target_object} ahead." + text = f"Seems like there is a {self._target_object} ahead." self.frontier_map.update(frontiers, rgb, text) goal, cosine = self.frontier_map.get_best_frontier() os.environ["DEBUG_INFO"] = f"Best frontier: {cosine:.3f}" - print(f"Step: {self.num_steps} Best frontier: {cosine}") + print(f"Step: {self._num_steps} Best frontier: {cosine}") pointnav_action = self._pointnav( observations, goal[:2], deterministic=True, stop=False ) diff --git a/zsos/policy/llm_policy.py b/zsos/policy/llm_policy.py index a732cd8..ffcdb1f 100644 --- a/zsos/policy/llm_policy.py +++ b/zsos/policy/llm_policy.py @@ -14,7 +14,7 @@ @baseline_registry.register_policy class LLMPolicy(BaseObjectNavPolicy): llm: BaseLLM = None - visualize: bool = True + _visualize: bool = True def __init__(self, *args, **kwargs): super().__init__() @@ -26,7 +26,7 @@ def __init__(self, *args, **kwargs): def _explore(self, observations: TensorDict) -> Tensor: curr_pos = observations["gps"][0].cpu().numpy() * np.array([1, -1]) baseline = True - if np.linalg.norm(self.last_goal - curr_pos) < 0.25: + if np.linalg.norm(self._last_goal - curr_pos) < 0.25: frontiers = observations["frontier_sensor"][0].cpu().numpy() if baseline: goal = frontiers[0] @@ -34,7 +34,7 @@ def _explore(self, observations: TensorDict) -> Tensor: # Ask LLM which waypoint to head to next goal, _ = self._get_llm_goal(curr_pos, frontiers) else: - goal = self.last_goal + goal = self._last_goal pointnav_action = self._pointnav( observations, goal[:2], deterministic=True, stop=False @@ -46,7 +46,7 @@ def _get_llm_goal( self, current_pos: np.ndarray, frontiers: np.ndarray ) -> Tuple[np.ndarray, str]: """ - Asks LLM which object or frontier to go to next. self.object_map is used to + Asks LLM which object or frontier to go to next. self._object_map is used to generate the prompt for the LLM. Args: @@ -58,8 +58,8 @@ def _get_llm_goal( Returns: Tuple[np.ndarray, str]: A tuple containing the goal and the LLM response. """ - prompt, waypoints = self.object_map.get_textual_map_prompt( - self.target_object, current_pos, frontiers + prompt, waypoints = self._object_map.get_textual_map_prompt( + self._target_object, current_pos, frontiers ) resp = self.llm.ask(prompt) int_resp = extract_integer(resp) - 1