From 306da1c8e554e691f3a7f99ef5b93142353808f0 Mon Sep 17 00:00:00 2001 From: sukhijab Date: Tue, 27 Feb 2024 13:21:59 +0100 Subject: [PATCH] working sergio version with lf sim and hyperparams --- experiments/data_provider.py | 2 +- .../lf_hf_transfer_exp/run_regression_exp.py | 12 +- .../sweep_regression_exp_num_data_sergio.py | 156 ++++++++++++++++++ sim_transfer/sims/dynamics_models.py | 7 +- sim_transfer/sims/simulators.py | 8 +- 5 files changed, 170 insertions(+), 15 deletions(-) create mode 100644 experiments/lf_hf_transfer_exp/sweep_regression_exp_num_data_sergio.py diff --git a/experiments/data_provider.py b/experiments/data_provider.py index 54c5e32..4355326 100644 --- a/experiments/data_provider.py +++ b/experiments/data_provider.py @@ -28,7 +28,7 @@ } DEFAULTS_SERGIO = { - 'obs_noise_std': 0.02, + 'obs_noise_std': 0.05, 'x_support_mode_train': 'full', 'param_mode': 'random', 'num_cells': 5, diff --git a/experiments/lf_hf_transfer_exp/run_regression_exp.py b/experiments/lf_hf_transfer_exp/run_regression_exp.py index 87bb113..4399d7f 100644 --- a/experiments/lf_hf_transfer_exp/run_regression_exp.py +++ b/experiments/lf_hf_transfer_exp/run_regression_exp.py @@ -215,7 +215,6 @@ def regression_experiment( # train model model.fit_with_scan(x_train, y_train, x_test, y_test, log_to_wandb=use_wandb, log_period=1000) - # eval model eval_metrics = model.eval(x_test, y_test, per_dim_metrics=True) return eval_metrics @@ -254,7 +253,7 @@ def main(args): exp_params['added_gp_outputscale'] = [factor * 0.05, 0.05, 0.5] elif 'Sergio' in exp_params['data_source']: from experiments.data_provider import DEFAULTS_SERGIO - exp_params['added_gp_outputscale'] = [factor * 0.05 for _ in range(DEFAULTS_SERGIO['sergio_dim'])] + exp_params['added_gp_outputscale'] = [factor * 0.1 for _ in range(DEFAULTS_SERGIO['sergio_dim'])] else: raise AssertionError('passed negative value for added_gp_outputscale') # set likelihood_std to default value if not specified @@ -271,9 +270,10 @@ def main(args): from pprint import pprint - print('\nExperiment parameters:') - pprint(exp_params) - print('') + if not 'Sergio' in exp_params['data_source']: + print('\nExperiment parameters:') + pprint(exp_params) + print('') """ Experiment core """ t_start = time.time() @@ -334,7 +334,7 @@ def main(args): parser.add_argument('--data_seed', type=int, default=77698) # standard BNN parameters - parser.add_argument('--model', type=str, default='BNN_FSVGD_SimPrior_hf_gp') + parser.add_argument('--model', type=str, default='SysID') parser.add_argument('--model_seed', type=int, default=892616) parser.add_argument('--likelihood_std', type=float, default=None) parser.add_argument('--learn_likelihood_std', type=int, default=0) diff --git a/experiments/lf_hf_transfer_exp/sweep_regression_exp_num_data_sergio.py b/experiments/lf_hf_transfer_exp/sweep_regression_exp_num_data_sergio.py new file mode 100644 index 0000000..fbd3947 --- /dev/null +++ b/experiments/lf_hf_transfer_exp/sweep_regression_exp_num_data_sergio.py @@ -0,0 +1,156 @@ +from experiments.util import (generate_run_commands, generate_base_command, RESULT_DIR, sample_param_flags, hash_dict) +from experiments.data_provider import DATASET_CONFIGS + +import experiments.lf_hf_transfer_exp.run_regression_exp +from experiments.lf_hf_transfer_exp.run_regression_exp import OUTPUTSCALES_RCCAR +import numpy as np +import datetime +import itertools +import argparse +import os +import jax.numpy as jnp + +MODEL_SPECIFIC_CONFIG = { + 'BNN_SVGD': { + 'bandwidth_svgd': {'values': [10.]}, + 'min_train_steps': {'values': [5_000]}, + 'num_epochs': {'values': [200]}, + 'lr': {'values': [1e-3]}, + # 'likelihood_reg': {'values': [10.0]}, + }, + 'BNN_FSVGD': { + 'bandwidth_svgd': {'values': [2.0]}, + 'bandwidth_gp_prior': {'values': [0.4]}, + 'min_train_steps': {'values': [5_000]}, + 'num_epochs': {'values': [200]}, + 'num_measurement_points': {'values': [128]}, + 'lr': {'values': [1e-3]}, + 'likelihood_reg': {'values': [10.0]}, + }, + + 'BNN_FSVGD_SimPrior_gp': { + 'bandwidth_svgd': {'values': [2.0]}, + 'min_train_steps': {'values': [5_000]}, + 'num_epochs': {'values': [200]}, + 'num_measurement_points': {'values': [128]}, + 'num_f_samples': {'values': [1028]}, + 'added_gp_lengthscale': {'values': [2.]}, + 'added_gp_outputscale': {'values': [2.0]}, + 'lr': {'values': [1e-3]}, + 'likelihood_reg': {'values': [10.0]}, + }, + + 'BNN_FSVGD_SimPrior_no_add_gp': { + 'bandwidth_svgd': {'values': [2.0]}, + 'min_train_steps': {'values': [5_000]}, + 'num_epochs': {'values': [200]}, + 'num_measurement_points': {'values': [128]}, + 'num_f_samples': {'values': [1028]}, + 'added_gp_lengthscale': {'values': [2.]}, + 'added_gp_outputscale': {'values': [2.0]}, + 'lr': {'values': [1e-3]}, + 'likelihood_reg': {'values': [10.0]}, + }, + + 'SysID': { + }, + 'GreyBox': { + 'bandwidth_svgd': {'values': [2.0]}, + 'bandwidth_gp_prior': {'values': [0.4]}, + 'min_train_steps': {'values': [5_000]}, + 'num_epochs': {'values': [200]}, + 'num_measurement_points': {'values': [128]}, + 'lr': {'values': [1e-3]}, + 'likelihood_reg': {'values': [10.0]}, + }, +} + + +def main(args): + # setup random seeds + rds = np.random.RandomState(args.seed) + model_seeds = list(rds.randint(0, 10 ** 6, size=(100,))) + data_seeds = list(rds.randint(0, 10 ** 6, size=(100,))) + + sweep_config = { + 'data_source': {'value': args.data_source}, + # 'num_samples_train': DATASET_CONFIGS[args.data_source]['num_samples_train'], + 'model': {'value': args.model}, + 'learn_likelihood_std': {'value': args.learn_likelihood_std}, + # 'likelihood_std': {'value': None}, + 'num_particles': {'value': 20}, + 'data_batch_size': {'value': 8}, + 'pred_diff': {'value': args.pred_diff}, + 'max_train_steps': {'value': 300_000}, + 'num_sim_model_train_steps': {'value': 5_000}, + } + # update with model specific sweep ranges + model_name = args.model.replace('_no_add_gp', '') + model_name = model_name.replace('_hf', '') + assert model_name in MODEL_SPECIFIC_CONFIG + sweep_config.update(MODEL_SPECIFIC_CONFIG[model_name]) + + # determine name of experiment + exp_base_path = os.path.join(RESULT_DIR, args.exp_name) + exp_path = os.path.join(exp_base_path, f'{args.data_source}_{args.model}') + + if args.data_source == 'racecar_hf': + N_SAMPLES_LIST = [50, 100, 200, 400, 800, 1600, 3200, 6400] + elif args.data_source == 'pendulum_hf': + N_SAMPLES_LIST = [10, 20, 40, 80, 160, 320, 640, 1280] + elif args.data_source == 'real_racecar_v3': + N_SAMPLES_LIST = [50, 100, 200, 400, 800, 1600, 3200, 6400] + elif args.data_source == 'Sergio_hf': + N_SAMPLES_LIST = [200, 400, 800, 1600, 3200, 4800, 6400, 12800] + else: + raise NotImplementedError(f'Unknown data source {args.data_source}.') + + command_list = [] + output_file_list = [] + for _ in range(args.num_hparam_samples): + flags = sample_param_flags(sweep_config) + exp_hash = hash_dict(flags) + for num_samples_train in N_SAMPLES_LIST: + exp_result_folder = os.path.join(exp_path, f'{exp_hash}_{num_samples_train}') + flags['exp_result_folder'] = exp_result_folder + + for model_seed, data_seed in itertools.product(model_seeds[:args.num_model_seeds], + data_seeds[:args.num_data_seeds]): + cmd = generate_base_command(experiments.lf_hf_transfer_exp.run_regression_exp, + flags=dict(**flags, **{'model_seed': model_seed, 'data_seed': data_seed, + 'num_samples_train': num_samples_train, + })) + command_list.append(cmd) + output_file_list.append(os.path.join(exp_result_folder, f'{model_seed}_{data_seed}.out')) + + generate_run_commands(command_list, output_file_list, num_cpus=args.num_cpus, + num_gpus=1 if args.gpu else 0, mode=args.run_mode, prompt=not args.yes) + + +if __name__ == '__main__': + current_date = datetime.datetime.now().strftime("%b%d").lower() + parser = argparse.ArgumentParser(description='Meta-BO run') + + # sweep args + parser.add_argument('--num_hparam_samples', type=int, default=1) + parser.add_argument('--num_model_seeds', type=int, default=5, help='number of model seeds per hparam') + parser.add_argument('--num_data_seeds', type=int, default=5, help='number of model seeds per hparam') + parser.add_argument('--num_cpus', type=int, default=4, help='number of cpus to use') + parser.add_argument('--run_mode', type=str, default='euler') + + # general args + parser.add_argument('--exp_name', type=str, default=f'test_{current_date}') + parser.add_argument('--seed', type=int, default=94563) + parser.add_argument('--gpu', default=True, action='store_true') + parser.add_argument('--yes', default=False, action='store_true') + + # data parameters + parser.add_argument('--data_source', type=str, default='pendulum_hf') + parser.add_argument('--pred_diff', type=int, default=0) + + # # standard BNN parameters + parser.add_argument('--model', type=str, default='BNN_SVGD') + parser.add_argument('--learn_likelihood_std', type=int, default=0) + + args = parser.parse_args() + main(args) diff --git a/sim_transfer/sims/dynamics_models.py b/sim_transfer/sims/dynamics_models.py index 083e79e..6af9fec 100644 --- a/sim_transfer/sims/dynamics_models.py +++ b/sim_transfer/sims/dynamics_models.py @@ -626,7 +626,7 @@ def hill_function(x: jnp.array): hills = hill_function(x) hills = hills[:, :, None] - masked_contribution = params.contribution_rates * params.graph + masked_contribution = params.graph[None] * params.contribution_rates # [n_cell_types, n_genes, n_genes] # switching mechanism between activation and repression, @@ -672,11 +672,10 @@ def sample_single_params(self, key: jax.random.PRNGKey, lower_bound: SergioParam lower_bound_graph = jnp.clip(lower_bound.graph, 0, 2) upper_bound_graph = jnp.clip(upper_bound.graph, 0, 2) graph = jax.random.randint(graph_key, shape=(self.n_genes, self.n_genes), minval=lower_bound_graph, - maxval=upper_bound_graph) + maxval=upper_bound_graph) * 1.0 diag_elements = jnp.diag_indices_from(graph) - graph = graph.at[diag_elements].set(1) + graph = graph.at[diag_elements].set(1.0) power = jax.random.uniform(power_key, shape=(1, ), minval=lower_bound.power, maxval=upper_bound.power) - return SergioParams( lam=lam, contribution_rates=contribution_rates, diff --git a/sim_transfer/sims/simulators.py b/sim_transfer/sims/simulators.py index 787426c..3e6fa6c 100644 --- a/sim_transfer/sims/simulators.py +++ b/sim_transfer/sims/simulators.py @@ -1021,7 +1021,7 @@ def domain(self) -> Domain: return self._domain def _setup_params(self): - self.lower_bound_param_hf = SergioParams(lam=jnp.array(0.7), + self.lower_bound_param_hf = SergioParams(lam=jnp.array(0.79), contribution_rates=jnp.array(-5.0), basal_rates=jnp.array(1.0), power=jnp.array(2.0), @@ -1034,15 +1034,15 @@ def _setup_params(self): self.default_param_hf = self.model.sample_single_params(jax.random.PRNGKey(0), self.lower_bound_param_hf, self.upper_bound_param_hf) - self.lower_bound_param_lf = SergioParams(lam=jnp.array(0.1), + self.lower_bound_param_lf = SergioParams(lam=jnp.array(0.79), contribution_rates=jnp.array(-5.0), basal_rates=jnp.array(1.0), - power=jnp.array(0.0), + power=jnp.array(1.0), graph=jnp.array(0)) self.upper_bound_param_lf = SergioParams(lam=jnp.array(0.8), contribution_rates=jnp.array(5.0), basal_rates=jnp.array(5.0), - power=jnp.array(0.0), + power=jnp.array(1.0), graph=jnp.array(0)) self.default_param_lf = self.model.sample_single_params(jax.random.PRNGKey(0), self.lower_bound_param_lf, self.upper_bound_param_lf)