Skip to content

Commit

Permalink
AcyclicEnforcer tracks trajectory, rejects action if already attempte…
Browse files Browse the repository at this point in the history
…d from current state to mitigate flip-flopping
  • Loading branch information
naokiyokoyamabd committed Aug 14, 2023
1 parent 8dd849b commit 97751ad
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 17 deletions.
28 changes: 14 additions & 14 deletions zsos/mapping/value_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,34 +123,34 @@ def update_map(
with open(JSON_PATH, "w") as f:
json.dump(data, f)

def select_best_waypoint(
def sort_waypoints(
self, waypoints: np.ndarray, radius: float
) -> Tuple[np.ndarray, float]:
) -> Tuple[List[np.ndarray], List[float]]:
"""Selects the best waypoint from the given list of waypoints.
Args:
waypoints (np.ndarray): An array of 2D waypoints to choose from.
Returns:
Tuple[np.ndarray, float]: The best waypoint and its associated value.
Tuple[List[np.ndarray], List[float]]: The best waypoint and its associated value.
"""
radius_px = int(radius * self.pixels_per_meter)
best_idx, best_value = 0, -np.inf

for i, waypoint in enumerate(waypoints):
# Convert to pixel units
x, y = waypoint
def get_value(point: np.ndarray) -> float:
x, y = point
px = int(-x * self.pixels_per_meter) + self.episode_pixel_origin[0]
py = int(-y * self.pixels_per_meter) + self.episode_pixel_origin[1]
waypoint_px = (self.value_map.shape[0] - px, py)
value = max_pixel_value_within_radius(
self.value_map, waypoint_px, radius_px
)
point_px = (self.value_map.shape[0] - px, py)
value = max_pixel_value_within_radius(self.value_map, point_px, radius_px)
return value

if value > best_value:
best_idx, best_value = i, value
values = [get_value(point) for point in waypoints]
# Use np.argsort to get the indices of the sorted values
sorted_inds = np.argsort([-v for v in values]) # sort in descending order
sorted_values = [values[i] for i in sorted_inds]
sorted_frontiers = [waypoints[i] for i in sorted_inds]

return waypoints[best_idx], best_value
return sorted_frontiers, sorted_values

def visualize(
self, markers: Optional[List[Tuple[np.ndarray, Dict[str, Any]]]] = None
Expand Down
52 changes: 49 additions & 3 deletions zsos/policy/itm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from zsos.mapping.frontier_map import FrontierMap
from zsos.mapping.value_map import ValueMap
from zsos.policy.base_objectnav_policy import BaseObjectNavPolicy
from zsos.policy.utils.acyclic_enforcer import AcyclicEnforcer
from zsos.vlm.blip2itm import BLIP2ITMClient
from zsos.vlm.detections import ObjectDetections

Expand Down Expand Up @@ -111,17 +112,62 @@ def _explore(self, observations: Union[Dict[str, Tensor], "TensorDict"]) -> Tens


class ITMPolicyV2(BaseITMPolicy):
_acyclic_enforcer: AcyclicEnforcer = None # must be set by ._reset()

def act(self, observations: "TensorDict", *args, **kwargs) -> Tuple[Tensor, Tensor]:
self._update_value_map(observations)
return super().act(observations, *args, **kwargs)

def _reset(self):
super()._reset()
self._acyclic_enforcer = AcyclicEnforcer()

def _explore(self, observations: Union[Dict[str, Tensor], "TensorDict"]) -> Tensor:
frontiers = observations["frontier_sensor"][0].cpu().numpy()
best_frontier, value = self._value_map.select_best_waypoint(frontiers, 0.5)
os.environ["DEBUG_INFO"] = f"Best value: {value*100:.2f}%"
print(f"Step: {self._num_steps} Best value: {value*100:.2f}%")
if np.array_equal(frontiers, np.zeros((1, 2))):
return self._stop_action
best_frontier, best_value = self._get_best_frontier(observations, frontiers)
os.environ["DEBUG_INFO"] = f"Best value: {best_value*100:.2f}%"
print(f"Step: {self._num_steps} Best value: {best_value*100:.2f}%")
pointnav_action = self._pointnav(
observations, best_frontier, deterministic=True, stop=False
)

return pointnav_action

def _get_best_frontier(
self,
observations: Union[Dict[str, Tensor], "TensorDict"],
frontiers: np.ndarray,
) -> Tuple[np.ndarray, float]:
"""Returns the best frontier and its value based on self._value_map.
Args:
observations (Union[Dict[str, Tensor], "TensorDict"]): The observations from
the environment. Must contain "gps"
frontiers (np.ndarray): The frontiers to choose from, array of 2D points.
Returns:
Tuple[np.ndarray, float]: The best frontier and its value.
"""
sorted_pts, sorted_values = self._value_map.sort_waypoints(frontiers, 0.5)

position = observations["gps"].squeeze(1).cpu().numpy()[0]
best_frontier, best_value = None, None
for frontier, value in zip(sorted_pts, sorted_values):
cyclic = self._acyclic_enforcer.check_cyclic(position, frontier)
if not cyclic:
best_frontier, best_value = frontier, value
break

if best_frontier is None:
print("All frontiers are cyclic. Choosing the closest one.")
best_idx = max(
range(len(frontiers)),
key=lambda i: np.linalg.norm(frontiers[i] - position),
)
best_frontier, best_value = (frontiers[best_idx], sorted_values[best_idx])

self._acyclic_enforcer.add_state_action(position, best_frontier)

return best_frontier, best_value
29 changes: 29 additions & 0 deletions zsos/policy/utils/acyclic_enforcer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Any, Set

import numpy as np


class StateAction:
def __init__(self, position: np.ndarray, action: Any):
self.position = position
self.action = action

def __eq__(self, other: "StateAction") -> bool:
return self.__hash__() == other.__hash__()

def __hash__(self) -> int:
string_repr = f"{self.position}_{self.action}"
return hash(string_repr)


class AcyclicEnforcer:
history: Set[StateAction] = set()

def check_cyclic(self, position: np.ndarray, action: Any) -> bool:
state_action = StateAction(position, action)
cyclic = state_action in self.history
return cyclic

def add_state_action(self, position: np.ndarray, action: Any):
state_action = StateAction(position, action)
self.history.add(state_action)

0 comments on commit 97751ad

Please sign in to comment.