Skip to content

Commit

Permalink
finally supporting full configurability via hydra, prefixing all priv…
Browse files Browse the repository at this point in the history
…ates with underscores
  • Loading branch information
naokiyokoyamabd committed Aug 11, 2023
1 parent d528e64 commit 85eacd6
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 108 deletions.
1 change: 1 addition & 0 deletions config/experiments/llm_objectnav_hm3d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ defaults:
- frontier_sensor
- /habitat/task/measurements:
- frontier_exploration_map
- /habitat_baselines/rl/policy: zsos_policy
- _self_

habitat:
Expand Down
144 changes: 76 additions & 68 deletions zsos/policy/base_objectnav_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
WrappedPointNavResNetPolicy,
rho_theta_from_gps_compass_goal,
)
from zsos.utils.geometry_utils import xyz_yaw_to_tf_matrix
from zsos.vlm.grounding_dino import GroundingDINOClient, ObjectDetections

try:
Expand All @@ -25,45 +24,50 @@ class BasePolicy:


class BaseObjectNavPolicy(BasePolicy):
target_object: str = ""
camera_height: float = 0.88
depth_image_shape: Tuple[int, int] = (244, 224)
det_conf_threshold: float = 0.50
pointnav_stop_radius: float = 0.85
visualize: bool = True
policy_info: Dict[str, Any] = {}
id_to_padding: Dict[str, float] = {}
_stop_action: Tensor = None # must be set by subclass
# ObjectMap parameters; these must be set by subclass
min_depth: float = None
max_depth: float = None
hfov: float = None
proximity_threshold: float = None

def __init__(self, *args, **kwargs):
_target_object: str = ""
_policy_info: Dict[str, Any] = {}
_id_to_padding: Dict[str, float] = {}
_stop_action: Tensor = None # MUST BE SET BY SUBCLASS

