Skip to content

Commit

Permalink
add outputscale_gp_prior to BNN_FSVGD
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasrothfuss committed Jan 16, 2024
1 parent 877bf41 commit 0967a5c
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions sim_transfer/models/bnn_fsvgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self,
num_particles: int = 10,
bandwidth_svgd: float = 0.2,
bandwidth_gp_prior: float = 0.2,
outputscale_gp_prior: float = 1.0,
data_batch_size: int = 16,
num_measurement_points: int = 16,
num_train_steps: int = 10000,
Expand All @@ -47,6 +48,7 @@ def __init__(self,
likelihood_reg=likelihood_reg)
self._save_init_args(locals())
self.bandwidth_gp_prior = bandwidth_gp_prior
self.outputscale_gp_prior = outputscale_gp_prior
self.num_measurement_points = num_measurement_points

# initialize kernel
Expand All @@ -71,7 +73,7 @@ def _neg_log_posterior(self, pred_raw: jnp.ndarray, likelihood_std: jnp.array, x

def _gp_prior_log_prob(self, x: jnp.array, y: jnp.array, eps: float = 1e-3) -> jnp.ndarray:
k = self.kernel_gp_prior.matrix(x, x) + eps * jnp.eye(x.shape[0])
dist = tfd.MultivariateNormalFullCovariance(jnp.zeros(x.shape[0]), k)
dist = tfd.MultivariateNormalFullCovariance(jnp.zeros(x.shape[0]), self.outputscale_gp_prior**2 * k)
return jnp.mean(jnp.sum(dist.log_prob(jnp.swapaxes(y, -1, -2)), axis=-1)) / x.shape[0]


Expand Down Expand Up @@ -113,7 +115,7 @@ def key_iter():
domain = sim.domain
x_measurement = jnp.linspace(domain.l[0], domain.u[0], 50).reshape(-1, 1)

num_train_points = 10
num_train_points = 1

x_train = jax.random.uniform(key=next(key_iter), shape=(num_train_points,),
minval=domain.l, maxval=domain.u).reshape(-1, 1)
Expand All @@ -124,7 +126,7 @@ def key_iter():

bnn = BNN_FSVGD(NUM_DIM_X, NUM_DIM_Y, domain=domain, rng_key=next(key_iter), num_train_steps=20000,
data_batch_size=10, num_measurement_points=16, normalize_data=True, bandwidth_svgd=1.0,
likelihood_std=0.2, learn_likelihood_std=False,
likelihood_std=0.2, learn_likelihood_std=False, outputscale_gp_prior=0.2,
bandwidth_gp_prior=0.2, hidden_layer_sizes=[64, 64, 64],
normalization_stats=sim.normalization_stats,
hidden_activation=jax.nn.tanh)
Expand Down

0 comments on commit 0967a5c

Please sign in to comment.