Skip to content

Commit

Permalink
add ppo and sac training for racecar
Browse files Browse the repository at this point in the history
  • Loading branch information
lenarttreven committed Aug 14, 2023
1 parent a87d5bd commit aa3f2dc
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 3 deletions.
Empty file added sim_transfer/rl/__init__.py
Empty file.
129 changes: 129 additions & 0 deletions sim_transfer/rl/rac_car_ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from datetime import datetime

import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import matplotlib.pyplot as plt
from brax.training.replay_buffers import UniformSamplingQueue
from brax.training.types import Transition
from jax import jit
from jax.lax import scan
from mbpo.optimizers.policy_optimizers.ppo.ppo import PPO
from mbpo.systems.brax_wrapper import BraxWrapper

from sim_transfer.sims.car_system import CarSystem
from sim_transfer.sims.util import plot_rc_trajectory

ENCODE_ANGLE = False
system = CarSystem(encode_angle=ENCODE_ANGLE,
action_delay=0.07,
use_tire_model=True, )

# Create replay buffer
init_sys_state = system.reset(key=jr.PRNGKey(0))

dummy_sample = Transition(observation=init_sys_state.x_next,
action=jnp.zeros(shape=(system.u_dim,)),
reward=init_sys_state.reward,
discount=jnp.array(0.99),
next_observation=init_sys_state.x_next)

sampling_buffer = UniformSamplingQueue(max_replay_size=1,
dummy_data_sample=dummy_sample,
sample_batch_size=1)

sampling_buffer_state = sampling_buffer.init(jr.PRNGKey(0))
sampling_buffer_state = sampling_buffer.insert(sampling_buffer_state,
jtu.tree_map(lambda x: x[None, ...], dummy_sample))

# Create brax environment
env = BraxWrapper(system=system,
sample_buffer_state=sampling_buffer_state,
sample_buffer=sampling_buffer,
system_params=system.init_params(jr.PRNGKey(0)), )

state = jit(env.reset)(rng=jr.PRNGKey(0))

sac_trainer = PPO(
environment=env,
num_timesteps=100_000,
episode_length=200,
action_repeat=1,
num_envs=4,
num_eval_envs=1,
lr=3e-3,
wd=0,
entropy_cost=0.0,
discounting=0.99,
seed=0,
unroll_length=20,
batch_size=64,
num_minibatches=16,
num_updates_per_batch=4,
num_evals=20,
normalize_observations=True,
reward_scaling=1,
clipping_epsilon=0.2,
gae_lambda=0.95,
deterministic_eval=True,
normalize_advantage=True,
policy_hidden_layer_sizes=(64, 64, 64),
critic_hidden_layer_sizes=(64, 64, 64),
wandb_logging=False,
)

max_y = 0
min_y = -100

xdata, ydata = [], []
times = [datetime.now()]


def progress(num_steps, metrics):
times.append(datetime.now())
xdata.append(num_steps)
ydata.append(metrics['eval/episode_reward'])
# plt.xlim([0, sac_trainer.num_timesteps])
# plt.ylim([min_y, max_y])
plt.xlabel('# environment steps')
plt.ylabel('reward per episode')
plt.plot(xdata, ydata)
plt.show()


params, metrics = sac_trainer.run_training(key=jr.PRNGKey(0), progress_fn=progress)

make_inference_fn = sac_trainer.make_policy

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')


def policy(x):
return make_inference_fn(params, deterministic=True)(x, jr.PRNGKey(0))[0]


system_state_init = system.reset(key=jr.PRNGKey(0))
x_init = system_state_init.x_next
system_params = system_state_init.system_params


def step(system_state, _):
u = policy(system_state.x_next)
next_sys_state = system.step(system_state.x_next, u, system_state.system_params)
return next_sys_state, (system_state.x_next, u, next_sys_state.reward)


horizon = 200
x_last, trajectory = scan(step, system_state_init, None, length=horizon)

plt.plot(trajectory[0], label='Xs')
plt.plot(trajectory[1], label='Us')
plt.plot(trajectory[2], label='Rewards')
plt.legend()
plt.show()

traj = trajectory[0]
actions = trajectory[1]

plot_rc_trajectory(traj, actions, encode_angle=ENCODE_ANGLE)
130 changes: 130 additions & 0 deletions sim_transfer/rl/race_car_sac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from datetime import datetime

import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import matplotlib.pyplot as plt
from brax.training.replay_buffers import UniformSamplingQueue
from brax.training.types import Transition
from jax import jit
from jax.lax import scan
from mbpo.optimizers.policy_optimizers.sac.sac import SAC
from mbpo.systems.brax_wrapper import BraxWrapper

from sim_transfer.sims.car_system import CarSystem
from sim_transfer.sims.util import plot_rc_trajectory

ENCODE_ANGLE = False
system = CarSystem(encode_angle=ENCODE_ANGLE,
action_delay=0.07,
use_tire_model=True, )

# Create replay buffer
init_sys_state = system.reset(key=jr.PRNGKey(0))

