diff --git a/zsos/policy/base_objectnav_policy.py b/zsos/policy/base_objectnav_policy.py index 7218b26..68e1600 100644 --- a/zsos/policy/base_objectnav_policy.py +++ b/zsos/policy/base_objectnav_policy.py @@ -143,7 +143,12 @@ def _pre_step(self, observations: "TensorDict", masks: Tensor) -> None: if not self._did_reset and masks[0] == 0: self._reset() self._target_object = observations["objectgoal"] - self._cache_observations(observations) + try: + self._cache_observations(observations) + except IndexError as e: + print(e) + print("Reached edge of map, stopping.") + raise StopIteration self._policy_info = {} def _initialize(self) -> Tensor: diff --git a/zsos/policy/habitat_policies.py b/zsos/policy/habitat_policies.py index 7e57cc5..de00456 100644 --- a/zsos/policy/habitat_policies.py +++ b/zsos/policy/habitat_policies.py @@ -133,9 +133,12 @@ def act( else: raise ValueError(f"Dataset type {self._dataset_type} not recognized") parent_cls: BaseObjectNavPolicy = super() # type: ignore - action, rnn_hidden_states = parent_cls.act( - obs_dict, rnn_hidden_states, prev_actions, masks, deterministic - ) + try: + action, rnn_hidden_states = parent_cls.act( + obs_dict, rnn_hidden_states, prev_actions, masks, deterministic + ) + except StopIteration: + action = self._stop_action return PolicyActionData( actions=action, rnn_hidden_states=rnn_hidden_states,