Skip to content

Commit

Permalink
Play-test other VLM predicates and update annotation in spot images.
Browse files Browse the repository at this point in the history
  • Loading branch information
ashay-bdai committed Sep 12, 2024
1 parent 0b8d125 commit d255810
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 42 deletions.
2 changes: 2 additions & 0 deletions predicators/approaches/bilevel_planning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]:
seed = self._seed + self._num_calls
nsrts = self._get_current_nsrts()
preds = self._get_current_predicates()
utils.abstract(task.init, preds, self._vlm)
import pdb; pdb.set_trace()
# utils.abstract(task.init, preds, self._vlm)

# Run task planning only and then greedily sample and execute in the
Expand Down
30 changes: 28 additions & 2 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,13 +1488,38 @@ def _get_vlm_query_str(pred_name: str, objects: Sequence[Object]) -> str:
[_movable_object_type, _base_object_type],
lambda o: _get_vlm_query_str("VLMOn", o)
)
_Upright = utils.create_vlm_predicate(
"Upright",
[_movable_object_type],
lambda o: _get_vlm_query_str("Upright", o)
)
_Toasted = utils.create_vlm_predicate(
"Toasted",
[_movable_object_type],
lambda o: _get_vlm_query_str("Toasted", o)
)
_VLMIn = utils.create_vlm_predicate(
"VLMIn",
[_movable_object_type, _immovable_object_type],
lambda o: _get_vlm_query_str("In", o)
)
_Open = utils.create_vlm_predicate(
"Open",
[_movable_object_type],
lambda o: _get_vlm_query_str("Open", o)
)
_Stained = utils.create_vlm_predicate(
"Stained",
[_movable_object_type],
lambda o: _get_vlm_query_str("Stained", o)
)

_ALL_PREDICATES = {
_NEq, _On, _TopAbove, _Inside, _NotInsideAnyContainer, _FitsInXY,
_HandEmpty, _Holding, _NotHolding, _InHandView, _InView, _Reachable,
_Blocking, _NotBlocked, _ContainerReadyForSweeping, _IsPlaceable,
_IsNotPlaceable, _IsSweeper, _HasFlatTopSurface, _RobotReadyForSweeping,
_IsSemanticallyGreaterThan, _VLMOn
_IsSemanticallyGreaterThan, _VLMOn, _Upright, _Toasted, _VLMIn, _Open, _Stained
}
_NONPERCEPT_PREDICATES: Set[Predicate] = set()

Expand Down Expand Up @@ -2426,7 +2451,8 @@ class VLMTestEnv(SpotRearrangementEnv):

@property
def predicates(self) -> Set[Predicate]:
return set(p for p in _ALL_PREDICATES if p.name in ["VLMOn", "Holding", "HandEmpty"])
# return set(p for p in _ALL_PREDICATES if p.name in ["VLMOn", "Holding", "HandEmpty", "Pourable", "Toasted", "VLMIn", "Open"])
return set(p for p in _ALL_PREDICATES if p.name in ["VLMOn", "Holding", "HandEmpty", "Upright"])

@property
def goal_predicates(self) -> Set[Predicate]:
Expand Down
179 changes: 139 additions & 40 deletions predicators/perception/spot_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,15 @@ def render_mental_images(self, observation: Observation,
class SpotMinimalPerceiver(BasePerceiver):
"""A perceiver for spot envs with minimal functionality."""

camera_name_to_annotation = {
'hand_color_image': "Hand Camera Image",
'back_fisheye_image': "Back Camera Image",
'frontleft_fisheye_image': "Front Left Camera Image",
'frontright_fisheye_image': "Front Right Camera Image",
'left_fisheye_image': "Left Camera Image",
'right_fisheye_image': "Right Camera Image"
}

def render_mental_images(self, observation: Observation,
env_task: EnvironmentTask) -> Video:
raise NotImplementedError()
Expand Down Expand Up @@ -672,12 +681,23 @@ def step(self, observation: Observation) -> State:
self._waiting_for_observation = False
self._robot = observation.robot
imgs = observation.rgbd_images
img_names = [v.camera_name for _, v in imgs.items()]
imgs = [v.rgb for _, v in imgs.items()]
# import PIL
# PIL.Image.fromarray(imgs[0]).show()
import pdb; pdb.set_trace()
import PIL
from PIL import ImageDraw
annotated_pil_imgs = []
for img, img_name in zip(imgs, img_names):
pil_img = PIL.Image.fromarray(img)
draw = ImageDraw.Draw(pil_img)
font = utils.get_scaled_default_font(draw, 4)
annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font)
annotated_pil_imgs.append(pil_img)
annotated_imgs = [np.array(img) for img in annotated_pil_imgs]
import pdb; pdb.set_trace()
self._gripper_open_percentage = observation.gripper_open_percentage
self._curr_state = self._create_state()
self._curr_state.simulator_state["images"] = imgs
self._curr_state.simulator_state["images"] = annotated_imgs
ret_state = self._curr_state.copy()
return ret_state

