Skip to content

Commit

Permalink
fixes to greenhouse sim and changed sergio to predict next state dist…
Browse files Browse the repository at this point in the history
…ribution
  • Loading branch information
sukhijab committed May 27, 2024
1 parent 30fddd1 commit 7264ed4
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 91 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ old_results/
log/
tf_summaries/
time_lines/

lenart_internal/
*.pdf
*.csv
mars/

# Byte-compiled / optimized / DLL files
Expand Down
13 changes: 6 additions & 7 deletions experiments/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@
'obs_noise_std': 0.05,
'x_support_mode_train': 'full',
'param_mode': 'random',
'num_cells': 5,
'num_genes': 15,
'sergio_dim': 5 * 15,
'num_cells': 10,
'num_genes': 200,
}

DEFAULTS_RACECAR = {
Expand Down Expand Up @@ -77,21 +76,21 @@
},

'Greenhouse': {
'likelihood_std': {'value': [0.05 for _ in range(16)]},
'likelihood_std': {'value': [0.01 for _ in range(16)]},
'num_samples_train': {'value': 20},
},

'Greenhouse_hf': {
'likelihood_std': {'value': [0.05 for _ in range(16)]},
'likelihood_std': {'value': [0.01 for _ in range(16)]},
'num_samples_train': {'value': 20},
},

'Sergio': {
'likelihood_std': {'value': [0.05 for _ in range(DEFAULTS_SERGIO['sergio_dim'])]},
'likelihood_std': {'value': [0.05 for _ in range(2 * DEFAULTS_SERGIO['num_cells'])]},
'num_samples_train': {'value': 20},
},
'Sergio_hf': {
'likelihood_std': {'value': [0.05 for _ in range(DEFAULTS_SERGIO['sergio_dim'])]},
'likelihood_std': {'value': [0.05 for _ in range(2 * DEFAULTS_SERGIO['num_cells'])]},
'num_samples_train': {'value': 20},
},

Expand Down
16 changes: 8 additions & 8 deletions experiments/lf_hf_transfer_exp/run_regression_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,14 @@ def main(args):
# print(f"Setting added_gp_outputscale to data_source default value from DATASET_CONFIGS "
# f"which is {exp_params['added_gp_outputscale']}")
elif 'pendulum' in exp_params['data_source']:
exp_params['added_gp_outputscale'] = [factor * 0.05, 0.05, 0.5]
exp_params['added_gp_outputscale'] = [factor * 0.05, factor * 0.05, factor * 0.5]
elif 'Sergio' in exp_params['data_source']:
from experiments.data_provider import DEFAULTS_SERGIO
exp_params['added_gp_outputscale'] = [factor * 0.1 for _ in range(DEFAULTS_SERGIO['sergio_dim'])]
exp_params['added_gp_outputscale'] = [factor * 0.1 for _ in range(2 * DEFAULTS_SERGIO['num_cells'])]
elif 'Greenhouse' in exp_params['data_source']:
exp_params['added_gp_outputscale'] = [factor * 0.1 for _ in range(16)]
exp_params['added_gp_outputscale'] = [factor * 0.05 for _ in range(16)]
# We are quite confident about exogenous effects
exp_params['added_gp_outputscale'][-6:] = [0.0001 for _ in range(6)]
exp_params['added_gp_outputscale'][5:] = [0.005 for _ in range(11)]
else:
raise AssertionError('passed negative value for added_gp_outputscale')
# set likelihood_std to default value if not specified
Expand Down Expand Up @@ -332,16 +332,16 @@ def main(args):
parser.add_argument('--use_wandb', type=bool, default=False)

# data parameters
parser.add_argument('--data_source', type=str, default='Greenhouse_hf')
parser.add_argument('--pred_diff', type=int, default=1)
parser.add_argument('--num_samples_train', type=int, default=100)
parser.add_argument('--data_source', type=str, default='Sergio_hf')
parser.add_argument('--pred_diff', type=int, default=0)
parser.add_argument('--num_samples_train', type=int, default=6400)
parser.add_argument('--data_seed', type=int, default=77698)

