Skip to content

Commit

Permalink
Update and fix computation of history and how it gets passed into the…
Browse files Browse the repository at this point in the history
… VLM query.
  • Loading branch information
ashay-bdai committed Sep 17, 2024
1 parent 41b58ea commit 71fe6d3
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 74 deletions.
1 change: 0 additions & 1 deletion predicators/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,6 @@ def _run_testing(env: BaseEnv, cogman: CogMan) -> Metrics:
metrics: Metrics = defaultdict(float)
curr_num_nodes_created = 0.0
curr_num_nodes_expanded = 0.0
import pdb; pdb.set_trace()
for test_task_idx, env_task in enumerate(test_tasks):
solve_start = time.perf_counter()
try:
Expand Down
68 changes: 13 additions & 55 deletions predicators/perception/spot_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,12 +730,8 @@ def step(self, observation: Observation) -> State:
# annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font)
# annotated_pil_imgs.append(pil_img)
annotated_imgs = [np.array(img) for img in pil_imgs]

self._gripper_open_percentage = observation.gripper_open_percentage

# check if self._curr_state is what we expect it to be.
import pdb; pdb.set_trace()

self._curr_state = self._create_state()
# This state is a default/empty. We have to set the attributes
# of the objects and set the simulator state properly.
Expand All @@ -756,7 +752,7 @@ def step(self, observation: Observation) -> State:
# in planning.
assert self._curr_env is not None
preds = self._curr_env.predicates
state_copy = self._curr_env.copy()
state_copy = self._curr_state.copy()
abstract_state = utils.abstract(state_copy, preds)
self._curr_state.simulator_state["abstract_state"] = abstract_state
# Compute all the VLM atoms. `utils.abstract()` only returns the ones that
Expand All @@ -767,57 +763,19 @@ def step(self, observation: Observation) -> State:
for choice in utils.get_object_combinations(list(state_copy), pred.types):
vlm_atoms.add(GroundAtom(pred, choice))
vlm_atoms = sorted(vlm_atoms)
import pdb; pdb.set_trace()

atom_queries_list = [atom.get_vlm_query_str() for atom in vlm_atoms]
reconstructed_all_vlm_responses = []
for atom in vlm_atoms:
if atom in abstract_state:
truth_value = 'True'
else:
truth_value = 'False'
atom_label = f"* {atom.get_vlm_query_str()}: {truth_value}"
reconstructed_all_vlm_responses.append(atom_label)
self._vlm_label_history.append(reconstructed_all_vlm_responses)
self._state_history.append(self._curr_state.copy())
# The executed skill will be `None` in the first timestep.
# This should be handled in the function that processes the
# history when passing it to the VLM.
self._executed_skill_history.append(observation.executed_skill)

#############################



curr_state = self._create_state
self._curr_state = self._create_state()
self._curr_state.simulator_state["images"] = annotated_imgs
ret_state = self._curr_state.copy()
self._state_history.append(ret_state)
ret_state.simulator_state["state_history"] = list(self._state_history)
self._executed_skill_history.append(observation.executed_skill)
ret_state.simulator_state["skill_history"] = list(self._executed_skill_history)

# Save "all_vlm_responses" towards building vlm atom history.
# Any time utils.abstract() is called, e.g. approach or planner,
# we may (depending on flags) want to pass in the vlm atom history
# into the prompt to the VLM.
# We could save `all_vlm_responses` computed internally by
# utils.query_vlm_for_aotm_vals(), but that would require us to
# change how utils.abstract() works. Instead, we'll re-compute the
# `all_vlm_responses` based on the true atoms returned by utils.abstract().
assert self._curr_env is not None
preds = self._curr_env.predicates
state_copy = ret_state.copy() # temporary, to ease debugging
abstract_state = utils.abstract(state_copy, preds)
# We should avoid recomputing the abstract state (VLM noise?) so let's store it in
# the state.
ret_state.simulator_state["abstract_state"] = abstract_state
# Re-compute the VLM labeling for the VLM atoms in this state to store in our
# vlm atom history.
# This code also appears in utils.abstract()
if self._curr_state is not None:
vlm_preds = set(pred for pred in preds if isinstance(pred, VLMPredicate))
vlm_atoms = set()
for pred in vlm_preds:
for choice in utils.get_object_combinations(list(state_copy), pred.types):
vlm_atoms.add(GroundAtom(pred, choice))
vlm_atoms = sorted(vlm_atoms)
import pdb; pdb.set_trace()
ret_state.simulator_state["vlm_atoms_history"].append(abstract_state)
else:
self._curr_state = ret_state.copy()
return ret_state
self._executed_skill_history.append(observation.executed_skill) # None in first timestep.
return self._curr_state.copy()

