From 1e655b38eda0e344cacb94d441a9e20a777e49b5 Mon Sep 17 00:00:00 2001 From: Tom Silver Date: Tue, 18 Jul 2023 11:57:12 -0400 Subject: [PATCH] switch to torque --- predicators/envs/kitchen.py | 5 ++--- tests/envs/test_kitchen.py | 21 ++++++++++++++++++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/predicators/envs/kitchen.py b/predicators/envs/kitchen.py index 359891f224..088abb13f0 100644 --- a/predicators/envs/kitchen.py +++ b/predicators/envs/kitchen.py @@ -44,7 +44,7 @@ def __init__(self, use_gui: bool = True) -> None: "use_raw_action_wrappers": False, "unflatten_images": False, }, - "control_mode": "end_effector", + "control_mode": "torque", }) def _generate_train_tasks(self) -> List[EnvironmentTask]: @@ -125,8 +125,7 @@ def types(self) -> Set[Type]: @property def action_space(self) -> Box: - # end-effector control mode: 3 for xyz, 3 for rpy, 1 for gripper - assert self._gym_env.action_space.shape == (7, ) + assert self._gym_env.action_space.shape == (9, ) return self._gym_env.action_space def reset(self, train_or_test: str, task_idx: int) -> Observation: diff --git a/tests/envs/test_kitchen.py b/tests/envs/test_kitchen.py index a1b719b215..742717632c 100644 --- a/tests/envs/test_kitchen.py +++ b/tests/envs/test_kitchen.py @@ -54,7 +54,7 @@ def test_kitchen(): gripper_type, object_type = env.types assert gripper_type.name == "gripper" assert object_type.name == "obj" - assert env.action_space.shape == (7, ) + assert env.action_space.shape == (9, ) nsrts = get_gt_nsrts(env.get_name(), env.predicates, options) assert len(nsrts) == 3 env_train_tasks = env.get_train_tasks() @@ -82,6 +82,25 @@ def test_kitchen(): env.simulate(obs, env.action_space.sample()) assert "Simulate not implemented for gym envs." in str(e) + # Test action space. + obs = env.reset("test", 0) + + # Move rotate gripper. + import imageio + imgs = [] + act_arr = np.zeros(9) + # act_arr[3] = 1.0 + act = Action(act_arr) + imgs.append(env.render()[0]) + for _ in range(500): + obs = env.step(act) + # state = env.state_info_to_state(obs["state_info"]) + # print(state.pretty_str()) + imgs.append(env.render()[0]) + + imageio.mimwrite("kitchen_noop.mp4", imgs, fps=5) + import ipdb + ipdb.set_trace() # Test NSRTs. MoveTo, PushObjOnObjForward, PushObjTurnOnRight = sorted(nsrts)