Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
lenarttreven committed Jan 17, 2024
2 parents 208cff9 + d713ad0 commit 47d170f
Show file tree
Hide file tree
Showing 18 changed files with 514 additions and 87 deletions.
64 changes: 64 additions & 0 deletions experiments/1d_visualization/1d_visualization_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from matplotlib import pyplot as plt
from sim_transfer.sims.simulators import SinusoidsSim

import os
import pickle
import jax

PLOT_POST_SAMPLES = True

PLOTS_1D_DIR = os.path.dirname(os.path.abspath(__file__))
PLOT_DICT_DIR = os.path.join(PLOTS_1D_DIR, 'plot_dicts')
PLOT_DICT_PATHS = [
('BNN_SVGD', 'SinusoidsSim_BNN_SVGD_2.pkl'),
('BNN_FSVGD', 'SinusoidsSim_BNN_FSVGD_2.pkl'),
('BNN_FSVGD_SimPrior_gp', 'SinusoidsSim_BNN_FSVGD_SimPrior_gp_2.pkl'),
('BNN_FSVGD_SimPrior_nu-method', 'SinusoidsSim_BNN_FSVGD_SimPrior_nu-method_2.pkl'),
('BNN_FSVGD_SimPrior_kde', 'SinusoidsSim_BNN_FSVGD_SimPrior_kde_2.pkl'),
]
PLOT_DICT_PATHS = map(lambda x: (x[0], os.path.join(PLOT_DICT_DIR, x[1])), PLOT_DICT_PATHS)

PLOT_MODELS = ['BNN_SVGD', 'BNN_FSVGD', 'BNN_FSVGD_SimPrior_gp']


# draw the plot
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(3 * 4, 6))

sim = SinusoidsSim(output_size=1)

for i, (model, load_path) in enumerate(PLOT_DICT_PATHS):
with open(load_path, 'rb') as f:
plot_dict = pickle.load(f)
print(f'Plot dict loaded from {load_path}')
plot_data = plot_dict['plot_data']

if i == 0:
""" plot samples from the simulation env """
f_sim = sim.sample_function_vals(plot_data['x_plot'], num_samples=10, rng_key=jax.random.PRNGKey(234234))
for j in range(f_sim.shape[0]):
axes[0][0].plot(plot_data['x_plot'], f_sim[j])
axes[0][0].set_title('sampled functions from sim prior')
axes[0][0].set_ylim((-14, 14))


ax = axes[(i+1)//3][(i+1)%3]
if PLOT_POST_SAMPLES:
for k, y in enumerate(plot_data['y_post_samples']):
ax.plot(plot_data['x_plot'], y[:, i], linewidth=0.2, color='tab:green', alpha=0.5,
label='BNN particles' if k == 0 else None)

ax.scatter(plot_data['x_train'].flatten(), plot_data['y_train'][:, i], 100, label='train points', marker='x',
linewidths=2.5, color='tab:blue')
ax.plot(plot_data['x_plot'], plot_data['true_fun'], label='true fun')
ax.plot(plot_data['x_plot'].flatten(), plot_data['pred_mean'][:, i], label='pred mean')
ax.fill_between(plot_data['x_plot'].flatten(), plot_data['pred_mean'][:, i] - 2 * plot_data['pred_std'][:, i],
plot_data['pred_mean'][:, i] + 2 * plot_data['pred_std'][:, i], alpha=0.2,
label='95 % CI', color='tab:orange')

if i == 4:
ax.legend()
ax.set_title(model)
ax.set_ylim((-14, 14))
fig.tight_layout()
fig.show()
fig.savefig(os.path.join(PLOTS_1D_DIR, '1d_visualization.pdf'))
146 changes: 146 additions & 0 deletions experiments/1d_visualization/1d_visualization_run_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from sim_transfer.sims.simulators import SinusoidsSim, QuadraticSim, LinearSim, ShiftedSinusoidsSim
from sim_transfer.models import BNN_FSVGD_SimPrior, BNN_SVGD, BNN_FSVGD
from matplotlib import pyplot as plt

import pickle
import os
import jax
import jax.numpy as jnp


# determine plot_dict_dir
plot_dict_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'plot_dicts')
os.makedirs(plot_dict_dir, exist_ok=True)


def _key_iter(data_seed: int = 24359):
key = jax.random.PRNGKey(data_seed)
while True:
key, new_key = jax.random.split(key)
yield new_key


def main(sim_type: str = 'SinusoidsSim', model: str = 'BNN_FSVGD_SimPrior_gp', num_train_points: int = 1,
plot_post_samples: bool = True, fun_seed: int = 24359):
key_iter = _key_iter()

