Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
tsilver-bdai committed Jul 18, 2023
1 parent 7680cbb commit 5aa41a2
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 31 deletions.
15 changes: 3 additions & 12 deletions predicators/envs/kitchen.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ def __init__(self, use_gui: bool = True) -> None:
https://github.com/Learning-and-Intelligent-Systems/mujoco_kitchen"

# Predicates
self._At, self._OnTop, self._TurnedOn, self._CanTurnDial = \
self.get_goal_at_predicates()
self._At, self._OnTop, self._TurnedOn = self.get_goal_at_predicates()

# NOTE: we can change the level by modifying what we pass
# into gym.make here.
Expand Down Expand Up @@ -110,11 +109,11 @@ def render(self,

@property
def predicates(self) -> Set[Predicate]:
return {self._At, self._TurnedOn, self._OnTop, self._CanTurnDial}
return {self._At, self._TurnedOn, self._OnTop}

@property
def goal_predicates(self) -> Set[Predicate]:
return {self._At, self._TurnedOn, self._OnTop, self._CanTurnDial}
return {self._At, self._TurnedOn, self._OnTop}

@property
def types(self) -> Set[Type]:
Expand Down Expand Up @@ -247,12 +246,6 @@ def _On_holds(cls, state: State, objects: Sequence[Object]) -> bool:
return state.get(obj, "angle") < -0.8
return False

@classmethod
def _CanTurnDial_holds(cls, state: State,
objects: Sequence[Object]) -> bool:
gripper = objects[0]
return state.get(gripper, "y") < 0.2 and state.get(gripper, "z") < 2.2

def _copy_observation(self, obs: Observation) -> Observation:
return copy.deepcopy(obs)

Expand All @@ -264,6 +257,4 @@ def get_goal_at_predicates(self: Any) -> Sequence[Predicate]:
Predicate("OnTop", [self.object_type, self.object_type],
self._OnTop_holds),
Predicate("TurnedOn", [self.object_type], self._On_holds),
Predicate("CanTurnDial", [self.gripper_type],
self._CanTurnDial_holds)
]
4 changes: 1 addition & 3 deletions predicators/ground_truth_models/kitchen/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def get_operators(env_name: str, types: Dict[str, Type],
At = predicates["At"]
TurnedOn = predicates["TurnedOn"]
OnTop = predicates["OnTop"]
CanTurnDial = predicates["CanTurnDial"]

operators = set()

Expand All @@ -44,7 +43,7 @@ def get_operators(env_name: str, types: Dict[str, Type],
parameters = [gripper, obj, obj2]
preconditions = {LiftedAtom(At, [gripper, obj])}
add_effects = {LiftedAtom(OnTop, [obj, obj2])}
delete_effects = {LiftedAtom(CanTurnDial, [gripper])}
delete_effects = set()
push_obj_on_obj_forward_operator = STRIPSOperator(
"PushObjOnObjForward", parameters, preconditions, add_effects,
delete_effects, {OnTop})
Expand All @@ -54,7 +53,6 @@ def get_operators(env_name: str, types: Dict[str, Type],
parameters = [gripper, obj]
preconditions = {
LiftedAtom(At, [gripper, obj]),
LiftedAtom(CanTurnDial, [gripper])
}
add_effects = {LiftedAtom(TurnedOn, [obj])}
delete_effects = set()
Expand Down
64 changes: 48 additions & 16 deletions tests/envs/test_kitchen.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,12 @@ def test_kitchen():
task = perceiver.reset(env_task)
for obj in task.init:
assert len(obj.type.feature_names) == len(task.init[obj])
assert len(env.predicates) == 4
At, CanTurnDial, OnTop, TurnedOn = sorted(env.predicates)
assert len(env.predicates) == 3
At, OnTop, TurnedOn = sorted(env.predicates)
assert At.name == "At"
assert CanTurnDial.name == "CanTurnDial"
assert OnTop.name == "OnTop"
assert TurnedOn.name == "TurnedOn"
assert env.goal_predicates == {At, CanTurnDial, OnTop, TurnedOn}
assert env.goal_predicates == {At, OnTop, TurnedOn}
options = get_gt_options(env.get_name())
assert len(options) == 3
moveto_option, pushobjonobjforward_option, pushobjturnonright_option = \
Expand Down Expand Up @@ -124,27 +123,60 @@ def test_kitchen():
burner2 = obj_name_to_obj["burner2"]

# Test moving to and pushing knob3.
# TODO

# Test moving to and pushing the kettle on top of burner2.
move_to_kettle_nsrt = MoveTo.ground([gripper, kettle])
for atom in move_to_kettle_nsrt.preconditions:
move_to_knob3_nsrt = MoveTo.ground([gripper, knob3])
for atom in move_to_knob3_nsrt.preconditions:
assert atom.holds(init_state)
# This sampler should always succeed.
move_to_kettle_option = move_to_kettle_nsrt.sample_option(
move_to_knob3_option = move_to_knob3_nsrt.sample_option(
init_state, set(), rng)
assert move_to_kettle_option.initiable(init_state)
assert move_to_knob3_option.initiable(init_state)
obs = env.reset("test", 0)
state = env.state_info_to_state(obs["state_info"])
assert state.allclose(init_state)
for _ in range(100):
act = move_to_kettle_option.policy(state)
act = move_to_knob3_option.policy(state)
obs = env.step(act)
state = env.state_info_to_state(obs["state_info"])
if move_to_knob3_option.terminal(state):
break
for atom in move_to_knob3_nsrt.add_effects:
assert atom.holds(state)
for atom in move_to_knob3_nsrt.delete_effects:
assert not atom.holds(state)

push_knob3_nsrt = PushObjTurnOnRight.ground([gripper, knob3])
for atom in push_knob3_nsrt.preconditions:
assert atom.holds(state)
push_knob3_option = push_knob3_nsrt.sample_option(state, set(), rng)
assert push_knob3_option.initiable(state)
for _ in range(100):
act = push_knob3_option.policy(state)
obs = env.step(act)
state = env.state_info_to_state(obs["state_info"])
if move_to_kettle_option.terminal(state):
if push_knob3_option.terminal(state):
break
for atom in move_to_kettle_nsrt.add_effects:
for atom in push_knob3_nsrt.add_effects:
assert atom.holds(state)
for atom in move_to_kettle_nsrt.delete_effects:
for atom in push_knob3_nsrt.delete_effects:
assert not atom.holds(state)

# # Test moving to and pushing the kettle on top of burner2 from init state.
# move_to_kettle_nsrt = MoveTo.ground([gripper, kettle])
# for atom in move_to_kettle_nsrt.preconditions:
# assert atom.holds(init_state)
# move_to_kettle_option = move_to_kettle_nsrt.sample_option(
# init_state, set(), rng)
# assert move_to_kettle_option.initiable(init_state)
# obs = env.reset("test", 0)
# state = env.state_info_to_state(obs["state_info"])
# assert state.allclose(init_state)
# for _ in range(100):
# act = move_to_kettle_option.policy(state)
# obs = env.step(act)
# state = env.state_info_to_state(obs["state_info"])
# if move_to_kettle_option.terminal(state):
# break
# for atom in move_to_kettle_nsrt.add_effects:
# assert atom.holds(state)
# for atom in move_to_kettle_nsrt.delete_effects:
# assert not atom.holds(state)
# TODO add pushing the kettle on tpo of burner2.

0 comments on commit 5aa41a2

Please sign in to comment.