Skip to content

Commit

Permalink
switch to torque
Browse files Browse the repository at this point in the history
  • Loading branch information
tsilver-bdai committed Jul 18, 2023
1 parent e5ae880 commit 1e655b3
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
5 changes: 2 additions & 3 deletions predicators/envs/kitchen.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, use_gui: bool = True) -> None:
"use_raw_action_wrappers": False,
"unflatten_images": False,
},
"control_mode": "end_effector",
"control_mode": "torque",
})

def _generate_train_tasks(self) -> List[EnvironmentTask]:
Expand Down Expand Up @@ -125,8 +125,7 @@ def types(self) -> Set[Type]:

@property
def action_space(self) -> Box:
# end-effector control mode: 3 for xyz, 3 for rpy, 1 for gripper
assert self._gym_env.action_space.shape == (7, )
assert self._gym_env.action_space.shape == (9, )
return self._gym_env.action_space

def reset(self, train_or_test: str, task_idx: int) -> Observation:
Expand Down
21 changes: 20 additions & 1 deletion tests/envs/test_kitchen.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_kitchen():
gripper_type, object_type = env.types
assert gripper_type.name == "gripper"
assert object_type.name == "obj"
assert env.action_space.shape == (7, )
assert env.action_space.shape == (9, )
nsrts = get_gt_nsrts(env.get_name(), env.predicates, options)
assert len(nsrts) == 3
env_train_tasks = env.get_train_tasks()
Expand Down Expand Up @@ -82,6 +82,25 @@ def test_kitchen():
env.simulate(obs, env.action_space.sample())
assert "Simulate not implemented for gym envs." in str(e)

# Test action space.
obs = env.reset("test", 0)

# Move rotate gripper.
import imageio
imgs = []
act_arr = np.zeros(9)
# act_arr[3] = 1.0
act = Action(act_arr)
imgs.append(env.render()[0])
for _ in range(500):
obs = env.step(act)
# state = env.state_info_to_state(obs["state_info"])
# print(state.pretty_str())
imgs.append(env.render()[0])

imageio.mimwrite("kitchen_noop.mp4", imgs, fps=5)
import ipdb
ipdb.set_trace()

# Test NSRTs.
MoveTo, PushObjOnObjForward, PushObjTurnOnRight = sorted(nsrts)
Expand Down

0 comments on commit 1e655b3

Please sign in to comment.