Skip to content

Commit

Permalink
fix pred_diff data generation issue in meta_learning exps
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasrothfuss committed Jan 10, 2024
1 parent c360188 commit be3bfcb
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions experiments/meta_learning_exp/run_meta_learning_exp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from experiments.data_provider import provide_data_and_sim
from experiments.util import hash_dict, NumpyArrayEncoder
from experiments.data_provider import DATASET_CONFIGS
from sim_transfer.sims.simulators import PredictStateChangeWrapper

import os
import tensorflow as tf
Expand All @@ -23,6 +24,7 @@ def meta_learning_experiment(
likelihood_std: float = 0.1,
num_iter_meta_train: int = 20000,
prior_weight: float = 1.0,
hyper_prior_weight: float = 1.0,
meta_batch_size: int = 16,
batch_size: int = 32,
bandwidth: float = 10.,
Expand All @@ -36,26 +38,41 @@ def meta_learning_experiment(
# provide data and sim
x_train, y_train, x_test, y_test, sim = provide_data_and_sim(
data_source=data_source,
data_spec={'num_samples_train': num_samples_train},
data_spec={'num_samples_train': 10000},
data_seed=data_seed)

# only take num_samples_train datapoints
assert num_samples_train <= 10000
x_train, y_train = x_train[:num_samples_train], y_train[:num_samples_train]

if bool(pred_diff):
assert x_train.shape[-1] == sim.input_size and y_train.shape[-1] == sim.output_size
y_train = y_train - x_train[..., :sim.output_size]
y_test = y_test - x_test[..., :sim.output_size]

meta_test_data = [tuple(map(lambda arr: np.array(arr), [x_train, y_train, x_test, y_test]))]


# generate meta-training data from sim
sim_data_key = jax.random.PRNGKey(model_seed + 1)
meta_training_data = []
for key in jax.random.split(sim_data_key, num_tasks):

@jax.jit
def generate_task(key):
key_x, key_y = jax.random.split(key)
x = sim.domain.sample_uniformly(key=key_x, sample_shape=500)
y = sim.sample_function_vals(x=x, num_samples=1, rng_key=key_y)[0]
y = y - x[..., :sim.output_size]
if pred_diff:
y = y - x[..., :sim.output_size]
return x, y

meta_training_data = []
for key in jax.random.split(sim_data_key, num_tasks):
x, y = generate_task(key)
meta_training_data.append((np.array(x), np.array(y)))

if bool(pred_diff):
sim = PredictStateChangeWrapper(sim)

if model == 'PACOH':
from baselines.pacoh_nn.pacoh_nn_regression import PACOH_NN_Regression
# run meta-learning
Expand All @@ -68,19 +85,21 @@ def meta_learning_experiment(
activation='leaky_relu',
likelihood_std=likelihood_std,
prior_weight=prior_weight,
hyper_prior_weight=hyper_prior_weight,
meta_batch_size=meta_batch_size,
batch_size=batch_size,
bandwidth=bandwidth,
lr=lr)

# run meta-testing
pacoh_model.meta_fit(meta_val_data=meta_test_data, eval_period=5000)
pacoh_model.meta_fit(meta_val_data=meta_test_data, eval_period=4000)

y_preds, pred_dist = pacoh_model.meta_predict(x_train, y_train, x_test)
nll = - float(tf.reduce_mean(pred_dist.log_prob(y_test)))
rmse = float(tf.sqrt(tf.reduce_mean(tf.reduce_sum((pred_dist.mean() - y_test)**2, axis=-1))))
avg_std = float(tf.reduce_mean(pred_dist.stddev()))
eval_stats = {'nll': nll, 'rmse': rmse, 'avg_std': avg_std}

elif model == 'BNN':
from baselines.pacoh_nn.bnn import BayesianNeuralNetworkSVGD
pacoh_model = BayesianNeuralNetworkSVGD(x_train, y_train,
Expand Down Expand Up @@ -186,7 +205,7 @@ def main(args):

# data parameters
parser.add_argument('--data_source', type=str, default='racecar')
parser.add_argument('--num_samples_train', type=int, default=20)
parser.add_argument('--num_samples_train', type=int, default=100)
parser.add_argument('--data_seed', type=int, default=77698)
parser.add_argument('--num_tasks', type=int, default=200)

Expand All @@ -196,7 +215,7 @@ def main(args):
parser.add_argument('--model', type=str, default='PACOH')
parser.add_argument('--model_seed', type=int, default=892616)
parser.add_argument('--likelihood_std', type=float, default=None)
parser.add_argument('--num_iter_meta_train', type=int, default=50000)
parser.add_argument('--num_iter_meta_train', type=int, default=100)
parser.add_argument('--pred_diff', type=int, default=1)

# -- PACOH
Expand All @@ -205,6 +224,7 @@ def main(args):
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--bandwidth', type=float, default=10.)
parser.add_argument('--learn_likelihood_std', type=int, default=0)
parser.add_argument('--hyper_prior_weight', type=float, default=1e-3)

# -- NP
parser.add_argument('--use_cross_attention', type=int, default=1)
Expand Down

0 comments on commit be3bfcb

Please sign in to comment.