def _create_state(self) -> State:
if self._waiting_for_observation:
Expand Down
51 changes: 33 additions & 18 deletions predicators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2499,7 +2499,6 @@ def get_prompt_for_vlm_state_labelling(
imgs_history: List[List[PIL.Image.Image]],
cropped_imgs_history: List[List[PIL.Image.Image]],
skill_history: List[Action]) -> Tuple[str, List[PIL.Image.Image]]:
# import pdb; pdb.set_trace()
"""Prompt for generating labels for an entire trajectory. Similar to the
above prompting method, this outputs a list of prompts to label the state
at each timestep of traj with atom values).
Expand All @@ -2517,7 +2516,6 @@ def get_prompt_for_vlm_state_labelling(
encoding="utf-8") as f:
prompt = f.read()
except FileNotFoundError:
import pdb; pdb.set_trace()
raise ValueError("Unknown VLM prompting option " +
f"{CFG.grammar_search_vlm_atom_label_prompt_type}")
# The prompt ends with a section for 'Predicates', so list these.
Expand Down Expand Up @@ -2574,22 +2572,40 @@ def query_vlm_for_atom_vals(
# vlm can be called on.
assert state.simulator_state is not None
assert isinstance(state.simulator_state["images"], List)
if "vlm_atoms_history" not in state.simulator_state:
state.simulator_state["vlm_atoms_history"] = []
imgs = state.simulator_state["images"]
previous_states = []
# We assume the state.simulator_state contains a list of previous states.
if "state_history" in state.simulator_state:
previous_states = state.simulator_state["state_history"]
state_imgs_history = [
state.simulator_state["images"] for state in previous_states
]
# if "vlm_atoms_history" not in state.simulator_state:
# state.simulator_state["vlm_atoms_history"] = []
# imgs = state.simulator_state["images"]
# previous_states = []
# # We assume the state.simulator_state contains a list of previous states.
# if "state_history" in state.simulator_state:
# previous_states = state.simulator_state["state_history"]
# state_imgs_history = [
# state.simulator_state["images"] for state in previous_states
# ]
vlm_atoms = sorted(vlm_atoms)
atom_queries_list = [atom.get_vlm_query_str() for atom in vlm_atoms]
# All "history" fields in the simulator state contain things from
# previous states -- not the current state.
# We want the image history to include the images from the current state.
curr_state_images = state.simulator_state["images"]
if "state_history" in state.simulator_state:
prev_states = state.simulator_state["state_history"]
prev_states_imgs_history = [s.simulator_state["images"] for s in prev_states]
images_history = [curr_state_images] + prev_states_imgs_history
skill_history = []
if "skill_history" in state.simulator_state:
skill_history = state.simulator_state["skill_history"]
label_history = []
if "vlm_label_history" in state.simulator_state:
label_history = state.simulator_state["vlm_label_history"]

# vlm_query_str, imgs = get_prompt_for_vlm_state_labelling(
# CFG.vlm_test_time_atom_label_prompt_type, atom_queries_list,
# state.simulator_state["vlm_atoms_history"], state_imgs_history, [],
# state.simulator_state["skill_history"])
vlm_query_str, imgs = get_prompt_for_vlm_state_labelling(
CFG.vlm_test_time_atom_label_prompt_type, atom_queries_list,
state.simulator_state["vlm_atoms_history"], state_imgs_history, [],
state.simulator_state["skill_history"])
label_history, images_history, [], skill_history)
if vlm is None:
vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover.
vlm_input_imgs = \
Expand All @@ -2603,7 +2619,6 @@ def query_vlm_for_atom_vals(
vlm_output_str = vlm_output[0]
print(f"VLM output: {vlm_output_str}")
all_vlm_responses = vlm_output_str.strip().split("\n")
# import pdb; pdb.set_trace()
# NOTE: this assumption is likely too brittle; if this is breaking, feel
# free to remove/adjust this and change the below parsing loop accordingly!
assert len(atom_queries_list) == len(all_vlm_responses)
Expand All @@ -2612,8 +2627,9 @@ def query_vlm_for_atom_vals(
assert atom_query + ":" in curr_vlm_output_line
assert "." in curr_vlm_output_line
period_idx = curr_vlm_output_line.find(".")
if curr_vlm_output_line[len(atom_query +
":"):period_idx].lower().strip() == "true":
# value = curr_vlm_output_line[len(atom_query + ":"):period_idx].lower().strip()
value = curr_vlm_output_line.split(': ')[-1].strip('.').lower()
if value == "true":
true_atoms.add(vlm_atoms[i])

# breakpoint()
Expand All @@ -2625,7 +2641,6 @@ def query_vlm_for_atom_vals(
# The appending of vlm atom history is currently done in query_vlm_for_atom_vals() in utils.py,
# and utils.ground calls that.
# state.simulator_state["vlm_atoms_history"].append(all_vlm_responses)

return true_atoms


Expand Down

0 comments on commit 71fe6d3

Please sign in to comment.