Skip to content

Commit

Permalink
fix norm stats issues in additive sim
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasrothfuss committed Aug 16, 2023
1 parent 1a89f13 commit 2cb8f5b
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions sim_transfer/sims/simulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ def domain(self) -> Domain:

@property
def normalization_stats(self) -> Dict[str, jnp.ndarray]:
norm_stats = {}
for stat_name in ['x_mean', 'x_std', 'y_mean', 'y_std']:
norm_stats = self.base_sims[0].normalization_stats # take x stats from first sim
# for the y stats, combine stats from all sims
for stat_name in ['y_mean', 'y_std']:
stats_stack = jnp.stack([sim.normalization_stats[stat_name] for sim in self.base_sims], axis=0)
if 'mean' in stat_name:
norm_stats[stat_name] = jnp.sum(stats_stack, axis=0)
Expand Down Expand Up @@ -196,19 +197,17 @@ def domain(self) -> Domain:
return HypercubeDomain(lower=lower, upper=upper)

def _typical_f(self, x: jnp.array) -> jnp.array:
return self.mean_fn(x)
return jnp.repeat(self.mean_fn(x)[:, None], self.output_size, axis=-1)

@cached_property
def normalization_stats(self) -> Dict[str, jnp.ndarray]:
key1, key2 = jax.random.split(jax.random.PRNGKey(23423), 2)
x = self.domain.sample_uniformly(key1, sample_shape=1000)
y = self.sample_function_vals(x, num_samples=50, rng_key=key2)
y = y.reshape((-1, self.output_size))
norm_stats = {
'x_mean': jnp.mean(x, axis=0),
'x_std': jnp.std(x, axis=0),
'y_mean': jnp.mean(y, axis=0),
'y_std': 1.5 * jnp.std(y, axis=0),
'y_mean': jnp.mean(self._typical_f(x), axis=0),
'y_std': 1.5 * self.output_scales,
}
return norm_stats

Expand Down

0 comments on commit 2cb8f5b

Please sign in to comment.