Skip to content

Commit

Permalink
1d visualization scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasrothfuss committed Jan 15, 2024
1 parent 0dc6f02 commit 81c496c
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 4 deletions.
46 changes: 46 additions & 0 deletions experiments/1d_visualization/1d_visualization_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
import pickle
from matplotlib import pyplot as plt

PLOT_POST_SAMPLES = True

PLOT_DICT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'plot_dicts')
PLOT_DICT_PATHS = [
('BNN_SVGD', 'SinusoidsSim_BNN_SVGD_2.pkl'),
('BNN_FSVGD', 'SinusoidsSim_BNN_FSVGD_2.pkl'),
('BNN_FSVGD_SimPrior_gp', 'SinusoidsSim_BNN_FSVGD_SimPrior_gp_2.pkl'),
('BNN_FSVGD_SimPrior_nu-method', 'SinusoidsSim_BNN_FSVGD_SimPrior_nu-method_2.pkl'),
('BNN_FSVGD_SimPrior_kde', 'SinusoidsSim_BNN_FSVGD_SimPrior_kde_2.pkl'),
]
PLOT_DICT_PATHS = map(lambda x: (x[0], os.path.join(PLOT_DICT_DIR, x[1])), PLOT_DICT_PATHS)

PLOT_MODELS = ['BNN_SVGD', 'BNN_FSVGD', 'BNN_FSVGD_SimPrior_gp']


# draw the plot
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(3 * 4, 6))


for i, (model, load_path) in enumerate(PLOT_DICT_PATHS):
with open(load_path, 'rb') as f:
plot_dict = pickle.load(f)
print(f'Plot dict loaded from {load_path}')
plot_data = plot_dict['plot_data']

ax = axes[i//3][i%3]
ax.scatter(plot_data['x_train'].flatten(), plot_data['y_train'][:, i], label='train points')
ax.plot(plot_data['x_plot'], plot_data['true_fun'], label='true fun')
ax.plot(plot_data['x_plot'].flatten(), plot_data['pred_mean'][:, i], label='pred mean')
ax.fill_between(plot_data['x_plot'].flatten(), plot_data['pred_mean'][:, i] - 2 * plot_data['pred_std'][:, i],
plot_data['pred_mean'][:, i] + 2 * plot_data['pred_std'][:, i], alpha=0.3)

if PLOT_POST_SAMPLES:
for y in plot_data['y_post_samples']:
ax.plot(plot_data['x_plot'], y[:, i], linewidth=0.2, color='green')

if i == 2:
ax.legend()
ax.set_title(model)
ax.set_ylim((-14, 14))
fig.tight_layout()
fig.show()
146 changes: 146 additions & 0 deletions experiments/1d_visualization/1d_visualization_run_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from sim_transfer.sims.simulators import SinusoidsSim, QuadraticSim, LinearSim, ShiftedSinusoidsSim
from sim_transfer.models import BNN_FSVGD_SimPrior, BNN_SVGD, BNN_FSVGD
from matplotlib import pyplot as plt

import pickle
import os
import jax
import jax.numpy as jnp


# determine plot_dict_dir
plot_dict_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'plot_dicts')
os.makedirs(plot_dict_dir, exist_ok=True)


def _key_iter(data_seed: int = 24359):
key = jax.random.PRNGKey(data_seed)
while True:
key, new_key = jax.random.split(key)
yield new_key


def main(sim_type: str = 'SinusoidsSim', model: str = 'BNN_FSVGD_SimPrior_gp', num_train_points: int = 1,
plot_post_samples: bool = True, fun_seed: int = 24359):
key_iter = _key_iter()

if sim_type == 'QuadraticSim':
sim = QuadraticSim()
elif sim_type == 'LinearSim':
sim = LinearSim()
elif sim_type == 'SinusoidsSim':
sim = SinusoidsSim(output_size=1)
else:
raise NotImplementedError

x_plot = jnp.linspace(sim.domain.l, sim.domain.u, 100).reshape(-1, 1)

