diff --git a/experiments/data_provider.py b/experiments/data_provider.py index 8babc24..b38a8b6 100644 --- a/experiments/data_provider.py +++ b/experiments/data_provider.py @@ -37,8 +37,8 @@ 'obs_noise_std': 0.05, 'x_support_mode_train': 'full', 'param_mode': 'random', - 'num_cells': 10, - 'num_genes': 200, + 'num_cells': 200, + 'num_genes': 10, } DEFAULTS_RACECAR = { @@ -86,11 +86,11 @@ }, 'Sergio': { - 'likelihood_std': {'value': [0.05 for _ in range(2 * DEFAULTS_SERGIO['num_cells'])]}, + 'likelihood_std': {'value': [0.05 for _ in range(2 * DEFAULTS_SERGIO['num_genes'])]}, 'num_samples_train': {'value': 20}, }, 'Sergio_hf': { - 'likelihood_std': {'value': [0.05 for _ in range(2 * DEFAULTS_SERGIO['num_cells'])]}, + 'likelihood_std': {'value': [0.05 for _ in range(2 * DEFAULTS_SERGIO['num_genes'])]}, 'num_samples_train': {'value': 20}, }, diff --git a/experiments/lf_hf_transfer_exp/run_regression_exp.py b/experiments/lf_hf_transfer_exp/run_regression_exp.py index b8b10ca..c8d23a3 100644 --- a/experiments/lf_hf_transfer_exp/run_regression_exp.py +++ b/experiments/lf_hf_transfer_exp/run_regression_exp.py @@ -253,7 +253,7 @@ def main(args): exp_params['added_gp_outputscale'] = [factor * 0.05, factor * 0.05, factor * 0.5] elif 'Sergio' in exp_params['data_source']: from experiments.data_provider import DEFAULTS_SERGIO - exp_params['added_gp_outputscale'] = [factor * 0.1 for _ in range(2 * DEFAULTS_SERGIO['num_cells'])] + exp_params['added_gp_outputscale'] = [factor * 0.1 for _ in range(2 * DEFAULTS_SERGIO['num_genes'])] elif 'Greenhouse' in exp_params['data_source']: exp_params['added_gp_outputscale'] = [factor * 0.05 for _ in range(16)] # We are quite confident about exogenous effects diff --git a/sim_transfer/sims/simulators.py b/sim_transfer/sims/simulators.py index 8b8fc44..846926d 100644 --- a/sim_transfer/sims/simulators.py +++ b/sim_transfer/sims/simulators.py @@ -996,7 +996,7 @@ class SergioSim(FunctionSimulator): sample_x_max: float = 3 def __init__(self, n_genes: int = 20, n_cells: int = 20, use_hf: bool = False): - FunctionSimulator.__init__(self, input_size=2 * n_cells, output_size=2 * n_cells) + FunctionSimulator.__init__(self, input_size=2 * n_genes, output_size=2 * n_genes) self.model = SergioDynamics(self._dt, n_genes, n_cells, state_ub=self.state_ub) self.n_cells = n_cells self.n_genes = n_genes @@ -1016,8 +1016,8 @@ def __init__(self, n_genes: int = 20, n_cells: int = 20, use_hf: bool = False): 'lower bounds have to be smaller than upper bounds' # setup domain - self.domain_lower = -self.sample_x_max * jnp.ones(shape=(2 * self.n_cells,)) - self.domain_upper = self.sample_x_max * jnp.ones(shape=(2 * self.n_cells,)) + self.domain_lower = -self.sample_x_max * jnp.ones(shape=(2 * self.n_genes,)) + self.domain_upper = self.sample_x_max * jnp.ones(shape=(2 * self.n_genes,)) self._domain = HypercubeDomain(lower=self.domain_lower, upper=self.domain_upper) @property @@ -1069,8 +1069,8 @@ def predict_next_state(self, x: jnp.array, params: NamedTuple, key: jax.random.PRNGKey = jax.random.PRNGKey(0)) -> jnp.array: assert x.ndim == 1 mu, log_std = jnp.split(x, 2, axis=-1) - x = mu + jax.random.normal(key=key, shape=(self.n_genes, self.n_cells)) * jax.nn.softplus(log_std) - x = x.reshape(self.n_genes * self.n_cells) + x = mu + jax.random.normal(key=key, shape=(self.n_cells, self.n_genes)) * jax.nn.softplus(log_std) + x = x.reshape(self.n_cells * self.n_genes) # clip state to be between -3, 3 x = jnp.clip(x, -self.sample_x_max, self.sample_x_max) # scale it to be between [0, 1] @@ -1080,7 +1080,7 @@ def predict_next_state(self, x: jnp.array, params: NamedTuple, # rescale it back to be between [-3, 3] f = (f - 0.5) * (2 * self.sample_x_max) # take the mean and std over genes - f = f.reshape(self.n_genes, self.n_cells) + f = f.reshape(self.n_cells, self.n_genes) mu_f, std_f = jnp.mean(f, axis=0), jnp.std(f, axis=0) # clip std so that its positive and take log std std_f = jnp.clip(std_f, 1e-6)