Skip to content

Commit

Permalink
Progress towards passing history into VLM at test time.
Browse files Browse the repository at this point in the history
  • Loading branch information
ashay-bdai committed Sep 17, 2024
1 parent 4cc0336 commit 41b58ea
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 68 deletions.
3 changes: 2 additions & 1 deletion predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2451,7 +2451,7 @@ class VLMTestEnv(SpotRearrangementEnv):
def predicates(self) -> Set[Predicate]:
# return set(p for p in _ALL_PREDICATES if p.name in ["VLMOn", "Holding", "HandEmpty", "Pourable", "Toasted", "VLMIn", "Open"])
return set(p for p in _ALL_PREDICATES
if p.name in ["VLMOn", "Holding", "HandEmpty", "Upright"])
if p.name in ["VLMOn", "Holding", "HandEmpty"])

@property
def goal_predicates(self) -> Set[Predicate]:
Expand All @@ -2474,6 +2474,7 @@ def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]:
Object("cup", _movable_object_type),
Object("chair", _movable_object_type),
Object("bowl", _movable_object_type),
Object("table", _movable_object_type),
}
for o in objects:
detection_id = LanguageObjectDetectionID(o.name)
Expand Down
1 change: 1 addition & 0 deletions predicators/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def _run_testing(env: BaseEnv, cogman: CogMan) -> Metrics:
metrics: Metrics = defaultdict(float)
curr_num_nodes_created = 0.0
curr_num_nodes_expanded = 0.0
import pdb; pdb.set_trace()
for test_task_idx, env_task in enumerate(test_tasks):
solve_start = time.perf_counter()
try:
Expand Down
176 changes: 113 additions & 63 deletions predicators/perception/spot_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,13 +635,6 @@ def __init__(self) -> None:
# self._static_object_features = meta.get("static-object-features", {})


def update_perceiver_with_action(self, action: Action) -> None:
# NOTE: we need to keep track of the previous action
# because the step function (where we need knowledge
# of the previous action) occurs *after* the action
# has already been taken.
self._prev_action = action

def _create_goal(self, state: State,
goal_description: GoalDescription) -> Set[GroundAtom]:
del state # not used
Expand Down Expand Up @@ -680,11 +673,20 @@ def reset(self, env_task: EnvironmentTask) -> Task:
# self._curr_state = state
self._curr_env = get_or_create_env(CFG.env)
state = self._create_state()
state.simulator_state = {}
state.simulator_state["images"] = []
# self._curr_state = state
self._curr_state = None # this will get set by self.step()
# state.simulator_state = {}
# state.simulator_state["images"] = []
# state.simulator_state["state_history"] = []
# state.simulator_state["skill_history"] = []
# state.simulator_state["vlm_atoms_history"] = []
self._curr_state = state
goal = self._create_goal(state, env_task.goal_description)

# Reset run-specific things.
self._state_history = []
self._executed_skill_history = []
self._vlm_label_history = []
self._prev_action = None

return Task(state, goal)

def step(self, observation: Observation) -> State:
Expand Down Expand Up @@ -718,21 +720,66 @@ def step(self, observation: Observation) -> State:
draw.rectangle(text_bbox, fill='green')
draw.text((x0 + 1, y0 - 1.5*text_height), text, fill='white', font=font)

import PIL
from PIL import ImageDraw
annotated_pil_imgs = []
for img, img_name in zip(imgs, img_names):
pil_img = PIL.Image.fromarray(img)
draw = ImageDraw.Draw(pil_img)
font = utils.get_scaled_default_font(draw, 4)
annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font)
annotated_pil_imgs.append(pil_img)
annotated_imgs = [np.array(img) for img in annotated_pil_imgs]
# import PIL
# from PIL import ImageDraw
# annotated_pil_imgs = []
# for img, img_name in zip(imgs, img_names):
# pil_img = PIL.Image.fromarray(img)
# draw = ImageDraw.Draw(pil_img)
# font = utils.get_scaled_default_font(draw, 4)
# annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font)
# annotated_pil_imgs.append(pil_img)
annotated_imgs = [np.array(img) for img in pil_imgs]

