Skip to content

Commit

Permalink
introduced selection of configs for hardware experiments and updated …
Browse files Browse the repository at this point in the history
…data provider for simulation experiment to include no frame stacking for true data evaluated on sim
  • Loading branch information
sukhijab committed Oct 25, 2023
1 parent c82166f commit d310274
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 20 deletions.
7 changes: 4 additions & 3 deletions experiments/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def provide_data_and_sim(data_source: str, data_spec: Dict[str, Any], data_seed:
use_hf_sim = data_spec.get('use_hf_sim', True)
car_id = data_spec.get('car_id', 2)
num_stacked_actions = data_spec.get('num_stacked_actions', 3)
assert num_stacked_actions == 3, "We only support 3 stacked actions for now"
# assert num_stacked_actions == 3, "We only support 3 stacked actions for now"

# Prepare simulator for bnn_training (the only difference is that here we can have also low fidelity sim)
sim = RaceCarSim(encode_angle=True, use_blend=use_hf_sim, car_id=car_id)
Expand All @@ -278,7 +278,8 @@ def provide_data_and_sim(data_source: str, data_spec: Dict[str, Any], data_seed:
# Now we prepare data
# 1.st load data from the real car
x_train, y_train, x_test, y_test = get_rccar_recorded_data_new(encode_angle=True, action_stacking=True,
action_delay=3, car_id=car_id)
action_delay=num_stacked_actions,
car_id=car_id)

# We delete y_train, y_test and replace it with the simulator output
del y_train, y_test
Expand Down Expand Up @@ -317,7 +318,7 @@ def provide_data_and_sim(data_source: str, data_spec: Dict[str, Any], data_seed:
sim_for_sampling_data = RaceCarSim(encode_angle=True, use_blend=True, car_id=car_id)
if num_stacked_actions > 0:
sim_for_sampling_data = StackedActionSimWrapper(sim_for_sampling_data,
num_stacked_actions=3,
num_stacked_actions=num_stacked_actions,
action_size=2)

y_train = sim_for_sampling_data._typical_f(x_train)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from sim_transfer.hardware.car_env import CarEnv
from sim_transfer.rl.rl_on_offline_data import RLFromOfflineData
from sim_transfer.sims.util import plot_rc_trajectory
import pickle

ENTITY = 'trevenl'
ENTITY = 'sukhijab'


class RunSpec(NamedTuple):
Expand All @@ -22,7 +23,9 @@ class RunSpec(NamedTuple):


def run_all_hardware_experiments(project_name_load: str,
project_name_save: str | None = None, ):
project_name_save: str | None = None,
desired_config: dict | None = None,
):
api = wandb.Api()
project_name = ENTITY + '/' + project_name_load
local_dir = "saved_data"
Expand All @@ -40,18 +43,28 @@ def run_all_hardware_experiments(project_name_load: str,
# Download all models
runs = api.runs(project_name)
for run in runs:
config = {k: v for k, v in run.config.items() if not k.startswith('_')}
correct_config = 1
if desired_config:
for key in desired_config.keys():
if config[key] != desired_config[key]:
correct_config = 0
break
if not correct_config:
continue
for file in run.files():
if file.name.startswith(dir_to_save):
file.download(replace=True, root=os.path.join(local_dir, run.group, run.id))
runs_spec.append(RunSpec(group_name=run.group,
run_id=run.id))
break

# Run all models on hardware
for run_spec in runs_spec:
# We open the file with pickle
pre_path = os.path.join(local_dir, run_spec.group_name, run_spec.run_id)
policy_name = 'parameters.pkl'
bnn_name = 'bnn_model.pkl'
policy_name = 'models/parameters.pkl'
bnn_name = 'models/bnn_model.pkl'

with open(os.path.join(pre_path, bnn_name), 'rb') as handle:
bnn_model = pickle.load(handle)
Expand Down Expand Up @@ -255,20 +268,26 @@ def plot_error_on_the_trajectory(data):


if __name__ == '__main__':
import pickle
# import pickle

filename_policy = 'parameters.pkl'
filename_bnn_model = 'bnn_model.pkl'
# filename_policy = 'parameters.pkl'
# filename_bnn_model = 'bnn_model.pkl'

with open(filename_bnn_model, 'rb') as handle:
bnn_model = pickle.load(handle)
# with open(filename_bnn_model, 'rb') as handle:
# bnn_model = pickle.load(handle)

with open(filename_policy, 'rb') as handle:
policy_params = pickle.load(handle)
# with open(filename_policy, 'rb') as handle:
# policy_params = pickle.load(handle)

observations_for_plotting, actions_for_plotting = run_with_learned_policy(bnn_model=bnn_model,
policy_params=policy_params,
project_name='Test',
group_name='MyGroup',
run_name='Butterfly'
)
# observations_for_plotting, actions_for_plotting = run_with_learned_policy(bnn_model=bnn_model,
# policy_params=policy_params,
# project_name='Test',
# group_name='MyGroup',
# run_name='Butterfly'
# )

run_all_hardware_experiments(
project_name_load='OfflineRLHW_without_frame_stack',
project_name_save='OfflineRLHW_without_frame_stack_evaluation',
desired_config={'bandwidth_svgd': 0.2}
)

0 comments on commit d310274

Please sign in to comment.