Skip to content

Commit

Permalink
Cartpole (#296)
Browse files Browse the repository at this point in the history
* work

* add cartpole robot urdf

* work

* Update cartpole.py

* align reward functions, change camera pose

* align to dm control time horizons, example ppo solving scripts

* fix bug with record episode and reward computations

* work

* remove old assets

* Update index.md

---------

Co-authored-by: chenbao <im.b.c@live.com>
  • Loading branch information
StoneT2000 and Kami-code authored Apr 29, 2024
1 parent 4cb96d1 commit 1ae8d65
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 61 deletions.
40 changes: 27 additions & 13 deletions docs/source/tasks/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,28 +280,42 @@ Using the TriFingerPro robot, rotate a cube

## Control Tasks

### MS-CartPole-v1

### MS-CartpoleBalance-v1

:::{dropdown} Task Card
:icon: note
:color: primary

**Task Description:**
Keep the CartPole stable and up right by sliding it left and right
Use the Cartpole robot to balance a pole on a cart.

**Supported Robots: None**

**Randomizations:**
- TODO
**Supported Robots: Cartpole**

**Success Conditions:**
- the cart is within 0.25m of the center of the rail (which is at 0)
- the cosine of the hinge angle attaching the pole is between 0.995 and 1
**Randomizations:**
- Pole direction is randomized around the vertical axis. the range is [-0.05, 0.05] radians.

**Goal Specification:**
- None
**Fail Conditions:**
- Pole is lower than the horizontal plane

<video preload="auto" controls="True" width="100%">
<source src="https://github.com/haosulab/ManiSkill/raw/main/figures/environment_demos/MS-CartPole-v1_rt.mp4" type="video/mp4">
</video>
<source src="https://github.com/haosulab/ManiSkill/raw/main/figures/environment_demos/MS-CartpoleBalance-v1_rt.mp4" type="video/mp4">
</video>

### MS-CartpoleSwingup-v1

:::{dropdown} Task Card
:icon: note
:color: primary

**Task Description:**
Use the Cartpole robot to swing up a pole on a cart.


**Supported Robots: Cartpole**

**Randomizations:**
- Pole direction is randomized around the whole circle. the range is [-pi, pi] radians.

**Success Conditions:**
- No specific success conditions. The task is considered successful if the pole is upright for the whole episode. We can threshold the episode accumulated reward to determine success.
10 changes: 8 additions & 2 deletions examples/baselines/ppo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,15 @@ python ppo.py --env_id="RotateCubeLevel4-v1" \
--num_envs=1024 --update_epochs=8 --num_minibatches=32 \
--total_timesteps=500_000_000 --num-steps=250 --num-eval-steps=250

python ppo.py --env_id="MS-CartPole-v1" \
python ppo.py --env_id="MS-CartpoleBalance-v1" \
--num_envs=1024 --update_epochs=8 --num_minibatches=32 \
--total_timesteps=10_000_000 --num-steps=500 --num-eval-steps=500 \
--total_timesteps=4_000_000 --num-steps=250 --num-eval-steps=1000 \
--gamma=0.99 --gae_lambda=0.95 \
--eval_freq=5

python ppo.py --env_id="MS-CartpoleSwingUp-v1" \
--num_envs=1024 --update_epochs=8 --num_minibatches=32 \
--total_timesteps=10_000_000 --num-steps=250 --num-eval-steps=1000 \
--gamma=0.99 --gae_lambda=0.95 \
--eval_freq=5

Expand Down
Binary file removed figures/environment_demos/MS-CartPole-v1.mp4
Binary file not shown.
Binary file not shown.
22 changes: 14 additions & 8 deletions mani_skill/envs/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,17 +485,23 @@ def get_sim_state(self) -> torch.Tensor:
state_dict["articulations"][
articulation.name
] = articulation.get_state().clone()
if len(state_dict["actors"]) == 0:
del state_dict["actors"]
if len(state_dict["articulations"]) == 0:
del state_dict["articulations"]
return state_dict

def set_sim_state(self, state: Dict):
for actor_id, actor_state in state["actors"].items():
if len(actor_state.shape) == 1:
actor_state = actor_state[None, :]
self.actors[actor_id].set_state(actor_state)
for art_id, art_state in state["articulations"].items():
if len(art_state.shape) == 1:
art_state = art_state[None, :]
self.articulations[art_id].set_state(art_state)
if "actors" in state:
for actor_id, actor_state in state["actors"].items():
if len(actor_state.shape) == 1:
actor_state = actor_state[None, :]
self.actors[actor_id].set_state(actor_state)
if "articulations" in state:
for art_id, art_state in state["articulations"].items():
if len(art_state.shape) == 1:
art_state = art_state[None, :]
self.articulations[art_id].set_state(art_state)

# ---------------------------------------------------------------------------- #
# GPU Simulation Management
Expand Down
2 changes: 1 addition & 1 deletion mani_skill/envs/tasks/control/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .cartpole import CartPoleEnv
from .cartpole import CartpoleBalanceEnv, CartpoleSwingUpEnv
122 changes: 88 additions & 34 deletions mani_skill/envs/tasks/control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,23 @@
from mani_skill.agents.base_agent import BaseAgent
from mani_skill.agents.controllers import *
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.envs.utils import randomization
from mani_skill.envs.utils import randomization, rewards
from mani_skill.sensors.camera import CameraConfig
from mani_skill.utils import common, sapien_utils
from mani_skill.utils.registration import register_env
from mani_skill.utils.structs.types import SceneConfig, SimConfig
from mani_skill.utils.structs.pose import Pose
from mani_skill.utils.structs.types import (
Array,
GPUMemoryConfig,
SceneConfig,
SimConfig,
)

MJCF_FILE = f"{os.path.join(os.path.dirname(__file__), 'assets/cartpole.xml')}"


class CartPoleRobot(BaseAgent):
uid = "cartpole"
uid = "cart_pole"
mjcf_path = MJCF_FILE

@property
Expand Down Expand Up @@ -57,15 +63,23 @@ def _load_articulation(self):
self.robot_link_ids = [link.name for link in self.robot.get_links()]


@register_env("MS-CartPole-v1", max_episode_steps=500)
class CartPoleEnv(BaseEnv):
SUPPORTED_REWARD_MODES = ["sparse", "none"]
# @register_env("MS-CartPole-v1", max_episode_steps=500)
# class CartPoleEnv(BaseEnv):
# SUPPORTED_REWARD_MODES = ["sparse", "none"]

SUPPORTED_ROBOTS = [CartPoleRobot]
agent: Union[CartPoleRobot]
# SUPPORTED_ROBOTS = [CartPoleRobot]
# agent: Union[CartPoleRobot]

# CART_RANGE = [-0.25, 0.25]
# ANGLE_COSINE_RANGE = [0.995, 1]

# def __init__(self, *args, robot_uids=CartPoleRobot, **kwargs):
# super().__init__(*args, robot_uids=robot_uids, **kwargs)

CART_RANGE = [-0.25, 0.25]
ANGLE_COSINE_RANGE = [0.995, 1]

class CartpoleEnv(BaseEnv):

agent: Union[CartPoleRobot]

def __init__(self, *args, robot_uids=CartPoleRobot, **kwargs):
super().__init__(*args, robot_uids=robot_uids, **kwargs)
Expand Down Expand Up @@ -94,6 +108,60 @@ def _load_scene(self, options: dict):
for a in actor_builders:
a.build(a.name)

def evaluate(self):
return dict()

def _get_obs_extra(self, info: Dict):
obs = dict(
velocity=self.agent.robot.links_map["pole_1"].linear_velocity,
angular_velocity=self.agent.robot.links_map["pole_1"].angular_velocity,
)
return obs

@property
def pole_angle_cosine(self):
return torch.cos(self.agent.robot.joints_map["hinge_1"].qpos)

def compute_dense_reward(self, obs: Any, action: Array, info: Dict):
cart_pos = self.agent.robot.links_map["cart"].pose.p[
:, 0
] # (B, ), we only care about x position
centered = rewards.tolerance(cart_pos, margin=2)
centered = (1 + centered) / 2 # (B, )

small_control = rewards.tolerance(
action, margin=1, value_at_margin=0, sigmoid="quadratic"
)[:, 0]
small_control = (4 + small_control) / 5

angular_vel = self.agent.robot.get_qvel()[:, 1]
small_velocity = rewards.tolerance(angular_vel, margin=5)
small_velocity = (1 + small_velocity) / 2 # (B, )

upright = (self.pole_angle_cosine + 1) / 2 # (B, )

# upright is 1 when the pole is upright, 0 when the pole is upside down
# small_control is 1 when the action is small, 0.8 when the action is large
# small_velocity is 1 when the angular velocity is small, 0.5 when the angular velocity is large
# centered is 1 when the cart is centered, 0 when the cart is at the edge of the screen

reward = upright * centered * small_control * small_velocity
return reward

def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict):
# this should be equal to compute_dense_reward / max possible reward
max_reward = 1.0
return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward


