Skip to content

Commit

Permalink
changed assert error to warning for iid sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
sukhijab committed Oct 6, 2023
1 parent 8f24245 commit c6fb6a9
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions experiments/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,11 @@ def provide_data_and_sim(data_source: str, data_spec: Dict[str, Any], data_seed:
sampling_scheme = data_spec.get('sampling', DEFAULTS_RACECAR_REAL['sampling'])
if sampling_scheme == 'iid':
# sample random subset (datapoints are not adjacent in time)
assert num_train <= num_train_available / 4., f'Not enough data for {num_train} iid samples.' \
f'Requires at lest 4 times as much data as requested iid samples.'
import warnings
if num_train > num_train_available / 4.:
warnings.warn(f'Not enough data for {num_train} iid samples.'
f'Requires at lest 4 times as much data as requested '
f'iid samples.')
idx_train = jax.random.choice(key_train, jnp.arange(num_train_available), shape=(num_train,), replace=False)
idx_test = jax.random.choice(key_test, jnp.arange(num_test_available), shape=(num_test,), replace=False)
elif sampling_scheme == 'consecutive':
Expand Down

0 comments on commit c6fb6a9

Please sign in to comment.