Skip to content

Commit

Permalink
swaped genes and cells for sergio sim distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
sukhijab committed May 27, 2024
1 parent 7264ed4 commit f11813f
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
8 changes: 4 additions & 4 deletions experiments/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
'obs_noise_std': 0.05,
'x_support_mode_train': 'full',
'param_mode': 'random',
'num_cells': 10,
'num_genes': 200,
'num_cells': 200,
'num_genes': 10,
}

DEFAULTS_RACECAR = {
Expand Down Expand Up @@ -86,11 +86,11 @@
},

'Sergio': {
'likelihood_std': {'value': [0.05 for _ in range(2 * DEFAULTS_SERGIO['num_cells'])]},
'likelihood_std': {'value': [0.05 for _ in range(2 * DEFAULTS_SERGIO['num_genes'])]},
'num_samples_train': {'value': 20},
},
'Sergio_hf': {
'likelihood_std': {'value': [0.05 for _ in range(2 * DEFAULTS_SERGIO['num_cells'])]},
'likelihood_std': {'value': [0.05 for _ in range(2 * DEFAULTS_SERGIO['num_genes'])]},
'num_samples_train': {'value': 20},
},

Expand Down
2 changes: 1 addition & 1 deletion experiments/lf_hf_transfer_exp/run_regression_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def main(args):
exp_params['added_gp_outputscale'] = [factor * 0.05, factor * 0.05, factor * 0.5]
elif 'Sergio' in exp_params['data_source']:
from experiments.data_provider import DEFAULTS_SERGIO
exp_params['added_gp_outputscale'] = [factor * 0.1 for _ in range(2 * DEFAULTS_SERGIO['num_cells'])]
exp_params['added_gp_outputscale'] = [factor * 0.1 for _ in range(2 * DEFAULTS_SERGIO['num_genes'])]
elif 'Greenhouse' in exp_params['data_source']:
exp_params['added_gp_outputscale'] = [factor * 0.05 for _ in range(16)]
# We are quite confident about exogenous effects
Expand Down
12 changes: 6 additions & 6 deletions sim_transfer/sims/simulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ class SergioSim(FunctionSimulator):
sample_x_max: float = 3

def __init__(self, n_genes: int = 20, n_cells: int = 20, use_hf: bool = False):
FunctionSimulator.__init__(self, input_size=2 * n_cells, output_size=2 * n_cells)
FunctionSimulator.__init__(self, input_size=2 * n_genes, output_size=2 * n_genes)
self.model = SergioDynamics(self._dt, n_genes, n_cells, state_ub=self.state_ub)
self.n_cells = n_cells
self.n_genes = n_genes
Expand All @@ -1016,8 +1016,8 @@ def __init__(self, n_genes: int = 20, n_cells: int = 20, use_hf: bool = False):
'lower bounds have to be smaller than upper bounds'

# setup domain
self.domain_lower = -self.sample_x_max * jnp.ones(shape=(2 * self.n_cells,))
self.domain_upper = self.sample_x_max * jnp.ones(shape=(2 * self.n_cells,))
self.domain_lower = -self.sample_x_max * jnp.ones(shape=(2 * self.n_genes,))
self.domain_upper = self.sample_x_max * jnp.ones(shape=(2 * self.n_genes,))
self._domain = HypercubeDomain(lower=self.domain_lower, upper=self.domain_upper)

@property
Expand Down Expand Up @@ -1069,8 +1069,8 @@ def predict_next_state(self, x: jnp.array, params: NamedTuple,
key: jax.random.PRNGKey = jax.random.PRNGKey(0)) -> jnp.array:
assert x.ndim == 1
mu, log_std = jnp.split(x, 2, axis=-1)
x = mu + jax.random.normal(key=key, shape=(self.n_genes, self.n_cells)) * jax.nn.softplus(log_std)
x = x.reshape(self.n_genes * self.n_cells)
x = mu + jax.random.normal(key=key, shape=(self.n_cells, self.n_genes)) * jax.nn.softplus(log_std)
x = x.reshape(self.n_cells * self.n_genes)
# clip state to be between -3, 3
x = jnp.clip(x, -self.sample_x_max, self.sample_x_max)
# scale it to be between [0, 1]
Expand All @@ -1080,7 +1080,7 @@ def predict_next_state(self, x: jnp.array, params: NamedTuple,
# rescale it back to be between [-3, 3]
f = (f - 0.5) * (2 * self.sample_x_max)
# take the mean and std over genes
f = f.reshape(self.n_genes, self.n_cells)
f = f.reshape(self.n_cells, self.n_genes)
mu_f, std_f = jnp.mean(f, axis=0), jnp.std(f, axis=0)
# clip std so that its positive and take log std
std_f = jnp.clip(std_f, 1e-6)
Expand Down

0 comments on commit f11813f

Please sign in to comment.