diff --git a/zsos/utils/episode_stats_logger.py b/zsos/utils/episode_stats_logger.py index e2f5878..5d5b32c 100644 --- a/zsos/utils/episode_stats_logger.py +++ b/zsos/utils/episode_stats_logger.py @@ -62,9 +62,7 @@ def determine_failure_cause(infos: Dict) -> str: Returns: A string describing the cause of failure. """ - if not infos["top_down_map"]["is_feasible"]: - return "infeasible" - elif infos["target_detected"]: + if infos["target_detected"]: if was_false_positive(infos): return "false_positive" else: @@ -77,9 +75,13 @@ def determine_failure_cause(infos: Dict) -> str: return "false_negative" else: if infos["traveled_stairs"]: - return "never_saw_target_traveled_stairs" + cause = "never_saw_target_traveled_stairs" else: - return "never_saw_target" + cause = "never_saw_target_did_not_travel_stairs" + if not infos["top_down_map"]["is_feasible"]: + return cause + "_likely_infeasible" + else: + return cause + "_feasible" def was_target_seen(infos: Dict[str, Any]) -> bool: @@ -116,7 +118,11 @@ def was_false_positive(infos: Dict[str, Any]) -> bool: remove_duplicates=True, ) - return target_bboxes_mask[grid_xy[0, 0], grid_xy[0, 1]] == 0 + try: + return target_bboxes_mask[grid_xy[0, 0], grid_xy[0, 1]] == 0 + except IndexError: + # If the point goal is outside the map, assume it is a false positive + return True def remove_numpy_arrays(d: Dict) -> Dict: