Skip to content

Commit

Permalink
minor change: set kernel params earlier for nu method
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasrothfuss committed Aug 24, 2023
1 parent 0a4c59e commit 6402dc8
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 13 deletions.
2 changes: 1 addition & 1 deletion experiments/meta_learning_exp/sweep_meta_learning_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
'bandwidth': {'distribution': 'log_uniform_10', 'min': 0., 'max': 2.},
},
'NP': {
'num_iter_meta_train': {'values': [40000]},
'num_iter_meta_train': {'values': [100000]},
'latent_dim': {'values': [64, 128]},
'hidden_dim': {'values': [32, 64]},
'lr': {'distribution': 'log_uniform_10', 'min': -4., 'max': -3},
Expand Down
14 changes: 8 additions & 6 deletions experiments/regression_exp/inspect_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,21 @@ def main(args, drop_nan=False):

print('Models:', set(df_agg['model']))

different_method_plot(df_agg, metric='nll')
different_method_plot(df_agg, metric='rmse')
different_method_plot(df_agg, metric='nll', filter_std_higher_than=50.)
different_method_plot(df_agg, metric='rmse', filter_std_higher_than=0.2)

df_method = df_agg[(df_agg['model'] == 'BNN_FSVGD_SimPrior_gp')]
df_method = df_agg[(df_agg['model'] == 'BNN_FSVGD_SimPrior_nu-method')]
#df_method = df_method[(df_method['num_train_steps'] == 80000)]

#df_method = df_method[df_method['bandwidth_score_estim'] > 1.0]

metric = 'nll'
for param in ['num_f_samples', 'bandwidth_score_estim', 'bandwidth_svgd', 'num_measurement_points']:
plt.scatter(df_method[param], df_method[(metric, 'mean')])
plt.xlabel(param)
plt.xscale('log')
#plt.xscale('log')
plt.ylabel(metric)
plt.ylim(-15, 20)
plt.show()

QUANTILE_BASED_CI = True
Expand Down Expand Up @@ -157,7 +159,7 @@ def main(args, drop_nan=False):

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Inspect results of a regression experiment.')
parser.add_argument('--exp_name', type=str, default='aug22')
parser.add_argument('--data_source', type=str, default='racecar')
parser.add_argument('--exp_name', type=str, default='june09')
parser.add_argument('--data_source', type=str, default='pendulum')
args = parser.parse_args()
main(args)
9 changes: 5 additions & 4 deletions experiments/regression_exp/sweep_regression_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,23 @@
},
'BNN_FSVGD_SimPrior_nu-method': {
'bandwidth_svgd': {'distribution': 'log_uniform_10', 'min': -1.0, 'max': 0.0},
'num_train_steps': {'values': [40000]},
'num_train_steps': {'values': [30000, 40000, 50000]},
'num_measurement_points': {'values': [16, 32]},
'num_f_samples': {'values': [512]},
'bandwidth_score_estim': {'distribution': 'log_uniform_10', 'min': 0.0, 'max': 0.5},
#'bandwidth_score_estim': {'distribution': 'log_uniform_10', 'min': -0.5, 'max': 0.5},
'bandwidth_score_estim': {'distribution': 'uniform', 'min': 0.5, 'max': 4.0},
},
'BNN_FSVGD_SimPrior_gp+nu-method': {
'bandwidth_svgd': {'distribution': 'log_uniform_10', 'min': -1.0, 'max': 0.0},
'num_train_steps': {'values': [40000]},
'num_train_steps': {'values': [80000]},
'num_measurement_points': {'values': [16, 32]},
'num_f_samples': {'values': [512]},
'switch_score_estimator_frac': {'values': [0.6667]},
'bandwidth_score_estim': {'distribution': 'log_uniform_10', 'min': 0.0, 'max': 0.5},
},
'BNN_FSVGD_SimPrior_kde': {
'bandwidth_svgd': {'distribution': 'log_uniform', 'min': -2., 'max': 2.},
'num_train_steps': {'values': [20000]},
'num_train_steps': {'values': [40000]},
'num_measurement_points': {'values': [16, 32]},
'num_f_samples': {'values': [512, 1024, 2056]},
},
Expand Down
4 changes: 2 additions & 2 deletions sim_transfer/score_estimation/nu_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def __init__(self,
num_iter = jnp.array(1.0 / jnp.sqrt(lam)).astype(jnp.int32) + 1

if kernel_type == 'curl_free_imq':
self._kernel = CurlFreeIMQKernel()
self._kernel = CurlFreeIMQKernel(kernel_hyperparams=bandwidth)
elif kernel_type == 'curl_free_se':
self._kernel = CurlFreeSEKernel()
self._kernel = CurlFreeSEKernel(kernel_hyperparams=bandwidth)
else:
raise NotImplementedError(f'Kernel type {kernel_type} is not implemented.')

Expand Down

0 comments on commit 6402dc8

Please sign in to comment.