Skip to content

Commit

Permalink
add wandb entity and euler entory to remote config system
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasrothfuss committed Nov 29, 2023
1 parent be32295 commit 3537876
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 37 deletions.
41 changes: 21 additions & 20 deletions experiments/online_rl_hardware/online_rl_loop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import os
import pickle
import random
import sys
from pprint import pprint
from typing import Any, NamedTuple
Expand All @@ -17,15 +16,12 @@
from experiments.online_rl_hardware.train_policy import train_model_based_policy
from experiments.online_rl_hardware.utils import (set_up_bnn_dynamics_model, set_up_dummy_sac_trainer,
dump_trajectory_summary, execute,
prepare_init_transitions_for_car_env)
prepare_init_transitions_for_car_env, get_random_hash)
from experiments.util import Logger, RESULT_DIR

from sim_transfer.sims.envs import RCCarSimEnv
from sim_transfer.sims.util import plot_rc_trajectory

WANDB_ENTITY = 'sukhijab'
EULER_ENTITY = 'sukhijab'
WANDB_LOG_DIR_EULER = '/cluster/scratch/' + EULER_ENTITY
PRIORS = {'none_FVSGD',
'none_SVGD',
'high_fidelity',
Expand All @@ -38,26 +34,28 @@
def _load_remote_config(machine: str):
# load remote config
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'remote_config.json'), 'r') as f:
remote_config = json.load(f)
config = json.load(f)

# choose machine
assert machine in remote_config, f'Machine {machine} not found in remote config. ' \
f'Available machines: {list(remote_config.keys())}'
remote_config = remote_config[machine]
assert machine in config['remote_machines'], \
f'Machine {machine} not found in remote config. Available machines: {list(config["remote_machines"].keys())}'
remote_config = config['remote_machines'][machine]

# create local director if it does not exist
assert 'user_config' in config, 'No user config found in remote config.'
user_config = config['user_config']
user_config['wandb_log_dir_euler'] = '/cluster/scratch/' + user_config['euler_entity']

# create local directory if it does not exist
os.makedirs(remote_config['local_dir'], exist_ok=True)

# print remote config
print(f'Remote config [{machine}]:')
pprint(remote_config)
print('')

return remote_config

print('\nUser config:')
pprint(user_config)

def _get_random_hash() -> str:
return "%032x" % random.getrandbits(128)
return remote_config, user_config