# standard BNN parameters
parser.add_argument('--model', type=str, default='BNN_FSVGD_SimPrior_gp')
parser.add_argument('--model_seed', type=int, default=892616)
parser.add_argument('--likelihood_std', type=float, default=None)
parser.add_argument('--learn_likelihood_std', type=int, default=0)
parser.add_argument('--learn_likelihood_std', type=int, default=1)
parser.add_argument('--likelihood_reg', type=float, default=0.0)
parser.add_argument('--data_batch_size', type=int, default=8)
parser.add_argument('--min_train_steps', type=int, default=2500)
Expand Down
67 changes: 40 additions & 27 deletions sim_transfer/sims/dynamics_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ class GreenHouseParams(NamedTuple):
cg: Union[jax.Array, float] = jnp.array(32 * (10 ** 3)) # green_house_heat_capacity
cp_w: Union[jax.Array, float] = jnp.array(4180.0) # specific_heat_water
cs: Union[jax.Array, float] = jnp.array(120 * (10 ** 3)) # green_house_soil_heat_capacity
cp_a: Union[jax.Array, float] = 1010 # air_specific_heat_water
cp_a: Union[jax.Array, float] = jnp.array(1010) # air_specific_heat_water
d1: Union[jax.Array, float] = jnp.array(2.1332 * (10 ** (-7))) # plant development rate 1
d2: Union[jax.Array, float] = jnp.array(2.4664 * (10 ** (-1))) # plant development rate 2
d2: Union[jax.Array, float] = jnp.array(2.4664 * (10 ** (-7))) # plant development rate 2
d3: Union[jax.Array, float] = jnp.array(20) # plant development rate 3
d4: Union[jax.Array, float] = jnp.array(7.4966 * (10 ** (-11))) # plant development rate 4
f: Union[jax.Array, float] = jnp.array(1.2) # fruit assimilate requirment
Expand Down Expand Up @@ -663,7 +663,6 @@ def _ode(self, x, u, params: CarParams):


class SergioDynamics(ABC):
l_b: float = 0
lam_lb: float = 0.2
lam_ub: float = 0.9

Expand All @@ -673,28 +672,42 @@ def __init__(self,
n_genes: int,
params: SergioParams = SergioParams(),
dt_integration: float = 0.01,
state_ub: float = 500.0,
):
super().__init__()
self.dt = dt
self.n_cells = n_cells
self.n_genes = n_genes
self.params = params
self.x_dim = self.n_cells * self.n_genes

self.state_ub = state_ub
self.dt_integration = dt_integration
assert dt >= dt_integration
assert (dt / dt_integration - int(dt / dt_integration)) < 1e-4, 'dt must be multiple of dt_integration'
self._num_steps_integrate = int(dt / dt_integration)

def next_step(self, x: jax.Array, params: PyTree) -> jax.Array:
x = self.transform_state(x)

def body(carry, _):
q = carry + self.dt_integration * self.ode(carry, params)
q = jnp.clip(q, a_min=self.l_b)
q = jnp.clip(q, a_min=0)
return q, None

next_state, _ = jax.lax.scan(body, x, xs=None, length=self._num_steps_integrate)
next_state = self.inv_transform_state(next_state)
return next_state

def transform_state(self, x):
# x is between [0, 1] -> [0, state_ub]
x = x * self.state_ub
return x

def inv_transform_state(self, x):
# [0, state_ub] -> [0, 1]
x = x / self.state_ub
return x

def ode(self, x: jax.Array, params) -> jax.Array:
assert x.shape[-1] == self.x_dim
return self._ode(x, params)
Expand Down Expand Up @@ -809,7 +822,7 @@ class GreenHouseDynamics(DynamicsModel):
]
)

