From 71fe6d390db9afc5b044d11ef5246917be349849 Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Tue, 17 Sep 2024 16:21:04 -0400 Subject: [PATCH] Update and fix computation of history and how it gets passed into the VLM query. --- predicators/main.py | 1 - predicators/perception/spot_perceiver.py | 68 +++++------------------- predicators/utils.py | 51 +++++++++++------- 3 files changed, 46 insertions(+), 74 deletions(-) diff --git a/predicators/main.py b/predicators/main.py index 08748d681c..173b660d8f 100644 --- a/predicators/main.py +++ b/predicators/main.py @@ -362,7 +362,6 @@ def _run_testing(env: BaseEnv, cogman: CogMan) -> Metrics: metrics: Metrics = defaultdict(float) curr_num_nodes_created = 0.0 curr_num_nodes_expanded = 0.0 - import pdb; pdb.set_trace() for test_task_idx, env_task in enumerate(test_tasks): solve_start = time.perf_counter() try: diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index b9325aadfe..db177f3fba 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -730,12 +730,8 @@ def step(self, observation: Observation) -> State: # annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font) # annotated_pil_imgs.append(pil_img) annotated_imgs = [np.array(img) for img in pil_imgs] - self._gripper_open_percentage = observation.gripper_open_percentage - # check if self._curr_state is what we expect it to be. - import pdb; pdb.set_trace() - self._curr_state = self._create_state() # This state is a default/empty. We have to set the attributes # of the objects and set the simulator state properly. @@ -756,7 +752,7 @@ def step(self, observation: Observation) -> State: # in planning. assert self._curr_env is not None preds = self._curr_env.predicates - state_copy = self._curr_env.copy() + state_copy = self._curr_state.copy() abstract_state = utils.abstract(state_copy, preds) self._curr_state.simulator_state["abstract_state"] = abstract_state # Compute all the VLM atoms. `utils.abstract()` only returns the ones that @@ -767,57 +763,19 @@ def step(self, observation: Observation) -> State: for choice in utils.get_object_combinations(list(state_copy), pred.types): vlm_atoms.add(GroundAtom(pred, choice)) vlm_atoms = sorted(vlm_atoms) - import pdb; pdb.set_trace() - + atom_queries_list = [atom.get_vlm_query_str() for atom in vlm_atoms] + reconstructed_all_vlm_responses = [] + for atom in vlm_atoms: + if atom in abstract_state: + truth_value = 'True' + else: + truth_value = 'False' + atom_label = f"* {atom.get_vlm_query_str()}: {truth_value}" + reconstructed_all_vlm_responses.append(atom_label) + self._vlm_label_history.append(reconstructed_all_vlm_responses) self._state_history.append(self._curr_state.copy()) - # The executed skill will be `None` in the first timestep. - # This should be handled in the function that processes the - # history when passing it to the VLM. - self._executed_skill_history.append(observation.executed_skill) - - ############################# - - - - curr_state = self._create_state - self._curr_state = self._create_state() - self._curr_state.simulator_state["images"] = annotated_imgs - ret_state = self._curr_state.copy() - self._state_history.append(ret_state) - ret_state.simulator_state["state_history"] = list(self._state_history) - self._executed_skill_history.append(observation.executed_skill) - ret_state.simulator_state["skill_history"] = list(self._executed_skill_history) - - # Save "all_vlm_responses" towards building vlm atom history. - # Any time utils.abstract() is called, e.g. approach or planner, - # we may (depending on flags) want to pass in the vlm atom history - # into the prompt to the VLM. - # We could save `all_vlm_responses` computed internally by - # utils.query_vlm_for_aotm_vals(), but that would require us to - # change how utils.abstract() works. Instead, we'll re-compute the - # `all_vlm_responses` based on the true atoms returned by utils.abstract(). - assert self._curr_env is not None - preds = self._curr_env.predicates - state_copy = ret_state.copy() # temporary, to ease debugging - abstract_state = utils.abstract(state_copy, preds) - # We should avoid recomputing the abstract state (VLM noise?) so let's store it in - # the state. - ret_state.simulator_state["abstract_state"] = abstract_state - # Re-compute the VLM labeling for the VLM atoms in this state to store in our - # vlm atom history. - # This code also appears in utils.abstract() - if self._curr_state is not None: - vlm_preds = set(pred for pred in preds if isinstance(pred, VLMPredicate)) - vlm_atoms = set() - for pred in vlm_preds: - for choice in utils.get_object_combinations(list(state_copy), pred.types): - vlm_atoms.add(GroundAtom(pred, choice)) - vlm_atoms = sorted(vlm_atoms) - import pdb; pdb.set_trace() - ret_state.simulator_state["vlm_atoms_history"].append(abstract_state) - else: - self._curr_state = ret_state.copy() - return ret_state + self._executed_skill_history.append(observation.executed_skill) # None in first timestep. + return self._curr_state.copy() def _create_state(self) -> State: if self._waiting_for_observation: diff --git a/predicators/utils.py b/predicators/utils.py index d41f2c09a2..18172d4e8d 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2499,7 +2499,6 @@ def get_prompt_for_vlm_state_labelling( imgs_history: List[List[PIL.Image.Image]], cropped_imgs_history: List[List[PIL.Image.Image]], skill_history: List[Action]) -> Tuple[str, List[PIL.Image.Image]]: - # import pdb; pdb.set_trace() """Prompt for generating labels for an entire trajectory. Similar to the above prompting method, this outputs a list of prompts to label the state at each timestep of traj with atom values). @@ -2517,7 +2516,6 @@ def get_prompt_for_vlm_state_labelling( encoding="utf-8") as f: prompt = f.read() except FileNotFoundError: - import pdb; pdb.set_trace() raise ValueError("Unknown VLM prompting option " + f"{CFG.grammar_search_vlm_atom_label_prompt_type}") # The prompt ends with a section for 'Predicates', so list these. @@ -2574,22 +2572,40 @@ def query_vlm_for_atom_vals( # vlm can be called on. assert state.simulator_state is not None assert isinstance(state.simulator_state["images"], List) - if "vlm_atoms_history" not in state.simulator_state: - state.simulator_state["vlm_atoms_history"] = [] - imgs = state.simulator_state["images"] - previous_states = [] - # We assume the state.simulator_state contains a list of previous states. - if "state_history" in state.simulator_state: - previous_states = state.simulator_state["state_history"] - state_imgs_history = [ - state.simulator_state["images"] for state in previous_states - ] + # if "vlm_atoms_history" not in state.simulator_state: + # state.simulator_state["vlm_atoms_history"] = [] + # imgs = state.simulator_state["images"] + # previous_states = [] + # # We assume the state.simulator_state contains a list of previous states. + # if "state_history" in state.simulator_state: + # previous_states = state.simulator_state["state_history"] + # state_imgs_history = [ + # state.simulator_state["images"] for state in previous_states + # ] vlm_atoms = sorted(vlm_atoms) atom_queries_list = [atom.get_vlm_query_str() for atom in vlm_atoms] + # All "history" fields in the simulator state contain things from + # previous states -- not the current state. + # We want the image history to include the images from the current state. + curr_state_images = state.simulator_state["images"] + if "state_history" in state.simulator_state: + prev_states = state.simulator_state["state_history"] + prev_states_imgs_history = [s.simulator_state["images"] for s in prev_states] + images_history = [curr_state_images] + prev_states_imgs_history + skill_history = [] + if "skill_history" in state.simulator_state: + skill_history = state.simulator_state["skill_history"] + label_history = [] + if "vlm_label_history" in state.simulator_state: + label_history = state.simulator_state["vlm_label_history"] + + # vlm_query_str, imgs = get_prompt_for_vlm_state_labelling( + # CFG.vlm_test_time_atom_label_prompt_type, atom_queries_list, + # state.simulator_state["vlm_atoms_history"], state_imgs_history, [], + # state.simulator_state["skill_history"]) vlm_query_str, imgs = get_prompt_for_vlm_state_labelling( CFG.vlm_test_time_atom_label_prompt_type, atom_queries_list, - state.simulator_state["vlm_atoms_history"], state_imgs_history, [], - state.simulator_state["skill_history"]) + label_history, images_history, [], skill_history) if vlm is None: vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover. vlm_input_imgs = \ @@ -2603,7 +2619,6 @@ def query_vlm_for_atom_vals( vlm_output_str = vlm_output[0] print(f"VLM output: {vlm_output_str}") all_vlm_responses = vlm_output_str.strip().split("\n") - # import pdb; pdb.set_trace() # NOTE: this assumption is likely too brittle; if this is breaking, feel # free to remove/adjust this and change the below parsing loop accordingly! assert len(atom_queries_list) == len(all_vlm_responses) @@ -2612,8 +2627,9 @@ def query_vlm_for_atom_vals( assert atom_query + ":" in curr_vlm_output_line assert "." in curr_vlm_output_line period_idx = curr_vlm_output_line.find(".") - if curr_vlm_output_line[len(atom_query + - ":"):period_idx].lower().strip() == "true": + # value = curr_vlm_output_line[len(atom_query + ":"):period_idx].lower().strip() + value = curr_vlm_output_line.split(': ')[-1].strip('.').lower() + if value == "true": true_atoms.add(vlm_atoms[i]) # breakpoint() @@ -2625,7 +2641,6 @@ def query_vlm_for_atom_vals( # The appending of vlm atom history is currently done in query_vlm_for_atom_vals() in utils.py, # and utils.ground calls that. # state.simulator_state["vlm_atoms_history"].append(all_vlm_responses) - return true_atoms