def train_model_based_policy_remote(*args,
Expand All @@ -82,14 +80,14 @@ def train_model_based_policy_remote(*args,
if machine == 'local':
# if not running remotely, just run the function locally and return the result
return train_model_based_policy(*args, **kwargs)
rmt_cfg = _load_remote_config(machine=machine)
rmt_cfg, _ = _load_remote_config(machine=machine)

# copy latest version of train_policy.py to remote and make sure remote directory exists
execute(f'scp {rmt_cfg["local_script"]} {rmt_cfg["remote_machine"]}:{rmt_cfg["remote_script"]}', verbosity)
execute(f'ssh {rmt_cfg["remote_machine"]} "mkdir -p {rmt_cfg["remote_dir"]}"', verbosity)

# dump train_data to local pkl file
run_hash = _get_random_hash()
run_hash = get_random_hash()
train_data_path_local = os.path.join(rmt_cfg['local_dir'], f'train_data_{run_hash}.pkl')
with open(train_data_path_local, 'wb') as f:
pickle.dump({'args': args, 'kwargs': kwargs}, f)
Expand Down Expand Up @@ -162,6 +160,8 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True,
machine: str = 'local'):
rng_key_env, rng_key_model, rng_key_rollouts = jax.random.split(jax.random.PRNGKey(config.seed), 3)

_, user_cfg = _load_remote_config(machine=machine)

"""Setup car reward kwargs"""
car_reward_kwargs = dict(encode_angle=encode_angle,
ctrl_cost_weight=config.ctrl_cost_weight,
Expand Down Expand Up @@ -235,8 +235,9 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True,
total_config = sac_kwargs | config._asdict() | car_reward_kwargs

""" WANDB & Logging configuration """
wandb_config = {'project': config.project_name, 'entity': WANDB_ENTITY, 'resume': 'allow',
'dir': WANDB_LOG_DIR_EULER if os.path.isdir(WANDB_LOG_DIR_EULER) else '/tmp/',
wandb_config = {'project': config.project_name, 'entity': user_cfg['wandb_entity'], 'resume': 'allow',
'dir': user_cfg['wandb_log_dir_euler'] if os.path.isdir(user_cfg['wandb_log_dir_euler']) \
else '/tmp/',
'config': total_config, 'settings': {'_service_wait': 300}}
wandb.init(**wandb_config)
run_id = wandb.run.id
Expand All @@ -248,7 +249,7 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True,
log_path = os.path.join(dump_dir, f"{wandb_config['id']}.log")

if machine == 'euler':
wandb_config_remote = wandb_config | {'dir': '/cluster/scratch/' + EULER_ENTITY}
wandb_config_remote = wandb_config | {'dir': '/cluster/scratch/' + user_cfg['euler_entity']}
else:
wandb_config_remote = wandb_config | {'dir': '/tmp/'}

Expand Down
40 changes: 23 additions & 17 deletions experiments/online_rl_hardware/remote_config.json
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
{
"euler": {
"local_dir": "/tmp/ssh_remote/",
"local_script": "/Users/rojonas/Dropbox/Eigene_Dateien/ETH/02_Projects/20_sim_transfer/code_sim_transfer/experiments/online_rl_hardware/train_policy.py",
"remote_machine": "rojonas@euler",
"remote_dir": "/cluster/scratch/rojonas/",
"remote_pre_cmd": "srun --gpus=1 --gres=gpumem:10240m --cpus-per-task=4 --time=01:00:00",
"remote_interpreter": "/cluster/project/infk/krause/rojonas/.venv/sim_transfer_gpu/bin/python",
"remote_script": "/cluster/project/infk/krause/rojonas/sim_transfer/experiments/online_rl_hardware/train_policy.py"
},
"optimality": {
"local_dir": "/tmp/ssh_remote/",
"local_script": "/Users/rojonas/Dropbox/Eigene_Dateien/ETH/02_Projects/20_sim_transfer/code_sim_transfer/experiments/online_rl_hardware/train_policy.py",
"remote_machine": "rojonas@optimality.inf.ethz.ch",
"remote_dir": "/tmp/ssh_remote/",
"remote_pre_cmd": "",
"remote_interpreter": "/local/rojonas/miniconda3/envs/sim_transfer_gpu/bin/python",
"remote_script": "/local/rojonas/sim_transfer/experiments/online_rl_hardware/train_policy.py"
"remote_machines": {
"euler": {
"local_dir": "/tmp/ssh_remote/",
"local_script": "/Users/rojonas/Dropbox/Eigene_Dateien/ETH/02_Projects/20_sim_transfer/code_sim_transfer/experiments/online_rl_hardware/train_policy.py",
"remote_machine": "rojonas@euler",
"remote_dir": "/cluster/scratch/rojonas/",
"remote_pre_cmd": "srun --gpus=1 --gres=gpumem:10240m --cpus-per-task=4 --mem-per-cpu=8192m --time=01:00:00",
"remote_interpreter": "/cluster/project/infk/krause/rojonas/.venv/sim_transfer_gpu/bin/python",
"remote_script": "/cluster/project/infk/krause/rojonas/sim_transfer/experiments/online_rl_hardware/train_policy.py"
},
"optimality": {
"local_dir": "/tmp/ssh_remote/",
"local_script": "/Users/rojonas/Dropbox/Eigene_Dateien/ETH/02_Projects/20_sim_transfer/code_sim_transfer/experiments/online_rl_hardware/train_policy.py",
"remote_machine": "rojonas@optimality.inf.ethz.ch",
"remote_dir": "/tmp/ssh_remote/",
"remote_pre_cmd": "",
"remote_interpreter": "/local/rojonas/miniconda3/envs/sim_transfer_gpu/bin/python",
"remote_script": "/local/rojonas/sim_transfer/experiments/online_rl_hardware/train_policy.py"
}
},
"user_config": {
"wandb_entity": "jonasrothfuss",
"euler_entity": "rojonas"
}
}
5 changes: 5 additions & 0 deletions experiments/online_rl_hardware/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from typing import Any, NamedTuple, Dict

from brax.training.replay_buffers import UniformSamplingQueue
Expand Down Expand Up @@ -237,3 +238,7 @@ def prepare_init_transitions_for_car_env(key: jax.random.PRNGKey, number_of_samp
discount=discounts,
next_observation=jnp.concatenate([state_obs, framestacked_actions], axis=-1))
return transitions


def get_random_hash() -> str:
return "%032x" % random.getrandbits(128)

0 comments on commit 3537876

Please sign in to comment.