From a0f24de95c4def176294adef42ac6487f6979af2 Mon Sep 17 00:00:00 2001 From: lenarttreven Date: Wed, 30 Aug 2023 20:06:28 +0200 Subject: [PATCH] fix reshaping in predict next state --- experiments/model_based_rl_sim_transfer_comparison/exp.py | 2 +- .../model_based_rl_sim_transfer_comparison/launcher.py | 4 ++-- sim_transfer/rl/model_based_rl/learned_system.py | 1 + 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/experiments/model_based_rl_sim_transfer_comparison/exp.py b/experiments/model_based_rl_sim_transfer_comparison/exp.py index e8e2f2d..0269675 100644 --- a/experiments/model_based_rl_sim_transfer_comparison/exp.py +++ b/experiments/model_based_rl_sim_transfer_comparison/exp.py @@ -199,6 +199,6 @@ def main(args): parser.add_argument('--best_bnn_model', type=int, default=1) parser.add_argument('--best_policy', type=int, default=0) parser.add_argument('--margin_factor', type=float, default=20.0) - parser.add_argument('--predict_difference', type=int, default=1) + parser.add_argument('--predict_difference', type=int, default=0) args = parser.parse_args() main(args) diff --git a/experiments/model_based_rl_sim_transfer_comparison/launcher.py b/experiments/model_based_rl_sim_transfer_comparison/launcher.py index 6f4c231..f6b595d 100644 --- a/experiments/model_based_rl_sim_transfer_comparison/launcher.py +++ b/experiments/model_based_rl_sim_transfer_comparison/launcher.py @@ -1,7 +1,7 @@ import exp from experiments.util import generate_run_commands, generate_base_command, dict_permutations -PROJECT_NAME = 'PredictNextStateInsteadofDifference' +PROJECT_NAME = 'PredictNextStateInsteadofDifferenceN2' applicable_configs = { 'horizon_len': [50, 2 ** 6, 100], @@ -10,7 +10,7 @@ 'num_episodes': [40], 'sac_num_env_steps': [1_000_000, 2_000_000], 'bnn_train_steps': [20_000, 40_000], - 'learnable_likelihood_std': ['yes', 'no'], + 'learnable_likelihood_std': ['yes'], 'reset_bnn': ['no'], 'use_sim_prior': [1], 'include_aleatoric_noise': [1], diff --git a/sim_transfer/rl/model_based_rl/learned_system.py b/sim_transfer/rl/model_based_rl/learned_system.py index e6d8b3d..f9c4ff8 100644 --- a/sim_transfer/rl/model_based_rl/learned_system.py +++ b/sim_transfer/rl/model_based_rl/learned_system.py @@ -46,6 +46,7 @@ def next_state(self, x_next_dist = self.model.predict_dist(z, include_noise=self.include_noise) next_key, key_sample_x_next = jr.split(dynamics_params.key) x_next = x_next_dist.sample(seed=key_sample_x_next) + x_next = x_next.reshape((self.x_dim,)) new_dynamics_params = dynamics_params.replace(key=next_key) return Normal(loc=x_next, scale=jnp.zeros_like(x_next)), new_dynamics_params