Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/working_version' into main
Browse files Browse the repository at this point in the history
# Conflicts:
#	experiments/data_provider.py
#	experiments/offline_dataset_analysis/sample_offline_dataset.py
#	experiments/offline_rl_from_recorded_data/exp.py
#	experiments/offline_rl_from_recorded_data/launcher.py
  • Loading branch information
lenarttreven committed Oct 16, 2023
2 parents f92b28a + 804f7d0 commit 4fd5963
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 32 deletions.
16 changes: 15 additions & 1 deletion experiments/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,21 @@
import pickle
from functools import partial
from typing import Dict, Any
import os
import pickle
from functools import partial
from typing import Dict, Any

import jax
import jax.numpy as jnp

import jax
import jax.numpy as jnp

from experiments.util import load_csv_recordings
from sim_transfer.sims.car_sim_config import OBS_NOISE_STD_SIM_CAR
from sim_transfer.sims.simulators import PredictStateChangeWrapper, StackedActionSimWrapper
from sim_transfer.sims.util import encode_angles as encode_angles_fn
from experiments.util import load_csv_recordings
from sim_transfer.sims.car_sim_config import OBS_NOISE_STD_SIM_CAR
from sim_transfer.sims.simulators import PredictStateChangeWrapper, StackedActionSimWrapper
Expand Down Expand Up @@ -247,6 +258,7 @@ def provide_data_and_sim(data_source: str, data_spec: Dict[str, Any], data_seed:
if num_stacked_actions > 0:
sim_sample = StackedActionSimWrapper(sim_sample, num_stacked_actions=num_stacked_actions, action_size=2)

# Prepare simulator for bnn_training (the only difference is that here we can have also low fidelity sim)
sim = RaceCarSim(encode_angle=True, use_blend=use_hf_sim, car_id=car_id)
if num_stacked_actions > 0:
sim = StackedActionSimWrapper(sim, num_stacked_actions=num_stacked_actions, action_size=2)
Expand All @@ -257,7 +269,9 @@ def provide_data_and_sim(data_source: str, data_spec: Dict[str, Any], data_seed:
num_samples_test=num_test,
obs_noise_std=data_spec.get('obs_noise_std', defaults['obs_noise_std']),
x_support_mode_train=data_spec.get('x_support_mode_train', defaults['x_support_mode_train']),
param_mode=data_spec.get('param_mode', defaults['param_mode'])
param_mode='typical'
# Used to be but then we don't sample the right model:
# param_mode=data_spec.get('param_mode', defaults['param_mode'])
)

return x_train, y_train, x_test, y_test, sim
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

def sample_dataset_from_simulator(num_samples_train: int,
data_seed: int) -> List[chex.PRNGKey]:
x_train, y_train, x_test, y_test, sim = provide_data_and_sim(data_source='racecar_hf',
data_spec={'num_samples_train': num_samples_train, },
x_train, y_train, x_test, y_test, sim = provide_data_and_sim(data_source='racecar_actionstack',
data_spec={'num_samples_train': num_samples_train,
'num_stacked_actions': 0},
data_seed=data_seed)
return x_train, y_train, x_test, y_test

Expand Down
15 changes: 12 additions & 3 deletions experiments/offline_rl_from_recorded_data/exp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse

import jax.nn
import jax.random as jr
import wandb

Expand Down Expand Up @@ -40,7 +41,11 @@ def experiment(horizon_len: int,
default_num_init_points_to_bs_for_sac_learning=1000,
data_from_simulation: int = 0,
num_frame_stack: int = 3,
bandwidth_svgd: float = 2.0,
):
# TODO: Not clear how many steps to train the BNN for. We should probably train it for a fixed number of steps
bnn_train_steps = min(50 * num_offline_collected_transitions, 100_000)

if not data_from_simulation:
assert num_frame_stack == 3, "Frame stacking has to be set to 3 if not using simulation data"
config_dict = dict(use_sim_prior=use_sim_prior,
Expand Down Expand Up @@ -109,6 +114,7 @@ def experiment(horizon_len: int,
eval_on_all_offline_data=eval_on_all_offline_data,
train_sac_only_from_init_states=train_sac_only_from_init_states,
num_frame_stack=num_frame_stack,
bandwidth_svgd=bandwidth_svgd,
)

total_config = SAC_KWARGS | config_dict
Expand All @@ -126,7 +132,7 @@ def experiment(horizon_len: int,
data_spec={'num_samples_train': num_offline_collected_transitions,
'use_hf_sim': bool(high_fidelity),
'num_stacked_actions': num_frame_stack},
data_seed=seed
data_seed=12345,
)

else:
Expand All @@ -153,6 +159,7 @@ def experiment(horizon_len: int,
'hidden_layer_sizes': [64, 64, 64],
'normalization_stats': sim.normalization_stats,
'data_batch_size': bnn_batch_size,
'hidden_activation': jax.nn.leaky_relu
}

if use_sim_prior:
Expand All @@ -175,7 +182,7 @@ def experiment(horizon_len: int,
score_estimator='gp',
num_train_steps=bnn_train_steps,
num_f_samples=256,
bandwidth_svgd=1.0,
bandwidth_svgd=bandwidth_svgd,
num_measurement_points=num_measurement_points,
)
elif use_grey_box:
Expand All @@ -191,7 +198,7 @@ def experiment(horizon_len: int,
**standard_params,
num_train_steps=bnn_train_steps,
domain=sim.domain,
bandwidth_svgd=1.0,
bandwidth_svgd=bandwidth_svgd,
)

s = share_of_x0s_in_sac_buffer
Expand Down Expand Up @@ -261,6 +268,7 @@ def main(args):
train_sac_only_from_init_states=args.train_sac_only_from_init_states,
data_from_simulation=args.data_from_simulation,
num_frame_stack=args.num_frame_stack,
bandwidth_svgd=args.bandwidth_svgd,
)


Expand Down Expand Up @@ -293,5 +301,6 @@ def main(args):
parser.add_argument('--likelihood_exponent', type=float, default=1.0)
parser.add_argument('--data_from_simulation', type=int, default=1)
parser.add_argument('--num_frame_stack', type=int, default=0)
parser.add_argument('--bandwidth_svgd', type=float, default=2.0)
args = parser.parse_args()
main(args)
16 changes: 11 additions & 5 deletions experiments/offline_rl_from_recorded_data/launcher.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import exp
from experiments.util import generate_run_commands, generate_base_command, dict_permutations

PROJECT_NAME = 'SimulatedOfflineRLNoActionStacking'
PROJECT_NAME = 'OfflineRLLeakyReluBandwidthSVGDN1'

_applicable_configs = {
'horizon_len': [200],
'seed': list(range(3)),
'seed': list(range(5)),
'project_name': [PROJECT_NAME],
'sac_num_env_steps': [2_000_000],
'bnn_train_steps': [20_000, 100_000],
'bnn_train_steps': [100_000],
'learnable_likelihood_std': ['yes'],
'include_aleatoric_noise': [1],
'include_aleatoric_noise': [0],
'best_bnn_model': [1],
'best_policy': [1],
'margin_factor': [20.0],
Expand All @@ -23,15 +23,17 @@
'share_of_x0s_in_sac_buffer': [0.5],
'bnn_batch_size': [32],
'likelihood_exponent': [1.0],
'train_sac_only_from_init_states': [1],
'train_sac_only_from_init_states': [0],
'data_from_simulation': [1],
'num_frame_stack': [0],
'bandwidth_svgd': [0.05, 0.1, 0.3]
}

_applicable_configs_no_sim_prior = {'use_sim_prior': [0],
'use_grey_box': [0],
'high_fidelity': [0],
'predict_difference': [1],
'num_measurement_points': [8]
} | _applicable_configs
_applicable_configs_high_fidelity = {'use_sim_prior': [1],
'use_grey_box': [0],
Expand All @@ -50,6 +52,10 @@
'predict_difference': [0],
'num_measurement_points': [8]} | _applicable_configs

# all_flags_combinations = dict_permutations(_applicable_configs_no_sim_prior) + dict_permutations(
# _applicable_configs_high_fidelity) + dict_permutations(_applicable_configs_low_fidelity) + dict_permutations(
# _applicable_configs_grey_box)

all_flags_combinations = dict_permutations(_applicable_configs_no_sim_prior) + dict_permutations(
_applicable_configs_high_fidelity) + dict_permutations(_applicable_configs_low_fidelity)

Expand Down
80 changes: 64 additions & 16 deletions sim_transfer/rl/rl_on_offline_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,62 @@ def prepare_policy_from_offline_data(self,
return policy, params, metrics, bnn_model

@staticmethod
def arg_median(a):
if len(a) % 2 == 1:
return jnp.where(a == jnp.median(a))[0][0]
else:
l, r = len(a) // 2 - 1, len(a) // 2
left = jnp.partition(a, l)[l]
# right = jnp.partition(a, r)[r]
return jnp.where(a == left)[0][0]
def arg_mean(a: chex.Array):
# Return index of the element that is closest to the mean
return jnp.argmin(jnp.abs(a - jnp.mean(a)))

def evaluate_policy_on_the_simulator(self,
policy: Callable,
key: chex.PRNGKey = jr.PRNGKey(0),
num_evals: int = 1, ):
def reward_on_simulator(key: chex.PRNGKey):
actions_buffer = jnp.zeros(shape=(self.action_dim * self.num_frame_stack))
sim = RCCarSimEnv(encode_angle=True, use_tire_model=True,
margin_factor=self.car_reward_kwargs['margin_factor'],
ctrl_cost_weight=self.car_reward_kwargs['ctrl_cost_weight'], )
obs = sim.reset(key)
done = False
transitions_for_plotting = []
while not done:
policy_input = jnp.concatenate([obs, actions_buffer], axis=-1)
action = policy(policy_input)
next_obs, reward, done, info = sim.step(action)
# Prepare new actions buffer
if self.num_frame_stack > 0:
next_actions_buffer = jnp.concatenate([actions_buffer[self.action_dim:], action])
else:
next_actions_buffer = jnp.zeros(shape=(0,))

transitions_for_plotting.append(Transition(observation=obs,
action=action,
reward=jnp.array(reward),
discount=jnp.array(0.99),
next_observation=next_obs)
)
actions_buffer = next_actions_buffer
obs = next_obs

concatenated_transitions_for_plotting = jtu.tree_map(lambda *xs: jnp.stack(xs, axis=0),
*transitions_for_plotting)
reward_on_simulator = jnp.sum(concatenated_transitions_for_plotting.reward)
return reward_on_simulator, concatenated_transitions_for_plotting

rewards, trajectories = vmap(reward_on_simulator)(jr.split(key, num_evals))

reward_mean = jnp.mean(rewards)
reward_std = jnp.std(rewards)

reward_mean_index = self.arg_mean(rewards)

transitions_mean = jtu.tree_map(lambda x: x[reward_mean_index], trajectories)
fig, axes = plot_rc_trajectory(transitions_mean.next_observation,
transitions_mean.action, encode_angle=True,
show=False)
model_name = 'simulator'
wandb.log({f'Mean_trajectory_on_{model_name}': wandb.Image(fig),
f'reward_mean_on_{model_name}': float(reward_mean),
f'reward_std_on_{model_name}': float(reward_std)})
plt.close('all')

def evaluate_policy_on_the_simulator(self,
policy: Callable,
Expand Down Expand Up @@ -458,20 +506,20 @@ def get_trajectory_transitions(init_obs, key):

trajectories = vmap(get_trajectory_transitions)(obs, key_generate_trajectories)

# Now we calculate median reward and std of rewards
# Now we calculate mean reward and std of rewards
rewards = jnp.sum(trajectories.reward, axis=-1)
reward_median = jnp.median(rewards)
reward_mean = jnp.mean(rewards)
reward_std = jnp.std(rewards)

reward_median_index = self.arg_median(rewards)
reward_mean_index = self.arg_mean(rewards)

transitions_median = jtu.tree_map(lambda x: x[reward_median_index], trajectories)
fig, axes = plot_rc_trajectory(transitions_median.next_observation,
transitions_median.action, encode_angle=True,
transitions_mean = jtu.tree_map(lambda x: x[reward_mean_index], trajectories)
fig, axes = plot_rc_trajectory(transitions_mean.next_observation,
transitions_mean.action, encode_angle=True,
show=False)

wandb.log({f'Median_trajectory_on_{model_name}': wandb.Image(fig),
f'reward_median_on_{model_name}': float(reward_median),
wandb.log({f'Mean_trajectory_on_{model_name}': wandb.Image(fig),
f'reward_mean_on_{model_name}': float(reward_mean),
f'reward_std_on_{model_name}': float(reward_std)})
plt.close('all')

Expand Down
10 changes: 5 additions & 5 deletions sim_transfer/sims/car_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,14 @@ def step(self,
) -> SystemState:
# Decompose to x and last actions
_x = x[:self._system.x_dim]
_us = x[self._system.x_dim:]
stacked_us = x[self._system.x_dim:]

next_sys_step = self._system.step(_x, u, system_params)
# We roll last actions and append the new action
_us = jnp.roll(_us, shift=self._system.u_dim)
_us = _us.at[:self._system.u_dim].set(u)
if self._num_frame_stack > 0:
stacked_us = jnp.concatenate([stacked_us[self._system.u_dim:], u])

# We add last actions to the state
x_next = jnp.concatenate([next_sys_step.x_next, _us], axis=0)
x_next = jnp.concatenate([next_sys_step.x_next, stacked_us], axis=0)
next_sys_step = next_sys_step.replace(x_next=x_next)
return next_sys_step

Expand Down

0 comments on commit 4fd5963

Please sign in to comment.