Skip to content

Commit

Permalink
Merge branch 'main' of github.com:lasgroup/simulation_transfer
Browse files Browse the repository at this point in the history
  • Loading branch information
sukhijab committed Nov 30, 2023
2 parents fbc30bf + 3537876 commit 4fcb13f
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 42 deletions.
46 changes: 23 additions & 23 deletions experiments/online_rl_hardware/online_rl_loop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import os
import pickle
import random
import sys
from pprint import pprint
from typing import Any, NamedTuple
Expand All @@ -17,15 +16,12 @@
from experiments.online_rl_hardware.train_policy import train_model_based_policy
from experiments.online_rl_hardware.utils import (set_up_bnn_dynamics_model, set_up_dummy_sac_trainer,
dump_trajectory_summary, execute,
prepare_init_transitions_for_car_env)
prepare_init_transitions_for_car_env, get_random_hash)
from experiments.util import Logger, RESULT_DIR

from sim_transfer.sims.envs import RCCarSimEnv
from sim_transfer.sims.util import plot_rc_trajectory

WANDB_ENTITY = 'sukhijab'
EULER_ENTITY = 'sukhijab'
WANDB_LOG_DIR_EULER = '/cluster/scratch/' + EULER_ENTITY
PRIORS = {'none_FVSGD',
'none_SVGD',
'high_fidelity',
Expand All @@ -38,26 +34,28 @@
def _load_remote_config(machine: str):
# load remote config
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'remote_config.json'), 'r') as f:
remote_config = json.load(f)
config = json.load(f)

# choose machine
assert machine in remote_config, f'Machine {machine} not found in remote config. ' \
f'Available machines: {list(remote_config.keys())}'
remote_config = remote_config[machine]
assert machine in config['remote_machines'], \
f'Machine {machine} not found in remote config. Available machines: {list(config["remote_machines"].keys())}'
remote_config = config['remote_machines'][machine]

# create local director if it does not exist
assert 'user_config' in config, 'No user config found in remote config.'
user_config = config['user_config']
user_config['wandb_log_dir_euler'] = '/cluster/scratch/' + user_config['euler_entity']

# create local directory if it does not exist
os.makedirs(remote_config['local_dir'], exist_ok=True)

# print remote config
print(f'Remote config [{machine}]:')
pprint(remote_config)
print('')

return remote_config

print('\nUser config:')
pprint(user_config)

def _get_random_hash() -> str:
return "%032x" % random.getrandbits(128)
return remote_config, user_config


