From 05281111aeecb10ee0d2cc810b81c47870bd3eec Mon Sep 17 00:00:00 2001 From: Naoki Yokoyama Date: Fri, 11 Aug 2023 20:54:42 -0400 Subject: [PATCH] use hydra to turn off all visualization features when video_option is [] --- zsos/policy/habitat_policies.py | 9 +++++++-- zsos/policy/itm_policy.py | 4 ++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/zsos/policy/habitat_policies.py b/zsos/policy/habitat_policies.py index a5d5bb2..d4d76fc 100644 --- a/zsos/policy/habitat_policies.py +++ b/zsos/policy/habitat_policies.py @@ -61,6 +61,9 @@ def from_config(cls, config: DictConfig, *args_unused, **kwargs_unused): agent_config = config.habitat.simulator.agents.main_agent kwargs["camera_height"] = agent_config.sim_sensors.rgb_sensor.position[1] + # Only bother visualizing if we're actually going to save the video + kwargs["visualize"] = len(config.habitat_baselines.eval.video_option) > 0 + return cls(**kwargs) def act( @@ -101,6 +104,10 @@ def _get_policy_info( """Get policy info for logging""" parent_cls: BaseObjectNavPolicy = super() # type: ignore info = parent_cls._get_policy_info(observations, detections) + + if not self._visualize: # type: ignore + return info + if self._start_yaw is None: self._start_yaw = observations[HeadingSensor.cls_uuid][0].item() info["start_yaw"] = self._start_yaw @@ -168,7 +175,6 @@ class ZSOSPolicyConfig(PolicyConfig): value_map_max_depth: float = 5.0 value_map_hfov: float = 79.0 object_map_proximity_threshold: float = 1.5 - visualize: bool = True @classmethod def arg_names(cls) -> List[str]: @@ -184,7 +190,6 @@ def arg_names(cls) -> List[str]: "object_map_proximity_threshold", "value_map_max_depth", "value_map_hfov", - "visualize", ] diff --git a/zsos/policy/itm_policy.py b/zsos/policy/itm_policy.py index c403e46..365f2a9 100644 --- a/zsos/policy/itm_policy.py +++ b/zsos/policy/itm_policy.py @@ -53,6 +53,10 @@ def _get_policy_info( detections: ObjectDetections, ) -> Dict[str, Any]: policy_info = super()._get_policy_info(observations, detections) + + if not self._visualize: + return policy_info + markers = [] # Draw frontiers on to the cost map