diff --git a/zsos/llm/prompts.py b/zsos/llm/prompts.py new file mode 100644 index 0000000..ba25370 --- /dev/null +++ b/zsos/llm/prompts.py @@ -0,0 +1,63 @@ +def get_textual_map_prompt( + target: str, + textual_map: str, + object_options: str, + frontier_options: str, + curr_position: str = None, +): + prompt = ( + "You are a robot exploring an unfamiliar home. Your task is to find a " + f"{target}.\n" + ) + + if curr_position is not None: + prompt += ( + f"You are currently at the following x-y coordinates: {curr_position}.\n" + ) + + prompt += ( + "This is a list of the names and locations of the objects that you have seen " + "so far:\n\n" + f"{textual_map}\n\n" + ) + + if object_options != "": + prompt += ( + "Here are a list of possible objects that you can go to:\n\n" + f"{object_options}\n\n" + "Alternatively, you can navigate to the following frontiers to explore " + "unexplored areas of the home:\n\n" + ) + choice = "EITHER one object or frontier" + else: + prompt += ( + "You can navigate to the following frontiers to explore unexplored areas " + "of the home:\n\n" + ) + choice = "the frontier" + + # prompt += ( + # f"{frontier_options}\n\n" + # "Carefully think about the layout of the objects and their categories, and " + # f"then select {choice} that represents the best location " + # f"to go to in order to find a {target} as soon as possible.\n" + # "Your response must be ONLY ONE integer (ex. '1', '23', etc.).\n" + # ) + + prompt += ( + f"{frontier_options}\n\n" + "Carefully think about the layout of the objects and their categories, and " + f"then select {choice} that best represents the location to go to in order to " + f"find a {target} with the highest likelihood.\n" + "Your response must be ONLY ONE integer (ex. '1', '23', etc.).\n" + ) + + return prompt + + +def unnumbered_list(items: list[str]) -> str: + return "\n".join(f"- {item}" for item in items) + + +def numbered_list(items: list[str], start: int = 1) -> str: + return "\n".join(f"{i+start}. {item}" for i, item in enumerate(items)) diff --git a/zsos/mapping/object_map.py b/zsos/mapping/object_map.py index 23f13c2..24f0831 100644 --- a/zsos/mapping/object_map.py +++ b/zsos/mapping/object_map.py @@ -4,6 +4,7 @@ import cv2 import numpy as np +from zsos.llm.prompts import get_textual_map_prompt, numbered_list, unnumbered_list from zsos.policy.utils.pointnav_policy import wrap_heading @@ -107,7 +108,39 @@ def update_explored( ): obj.explored = True - def visualize(self) -> np.ndarray: + def get_textual_map_prompt( + self, target: str, current_pos: np.ndarray, frontiers: np.ndarray + ) -> str: + """ + Returns a textual representation of the object map. The {target} field will + still be unfilled. + """ + # 'textual_map' is a list of strings, where each string represents the + # object's name and location + textual_map = objects_to_str(self.map, current_pos) + textual_map = unnumbered_list(textual_map) + + # 'unexplored_objects' is a list of strings, where each string represents the + # object's name and location + unexplored_objects = [obj for obj in self.map if not obj.explored] + unexplored_objects = objects_to_str(unexplored_objects, current_pos) + # For object_options, only return a list of objects that have not been explored + object_options = numbered_list(unexplored_objects) + + # 'frontiers_list' is a list of strings, where each string represents the + # frontier's location + frontiers_list = [ + f"({frontier[0]:.2f}, {frontier[1]:.2f})" for frontier in frontiers + ] + frontier_options = numbered_list( + frontiers_list, start=len(unexplored_objects) + 1 + ) + + return get_textual_map_prompt( + target, textual_map, object_options, frontier_options + ) + + def visualize(self, frontiers: np.ndarray) -> np.ndarray: """ Visualizes the object map by plotting the history of the camera coordinates and the location of each object in a 2D top-down view. If the object is @@ -140,6 +173,11 @@ def plot_circle(coordinates, circle_color): for camera_c, _ in self.camera_history: plot_circle(camera_c, (0, 255, 0)) + for frontier in frontiers: + if np.all(frontier == 0): # ignore all zeros frontiers + continue + plot_circle(frontier, (255, 255, 255)) + visual_map = cv2.flip(visual_map, 0) visual_map = cv2.rotate(visual_map, cv2.ROTATE_90_COUNTERCLOCKWISE) @@ -409,3 +447,35 @@ def within_fov_cone( angle_diff = wrap_heading(angle - cone_angle) return dist <= cone_range and abs(angle_diff) <= cone_fov / 2 + + +def objects_to_str(objs: List[Object], current_pos: np.ndarray) -> List[str]: + """ + This function converts a list of object locations into strings. The list is first + sorted based on the distance of each object from the agent's current position. + + Args: + objs (List[Object]): A list of Object instances representing the objects. + current_pos (np.ndarray): Current position of the agent. + + Returns: + List[str]: A list where each string represents an object and its location in + relation to the agent's position. + """ + objs.sort(key=lambda obj: np.linalg.norm(obj.location - current_pos)) + objs = [f"{obj.class_name} at {obj_loc_to_str(obj.location)}" for obj in objs] + return objs + + +def obj_loc_to_str(arr: np.ndarray) -> str: + """ + Converts a numpy array representing an object's location into a string. + + Args: + arr (np.ndarray): Object's coordinates. + + Returns: + str: A string representation of the object's location with precision up to two + decimal places. + """ + return f"({arr[0]:.2f}, {arr[1]:.2f})" diff --git a/zsos/policy/llm_policy.py b/zsos/policy/llm_policy.py index 51b662a..d198855 100644 --- a/zsos/policy/llm_policy.py +++ b/zsos/policy/llm_policy.py @@ -248,6 +248,7 @@ def _update_object_map( yaw, confidence, ) + self.object_map.update_explored(camera_coordinates, yaw) def _should_explore(self) -> bool: return self.target_object not in self.seen_objects