Skip to content

Commit

Permalink
running (but not performing well) model based rl loop
Browse files Browse the repository at this point in the history
  • Loading branch information
lenarttreven committed Aug 17, 2023
1 parent 5366f0f commit b8f8173
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 19 deletions.
3 changes: 2 additions & 1 deletion experiments/car_sac/hyperparams_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def experiment(num_envs: int,
ENCODE_ANGLE = False
system = CarSystem(encode_angle=ENCODE_ANGLE,
action_delay=0.00,
use_tire_model=True)
use_tire_model=True,
ctrl_cost_weight=0.005)

# Create replay buffer
init_sys_state = system.reset(key=jr.PRNGKey(0))
Expand Down
2 changes: 1 addition & 1 deletion experiments/car_sac/hyperparams_launcher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import hyperparams_exp
from experiments.util import generate_run_commands, generate_base_command

PROJECT_NAME = 'RaceCarSACHyperparams'
PROJECT_NAME = 'RaceCarSACHyperparamsCTRLCost0.005'

applicable_configs = {
'num_envs': [32, ],
Expand Down
6 changes: 3 additions & 3 deletions sim_transfer/rl/model_based_rl/learned_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def next_state(self,
# Create state-action pair
z = jnp.concatenate([x, u])
z = z.reshape((1, -1))
x_next_dist = self.model.predict_dist(z, include_noise=self.include_noise)
delta_x_dist = self.model.predict_dist(z, include_noise=self.include_noise)
next_key, key_sample_x_next = jr.split(dynamics_params.key)
x_next = x_next_dist.sample(seed=key_sample_x_next)
x_next = x_next.reshape((self.x_dim,))
delta_x = delta_x_dist.sample(seed=key_sample_x_next)
x_next = x + delta_x.reshape((self.x_dim,))
new_dynamics_params = dynamics_params.replace(key=next_key)
return Normal(loc=x_next, scale=jnp.zeros_like(x_next)), new_dynamics_params

Expand Down
42 changes: 28 additions & 14 deletions sim_transfer/rl/race_car_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import matplotlib.pyplot as plt
from brax.training.replay_buffers import UniformSamplingQueue
from brax.training.types import Transition
from jax import jit
from jax import jit, vmap
from jax.lax import scan
from mbpo.optimizers.policy_optimizers.sac.sac import SAC
from mbpo.systems.brax_wrapper import BraxWrapper
Expand All @@ -23,21 +23,27 @@
)

# Create replay buffer
init_sys_state = system.reset(key=jr.PRNGKey(0))
num_init_states = 500
keys = jr.split(jr.PRNGKey(0), num_init_states)
init_sys_state = vmap(system.reset)(key=keys)

dummy_sample = Transition(observation=init_sys_state.x_next,
action=jnp.zeros(shape=(system.u_dim,)),
init_samples = Transition(observation=init_sys_state.x_next,
action=jnp.zeros(shape=(num_init_states, system.u_dim,)),
reward=init_sys_state.reward,
discount=jnp.array(0.99),
discount=0.99 * jnp.ones(shape=(num_init_states,)),
next_observation=init_sys_state.x_next)

sampling_buffer = UniformSamplingQueue(max_replay_size=1,
dummy_sample = jtu.tree_map(lambda x: x[0], init_samples)

sampling_buffer = UniformSamplingQueue(max_replay_size=num_init_states,
dummy_data_sample=dummy_sample,
sample_batch_size=1)

sampling_buffer_state = sampling_buffer.init(jr.PRNGKey(0))
sampling_buffer_state = sampling_buffer.insert(sampling_buffer_state,
jtu.tree_map(lambda x: x[None, ...], dummy_sample))



sampling_buffer_state = sampling_buffer.insert(sampling_buffer_state, init_samples)

# Create brax environment
env = BraxWrapper(system=system,
Expand All @@ -47,14 +53,14 @@

state = jit(env.reset)(rng=jr.PRNGKey(0))

num_env_steps_between_updates = 4
num_envs = 32
num_env_steps_between_updates = 16
num_envs = 8

sac_trainer = SAC(
environment=env,
num_timesteps=300_000,
num_timesteps=1_000_000,
num_evals=20,
reward_scaling=10,
reward_scaling=1,
episode_length=200,
action_repeat=1,
discounting=0.99,
Expand All @@ -71,7 +77,7 @@
wd_alpha=0,
num_eval_envs=1,
max_replay_size=5 * 10 ** 4,
min_replay_size=2 ** 11,
min_replay_size=10 ** 3,
policy_hidden_layer_sizes=(64, 64),
critic_hidden_layer_sizes=(64, 64),
normalize_observations=True,
Expand Down Expand Up @@ -110,14 +116,21 @@ def policy(x):
return make_inference_fn(params, deterministic=True)(x, jr.PRNGKey(0))[0]


test_system = CarSystem(encode_angle=ENCODE_ANGLE,
action_delay=0.00,
use_tire_model=True,
use_obs_noise=True,
ctrl_cost_weight=0.005,
)

system_state_init = system.reset(key=jr.PRNGKey(0))
x_init = system_state_init.x_next
system_params = system_state_init.system_params


def step(system_state, _):
u = policy(system_state.x_next)
next_sys_state = system.step(system_state.x_next, u, system_state.system_params)
next_sys_state = test_system.step(system_state.x_next, u, system_state.system_params)
return next_sys_state, (system_state.x_next, u, next_sys_state.reward)


Expand All @@ -129,6 +142,7 @@ def step(system_state, _):
plt.plot(trajectory[2], label='Rewards')
plt.legend()
plt.show()
print('Reward: ', jnp.sum(trajectory[2]))

traj = trajectory[0]
actions = trajectory[1]
Expand Down

0 comments on commit b8f8173

Please sign in to comment.