input_ub = jnp.array([60, 1.0, 1.0, 2.0])
input_ub = jnp.array([80, 1.0, 1.0, 2.1])
input_lb = jnp.array([0, 0.0, 0.0, 0.0])
noise_to = 5
noise_td = 0.01
Expand All @@ -821,11 +834,11 @@ class GreenHouseDynamics(DynamicsModel):
noise_std = jnp.array(
[
0.05, 0.1, 0.05, 0.01, 0.05, 0.05, 0.1, 0.1, 0.01, noise_to,
noise_td, noise_co, noise_vo, noise_w, noise_g, 2,
noise_td, noise_co, noise_vo, noise_w, noise_g, 0.1,
]
)

def __init__(self, use_hf: bool = False, dt: float = 60):
def __init__(self, use_hf: bool = False, dt: float = 300):

self.use_hf = use_hf
self.greenhouse_state_dim = 5
Expand All @@ -847,6 +860,7 @@ def __init__(self, use_hf: bool = False, dt: float = 60):

def next_step(self, x: jnp.array, u: jnp.array, params: GreenHouseParams) -> jnp.array:
x, u = self.transform_state(x), self.transform_action(u)

def body(carry, _):
q = carry + self.dt_integration * self.ode(carry, u, params)
q = jnp.clip(q, a_min=self.constraint_lb, a_max=self.constraint_ub)
Expand Down Expand Up @@ -883,17 +897,22 @@ def buffer_switching_func(B, b1):
return 1 - jnp.exp(-b1 * B)

def get_respiration_param(self, x, u, params: GreenHouseParams):
ml = x[self.greenhouse_state_dim + 2]
l_lai = (ml / params.wr) ** (params.laim) / (1 + (ml / params.wr) ** (params.laim))
R = - params.p1 - params.p5 * l_lai
# ml = x[self.greenhouse_state_dim + 2]
# l_lai = (ml / params.wr) ** (params.laim) / (1 + (ml / params.wr) ** (params.laim))
R = - params.p1 - params.p5
return R

def get_crop_photosynthesis(self, x, u, params: GreenHouseParams):
t_g, c_i = x[0], x[3]
ml = x[self.greenhouse_state_dim + 2]
G = x[-2]
i_par = params.eta * G * params.mp * params.pg
c_ppm = (10 ** 6) * params.rg / (params.patm * params.Mco2) * (t_g + params.T0) * c_i
# note patm is in kPa. c_i is in g/m^3
# c_ppm units: m^3 Pa/(mol K) * K * g/m^3 /(kPa * kg/mol)
# c_ppm: units 10^-3 m^3 kPa/mol * 10^-3 kg/m^3 /(kPa * kg/mol)
# c_ppm: units 10^-6 [] -> need to multiply with 10^-6 to get right units
# c_ppm = 1/mol * g/kg
c_ppm = params.rg / (params.patm * params.Mco2) * (t_g + params.T0) * c_i
l_lai = (ml / params.wr) ** (params.laim) / (1 + (ml / params.wr) ** (params.laim))
p_g = params.pm * l_lai * i_par / (params.p3 + i_par) * c_ppm / (params.p4 + c_ppm)
return p_g
Expand All @@ -902,7 +921,8 @@ def get_harvest_coefficient(self, x, u, params: GreenHouseParams):
d_p = x[self.greenhouse_state_dim + 3]
t = x[-1]
t_g = x[0]
h = (d_p >= 1) * (params.d1 + params.d2 * jnp.log(t_g / params.d3) - params.d4 * t)
temp_ratio = jnp.clip(t_g / params.d3, a_min=1e-6)
h = (d_p >= 1) * (params.d1 + params.d2 * jnp.log(temp_ratio) - params.d4 * t)
return h

def transform_state(self, x):
Expand All @@ -915,12 +935,12 @@ def transform_action(self, u):
return u

def inv_transform_state(self, x):
x = (x - self.state_lb)/(self.state_ub - self.state_lb)
x = (x - self.state_lb) / (self.state_ub - self.state_lb)
return x

