From 2cb8f5b432c3e47c0ac3380c489f42eb188ca18d Mon Sep 17 00:00:00 2001 From: Jonas Rothfuss Date: Wed, 16 Aug 2023 15:52:02 +0200 Subject: [PATCH] fix norm stats issues in additive sim --- sim_transfer/sims/simulators.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/sim_transfer/sims/simulators.py b/sim_transfer/sims/simulators.py index a05df7c..844d4b6 100644 --- a/sim_transfer/sims/simulators.py +++ b/sim_transfer/sims/simulators.py @@ -103,8 +103,9 @@ def domain(self) -> Domain: @property def normalization_stats(self) -> Dict[str, jnp.ndarray]: - norm_stats = {} - for stat_name in ['x_mean', 'x_std', 'y_mean', 'y_std']: + norm_stats = self.base_sims[0].normalization_stats # take x stats from first sim + # for the y stats, combine stats from all sims + for stat_name in ['y_mean', 'y_std']: stats_stack = jnp.stack([sim.normalization_stats[stat_name] for sim in self.base_sims], axis=0) if 'mean' in stat_name: norm_stats[stat_name] = jnp.sum(stats_stack, axis=0) @@ -196,19 +197,17 @@ def domain(self) -> Domain: return HypercubeDomain(lower=lower, upper=upper) def _typical_f(self, x: jnp.array) -> jnp.array: - return self.mean_fn(x) + return jnp.repeat(self.mean_fn(x)[:, None], self.output_size, axis=-1) @cached_property def normalization_stats(self) -> Dict[str, jnp.ndarray]: key1, key2 = jax.random.split(jax.random.PRNGKey(23423), 2) x = self.domain.sample_uniformly(key1, sample_shape=1000) - y = self.sample_function_vals(x, num_samples=50, rng_key=key2) - y = y.reshape((-1, self.output_size)) norm_stats = { 'x_mean': jnp.mean(x, axis=0), 'x_std': jnp.std(x, axis=0), - 'y_mean': jnp.mean(y, axis=0), - 'y_std': 1.5 * jnp.std(y, axis=0), + 'y_mean': jnp.mean(self._typical_f(x), axis=0), + 'y_std': 1.5 * self.output_scales, } return norm_stats