@register_env("MS-CartpoleBalance-v1", max_episode_steps=1000)
class CartpoleBalanceEnv(CartpoleEnv):
def __init__(self, *args, **kwargs):
super().__init__(
*args,
**kwargs,
)

def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
with torch.device(self.device):
b = len(env_idx)
Expand All @@ -104,38 +172,24 @@ def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
self.agent.robot.set_qpos(qpos)
self.agent.robot.set_qvel(qvel)

@property
def pole_angle_cosine(self):
return torch.cos(self.agent.robot.joints_map["hinge_1"].qpos)

def evaluate(self):
cart_pos = self.agent.robot.joints_map["slider"].qpos
pole_angle_cosine = self.pole_angle_cosine
cart_in_bounds = cart_pos < self.CART_RANGE[1]
cart_in_bounds = cart_in_bounds & (cart_pos > self.CART_RANGE[0])
angle_in_bounds = pole_angle_cosine < self.ANGLE_COSINE_RANGE[1]
angle_in_bounds = angle_in_bounds & (
pole_angle_cosine > self.ANGLE_COSINE_RANGE[0]
)
return {"cart_in_bounds": cart_in_bounds, "angle_in_bounds": angle_in_bounds}
return dict(fail=self.pole_angle_cosine < 0)

def _get_obs_extra(self, info: Dict):
return dict()

