Skip to content

Commit

Permalink
added sergio for experimenting
Browse files Browse the repository at this point in the history
  • Loading branch information
sukhijab committed Feb 19, 2024
1 parent 97b56db commit 6d264a6
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 12 deletions.
28 changes: 26 additions & 2 deletions experiments/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from sim_transfer.sims.simulators import StackedActionSimWrapper
from sim_transfer.sims.util import encode_angles as encode_angles_fn


DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data')

DEFAULTS_SINUSOIDS = {
Expand All @@ -28,6 +27,15 @@
'param_mode': 'random',
}

DEFAULTS_SERGIO = {
'obs_noise_std': 0.02,
'x_support_mode_train': 'full',
'param_mode': 'random',
'num_cells': 10,
'num_genes': 10,
'sergio_dim': 10 * 10,
}

DEFAULTS_RACECAR = {
'obs_noise_std': OBS_NOISE_STD_SIM_CAR,
'x_support_mode_train': 'full',
Expand Down Expand Up @@ -61,6 +69,15 @@
'likelihood_std': {'value': [0.05, 0.05, 0.5]},
'num_samples_train': {'value': 20},
},
'Sergio': {
'likelihood_std': {'value': [0.05 for _ in range(DEFAULTS_SERGIO['sergio_dim'])]},
'num_samples_train': {'value': 20},
},
'Sergio_hf': {
'likelihood_std': {'value': [0.05 for _ in range(DEFAULTS_SERGIO['sergio_dim'])]},
'num_samples_train': {'value': 20},
},

