From 1231c786b6f848e5dcd781186a1ad565b6b2344c Mon Sep 17 00:00:00 2001 From: Naoki Yokoyama Date: Tue, 12 Sep 2023 17:59:04 -0400 Subject: [PATCH] wrapping video generation in try except, adding confidence fusion ablations --- zsos/mapping/value_map.py | 22 ++++++++++++++++++++++ zsos/semexp_env/eval.py | 5 ++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/zsos/mapping/value_map.py b/zsos/mapping/value_map.py index 37b69fe..6945f9b 100644 --- a/zsos/mapping/value_map.py +++ b/zsos/mapping/value_map.py @@ -43,6 +43,7 @@ def __init__( value_channels: int, size: int = 1000, use_max_confidence: bool = True, + fusion_type: str = "default", ): """ Args: @@ -55,6 +56,7 @@ def __init__( self._value_map = np.zeros((size, size, value_channels), np.float32) self._value_channels = value_channels self._use_max_confidence = use_max_confidence + self._fusion_type = fusion_type if RECORDING: if osp.isdir(RECORDING_DIR): @@ -356,6 +358,26 @@ def _fuse_new_data(self, new_map: np.ndarray, values: np.ndarray) -> None: f"({len(values)}). Expected {self._value_channels}." ) + if self._fusion_type == "replace": + # Ablation. The values from the current observation will overwrite any + # existing values + print("VALUE MAP ABLATION:", self._fusion_type) + new_value_map = np.zeros_like(self._value_map) + new_value_map[new_map > 0] = values + self._map[new_map > 0] = new_map[new_map > 0] + self._value_map[new_map > 0] = new_value_map[new_map > 0] + return + elif self._fusion_type == "equal_weighting": + # Ablation. Updated values will always be the mean of the current and + # new values, meaning that confidence scores are forced to be the same. + print("VALUE MAP ABLATION:", self._fusion_type) + self._map[self._map > 0] = 1 + new_map[new_map > 0] = 1 + else: + assert ( + self._fusion_type == "default" + ), f"Unknown fusion type {self._fusion_type}" + # Any values in the given map that are less confident than # self._decision_threshold AND less than the new_map in the existing map # will be silenced into 0s diff --git a/zsos/semexp_env/eval.py b/zsos/semexp_env/eval.py index d029e71..9cebf7f 100644 --- a/zsos/semexp_env/eval.py +++ b/zsos/semexp_env/eval.py @@ -113,7 +113,10 @@ def main(): "target_object": target_object, } if "VIDEO_DIR" in os.environ: - generate_video(vis_imgs, ep_id, scene_id, data) + try: + generate_video(vis_imgs, ep_id, scene_id, data) + except Exception: + print("Error generating video") if "ZSOS_LOG_DIR" in os.environ and not is_evaluated(ep_id, scene_id): log_episode(ep_id, scene_id, data) break