if sim_type == 'QuadraticSim':
sim = QuadraticSim()
elif sim_type == 'LinearSim':
sim = LinearSim()
elif sim_type == 'SinusoidsSim':
sim = SinusoidsSim(output_size=1)
else:
raise NotImplementedError

x_plot = jnp.linspace(sim.domain.l, sim.domain.u, 100).reshape(-1, 1)

# """ plot samples from the simulation env """
# f_sim = sim.sample_function_vals(x_plot, num_samples=10, rng_key=jax.random.PRNGKey(234234))
# for i in range(f_sim.shape[0]):
# plt.plot(x_plot, f_sim[i])
# plt.show()

""" generate data """
fun = sim.sample_function(rng_key=jax.random.PRNGKey(291)) # 764
x_train = jax.random.uniform(key=next(key_iter), shape=(50,),
minval=sim.domain.l, maxval=sim.domain.u).reshape(-1, 1)
x_train = x_train[:num_train_points]
y_train = fun(x_train)
x_test = jnp.linspace(sim.domain.l, sim.domain.u, 100).reshape(-1, 1)
y_test = fun(x_test)

""" fit the model """
common_kwargs = {
'input_size': 1,
'output_size': 1,
'rng_key': next(key_iter),
'hidden_layer_sizes': [64, 64, 64],
'data_batch_size': 4,
'num_particles': 20,
'likelihood_std': 0.05,
'normalization_stats': sim.normalization_stats,
}
if model == 'BNN_SVGD':
bnn = BNN_SVGD(**common_kwargs, bandwidth_svgd=10., num_train_steps=2)
elif model == 'BNN_FSVGD':
bnn = BNN_FSVGD(**common_kwargs, domain=sim.domain, bandwidth_svgd=0.5, num_measurement_points=8)
elif model == 'BNN_FSVGD_SimPrior_gp':
bnn = BNN_FSVGD_SimPrior(**common_kwargs, domain=sim.domain, function_sim=sim,
num_train_steps=20000, num_f_samples=256, num_measurement_points=8,
bandwidth_svgd=1., score_estimator='gp')
elif model == 'BNN_FSVGD_SimPrior_kde':
bnn = BNN_FSVGD_SimPrior(**common_kwargs, domain=sim.domain, function_sim=sim,
num_train_steps=40000, num_f_samples=256, num_measurement_points=16,
bandwidth_svgd=1., score_estimator='kde')
elif model == 'BNN_FSVGD_SimPrior_nu-method':
bnn = BNN_FSVGD_SimPrior(**common_kwargs, domain=sim.domain, function_sim=sim,
num_train_steps=20000, num_f_samples=256, num_measurement_points=16,
bandwidth_svgd=1., score_estimator='nu-method', bandwidth_score_estim=1.0)

else:
raise NotImplementedError

bnn.fit(x_train, y_train, x_eval=x_test, y_eval=y_test)

""" make predictions and save the plot """
x_plot = jnp.linspace(sim.domain.l, sim.domain.u, 200).reshape((-1, 1))

# make predictions
pred_mean, pred_std = bnn.predict(x_plot)
y_post_samples = bnn.predict_post_samples(x_plot)

# get true function value
true_fun = fun(x_plot)
typical_fun = sim._typical_f(x_plot)

plot_dict = {
'model': model,
'plot_data': {
'x_train': x_train,
'y_train': y_train,
'x_plot': x_plot,
'true_fun': true_fun,
'pred_mean': pred_mean,
'pred_std': pred_std,
'y_post_samples': y_post_samples,
}
}
dump_path = os.path.join(plot_dict_dir, f'{sim_type}_{model}_{num_train_points}.pkl')

with open(dump_path, 'wb') as f:
pickle.dump(plot_dict, f)
print(f'Plot dict saved to {dump_path}')

# draw the plot
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(1 * 4, 4))
if bnn.output_size == 1:
ax = [ax]
for i in range(1):
ax[i].scatter(x_train.flatten(), y_train[:, i], label='train points')
ax[i].plot(x_plot, fun(x_plot)[:, i], label='true fun')
ax[i].plot(x_plot, typical_fun, label='typical fun')
ax[i].plot(x_plot.flatten(), pred_mean[:, i], label='pred mean')
ax[i].fill_between(x_plot.flatten(), pred_mean[:, i] - 2 * pred_std[:, i],
pred_mean[:, i] + 2 * pred_std[:, i], alpha=0.3)

if plot_post_samples:
y_post_samples = bnn.predict_post_samples(x_plot)
for y in y_post_samples:
ax[i].plot(x_plot, y[:, i], linewidth=0.2, color='green')

ax[i].legend()
fig.suptitle(model)
fig.show()


