Skip to content

Commit

Permalink
adding frontier values to acyclic check, adding step id to gibson vid…
Browse files Browse the repository at this point in the history
…eos, treating detections at fringes of frame as lesser than others
  • Loading branch information
naokiyokoyamabd committed Sep 8, 2023
1 parent 82ac534 commit 8f070c8
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 24 deletions.
9 changes: 5 additions & 4 deletions scripts/parse_jsons.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def calculate_avg_performance(stats: List[Dict[str, Any]]) -> None:
stats (List[Dict[str, Any]]): A list of stats for each episode.
"""
success, spl, soft_spl = [
[episode[k] for episode in stats] for k in ["success", "spl", "soft_spl"]
[episode.get(k, -1) for episode in stats]
for k in ["success", "spl", "soft_spl"]
]

# Create a table with headers
Expand Down Expand Up @@ -168,12 +169,12 @@ def main() -> None:
episode_stats = read_json_files(args.directory)
print(f"\nTotal episodes: {len(episode_stats)}\n")

failure_causes = [episode["failure_cause"] for episode in episode_stats]
calculate_frequencies(failure_causes)

print()
calculate_avg_performance(episode_stats)

failure_causes = [episode["failure_cause"] for episode in episode_stats]
calculate_frequencies(failure_causes)

if args.compact:
return

Expand Down
31 changes: 28 additions & 3 deletions zsos/mapping/object_point_cloud_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,12 @@ def update_map(
if len(local_cloud) == 0:
return

# Mark all points of local_cloud whose distance from the camera is too far
# as being out of range
within_range = local_cloud[:, 0] <= max_depth * 0.95 # 5% margin
if too_offset(object_mask):
within_range = np.zeros_like(local_cloud[:, 0])
else:
# Mark all points of local_cloud whose distance from the camera is too far
# as being out of range
within_range = local_cloud[:, 0] <= max_depth * 0.95 # 5% margin
global_cloud = transform_points(tf_camera_to_episodic, local_cloud)
global_cloud = np.concatenate((global_cloud, within_range[:, None]), axis=1)

Expand Down Expand Up @@ -234,3 +237,25 @@ def get_random_subarray(points: np.ndarray, size: int) -> np.ndarray:
return points
indices = np.random.choice(len(points), size, replace=False)
return points[indices]


def too_offset(mask: np.ndarray) -> bool:
"""
This will return true if the entire bounding rectangle of the mask is either on the
left or right third of the mask. This is used to determine if the object is too far
to the side of the image to be a reliable detection.
Args:
mask (numpy array): A 2D numpy array of 0s and 1s representing the mask of the
object.
Returns:
bool: True if the object is too offset, False otherwise.
"""
# Find the bounding rectangle of the mask
x, y, w, h = cv2.boundingRect(mask)

# Calculate the thirds of the mask
third = mask.shape[1] // 3

# Check if the entire bounding rectangle is in the left or right third of the mask
return x + w <= third or x >= 2 * third
12 changes: 7 additions & 5 deletions zsos/policy/itm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def _get_best_frontier(
)
robot_xy = self._observations_cache["robot_xy"]
best_frontier_idx = None
top_two_values = tuple(sorted_values[:2])

os.environ["DEBUG_INFO"] = ""
# If there is a last point pursued, then we consider sticking to pursuing it
Expand Down Expand Up @@ -118,21 +119,22 @@ def _get_best_frontier(
if curr_value + 0.01 > self._last_value:
# The last point pursued is still in the list of frontiers and its
# value is not much worse than self._last_value
print("Sticking to last point.")
os.environ["DEBUG_INFO"] += "Sticking to last point. "
best_frontier_idx = curr_index

# If there is no last point pursued, then just take the best point, given that
# it is not cyclic.
if best_frontier_idx is None:
for idx, frontier in enumerate(sorted_pts):
cyclic = self._acyclic_enforcer.check_cyclic(robot_xy, frontier)
cyclic = self._acyclic_enforcer.check_cyclic(
robot_xy, frontier, top_two_values
)
if cyclic:
print("Suppressed cyclic frontier.")
continue
best_frontier_idx = idx
break
else:
print("Sticking to last point.")
os.environ["DEBUG_INFO"] += "Sticking to last point. "

if best_frontier_idx is None:
print("All frontiers are cyclic. Just choosing the closest one.")
Expand All @@ -144,7 +146,7 @@ def _get_best_frontier(

best_frontier = sorted_pts[best_frontier_idx]
best_value = sorted_values[best_frontier_idx]
self._acyclic_enforcer.add_state_action(robot_xy, best_frontier)
self._acyclic_enforcer.add_state_action(robot_xy, best_frontier, top_two_values)
self._last_value = best_value
self._last_frontier = best_frontier
os.environ["DEBUG_INFO"] += f" Best value: {best_value*100:.2f}%"
Expand Down
20 changes: 9 additions & 11 deletions zsos/policy/utils/acyclic_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,26 @@


class StateAction:
def __init__(self, position: np.ndarray, action: Any):
def __init__(self, position: np.ndarray, action: Any, other: Any = None):
self.position = position
self.action = action

def __eq__(self, other: "StateAction") -> bool:
dist1 = np.linalg.norm(self.position - other.position)
dist2 = np.linalg.norm(self.action - other.action)
return dist1 < 0.5 and dist2 < 0.5
self.other = other

def __hash__(self) -> int:
string_repr = f"{self.position}_{self.action}"
string_repr = f"{self.position}_{self.action}_{self.other}"
return hash(string_repr)


class AcyclicEnforcer:
history: Set[StateAction] = set()

def check_cyclic(self, position: np.ndarray, action: Any) -> bool:
state_action = StateAction(position, action)
def check_cyclic(
self, position: np.ndarray, action: Any, other: Any = None
) -> bool:
state_action = StateAction(position, action, other)
cyclic = state_action in self.history
return cyclic

def add_state_action(self, position: np.ndarray, action: Any):
state_action = StateAction(position, action)
def add_state_action(self, position: np.ndarray, action: Any, other: Any = None):
state_action = StateAction(position, action, other)
self.history.add(state_action)
5 changes: 4 additions & 1 deletion zsos/semexp_env/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from zsos.semexp_env.semexp_policy import SemExpITMPolicyV3
from zsos.utils.img_utils import reorient_rescale_map, resize_images
from zsos.utils.log_saver import is_evaluated, log_episode
from zsos.utils.visualization import add_text_to_image

os.environ["OMP_NUM_THREADS"] = "1"

Expand Down Expand Up @@ -79,7 +80,9 @@ def main():
action, policy_infos = policy.act(obs_dict, masks)

if "VIDEO_DIR" in os.environ:
vis_imgs.append(create_frame(policy_infos))
frame = create_frame(policy_infos)
frame = add_text_to_image(frame, "Step: " + str(step), top=True)
vis_imgs.append(frame)

action = action.squeeze(0)

Expand Down

0 comments on commit 8f070c8

Please sign in to comment.