diff --git a/predicators/envs/kitchen.py b/predicators/envs/kitchen.py index 1a6d5d8845..057e807ee9 100644 --- a/predicators/envs/kitchen.py +++ b/predicators/envs/kitchen.py @@ -33,8 +33,7 @@ def __init__(self, use_gui: bool = True) -> None: https://github.com/Learning-and-Intelligent-Systems/mujoco_kitchen" # Predicates - self._At, self._OnTop, self._TurnedOn, self._CanTurnDial = \ - self.get_goal_at_predicates() + self._At, self._OnTop, self._TurnedOn = self.get_goal_at_predicates() # NOTE: we can change the level by modifying what we pass # into gym.make here. @@ -110,11 +109,11 @@ def render(self, @property def predicates(self) -> Set[Predicate]: - return {self._At, self._TurnedOn, self._OnTop, self._CanTurnDial} + return {self._At, self._TurnedOn, self._OnTop} @property def goal_predicates(self) -> Set[Predicate]: - return {self._At, self._TurnedOn, self._OnTop, self._CanTurnDial} + return {self._At, self._TurnedOn, self._OnTop} @property def types(self) -> Set[Type]: @@ -247,12 +246,6 @@ def _On_holds(cls, state: State, objects: Sequence[Object]) -> bool: return state.get(obj, "angle") < -0.8 return False - @classmethod - def _CanTurnDial_holds(cls, state: State, - objects: Sequence[Object]) -> bool: - gripper = objects[0] - return state.get(gripper, "y") < 0.2 and state.get(gripper, "z") < 2.2 - def _copy_observation(self, obs: Observation) -> Observation: return copy.deepcopy(obs) @@ -264,6 +257,4 @@ def get_goal_at_predicates(self: Any) -> Sequence[Predicate]: Predicate("OnTop", [self.object_type, self.object_type], self._OnTop_holds), Predicate("TurnedOn", [self.object_type], self._On_holds), - Predicate("CanTurnDial", [self.gripper_type], - self._CanTurnDial_holds) ] diff --git a/predicators/ground_truth_models/kitchen/operators.py b/predicators/ground_truth_models/kitchen/operators.py index 2605bb37e6..649d393f61 100644 --- a/predicators/ground_truth_models/kitchen/operators.py +++ b/predicators/ground_truth_models/kitchen/operators.py @@ -27,7 +27,6 @@ def get_operators(env_name: str, types: Dict[str, Type], At = predicates["At"] TurnedOn = predicates["TurnedOn"] OnTop = predicates["OnTop"] - CanTurnDial = predicates["CanTurnDial"] operators = set() @@ -44,7 +43,7 @@ def get_operators(env_name: str, types: Dict[str, Type], parameters = [gripper, obj, obj2] preconditions = {LiftedAtom(At, [gripper, obj])} add_effects = {LiftedAtom(OnTop, [obj, obj2])} - delete_effects = {LiftedAtom(CanTurnDial, [gripper])} + delete_effects = set() push_obj_on_obj_forward_operator = STRIPSOperator( "PushObjOnObjForward", parameters, preconditions, add_effects, delete_effects, {OnTop}) @@ -54,7 +53,6 @@ def get_operators(env_name: str, types: Dict[str, Type], parameters = [gripper, obj] preconditions = { LiftedAtom(At, [gripper, obj]), - LiftedAtom(CanTurnDial, [gripper]) } add_effects = {LiftedAtom(TurnedOn, [obj])} delete_effects = set() diff --git a/tests/envs/test_kitchen.py b/tests/envs/test_kitchen.py index 926a445657..3fe4d2caed 100644 --- a/tests/envs/test_kitchen.py +++ b/tests/envs/test_kitchen.py @@ -36,13 +36,12 @@ def test_kitchen(): task = perceiver.reset(env_task) for obj in task.init: assert len(obj.type.feature_names) == len(task.init[obj]) - assert len(env.predicates) == 4 - At, CanTurnDial, OnTop, TurnedOn = sorted(env.predicates) + assert len(env.predicates) == 3 + At, OnTop, TurnedOn = sorted(env.predicates) assert At.name == "At" - assert CanTurnDial.name == "CanTurnDial" assert OnTop.name == "OnTop" assert TurnedOn.name == "TurnedOn" - assert env.goal_predicates == {At, CanTurnDial, OnTop, TurnedOn} + assert env.goal_predicates == {At, OnTop, TurnedOn} options = get_gt_options(env.get_name()) assert len(options) == 3 moveto_option, pushobjonobjforward_option, pushobjturnonright_option = \ @@ -124,27 +123,60 @@ def test_kitchen(): burner2 = obj_name_to_obj["burner2"] # Test moving to and pushing knob3. - # TODO - - # Test moving to and pushing the kettle on top of burner2. - move_to_kettle_nsrt = MoveTo.ground([gripper, kettle]) - for atom in move_to_kettle_nsrt.preconditions: + move_to_knob3_nsrt = MoveTo.ground([gripper, knob3]) + for atom in move_to_knob3_nsrt.preconditions: assert atom.holds(init_state) - # This sampler should always succeed. - move_to_kettle_option = move_to_kettle_nsrt.sample_option( + move_to_knob3_option = move_to_knob3_nsrt.sample_option( init_state, set(), rng) - assert move_to_kettle_option.initiable(init_state) + assert move_to_knob3_option.initiable(init_state) obs = env.reset("test", 0) state = env.state_info_to_state(obs["state_info"]) assert state.allclose(init_state) for _ in range(100): - act = move_to_kettle_option.policy(state) + act = move_to_knob3_option.policy(state) + obs = env.step(act) + state = env.state_info_to_state(obs["state_info"]) + if move_to_knob3_option.terminal(state): + break + for atom in move_to_knob3_nsrt.add_effects: + assert atom.holds(state) + for atom in move_to_knob3_nsrt.delete_effects: + assert not atom.holds(state) + + push_knob3_nsrt = PushObjTurnOnRight.ground([gripper, knob3]) + for atom in push_knob3_nsrt.preconditions: + assert atom.holds(state) + push_knob3_option = push_knob3_nsrt.sample_option(state, set(), rng) + assert push_knob3_option.initiable(state) + for _ in range(100): + act = push_knob3_option.policy(state) obs = env.step(act) state = env.state_info_to_state(obs["state_info"]) - if move_to_kettle_option.terminal(state): + if push_knob3_option.terminal(state): break - for atom in move_to_kettle_nsrt.add_effects: + for atom in push_knob3_nsrt.add_effects: assert atom.holds(state) - for atom in move_to_kettle_nsrt.delete_effects: + for atom in push_knob3_nsrt.delete_effects: assert not atom.holds(state) + + # # Test moving to and pushing the kettle on top of burner2 from init state. + # move_to_kettle_nsrt = MoveTo.ground([gripper, kettle]) + # for atom in move_to_kettle_nsrt.preconditions: + # assert atom.holds(init_state) + # move_to_kettle_option = move_to_kettle_nsrt.sample_option( + # init_state, set(), rng) + # assert move_to_kettle_option.initiable(init_state) + # obs = env.reset("test", 0) + # state = env.state_info_to_state(obs["state_info"]) + # assert state.allclose(init_state) + # for _ in range(100): + # act = move_to_kettle_option.policy(state) + # obs = env.step(act) + # state = env.state_info_to_state(obs["state_info"]) + # if move_to_kettle_option.terminal(state): + # break + # for atom in move_to_kettle_nsrt.add_effects: + # assert atom.holds(state) + # for atom in move_to_kettle_nsrt.delete_effects: + # assert not atom.holds(state) # TODO add pushing the kettle on tpo of burner2.