Skip to content

Commit

Permalink
working sergio version with lf sim and hyperparams
Browse files Browse the repository at this point in the history
  • Loading branch information
sukhijab committed Feb 27, 2024
1 parent db4ead8 commit 306da1c
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 15 deletions.
2 changes: 1 addition & 1 deletion experiments/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
}

DEFAULTS_SERGIO = {
'obs_noise_std': 0.02,
'obs_noise_std': 0.05,
'x_support_mode_train': 'full',
'param_mode': 'random',
'num_cells': 5,
Expand Down
12 changes: 6 additions & 6 deletions experiments/lf_hf_transfer_exp/run_regression_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def regression_experiment(

# train model
model.fit_with_scan(x_train, y_train, x_test, y_test, log_to_wandb=use_wandb, log_period=1000)

# eval model
eval_metrics = model.eval(x_test, y_test, per_dim_metrics=True)
return eval_metrics
Expand Down Expand Up @@ -254,7 +253,7 @@ def main(args):
exp_params['added_gp_outputscale'] = [factor * 0.05, 0.05, 0.5]
elif 'Sergio' in exp_params['data_source']:
from experiments.data_provider import DEFAULTS_SERGIO
exp_params['added_gp_outputscale'] = [factor * 0.05 for _ in range(DEFAULTS_SERGIO['sergio_dim'])]
exp_params['added_gp_outputscale'] = [factor * 0.1 for _ in range(DEFAULTS_SERGIO['sergio_dim'])]
else:
raise AssertionError('passed negative value for added_gp_outputscale')
# set likelihood_std to default value if not specified
Expand All @@ -271,9 +270,10 @@ def main(args):


from pprint import pprint
print('\nExperiment parameters:')
pprint(exp_params)
print('')
if not 'Sergio' in exp_params['data_source']:
print('\nExperiment parameters:')
pprint(exp_params)
print('')

""" Experiment core """
t_start = time.time()
Expand Down Expand Up @@ -334,7 +334,7 @@ def main(args):
parser.add_argument('--data_seed', type=int, default=77698)

# standard BNN parameters
parser.add_argument('--model', type=str, default='BNN_FSVGD_SimPrior_hf_gp')
parser.add_argument('--model', type=str, default='SysID')
parser.add_argument('--model_seed', type=int, default=892616)
parser.add_argument('--likelihood_std', type=float, default=None)
parser.add_argument('--learn_likelihood_std', type=int, default=0)
Expand Down
156 changes: 156 additions & 0 deletions experiments/lf_hf_transfer_exp/sweep_regression_exp_num_data_sergio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from experiments.util import (generate_run_commands, generate_base_command, RESULT_DIR, sample_param_flags, hash_dict)
from experiments.data_provider import DATASET_CONFIGS

import experiments.lf_hf_transfer_exp.run_regression_exp
from experiments.lf_hf_transfer_exp.run_regression_exp import OUTPUTSCALES_RCCAR
import numpy as np
import datetime
import itertools
import argparse
import os
import jax.numpy as jnp

MODEL_SPECIFIC_CONFIG = {
'BNN_SVGD': {
'bandwidth_svgd': {'values': [10.]},
'min_train_steps': {'values': [5_000]},
'num_epochs': {'values': [200]},
'lr': {'values': [1e-3]},
# 'likelihood_reg': {'values': [10.0]},
},
'BNN_FSVGD': {
'bandwidth_svgd': {'values': [2.0]},
'bandwidth_gp_prior': {'values': [0.4]},
'min_train_steps': {'values': [5_000]},
'num_epochs': {'values': [200]},
'num_measurement_points': {'values': [128]},
'lr': {'values': [1e-3]},
'likelihood_reg': {'values': [10.0]},
},

'BNN_FSVGD_SimPrior_gp': {
'bandwidth_svgd': {'values': [2.0]},
'min_train_steps': {'values': [5_000]},
'num_epochs': {'values': [200]},
'num_measurement_points': {'values': [128]},
'num_f_samples': {'values': [1028]},
'added_gp_lengthscale': {'values': [2.]},
'added_gp_outputscale': {'values': [2.0]},
'lr': {'values': [1e-3]},
'likelihood_reg': {'values': [10.0]},
},

'BNN_FSVGD_SimPrior_no_add_gp': {
'bandwidth_svgd': {'values': [2.0]},
'min_train_steps': {'values': [5_000]},
'num_epochs': {'values': [200]},
'num_measurement_points': {'values': [128]},
'num_f_samples': {'values': [1028]},
'added_gp_lengthscale': {'values': [2.]},
'added_gp_outputscale': {'values': [2.0]},
'lr': {'values': [1e-3]},
'likelihood_reg': {'values': [10.0]},
},

'SysID': {
},
'GreyBox': {
'bandwidth_svgd': {'values': [2.0]},
'bandwidth_gp_prior': {'values': [0.4]},
'min_train_steps': {'values': [5_000]},
'num_epochs': {'values': [200]},
'num_measurement_points': {'values': [128]},
'lr': {'values': [1e-3]},
'likelihood_reg': {'values': [10.0]},
},
}


def main(args):
# setup random seeds
rds = np.random.RandomState(args.seed)
model_seeds = list(rds.randint(0, 10 ** 6, size=(100,)))
data_seeds = list(rds.randint(0, 10 ** 6, size=(100,)))

sweep_config = {
'data_source': {'value': args.data_source},
# 'num_samples_train': DATASET_CONFIGS[args.data_source]['num_samples_train'],
'model': {'value': args.model},
'learn_likelihood_std': {'value': args.learn_likelihood_std},
# 'likelihood_std': {'value': None},
'num_particles': {'value': 20},
'data_batch_size': {'value': 8},
'pred_diff': {'value': args.pred_diff},
'max_train_steps': {'value': 300_000},
'num_sim_model_train_steps': {'value': 5_000},
}
# update with model specific sweep ranges
model_name = args.model.replace('_no_add_gp', '')
model_name = model_name.replace('_hf', '')
assert model_name in MODEL_SPECIFIC_CONFIG
sweep_config.update(MODEL_SPECIFIC_CONFIG[model_name])

# determine name of experiment
exp_base_path = os.path.join(RESULT_DIR, args.exp_name)
exp_path = os.path.join(exp_base_path, f'{args.data_source}_{args.model}')

if args.data_source == 'racecar_hf':
N_SAMPLES_LIST = [50, 100, 200, 400, 800, 1600, 3200, 6400]
elif args.data_source == 'pendulum_hf':
N_SAMPLES_LIST = [10, 20, 40, 80, 160, 320, 640, 1280]
elif args.data_source == 'real_racecar_v3':
N_SAMPLES_LIST = [50, 100, 200, 400, 800, 1600, 3200, 6400]
elif args.data_source == 'Sergio_hf':
N_SAMPLES_LIST = [200, 400, 800, 1600, 3200, 4800, 6400, 12800]
else:
raise NotImplementedError(f'Unknown data source {args.data_source}.')

command_list = []
output_file_list = []
for _ in range(args.num_hparam_samples):
flags = sample_param_flags(sweep_config)
exp_hash = hash_dict(flags)
for num_samples_train in N_SAMPLES_LIST:
exp_result_folder = os.path.join(exp_path, f'{exp_hash}_{num_samples_train}')
flags['exp_result_folder'] = exp_result_folder

for model_seed, data_seed in itertools.product(model_seeds[:args.num_model_seeds],
data_seeds[:args.num_data_seeds]):
cmd = generate_base_command(experiments.lf_hf_transfer_exp.run_regression_exp,
flags=dict(**flags, **{'model_seed': model_seed, 'data_seed': data_seed,
'num_samples_train': num_samples_train,
}))
command_list.append(cmd)
output_file_list.append(os.path.join(exp_result_folder, f'{model_seed}_{data_seed}.out'))

generate_run_commands(command_list, output_file_list, num_cpus=args.num_cpus,
num_gpus=1 if args.gpu else 0, mode=args.run_mode, prompt=not args.yes)


if __name__ == '__main__':
current_date = datetime.datetime.now().strftime("%b%d").lower()
parser = argparse.ArgumentParser(description='Meta-BO run')

# sweep args
parser.add_argument('--num_hparam_samples', type=int, default=1)
parser.add_argument('--num_model_seeds', type=int, default=5, help='number of model seeds per hparam')
parser.add_argument('--num_data_seeds', type=int, default=5, help='number of model seeds per hparam')
parser.add_argument('--num_cpus', type=int, default=4, help='number of cpus to use')
parser.add_argument('--run_mode', type=str, default='euler')

# general args
parser.add_argument('--exp_name', type=str, default=f'test_{current_date}')
parser.add_argument('--seed', type=int, default=94563)
parser.add_argument('--gpu', default=True, action='store_true')
parser.add_argument('--yes', default=False, action='store_true')

# data parameters
parser.add_argument('--data_source', type=str, default='pendulum_hf')
parser.add_argument('--pred_diff', type=int, default=0)

# # standard BNN parameters
parser.add_argument('--model', type=str, default='BNN_SVGD')
parser.add_argument('--learn_likelihood_std', type=int, default=0)

args = parser.parse_args()
main(args)
7 changes: 3 additions & 4 deletions sim_transfer/sims/dynamics_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def hill_function(x: jnp.array):

hills = hill_function(x)
hills = hills[:, :, None]
masked_contribution = params.contribution_rates * params.graph
masked_contribution = params.graph[None] * params.contribution_rates

# [n_cell_types, n_genes, n_genes]
# switching mechanism between activation and repression,
Expand Down Expand Up @@ -672,11 +672,10 @@ def sample_single_params(self, key: jax.random.PRNGKey, lower_bound: SergioParam
lower_bound_graph = jnp.clip(lower_bound.graph, 0, 2)
upper_bound_graph = jnp.clip(upper_bound.graph, 0, 2)
graph = jax.random.randint(graph_key, shape=(self.n_genes, self.n_genes), minval=lower_bound_graph,
maxval=upper_bound_graph)
maxval=upper_bound_graph) * 1.0
diag_elements = jnp.diag_indices_from(graph)
graph = graph.at[diag_elements].set(1)
graph = graph.at[diag_elements].set(1.0)
power = jax.random.uniform(power_key, shape=(1, ), minval=lower_bound.power, maxval=upper_bound.power)

return SergioParams(
lam=lam,
contribution_rates=contribution_rates,
Expand Down
8 changes: 4 additions & 4 deletions sim_transfer/sims/simulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ def domain(self) -> Domain:
return self._domain

def _setup_params(self):
self.lower_bound_param_hf = SergioParams(lam=jnp.array(0.7),
self.lower_bound_param_hf = SergioParams(lam=jnp.array(0.79),
contribution_rates=jnp.array(-5.0),
basal_rates=jnp.array(1.0),
power=jnp.array(2.0),
Expand All @@ -1034,15 +1034,15 @@ def _setup_params(self):
self.default_param_hf = self.model.sample_single_params(jax.random.PRNGKey(0), self.lower_bound_param_hf,
self.upper_bound_param_hf)

self.lower_bound_param_lf = SergioParams(lam=jnp.array(0.1),
self.lower_bound_param_lf = SergioParams(lam=jnp.array(0.79),
contribution_rates=jnp.array(-5.0),
basal_rates=jnp.array(1.0),
power=jnp.array(0.0),
power=jnp.array(1.0),
graph=jnp.array(0))
self.upper_bound_param_lf = SergioParams(lam=jnp.array(0.8),
contribution_rates=jnp.array(5.0),
basal_rates=jnp.array(5.0),
power=jnp.array(0.0),
power=jnp.array(1.0),
graph=jnp.array(0))
self.default_param_lf = self.model.sample_single_params(jax.random.PRNGKey(0), self.lower_bound_param_lf,
self.upper_bound_param_lf)
Expand Down

0 comments on commit 306da1c

Please sign in to comment.