Skip to content

Commit

Permalink
Some progress towards fixing vlm atom history.
Browse files Browse the repository at this point in the history
  • Loading branch information
ashay-bdai committed Sep 17, 2024
1 parent fe7da30 commit 4cc0336
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 6 deletions.
48 changes: 42 additions & 6 deletions predicators/perception/spot_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
get_allowed_map_regions, load_spot_metadata, object_to_top_down_geom
from predicators.structs import Action, DefaultState, EnvironmentTask, \
GoalDescription, GroundAtom, Object, Observation, Predicate, \
SpotActionExtraInfo, State, Task, Video, _Option
SpotActionExtraInfo, State, Task, Video, _Option, VLMPredicate


class SpotPerceiver(BasePerceiver):
Expand Down Expand Up @@ -624,13 +624,16 @@ def __init__(self) -> None:
self._ordered_objects: List[Object] = [] # list of all known objects
self._state_history: List[State] = []
self._executed_skill_history: List[_Option] = []
self._vlm_label_history: List[str] = []
self._curr_state = None
# # Keep track of objects that are contained (out of view) in another
# # object, like a bag or bucket. This is important not only for gremlins
# # but also for small changes in the container's perceived pose.
# self._container_to_contained_objects: Dict[Object, Set[Object]] = {}
# Load static, hard-coded features of objects, like their shapes.
# meta = load_spot_metadata()
# self._static_object_features = meta.get("static-object-features", {})


def update_perceiver_with_action(self, action: Action) -> None:
# NOTE: we need to keep track of the previous action
Expand Down Expand Up @@ -679,7 +682,8 @@ def reset(self, env_task: EnvironmentTask) -> Task:
state = self._create_state()
state.simulator_state = {}
state.simulator_state["images"] = []
self._curr_state = state
# self._curr_state = state
self._curr_state = None # this will get set by self.step()
goal = self._create_goal(state, env_task.goal_description)
return Task(state, goal)

Expand Down Expand Up @@ -713,27 +717,59 @@ def step(self, observation: Observation) -> State:
text_bbox = [(x0, y0 - 1.5*text_height), (x0 + text_width + 1, y0)]
draw.rectangle(text_bbox, fill='green')
draw.text((x0 + 1, y0 - 1.5*text_height), text, fill='white', font=font)

# import pdb; pdb.set_trace()

import PIL
from PIL import ImageDraw
annotated_pil_imgs = []
for img, img_name in zip(imgs, img_names):
pil_img = PIL.Image.fromarray(img)
draw = ImageDraw.Draw(pil_img)
font = utils.get_scaled_default_font(draw, 4)
annotated_pil_img = utils.add_text_to_draw_img(
draw, (0, 0), self.camera_name_to_annotation[img_name], font)
annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font)
annotated_pil_imgs.append(pil_img)
annotated_imgs = [np.array(img) for img in annotated_pil_imgs]

self._gripper_open_percentage = observation.gripper_open_percentage

curr_state = self._create_state

self._curr_state = self._create_state()
self._curr_state.simulator_state["images"] = annotated_imgs
ret_state = self._curr_state.copy()
self._state_history.append(ret_state)
ret_state.simulator_state["state_history"] = list(self._state_history)
self._executed_skill_history.append(observation.executed_skill)
ret_state.simulator_state["skill_history"] = list(self._executed_skill_history)

# Save "all_vlm_responses" towards building vlm atom history.
# Any time utils.abstract() is called, e.g. approach or planner,
# we may (depending on flags) want to pass in the vlm atom history
# into the prompt to the VLM.
# We could save `all_vlm_responses` computed internally by
# utils.query_vlm_for_aotm_vals(), but that would require us to
# change how utils.abstract() works. Instead, we'll re-compute the
# `all_vlm_responses` based on the true atoms returned by utils.abstract().
assert self._curr_env is not None
preds = self._curr_env.predicates
state_copy = ret_state.copy() # temporary, to ease debugging
abstract_state = utils.abstract(state_copy, preds)
# We should avoid recomputing the abstract state (VLM noise?) so let's store it in
# the state.
ret_state.simulator_state["abstract_state"] = abstract_state
# Re-compute the VLM labeling for the VLM atoms in this state to store in our
# vlm atom history.
# This code also appears in utils.abstract()
if self._curr_state is not None:
vlm_preds = set(pred for pred in preds if isinstance(pred, VLMPredicate))
vlm_atoms = set()
for pred in vlm_preds:
for choice in utils.get_object_combinations(list(state_copy), pred.types):
vlm_atoms.add(GroundAtom(pred, choice))
vlm_atoms = sorted(vlm_atoms)
import pdb; pdb.set_trace()
ret_state.simulator_state["vlm_atoms_history"].append(abstract_state)
else:
self._curr_state = ret_state.copy()
return ret_state

def _create_state(self) -> State:
Expand Down
5 changes: 5 additions & 0 deletions predicators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2618,6 +2618,11 @@ def query_vlm_for_atom_vals(
breakpoint()
# Add the text of the VLM's response to the state, to be used in the future!
# REMOVE THIS -> AND PUT IT IN THE PERCEIVER
# Perceiver calls utils.abstract once, and puts it in the state history.
# According to a flag, anywhere else we normally call utils.abstract, we
# instead just pull the abstract state from the state simulator state field that has it already.
# The appending of vlm atom history is currently done in query_vlm_for_atom_vals() in utils.py,
# and utils.ground calls that.
state.simulator_state["vlm_atoms_history"].append(all_vlm_responses)

return true_atoms
Expand Down

0 comments on commit 4cc0336

Please sign in to comment.