def compute_sparse_reward(self, obs: Any, action: torch.Tensor, info: Dict):
return info["cart_in_bounds"] * info["angle_in_bounds"]

@register_env("MS-CartpoleSwingUp-v1", max_episode_steps=1000)
class CartpoleSwingUpEnv(CartpoleEnv):
def __init__(self, *args, **kwargs):
super().__init__(
*args,
**kwargs,
)

@register_env("CartPoleSwingUp-v1", max_episode_steps=500, override=True)
class CartPoleSwingUpEnv(CartPoleEnv):
def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
with torch.device(self.device):
b = len(env_idx)
qpos = torch.zeros((b, 2))
qpos[:, 0] = 0.01 * torch.randn(size=(b,))
qpos[:, 1] = torch.pi + 0.01 * torch.randn(size=(b,))
qpos[:, 0] = torch.randn((b,)) * 0.01
qpos[:, 1] = torch.randn((b,)) * 0.01 + torch.pi
qvel = torch.randn(size=(b, 2)) * 0.01
self.agent.robot.set_qpos(qpos)
self.agent.robot.set_qvel(qvel)
# Note DM-Control sets some randomness to other qpos values but am not sure what they are
# as cartpole.xml seems to only load two joints
59 changes: 58 additions & 1 deletion mani_skill/envs/utils/rewards/common.py
Original file line number Diff line number Diff line change
@@ -1 +1,58 @@
"""Useful utilities for reward functions"""
import torch


