diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index dbf9eb8a3..5cadda34c 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -59,7 +59,8 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: nsrts = self._get_current_nsrts() preds = self._get_current_predicates() utils.abstract(task.init, preds, self._vlm) - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() # utils.abstract(task.init, preds, self._vlm) # Run task planning only and then greedily sample and execute in the # policy. diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index 374c0114a..b702638bf 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -48,7 +48,7 @@ update_pbrspot_robot_conf, verify_estop from predicators.structs import Action, EnvironmentTask, GoalDescription, \ GroundAtom, LiftedAtom, Object, Observation, Predicate, \ - SpotActionExtraInfo, State, STRIPSOperator, Type, Variable + SpotActionExtraInfo, State, STRIPSOperator, Type, Variable, _Option ############################################################################### # Base Class # @@ -106,6 +106,7 @@ class _TruncatedSpotObservation: # # A placeholder until all predicates have classifiers # nonpercept_atoms: Set[GroundAtom] # nonpercept_predicates: Set[Predicate] + executed_skill: Optional[_Option] = None # Object detections per camera in self.rgbd_images. object_detections_per_camera: Dict[str, List[Tuple[ObjectDetectionID, SegmentedBoundingBox]]] @@ -197,7 +198,9 @@ def get_robot_only() -> Tuple[Optional[Robot], Optional[LeaseClient]]: verify_estop(robot) lease_client = robot.ensure_client(LeaseClient.default_service_name) lease_client.take() - lease_keepalive = LeaseKeepAlive(lease_client, must_acquire=True, return_at_exit=True) + lease_keepalive = LeaseKeepAlive(lease_client, + must_acquire=True, + return_at_exit=True) return robot, lease_client @@ -1484,45 +1487,37 @@ def _get_sweeping_surface_for_container(container: Object, "IsSemanticallyGreaterThan", [_base_object_type, _base_object_type], _is_semantically_greater_than_classifier) + def _get_vlm_query_str(pred_name: str, objects: Sequence[Object]) -> str: - return pred_name + "(" + ", ".join(str(obj.name) for obj in objects) + ")" # pragma: no cover -_VLMOn = utils.create_vlm_predicate( - "VLMOn", - [_movable_object_type, _base_object_type], - lambda o: _get_vlm_query_str("VLMOn", o) -) + return pred_name + "(" + ", ".join( + str(obj.name) for obj in objects) + ")" # pragma: no cover + + +_VLMOn = utils.create_vlm_predicate("VLMOn", + [_movable_object_type, _base_object_type], + lambda o: _get_vlm_query_str("VLMOn", o)) _Upright = utils.create_vlm_predicate( - "Upright", - [_movable_object_type], - lambda o: _get_vlm_query_str("Upright", o) -) + "Upright", [_movable_object_type], + lambda o: _get_vlm_query_str("Upright", o)) _Toasted = utils.create_vlm_predicate( - "Toasted", - [_movable_object_type], - lambda o: _get_vlm_query_str("Toasted", o) -) + "Toasted", [_movable_object_type], + lambda o: _get_vlm_query_str("Toasted", o)) _VLMIn = utils.create_vlm_predicate( - "VLMIn", - [_movable_object_type, _immovable_object_type], - lambda o: _get_vlm_query_str("In", o) -) -_Open = utils.create_vlm_predicate( - "Open", - [_movable_object_type], - lambda o: _get_vlm_query_str("Open", o) -) + "VLMIn", [_movable_object_type, _immovable_object_type], + lambda o: _get_vlm_query_str("In", o)) +_Open = utils.create_vlm_predicate("Open", [_movable_object_type], + lambda o: _get_vlm_query_str("Open", o)) _Stained = utils.create_vlm_predicate( - "Stained", - [_movable_object_type], - lambda o: _get_vlm_query_str("Stained", o) -) + "Stained", [_movable_object_type], + lambda o: _get_vlm_query_str("Stained", o)) _ALL_PREDICATES = { _NEq, _On, _TopAbove, _Inside, _NotInsideAnyContainer, _FitsInXY, _HandEmpty, _Holding, _NotHolding, _InHandView, _InView, _Reachable, _Blocking, _NotBlocked, _ContainerReadyForSweeping, _IsPlaceable, _IsNotPlaceable, _IsSweeper, _HasFlatTopSurface, _RobotReadyForSweeping, - _IsSemanticallyGreaterThan, _VLMOn, _Upright, _Toasted, _VLMIn, _Open, _Stained + _IsSemanticallyGreaterThan, _VLMOn, _Upright, _Toasted, _VLMIn, _Open, + _Stained } _NONPERCEPT_PREDICATES: Set[Predicate] = set() @@ -2455,8 +2450,9 @@ class VLMTestEnv(SpotRearrangementEnv): @property def predicates(self) -> Set[Predicate]: # return set(p for p in _ALL_PREDICATES if p.name in ["VLMOn", "Holding", "HandEmpty", "Pourable", "Toasted", "VLMIn", "Open"]) - return set(p for p in _ALL_PREDICATES if p.name in ["VLMOn", "Holding", "HandEmpty", "Upright"]) - + return set(p for p in _ALL_PREDICATES + if p.name in ["VLMOn", "Holding", "HandEmpty", "Upright"]) + @property def goal_predicates(self) -> Set[Predicate]: return self.predicates @@ -2464,7 +2460,7 @@ def goal_predicates(self) -> Set[Predicate]: @classmethod def get_name(cls) -> str: return "spot_vlm_test_env" - + def _get_dry_task(self, train_or_test: str, task_idx: int) -> EnvironmentTask: raise NotImplementedError("No dry task for VLMTestEnv.") @@ -2495,35 +2491,31 @@ def _create_operators(self) -> Iterator[STRIPSOperator]: LiftedAtom(_NotHolding, [robot, obj]), LiftedAtom(_VLMOn, [obj, table]) } - add_effs: Set[LiftedAtom] = { - LiftedAtom(_Holding, [robot, obj]) - } + add_effs: Set[LiftedAtom] = {LiftedAtom(_Holding, [robot, obj])} del_effs: Set[LiftedAtom] = { LiftedAtom(_HandEmpty, [robot]), LiftedAtom(_NotHolding, [robot, obj]), LiftedAtom(_VLMOn, [obj, table]) } ignore_effs: Set[LiftedAtom] = set() - yield STRIPSOperator("Pick", parameters, preconds, add_effs, del_effs, ignore_effs) + yield STRIPSOperator("Pick", parameters, preconds, add_effs, del_effs, + ignore_effs) # Place object robot = Variable("?robot", _robot_type) obj = Variable("?object", _movable_object_type) pan = Variable("?pan", _container_type) parameters = [robot, obj, pan] - preconds: Set[LiftedAtom] = { - LiftedAtom(_Holding, [robot, obj]) - } + preconds: Set[LiftedAtom] = {LiftedAtom(_Holding, [robot, obj])} add_effs: Set[LiftedAtom] = { LiftedAtom(_HandEmpty, [robot]), LiftedAtom(_NotHolding, [robot, obj]), LiftedAtom(_VLMOn, [obj, pan]) } - del_effs: Set[LiftedAtom] = { - LiftedAtom(_Holding, [robot, obj]) - } + del_effs: Set[LiftedAtom] = {LiftedAtom(_Holding, [robot, obj])} ignore_effs: Set[LiftedAtom] = set() - yield STRIPSOperator("Place", parameters, preconds, add_effs, del_effs, ignore_effs) + yield STRIPSOperator("Place", parameters, preconds, add_effs, del_effs, + ignore_effs) # def _generate_train_tasks(self) -> List[EnvironmentTask]: # goal = self._generate_goal_description() # currently just one goal @@ -2555,10 +2547,7 @@ def __init__(self, use_gui: bool = True) -> None: # Create constant objects. self._spot_object = Object("robot", _robot_type) op_to_name = {o.name: o for o in self._create_operators()} - op_names_to_keep = { - "Pick", - "Place" - } + op_names_to_keep = {"Pick", "Place"} self._strips_operators = {op_to_name[o] for o in op_names_to_keep} self._train_tasks = [] self._test_tasks = [] @@ -2649,7 +2638,7 @@ def _actively_construct_env_task(self) -> EnvironmentTask: def _generate_goal_description(self) -> GoalDescription: return "put the cup in the pan" - + def reset(self, train_or_test: str, task_idx: int) -> Observation: prompt = f"Please set up {train_or_test} task {task_idx}!" utils.prompt_user(prompt) @@ -2667,7 +2656,7 @@ def reset(self, train_or_test: str, task_idx: int) -> Observation: self._current_task_goal_reached = False self._last_action = None return self._current_task.init_obs - + def step(self, action: Action) -> Observation: assert self._robot is not None action_name = action.extra_info.action_name @@ -2716,6 +2705,7 @@ def step(self, action: Action) -> Observation: ) return obs + ############################################################################### # Cube Table Env # ############################################################################### diff --git a/predicators/ground_truth_models/spot_env/nsrts.py b/predicators/ground_truth_models/spot_env/nsrts.py index 7d4e5dad4..849854dbc 100644 --- a/predicators/ground_truth_models/spot_env/nsrts.py +++ b/predicators/ground_truth_models/spot_env/nsrts.py @@ -285,11 +285,11 @@ class SpotEnvsGroundTruthNSRTFactory(GroundTruthNSRTFactory): @classmethod def get_env_names(cls) -> Set[str]: return { - "spot_vlm_test_env", - "spot_cube_env", "spot_soda_floor_env", "spot_soda_table_env", - "spot_soda_bucket_env", "spot_soda_chair_env", - "spot_main_sweep_env", "spot_ball_and_cup_sticky_table_env", - "spot_brush_shelf_env", "lis_spot_block_floor_env" + "spot_vlm_test_env", "spot_cube_env", "spot_soda_floor_env", + "spot_soda_table_env", "spot_soda_bucket_env", + "spot_soda_chair_env", "spot_main_sweep_env", + "spot_ball_and_cup_sticky_table_env", "spot_brush_shelf_env", + "lis_spot_block_floor_env" } @staticmethod diff --git a/predicators/ground_truth_models/spot_env/options.py b/predicators/ground_truth_models/spot_env/options.py index 59749ab03..6e26a7ccf 100644 --- a/predicators/ground_truth_models/spot_env/options.py +++ b/predicators/ground_truth_models/spot_env/options.py @@ -899,9 +899,11 @@ def _move_to_ready_sweep_policy(state: State, memory: Dict, robot_obj_idx, target_obj_idx, do_gaze, state, memory, objects, params) -def _teleop_policy(state: State, memory: Dict, objects: Sequence[Object], params: Array) -> Action: + +def _teleop_policy(state: State, memory: Dict, objects: Sequence[Object], + params: Array) -> Action: del state, memory, params - + robot, lease_client = get_robot_only() def _teleop(robot: Robot, lease_client: LeaseClient): @@ -914,15 +916,14 @@ def _teleop(robot: Robot, lease_client: LeaseClient): # Take back control. robot, lease_client = get_robot_only() lease_client.take() - + fn = _teleop fn_args = (robot, lease_client) sim_fn = lambda _: None sim_fn_args = () name = "teleop" - action_extra_info = SpotActionExtraInfo( - name, objects, fn, fn_args, sim_fn, sim_fn_args - ) + action_extra_info = SpotActionExtraInfo(name, objects, fn, fn_args, sim_fn, + sim_fn_args) return utils.create_spot_env_action(action_extra_info) @@ -958,7 +959,7 @@ def _teleop(robot: Robot, lease_client: LeaseClient): "MoveToReadySweep": Box(0, 1, (0, )), # empty "Pick": Box(0, 1, (0, )), # empty "Place": Box(0, 1, (0, )) # empty -} +} # NOTE: the policies MUST be unique because they output actions with extra info # that includes the name of the operators. diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index b9e56d719..62394889b 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -26,7 +26,7 @@ get_allowed_map_regions, load_spot_metadata, object_to_top_down_geom from predicators.structs import Action, DefaultState, EnvironmentTask, \ GoalDescription, GroundAtom, Object, Observation, Predicate, \ - SpotActionExtraInfo, State, Task, Video + SpotActionExtraInfo, State, Task, Video, _Option class SpotPerceiver(BasePerceiver): @@ -581,7 +581,7 @@ def render_mental_images(self, observation: Observation, logging.info(f"Wrote out to {outfile}") plt.close() return [img] - + class SpotMinimalPerceiver(BasePerceiver): """A perceiver for spot envs with minimal functionality.""" @@ -622,6 +622,8 @@ def __init__(self) -> None: self._curr_env: Optional[BaseEnv] = None self._waiting_for_observation = True self._ordered_objects: List[Object] = [] # list of all known objects + self._state_history: List[State] = [] + self._executed_skill_history: List[_Option] = [] # # Keep track of objects that are contained (out of view) in another # # object, like a bag or bucket. This is important not only for gremlins # # but also for small changes in the container's perceived pose. @@ -629,14 +631,14 @@ def __init__(self) -> None: # Load static, hard-coded features of objects, like their shapes. # meta = load_spot_metadata() # self._static_object_features = meta.get("static-object-features", {}) - + def update_perceiver_with_action(self, action: Action) -> None: # NOTE: we need to keep track of the previous action # because the step function (where we need knowledge # of the previous action) occurs *after* the action # has already been taken. self._prev_action = action - + def _create_goal(self, state: State, goal_description: GoalDescription) -> Set[GroundAtom]: del state # not used @@ -720,14 +722,18 @@ def step(self, observation: Observation) -> State: pil_img = PIL.Image.fromarray(img) draw = ImageDraw.Draw(pil_img) font = utils.get_scaled_default_font(draw, 4) - annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font) + annotated_pil_img = utils.add_text_to_draw_img( + draw, (0, 0), self.camera_name_to_annotation[img_name], font) annotated_pil_imgs.append(pil_img) annotated_imgs = [np.array(img) for img in annotated_pil_imgs] - # import pdb; pdb.set_trace() self._gripper_open_percentage = observation.gripper_open_percentage self._curr_state = self._create_state() self._curr_state.simulator_state["images"] = annotated_imgs ret_state = self._curr_state.copy() + ret_state.simulator_state["state_history"] = list(self._state_history) + self._state_history.append(ret_state) + self._executed_skill_history.append(observation.executed_skill) + ret_state.simulator_state["skill_history"] = list(self._executed_skill_history) return ret_state def _create_state(self) -> State: @@ -778,7 +784,7 @@ def _create_state(self) -> State: "qz": 0, "shape": 0, "height": 0, - "width" : 0, + "width": 0, "length": 0, "object_id": 2, "placeable": 1, @@ -885,7 +891,7 @@ def _create_state(self) -> State: # } } state_dict = {k: list(v.values()) for k, v in state_dict.items()} - ret_state = State(state_dict) + ret_state = State(state_dict) ret_state.simulator_state = {} ret_state.simulator_state["images"] = [] return ret_state diff --git a/predicators/settings.py b/predicators/settings.py index 1af1ae0b6..f2c4d9ed2 100644 --- a/predicators/settings.py +++ b/predicators/settings.py @@ -434,7 +434,7 @@ class GlobalSettings: # parameters for vision language models # gemini-1.5-pro-latest, gpt-4-turbo, gpt-4o - vlm_model_name = "gemini-1.5-flash" #"gemini-1.5-pro-latest" + vlm_model_name = "gemini-1.5-flash" #"gemini-1.5-pro-latest" vlm_temperature = 0.0 vlm_num_completions = 1 vlm_include_cropped_images = False @@ -719,6 +719,9 @@ class GlobalSettings: # saved_vlm_img_demos_folder vlm_trajs_folder_name = "" vlm_predicate_vision_api_generate_ground_atoms = False + # At test-time, we will use the below number of states + # as part of labelling the current state's VLM atoms. + vlm_test_time_atom_label_prompt_type = "per_scene_naive" @classmethod def get_arg_specific_settings(cls, args: Dict[str, Any]) -> Dict[str, Any]: diff --git a/predicators/spot_utils/perception/perception_structs.py b/predicators/spot_utils/perception/perception_structs.py index d676d20cf..a39f78f38 100644 --- a/predicators/spot_utils/perception/perception_structs.py +++ b/predicators/spot_utils/perception/perception_structs.py @@ -29,9 +29,10 @@ def rotated_rgb(self) -> NDArray[np.uint8]: """The image rotated to be upright.""" return ndimage.rotate(self.rgb, self.image_rot, reshape=False) + @dataclass class RGBDImage: - """An RGBD image""" + """An RGBD image.""" rgb: NDArray[np.uint8] depth: NDArray[np.uint16] image_rot: float diff --git a/predicators/spot_utils/perception/spot_cameras.py b/predicators/spot_utils/perception/spot_cameras.py index 036dc089e..9f85bcc83 100644 --- a/predicators/spot_utils/perception/spot_cameras.py +++ b/predicators/spot_utils/perception/spot_cameras.py @@ -193,7 +193,8 @@ def capture_images_without_context( # world_tform_camera, depth_scale, # transforms_snapshot, # frame_name_image_sensor, camera_model) - rgbd = RGBDImage(rgb_img, depth_img, rot, camera_name, depth_scale, camera_model) + rgbd = RGBDImage(rgb_img, depth_img, rot, camera_name, depth_scale, + camera_model) rgbds[camera_name] = rgbd # _LAST_CAPTURED_IMAGES = rgbds diff --git a/predicators/spot_utils/spot_localization.py b/predicators/spot_utils/spot_localization.py index 145178cd7..a85399270 100644 --- a/predicators/spot_utils/spot_localization.py +++ b/predicators/spot_utils/spot_localization.py @@ -55,8 +55,10 @@ def __init__(self, robot: Robot, upload_path: Path, self._robot_pose = math_helpers.SE3Pose(0, 0, 0, math_helpers.Quat()) # Initialize the robot's position in the map. robot_state = get_robot_state(self._robot) - z_position = robot_state.kinematic_state.transforms_snapshot.child_to_parent_edge_map["gpe"].parent_tform_child.position.z - import pdb; pdb.set_trace() + z_position = robot_state.kinematic_state.transforms_snapshot.child_to_parent_edge_map[ + "gpe"].parent_tform_child.position.z + import pdb + pdb.set_trace() # current_odom_tform_body = get_odom_tform_body( # robot_state.kinematic_state.transforms_snapshot).to_proto() # localization = nav_pb2.Localization() diff --git a/predicators/structs.py b/predicators/structs.py index 185deecdf..5309fc4ee 100644 --- a/predicators/structs.py +++ b/predicators/structs.py @@ -492,7 +492,8 @@ def __post_init__(self) -> None: def goal_holds(self, state: State, vlm: Optional[Any] = None) -> bool: """Return whether the goal of this task holds in the given state.""" from predicators.utils import query_vlm_for_atom_vals - vlm_atoms = set(atom for atom in self.goal if isinstance(atom.predicate, VLMPredicate)) + vlm_atoms = set(atom for atom in self.goal + if isinstance(atom.predicate, VLMPredicate)) for atom in self.goal: if atom not in vlm_atoms: if not atom.holds(state): diff --git a/predicators/utils.py b/predicators/utils.py index 1aed539f3..b13201b3c 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2494,6 +2494,73 @@ def parse_model_output_into_option_plan( return option_plan +def get_prompt_for_vlm_state_labelling( + prompt_type: str, atoms_list: List[str], label_history: List[str], + imgs_history: List[List[PIL.Image.Image]], + cropped_imgs_history: List[List[PIL.Image.Image]], + skill_history: List[Action]) -> Tuple[str, List[PIL.Image.Image]]: + """Prompt for generating labels for an entire trajectory. Similar to the + above prompting method, this outputs a list of prompts to label the state + at each timestep of traj with atom values). + + Note that all our prompts are saved as separate txt files under the + 'vlm_input_data_prompts/atom_labelling' folder. + """ + # Load the pre-specified prompt. + filepath_prefix = get_path_to_predicators_root() + \ + "/predicators/datasets/vlm_input_data_prompts/atom_proposal/" + try: + with open(filepath_prefix + + CFG.grammar_search_vlm_atom_label_prompt_type + ".txt", + "r", + encoding="utf-8") as f: + prompt = f.read() + except FileNotFoundError: + raise ValueError("Unknown VLM prompting option " + + f"{CFG.grammar_search_vlm_atom_label_prompt_type}") + # The prompt ends with a section for 'Predicates', so list these. + for atom_str in atoms_list: + prompt += f"\n{atom_str}" + + if "img_option_diffs" in prompt_type: + # In this case, we need to load the 'per_scene_naive' prompt as well + # for the first timestep. + with open(filepath_prefix + "per_scene_naive.txt", + "r", + encoding="utf-8") as f: + init_prompt = f.read() + for atom_str in atoms_list: + init_prompt += f"\n{atom_str}" + if len(label_history) == 0: + return (init_prompt, imgs_history[0]) + # Now, we use actual difference-based prompting for the second timestep + # and beyond. + curr_prompt = prompt[:] + curr_prompt_imgs = [ + imgs_timestep for imgs_timestep in imgs_history[-1] + ] + if CFG.vlm_include_cropped_images: + if CFG.env in ["burger", "burger_no_move"]: # pragma: no cover + curr_prompt_imgs.extend( + [cropped_imgs_history[-1][1], cropped_imgs_history[-1][0]]) + else: + raise NotImplementedError( + f"Cropped images not implemented for {CFG.env}.") + curr_prompt += "\n\nSkill executed between states: " + skill_name = skill_history[-1].name + str(skill_history[-1].objects) + curr_prompt += skill_name + if "label_history" in prompt_type: + curr_prompt += "\n\nPredicate values in the first scene, " \ + "before the skill was executed: \n" + curr_prompt += label_history[-1] + return (curr_prompt, curr_prompt_imgs) + else: + # NOTE: we rip out only the first image from each trajectory + # which is fine for most domains, but will be problematic for + # situations in which there is more than one image per state. + return (prompt, [imgs_history[-1][0]]) + + def query_vlm_for_atom_vals( vlm_atoms: Collection[GroundAtom], state: State, @@ -2505,17 +2572,22 @@ def query_vlm_for_atom_vals( # vlm can be called on. assert state.simulator_state is not None assert isinstance(state.simulator_state["images"], List) + if "vlm_atoms_history" not in state.simulator_state: + state.simulator_state["vlm_atoms_history"] = [] imgs = state.simulator_state["images"] + previous_states = [] + # We assume the state.simulator_state contains a list of previous states. + if "state_history" in state.simulator_state: + previous_states = state.simulator_state["state_history"] + state_imgs_history = [ + state.simulator_state["images"] for state in previous_states + ] vlm_atoms = sorted(vlm_atoms) - atom_queries_str = "\n* " - atom_queries_str += "\n* ".join(atom.get_vlm_query_str() - for atom in vlm_atoms) - filepath_to_vlm_prompt = get_path_to_predicators_root() + \ - "/predicators/datasets/vlm_input_data_prompts/atom_labelling/" + \ - "per_scene_naive.txt" - with open(filepath_to_vlm_prompt, "r", encoding="utf-8") as f: - vlm_query_str = f.read() - vlm_query_str += atom_queries_str + atom_queries_str = [atom.get_vlm_query_str() for atom in vlm_atoms] + vlm_query_str, imgs = get_prompt_for_vlm_state_labelling( + CFG.vlm_test_time_atom_label_prompt_type, atom_queries_str, + state.simulator_state["vlm_atoms_history"], state_imgs_history, [], + state.simulator_state["skill_history"]) if vlm is None: vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover. vlm_input_imgs = \ @@ -2530,7 +2602,6 @@ def query_vlm_for_atom_vals( print(f"VLM output: {vlm_output_str}") all_atom_queries = atom_queries_str.strip().split("\n") all_vlm_responses = vlm_output_str.strip().split("\n") - # NOTE: this assumption is likely too brittle; if this is breaking, feel # free to remove/adjust this and change the below parsing loop accordingly! assert len(all_atom_queries) == len(all_vlm_responses) @@ -2542,6 +2613,8 @@ def query_vlm_for_atom_vals( if curr_vlm_output_line[len(atom_query + ":"):period_idx].lower().strip() == "true": true_atoms.add(vlm_atoms[i]) + # Add the text of the VLM's response to the state, to be used in the future! + state.simulator_state["vlm_atoms_history"].append(all_vlm_responses) return true_atoms diff --git a/tests/envs/test_spot_envs.py b/tests/envs/test_spot_envs.py index f05deceff..f723f4f13 100644 --- a/tests/envs/test_spot_envs.py +++ b/tests/envs/test_spot_envs.py @@ -31,6 +31,7 @@ def test_spot_vlm_debug(): }) pass + def test_spot_env_dry_run(): """Dry run tests (do not require access to robot).""" utils.reset_config({ @@ -271,13 +272,15 @@ def real_robot_cube_env_test() -> None: "test_task_json_dir": args.get("test_task_json_dir", None), }) - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() rng = np.random.default_rng(123) import os os.environ["BOSDYN_CLIENT_USERNAME"] = "user" os.environ["BOSDYN_CLIENT_PASSWORD"] = "bbbdddaaaiii" env = SpotCubeEnv() - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() perceiver = SpotPerceiver() nsrts = get_gt_nsrts(env.get_name(), env.predicates, get_gt_options(env.get_name())) @@ -286,7 +289,8 @@ def real_robot_cube_env_test() -> None: task = env.get_test_tasks()[0] obs = env.reset("test", 0) perceiver.reset(task) - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() assert len(obs.objects_in_view) == 4 cube, floor, table1, table2 = sorted(obs.objects_in_view) assert cube.name == "cube"