self._gripper_open_percentage = observation.gripper_open_percentage

curr_state = self._create_state
# check if self._curr_state is what we expect it to be.
import pdb; pdb.set_trace()

self._curr_state = self._create_state()
# This state is a default/empty. We have to set the attributes
# of the objects and set the simulator state properly.
self._curr_state.simulator_state["images"] = annotated_imgs
# At the first timestep, these histories will be empty due to self.reset().
# But at every timestep that isn't the first one, they will be non-empty.
self._curr_state.simulator_state["state_history"] = list(self._state_history)
self._curr_state.simulator_state["skill_history"] = list(self._executed_skill_history)
self._curr_state.simulator_state["vlm_label_history"] = list(self._vlm_label_history)

# Add to histories.
# A bit of extra work is required to build the VLM label history.
# We want to keep `utils.abstract()` as straightforward as possible,
# so we'll "rebuild" the VLM labels from the abstract state
# returned by `utils.abstract()`. And since we call this function,
# we might as well store the abstract state as a part of the simulator
# state so that we don't need to recompute it later in the approach or
# in planning.
assert self._curr_env is not None
preds = self._curr_env.predicates
state_copy = self._curr_env.copy()
abstract_state = utils.abstract(state_copy, preds)
self._curr_state.simulator_state["abstract_state"] = abstract_state
# Compute all the VLM atoms. `utils.abstract()` only returns the ones that
# are True. The remaining ones are the ones that are False.
vlm_preds = set(pred for pred in preds if isinstance(pred, VLMPredicate))
vlm_atoms = set()
for pred in vlm_preds:
for choice in utils.get_object_combinations(list(state_copy), pred.types):
vlm_atoms.add(GroundAtom(pred, choice))
vlm_atoms = sorted(vlm_atoms)
import pdb; pdb.set_trace()

self._state_history.append(self._curr_state.copy())
# The executed skill will be `None` in the first timestep.
# This should be handled in the function that processes the
# history when passing it to the VLM.
self._executed_skill_history.append(observation.executed_skill)

#############################



