Skip to content

Commit

Permalink
more MP3D class remapping
Browse files Browse the repository at this point in the history
  • Loading branch information
naokiyokoyamabd committed Sep 2, 2023
1 parent fd10075 commit a9919dd
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 1 deletion.
94 changes: 94 additions & 0 deletions scripts/parse_jsons.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,89 @@ def calculate_avg_performance(stats: List[Dict[str, Any]]) -> None:
print(table)


def calculate_avg_fail_per_category(stats: List[Dict[str, Any]]) -> None:
"""
For each possible "target_object", calculate the average failure rate.
Args:
stats (List[Dict[str, Any]]): A list of stats for each episode.
"""
# Create a dictionary to store the fail count and total count for each category
category_stats = {}

for episode in stats:
category = episode["target_object"]
success = int(episode["success"]) == 1

if category not in category_stats:
category_stats[category] = {"fail_count": 0, "total_count": 0}

category_stats[category]["total_count"] += 1
if not success:
category_stats[category]["fail_count"] += 1

# Create a table with headers
table = PrettyTable(["Category", "Average Failure Rate"])

# Add each row to the table
for category, stats in sorted(
category_stats.items(),
key=lambda x: (x[1]["fail_count"] / x[1]["total_count"]),
reverse=True,
):
avg_failure_rate = (stats["fail_count"] / stats["total_count"]) * 100
table.add_row(
[
category,
(
f"{avg_failure_rate:.2f}% ({stats['fail_count']}/"
f"{stats['total_count']})"
),
]
)

print(table)


def calculate_avg_fail_rate_per_category(
stats: List[Dict[str, Any]], failure_cause: str
) -> None:
"""
For each possible "target_object", count the number of times the agent failed due to
the given failure cause. Then, sum the counts across all categories and use it to
divide the per category failure count to get the average failure rate for each
category.
Args:
stats (List[Dict[str, Any]]): A list of stats for each episode.
"""
category_to_fail_count = {}
total_fail_count = 0
for episode in stats:
if episode["failure_cause"] != failure_cause:
continue
total_fail_count += 1
category = episode["target_object"]
if category not in category_to_fail_count:
category_to_fail_count[category] = 0
category_to_fail_count[category] += 1

# Create a table with headers
table = PrettyTable(["Category", f"% Occurrence for {failure_cause}"])

# Sort the categories by their failure count in descending order
sorted_categories = sorted(
category_to_fail_count.items(), key=lambda x: x[1], reverse=True
)

# Add each row to the table
for category, count in sorted_categories:
percentage = (count / total_fail_count) * 100
table.add_row([category, f"{percentage:.2f}% ({count})"])

print(table)


def main() -> None:
"""
Main function to parse command line arguments and process the directory.
Expand All @@ -87,6 +170,17 @@ def main() -> None:
print()
calculate_avg_performance(episode_stats)

print()
calculate_avg_fail_per_category(episode_stats)

print()
print("Conditioned on failure cause: false_positive")
calculate_avg_fail_rate_per_category(episode_stats, "false_positive")

print()
print("Conditioned on failure cause: false_negative")
calculate_avg_fail_rate_per_category(episode_stats, "false_negative")


if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions zsos/policy/base_objectnav_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,18 @@ def _get_policy_info(self, detections: ObjectDetections) -> Dict[str, Any]:
def _get_object_detections(self, img: np.ndarray) -> ObjectDetections:
if self._target_object in COCO_CLASSES:
detections = self._coco_object_detector.predict(img)
self._det_conf_threshold = 0.8
detections.phrases = [
p.replace("cupboard", "cabinet") for p in detections.phrases
]
else:
detections = self._object_detector.predict(img)
if self._target_object == "table" and detections.num_detections == 0:
detections = self._coco_object_detector.predict(img)
detections.phrases = [
p.replace("dining table", "table") for p in detections.phrases
]
self._det_conf_threshold = 0.6
if self._detect_target_only:
detections.filter_by_class([self._target_object])
detections.filter_by_conf(self._det_conf_threshold)
Expand Down
2 changes: 1 addition & 1 deletion zsos/policy/habitat_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
HM3D_ID_TO_NAME = ["chair", "bed", "potted plant", "toilet", "tv", "couch"]
MP3D_ID_TO_NAME = [
"chair",
"dining table", # "table",
"table",
"picture",
"cabinet",
"pillow", # "cushion",
Expand Down
1 change: 1 addition & 0 deletions zsos/vlm/classes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ fireplace
gym equipment
seating
clothes
cupboard

0 comments on commit a9919dd

Please sign in to comment.