Skip to content

Commit

Permalink
code review grey_box model
Browse files Browse the repository at this point in the history
  • Loading branch information
lenarttreven committed Nov 22, 2023
1 parent 88a5bb6 commit ad4afb7
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions sim_transfer/models/bnn_fsvd_grey_box_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,25 @@

class BNN_FSVGD_GreyBox(BNN_FSVGD):

def __init__(self, sim: FunctionSimulator, lr_sim: float = None, weight_decay_sim: float = 0.0,
def __init__(self,
sim: FunctionSimulator,
lr_sim: float = None,
weight_decay_sim: float = 0.0,
num_sim_model_train_steps: int = 10_000,
*args, **kwargs):
*args,
**kwargs):
super().__init__(domain=sim.domain, *args, **kwargs)
self.sim = sim
param_key = self._next_rng_key()
self.num_sim_model_train_steps = num_sim_model_train_steps
sim_params = self.sim.sample_params(param_key)
# Pseudo-likelihood std.
likelihood_std = -1. * jnp.ones(self.output_size)
self.params_sim = {'sim_params': sim_params, 'likelihood_std': likelihood_std}
self.optim_sim = None
# Get the same normalization stats that would be provided to pure BNN_FSVGD model also for the sim model.
self._y_mean_sim, self._y_std_sim = jnp.copy(self._y_mean), jnp.copy(self._y_std)
# Setup learning rate and weight decay for sim model.
if lr_sim:
self.lr_sim = lr_sim
else:
Expand All @@ -38,30 +45,32 @@ def __init__(self, sim: FunctionSimulator, lr_sim: float = None, weight_decay_si
self.weight_decay_sim = weight_decay_sim
else:
self.weight_decay_sim = self.weight_decay
# Initialize optimizer for the BNN model
self._init_optim()
# Initialize optimizer for the sim model
self._init_sim_optim()

def reinit(self, rng_key: Optional[jax.random.PRNGKey] = None):
""" Reinitializes the model parameters and the optimizer state."""
if rng_key is None:
key_rng = self._next_rng_key()
key_model = self._next_rng_key()
param_key = self._next_rng_key()
key_param = self._next_rng_key()
else:
key_model, key_rng, param_key = jax.random.split(rng_key, 3)
key_model, key_rng, key_param = jax.random.split(rng_key, 3)
self._rng_key = key_rng # reinitialize rng_key
self.batched_model.reinit_params(key_model) # reinitialize model parameters
self.params['nn_params_stacked'] = self.batched_model.param_vectors_stacked
sim_params = self.sim.sample_params(param_key)
likelihood_std = -1. * jnp.ones(self.output_size)
sim_params = self.sim.sample_params(key_param) # Sample one set of sim params
likelihood_std = -1. * jnp.ones(self.output_size) # Initialize pseudo likelihood std
self.params_sim = {'sim_params': sim_params, 'likelihood_std': likelihood_std}
self._init_likelihood() # reinitialize likelihood std
self._init_optim() # reinitialize optimizer
self._init_sim_optim()
self._init_sim_optim() # reinitialize optimizer for sim model

def _init_sim_optim(self):
""" Initializes the optimizer and the optimizer state.
Sets the attributes self.optim and self.opt_state. """
Sets the attributes self.optim_sim and self.opt_state_sim. """
if self.weight_decay_sim > 0:
self.optim_sim = optax.adamw(learning_rate=self.lr_sim, weight_decay=self.weight_decay_sim)
else:
Expand All @@ -79,7 +88,7 @@ def _preprocess_train_data_sim(self, x_train: jnp.ndarray, y_train: jnp.ndarray)
return x_train, y_train

def _compute_normalization_stats_sim(self, x: jnp.ndarray, y: jnp.ndarray) -> None:
# computes the empirical normalization stats and stores as private variables
# Computes the empirical normalization stats and stores as private variables
x, y = self._ensure_atleast_2d_float(x, y)
self._x_mean = jnp.mean(x, axis=0)
self._y_mean_sim = jnp.mean(y, axis=0)
Expand Down Expand Up @@ -150,6 +159,12 @@ def _step_grey_box(self, opt_state_sim: optax.OptState, params_sim: Dict,

def sim_model_step(self, x: jnp.array, params_sim: NamedTuple, normalized_x: bool = False,
normalized_y: bool = False):
""" Evaluates the sim model with parameter params_sim on the input x.
Args: x: input
params_sim: parameters of the sim model
normalized_x: whether the input is normalized
normalized_y: whether the output should be normalized
"""
if normalized_x:
x = self._unnormalize_data_sim(x)
y = self.sim.evaluate_sim(x, params_sim)
Expand All @@ -167,7 +182,7 @@ def predict_dist(self, x: jnp.ndarray, include_noise: bool = True):
return pred_dist

def normalized_batched_predictions(self, x: jnp.ndarray):
# sim model takes unormalized input and predicts unnormalized output.
# Sim model takes unormalized input and predicts normalized output.
y_sim_raw = self.sim_model_step(x, params_sim=self.params_sim['sim_params'], normalized_x=False,
normalized_y=False)
# We scale the output of the sim model with y_std to obtain y_sim_raw
Expand All @@ -180,6 +195,8 @@ def normalized_batched_predictions(self, x: jnp.ndarray):
y_pred_raw = self.batched_model(x)
# total normalized output is output of NNs + y_sim_scaled. To unnormalize we get (y_nn + y_sim_scaled) * std
# + mean -> y_nn * std + mean + y_sim_scaled * std -> y_nn_unnormalized + y_sim_raw
# TODO: this looks that here could be potential bug!! Check this.
# TODO: Why is y_sim_raw not scaled with self._y_std? Is the output really unnormalized?
y_pred_raw = jax.vmap(lambda z: z + y_sim_scaled)(y_pred_raw)
return y_pred_raw

Expand Down

0 comments on commit ad4afb7

Please sign in to comment.