Skip to content

Commit

Permalink
Fix bugs related to image history to VLM and abstracting the state.
Browse files Browse the repository at this point in the history
  • Loading branch information
ashay-bdai committed Sep 17, 2024
1 parent 71fe6d3 commit 5232fa2
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 12 deletions.
2 changes: 1 addition & 1 deletion predicators/approaches/bilevel_planning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]:
if self._plan_without_sim:
nsrt_plan, atoms_seq, metrics = self._run_task_plan(
task, nsrts, preds, timeout, seed)
# import pdb; pdb.set_trace()
import pdb; pdb.set_trace()
self._last_nsrt_plan = nsrt_plan
self._last_atoms_seq = atoms_seq
policy = utils.nsrt_plan_to_greedy_policy(nsrt_plan, task.goal,
Expand Down
2 changes: 1 addition & 1 deletion predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2485,7 +2485,7 @@ def _create_operators(self) -> Iterator[STRIPSOperator]:
# Pick object
robot = Variable("?robot", _robot_type)
obj = Variable("?object", _movable_object_type)
table = Variable("?table", _immovable_object_type)
table = Variable("?table", _movable_object_type)
parameters = [robot, obj, table]
preconds: Set[LiftedAtom] = {
LiftedAtom(_HandEmpty, [robot]),
Expand Down
24 changes: 16 additions & 8 deletions predicators/perception/spot_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,9 @@ def step(self, observation: Observation) -> State:
# At the first timestep, these histories will be empty due to self.reset().
# But at every timestep that isn't the first one, they will be non-empty.
self._curr_state.simulator_state["state_history"] = list(self._state_history)
# We do this here so the call to `utils.abstract()` a few lines later has the skill
# that was just run.
self._executed_skill_history.append(observation.executed_skill) # None in first timestep.
self._curr_state.simulator_state["skill_history"] = list(self._executed_skill_history)
self._curr_state.simulator_state["vlm_label_history"] = list(self._vlm_label_history)

Expand All @@ -753,6 +756,7 @@ def step(self, observation: Observation) -> State:
assert self._curr_env is not None
preds = self._curr_env.predicates
state_copy = self._curr_state.copy()
print(f"Right before abstract state, skill in obs: {observation.executed_skill}")
abstract_state = utils.abstract(state_copy, preds)
self._curr_state.simulator_state["abstract_state"] = abstract_state
# Compute all the VLM atoms. `utils.abstract()` only returns the ones that
Expand All @@ -763,7 +767,6 @@ def step(self, observation: Observation) -> State:
for choice in utils.get_object_combinations(list(state_copy), pred.types):
vlm_atoms.add(GroundAtom(pred, choice))
vlm_atoms = sorted(vlm_atoms)
atom_queries_list = [atom.get_vlm_query_str() for atom in vlm_atoms]
reconstructed_all_vlm_responses = []
for atom in vlm_atoms:
if atom in abstract_state:
Expand All @@ -772,17 +775,17 @@ def step(self, observation: Observation) -> State:
truth_value = 'False'
atom_label = f"* {atom.get_vlm_query_str()}: {truth_value}"
reconstructed_all_vlm_responses.append(atom_label)
self._vlm_label_history.append(reconstructed_all_vlm_responses)
str_vlm_response = '\n'.join(reconstructed_all_vlm_responses)
self._vlm_label_history.append(str_vlm_response)
self._state_history.append(self._curr_state.copy())
self._executed_skill_history.append(observation.executed_skill) # None in first timestep.
return self._curr_state.copy()

def _create_state(self) -> State:
if self._waiting_for_observation:
return DefaultState
# Build the continuous part of the state.
assert self._robot is not None
table = Object("table", _immovable_object_type)
table = Object("table", _movable_object_type)
cup = Object("cup", _movable_object_type)
pan = Object("pan", _container_type)
# bread = Object("bread", _movable_object_type)
Expand Down Expand Up @@ -812,8 +815,13 @@ def _create_state(self) -> State:
"height": 0,
"width" : 0,
"length": 0,
"object_id": 1,
"flat_top_surface": 1
"object_id": 0,
"placeable": 1,
"held": 0,
"lost": 0,
"in_hand_view": 0,
"in_view": 1,
"is_sweeper": 0
},
cup: {
"x": 0,
Expand All @@ -827,7 +835,7 @@ def _create_state(self) -> State:
"height": 0,
"width": 0,
"length": 0,
"object_id": 2,
"object_id": 1,
"placeable": 1,
"held": 0,
"lost": 0,
Expand Down Expand Up @@ -922,7 +930,7 @@ def _create_state(self) -> State:
"height": 0,
"width" : 0,
"length": 0,
"object_id": 3,
"object_id": 2,
"placeable": 1,
"held": 0,
"lost": 0,
Expand Down
2 changes: 1 addition & 1 deletion predicators/spot_utils/perception/spot_cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
}
RGB_TO_DEPTH_CAMERAS = {
# "hand_color_image": "hand_depth_in_hand_color_frame",
# "left_fisheye_image": "left_depth_in_visual_frame",
"left_fisheye_image": "left_depth_in_visual_frame",
# "right_fisheye_image": "right_depth_in_visual_frame",
# "frontleft_fisheye_image": "frontleft_depth_in_visual_frame",
"frontright_fisheye_image": "frontright_depth_in_visual_frame",
Expand Down
3 changes: 3 additions & 0 deletions predicators/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,9 @@ def __post_init__(self) -> None:

def goal_holds(self, state: State, vlm: Optional[Any] = None) -> bool:
"""Return whether the goal of this task holds in the given state."""
if "abstract_state" in state.simulator_state:
abstract_state = state.simulator_state["abstract_state"]
return self.goal.issubset(abstract_state)
from predicators.utils import query_vlm_for_atom_vals
vlm_atoms = set(atom for atom in self.goal
if isinstance(atom.predicate, VLMPredicate))
Expand Down
5 changes: 4 additions & 1 deletion predicators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2591,7 +2591,7 @@ def query_vlm_for_atom_vals(
if "state_history" in state.simulator_state:
prev_states = state.simulator_state["state_history"]
prev_states_imgs_history = [s.simulator_state["images"] for s in prev_states]
images_history = [curr_state_images] + prev_states_imgs_history
images_history = prev_states_imgs_history + [curr_state_images]
skill_history = []
if "skill_history" in state.simulator_state:
skill_history = state.simulator_state["skill_history"]
Expand All @@ -2606,6 +2606,7 @@ def query_vlm_for_atom_vals(
vlm_query_str, imgs = get_prompt_for_vlm_state_labelling(
CFG.vlm_test_time_atom_label_prompt_type, atom_queries_list,
label_history, images_history, [], skill_history)
import pdb; pdb.set_trace()
if vlm is None:
vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover.
vlm_input_imgs = \
Expand Down Expand Up @@ -2652,6 +2653,8 @@ def abstract(state: State,
Duplicate arguments in predicates are allowed.
"""
if "abstract_state" in state.simulator_state:
return state.simulator_state["abstract_state"]
# Start by pulling out all VLM predicates.
vlm_preds = set(pred for pred in preds if isinstance(pred, VLMPredicate))
# Next, classify all non-VLM predicates.
Expand Down

0 comments on commit 5232fa2

Please sign in to comment.