Skip to content

Commit

Permalink
add only car dataset v2 option to data_provider
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasrothfuss committed Jan 12, 2024
1 parent be3bfcb commit a460bbe
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
29 changes: 19 additions & 10 deletions experiments/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
'likelihood_std': {'value': _RACECAR_NOISE_STD_ENCODED.tolist()},
'num_samples_train': {'value': 200},
} for name in ['real_racecar_new', 'real_racecar_new_only_pose', 'real_racecar_new_no_angvel',
'real_racecar_new_actionstack']
'real_racecar_new_actionstack', 'real_racecar_v2']
})


Expand Down Expand Up @@ -187,18 +187,25 @@ def _rccar_transitions_to_dataset(transitions: Transition, encode_angles: bool =


def get_rccar_recorded_data_new(encode_angle: bool = True, skip_first_n_points: int = 10,
dataset: str = 'all',
action_delay: int = 3, action_stacking: bool = False,
car_id: int = 2):

assert car_id in [1, 2]
if car_id == 1:
num_train_traj = 8
assert dataset in ['all', 'v1']
recordings_dir = [os.path.join(DATA_DIR, 'recordings_rc_car_v1')]
elif car_id == 2:
num_train_traj = 12
recordings_dir = [os.path.join(DATA_DIR, 'recordings_rc_car_v2'),
os.path.join(DATA_DIR, 'recordings_rc_car_v3'),
os.path.join(DATA_DIR, 'recordings_rc_car_v4')]
if dataset == 'all':
recordings_dir = [os.path.join(DATA_DIR, 'recordings_rc_car_v2'),
os.path.join(DATA_DIR, 'recordings_rc_car_v3'),
os.path.join(DATA_DIR, 'recordings_rc_car_v4')]
num_test_points = 20_000
elif dataset in ['v2', 'v3', 'v4']:
recordings_dir = [os.path.join(DATA_DIR, f'recordings_rc_car_{dataset}')]
num_test_points = 10_000
else:
raise ValueError(f"Unknown dataset {dataset} for car_id {car_id}")
else:
raise ValueError(f"Unknown car id {car_id}")
files = [sorted(glob.glob(rd + '/*.pickle')) for rd in recordings_dir]
Expand All @@ -208,18 +215,16 @@ def get_rccar_recorded_data_new(encode_angle: bool = True, skip_first_n_points:

# load and shuffle transitions
transitions = _load_transitions(file_names)
# indices = jax.random.permutation(key=jax.random.PRNGKey(9345), x=jnp.arange(0, len(transitions)))
# transitions = [transitions[idx] for idx in indices]

# transform transitions into supervised learning datasets
prep_fn = partial(_rccar_transitions_to_dataset, encode_angles=encode_angle, skip_first_n=skip_first_n_points,
action_delay=action_delay, action_stacking=action_stacking)
x, y = map(lambda x: jnp.concatenate(x, axis=0), zip(*map(prep_fn, transitions)))
# x_test, y_test = map(lambda x: jnp.concatenate(x, axis=0), zip(*map(prep_fn, transitions[num_train_traj:])))
indices = jnp.arange(start=0, stop=x.shape[0], step=1)
indices = jax.random.shuffle(key=jax.random.PRNGKey(9345), x=indices)
x, y = x[indices], y[indices]
num_test_points = 20_000

# split into train and test
x_train, y_train, x_test, y_test = x[:-num_test_points], y[:-num_test_points], \
x[-num_test_points:], y[-num_test_points:]
return x_train, y_train, x_test, y_test
Expand Down Expand Up @@ -380,6 +385,10 @@ def provide_data_and_sim(data_source: str, data_spec: Dict[str, Any], data_seed:
elif data_source.startswith('real_racecar_new'):
x_train, y_train, x_test, y_test = get_rccar_recorded_data_new(encode_angle=True, action_stacking=False,
action_delay=3, car_id=car_id)
elif data_source.startswith('real_racecar_v2'):
x_train, y_train, x_test, y_test = get_rccar_recorded_data_new(encode_angle=True, action_stacking=False,
action_delay=3, car_id=car_id,
dataset='v2')
else:
x_train, y_train, x_test, y_test = get_rccar_recorded_data(encode_angle=True)

Expand Down
6 changes: 0 additions & 6 deletions experiments/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,20 +321,14 @@ def collect_exp_results(exp_name: str, dir_tree_depth: int = 3, verbose: bool =


def ucb(row):
if row.shape[0] > 1:
return np.nan
return np.quantile(row, q=0.95, axis=0)


def lcb(row):
if row.shape[0] > 1:
return np.nan
return np.quantile(row, q=0.05, axis=0)


def median(row):
if row.shape[0] > 1:
return np.nan
return np.quantile(row, q=0.5, axis=0)


Expand Down

0 comments on commit a460bbe

Please sign in to comment.