def __init__(
self,
pointnav_policy_path: str,
depth_image_shape: Tuple[int, int],
det_conf_threshold: float,
pointnav_stop_radius: float,
object_map_min_depth: float,
object_map_max_depth: float,
object_map_hfov: float,
object_map_proximity_threshold: float,
visualize: bool = True,
*args,
**kwargs,
):
super().__init__()
self.object_detector = GroundingDINOClient()
self.pointnav_policy = WrappedPointNavResNetPolicy(
os.environ["POINTNAV_POLICY_PATH"]
)
self.object_map: ObjectMap = ObjectMap(
min_depth=self.min_depth,
max_depth=self.max_depth,
hfov=self.hfov,
proximity_threshold=self.proximity_threshold,
self._object_detector = GroundingDINOClient()
self._pointnav_policy = WrappedPointNavResNetPolicy(pointnav_policy_path)
self._object_map: ObjectMap = ObjectMap(
min_depth=object_map_min_depth,
max_depth=object_map_max_depth,
hfov=object_map_hfov,
proximity_threshold=object_map_proximity_threshold,
)
self._depth_image_shape = tuple(depth_image_shape)
self._det_conf_threshold = det_conf_threshold
self._pointnav_stop_radius = pointnav_stop_radius
self._visualize = visualize

self.num_steps = 0
self.last_goal = np.zeros(2)
self.done_initializing = False
self._num_steps = 0
self._last_goal = np.zeros(2)
self._done_initializing = False

def _reset(self):
self.target_object = ""
self.pointnav_policy.reset()
self.object_map.reset()
self.last_goal = np.zeros(2)
self.num_steps = 0
self.done_initializing = False
self._target_object = ""
self._pointnav_policy.reset()
self._object_map.reset()
self._last_goal = np.zeros(2)
self._num_steps = 0
self._done_initializing = False

def act(
self, observations, rnn_hidden_states, prev_actions, masks, deterministic=False
Expand All @@ -78,17 +82,15 @@ def act(
assert masks.shape[1] == 1, "Currently only supporting one env at a time"
if masks[0] == 0:
self._reset()
self.target_object = observations["objectgoal"]
self._target_object = observations["objectgoal"]

self.policy_info = {}
self._policy_info = {}

rgb, depth, tf_camera_to_episodic = self._get_detection_camera_info(
observations
)
rgb, depth, tf_camera_to_episodic = self._get_object_camera_info(observations)
detections = self._update_object_map(rgb, depth, tf_camera_to_episodic)
goal = self._get_target_object_location()

if not self.done_initializing: # Initialize
if not self._done_initializing: # Initialize
pointnav_action = self._initialize()
elif goal is None: # Haven't found target object yet
pointnav_action = self._explore(observations)
Expand All @@ -97,8 +99,8 @@ def act(
observations, goal[:2], deterministic=deterministic, stop=True
)

self.policy_info = self._get_policy_info(observations, detections)
self.num_steps += 1
self._policy_info = self._get_policy_info(observations, detections)
self._num_steps += 1

return pointnav_action, rnn_hidden_states

Expand All @@ -110,18 +112,18 @@ def _explore(self, observations: "TensorDict") -> Tensor:

def _get_target_object_location(self) -> Union[None, np.ndarray]:
try:
return self.object_map.get_best_object(self.target_object)
return self._object_map.get_best_object(self._target_object)
except ValueError:
# Target object has not been spotted
return None

def _get_policy_info(
self, observations: "TensorDict", detections: ObjectDetections
) -> Dict[str, Any]:
seen_objects = set(i.class_name for i in self.object_map.map)
seen_objects = set(i.class_name for i in self._object_map.map)
seen_objects_str = ", ".join(seen_objects)
policy_info = {
"target_object": "target: " + self.target_object,
"target_object": "target: " + self._target_object,
"visualized_detections": detections.annotated_frame,
"seen_objects": seen_objects_str,
"gps": str(observations["gps"][0].cpu().numpy()),
Expand All @@ -140,8 +142,8 @@ def _get_policy_info(
return policy_info

def _get_object_detections(self, img: np.ndarray) -> ObjectDetections:
detections = self.object_detector.predict(img, visualize=self.visualize)
detections.filter_by_conf(self.det_conf_threshold)
detections = self._object_detector.predict(img, visualize=self._visualize)
detections.filter_by_conf(self._det_conf_threshold)

return detections

Expand All @@ -160,55 +162,61 @@ def _pointnav(
Args:
observations ("TensorDict"): The observations from the current timestep.
"""
masks = torch.tensor([self.num_steps != 0], dtype=torch.bool, device="cuda")
if not np.array_equal(goal, self.last_goal):
self.last_goal = goal
self.pointnav_policy.reset()
masks = torch.tensor([self._num_steps != 0], dtype=torch.bool, device="cuda")
if not np.array_equal(goal, self._last_goal):
self._last_goal = goal
self._pointnav_policy.reset()
masks = torch.zeros_like(masks)
rho_theta = rho_theta_from_gps_compass_goal(observations, goal)
obs_pointnav = {
"depth": image_resize(
observations["depth"],
self.depth_image_shape,
self._depth_image_shape,
channels_last=True,
interpolation_mode="area",
),
"pointgoal_with_gps_compass": rho_theta.unsqueeze(0),
}
stop_dist = self.pointnav_stop_radius + self.id_to_padding.get(
self.target_object, 0.0
stop_dist = self._pointnav_stop_radius + self._id_to_padding.get(
self._target_object, 0.0
)
if rho_theta[0] < stop_dist and stop:
return self._stop_action
action = self.pointnav_policy.act(
action = self._pointnav_policy.act(
obs_pointnav, masks, deterministic=deterministic
)
return action

def _get_detection_camera_info(self, observations: "TensorDict") -> Tuple:
rgb = observations["rgb"][0].cpu().numpy()
depth = observations["depth"][0].cpu().numpy()
x, y = observations["gps"][0].cpu().numpy()
camera_yaw = observations["compass"][0].cpu().item()
# Habitat GPS makes west negative, so flip y
camera_position = np.array([x, -y, self.camera_height])
tf_camera_to_episodic = xyz_yaw_to_tf_matrix(camera_position, camera_yaw)
return rgb, depth, tf_camera_to_episodic

def _update_object_map(
self, rgb: np.ndarray, depth: np.ndarray, tf_camera_to_episodic: np.ndarray
) -> ObjectDetections:
detections = self._get_object_detections(rgb)

for idx, confidence in enumerate(detections.logits):
self.object_map.update_map(
self._object_map.update_map(
detections.phrases[idx],
detections.boxes[idx],
depth,
tf_camera_to_episodic,
confidence,
)

self.object_map.update_explored(tf_camera_to_episodic)
self._object_map.update_explored(tf_camera_to_episodic)

return detections

def _get_object_camera_info(
self, observations: "TensorDict"
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Extracts the rgb, depth, and camera transform from the observations.
Args:
observations ("TensorDict"): The observations from the current timestep.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: The rgb image, depth image, and
camera transform. The depth image is normalized to be between 0 and 1.
The camera transform is the transform from the camera to the episodic
frame, a 4x4 transformation matrix.
"""
raise NotImplementedError
Loading

0 comments on commit 85eacd6

Please sign in to comment.