From 94d614e0cbe1d3f78d85c0b2ccecb05aadb99eb9 Mon Sep 17 00:00:00 2001 From: Tom Silver Date: Tue, 18 Jul 2023 15:34:05 -0400 Subject: [PATCH] wip --- scripts/kitchen_precompute_prm.py | 54 +++++++++++++++++++------------ 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/scripts/kitchen_precompute_prm.py b/scripts/kitchen_precompute_prm.py index 8442630198..8d2004e524 100644 --- a/scripts/kitchen_precompute_prm.py +++ b/scripts/kitchen_precompute_prm.py @@ -93,7 +93,7 @@ def _main() -> None: # Sample random trajectories in 7 DOF space. noise_scale = 0.1 - max_steps_per_traj = 100 + max_steps_per_traj = 500 num_trajs = 10 noise = OrnsteinUhlenbeckActionNoise(np.zeros(7), sigma=noise_scale) @@ -116,31 +116,45 @@ def _main() -> None: all_poses.append(pose) # TODO quit if something changes... - _reset_gym_env(gym_env) - init_pose = _get_pose_from_env(gym_env) - _add_pose_to_graph(init_pose, graph, distance_thresh) - reachable_nodes = set(nx.shortest_path(graph, init_pose)) + obs = env.reset("train", 0) + state = env.state_info_to_state(obs["state_info"]) + objs = set(state) + total_dist = 0.0 - target_xyz = np.array([5.0, 0.0, 2.0]) - target = min(reachable_nodes, key=lambda p: np.linalg.norm(p.xyz - target_xyz)) - print("init:", init_pose) - print("target_xyz:", target_xyz) - print("target:", target) + for obj in objs: + if "gripper" in str(obj): + continue + print("TRYING TO REACH", obj) + obs = env.reset("train", 0) + state = env.state_info_to_state(obs["state_info"]) + init_pose = _get_pose_from_env(env._gym_env) + x, y, z = state.get(obj, "x"), state.get(obj, "y"), state.get(obj, "z") + target_xyz = np.array([x, y, z]) + graph_copy = graph.copy() - path = nx.shortest_path(graph, init_pose, target, weight="weight") + _add_pose_to_graph(init_pose, graph_copy, distance_thresh) + reachable_nodes = set(nx.shortest_path(graph_copy, init_pose)) - for pose in path: - main_act = pose.joints - act = np.concatenate([main_act, gym_env.sim.data.qpos[7:9]]) - gym_env.step(act) - gym_env.render() + target = min(reachable_nodes, key=lambda p: np.linalg.norm(p.xyz - target_xyz)) + print("Closest node in graph distance:", np.linalg.norm(target.xyz - target_xyz)) + + path = nx.shortest_path(graph_copy, init_pose, target, weight="weight") + + for pose in path: + main_act = pose.joints + act = np.concatenate([main_act, gym_env.sim.data.qpos[7:9]]) + gym_env.step(act) + gym_env.render() - final_pose = _get_pose_from_env(gym_env) - print("Distance to target xyz:", np.linalg.norm(final_pose.xyz - target_xyz)) + final_pose = _get_pose_from_env(gym_env) + dist = np.linalg.norm(final_pose.xyz - target_xyz) + print("Distance to target xyz:", dist) + total_dist += dist + + print("TOTAL DIST:", total_dist) # TODO NEXT - # visualize important target locations in the environment, make sure they - # are reachable. + # check distance to important objects in the environment # incorporate collisions # incorporate orientation of end effector # sparsify graph