From 4cc03364e7e8d407a52a7dfd3fb677aa7a0dbcec Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Mon, 16 Sep 2024 22:24:48 -0400 Subject: [PATCH] Some progress towards fixing vlm atom history. --- predicators/perception/spot_perceiver.py | 48 +++++++++++++++++++++--- predicators/utils.py | 5 +++ 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index 201df39f2..54e0f513b 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -26,7 +26,7 @@ get_allowed_map_regions, load_spot_metadata, object_to_top_down_geom from predicators.structs import Action, DefaultState, EnvironmentTask, \ GoalDescription, GroundAtom, Object, Observation, Predicate, \ - SpotActionExtraInfo, State, Task, Video, _Option + SpotActionExtraInfo, State, Task, Video, _Option, VLMPredicate class SpotPerceiver(BasePerceiver): @@ -624,6 +624,8 @@ def __init__(self) -> None: self._ordered_objects: List[Object] = [] # list of all known objects self._state_history: List[State] = [] self._executed_skill_history: List[_Option] = [] + self._vlm_label_history: List[str] = [] + self._curr_state = None # # Keep track of objects that are contained (out of view) in another # # object, like a bag or bucket. This is important not only for gremlins # # but also for small changes in the container's perceived pose. @@ -631,6 +633,7 @@ def __init__(self) -> None: # Load static, hard-coded features of objects, like their shapes. # meta = load_spot_metadata() # self._static_object_features = meta.get("static-object-features", {}) + def update_perceiver_with_action(self, action: Action) -> None: # NOTE: we need to keep track of the previous action @@ -679,7 +682,8 @@ def reset(self, env_task: EnvironmentTask) -> Task: state = self._create_state() state.simulator_state = {} state.simulator_state["images"] = [] - self._curr_state = state + # self._curr_state = state + self._curr_state = None # this will get set by self.step() goal = self._create_goal(state, env_task.goal_description) return Task(state, goal) @@ -713,8 +717,7 @@ def step(self, observation: Observation) -> State: text_bbox = [(x0, y0 - 1.5*text_height), (x0 + text_width + 1, y0)] draw.rectangle(text_bbox, fill='green') draw.text((x0 + 1, y0 - 1.5*text_height), text, fill='white', font=font) - - # import pdb; pdb.set_trace() + import PIL from PIL import ImageDraw annotated_pil_imgs = [] @@ -722,11 +725,14 @@ def step(self, observation: Observation) -> State: pil_img = PIL.Image.fromarray(img) draw = ImageDraw.Draw(pil_img) font = utils.get_scaled_default_font(draw, 4) - annotated_pil_img = utils.add_text_to_draw_img( - draw, (0, 0), self.camera_name_to_annotation[img_name], font) + annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font) annotated_pil_imgs.append(pil_img) annotated_imgs = [np.array(img) for img in annotated_pil_imgs] + self._gripper_open_percentage = observation.gripper_open_percentage + + curr_state = self._create_state + self._curr_state = self._create_state() self._curr_state.simulator_state["images"] = annotated_imgs ret_state = self._curr_state.copy() @@ -734,6 +740,36 @@ def step(self, observation: Observation) -> State: ret_state.simulator_state["state_history"] = list(self._state_history) self._executed_skill_history.append(observation.executed_skill) ret_state.simulator_state["skill_history"] = list(self._executed_skill_history) + + # Save "all_vlm_responses" towards building vlm atom history. + # Any time utils.abstract() is called, e.g. approach or planner, + # we may (depending on flags) want to pass in the vlm atom history + # into the prompt to the VLM. + # We could save `all_vlm_responses` computed internally by + # utils.query_vlm_for_aotm_vals(), but that would require us to + # change how utils.abstract() works. Instead, we'll re-compute the + # `all_vlm_responses` based on the true atoms returned by utils.abstract(). + assert self._curr_env is not None + preds = self._curr_env.predicates + state_copy = ret_state.copy() # temporary, to ease debugging + abstract_state = utils.abstract(state_copy, preds) + # We should avoid recomputing the abstract state (VLM noise?) so let's store it in + # the state. + ret_state.simulator_state["abstract_state"] = abstract_state + # Re-compute the VLM labeling for the VLM atoms in this state to store in our + # vlm atom history. + # This code also appears in utils.abstract() + if self._curr_state is not None: + vlm_preds = set(pred for pred in preds if isinstance(pred, VLMPredicate)) + vlm_atoms = set() + for pred in vlm_preds: + for choice in utils.get_object_combinations(list(state_copy), pred.types): + vlm_atoms.add(GroundAtom(pred, choice)) + vlm_atoms = sorted(vlm_atoms) + import pdb; pdb.set_trace() + ret_state.simulator_state["vlm_atoms_history"].append(abstract_state) + else: + self._curr_state = ret_state.copy() return ret_state def _create_state(self) -> State: diff --git a/predicators/utils.py b/predicators/utils.py index f4e32a3b7..9d3f3f723 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2618,6 +2618,11 @@ def query_vlm_for_atom_vals( breakpoint() # Add the text of the VLM's response to the state, to be used in the future! # REMOVE THIS -> AND PUT IT IN THE PERCEIVER + # Perceiver calls utils.abstract once, and puts it in the state history. + # According to a flag, anywhere else we normally call utils.abstract, we + # instead just pull the abstract state from the state simulator state field that has it already. + # The appending of vlm atom history is currently done in query_vlm_for_atom_vals() in utils.py, + # and utils.ground calls that. state.simulator_state["vlm_atoms_history"].append(all_vlm_responses) return true_atoms