From bbfc60ba0544a90daa7996eef32151957a75c5a6 Mon Sep 17 00:00:00 2001 From: sukhijab Date: Mon, 19 Feb 2024 13:34:07 +0100 Subject: [PATCH] bug fix in time integral of sergio dynamics --- sim_transfer/sims/dynamics_models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sim_transfer/sims/dynamics_models.py b/sim_transfer/sims/dynamics_models.py index b896c38..11c38a6 100644 --- a/sim_transfer/sims/dynamics_models.py +++ b/sim_transfer/sims/dynamics_models.py @@ -604,6 +604,7 @@ def body(carry, _): return q, None next_state, _ = jax.lax.scan(body, x, xs=None, length=self._num_steps_integrate) + next_state = jnp.clip(next_state, self.l_b, self.u_b) return next_state def ode(self, x: jax.Array, params) -> jax.Array: @@ -642,9 +643,8 @@ def _ode(self, x: jax.Array, params) -> jax.Array: assert x.shape == (self.n_cells * self.n_genes,) x = x.reshape(self.n_cells, self.n_genes) production_rate = self.production_rate(x, params) - x_next = (production_rate - params.lam * x) * self.dt + x_next = (production_rate - params.lam * x) x_next = x_next.reshape(self.n_cells * self.n_genes) - x_next = jnp.clip(x_next, self.l_b, self.u_b) return x_next def _split_key_like_tree(self, key: jax.random.PRNGKey): @@ -685,8 +685,8 @@ def sample_params_uniform(self, key: jax.random.PRNGKey, sample_shape: Union[int if __name__ == "__main__": - sim = SergioDynamics(0.1, 20, 20) - x_next = sim.next_step(x=jnp.ones(20 * 20), params=sim.params) + sim = SergioDynamics(0.1, 10, 10) + x_next = sim.next_step(x=jnp.ones(10 * 10), params=sim.params) lower_bound = SergioParams(lam=jnp.array(0.2), contribution_rates=jnp.array(1.0), basal_rates=jnp.array(1.0))