Expand All @@ -686,9 +706,13 @@ def _create_state(self) -> State:
return DefaultState
# Build the continuous part of the state.
assert self._robot is not None
table = Object("table", _immovable_object_type)
# table = Object("table", _immovable_object_type)
cup = Object("cup", _movable_object_type)
pan = Object("pan", _container_type)
# pan = Object("pan", _container_type)
# bread = Object("bread", _movable_object_type)
# toaster = Object("toaster", _immovable_object_type)
# microwave = Object("microwave", _movable_object_type)
# napkin = Object("napkin", _movable_object_type)
state_dict = {
self._robot: {
"gripper_open_percentage": self._gripper_open_percentage,
Expand All @@ -700,21 +724,21 @@ def _create_state(self) -> State:
"qy": 0,
"qz": 0,
},
table: {
"x": 0,
"y": 0,
"z": 0,
"qw": 0,
"qx": 0,
"qy": 0,
"qz": 0,
"shape": 0,
"height": 0,
"width" : 0,
"length": 0,
"object_id": 1,
"flat_top_surface": 1
},
# table: {
# "x": 0,
# "y": 0,
# "z": 0,
# "qw": 0,
# "qx": 0,
# "qy": 0,
# "qz": 0,
# "shape": 0,
# "height": 0,
# "width" : 0,
# "length": 0,
# "object_id": 1,
# "flat_top_surface": 1
# },
cup: {
"x": 0,
"y": 0,
Expand All @@ -735,26 +759,101 @@ def _create_state(self) -> State:
"in_view": 1,
"is_sweeper": 0
},
pan: {
"x": 0,
"y": 0,
"z": 0,
"qw": 0,
"qx": 0,
"qy": 0,
"qz": 0,
"shape": 0,
"height": 0,
"width" : 0,
"length": 0,
"object_id": 3,
"placeable": 1,
"held": 0,
"lost": 0,
"in_hand_view": 0,
"in_view": 1,
"is_sweeper": 0
}
# napkin: {
# "x": 0,
# "y": 0,
# "z": 0,
# "qw": 0,
# "qx": 0,
# "qy": 0,
# "qz": 0,
# "shape": 0,
# "height": 0,
# "width" : 0,
# "length": 0,
# "object_id": 2,
# "placeable": 1,
# "held": 0,
# "lost": 0,
# "in_hand_view": 0,
# "in_view": 1,
# "is_sweeper": 0
# },
# microwave: {
# "x": 0,
# "y": 0,
# "z": 0,
# "qw": 0,
# "qx": 0,
# "qy": 0,
# "qz": 0,
# "shape": 0,
# "height": 0,
# "width" : 0,
# "length": 0,
# "object_id": 2,
# "placeable": 1,
# "held": 0,
# "lost": 0,
# "in_hand_view": 0,
# "in_view": 1,
# "is_sweeper": 0
# },
# bread: {
# "x": 0,
# "y": 0,
# "z": 0,
# "qw": 0,
# "qx": 0,
# "qy": 0,
# "qz": 0,
# "shape": 0,
# "height": 0,
# "width" : 0,
# "length": 0,
# "object_id": 2,
# "placeable": 1,
# "held": 0,
# "lost": 0,
# "in_hand_view": 0,
# "in_view": 1,
# "is_sweeper": 0
# },
# toaster: {
# "x": 0,
# "y": 0,
# "z": 0,
# "qw": 0,
# "qx": 0,
# "qy": 0,
# "qz": 0,
# "shape": 0,
# "height": 0,
# "width" : 0,
# "length": 0,
# "object_id": 1,
# "flat_top_surface": 1
# },
# pan: {
# "x": 0,
# "y": 0,
# "z": 0,
# "qw": 0,
# "qx": 0,
# "qy": 0,
# "qz": 0,
# "shape": 0,
# "height": 0,
# "width" : 0,
# "length": 0,
# "object_id": 3,
# "placeable": 1,
# "held": 0,
# "lost": 0,
# "in_hand_view": 0,
# "in_view": 1,
# "is_sweeper": 0
# }
}
state_dict = {k: list(v.values()) for k, v in state_dict.items()}
ret_state = State(state_dict)
Expand Down

0 comments on commit d255810

Please sign in to comment.