Skip to content

Commit

Permalink
prepation of reset buffer with initial state distribution, adding har…
Browse files Browse the repository at this point in the history
…dware car env
  • Loading branch information
sukhijab committed Nov 28, 2023
1 parent 072d2db commit 9909610
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 30 deletions.
90 changes: 65 additions & 25 deletions experiments/online_rl_hardware/online_rl_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from experiments.online_rl_hardware.train_policy import ModelBasedRLConfig
from experiments.online_rl_hardware.train_policy import train_model_based_policy
from experiments.online_rl_hardware.utils import (set_up_bnn_dynamics_model, set_up_dummy_sac_trainer,
dump_trajectory_summary, execute)
dump_trajectory_summary, execute,
prepare_init_transitions_for_car_env)
from experiments.util import Logger, RESULT_DIR

from sim_transfer.sims.envs import RCCarSimEnv
from sim_transfer.sims.util import plot_rc_trajectory


WANDB_ENTITY = 'jonasrothfuss'
WANDB_ENTITY = 'sukhijab'
EULER_ENTITY = 'rojonas'
WANDB_LOG_DIR_EULER = '/cluster/scratch/' + EULER_ENTITY
PRIORS = {'none_FVSGD',
Expand Down Expand Up @@ -139,6 +139,7 @@ class MainConfig(NamedTuple):
include_aleatoric_noise: int = 1
best_bnn_model: int = 1
best_policy: int = 1
deterministic_policy: int = 1
predict_difference: int = 1
margin_factor: float = 20.0
ctrl_cost_weight: float = 0.005
Expand All @@ -151,19 +152,40 @@ class MainConfig(NamedTuple):
length_scale_aditive_sim_gp: float = 10.0
num_f_samples: int = 512
num_measurement_points: int = 16
initial_state_fraction: float = 0.5
sim: int = 1
control_time_ms: float = 24.


def main(config: MainConfig = MainConfig(), encode_angle: bool = True,
machine: str = 'local'):
rng_key_env, rng_key_model, rng_key_rollouts = jax.random.split(jax.random.PRNGKey(config.seed), 3)

env = RCCarSimEnv(encode_angle=encode_angle,
action_delay=config.delay,
use_tire_model=True,
use_obs_noise=True,
ctrl_cost_weight=config.ctrl_cost_weight,
margin_factor=config.margin_factor,
)
"""Setup car reward kwargs"""
car_reward_kwargs = dict(encode_angle=encode_angle,
ctrl_cost_weight=config.ctrl_cost_weight,
margin_factor=config.margin_factor)
"""Set up env"""
if bool(config.sim):
env = RCCarSimEnv(encode_angle=encode_angle,
action_delay=config.delay,
use_tire_model=True,
use_obs_noise=True,
ctrl_cost_weight=config.ctrl_cost_weight,
margin_factor=config.margin_factor,
)
else:
from sim_transfer.hardware.car_env import CarEnv
# We do not perform frame stacking in the env and do it manually here in the rollout function.
env = CarEnv(
encode_angle=encode_angle,
car_id=2,
control_time_ms=config.control_time_ms,
max_throttle=0.4,
car_reward_kwargs=car_reward_kwargs,
num_frame_stacks=0

)

# initialize train_data as empty arrays
train_data = {
Expand All @@ -176,12 +198,7 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True,
################################################################################
"""Setup key"""
key = jr.PRNGKey(config.seed)
key_bnn, key_run_episodes, key_dummy_sac_trainer = jr.split(key, 3)

"""Setup car reward kwargs"""
car_reward_kwargs = dict(encode_angle=encode_angle,
ctrl_cost_weight=config.ctrl_cost_weight,
margin_factor=config.margin_factor)
key_bnn, key_run_episodes, key_dummy_sac_trainer, key = jr.split(key, 4)

"""Setup SAC config dict"""
num_env_steps_between_updates = 16
Expand Down Expand Up @@ -257,6 +274,10 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True,
bnn_training_test_ratio=0.2,
max_num_episodes=100)

initial_states_fraction = max(min(config.initial_state_fraction, 0.9999), 0.0)
init_state_points = lambda true_buffer_points: int(initial_states_fraction * true_buffer_points
/ (1 - initial_states_fraction))

""" Set up dummy SAC trainer for getting the policy from policy params """
dummy_sac_trainer = set_up_dummy_sac_trainer(main_config=config, mbrl_config=mbrl_config, key=key)

Expand All @@ -270,16 +291,26 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True,
print('\n\n------- Episode', episode_id)

key, key_episode = jr.split(key)

key_episode, key_init_buffer = jr.split(key_episode)

num_points = train_data['x_train'].shape[0]
num_init_state_points = init_state_points(num_points)
if num_init_state_points > 0:
init_transitions = prepare_init_transitions_for_car_env(key=key_init_buffer,
number_of_samples=num_init_state_points,
num_frame_stack=config.num_stacked_actions)
else:
init_transitions = None
# train model & policy
policy_params, bnn = train_model_based_policy_remote(
train_data=train_data, bnn_model=bnn, config=mbrl_config, key=key_episode,
episode_idx=episode_id, machine=machine, wandb_config=wandb_config_remote,
remote_training=remote_training)
remote_training=remote_training, reset_buffer_transitions=init_transitions)