dummy_sample = Transition(observation=init_sys_state.x_next,
action=jnp.zeros(shape=(system.u_dim,)),
reward=init_sys_state.reward,
discount=jnp.array(0.99),
next_observation=init_sys_state.x_next)

sampling_buffer = UniformSamplingQueue(max_replay_size=1,
dummy_data_sample=dummy_sample,
sample_batch_size=1)

sampling_buffer_state = sampling_buffer.init(jr.PRNGKey(0))
sampling_buffer_state = sampling_buffer.insert(sampling_buffer_state,
jtu.tree_map(lambda x: x[None, ...], dummy_sample))

# Create brax environment
env = BraxWrapper(system=system,
sample_buffer_state=sampling_buffer_state,
sample_buffer=sampling_buffer,
system_params=system.init_params(jr.PRNGKey(0)), )

state = jit(env.reset)(rng=jr.PRNGKey(0))

sac_trainer = SAC(
environment=env,
num_timesteps=20_000,
num_evals=20,
reward_scaling=1,
episode_length=200,
normalize_observations=True,
action_repeat=1,
discounting=0.99,
lr_policy=3e-4,
lr_alpha=3e-4,
lr_q=3e-4,
num_envs=16,
batch_size=64,
grad_updates_per_step=2 * 16,
max_replay_size=2 ** 14,
min_replay_size=2 ** 7,
num_eval_envs=1,
deterministic_eval=True,
tau=0.005,
wd_policy=0,
wd_q=0,
wd_alpha=0,
wandb_logging=False,
num_env_steps_between_updates=2,
policy_hidden_layer_sizes=(64, 64, 64),
critic_hidden_layer_sizes=(64, 64, 64),
)

max_y = 0
min_y = -100

xdata, ydata = [], []
times = [datetime.now()]


def progress(num_steps, metrics):
times.append(datetime.now())
xdata.append(num_steps)
ydata.append(metrics['eval/episode_reward'])
# plt.xlim([0, sac_trainer.num_timesteps])
# plt.ylim([min_y, max_y])
plt.xlabel('# environment steps')
plt.ylabel('reward per episode')
plt.plot(xdata, ydata)
plt.show()


params, metrics = sac_trainer.run_training(key=jr.PRNGKey(0), progress_fn=progress)

make_inference_fn = sac_trainer.make_policy

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')


def policy(x):
return make_inference_fn(params, deterministic=True)(x, jr.PRNGKey(0))[0]


system_state_init = system.reset(key=jr.PRNGKey(0))
x_init = system_state_init.x_next
system_params = system_state_init.system_params


def step(system_state, _):
u = policy(system_state.x_next)
next_sys_state = system.step(system_state.x_next, u, system_state.system_params)
return next_sys_state, (system_state.x_next, u, next_sys_state.reward)


horizon = 200
x_last, trajectory = scan(step, system_state_init, None, length=horizon)

plt.plot(trajectory[0], label='Xs')
plt.plot(trajectory[1], label='Us')
plt.plot(trajectory[2], label='Rewards')
plt.legend()
plt.show()

traj = trajectory[0]
actions = trajectory[1]

plot_rc_trajectory(traj, actions, encode_angle=ENCODE_ANGLE)
12 changes: 9 additions & 3 deletions sim_transfer/sims/car_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ def _get_delayed_action(self, action: jnp.array, action_buffer: chex.PRNGKey) ->
assert delayed_action.shape == self.dim_action
return delayed_action, new_action_buffer

def reset(self, rng_key: chex.PRNGKey) -> jnp.array:
def reset(self, key: chex.PRNGKey) -> jnp.array:
""" Resets the environment to a random initial state close to the initial pose """
# sample random initial state
key_pos, key_theta, key_vel, key_obs = jr.split(rng_key, 4)
key_pos, key_theta, key_vel, key_obs = jr.split(key, 4)
init_pos = self._init_pose[:2] + jr.uniform(key_pos, shape=(2,), minval=-0.10, maxval=0.10)
init_theta = self._init_pose[2:] + \
jr.uniform(key_pos, shape=(1,), minval=-0.10 * jnp.pi, maxval=0.10 * jnp.pi)
Expand Down Expand Up @@ -234,6 +234,12 @@ def step(self,
key=key, ),
)

def reset(self, key: chex.PRNGKey) -> SystemState:
return SystemState(
x_next=self.dynamics.reset(key=key),
reward=jnp.array([0.0]).squeeze(),
system_params=self.init_params(key=key))


if __name__ == '__main__':
ENCODE_ANGLE = False
Expand All @@ -244,7 +250,7 @@ def step(self,

t_start = time.time()
system_params = system.init_params(key=jr.PRNGKey(0))
s = system.dynamics.reset(rng_key=jr.PRNGKey(0))
s = system.dynamics.reset(key=jr.PRNGKey(0))

traj = [s]
rewards = []
Expand Down

0 comments on commit aa3f2dc

Please sign in to comment.