Skip to content

Commit

Permalink
flag to stop bnn training and prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
sukhijab committed Nov 23, 2023
1 parent f7f8de0 commit 59e2f09
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions sim_transfer/models/bnn_grey_box_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def __init__(self,
sim: FunctionSimulator,
lr_sim: float = None,
weight_decay_sim: float = 0.0,
num_sim_model_train_steps: int = 10_000):
num_sim_model_train_steps: int = 10_000,
use_base_bnn: bool = True):
self.base_bnn = base_bnn
super().__init__(
input_size=self.base_bnn.input_size,
Expand All @@ -48,6 +49,7 @@ def __init__(self,
self.optim_sim = None
self._x_mean_sim, self._x_std_sim = jnp.copy(self.base_bnn._x_mean), jnp.copy(self.base_bnn._x_std)
self._y_mean_sim, self._y_std_sim = jnp.copy(self.base_bnn._y_mean), jnp.copy(self.base_bnn._y_std)
self.use_base_bnn = use_base_bnn
if lr_sim:
self.lr_sim = lr_sim
else:
Expand Down Expand Up @@ -246,7 +248,9 @@ def predict_post_samples(self, x: jnp.ndarray) -> jnp.ndarray:
sim_model_prediction = self.sim_model_step(x, self.params_sim['sim_params'])
x = self._normalize_data(x)
y_pred_raw = self.batched_model(x)
y_pred = jax.tree_util.tree_map(lambda y: self._unnormalize_y(y) + sim_model_prediction, y_pred_raw)
y_pred = jax.tree_util.tree_map(lambda y:
int(self.use_base_bnn) * self._unnormalize_y(y) + sim_model_prediction,
y_pred_raw)
assert y_pred.ndim == 3 and y_pred.shape[-2:] == (x.shape[0], self.output_size)
return y_pred

Expand Down Expand Up @@ -394,11 +398,12 @@ def fit(self, x_train: jnp.ndarray, y_train: jnp.ndarray, x_eval: Optional[jnp.n
log_to_wandb, metrics_objective, keep_the_best, per_dim_metrics)
y_train = y_train - self.sim_model_step(x_train, self.params_sim['sim_params'])
y_eval = y_eval - self.sim_model_step(x_eval, self.params_sim['sim_params'])
self.base_bnn.fit(
x_train, y_train, x_eval, y_eval, num_steps, log_period, log_to_wandb,
metrics_objective,
keep_the_best, per_dim_metrics
)
if self.use_base_bnn:
self.base_bnn.fit(
x_train, y_train, x_eval, y_eval, num_steps, log_period, log_to_wandb,
metrics_objective,
keep_the_best, per_dim_metrics
)


if __name__ == '__main__':
Expand Down Expand Up @@ -427,6 +432,10 @@ def key_iter():
bandwidth_svgd=1.0, likelihood_std=obs_noise_std, likelihood_exponent=1.0,
normalize_likelihood_std=True, learn_likelihood_std=False, weight_decay=weight_decay, domain=sim.domain,
)
bnn = BNNGreyBox(base_bnn=base_bnn, sim=sim)
bnn = BNNGreyBox(base_bnn=base_bnn, sim=sim, use_base_bnn=True, lr_sim=3e-4)
for i in range(10):
bnn.fit(x_train, y_train, x_eval=x_test, y_eval=y_test, num_steps=2000, per_dim_metrics=True)
bnn.fit(x_train, y_train, x_eval=x_test, y_eval=y_test, num_steps=2000, per_dim_metrics=True,
num_sim_model_train_steps=2000)
y_pred, _ = bnn.predict(x_test)
loss = jnp.sqrt(jnp.square(y_pred - y_test).sum(axis=-1)).mean(0)
print('loss: ', loss)

0 comments on commit 59e2f09

Please sign in to comment.