# """ plot samples from the simulation env """
# f_sim = sim.sample_function_vals(x_plot, num_samples=10, rng_key=jax.random.PRNGKey(234234))
# for i in range(f_sim.shape[0]):
# plt.plot(x_plot, f_sim[i])
# plt.show()

""" generate data """
fun = sim.sample_function(rng_key=jax.random.PRNGKey(291)) # 764
x_train = jax.random.uniform(key=next(key_iter), shape=(50,),
minval=sim.domain.l, maxval=sim.domain.u).reshape(-1, 1)
x_train = x_train[:num_train_points]
y_train = fun(x_train)
x_test = jnp.linspace(sim.domain.l, sim.domain.u, 100).reshape(-1, 1)
y_test = fun(x_test)

""" fit the model """
common_kwargs = {
'input_size': 1,
'output_size': 1,
'rng_key': next(key_iter),
'hidden_layer_sizes': [64, 64, 64],
'data_batch_size': 4,
'num_particles': 20,
'likelihood_std': 0.05,
'normalization_stats': sim.normalization_stats,
}
if model == 'BNN_SVGD':
bnn = BNN_SVGD(**common_kwargs, bandwidth_svgd=10., num_train_steps=2)
elif model == 'BNN_FSVGD':
bnn = BNN_FSVGD(**common_kwargs, domain=sim.domain, bandwidth_svgd=0.5, num_measurement_points=8)
elif model == 'BNN_FSVGD_SimPrior_gp':
bnn = BNN_FSVGD_SimPrior(**common_kwargs, domain=sim.domain, function_sim=sim,
num_train_steps=20000, num_f_samples=256, num_measurement_points=8,
bandwidth_svgd=1., score_estimator='gp')
elif model == 'BNN_FSVGD_SimPrior_kde':
bnn = BNN_FSVGD_SimPrior(**common_kwargs, domain=sim.domain, function_sim=sim,
num_train_steps=40000, num_f_samples=256, num_measurement_points=16,
bandwidth_svgd=1., score_estimator='kde')
elif model == 'BNN_FSVGD_SimPrior_nu-method':
bnn = BNN_FSVGD_SimPrior(**common_kwargs, domain=sim.domain, function_sim=sim,
num_train_steps=20000, num_f_samples=256, num_measurement_points=16,
bandwidth_svgd=1., score_estimator='nu-method', bandwidth_score_estim=1.0)

else:
raise NotImplementedError

bnn.fit(x_train, y_train, x_eval=x_test, y_eval=y_test)

""" make predictions and save the plot """
x_plot = jnp.linspace(sim.domain.l, sim.domain.u, 200).reshape((-1, 1))

# make predictions
pred_mean, pred_std = bnn.predict(x_plot)
y_post_samples = bnn.predict_post_samples(x_plot)

# get true function value
true_fun = fun(x_plot)
typical_fun = sim._typical_f(x_plot)

plot_dict = {
'model': model,
'plot_data': {
'x_train': x_train,
'y_train': y_train,
'x_plot': x_plot,
'true_fun': true_fun,
'pred_mean': pred_mean,
'pred_std': pred_std,
'y_post_samples': y_post_samples,
}
}
dump_path = os.path.join(plot_dict_dir, f'{sim_type}_{model}_{num_train_points}.pkl')

with open(dump_path, 'wb') as f:
pickle.dump(plot_dict, f)
print(f'Plot dict saved to {dump_path}')

# draw the plot
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(1 * 4, 4))
if bnn.output_size == 1:
ax = [ax]
for i in range(1):
ax[i].scatter(x_train.flatten(), y_train[:, i], label='train points')
ax[i].plot(x_plot, fun(x_plot)[:, i], label='true fun')
ax[i].plot(x_plot, typical_fun, label='typical fun')
ax[i].plot(x_plot.flatten(), pred_mean[:, i], label='pred mean')
ax[i].fill_between(x_plot.flatten(), pred_mean[:, i] - 2 * pred_std[:, i],
pred_mean[:, i] + 2 * pred_std[:, i], alpha=0.3)