def tolerance(
x, lower=0.0, upper=0.0, margin=0.0, sigmoid="gaussian", value_at_margin=0.1
):
# modified from https://github.com/google-deepmind/dm_control/blob/554ad2753df914372597575505249f22c255979d/dm_control/utils/rewards.py#L93
"""Returns 1 when `x` falls inside the bounds, between 0 and 1 otherwise.
Args:
x: A torch array. (B, 3)
lower, upper: specifying inclusive `(lower, upper)` bounds for
the target interval. These can be infinite if the interval is unbounded
at one or both ends, or they can be equal to one another if the target
value is exact.
margin: Float. Parameter that controls how steeply the output decreases as
`x` moves out-of-bounds.
* If `margin == 0` then the output will be 0 for all values of `x`
outside of `bounds`.
* If `margin > 0` then the output will decrease sigmoidally with
increasing distance from the nearest bound.
sigmoid: String, choice of sigmoid type. Valid values are: 'gaussian',
'linear', 'hyperbolic', 'long_tail', 'cosine', 'tanh_squared'.
value_at_margin: A float between 0 and 1 specifying the output value when
the distance from `x` to the nearest bound is equal to `margin`. Ignored
if `margin == 0`. todo: not implemented yet
Returns:
A torch array with values between 0.0 and 1.0.
Raises:
ValueError: If `bounds[0] > bounds[1]`.
ValueError: If `margin` is negative.
"""
if lower > upper:
raise ValueError("Lower bound must be <= upper bound.")

if margin < 0:
raise ValueError("`margin` must be non-negative.")

in_bounds = torch.logical_and(lower <= x, x <= upper)

if margin == 0:
value = torch.where(in_bounds, torch.tensor(1.0), torch.tensor(0.0))
else:
d = torch.where(x < lower, lower - x, x - upper) / margin
if sigmoid == "gaussian":
value = torch.where(
in_bounds, torch.tensor(1.0), torch.exp(-0.5 * (d**2))
)
elif sigmoid == "hyperbolic":
value = torch.where(in_bounds, torch.tensor(1.0), 1 / (1 + torch.exp(d)))
elif sigmoid == "quadratic":
value = torch.where(in_bounds, torch.tensor(1.0), 1 - d**2)
else:
raise ValueError(f"Unknown sigmoid type {sigmoid!r}.")

return value
4 changes: 4 additions & 0 deletions mani_skill/trajectory/replay_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,9 @@ def parse_args(args=None):
type=str,
help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer",
)
parser.add_argument(
"--video-fps", default=30, type=int, help="The FPS of saved videos"
)

return parser.parse_args(args)

Expand Down Expand Up @@ -418,6 +421,7 @@ def _main(args, proc_id: int = 0, num_procs=1, pbar=None):
save_trajectory=args.save_traj,
trajectory_name=new_traj_name,
save_video=args.save_video,
video_fps=args.video_fps,
record_reward=args.record_rewards,
)

Expand Down
4 changes: 2 additions & 2 deletions mani_skill/utils/wrappers/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def __init__(
max_steps_per_video=None,
clean_on_close=True,
record_reward=True,
video_fps=20,
video_fps=30,
source_type=None,
source_desc=None,
):
Expand Down Expand Up @@ -597,7 +597,7 @@ def recursive_add_to_h5py(group: h5py.Group, data: dict, key):
dtype=bool,
)
episode_info.update(
fail=self._trajectory_buffer.success[end_ptr - 1, env_idx]
fail=self._trajectory_buffer.fail[end_ptr - 1, env_idx]
)
recursive_add_to_h5py(group, self._trajectory_buffer.state, "env_states")
if self.record_reward:
Expand Down

0 comments on commit 1ae8d65

Please sign in to comment.