Skip to content

Commit

Permalink
Merge branch 'main' of github.com:lasgroup/simulation_transfer
Browse files Browse the repository at this point in the history
  • Loading branch information
sukhijab committed Jan 29, 2024
2 parents c1dc56b + a05da7e commit ecebb0e
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions experiments/sim_real_transfer_exp/run_regression_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
'swish': jax.nn.swish,
}

OUTPUTSCALES_RCCAR = [0.008, 0.008, 0.009, 0.009, 0.05, 0.05, 0.2]



def regression_experiment(
Expand Down Expand Up @@ -89,8 +89,8 @@ def regression_experiment(
# provide data and sim
x_train, y_train, x_test, y_test, sim = provide_data_and_sim(
data_source=data_source,
data_spec={'num_samples_train': 6500, 'sampling': 'iid',
'use_hf_sim': bool(use_hf_sim), 'num_samples_test': 6000},
data_spec={'num_samples_train': 6500, 'sampling': 'iid', 'use_hf_sim': bool(use_hf_sim),
'num_samples_test': 6000},
data_seed=data_seed)
x_train = x_train[:num_samples_train]
y_train = y_train[:num_samples_train]
Expand Down Expand Up @@ -184,7 +184,7 @@ def regression_experiment(
base_bnn=base_bnn,
sim=sim,
use_base_bnn=(model == 'GreyBox'),
num_sim_model_train_steps=20_000,
num_sim_model_train_steps=5_000,
)
elif model == 'BNN_MMD_SimPrior':
model = BNN_MMD_SimPrior(domain=sim.domain,
Expand Down Expand Up @@ -238,6 +238,11 @@ def main(args):
print(f"Setting likelihood_std to data_source default value from DATASET_CONFIGS "
f"which is {exp_params['likelihood_std']}")

if args.use_hf_sim:
OUTPUTSCALES_RCCAR = [0.008, 0.008, 0.009, 0.009, 0.05, 0.05, 0.2]
else:
OUTPUTSCALES_RCCAR = [0.008, 0.008, 0.03, 0.03, 0.3, 0.3, 1.5]

# custom gp outputscale for racecar_hf
if 'real_racecar' in exp_params['data_source']:
outputscales_racecar = exp_params['added_gp_outputscale'] * jnp.array(OUTPUTSCALES_RCCAR)
Expand Down Expand Up @@ -309,17 +314,17 @@ def main(args):
# data parameters
parser.add_argument('--data_source', type=str, default='real_racecar_v3')
parser.add_argument('--pred_diff', type=int, default=1)
parser.add_argument('--num_samples_train', type=int, default=50)
parser.add_argument('--data_seed', type=int, default=77698)
parser.add_argument('--num_samples_train', type=int, default=6400)
parser.add_argument('--data_seed', type=int, default=127748)

# standard BNN parameters
parser.add_argument('--model', type=str, default='SysID')
parser.add_argument('--model_seed', type=int, default=892616)
parser.add_argument('--model', type=str, default='BNN_FSVGD_SimPrior_gp')
parser.add_argument('--model_seed', type=int, default=2342)
parser.add_argument('--likelihood_std', type=float, default=None)
parser.add_argument('--learn_likelihood_std', type=int, default=1)
parser.add_argument('--likelihood_reg', type=float, default=0.0)
parser.add_argument('--likelihood_reg', type=float, default=10.0)
parser.add_argument('--data_batch_size', type=int, default=8)
parser.add_argument('--min_train_steps', type=int, default=2500)
parser.add_argument('--min_train_steps', type=int, default=8_000)
parser.add_argument('--num_epochs', type=int, default=60)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--hidden_activation', type=str, default='leaky_relu')
Expand All @@ -330,26 +335,25 @@ def main(args):

# SVGD parameters
parser.add_argument('--num_particles', type=int, default=20)
parser.add_argument('--bandwidth_svgd', type=float, default=10.0)
parser.add_argument('--bandwidth_svgd', type=float, default=10.)
parser.add_argument('--weight_prior_std', type=float, default=0.5)
parser.add_argument('--bias_prior_std', type=float, default=1.0)

# FSVGD parameters
parser.add_argument('--bandwidth_gp_prior', type=float, default=0.1)
parser.add_argument('--outputscale_gp_prior', type=float, default=1.0)
parser.add_argument('--bandwidth_gp_prior', type=float, default=0.4)
parser.add_argument('--outputscale_gp_prior', type=float, default=0.1)
parser.add_argument('--num_measurement_points', type=int, default=32)

# FSVGD_SimPrior parameters
parser.add_argument('--bandwidth_score_estim', type=float, default=None)
parser.add_argument('--ssge_kernel_type', type=str, default='IMQ')
parser.add_argument('--num_f_samples', type=int, default=128)
parser.add_argument('--num_f_samples', type=int, default=1028)
parser.add_argument('--switch_score_estimator_frac', type=float, default=0.6667)
parser.add_argument('--use_hf_sim', type=int, default=1)


# Additive SimPrior GP parameters
parser.add_argument('--added_gp_lengthscale', type=float, default=5.)
parser.add_argument('--added_gp_outputscale', type=float, default=1.0)
parser.add_argument('--added_gp_outputscale', type=float, default=20.)

# FSVGD_SimPrior parameters
parser.add_argument('--num_distill_steps', type=int, default=50000)
Expand Down

0 comments on commit ecebb0e

Please sign in to comment.