Skip to content

Commit

Permalink
bug fixes for sergio
Browse files Browse the repository at this point in the history
  • Loading branch information
sukhijab committed Feb 26, 2024
1 parent bbfc60b commit db4ead8
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 37 deletions.
6 changes: 3 additions & 3 deletions experiments/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
'obs_noise_std': 0.02,
'x_support_mode_train': 'full',
'param_mode': 'random',
'num_cells': 10,
'num_genes': 10,
'sergio_dim': 10 * 10,
'num_cells': 5,
'num_genes': 15,
'sergio_dim': 5 * 15,
}

DEFAULTS_RACECAR = {
Expand Down
45 changes: 30 additions & 15 deletions sim_transfer/sims/dynamics_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class SergioParams(NamedTuple):
lam: jax.Array = jnp.array(0.8)
contribution_rates: jax.Array = jnp.array(2.0)
basal_rates: jax.Array = jnp.array(0.0)
power: jax.Array = jnp.array(2.0)
graph: jax.Array = jnp.array(1.0)


class DynamicsModel(ABC):
Expand Down Expand Up @@ -575,7 +577,6 @@ def _ode(self, x, u, params: CarParams):

class SergioDynamics(ABC):
l_b: float = 0
u_b: float = 500
lam_lb: float = 0.2
lam_ub: float = 0.9

Expand All @@ -601,10 +602,10 @@ def __init__(self,
def next_step(self, x: jax.Array, params: PyTree) -> jax.Array:
def body(carry, _):
q = carry + self.dt_integration * self.ode(carry, params)
q = jnp.clip(q, a_min=self.l_b)
return q, None

next_state, _ = jax.lax.scan(body, x, xs=None, length=self._num_steps_integrate)
next_state = jnp.clip(next_state, self.l_b, self.u_b)
return next_state

def ode(self, x: jax.Array, params) -> jax.Array:
Expand All @@ -614,17 +615,18 @@ def ode(self, x: jax.Array, params) -> jax.Array:
def production_rate(self, x: jnp.array, params: SergioParams):
assert x.shape == (self.n_cells, self.n_genes)

def hill_function(x: jnp.array, n: float = 2.0):
def hill_function(x: jnp.array):
h = x.mean(0)
hill_numerator = jnp.power(x, n)
hill_denominator = jnp.power(h, n) + hill_numerator
hill_numerator = jnp.power(x, params.power)
hill_denominator = jnp.power(h, params.power) + hill_numerator
hill = jnp.where(
jnp.abs(hill_denominator) < 1e-6, 0, hill_numerator / hill_denominator
)
return hill

hills = hill_function(x)
masked_contribution = params.contribution_rates
hills = hills[:, :, None]
masked_contribution = params.contribution_rates * params.graph

# [n_cell_types, n_genes, n_genes]
# switching mechanism between activation and repression,
Expand Down Expand Up @@ -652,8 +654,8 @@ def _split_key_like_tree(self, key: jax.random.PRNGKey):
keys = jax.random.split(key, treedef.num_leaves)
return jtu.tree_unflatten(treedef, keys)

def sample_single_params(self, key: jax.random.PRNGKey, lower_bound: NamedTuple, upper_bound: NamedTuple):
lam_key, contrib_key, basal_key = jax.random.split(key, 3)
def sample_single_params(self, key: jax.random.PRNGKey, lower_bound: SergioParams, upper_bound: SergioParams):
lam_key, contrib_key, basal_key, graph_key, power_key = jax.random.split(key, 5)
lam = jax.random.uniform(lam_key, shape=(self.n_cells, self.n_genes), minval=lower_bound.lam,
maxval=upper_bound.lam)

Expand All @@ -667,36 +669,49 @@ def sample_single_params(self, key: jax.random.PRNGKey, lower_bound: NamedTuple,
minval=lower_bound.basal_rates,
maxval=upper_bound.basal_rates)

lower_bound_graph = jnp.clip(lower_bound.graph, 0, 2)
upper_bound_graph = jnp.clip(upper_bound.graph, 0, 2)
graph = jax.random.randint(graph_key, shape=(self.n_genes, self.n_genes), minval=lower_bound_graph,
maxval=upper_bound_graph)
diag_elements = jnp.diag_indices_from(graph)
graph = graph.at[diag_elements].set(1)
power = jax.random.uniform(power_key, shape=(1, ), minval=lower_bound.power, maxval=upper_bound.power)

return SergioParams(
lam=lam,
contribution_rates=contribution_rates,
basal_rates=basal_rates
basal_rates=basal_rates,
power=power,
graph=graph,
)

def sample_params_uniform(self, key: jax.random.PRNGKey, sample_shape: Union[int, Tuple[int]],
lower_bound: NamedTuple, upper_bound: NamedTuple):
if isinstance(sample_shape, int):
keys = jax.random.split(key, sample_shape)
else:
keys = jax.random.split(key, jnp.prod(sample_shape.shape))
keys = jax.random.split(key, np.prod(sample_shape))
sampled_params = jax.vmap(self.sample_single_params,
in_axes=(0, None, None))(keys, lower_bound, upper_bound)
return sampled_params


if __name__ == "__main__":
sim = SergioDynamics(0.1, 10, 10)
x_next = sim.next_step(x=jnp.ones(10 * 10), params=sim.params)
dim_x, dim_y = 10, 10
sim = SergioDynamics(0.1, dim_x, dim_y)
x_next = sim.next_step(x=jnp.ones(dim_x * dim_y), params=sim.params)
lower_bound = SergioParams(lam=jnp.array(0.2),
contribution_rates=jnp.array(1.0),
basal_rates=jnp.array(1.0))
basal_rates=jnp.array(1.0),
graph=jnp.array(0))
upper_bound = SergioParams(lam=jnp.array(0.9),
contribution_rates=jnp.array(5.0),
basal_rates=jnp.array(5.0))
basal_rates=jnp.array(5.0),
graph=jnp.array(2))
key = jax.random.PRNGKey(0)
keys = random.split(key, 4)
params = vmap(sim.sample_params_uniform, in_axes=(0, None, None, None))(keys, 1, lower_bound, upper_bound)
x_next = vmap(vmap(lambda p: sim.next_step(x=jnp.ones(20 * 20), params=p)))(params)
x_next = vmap(vmap(lambda p: sim.next_step(x=jnp.ones(dim_x * dim_y), params=p)))(params)
pendulum = Pendulum(0.1)
pendulum.next_step(x=jnp.array([0., 0., 0.]), u=jnp.array([1.0]), params=pendulum.params)