# get allable policy from policy params
def policy(x):
return dummy_sac_trainer.make_policy(policy_params, deterministic=True)(x, jr.PRNGKey(0))[0]
def policy(x, key: jr.PRNGKey = jr.PRNGKey(0)):
return dummy_sac_trainer.make_policy(policy_params,
deterministic=bool(config.deterministic_policy))(x, key)[0]

# perform policy rollout on the car
stacked_actions = jnp.zeros(shape=(config.num_stacked_actions * mbrl_config.u_dim,))
Expand All @@ -288,7 +319,7 @@ def policy(x):
actions, rewards, pure_obs = [], [], []
for i in range(config.num_env_steps):
rng_key_rollouts, rng_key_act = jr.split(rng_key_rollouts)
act = policy(obs)
act = policy(obs, rng_key_act)
obs, reward, _, _ = env.step(act)
rewards.append(reward)
actions.append(act)
Expand All @@ -300,7 +331,7 @@ def policy(x):

# logging and saving
trajectory, actions, rewards, pure_obs = map(lambda arr: jnp.array(arr),
[trajectory, actions, rewards, pure_obs])
[trajectory, actions, rewards, pure_obs])

traj_summary = {'episode_id': episode_id, 'trajectory': trajectory, 'actions': actions, 'rewards': rewards,
'obs': pure_obs, 'return': jnp.sum(rewards)}
Expand Down Expand Up @@ -328,12 +359,16 @@ def policy(x):
parser = argparse.ArgumentParser(description='Meta-BO run')
parser.add_argument('--seed', type=int, default=914)
parser.add_argument('--project_name', type=str, default='OnlineRL_RCCar')
parser.add_argument('--machine', type=str, default='optimality')
parser.add_argument('--machine', type=str, default='local')
parser.add_argument('--gpu', type=int, default=1)
parser.add_argument('--sim', type=int, default=1)
parser.add_argument('--control_time_ms', type=float, default=24.)

parser.add_argument('--prior', type=str, default='none_FVSGD')
parser.add_argument('--num_env_steps', type=int, default=200, info='number of steps in the environment per episode')
parser.add_argument('--num_env_steps', type=int, default=200)
parser.add_argument('--reset_bnn', type=int, default=0)
parser.add_argument('--deterministic_policy', type=int, default=1)
parser.add_argument('--initial_state_fraction', type=float, default=0.5)
args = parser.parse_args()

if not args.gpu:
Expand All @@ -345,5 +380,10 @@ def policy(x):
seed=args.seed,
project_name=args.project_name,
num_env_steps=args.num_env_steps,
reset_bnn=args.reset_bnn),
reset_bnn=args.reset_bnn,
sim=args.sim,
control_time_ms=args.control_time_ms,
deterministic_policy=args.control_time_ms,
initial_state_fraction=args.initial_state_fraction,
),
machine=args.machine)
15 changes: 12 additions & 3 deletions experiments/online_rl_hardware/train_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,29 @@
import copy
import time
import chex
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import wandb

