From 5232fa2b26d31f23c16deb0cd5ec5128e3cc4446 Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Tue, 17 Sep 2024 17:43:30 -0400 Subject: [PATCH] Fix bugs related to image history to VLM and abstracting the state. --- .../approaches/bilevel_planning_approach.py | 2 +- predicators/envs/spot_env.py | 2 +- predicators/perception/spot_perceiver.py | 24 ++++++++++++------- .../spot_utils/perception/spot_cameras.py | 2 +- predicators/structs.py | 3 +++ predicators/utils.py | 5 +++- 6 files changed, 26 insertions(+), 12 deletions(-) diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index fc45a70de..e1c45b083 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -65,7 +65,7 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: if self._plan_without_sim: nsrt_plan, atoms_seq, metrics = self._run_task_plan( task, nsrts, preds, timeout, seed) - # import pdb; pdb.set_trace() + import pdb; pdb.set_trace() self._last_nsrt_plan = nsrt_plan self._last_atoms_seq = atoms_seq policy = utils.nsrt_plan_to_greedy_policy(nsrt_plan, task.goal, diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index 800c44fd4..60cfeab20 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -2485,7 +2485,7 @@ def _create_operators(self) -> Iterator[STRIPSOperator]: # Pick object robot = Variable("?robot", _robot_type) obj = Variable("?object", _movable_object_type) - table = Variable("?table", _immovable_object_type) + table = Variable("?table", _movable_object_type) parameters = [robot, obj, table] preconds: Set[LiftedAtom] = { LiftedAtom(_HandEmpty, [robot]), diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index db177f3fb..44c1b3922 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -739,6 +739,9 @@ def step(self, observation: Observation) -> State: # At the first timestep, these histories will be empty due to self.reset(). # But at every timestep that isn't the first one, they will be non-empty. self._curr_state.simulator_state["state_history"] = list(self._state_history) + # We do this here so the call to `utils.abstract()` a few lines later has the skill + # that was just run. + self._executed_skill_history.append(observation.executed_skill) # None in first timestep. self._curr_state.simulator_state["skill_history"] = list(self._executed_skill_history) self._curr_state.simulator_state["vlm_label_history"] = list(self._vlm_label_history) @@ -753,6 +756,7 @@ def step(self, observation: Observation) -> State: assert self._curr_env is not None preds = self._curr_env.predicates state_copy = self._curr_state.copy() + print(f"Right before abstract state, skill in obs: {observation.executed_skill}") abstract_state = utils.abstract(state_copy, preds) self._curr_state.simulator_state["abstract_state"] = abstract_state # Compute all the VLM atoms. `utils.abstract()` only returns the ones that @@ -763,7 +767,6 @@ def step(self, observation: Observation) -> State: for choice in utils.get_object_combinations(list(state_copy), pred.types): vlm_atoms.add(GroundAtom(pred, choice)) vlm_atoms = sorted(vlm_atoms) - atom_queries_list = [atom.get_vlm_query_str() for atom in vlm_atoms] reconstructed_all_vlm_responses = [] for atom in vlm_atoms: if atom in abstract_state: @@ -772,9 +775,9 @@ def step(self, observation: Observation) -> State: truth_value = 'False' atom_label = f"* {atom.get_vlm_query_str()}: {truth_value}" reconstructed_all_vlm_responses.append(atom_label) - self._vlm_label_history.append(reconstructed_all_vlm_responses) + str_vlm_response = '\n'.join(reconstructed_all_vlm_responses) + self._vlm_label_history.append(str_vlm_response) self._state_history.append(self._curr_state.copy()) - self._executed_skill_history.append(observation.executed_skill) # None in first timestep. return self._curr_state.copy() def _create_state(self) -> State: @@ -782,7 +785,7 @@ def _create_state(self) -> State: return DefaultState # Build the continuous part of the state. assert self._robot is not None - table = Object("table", _immovable_object_type) + table = Object("table", _movable_object_type) cup = Object("cup", _movable_object_type) pan = Object("pan", _container_type) # bread = Object("bread", _movable_object_type) @@ -812,8 +815,13 @@ def _create_state(self) -> State: "height": 0, "width" : 0, "length": 0, - "object_id": 1, - "flat_top_surface": 1 + "object_id": 0, + "placeable": 1, + "held": 0, + "lost": 0, + "in_hand_view": 0, + "in_view": 1, + "is_sweeper": 0 }, cup: { "x": 0, @@ -827,7 +835,7 @@ def _create_state(self) -> State: "height": 0, "width": 0, "length": 0, - "object_id": 2, + "object_id": 1, "placeable": 1, "held": 0, "lost": 0, @@ -922,7 +930,7 @@ def _create_state(self) -> State: "height": 0, "width" : 0, "length": 0, - "object_id": 3, + "object_id": 2, "placeable": 1, "held": 0, "lost": 0, diff --git a/predicators/spot_utils/perception/spot_cameras.py b/predicators/spot_utils/perception/spot_cameras.py index f39dbc0cc..a1a9780e9 100644 --- a/predicators/spot_utils/perception/spot_cameras.py +++ b/predicators/spot_utils/perception/spot_cameras.py @@ -24,7 +24,7 @@ } RGB_TO_DEPTH_CAMERAS = { # "hand_color_image": "hand_depth_in_hand_color_frame", - # "left_fisheye_image": "left_depth_in_visual_frame", + "left_fisheye_image": "left_depth_in_visual_frame", # "right_fisheye_image": "right_depth_in_visual_frame", # "frontleft_fisheye_image": "frontleft_depth_in_visual_frame", "frontright_fisheye_image": "frontright_depth_in_visual_frame", diff --git a/predicators/structs.py b/predicators/structs.py index 5309fc4ee..d1c8e1880 100644 --- a/predicators/structs.py +++ b/predicators/structs.py @@ -491,6 +491,9 @@ def __post_init__(self) -> None: def goal_holds(self, state: State, vlm: Optional[Any] = None) -> bool: """Return whether the goal of this task holds in the given state.""" + if "abstract_state" in state.simulator_state: + abstract_state = state.simulator_state["abstract_state"] + return self.goal.issubset(abstract_state) from predicators.utils import query_vlm_for_atom_vals vlm_atoms = set(atom for atom in self.goal if isinstance(atom.predicate, VLMPredicate)) diff --git a/predicators/utils.py b/predicators/utils.py index 18172d4e8..b8ea1e97e 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2591,7 +2591,7 @@ def query_vlm_for_atom_vals( if "state_history" in state.simulator_state: prev_states = state.simulator_state["state_history"] prev_states_imgs_history = [s.simulator_state["images"] for s in prev_states] - images_history = [curr_state_images] + prev_states_imgs_history + images_history = prev_states_imgs_history + [curr_state_images] skill_history = [] if "skill_history" in state.simulator_state: skill_history = state.simulator_state["skill_history"] @@ -2606,6 +2606,7 @@ def query_vlm_for_atom_vals( vlm_query_str, imgs = get_prompt_for_vlm_state_labelling( CFG.vlm_test_time_atom_label_prompt_type, atom_queries_list, label_history, images_history, [], skill_history) + import pdb; pdb.set_trace() if vlm is None: vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover. vlm_input_imgs = \ @@ -2652,6 +2653,8 @@ def abstract(state: State, Duplicate arguments in predicates are allowed. """ + if "abstract_state" in state.simulator_state: + return state.simulator_state["abstract_state"] # Start by pulling out all VLM predicates. vlm_preds = set(pred for pred in preds if isinstance(pred, VLMPredicate)) # Next, classify all non-VLM predicates.