Expand Down
57 changes: 38 additions & 19 deletions sim_transfer/sims/simulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ class SergioSim(FunctionSimulator):

# domain for generating data
state_lb: float = 0.0
state_ub: float = 400
state_ub: float = 500

def __init__(self, n_genes: int = 20, n_cells: int = 20, use_hf: bool = False):
FunctionSimulator.__init__(self, input_size=n_genes * n_cells, output_size=n_genes * n_cells)
Expand Down Expand Up @@ -1021,33 +1021,44 @@ def domain(self) -> Domain:
return self._domain

def _setup_params(self):
self.lower_bound_param_hf = SergioParams(lam=jnp.array(0.1),
self.lower_bound_param_hf = SergioParams(lam=jnp.array(0.7),
contribution_rates=jnp.array(-5.0),
basal_rates=jnp.array(0.0))
self.upper_bound_param_hf = SergioParams(lam=jnp.array(0.9),
basal_rates=jnp.array(1.0),
power=jnp.array(2.0),
graph=jnp.array(0))
self.upper_bound_param_hf = SergioParams(lam=jnp.array(0.8),
contribution_rates=jnp.array(5.0),
basal_rates=jnp.array(5.0))
basal_rates=jnp.array(5.0),
power=jnp.array(2.0),
graph=jnp.array(2))
self.default_param_hf = self.model.sample_single_params(jax.random.PRNGKey(0), self.lower_bound_param_hf,
self.upper_bound_param_hf)

self.lower_bound_param_lf = SergioParams(lam=jnp.array(0.2),
contribution_rates=jnp.array(0.0),
basal_rates=jnp.array(0.0))
self.upper_bound_param_lf = SergioParams(lam=jnp.array(0.9),
contribution_rates=jnp.array(0.0),
basal_rates=jnp.array(0.0))
self.lower_bound_param_lf = SergioParams(lam=jnp.array(0.1),
contribution_rates=jnp.array(-5.0),
basal_rates=jnp.array(1.0),
power=jnp.array(0.0),
graph=jnp.array(0))
self.upper_bound_param_lf = SergioParams(lam=jnp.array(0.8),
contribution_rates=jnp.array(5.0),
basal_rates=jnp.array(5.0),
power=jnp.array(0.0),
graph=jnp.array(0))
self.default_param_lf = self.model.sample_single_params(jax.random.PRNGKey(0), self.lower_bound_param_lf,
self.upper_bound_param_lf)

def init_params(self):
return self._typical_params

def sample_params(self, rng_key: jax.random.PRNGKey):
params = self.model.sample_params_uniform(rng_key, sample_shape=1,
lower_bound=self._lower_bound_params,
upper_bound=self._upper_bound_params)
params = jtu.tree_map(lambda x: x.item(), params)
params = self.model.sample_single_params(rng_key,
lower_bound=self._lower_bound_params,
upper_bound=self._upper_bound_params)
train_params = jtu.tree_map(lambda x: 1, params)
if self.use_hf:
train_params = train_params._replace(power=0)
else:
train_params = train_params._replace(power=0, graph=0)
return params, train_params

def sample_function_vals(self, x: jnp.ndarray, num_samples: int, rng_key: jax.random.PRNGKey) -> jnp.ndarray:
Expand Down Expand Up @@ -1248,12 +1259,20 @@ def evaluate_sim(self, x: jnp.array, params: NamedTuple) -> jnp.array:

if __name__ == '__main__':
key1, key2 = jax.random.split(jax.random.PRNGKey(435345), 2)
function_sim = SergioSim(use_hf=False)
function_sim = SergioSim(5, 10, use_hf=False)
function_sim.sample_params(key1)
x, _ = function_sim._sample_x_data(key1, 1, 1)

f1 = function_sim.sample_function_vals(x, num_samples=10, rng_key=key2)
param1 = function_sim._typical_params
f1 = function_sim.sample_function_vals(x, num_samples=1000, rng_key=key2)
f2 = function_sim._typical_f(x)

function_sim = SergioSim(5, 10, use_hf=True)
params = function_sim._typical_params
params = params._replace(
lam=param1.lam,
)
f3 = function_sim.evaluate_sim(x, params)
print(jnp.isnan(f1).any())
print(jnp.isnan(f2).any())
function_sim = RaceCarSim(use_blend=False, no_angular_velocity=True)
x, _ = function_sim._sample_x_data(key1, 1000, 1000)

Expand Down

0 comments on commit db4ead8

Please sign in to comment.