if __name__ == '__main__':
for num_train_points in [2]: #, 3, 5]:
for model in [
'BNN_SVGD',
'BNN_FSVGD',
'BNN_FSVGD_SimPrior_gp',
'BNN_FSVGD_SimPrior_nu-method',
'BNN_FSVGD_SimPrior_kde'
]:
main(model=model, num_train_points=num_train_points)
32 changes: 19 additions & 13 deletions experiments/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
'likelihood_std': {'value': _RACECAR_NOISE_STD_ENCODED.tolist()},
'num_samples_train': {'value': 200},
} for name in ['real_racecar_new', 'real_racecar_new_only_pose', 'real_racecar_new_no_angvel',
'real_racecar_new_actionstack', 'real_racecar_v2']
'real_racecar_new_actionstack', 'real_racecar_v2', 'real_racecar_v3', 'real_racecar_v4']
})


Expand Down Expand Up @@ -348,20 +348,26 @@ def provide_data_and_sim(data_source: str, data_spec: Dict[str, Any], data_seed:
y_test = sim_for_sampling_data._typical_f(x_test)
return x_train, y_train, x_test, y_test, sim
elif data_source == 'racecar_hf':
sim_hf = RaceCarSim(encode_angle=True, use_blend=True, only_pose=False)
sim_lf = RaceCarSim(encode_angle=True, use_blend=False, only_pose=False)
car_id = data_spec.get('car_id', 2)
sim_hf = RaceCarSim(encode_angle=True, use_blend=True, only_pose=False, car_id=car_id)
sim_lf = RaceCarSim(encode_angle=True, use_blend=False, only_pose=False, car_id=car_id)
elif data_source == 'racecar_hf_only_pose':
sim_hf = RaceCarSim(encode_angle=True, use_blend=True, only_pose=True)
sim_lf = RaceCarSim(encode_angle=True, use_blend=False, only_pose=True)
car_id = data_spec.get('car_id', 2)
sim_hf = RaceCarSim(encode_angle=True, use_blend=True, only_pose=True, car_id=car_id)
sim_lf = RaceCarSim(encode_angle=True, use_blend=False, only_pose=True, car_id=car_id)
elif data_source == 'racecar_hf_no_angvel':
sim_hf = RaceCarSim(encode_angle=True, use_blend=True, no_angular_velocity=True)
sim_lf = RaceCarSim(encode_angle=True, use_blend=False, no_angular_velocity=True)
car_id = data_spec.get('car_id', 2)
sim_hf = RaceCarSim(encode_angle=True, use_blend=True, no_angular_velocity=True, car_id=car_id)
sim_lf = RaceCarSim(encode_angle=True, use_blend=False, no_angular_velocity=True, car_id=car_id)
elif data_source == 'racecar_only_pose':
sim_hf = sim_lf = RaceCarSim(encode_angle=True, use_blend=True, only_pose=True)
car_id = data_spec.get('car_id', 2)
sim_hf = sim_lf = RaceCarSim(encode_angle=True, use_blend=True, only_pose=True, car_id=car_id)
elif data_source == 'racecar_no_angvel':
sim_hf = sim_lf = RaceCarSim(encode_angle=True, use_blend=True, no_angular_velocity=True)
car_id = data_spec.get('car_id', 2)
sim_hf = sim_lf = RaceCarSim(encode_angle=True, use_blend=True, no_angular_velocity=True, car_id=car_id)
elif data_source == 'racecar':
sim_hf = sim_lf = RaceCarSim(encode_angle=True, use_blend=True, only_pose=False)
car_id = data_spec.get('car_id', 2)
sim_hf = sim_lf = RaceCarSim(encode_angle=True, use_blend=True, only_pose=False, car_id=car_id)
else:
raise ValueError(f'Unknown data source {data_source}')
assert {'num_samples_train'} <= set(data_spec.keys()) <= {'num_samples_train'}.union(DEFAULTS_RACECAR.keys())
Expand All @@ -385,10 +391,10 @@ def provide_data_and_sim(data_source: str, data_spec: Dict[str, Any], data_seed:
elif data_source.startswith('real_racecar_new'):
x_train, y_train, x_test, y_test = get_rccar_recorded_data_new(encode_angle=True, action_stacking=False,
action_delay=3, car_id=car_id)
elif data_source.startswith('real_racecar_v2'):
elif data_source.startswith('real_racecar_v3'):
x_train, y_train, x_test, y_test = get_rccar_recorded_data_new(encode_angle=True, action_stacking=False,
action_delay=3, car_id=car_id,
dataset='v2')
dataset='v3')
else:
x_train, y_train, x_test, y_test = get_rccar_recorded_data(encode_angle=True)

Expand All @@ -411,7 +417,7 @@ def provide_data_and_sim(data_source: str, data_spec: Dict[str, Any], data_seed:
elif sampling_scheme == 'consecutive':
# sample random sub-trajectory (datapoints are adjacent in time -> highly correlated)
offset_train = jax.random.choice(key_train, jnp.arange(num_train_available - num_train))
offset_test = jax.random.choice(key_test, jnp.arange(num_test_available - num_test))
offset_test = jax.random.choice(key_test, jnp.arange(num_test_available - num_test + 1))
idx_train = jnp.arange(num_train) + offset_train
idx_test = jnp.arange(num_test) + offset_test
else:
Expand Down
2 changes: 1 addition & 1 deletion experiments/meta_learning_exp/run_meta_learning_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,11 @@ def generate_task(key):
rng_key=jax.random.PRNGKey(model_seed),
num_train_steps=num_iter_meta_train,
use_cross_attention=use_cross_attention,
num_points_context=args.num_samples_train//2,
latent_dim=latent_dim,
num_f_samples=16,
num_z_samples=4,
num_points_target=16,
num_points_context=8,
hidden_dim=hidden_dim,
use_self_attention=use_self_attention,
likelihood_std=likelihood_std,
Expand Down
13 changes: 7 additions & 6 deletions experiments/meta_learning_exp/sweep_meta_learning_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
MODEL_SPECIFIC_CONFIG = {
'PACOH': {
'prior_weight': {'distribution': 'log_uniform_10', 'min': -1., 'max': 0.},
'num_iter_meta_train': {'values': [20000, 30000, 40000]},
'meta_batch_size': {'values': [4, 8, 16]},
'bandwidth': {'distribution': 'log_uniform_10', 'min': 0., 'max': 2.},
'lr': {'distribution': 'log_uniform_10', 'min': -4., 'max': -3}
'hyper_prior_weight': {'distribution': 'log_uniform_10', 'min': -4., 'max': 0.},
'num_iter_meta_train': {'values': [40000]},
'meta_batch_size': {'values': [4]},
'bandwidth': {'values': [10.]},
'lr': {'distribution': 'log_uniform_10', 'min': -3.5, 'max': -3}
},
'NP': {
'num_iter_meta_train': {'values': [60000]},
'num_iter_meta_train': {'values': [60000, 100000]},
'latent_dim': {'values': [64, 128, 256]},
'hidden_dim': {'values': [32, 64, 128]},
'hidden_dim': {'values': [128, 256, 512]},
'lr': {'distribution': 'log_uniform_10', 'min': -4., 'max': -3},
},
}
Expand Down
9 changes: 7 additions & 2 deletions experiments/offline_rl_from_recorded_data/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def experiment(horizon_len: int,
input_from_recorded_data: int = 1,
obtain_consecutive_data: int = 1,
lr: float = 3e-4,
car_id: int = 2,
):
bnn_train_steps = min(num_epochs * num_offline_collected_transitions, max_train_steps)
bnn_train_steps = max(bnn_train_steps, min_train_steps)
Expand Down Expand Up @@ -77,7 +78,7 @@ def experiment(horizon_len: int,
num_evals=20,
reward_scaling=1,
episode_length=horizon_len,
episode_length_eval=200,
episode_length_eval=horizon_len,
action_repeat=1,
discounting=0.99,
lr_policy=1e-4,
Expand Down Expand Up @@ -135,6 +136,7 @@ def experiment(horizon_len: int,
input_from_recorded_data=input_from_recorded_data,
data_from_simulation=data_from_simulation,
likelihood_exponent=likelihood_exponent,
car_id=car_id,
)

total_config = SAC_KWARGS | config_dict | car_reward_kwargs
Expand All @@ -160,7 +162,8 @@ def experiment(horizon_len: int,
data_spec={'num_samples_train': 20_000,
'use_hf_sim': bool(high_fidelity),
'sampling': 'iid',
'num_stacked_actions': num_frame_stack},
'num_stacked_actions': num_frame_stack,
'car_id': car_id},
data_seed=int(int_data_seed),
)

Expand Down Expand Up @@ -347,6 +350,7 @@ def main(args):
input_from_recorded_data=args.input_from_recorded_data,
obtain_consecutive_data=args.obtain_consecutive_data,
lr=args.lr,
car_id=args.car_id,
)


Expand Down Expand Up @@ -387,5 +391,6 @@ def main(args):
parser.add_argument('--input_from_recorded_data', type=int, default=1)
parser.add_argument('--obtain_consecutive_data', type=int, default=1)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--car_id', type=int, default=3)
args = parser.parse_args()
main(args)
Loading

0 comments on commit 47d170f

Please sign in to comment.