'pendulum_bimodal': {
'likelihood_std': {'value': [0.05, 0.05, 0.5]},
'num_samples_train': {'value': 20},
Expand Down Expand Up @@ -190,7 +207,6 @@ def get_rccar_recorded_data_new(encode_angle: bool = True, skip_first_n_points:
dataset: str = 'all',
action_delay: int = 3, action_stacking: bool = False,
car_id: int = 2):

assert car_id in [1, 2, 3]
if car_id == 1:
assert dataset in ['all', 'v1']
Expand Down Expand Up @@ -247,6 +263,14 @@ def provide_data_and_sim(data_source: str, data_spec: Dict[str, Any], data_seed:
else:
sim_hf = sim_lf = PendulumSim(encode_angle=True, high_fidelity=True)
assert {'num_samples_train'} <= set(data_spec.keys()) <= {'num_samples_train'}.union(DEFAULTS_PENDULUM.keys())
elif data_source == 'Sergio' or data_source == 'Sergio_hf':
defaults = DEFAULTS_SERGIO
from sim_transfer.sims.simulators import SergioSim
if data_source == 'Sergio_hf':
sim_hf = SergioSim(n_genes=defaults['num_genes'], n_cells=defaults['num_cells'], use_hf=True)
sim_lf = SergioSim(n_genes=defaults['num_genes'], n_cells=defaults['num_cells'], use_hf=False)
else:
sim_hf = sim_lf = SergioSim(n_genes=defaults['num_genes'], n_cells=defaults['n_cells'], use_hf=False)
elif data_source == 'pendulum_bimodal' or data_source == 'pendulum_bimodal_hf':
from sim_transfer.sims.simulators import PendulumBiModalSim
defaults = DEFAULTS_PENDULUM
Expand Down
13 changes: 8 additions & 5 deletions experiments/lf_hf_transfer_exp/run_regression_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,13 @@ def main(args):
elif 'only_pose' in exp_params['data_source']:
outputscales_racecar = outputscales_racecar[:-3]
exp_params['added_gp_outputscale'] = outputscales_racecar.tolist()
print(f"Setting added_gp_outputscale to data_source default value from DATASET_CONFIGS "
f"which is {exp_params['added_gp_outputscale']}")
# print(f"Setting added_gp_outputscale to data_source default value from DATASET_CONFIGS "
# f"which is {exp_params['added_gp_outputscale']}")
elif 'pendulum' in exp_params['data_source']:
exp_params['added_gp_outputscale'] = [factor * 0.05, 0.05, 0.5]
elif 'Sergio' in exp_params['data_source']:
from experiments.data_provider import DEFAULTS_SERGIO
exp_params['added_gp_outputscale'] = [factor * 0.05 for _ in range(DEFAULTS_SERGIO['sergio_dim'])]
else:
raise AssertionError('passed negative value for added_gp_outputscale')
# set likelihood_std to default value if not specified
Expand All @@ -262,8 +265,8 @@ def main(args):
elif 'only_pose' in exp_params['data_source']:
likelihood_std = likelihood_std[:-3]
exp_params['likelihood_std'] = likelihood_std
print(f"Setting likelihood_std to data_source default value from DATASET_CONFIGS "
f"which is {exp_params['likelihood_std']}")
# print(f"Setting likelihood_std to data_source default value from DATASET_CONFIGS "
# f"which is {exp_params['likelihood_std']}")



Expand Down Expand Up @@ -325,7 +328,7 @@ def main(args):
parser.add_argument('--use_wandb', type=bool, default=False)

# data parameters
parser.add_argument('--data_source', type=str, default='real_racecar_v3')
parser.add_argument('--data_source', type=str, default='Sergio_hf')
parser.add_argument('--pred_diff', type=int, default=1)
parser.add_argument('--num_samples_train', type=int, default=5000)
parser.add_argument('--data_seed', type=int, default=77698)
Expand Down
131 changes: 130 additions & 1 deletion sim_transfer/sims/dynamics_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ class CarParams(NamedTuple):
angle_offset: Union[jax.Array, float] = jnp.array([0.02791893])


class SergioParams(NamedTuple):
lam: jax.Array = jnp.array(0.8)
contribution_rates: jax.Array = jnp.array(2.0)
basal_rates: jax.Array = jnp.array(0.0)


class DynamicsModel(ABC):
def __init__(self,
dt: float,
Expand Down Expand Up @@ -567,7 +573,130 @@ def _ode(self, x, u, params: CarParams):
return dx


class SergioDynamics(ABC):
l_b: float = 0
u_b: float = 500
lam_lb: float = 0.2
lam_ub: float = 0.9

def __init__(self,
dt: float,
n_cells: int,
n_genes: int,
params: SergioParams = SergioParams(),
dt_integration: float = 0.01,
):
super().__init__()
self.dt = dt
self.n_cells = n_cells
self.n_genes = n_genes
self.params = params
self.x_dim = self.n_cells * self.n_genes

self.dt_integration = dt_integration
assert dt >= dt_integration
assert (dt / dt_integration - int(dt / dt_integration)) < 1e-4, 'dt must be multiple of dt_integration'
self._num_steps_integrate = int(dt / dt_integration)

def next_step(self, x: jax.Array, params: PyTree) -> jax.Array:
def body(carry, _):
q = carry + self.dt_integration * self.ode(carry, params)
return q, None

next_state, _ = jax.lax.scan(body, x, xs=None, length=self._num_steps_integrate)
return next_state

def ode(self, x: jax.Array, params) -> jax.Array:
assert x.shape[-1] == self.x_dim
return self._ode(x, params)

def production_rate(self, x: jnp.array, params: SergioParams):
assert x.shape == (self.n_cells, self.n_genes)

def hill_function(x: jnp.array, n: float = 2.0):
h = x.mean(0)
hill_numerator = jnp.power(x, n)
hill_denominator = jnp.power(h, n) + hill_numerator
hill = jnp.where(
jnp.abs(hill_denominator) < 1e-6, 0, hill_numerator / hill_denominator
)
return hill

hills = hill_function(x)
masked_contribution = params.contribution_rates

# [n_cell_types, n_genes, n_genes]
# switching mechanism between activation and repression,
# which is decided via the sign of the contribution rates
intermediate = jnp.where(
masked_contribution > 0,
jnp.abs(masked_contribution) * hills,
jnp.abs(masked_contribution) * (1 - hills),
)
# [n_cell_types, n_genes]
# sum over regulators, i.e. sum over i in [b, i, j]
production_rate = params.basal_rates + intermediate.sum(1)
return production_rate

def _ode(self, x: jax.Array, params) -> jax.Array:
assert x.shape == (self.n_cells * self.n_genes,)
x = x.reshape(self.n_cells, self.n_genes)
production_rate = self.production_rate(x, params)
x_next = (production_rate - params.lam * x) * self.dt
x_next = x_next.reshape(self.n_cells * self.n_genes)
x_next = jnp.clip(x_next, self.l_b, self.u_b)
return x_next

def _split_key_like_tree(self, key: jax.random.PRNGKey):
treedef = jtu.tree_structure(self.params)
keys = jax.random.split(key, treedef.num_leaves)
return jtu.tree_unflatten(treedef, keys)

def sample_single_params(self, key: jax.random.PRNGKey, lower_bound: NamedTuple, upper_bound: NamedTuple):
lam_key, contrib_key, basal_key = jax.random.split(key, 3)
lam = jax.random.uniform(lam_key, shape=(self.n_cells, self.n_genes), minval=lower_bound.lam,
maxval=upper_bound.lam)

contribution_rates = jax.random.uniform(contrib_key, shape=(self.n_cells,
self.n_genes,
self.n_genes),
minval=lower_bound.contribution_rates,
maxval=upper_bound.contribution_rates)
basal_rates = jax.random.uniform(basal_key, shape=(self.n_cells,
self.n_genes),
minval=lower_bound.basal_rates,
maxval=upper_bound.basal_rates)

return SergioParams(
lam=lam,
contribution_rates=contribution_rates,
basal_rates=basal_rates
)

def sample_params_uniform(self, key: jax.random.PRNGKey, sample_shape: Union[int, Tuple[int]],
lower_bound: NamedTuple, upper_bound: NamedTuple):
if isinstance(sample_shape, int):
keys = jax.random.split(key, sample_shape)
else:
keys = jax.random.split(key, jnp.prod(sample_shape.shape))
sampled_params = jax.vmap(self.sample_single_params,
in_axes=(0, None, None))(keys, lower_bound, upper_bound)
return sampled_params


if __name__ == "__main__":
sim = SergioDynamics(0.1, 20, 20)
x_next = sim.next_step(x=jnp.ones(20 * 20), params=sim.params)
lower_bound = SergioParams(lam=jnp.array(0.2),
contribution_rates=jnp.array(1.0),
basal_rates=jnp.array(1.0))
upper_bound = SergioParams(lam=jnp.array(0.9),
contribution_rates=jnp.array(5.0),
basal_rates=jnp.array(5.0))
key = jax.random.PRNGKey(0)
keys = random.split(key, 4)
params = vmap(sim.sample_params_uniform, in_axes=(0, None, None, None))(keys, 1, lower_bound, upper_bound)
x_next = vmap(vmap(lambda p: sim.next_step(x=jnp.ones(20 * 20), params=p)))(params)
pendulum = Pendulum(0.1)
pendulum.next_step(x=jnp.array([0., 0., 0.]), u=jnp.array([1.0]), params=pendulum.params)

Expand All @@ -577,7 +706,7 @@ def _ode(self, x, u, params: CarParams):
c_d=jnp.array(0.1))
key = jax.random.PRNGKey(0)
keys = random.split(key, 4)
params = vmap(pendulum.sample_params_uniform, in_axes=(0, None, None, None))(keys, 1, upper_bound, lower_bound)
params = vmap(pendulum.sample_params_uniform, in_axes=(0, None, None, None))(keys, 1, lower_bound, upper_bound)


def simulate_car(init_pos=jnp.zeros(2), horizon=150):
Expand Down
Loading

0 comments on commit 6d264a6

Please sign in to comment.