Skip to content

Commit

Permalink
finally fixed rho-theta calculations for both frontiers and objects
Browse files Browse the repository at this point in the history
  • Loading branch information
naokiyokoyamabd committed Jul 28, 2023
1 parent 0fcd08b commit d1111fb
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 109 deletions.
71 changes: 45 additions & 26 deletions zsos/mapping/object_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def update_map(

def get_best_object(self, target_class: str) -> np.ndarray:
"""
Returns the closest object to the agent that matches the given object name.
Returns the closest object to the agent that matches the given object name. It
will not go towards objects that are too far away, unless there are no other
objects of the same class.
Args:
target_class (str): The name of the object class to search for.
Expand All @@ -78,19 +80,22 @@ def get_best_object(self, target_class: str) -> np.ndarray:
np.ndarray: The location of the closest object to the agent that matches the
given object name [x, y, z].
"""
matches = [obj for obj in self.map if obj.class_name == target_class]
if len(matches) == 0:
raise ValueError(
f"No object of type {target_class} found in the object map."
)

ignore_too_far = any([not obj.too_far for obj in matches])
best_loc, best_conf = None, -float("inf")
for object_inst in self.map:
if (
target_class == object_inst.class_name
and object_inst.confidence > best_conf
):
for object_inst in matches:
if object_inst.confidence > best_conf:
if ignore_too_far and object_inst.too_far:
continue
best_loc = object_inst.location
best_conf = object_inst.confidence

if best_loc is None:
raise ValueError(
f"No object of type {target_class} found in the object map."
)
assert best_loc is not None, "This error should never be reached."

return best_loc

Expand All @@ -107,10 +112,12 @@ def update_explored(
obj.location,
):
obj.explored = True
# Remove objects that are both too far and explored
self.map = [obj for obj in self.map if not (obj.explored and obj.too_far)]

def get_textual_map_prompt(
self, target: str, current_pos: np.ndarray, frontiers: np.ndarray
) -> str:
) -> Tuple[str, np.ndarray]:
"""
Returns a textual representation of the object map. The {target} field will
still be unfilled.
Expand All @@ -123,23 +130,37 @@ def get_textual_map_prompt(
# 'unexplored_objects' is a list of strings, where each string represents the
# object's name and location
unexplored_objects = [obj for obj in self.map if not obj.explored]
unexplored_objects = objects_to_str(unexplored_objects, current_pos)
unexplored_objects_strs = objects_to_str(unexplored_objects, current_pos)
# For object_options, only return a list of objects that have not been explored
object_options = numbered_list(unexplored_objects)
object_options = numbered_list(unexplored_objects_strs)

# 'frontiers_list' is a list of strings, where each string represents the
# frontier's location
frontiers_list = [
f"({frontier[0]:.2f}, {frontier[1]:.2f})" for frontier in frontiers
]
frontier_options = numbered_list(
frontiers_list, start=len(unexplored_objects) + 1
frontiers_list, start=len(unexplored_objects_strs) + 1
)

return get_textual_map_prompt(
target, textual_map, object_options, frontier_options
curr_pos_str = f"({current_pos[0]:.2f}, {current_pos[1]:.2f})"

prompt = get_textual_map_prompt(
target,
textual_map,
object_options,
frontier_options,
curr_position=curr_pos_str,
)

waypoints = []
for obj in unexplored_objects:
waypoints.append(obj.location)
waypoints.extend(list(frontiers))
waypoints = np.array(waypoints)

return prompt, waypoints

def visualize(self, frontiers: np.ndarray) -> np.ndarray:
"""
Visualizes the object map by plotting the history of the camera coordinates
Expand Down Expand Up @@ -196,18 +217,18 @@ def _estimate_object_location(
Args:
bounding_box (np.ndarray): The bounding box coordinates of the detected
object in the image [x_min, y_min, x_max, y_max]. These coordinates are
normalized to the range [0, 1].
object in the image [x_min, y_min, x_max, y_max]. These coordinates are
normalized to the range [0, 1].
depth_image (np.ndarray): The depth image captured by the RGBD camera.
camera_coordinates (np.ndarray): The global coordinates of the camera
[x, y, z].
[x, y, z].
camera_yaw (float): The yaw angle of the camera in radians.
Returns:
np.ndarray: The estimated 3D location of the detected object in the global
coordinate frame [x, y, z].
coordinate frame [x, y, z].
bool: True if the object is too far away for the depth camera, False
otherwise.
otherwise.
"""
# Get the depth value of the object
pixel_x, pixel_y, depth_value = self._get_object_depth(
Expand All @@ -225,7 +246,7 @@ def _estimate_object_location(
)
# Yaw from compass sensor must be negated to work properly
object_coord_global = convert_to_global_frame(
camera_coordinates, -camera_yaw, object_coord_agent
camera_coordinates, camera_yaw, object_coord_agent
)

return object_coord_global, too_far
Expand Down Expand Up @@ -350,9 +371,7 @@ def calculate_3d_coordinates(

hor_distance = depth_value * math.cos(phi)
x = hor_distance * math.cos(theta)

y = hor_distance * math.sin(theta)

y = -hor_distance * math.sin(theta)
ver_distance = depth_value * math.sin(theta)
z = ver_distance * math.sin(phi)

Expand Down Expand Up @@ -462,7 +481,7 @@ def objects_to_str(objs: List[Object], current_pos: np.ndarray) -> List[str]:
List[str]: A list where each string represents an object and its location in
relation to the agent's position.
"""
objs.sort(key=lambda obj: np.linalg.norm(obj.location - current_pos))
objs.sort(key=lambda obj: np.linalg.norm(obj.location[:2] - current_pos))
objs = [f"{obj.class_name} at {obj_loc_to_str(obj.location)}" for obj in objs]
return objs

Expand Down
124 changes: 70 additions & 54 deletions zsos/policy/llm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

import numpy as np
import torch
from frontier_exploration.policy import FrontierExplorationPolicy
from habitat.tasks.nav.object_nav_task import ObjectGoalSensor
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.common.tensor_dict import TensorDict
from habitat_baselines.rl.ppo.policy import PolicyActionData
from torch import Tensor

from frontier_exploration.policy import FrontierExplorationPolicy
from zsos.llm.llm import BaseLLM, ClientFastChat
from zsos.mapping.object_map import ObjectMap
from zsos.obs_transformers.resize import image_resize
Expand Down Expand Up @@ -76,6 +76,7 @@ def act(
assert masks.shape[1] == 1, "Currently only supporting one env at a time"
if masks[0] == 0:
self._reset()
self.target_object = ID_TO_NAME[observations[ObjectGoalSensor.cls_uuid][0]]

# Get action_data from FrontierExplorationPolicy
action_data = super().act(
Expand All @@ -86,29 +87,50 @@ def act(
deterministic=deterministic,
)

self.target_object = ID_TO_NAME[observations[ObjectGoalSensor.cls_uuid][0]]
detections = self._update_object_map(observations)

image_numpy = observations["rgb"][0].cpu().numpy()
detections = self._get_object_detections(image_numpy)
self._update_object_map(observations, detections)
llm_responses = self._get_llm_responses()
try:
# Target object has been spotted
goal = self.object_map.get_best_object(self.target_object)
except ValueError:
# Target object has not been spotted
goal = None

# baseline = True
baseline = False

if self.start_steps < 12:
self.start_steps += 1
pointnav_action = TorchActionIDs.TURN_LEFT
elif self._should_explore():
pointnav_action = action_data.actions
llm_responses = "Spinning..."
elif goal is not None:
pointnav_action = self._pointnav(
observations, masks, goal[:2], deterministic=deterministic, stop=True
)
llm_responses = "Beelining to target!"
else:
goal = self.object_map.get_best_object(self.target_object)
# PointNav only cares about x, y
curr_pos = observations["gps"][0].cpu().numpy() * np.array([1, -1])
llm_responses = "Closest exploration" if baseline else "LLM exploration"
if np.linalg.norm(self.last_goal - curr_pos) < 0.25:
frontiers = observations["frontier_sensor"][0].cpu().numpy()
if baseline:
goal = frontiers[0]
else:
# Ask LLM which waypoint to head to next
goal, llm_responses = self._get_llm_goal(curr_pos, frontiers)
else:
goal = self.last_goal

pointnav_action = self._pointnav(
observations, masks, goal[:2], deterministic=deterministic
observations, masks, goal[:2], deterministic=deterministic, stop=False
)

action_data.actions = pointnav_action

action_data.policy_info = self._get_policy_info(
observations, detections, llm_responses
)
print(llm_responses)

return action_data

Expand Down Expand Up @@ -146,65 +168,52 @@ def _get_policy_info(
def _get_object_detections(self, img: np.ndarray) -> ObjectDetections:
detections = self.object_detector.predict(img, visualize=self.visualize)
detections.filter_by_conf(self.det_conf_threshold)
objects = self._extract_detected_names(detections)

self.seen_objects.update(objects)

return detections

def _extract_detected_names(self, detections: ObjectDetections) -> List[str]:
# Filter out detections that are not verbatim in the classes.txt file
objects = [
phrase
for phrase in detections.phrases
if phrase in self.object_detector.classes
]
return objects

def _get_llm_responses(self) -> str:
def _get_llm_goal(
self, current_pos: np.ndarray, frontiers: np.ndarray
) -> Tuple[np.ndarray, str]:
"""
Asks LLM which object to go to next, conditioned on the target object.
Asks LLM which object or frontier to go to next. self.object_map is used to
generate the prompt for the LLM.
Args:
current_pos (np.ndarray): A 1D array of shape (2,) containing the current
position of the robot.
frontiers (np.ndarray): A 2D array of shape (num_frontiers, 2) containing
the coordinates of the frontiers.
Returns:
List[str]: A list containing the responses generated by the LLM.
Tuple[np.ndarray, str]: A tuple containing the goal and the LLM response.
"""
if len(self.seen_objects) == 0:
return ""

if self.target_object in self.seen_objects:
return self.target_object

choices = list(self.seen_objects)
choices_str = ""
for i, category in enumerate(choices):
choices_str += f"{i}. {category}\n"

prompt = (
"Question: Which object category from the following options would be most "
f"likely to be found near a '{self.target_object}'?\n\n"
f"{choices_str}"
"\nYour response must be ONLY ONE integer (ex. '0', '15', etc.).\n"
"Answer: "
prompt, waypoints = self.object_map.get_textual_map_prompt(
self.target_object, current_pos, frontiers
)
resp = self.llm.ask(prompt)
int_resp = extract_integer(resp) - 1

llm_resp = self.llm.ask(prompt)
obj_idx = extract_integer(llm_resp)
if obj_idx != -1:
self.current_best_object = choices[obj_idx]
# TODO: IndexError here
try:
waypoint = waypoints[int_resp]
except IndexError:
print("Seems like the LLM returned an invalid response:\n")
print(resp)
waypoint = waypoints[-1]

return self.current_best_object
return waypoint, resp

def _pointnav(
self,
observations: TensorDict,
masks: Tensor,
goal: np.ndarray,
deterministic=False,
stop=False,
) -> Tensor:
if not np.array_equal(goal, self.last_goal):
self.last_goal = goal
self.pointnav_policy.reset()
masks = torch.zeros_like(masks)
rho_theta = rho_theta_from_gps_compass_goal(observations, goal)
obs_pointnav = {
"depth": image_resize(
Expand All @@ -215,16 +224,14 @@ def _pointnav(
),
"pointgoal_with_gps_compass": rho_theta.unsqueeze(0),
}
if rho_theta[0] < self.pointnav_stop_radius:
if rho_theta[0] < self.pointnav_stop_radius and stop:
return TorchActionIDs.STOP
action = self.pointnav_policy.get_actions(
obs_pointnav, masks, deterministic=deterministic
)
return action

def _update_object_map(
self, observations: TensorDict, detections: ObjectDetections
) -> None:
def _update_object_map(self, observations: TensorDict) -> ObjectDetections:
"""
Updates the object map with the detections from the current timestep.
Expand All @@ -233,12 +240,16 @@ def _update_object_map(
detections (ObjectDetections): The detections from the current
timestep.
"""
rgb = observations["rgb"][0].cpu().numpy()
depth = observations["depth"][0].cpu().numpy()
x, y = observations["gps"][0].cpu().numpy()
camera_coordinates = np.array(
[*observations["gps"][0].cpu().numpy(), self.camera_height]
[x, -y, self.camera_height] # Habitat GPS makes west negative, so flip y
)
yaw = observations["compass"][0].item()

detections = self._get_object_detections(rgb)

for idx, confidence in enumerate(detections.logits):
self.object_map.update_map(
detections.phrases[idx],
Expand All @@ -250,6 +261,11 @@ def _update_object_map(
)
self.object_map.update_explored(camera_coordinates, yaw)

seen_objects = set(i.class_name for i in self.object_map.map)
self.seen_objects.update(seen_objects)

return detections

def _should_explore(self) -> bool:
return self.target_object not in self.seen_objects

Expand Down
Loading

0 comments on commit d1111fb

Please sign in to comment.