curr_state = self._create_state
self._curr_state = self._create_state()
self._curr_state.simulator_state["images"] = annotated_imgs
ret_state = self._curr_state.copy()
Expand Down Expand Up @@ -777,9 +824,9 @@ def _create_state(self) -> State:
return DefaultState
# Build the continuous part of the state.
assert self._robot is not None
# table = Object("table", _immovable_object_type)
table = Object("table", _immovable_object_type)
cup = Object("cup", _movable_object_type)
# pan = Object("pan", _container_type)
pan = Object("pan", _container_type)
# bread = Object("bread", _movable_object_type)
# toaster = Object("toaster", _immovable_object_type)
# microwave = Object("microwave", _movable_object_type)
Expand All @@ -795,21 +842,21 @@ def _create_state(self) -> State:
"qy": 0,
"qz": 0,
},
# table: {
# "x": 0,
# "y": 0,
# "z": 0,
# "qw": 0,
# "qx": 0,
# "qy": 0,
# "qz": 0,
# "shape": 0,
# "height": 0,
# "width" : 0,
# "length": 0,
# "object_id": 1,
# "flat_top_surface": 1
# },
table: {
"x": 0,
"y": 0,
"z": 0,
"qw": 0,
"qx": 0,
"qy": 0,
"qz": 0,
"shape": 0,
"height": 0,
"width" : 0,
"length": 0,
"object_id": 1,
"flat_top_surface": 1
},
cup: {
"x": 0,
"y": 0,
Expand Down Expand Up @@ -905,29 +952,32 @@ def _create_state(self) -> State:
# "object_id": 1,
# "flat_top_surface": 1
# },
# pan: {
# "x": 0,
# "y": 0,
# "z": 0,
# "qw": 0,
# "qx": 0,
# "qy": 0,
# "qz": 0,
# "shape": 0,
# "height": 0,
# "width" : 0,
# "length": 0,
# "object_id": 3,
# "placeable": 1,
# "held": 0,
# "lost": 0,
# "in_hand_view": 0,
# "in_view": 1,
# "is_sweeper": 0
# }
pan: {
"x": 0,
"y": 0,
"z": 0,
"qw": 0,
"qx": 0,
"qy": 0,
"qz": 0,
"shape": 0,
"height": 0,
"width" : 0,
"length": 0,
"object_id": 3,
"placeable": 1,
"held": 0,
"lost": 0,
"in_hand_view": 0,
"in_view": 1,
"is_sweeper": 0
}
}
state_dict = {k: list(v.values()) for k, v in state_dict.items()}
ret_state = State(state_dict)
ret_state.simulator_state = {}
ret_state.simulator_state["images"] = []
return ret_state
state = State(state_dict)
state.simulator_state = {}
state.simulator_state["images"] = []
state.simulator_state["state_history"] = []
state.simulator_state["skill_history"] = []
state.simulator_state["vlm_atoms_history"] = []
return state
4 changes: 2 additions & 2 deletions predicators/spot_utils/perception/spot_cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
# "hand_color_image": "hand_depth_in_hand_color_frame",
# "left_fisheye_image": "left_depth_in_visual_frame",
# "right_fisheye_image": "right_depth_in_visual_frame",
"frontleft_fisheye_image": "frontleft_depth_in_visual_frame",
# "frontright_fisheye_image": "frontright_depth_in_visual_frame",
# "frontleft_fisheye_image": "frontleft_depth_in_visual_frame",
"frontright_fisheye_image": "frontright_depth_in_visual_frame",
# "back_fisheye_image": "back_depth_in_visual_frame"
}

Expand Down
6 changes: 4 additions & 2 deletions predicators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2603,6 +2603,7 @@ def query_vlm_for_atom_vals(
vlm_output_str = vlm_output[0]
print(f"VLM output: {vlm_output_str}")
all_vlm_responses = vlm_output_str.strip().split("\n")
# import pdb; pdb.set_trace()
# NOTE: this assumption is likely too brittle; if this is breaking, feel
# free to remove/adjust this and change the below parsing loop accordingly!
assert len(atom_queries_list) == len(all_vlm_responses)
Expand All @@ -2615,15 +2616,15 @@ def query_vlm_for_atom_vals(
":"):period_idx].lower().strip() == "true":
true_atoms.add(vlm_atoms[i])

breakpoint()
# breakpoint()
# Add the text of the VLM's response to the state, to be used in the future!
# REMOVE THIS -> AND PUT IT IN THE PERCEIVER
# Perceiver calls utils.abstract once, and puts it in the state history.
# According to a flag, anywhere else we normally call utils.abstract, we
# instead just pull the abstract state from the state simulator state field that has it already.
# The appending of vlm atom history is currently done in query_vlm_for_atom_vals() in utils.py,
# and utils.ground calls that.
state.simulator_state["vlm_atoms_history"].append(all_vlm_responses)
# state.simulator_state["vlm_atoms_history"].append(all_vlm_responses)

return true_atoms

Expand Down Expand Up @@ -2652,6 +2653,7 @@ def abstract(state: State,
for pred in vlm_preds:
for choice in get_object_combinations(list(state), pred.types):
vlm_atoms.add(GroundAtom(pred, choice))
# import pdb; pdb.set_trace()
true_vlm_atoms = query_vlm_for_atom_vals(vlm_atoms, state, vlm)
atoms |= true_vlm_atoms
return atoms
Expand Down

0 comments on commit 41b58ea

Please sign in to comment.