Skip to content

Commit

Permalink
change to fit_with_scan and appending episode num to run id
Browse files Browse the repository at this point in the history
  • Loading branch information
sukhijab authored and jonasrothfuss committed Nov 28, 2023
1 parent d7d3672 commit c90dc13
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
12 changes: 5 additions & 7 deletions experiments/online_rl_hardware/online_rl_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True,
'dir': WANDB_LOG_DIR_EULER if os.path.isdir(WANDB_LOG_DIR_EULER) else '/tmp/',
'config': total_config, 'settings': {'_service_wait': 300}}
wandb.init(**wandb_config)
wandb_config['id'] = wandb.run.id
run_id = wandb.run.id
wandb_config['id'] = run_id
remote_training = not (machine == 'local')

dump_dir = os.path.join(RESULT_DIR, 'online_rl_hardware', wandb_config['id'])
Expand Down Expand Up @@ -284,8 +285,6 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True,
""" Main loop over episodes """
for episode_id in range(1, config.num_episodes + 1):

if remote_training:
wandb.init(**wandb_config)
sys.stdout = Logger(log_path, stream=sys.stdout)
sys.stderr = Logger(log_path, stream=sys.stderr)
print('\n\n------- Episode', episode_id)
Expand All @@ -302,6 +301,8 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True,
else:
init_transitions = None
# train model & policy
if remote_training:
wandb_config_remote['id'] = run_id + '_episode_{}'.format(episode_id)
policy_params, bnn = train_model_based_policy_remote(
train_data=train_data, bnn_model=bnn, config=mbrl_config, key=key_episode,
episode_idx=episode_id, machine=machine, wandb_config=wandb_config_remote,
Expand Down Expand Up @@ -349,9 +350,6 @@ def policy(x, key: jr.PRNGKey = jr.PRNGKey(0)):
train_data['y_train'] = jnp.concatenate([train_data['y_train'], trajectory[1:, :mbrl_config.x_dim]], axis=0)
print(f'Size of train_data in episode {episode_id}:', train_data['x_train'].shape[0])

if remote_training:
wandb.finish()


if __name__ == '__main__':
import argparse
Expand All @@ -365,7 +363,7 @@ def policy(x, key: jr.PRNGKey = jr.PRNGKey(0)):
parser.add_argument('--control_time_ms', type=float, default=24.)

parser.add_argument('--prior', type=str, default='none_FVSGD')
parser.add_argument('--num_env_steps', type=int, default=200, info='number of steps in the environment per episode')
parser.add_argument('--num_env_steps', type=int, default=200)
parser.add_argument('--reset_bnn', type=int, default=0)
parser.add_argument('--deterministic_policy', type=int, default=1)
parser.add_argument('--initial_state_fraction', type=float, default=0.5)
Expand Down
2 changes: 1 addition & 1 deletion experiments/online_rl_hardware/train_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def train_model_based_policy(train_data: Dict,
# Train model
if config.reset_bnn:
bnn_model.reinit(rng_key=key_reinit_model)
bnn_model.fit(x_train=x_train, y_train=y_train, x_eval=x_test, y_eval=y_test, log_to_wandb=True,
bnn_model.fit_with_scan(x_train=x_train, y_train=y_train, x_eval=x_test, y_eval=y_test, log_to_wandb=True,
keep_the_best=config.return_best_bnn, metrics_objective='eval_nll', log_period=2000)
print(f'Time fo training the transition model: {time.time() - t:.2f} seconds')

Expand Down

0 comments on commit c90dc13

Please sign in to comment.