Skip to content

Commit

Permalink
Expose op name to sampler dict
Browse files Browse the repository at this point in the history
  • Loading branch information
tkusnur-bdai committed May 2, 2024
1 parent b4e7808 commit 62e3970
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions predicators/ground_truth_models/spot_env/nsrts.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,20 @@ def _prepare_sweeping_sampler(state: State, goal: Set[GroundAtom],
return np.array([-0.8, -0.4, home_pose.angle])


_OPERATOR_NAME_TO_SAMPLER: Dict[str, NSRTSampler] = {
"MoveToHandViewObject": _move_to_hand_view_object_sampler,
"MoveToBodyViewObject": _move_to_body_view_object_sampler,
"MoveToReachObject": _move_to_reach_object_sampler,
"PickObjectFromTop": _pick_object_from_top_sampler,
"PlaceObjectOnTop": _place_object_on_top_sampler,
"DropObjectInside": _drop_object_inside_sampler,
"DropObjectInsideContainerOnTop": _drop_object_inside_sampler,
"DragToUnblockObject": _drag_to_unblock_object_sampler,
"SweepIntoContainer": _sweep_into_container_sampler,
"PrepareContainerForSweeping": _prepare_sweeping_sampler,
}


class SpotCubeEnvGroundTruthNSRTFactory(GroundTruthNSRTFactory):
"""Ground-truth NSRTs for the Spot Env."""

Expand All @@ -192,21 +206,8 @@ def get_nsrts(env_name: str, types: Dict[str, Type],

nsrts = set()

operator_name_to_sampler: Dict[str, NSRTSampler] = {
"MoveToHandViewObject": _move_to_hand_view_object_sampler,
"MoveToBodyViewObject": _move_to_body_view_object_sampler,
"MoveToReachObject": _move_to_reach_object_sampler,
"PickObjectFromTop": _pick_object_from_top_sampler,
"PlaceObjectOnTop": _place_object_on_top_sampler,
"DropObjectInside": _drop_object_inside_sampler,
"DropObjectInsideContainerOnTop": _drop_object_inside_sampler,
"DragToUnblockObject": _drag_to_unblock_object_sampler,
"SweepIntoContainer": _sweep_into_container_sampler,
"PrepareContainerForSweeping": _prepare_sweeping_sampler,
}

for strips_op in env.strips_operators:
sampler = operator_name_to_sampler[strips_op.name]
sampler = _OPERATOR_NAME_TO_SAMPLER[strips_op.name]
option = options[strips_op.name]
nsrt = strips_op.make_nsrt(
option=option,
Expand Down

0 comments on commit 62e3970

Please sign in to comment.