Skip to content

Commit

Permalink
changed callable to property for grey box dist
Browse files Browse the repository at this point in the history
  • Loading branch information
sukhijab committed Nov 22, 2023
1 parent 4f40963 commit a0bfd4b
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion sim_transfer/models/bnn_grey_box_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sim_transfer.modules.util import aggregate_stats
import wandb
import numpy as np
from tensorflow_probability.substrates import jax as tfp


class BNNGreyBox(AbstractRegressionModel):
Expand Down Expand Up @@ -254,12 +255,17 @@ def _to_pred_dist(self, y_pred_raw: jnp.ndarray, likelihood_std: jnp.ndarray, in
jnp.std(y_pred_raw, axis=0))
return pred_dist

def predict_dist(self, x: jnp.ndarray, include_noise: bool = True):
def predict_dist(self, x: jnp.ndarray, include_noise: bool = True) -> tfp.distributions.Distribution:
self.batched_model.param_vectors_stacked = self.params['nn_params_stacked']
y_pred = self.predict_post_samples(x)
pred_dist = self._to_pred_dist(y_pred, likelihood_std=self.likelihood_std, include_noise=include_noise)
assert pred_dist.batch_shape == x.shape[:-1]
assert pred_dist.event_shape == (self.output_size,)
if callable(pred_dist.mean):
mean, stddev, var = pred_dist.mean(), pred_dist.stddev(), pred_dist.variance()
pred_dist.mean = mean
pred_dist.stddev = stddev
pred_dist.variance = var
return pred_dist

def predict(self, x: jnp.ndarray, include_noise: bool = False) -> Tuple[jnp.ndarray, jnp.ndarray]:
Expand Down

0 comments on commit a0bfd4b

Please sign in to comment.