def _greenhouse_dynamics_hf(self, x, u, params: GreenHouseParams):
# C, C, C, m, g/m^3, kg/m^-3
t_g, t_p, t_s, c_i, v_i = x[0], x[1], x[2], x[3], x[5]
t_g, t_p, t_s, c_i, v_i = x[0], x[1], x[2], x[3], x[4]
# g/m^-2, g/m^-2, g/m^-2, []
mb, mf, ml, d_p = x[self.greenhouse_state_dim], x[self.greenhouse_state_dim + 1], \
x[self.greenhouse_state_dim + 2], x[self.greenhouse_state_dim + 3]
Expand Down Expand Up @@ -956,10 +976,6 @@ def _greenhouse_dynamics_hf(self, x, u, params: GreenHouseParams):
dt_g_dt = (k_v + params.kr) * (t_o - t_g) + alpha * (t_p - t_g) + params.ks * (t_s - t_g) \
+ G * params.eta - l * E + l / (1 + params.epsilon) * Mc
dt_g_dt = dt_g_dt / params.cg
# jax.debug.print('G {x}', x=G/params.cg)
# jax.debug.print('Mc{x}', x=Mc)
# jax.debug.print('t_g {x}', x=t_g)
# jax.debug.print('dt_g {x}', x=dt_g_dt)

phi = params.phi
rh = params.rh
Expand Down Expand Up @@ -991,8 +1007,8 @@ def _greenhouse_dynamics_hf(self, x, u, params: GreenHouseParams):
hf, hl = h * params.yf, h * params.yl
dmf_dt = (b * g_f - (1 - b) * rf - hf) * mf
dml_dt = (b * g_l - (1 - b) * rl - hl) * ml
dd_p_dt = params.d1 + params.d2 * jnp.log(t_g / params.d3) - params.d4 * t - h

temp_ratio = jnp.clip(t_g / params.d3, a_min=1e-6)
dd_p_dt = params.d1 + params.d2 * jnp.log(temp_ratio) - params.d4 * t - h
# Exogenous effects

dt_o_dt = jnp.zeros_like(dt_g_dt)
Expand All @@ -1013,8 +1029,7 @@ def _greenhouse_dynamics_hf(self, x, u, params: GreenHouseParams):
return dx_dt

def _greenhouse_dynamics_lf(self, x, u, params: GreenHouseParams):

t_g, t_p, t_s, c_i, v_i = x[0], x[1], x[2], x[3], x[5]
t_g, t_p, t_s, c_i, v_i = x[0], x[1], x[2], x[3], x[4]
# g/m^-2, g/m^-2, g/m^-2, []
mb, mf, ml, d_p = x[self.greenhouse_state_dim], x[self.greenhouse_state_dim + 1], \
x[self.greenhouse_state_dim + 2], x[self.greenhouse_state_dim + 3]
Expand All @@ -1031,9 +1046,6 @@ def _greenhouse_dynamics_lf(self, x, u, params: GreenHouseParams):
dt_g_dt = (k_v + params.kr) * (t_o - t_g) + params.ks * (t_s - t_g) \
+ G * params.eta
dt_g_dt = dt_g_dt / params.cg
# jax.debug.print('l {x}', x=l)
# jax.debug.print('p_g_star{x}', x=p_g_star)
# jax.debug.print('t_g {x}', x=t_g)

dt_p_dt = jnp.zeros_like(dt_g_dt)

Expand All @@ -1060,7 +1072,8 @@ def _greenhouse_dynamics_lf(self, x, u, params: GreenHouseParams):
hf, hl = h * params.yf, h * params.yl
dmf_dt = (b * g_f - (1 - b) * rf - hf) * mf
dml_dt = (b * g_l - (1 - b) * rl - hl) * ml
dd_p_dt = params.d1 + params.d2 * jnp.log(t_g / params.d3) - params.d4 * t - h
temp_ratio = jnp.clip(t_g / params.d3, a_min=1e-6)
dd_p_dt = params.d1 + params.d2 * jnp.log(temp_ratio) - params.d4 * t - h

# Exogenous effects

Expand Down
Loading

0 comments on commit 7264ed4

Please sign in to comment.