from brax.training.types import Transition
from sim_transfer.models.abstract_model import BatchedNeuralNetworkModel
from typing import Dict

from sim_transfer.rl.model_based_rl.utils import split_data
from experiments.online_rl_hardware.utils import (load_data, dump_model, ModelBasedRLConfig, init_transition_buffer,
add_data_to_buffer, set_up_model_based_sac_trainer)


def train_model_based_policy(train_data: Dict,
bnn_model: BatchedNeuralNetworkModel,
key: chex.PRNGKey,
episode_idx: int,
config: ModelBasedRLConfig,
wandb_config: Dict,
remote_training: bool = False):
remote_training: bool = False,
reset_buffer_transitions: Transition | None = None,
):
"""
train_data = {'x_train': jnp.empty((0, state_dim + (1 + num_framestacks) * action_dim)),
'y_train': jnp.empty((0, state_dim))}
Expand Down Expand Up @@ -60,6 +64,11 @@ def train_model_based_policy(train_data: Dict,

"""Train policy"""
t = time.time()
if reset_buffer_transitions:
sac_buffer_state = true_data_buffer.insert(true_data_buffer_state, reset_buffer_transitions)
else:
sac_buffer_state = true_data_buffer_state

_sac_kwargs = config.sac_kwargs
# TODO: Be careful!!
if num_training_points == 0:
Expand All @@ -69,7 +78,7 @@ def train_model_based_policy(train_data: Dict,

key, key_sac_training, key_sac_trainer_init = jr.split(key, 3)
sac_trainer = set_up_model_based_sac_trainer(
bnn_model=bnn_model, data_buffer=true_data_buffer, data_buffer_state=true_data_buffer_state,
bnn_model=bnn_model, data_buffer=true_data_buffer, data_buffer_state=sac_buffer_state,
key=key_sac_trainer_init, config=config, sac_kwargs=_sac_kwargs)

policy_params, metrics = sac_trainer.run_training(key=key_sac_training)
Expand Down
23 changes: 21 additions & 2 deletions experiments/online_rl_hardware/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sim_transfer.models import BNN_FSVGD_SimPrior, BNN_FSVGD, BNN_SVGD
from sim_transfer.sims.simulators import AdditiveSim, PredictStateChangeWrapper, GaussianProcessSim
from sim_transfer.sims.simulators import RaceCarSim, StackedActionSimWrapper

from sim_transfer.sims.envs import RCCarSimEnv
from mbpo.optimizers.policy_optimizers.sac.sac import SAC
from mbpo.systems.brax_wrapper import BraxWrapper

Expand Down Expand Up @@ -42,6 +42,7 @@ def execute(cmd: str, verbosity: int = 0) -> None:
print(cmd)
os.system(cmd)


def load_data(data_load_path: str) -> Any:
# loads the pkl file
with open(data_load_path, 'rb') as f:
Expand Down Expand Up @@ -207,4 +208,22 @@ def set_up_dummy_sac_trainer(main_config, mbrl_config: ModelBasedRLConfig, key:
bnn_model=bnn, data_buffer=true_data_buffer,
data_buffer_state=true_data_buffer_state, key=key_bnn, config=mbrl_config)

return sac_trainer
return sac_trainer


def prepare_init_transitions_for_car_env(key: jax.random.PRNGKey, number_of_samples: int, num_frame_stack: int = 3):
sim = RCCarSimEnv(encode_angle=True, use_tire_model=True)
action_dim = 2
key_init_state = jax.random.split(key, number_of_samples)
state_obs = jax.vmap(sim.reset)(rng_key=key_init_state)
framestacked_actions = jnp.zeros(
shape=(number_of_samples, num_frame_stack * action_dim))
actions = jnp.zeros(shape=(number_of_samples, action_dim))
rewards = jnp.zeros(shape=(number_of_samples,))
discounts = 0.99 * jnp.ones(shape=(number_of_samples,))
transitions = Transition(observation=jnp.concatenate([state_obs, framestacked_actions], axis=-1),
action=actions,
reward=rewards,
discount=discounts,
next_observation=jnp.concatenate([state_obs, framestacked_actions], axis=-1))
return transitions

0 comments on commit 9909610

Please sign in to comment.