def train_model_based_policy_remote(*args,
Expand All @@ -82,14 +80,14 @@ def train_model_based_policy_remote(*args,
if machine == 'local':
# if not running remotely, just run the function locally and return the result
return train_model_based_policy(*args, **kwargs)
rmt_cfg = _load_remote_config(machine=machine)
rmt_cfg, _ = _load_remote_config(machine=machine)

# copy latest version of train_policy.py to remote and make sure remote directory exists
execute(f'scp {rmt_cfg["local_script"]} {rmt_cfg["remote_machine"]}:{rmt_cfg["remote_script"]}', verbosity)
execute(f'ssh {rmt_cfg["remote_machine"]} "mkdir -p {rmt_cfg["remote_dir"]}"', verbosity)

# dump train_data to local pkl file
run_hash = _get_random_hash()
run_hash = get_random_hash()
train_data_path_local = os.path.join(rmt_cfg['local_dir'], f'train_data_{run_hash}.pkl')
with open(train_data_path_local, 'wb') as f:
pickle.dump({'args': args, 'kwargs': kwargs}, f)
Expand All @@ -104,12 +102,11 @@ def train_model_based_policy_remote(*args,

# run the train_policy.py script on the remote machine
result_path_remote = os.path.join(rmt_cfg['remote_dir'], f'result_{run_hash}.pkl')
command = f'export PYTHONPATH={rmt_cfg["remote_pythonpath"]} && ' \
f'{rmt_cfg["remote_interpreter"]} {rmt_cfg["remote_script"]} ' \
command = f'{rmt_cfg["remote_interpreter"]} {rmt_cfg["remote_script"]} ' \
f'--data_load_path {train_data_path_remote} --model_dump_path {result_path_remote}'
if verbosity:
print('[Local] Executing command:', command)
execute(f'ssh {rmt_cfg["remote_machine"]} "{command}"', verbosity)
execute(f'ssh -tt {rmt_cfg["remote_machine"]} "{rmt_cfg["remote_pre_cmd"]} {command}"', verbosity)

# transfer result back to local
result_path_local = os.path.join(rmt_cfg['local_dir'], f'result_{run_hash}.pkl')
Expand Down Expand Up @@ -163,6 +160,8 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True,
machine: str = 'local'):
rng_key_env, rng_key_model, rng_key_rollouts = jax.random.split(jax.random.PRNGKey(config.seed), 3)

_, user_cfg = _load_remote_config(machine=machine)

"""Setup car reward kwargs"""
car_reward_kwargs = dict(encode_angle=encode_angle,
ctrl_cost_weight=config.ctrl_cost_weight,
Expand Down Expand Up @@ -236,8 +235,9 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True,
total_config = sac_kwargs | config._asdict() | car_reward_kwargs

""" WANDB & Logging configuration """
wandb_config = {'project': config.project_name, 'entity': WANDB_ENTITY, 'resume': 'allow',
'dir': WANDB_LOG_DIR_EULER if os.path.isdir(WANDB_LOG_DIR_EULER) else '/tmp/',
wandb_config = {'project': config.project_name, 'entity': user_cfg['wandb_entity'], 'resume': 'allow',
'dir': user_cfg['wandb_log_dir_euler'] if os.path.isdir(user_cfg['wandb_log_dir_euler']) \
else '/tmp/',
'config': total_config, 'settings': {'_service_wait': 300}}
wandb.init(**wandb_config)
run_id = wandb.run.id
Expand All @@ -249,7 +249,7 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True,
log_path = os.path.join(dump_dir, f"{wandb_config['id']}.log")

if machine == 'euler':
wandb_config_remote = wandb_config | {'dir': '/cluster/scratch/' + EULER_ENTITY}
wandb_config_remote = wandb_config | {'dir': '/cluster/scratch/' + user_cfg['euler_entity']}
else:
wandb_config_remote = wandb_config | {'dir': '/tmp/'}

Expand Down
42 changes: 23 additions & 19 deletions experiments/online_rl_hardware/remote_config.json
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
{
"euler": {
"local_dir": "/tmp/ssh_remote/",
"local_script": "/Users/rojonas/Dropbox/Eigene_Dateien/ETH/02_Projects/20_sim_transfer/code_sim_transfer/experiments/online_rl_hardware/train_policy.py",
"remote_machine": "rojonas@euler",
"remote_dir": "/cluster/scratch/rojonas/",
"remote_pre_cmd": "srun --gpus=1 --gres=gpumem:10240m --cpus-per-task=4 --time=01:00:00",
"remote_interpreter": "/cluster/project/infk/krause/rojonas/.venv/sim_transfer_gpu/bin/python",
"remote_script": "/cluster/project/infk/krause/rojonas/sim_transfer/experiments/online_rl_hardware/train_policy.py",
"remote_pythonpath": "/cluster/project/infk/krause/rojonas/sim_transfer"
},
"optimality": {
"local_dir": "/tmp/ssh_remote/",
"local_script": "/Users/rojonas/Dropbox/Eigene_Dateien/ETH/02_Projects/20_sim_transfer/code_sim_transfer/experiments/online_rl_hardware/train_policy.py",
"remote_machine": "rojonas@optimality.inf.ethz.ch",
"remote_dir": "/tmp/ssh_remote/",
"remote_pre_cmd": "",
"remote_interpreter": "/local/rojonas/miniconda3/envs/sim_transfer_gpu/bin/python",
"remote_script": "/local/rojonas/sim_transfer/experiments/online_rl_hardware/train_policy.py",
"remote_pythonpath": "/local/rojonas/sim_transfer"
"remote_machines": {
"euler": {
"local_dir": "/tmp/ssh_remote/",
"local_script": "/Users/rojonas/Dropbox/Eigene_Dateien/ETH/02_Projects/20_sim_transfer/code_sim_transfer/experiments/online_rl_hardware/train_policy.py",
"remote_machine": "rojonas@euler",
"remote_dir": "/cluster/scratch/rojonas/",
"remote_pre_cmd": "srun --gpus=1 --gres=gpumem:10240m --cpus-per-task=4 --mem-per-cpu=8192m --time=01:00:00",
"remote_interpreter": "/cluster/project/infk/krause/rojonas/.venv/sim_transfer_gpu/bin/python",
"remote_script": "/cluster/project/infk/krause/rojonas/sim_transfer/experiments/online_rl_hardware/train_policy.py"
},
"optimality": {
"local_dir": "/tmp/ssh_remote/",
"local_script": "/Users/rojonas/Dropbox/Eigene_Dateien/ETH/02_Projects/20_sim_transfer/code_sim_transfer/experiments/online_rl_hardware/train_policy.py",
"remote_machine": "rojonas@optimality.inf.ethz.ch",
"remote_dir": "/tmp/ssh_remote/",
"remote_pre_cmd": "",
"remote_interpreter": "/local/rojonas/miniconda3/envs/sim_transfer_gpu/bin/python",
"remote_script": "/local/rojonas/sim_transfer/experiments/online_rl_hardware/train_policy.py"
}
},
"user_config": {
"wandb_entity": "jonasrothfuss",
"euler_entity": "rojonas"
}
}
5 changes: 5 additions & 0 deletions experiments/online_rl_hardware/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from typing import Any, NamedTuple, Dict

from brax.training.replay_buffers import UniformSamplingQueue
Expand Down Expand Up @@ -237,3 +238,7 @@ def prepare_init_transitions_for_car_env(key: jax.random.PRNGKey, number_of_samp
discount=discounts,
next_observation=jnp.concatenate([state_obs, framestacked_actions], axis=-1))
return transitions


def get_random_hash() -> str:
return "%032x" % random.getrandbits(128)

0 comments on commit 4fcb13f

Please sign in to comment.