if plot_post_samples:
y_post_samples = bnn.predict_post_samples(x_plot)
for y in y_post_samples:
ax[i].plot(x_plot, y[:, i], linewidth=0.2, color='green')

ax[i].legend()
fig.suptitle(model)
fig.show()


if __name__ == '__main__':
for num_train_points in [2]: #, 3, 5]:
for model in [
'BNN_SVGD',
'BNN_FSVGD',
'BNN_FSVGD_SimPrior_gp',
'BNN_FSVGD_SimPrior_nu-method',
'BNN_FSVGD_SimPrior_kde'
]:
main(model=model, num_train_points=num_train_points)
8 changes: 4 additions & 4 deletions experiments/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
'likelihood_std': {'value': _RACECAR_NOISE_STD_ENCODED.tolist()},
'num_samples_train': {'value': 200},
} for name in ['real_racecar_new', 'real_racecar_new_only_pose', 'real_racecar_new_no_angvel',
'real_racecar_new_actionstack', 'real_racecar_v2']
'real_racecar_new_actionstack', 'real_racecar_v2', 'real_racecar_v3', 'real_racecar_v4']
})


Expand Down Expand Up @@ -385,10 +385,10 @@ def provide_data_and_sim(data_source: str, data_spec: Dict[str, Any], data_seed:
elif data_source.startswith('real_racecar_new'):
x_train, y_train, x_test, y_test = get_rccar_recorded_data_new(encode_angle=True, action_stacking=False,
action_delay=3, car_id=car_id)
elif data_source.startswith('real_racecar_v2'):
elif data_source.startswith('real_racecar_v3'):
x_train, y_train, x_test, y_test = get_rccar_recorded_data_new(encode_angle=True, action_stacking=False,
action_delay=3, car_id=car_id,
dataset='v2')
dataset='v3')
else:
x_train, y_train, x_test, y_test = get_rccar_recorded_data(encode_angle=True)

Expand All @@ -411,7 +411,7 @@ def provide_data_and_sim(data_source: str, data_spec: Dict[str, Any], data_seed:
elif sampling_scheme == 'consecutive':
# sample random sub-trajectory (datapoints are adjacent in time -> highly correlated)
offset_train = jax.random.choice(key_train, jnp.arange(num_train_available - num_train))
offset_test = jax.random.choice(key_test, jnp.arange(num_test_available - num_test))
offset_test = jax.random.choice(key_test, jnp.arange(num_test_available - num_test + 1))
idx_train = jnp.arange(num_train) + offset_train
idx_test = jnp.arange(num_test) + offset_test
else:
Expand Down
17 changes: 17 additions & 0 deletions sim_transfer/sims/simulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,23 @@ def _f1(self, amp, freq, slope, x):
def _f2(self, amp, freq, slope, x):
return amp * jnp.cos(freq * x) - slope * x

def sample_function(self, rng_key: jax.random.PRNGKey) -> Callable:
key1, key2, key3, key4 = jax.random.split(rng_key, 4)
freq = jax.random.uniform(key1, minval=self.freq1_mid - self.freq1_spread,
maxval=self.freq1_mid + self.freq1_spread)
amp = self.amp_mean + self.amp_std * jax.random.normal(key2)
slope = self.slope_mean + self.slope_std * jax.random.normal(key3)
f = lambda x: self._f1(amp, freq, slope, x)
if self.output_size == 1:
return f
elif self.output_size == 2:
freq2 = jax.random.uniform(key4, minval=self.freq2_mid - self.freq2_spread,
maxval=self.freq2_mid + self.freq2_spread)
f2 = lambda x: self._f2(amp, freq2, slope, x)
return lambda x: jnp.concatenate([f(x)[:, None], f2(x)[:, None]], axis=-1)
else:
raise NotImplementedError

def _typical_f(self, x: jnp.array) -> jnp.array:
assert x.ndim == 2 and x.shape[-1] == self.input_size
f = self._f1(self.amp_mean, self.freq1_mid, self.slope_mean, x)
Expand Down

0 comments on commit 81c496c

Please sign in to comment.