diff --git a/experiments/data_provider.py b/experiments/data_provider.py index f54cf09..eb2c0f6 100644 --- a/experiments/data_provider.py +++ b/experiments/data_provider.py @@ -13,7 +13,6 @@ from sim_transfer.sims.simulators import StackedActionSimWrapper from sim_transfer.sims.util import encode_angles as encode_angles_fn - DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data') DEFAULTS_SINUSOIDS = { @@ -28,6 +27,15 @@ 'param_mode': 'random', } +DEFAULTS_SERGIO = { + 'obs_noise_std': 0.02, + 'x_support_mode_train': 'full', + 'param_mode': 'random', + 'num_cells': 10, + 'num_genes': 10, + 'sergio_dim': 10 * 10, +} + DEFAULTS_RACECAR = { 'obs_noise_std': OBS_NOISE_STD_SIM_CAR, 'x_support_mode_train': 'full', @@ -61,6 +69,15 @@ 'likelihood_std': {'value': [0.05, 0.05, 0.5]}, 'num_samples_train': {'value': 20}, }, + 'Sergio': { + 'likelihood_std': {'value': [0.05 for _ in range(DEFAULTS_SERGIO['sergio_dim'])]}, + 'num_samples_train': {'value': 20}, + }, + 'Sergio_hf': { + 'likelihood_std': {'value': [0.05 for _ in range(DEFAULTS_SERGIO['sergio_dim'])]}, + 'num_samples_train': {'value': 20}, + }, + 'pendulum_bimodal': { 'likelihood_std': {'value': [0.05, 0.05, 0.5]}, 'num_samples_train': {'value': 20}, @@ -190,7 +207,6 @@ def get_rccar_recorded_data_new(encode_angle: bool = True, skip_first_n_points: dataset: str = 'all', action_delay: int = 3, action_stacking: bool = False, car_id: int = 2): - assert car_id in [1, 2, 3] if car_id == 1: assert dataset in ['all', 'v1'] @@ -247,6 +263,14 @@ def provide_data_and_sim(data_source: str, data_spec: Dict[str, Any], data_seed: else: sim_hf = sim_lf = PendulumSim(encode_angle=True, high_fidelity=True) assert {'num_samples_train'} <= set(data_spec.keys()) <= {'num_samples_train'}.union(DEFAULTS_PENDULUM.keys()) + elif data_source == 'Sergio' or data_source == 'Sergio_hf': + defaults = DEFAULTS_SERGIO + from sim_transfer.sims.simulators import SergioSim + if data_source == 'Sergio_hf': + sim_hf = SergioSim(n_genes=defaults['num_genes'], n_cells=defaults['num_cells'], use_hf=True) + sim_lf = SergioSim(n_genes=defaults['num_genes'], n_cells=defaults['num_cells'], use_hf=False) + else: + sim_hf = sim_lf = SergioSim(n_genes=defaults['num_genes'], n_cells=defaults['n_cells'], use_hf=False) elif data_source == 'pendulum_bimodal' or data_source == 'pendulum_bimodal_hf': from sim_transfer.sims.simulators import PendulumBiModalSim defaults = DEFAULTS_PENDULUM diff --git a/experiments/lf_hf_transfer_exp/run_regression_exp.py b/experiments/lf_hf_transfer_exp/run_regression_exp.py index a72cab4..87bb113 100644 --- a/experiments/lf_hf_transfer_exp/run_regression_exp.py +++ b/experiments/lf_hf_transfer_exp/run_regression_exp.py @@ -248,10 +248,13 @@ def main(args): elif 'only_pose' in exp_params['data_source']: outputscales_racecar = outputscales_racecar[:-3] exp_params['added_gp_outputscale'] = outputscales_racecar.tolist() - print(f"Setting added_gp_outputscale to data_source default value from DATASET_CONFIGS " - f"which is {exp_params['added_gp_outputscale']}") + # print(f"Setting added_gp_outputscale to data_source default value from DATASET_CONFIGS " + # f"which is {exp_params['added_gp_outputscale']}") elif 'pendulum' in exp_params['data_source']: exp_params['added_gp_outputscale'] = [factor * 0.05, 0.05, 0.5] + elif 'Sergio' in exp_params['data_source']: + from experiments.data_provider import DEFAULTS_SERGIO + exp_params['added_gp_outputscale'] = [factor * 0.05 for _ in range(DEFAULTS_SERGIO['sergio_dim'])] else: raise AssertionError('passed negative value for added_gp_outputscale') # set likelihood_std to default value if not specified @@ -262,8 +265,8 @@ def main(args): elif 'only_pose' in exp_params['data_source']: likelihood_std = likelihood_std[:-3] exp_params['likelihood_std'] = likelihood_std - print(f"Setting likelihood_std to data_source default value from DATASET_CONFIGS " - f"which is {exp_params['likelihood_std']}") + # print(f"Setting likelihood_std to data_source default value from DATASET_CONFIGS " + # f"which is {exp_params['likelihood_std']}") @@ -325,7 +328,7 @@ def main(args): parser.add_argument('--use_wandb', type=bool, default=False) # data parameters - parser.add_argument('--data_source', type=str, default='real_racecar_v3') + parser.add_argument('--data_source', type=str, default='Sergio_hf') parser.add_argument('--pred_diff', type=int, default=1) parser.add_argument('--num_samples_train', type=int, default=5000) parser.add_argument('--data_seed', type=int, default=77698) diff --git a/sim_transfer/sims/dynamics_models.py b/sim_transfer/sims/dynamics_models.py index 23df801..b896c38 100644 --- a/sim_transfer/sims/dynamics_models.py +++ b/sim_transfer/sims/dynamics_models.py @@ -60,6 +60,12 @@ class CarParams(NamedTuple): angle_offset: Union[jax.Array, float] = jnp.array([0.02791893]) +class SergioParams(NamedTuple): + lam: jax.Array = jnp.array(0.8) + contribution_rates: jax.Array = jnp.array(2.0) + basal_rates: jax.Array = jnp.array(0.0) + + class DynamicsModel(ABC): def __init__(self, dt: float, @@ -567,7 +573,130 @@ def _ode(self, x, u, params: CarParams): return dx +class SergioDynamics(ABC): + l_b: float = 0 + u_b: float = 500 + lam_lb: float = 0.2 + lam_ub: float = 0.9 + + def __init__(self, + dt: float, + n_cells: int, + n_genes: int, + params: SergioParams = SergioParams(), + dt_integration: float = 0.01, + ): + super().__init__() + self.dt = dt + self.n_cells = n_cells + self.n_genes = n_genes + self.params = params + self.x_dim = self.n_cells * self.n_genes + + self.dt_integration = dt_integration + assert dt >= dt_integration + assert (dt / dt_integration - int(dt / dt_integration)) < 1e-4, 'dt must be multiple of dt_integration' + self._num_steps_integrate = int(dt / dt_integration) + + def next_step(self, x: jax.Array, params: PyTree) -> jax.Array: + def body(carry, _): + q = carry + self.dt_integration * self.ode(carry, params) + return q, None + + next_state, _ = jax.lax.scan(body, x, xs=None, length=self._num_steps_integrate) + return next_state + + def ode(self, x: jax.Array, params) -> jax.Array: + assert x.shape[-1] == self.x_dim + return self._ode(x, params) + + def production_rate(self, x: jnp.array, params: SergioParams): + assert x.shape == (self.n_cells, self.n_genes) + + def hill_function(x: jnp.array, n: float = 2.0): + h = x.mean(0) + hill_numerator = jnp.power(x, n) + hill_denominator = jnp.power(h, n) + hill_numerator + hill = jnp.where( + jnp.abs(hill_denominator) < 1e-6, 0, hill_numerator / hill_denominator + ) + return hill + + hills = hill_function(x) + masked_contribution = params.contribution_rates + + # [n_cell_types, n_genes, n_genes] + # switching mechanism between activation and repression, + # which is decided via the sign of the contribution rates + intermediate = jnp.where( + masked_contribution > 0, + jnp.abs(masked_contribution) * hills, + jnp.abs(masked_contribution) * (1 - hills), + ) + # [n_cell_types, n_genes] + # sum over regulators, i.e. sum over i in [b, i, j] + production_rate = params.basal_rates + intermediate.sum(1) + return production_rate + + def _ode(self, x: jax.Array, params) -> jax.Array: + assert x.shape == (self.n_cells * self.n_genes,) + x = x.reshape(self.n_cells, self.n_genes) + production_rate = self.production_rate(x, params) + x_next = (production_rate - params.lam * x) * self.dt + x_next = x_next.reshape(self.n_cells * self.n_genes) + x_next = jnp.clip(x_next, self.l_b, self.u_b) + return x_next + + def _split_key_like_tree(self, key: jax.random.PRNGKey): + treedef = jtu.tree_structure(self.params) + keys = jax.random.split(key, treedef.num_leaves) + return jtu.tree_unflatten(treedef, keys) + + def sample_single_params(self, key: jax.random.PRNGKey, lower_bound: NamedTuple, upper_bound: NamedTuple): + lam_key, contrib_key, basal_key = jax.random.split(key, 3) + lam = jax.random.uniform(lam_key, shape=(self.n_cells, self.n_genes), minval=lower_bound.lam, + maxval=upper_bound.lam) + + contribution_rates = jax.random.uniform(contrib_key, shape=(self.n_cells, + self.n_genes, + self.n_genes), + minval=lower_bound.contribution_rates, + maxval=upper_bound.contribution_rates) + basal_rates = jax.random.uniform(basal_key, shape=(self.n_cells, + self.n_genes), + minval=lower_bound.basal_rates, + maxval=upper_bound.basal_rates) + + return SergioParams( + lam=lam, + contribution_rates=contribution_rates, + basal_rates=basal_rates + ) + + def sample_params_uniform(self, key: jax.random.PRNGKey, sample_shape: Union[int, Tuple[int]], + lower_bound: NamedTuple, upper_bound: NamedTuple): + if isinstance(sample_shape, int): + keys = jax.random.split(key, sample_shape) + else: + keys = jax.random.split(key, jnp.prod(sample_shape.shape)) + sampled_params = jax.vmap(self.sample_single_params, + in_axes=(0, None, None))(keys, lower_bound, upper_bound) + return sampled_params + + if __name__ == "__main__": + sim = SergioDynamics(0.1, 20, 20) + x_next = sim.next_step(x=jnp.ones(20 * 20), params=sim.params) + lower_bound = SergioParams(lam=jnp.array(0.2), + contribution_rates=jnp.array(1.0), + basal_rates=jnp.array(1.0)) + upper_bound = SergioParams(lam=jnp.array(0.9), + contribution_rates=jnp.array(5.0), + basal_rates=jnp.array(5.0)) + key = jax.random.PRNGKey(0) + keys = random.split(key, 4) + params = vmap(sim.sample_params_uniform, in_axes=(0, None, None, None))(keys, 1, lower_bound, upper_bound) + x_next = vmap(vmap(lambda p: sim.next_step(x=jnp.ones(20 * 20), params=p)))(params) pendulum = Pendulum(0.1) pendulum.next_step(x=jnp.array([0., 0., 0.]), u=jnp.array([1.0]), params=pendulum.params) @@ -577,7 +706,7 @@ def _ode(self, x, u, params: CarParams): c_d=jnp.array(0.1)) key = jax.random.PRNGKey(0) keys = random.split(key, 4) - params = vmap(pendulum.sample_params_uniform, in_axes=(0, None, None, None))(keys, 1, upper_bound, lower_bound) + params = vmap(pendulum.sample_params_uniform, in_axes=(0, None, None, None))(keys, 1, lower_bound, upper_bound) def simulate_car(init_pos=jnp.zeros(2), horizon=150): diff --git a/sim_transfer/sims/simulators.py b/sim_transfer/sims/simulators.py index 2d0e1bc..e032149 100644 --- a/sim_transfer/sims/simulators.py +++ b/sim_transfer/sims/simulators.py @@ -10,7 +10,7 @@ from tensorflow_probability.substrates import jax as tfp from sim_transfer.sims.domain import Domain, HypercubeDomain, HypercubeDomainWithAngles -from sim_transfer.sims.dynamics_models import Pendulum, PendulumParams, RaceCar, CarParams +from sim_transfer.sims.dynamics_models import Pendulum, PendulumParams, RaceCar, CarParams, SergioParams, SergioDynamics from sim_transfer.sims.util import encode_angles, decode_angles @@ -351,7 +351,7 @@ def sample_params(self, rng_key: jax.random.PRNGKey): freq_key, amp_key, slope_key, rng_key = jax.random.split(rng_key, 4) sim_params = { 'freq': jax.random.uniform(freq_key, minval=self.freq1_mid - self.freq1_spread, - maxval=self.freq1_mid + self.freq1_spread), + maxval=self.freq1_mid + self.freq1_spread), 'amp': self.amp_mean + self.amp_std * jax.random.normal(amp_key), 'slope': self.slope_mean + self.slope_std * jax.random.normal(slope_key), } @@ -366,7 +366,7 @@ def sample_params(self, rng_key: jax.random.PRNGKey): return sim_params, train_params def evaluate_sim(self, x: jnp.array, params: NamedTuple) -> jnp.array: - f = self._f1(amp=params.amp, freq=params.freq, slope=params.slope,x=x) + f = self._f1(amp=params.amp, freq=params.freq, slope=params.slope, x=x) if self.output_size == 1: return f elif self.output_size == 2: @@ -607,7 +607,7 @@ def _split_state_action(self, z: jnp.array) -> Tuple[jnp.array, jnp.array]: return z[..., :self._state_action_spit_idx], z[..., self._state_action_spit_idx:] def sample_params(self, rng_key: jax.random.PRNGKey): - params = self.model.sample_params_uniform(rng_key, sample_shape=(1, ), + params = self.model.sample_params_uniform(rng_key, sample_shape=(1,), lower_bound=self._lower_bound_params, upper_bound=self._upper_bound_params) @@ -986,6 +986,132 @@ def _set_default_params(self): raise ValueError(f'Car id {self.car_id} not supported.') +class SergioSim(FunctionSimulator): + _dt: float = 1 / 10 + + # domain for generating data + state_lb: float = 0.0 + state_ub: float = 400 + + def __init__(self, n_genes: int = 20, n_cells: int = 20, use_hf: bool = False): + FunctionSimulator.__init__(self, input_size=n_genes * n_cells, output_size=n_genes * n_cells) + self.model = SergioDynamics(self._dt, n_genes, n_cells) + self._setup_params() + self.use_hf = use_hf + if self.use_hf: + self._typical_params = self.default_param_hf + self._lower_bound_params = self.lower_bound_param_hf + self._upper_bound_params = self.upper_bound_param_hf + else: + self._typical_params = self.default_param_lf + self._lower_bound_params = self.lower_bound_param_lf + self._upper_bound_params = self.upper_bound_param_lf + + assert jnp.all(jnp.stack(jtu.tree_flatten( + jtu.tree_map(lambda l, u: l <= u, self._lower_bound_params, self._upper_bound_params))[0])), \ + 'lower bounds have to be smaller than upper bounds' + + # setup domain + self.domain_lower = jnp.ones(shape=(n_genes * n_cells,)) * self.state_lb + self.domain_upper = jnp.ones(shape=(n_genes * n_cells,)) * self.state_ub + self._domain = HypercubeDomain(lower=self.domain_lower, upper=self.domain_upper) + + @property + def domain(self) -> Domain: + return self._domain + + def _setup_params(self): + self.lower_bound_param_hf = SergioParams(lam=jnp.array(0.1), + contribution_rates=jnp.array(-5.0), + basal_rates=jnp.array(0.0)) + self.upper_bound_param_hf = SergioParams(lam=jnp.array(0.9), + contribution_rates=jnp.array(5.0), + basal_rates=jnp.array(5.0)) + self.default_param_hf = self.model.sample_single_params(jax.random.PRNGKey(0), self.lower_bound_param_hf, + self.upper_bound_param_hf) + + self.lower_bound_param_lf = SergioParams(lam=jnp.array(0.2), + contribution_rates=jnp.array(0.0), + basal_rates=jnp.array(0.0)) + self.upper_bound_param_lf = SergioParams(lam=jnp.array(0.9), + contribution_rates=jnp.array(0.0), + basal_rates=jnp.array(0.0)) + self.default_param_lf = self.model.sample_single_params(jax.random.PRNGKey(0), self.lower_bound_param_lf, + self.upper_bound_param_lf) + + def init_params(self): + return self._typical_params + + def sample_params(self, rng_key: jax.random.PRNGKey): + params = self.model.sample_params_uniform(rng_key, sample_shape=1, + lower_bound=self._lower_bound_params, + upper_bound=self._upper_bound_params) + params = jtu.tree_map(lambda x: x.item(), params) + train_params = jtu.tree_map(lambda x: 1, params) + return params, train_params + + def sample_function_vals(self, x: jnp.ndarray, num_samples: int, rng_key: jax.random.PRNGKey) -> jnp.ndarray: + assert x.ndim == 2 and x.shape[-1] == self.input_size + params = self.model.sample_params_uniform(rng_key, sample_shape=num_samples, + lower_bound=self._lower_bound_params, + upper_bound=self._upper_bound_params) + + def batched_fun(z, params): + f = vmap(self.model.next_step, in_axes=(0, None))(z, params) + return f + + f = vmap(batched_fun, in_axes=(None, 0))(x, params) + assert f.shape == (num_samples, x.shape[0], self.output_size) + return f + + def sample_functions(self, num_samples: int, rng_key: jax.random.PRNGKey) -> Callable: + params = self.model.sample_params_uniform(rng_key, sample_shape=(num_samples,), + lower_bound=self._lower_bound_params, + upper_bound=self._upper_bound_params) + + def stacked_fun(z): + f = vmap(self.model.next_step, in_axes=(0, 0))(x, params) + return f + + return stacked_fun + + @property + def domain(self) -> Domain: + return self._domain + + @property + def normalization_stats(self) -> Dict[str, jnp.ndarray]: + + stats = {'x_mean': jnp.ones(self.input_size) * (self.state_ub + self.state_lb) / 2, + 'x_std': jnp.ones(self.input_size) * (self.state_ub - self.state_lb) ** 2 / 12, + 'y_mean': jnp.ones(self.output_size) * (self.state_ub + self.state_lb) / 2, + 'y_std': jnp.ones(self.output_size) * (self.state_ub - self.state_lb) ** 2 / 12} + return stats + + def _typical_f(self, x: jnp.array) -> jnp.array: + f = jax.vmap(self.model.next_step, in_axes=(0, None))(x, self._typical_params) + return f + + def evaluate_sim(self, x: jnp.array, params: NamedTuple) -> jnp.array: + f = jax.vmap(self.model.next_step, in_axes=(0, None))(x, params) + return f + + def _add_observation_noise(self, f_vals: jnp.ndarray, obs_noise_std: Union[jnp.ndarray, float], + rng_key: jax.random.PRNGKey) -> jnp.ndarray: + + y = f_vals + obs_noise_std * jax.random.normal(rng_key, shape=f_vals.shape) + assert f_vals.shape == y.shape + return y + + def _sample_x_data(self, rng_key: jax.random.PRNGKey, num_samples_train: int, num_samples_test: int, + support_mode_train: str = 'full') -> Tuple[jnp.ndarray, jnp.ndarray]: + """ Sample inputs for training and testing. """ + dataset_domain = HypercubeDomain(lower=self.domain_lower, upper=self.domain_upper) + x_train = dataset_domain.sample_uniformly(rng_key, num_samples_train, support_mode=support_mode_train) + x_test = dataset_domain.sample_uniformly(rng_key, num_samples_test, support_mode='full') + return x_train, x_test + + class PredictStateChangeWrapper(FunctionSimulator): def __init__(self, function_simulator: FunctionSimulator): """ @@ -1122,6 +1248,12 @@ def evaluate_sim(self, x: jnp.array, params: NamedTuple) -> jnp.array: if __name__ == '__main__': key1, key2 = jax.random.split(jax.random.PRNGKey(435345), 2) + function_sim = SergioSim(use_hf=False) + x, _ = function_sim._sample_x_data(key1, 1, 1) + + f1 = function_sim.sample_function_vals(x, num_samples=10, rng_key=key2) + f2 = function_sim._typical_f(x) + function_sim = RaceCarSim(use_blend=False, no_angular_velocity=True) x, _ = function_sim._sample_x_data(key1, 1000, 1000)