diff --git a/experiments/online_rl_hardware/online_rl_loop.py b/experiments/online_rl_hardware/online_rl_loop.py index 00bbb81..520734d 100644 --- a/experiments/online_rl_hardware/online_rl_loop.py +++ b/experiments/online_rl_hardware/online_rl_loop.py @@ -238,7 +238,8 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True, 'dir': WANDB_LOG_DIR_EULER if os.path.isdir(WANDB_LOG_DIR_EULER) else '/tmp/', 'config': total_config, 'settings': {'_service_wait': 300}} wandb.init(**wandb_config) - wandb_config['id'] = wandb.run.id + run_id = wandb.run.id + wandb_config['id'] = run_id remote_training = not (machine == 'local') dump_dir = os.path.join(RESULT_DIR, 'online_rl_hardware', wandb_config['id']) @@ -284,8 +285,6 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True, """ Main loop over episodes """ for episode_id in range(1, config.num_episodes + 1): - if remote_training: - wandb.init(**wandb_config) sys.stdout = Logger(log_path, stream=sys.stdout) sys.stderr = Logger(log_path, stream=sys.stderr) print('\n\n------- Episode', episode_id) @@ -302,6 +301,8 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True, else: init_transitions = None # train model & policy + if remote_training: + wandb_config_remote['id'] = run_id + '_episode_{}'.format(episode_id) policy_params, bnn = train_model_based_policy_remote( train_data=train_data, bnn_model=bnn, config=mbrl_config, key=key_episode, episode_idx=episode_id, machine=machine, wandb_config=wandb_config_remote, @@ -349,9 +350,6 @@ def policy(x, key: jr.PRNGKey = jr.PRNGKey(0)): train_data['y_train'] = jnp.concatenate([train_data['y_train'], trajectory[1:, :mbrl_config.x_dim]], axis=0) print(f'Size of train_data in episode {episode_id}:', train_data['x_train'].shape[0]) - if remote_training: - wandb.finish() - if __name__ == '__main__': import argparse @@ -365,7 +363,7 @@ def policy(x, key: jr.PRNGKey = jr.PRNGKey(0)): parser.add_argument('--control_time_ms', type=float, default=24.) parser.add_argument('--prior', type=str, default='none_FVSGD') - parser.add_argument('--num_env_steps', type=int, default=200, info='number of steps in the environment per episode') + parser.add_argument('--num_env_steps', type=int, default=200) parser.add_argument('--reset_bnn', type=int, default=0) parser.add_argument('--deterministic_policy', type=int, default=1) parser.add_argument('--initial_state_fraction', type=float, default=0.5) diff --git a/experiments/online_rl_hardware/train_policy.py b/experiments/online_rl_hardware/train_policy.py index c53342a..f608e16 100644 --- a/experiments/online_rl_hardware/train_policy.py +++ b/experiments/online_rl_hardware/train_policy.py @@ -58,7 +58,7 @@ def train_model_based_policy(train_data: Dict, # Train model if config.reset_bnn: bnn_model.reinit(rng_key=key_reinit_model) - bnn_model.fit(x_train=x_train, y_train=y_train, x_eval=x_test, y_eval=y_test, log_to_wandb=True, + bnn_model.fit_with_scan(x_train=x_train, y_train=y_train, x_eval=x_test, y_eval=y_test, log_to_wandb=True, keep_the_best=config.return_best_bnn, metrics_objective='eval_nll', log_period=2000) print(f'Time fo training the transition model: {time.time() - t:.2f} seconds')