Skip to content

Commit

Permalink
Fix handling holding in minimal perceiver.
Browse files Browse the repository at this point in the history
  • Loading branch information
ashay-bdai committed Sep 19, 2024
1 parent 29242a8 commit 7298b93
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 30 deletions.
59 changes: 31 additions & 28 deletions predicators/ground_truth_models/spot_env/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,31 +900,34 @@ def _move_to_ready_sweep_policy(state: State, memory: Dict,
state, memory, objects, params)


def _teleop_policy(state: State, memory: Dict, objects: Sequence[Object],
params: Array) -> Action:
del state, memory, params

robot, lease_client = get_robot_only()

def _teleop(robot: Robot, lease_client: LeaseClient):
prompt = "Press (y) when you are done with teleop."
while True:
response = utils.prompt_user(prompt).strip()
if response == "y":
break
logging.info("Invalid input. Press (y) when y")
# Take back control.
def _create_teleop_policy_with_name(name: str) -> Callable[[State, Dict, Sequence[Object], Array], Action]:
def _teleop_policy(state: State, memory: Dict, objects: Sequence[Object],
params: Array) -> Action:
nonlocal name
del state, memory, params

robot, lease_client = get_robot_only()
lease_client.take()

fn = _teleop
fn_args = (robot, lease_client)
sim_fn = lambda _: None
sim_fn_args = ()
name = "teleop"
action_extra_info = SpotActionExtraInfo(name, objects, fn, fn_args, sim_fn,
sim_fn_args)
return utils.create_spot_env_action(action_extra_info)
def _teleop(robot: Robot, lease_client: LeaseClient):
prompt = "Press (y) when you are done with teleop."
while True:
response = utils.prompt_user(prompt).strip()
if response == "y":
break
logging.info("Invalid input. Press (y) when y")
# Take back control.
robot, lease_client = get_robot_only()
lease_client.take()

fn = _teleop
fn_args = (robot, lease_client)
sim_fn = lambda _: None
sim_fn_args = ()
name = name
action_extra_info = SpotActionExtraInfo(name, objects, fn, fn_args, sim_fn,
sim_fn_args)
return utils.create_spot_env_action(action_extra_info)
return _teleop_policy


###############################################################################
Expand Down Expand Up @@ -985,11 +988,11 @@ def _teleop(robot: Robot, lease_client: LeaseClient):
"PrepareContainerForSweeping": _prepare_container_for_sweeping_policy,
"DropNotPlaceableObject": _drop_not_placeable_object_policy,
"MoveToReadySweep": _move_to_ready_sweep_policy,
"Pick1": _teleop_policy,
"PlaceNextTo": _teleop_policy,
"Pick2": _teleop_policy,
"Sweep": _teleop_policy,
"PlaceOnFloor": _teleop_policy
"Pick1": _create_teleop_policy_with_name("Pick1"),
"PlaceNextTo": _create_teleop_policy_with_name("PlaceNextTo"),
"Pick2": _create_teleop_policy_with_name("Pick2"),
"Sweep": _create_teleop_policy_with_name("Sweep"),
"PlaceOnFloor": _create_teleop_policy_with_name("PlaceOnFloor")
}


Expand Down
27 changes: 26 additions & 1 deletion predicators/perception/spot_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,15 +658,26 @@ def _create_goal(self, state: State,
# GroundAtom(VLMOn, [cup, pan])
}
return goal
# if goal_description == "put the mess in the dustpan":
# robot = Object("robot", _robot_type)
# dustpan = Object("dustpan", _dustpan_type)
# wrappers = Object("wrappers", _wrappers_type)
# goal = {
# GroundAtom(Inside, [wrappers, dustpan]),
# GroundAtom(Holding, [robot, dustpan])
# }
# return goal

if goal_description == "put the mess in the dustpan":
robot = Object("robot", _robot_type)
dustpan = Object("dustpan", _dustpan_type)
wrappers = Object("wrappers", _wrappers_type)
goal = {
GroundAtom(Inside, [wrappers, dustpan]),
# GroundAtom(Inside, [wrappers, dustpan]),
GroundAtom(Holding, [robot, dustpan])
}
return goal

raise NotImplementedError("Unrecognized goal description")

def update_perceiver_with_action(self, action: Action) -> None:
Expand Down Expand Up @@ -747,6 +758,19 @@ def step(self, observation: Observation) -> State:
self._gripper_open_percentage = observation.gripper_open_percentage

self._curr_state = self._create_state()
if observation.executed_skill is not None:
if "Pick" in observation.executed_skill.extra_info.action_name:
for obj in observation.executed_skill.extra_info.operator_objects:
if not obj.is_instance(_robot_type):
# Turn the held feature on
self._curr_state.set(obj, "held", 1.0)
if "Place" in observation.executed_skill.extra_info.action_name:
for obj in observation.executed_skill.extra_info.operator_objects:
if not obj.is_instance(_robot_type):
# Turn the held feature off
self._curr_state.set(obj, "held", 0.0)

# import pdb; pdb.set_trace()
# This state is a default/empty. We have to set the attributes
# of the objects and set the simulator state properly.
self._curr_state.simulator_state["images"] = annotated_imgs
Expand Down Expand Up @@ -800,6 +824,7 @@ def step(self, observation: Observation) -> State:
str_vlm_response = '\n'.join(reconstructed_all_vlm_responses)
self._vlm_label_history.append(str_vlm_response)
self._state_history.append(self._curr_state.copy())

return self._curr_state.copy()

def _create_state(self) -> State:
Expand Down
1 change: 0 additions & 1 deletion predicators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2625,7 +2625,6 @@ def query_vlm_for_atom_vals(
# ALTERNATIVE WAY TO PARSE
if len(label_history) > 0:
truth_values = re.findall(r'\* (.*): (True|False)', vlm_output_str)
import pdb; pdb.set_trace()
for i, (atom_query, pred_label) in enumerate(zip(atom_queries_list, truth_values)):
pred, label = pred_label
assert pred in atom_query
Expand Down

0 comments on commit 7298b93

Please sign in to comment.