Skip to content

Commit

Permalink
fixes in offline rl on hardware file
Browse files Browse the repository at this point in the history
  • Loading branch information
sukhijab committed Oct 26, 2023
1 parent d310274 commit 129a628
Showing 1 changed file with 56 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class RunSpec(NamedTuple):
def run_all_hardware_experiments(project_name_load: str,
project_name_save: str | None = None,
desired_config: dict | None = None,
control_time_ms: float = 32,
):
api = wandb.Api()
project_name = ENTITY + '/' + project_name_load
Expand Down Expand Up @@ -57,7 +58,6 @@ def run_all_hardware_experiments(project_name_load: str,
file.download(replace=True, root=os.path.join(local_dir, run.group, run.id))
runs_spec.append(RunSpec(group_name=run.group,
run_id=run.id))
break

# Run all models on hardware
for run_spec in runs_spec:
Expand All @@ -76,19 +76,22 @@ def run_all_hardware_experiments(project_name_load: str,
bnn_model=bnn_model,
project_name=project_name_save,
group_name=run_spec.group_name,
run_id=run_spec.run_id)
run_id=run_spec.run_id,
control_time_ms=control_time_ms)


def run_with_learned_policy(policy_params,
bnn_model,
project_name: str,
group_name: str,
run_id: str,
encode_angle: bool = True,
control_time_ms: float = 32,
):
"""
Num stacked frames: 3
"""
car_reward_kwargs = dict(encode_angle=True,
car_reward_kwargs = dict(encode_angle=encode_angle,
ctrl_cost_weight=0.005,
margin_factor=20)

Expand All @@ -105,8 +108,10 @@ def run_with_learned_policy(policy_params,
policy = rl_from_offline_data.prepare_policy(params=policy_params)

# replay action sequence on car
env = CarEnv(encode_angle=True, num_frame_stacks=0, max_throttle=0.4,
control_time_ms=27.9)
# env = CarEnv(encode_angle=True, num_frame_stacks=0, max_throttle=0.4,
# control_time_ms=27.9)
env = CarEnv(car_id=2, encode_angle=encode_angle, max_throttle=0.4, control_time_ms=control_time_ms,
num_frame_stacks=3)
obs, _ = env.reset()
print(obs)
observations = []
Expand All @@ -116,7 +121,7 @@ def run_with_learned_policy(policy_params,

num_frame_stack = 3
action_dim = 2
state_dim = 7
state_dim = 6 + int(encode_angle)

stacked_actions = jnp.zeros(shape=(num_frame_stack * action_dim,))
time_diffs = []
Expand All @@ -137,20 +142,15 @@ def run_with_learned_policy(policy_params,
rewards = []

for i in range(200):
action = policy(jnp.concatenate([obs, stacked_actions], axis=-1))
action = np.array(action)
action = np.array(policy(obs))
actions.append(action)
obs, reward, terminate, info = env.step(action)
t = time.time()
time_diff = t - t_prev
t_prev = t
print(i, action, reward, time_diff)
time_diffs.append(time_diff)

# Now we shift the actions
stacked_actions = jnp.roll(stacked_actions, shift=action_dim)
stacked_actions = stacked_actions.at[:action_dim].set(action)

stacked_actions = obs[state_dim:]
observations.append(obs)
rewards.append(reward)
all_stacked_actions.append(stacked_actions)
Expand All @@ -163,8 +163,15 @@ def run_with_learned_policy(policy_params,
observations = np.array(observations)
actions = np.array(actions)
time_diffs = np.array(time_diffs)

print('Avg time per iter:', np.mean(time_diffs[1:]))
mean_time_diff = np.mean(time_diffs[1:])
print('Avg time per iter:', mean_time_diff)
time_diff_std = np.std(time_diffs[1:])
print('Std time per iter:', time_diff_std)

if time_diff_std > 0.001:
Warning('Variability in time difference is too high')
if abs(mean_time_diff - 1/30.) < 0.001:
Warning('Control frequency is not maintained with the time difference')
plt.plot(time_diffs[1:])
plt.title('time diffs')
plt.show()
Expand All @@ -186,12 +193,13 @@ def run_with_learned_policy(policy_params,
})

# We plot the error between the predicted next state and the true next state on the true model
all_stacked_actions = np.stack(all_stacked_actions, axis=0)
extended_state = np.concatenate([observations, all_stacked_actions], axis=-1)
# all_stacked_actions = np.stack(all_stacked_actions, axis=0)
# extended_state = np.concatenate([observations, all_stacked_actions], axis=-1)
extended_state = observations
state_action_pairs = np.concatenate([extended_state, actions], axis=-1)

all_inputs = state_action_pairs[:-1, :]
target_outputs = observations[1:, :] - observations[:-1, :]
target_outputs = observations[1:, :state_dim] - observations[:-1, :state_dim]

"""
We test the model error on the predicted trajectory
Expand All @@ -208,26 +216,27 @@ def run_with_learned_policy(policy_params,
wandb.log({'Error of state difference prediction': wandb.Image(fig)})

# We plot the true trajectory
fig, axes = plot_rc_trajectory(observations,
fig, axes = plot_rc_trajectory(observations[:, :state_dim],
actions,
encode_angle=True,
encode_angle=encode_angle,
show=True)
wandb.log({'Trajectory_on_true_model': wandb.Image(fig)})

sim_obs = sim_obs[:state_dim]
for i in range(200):
sim_action = policy(jnp.concatenate([sim_obs, sim_stacked_actions], axis=-1))
sim_action = np.array(sim_action)

z = jnp.concatenate([sim_obs, sim_stacked_actions, sim_action], axis=-1)
obs = jnp.stack([sim_obs, sim_stacked_actions], axis=0)
sim_action = policy(obs)
# sim_action = np.array(sim_action)
z = jnp.concatenate([obs, sim_action], axis=-1)
z = z.reshape(1, -1)
delta_x_dist = bnn_model.predict_dist(z, include_noise=True)
sim_key, subkey = jr.split(sim_key)
delta_x = delta_x_dist.sample(seed=subkey)
sim_obs = sim_obs + delta_x.reshape(-1)

# Now we shift the actions
sim_stacked_actions = jnp.roll(sim_stacked_actions, shift=action_dim)
sim_stacked_actions = sim_stacked_actions.at[:action_dim].set(sim_action)
old_sim_stacked_actions = sim_stacked_actions
sim_stacked_actions.at[:-action_dim].set(old_sim_stacked_actions[action_dim:])
sim_stacked_actions = sim_stacked_actions.at[-action_dim:].set(sim_action)
all_sim_actions.append(sim_action)
all_sim_obs.append(sim_obs)
all_sim_stacked_actions.append(sim_stacked_actions)
Expand All @@ -236,7 +245,7 @@ def run_with_learned_policy(policy_params,
sim_actions_for_plotting = np.stack(all_sim_actions, axis=0)
fig, axes = plot_rc_trajectory(sim_observations_for_plotting,
sim_actions_for_plotting,
encode_angle=True,
encode_angle=encode_angle,
show=True)
wandb.log({'Trajectory_on_learned_model': wandb.Image(fig)})
wandb.finish()
Expand Down Expand Up @@ -268,26 +277,30 @@ def plot_error_on_the_trajectory(data):


if __name__ == '__main__':
# import pickle
import pickle

# filename_policy = 'parameters.pkl'
# filename_bnn_model = 'bnn_model.pkl'
filename_policy = 'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=0_num_offline_data' \
'=2500_share_of_x0s=0.5_train_sac_only_from_init_states=0_0.5/tshlnhs0/models/parameters.pkl'
filename_bnn_model = 'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=0_num_offline_data' \
'=2500_share_of_x0s=0.5_train_sac_only_from_init_states=0_0.5/tshlnhs0/models/bnn_model.pkl'

# with open(filename_bnn_model, 'rb') as handle:
# bnn_model = pickle.load(handle)
with open(filename_policy, 'rb') as handle:
policy_params = pickle.load(handle)

# with open(filename_policy, 'rb') as handle:
# policy_params = pickle.load(handle)
with open(filename_bnn_model, 'rb') as handle:
bnn_model = pickle.load(handle)

# observations_for_plotting, actions_for_plotting = run_with_learned_policy(bnn_model=bnn_model,
# policy_params=policy_params,
# project_name='Test',
# group_name='MyGroup',
# run_name='Butterfly'
# )
observations_for_plotting, actions_for_plotting = run_with_learned_policy(bnn_model=bnn_model,
policy_params=policy_params,
project_name='Test',
group_name='MyGroup',
run_name='Butterfly',
control_time_ms=32,
)

run_all_hardware_experiments(
project_name_load='OfflineRLHW_without_frame_stack',
project_name_save='OfflineRLHW_without_frame_stack_evaluation',
desired_config={'bandwidth_svgd': 0.2}
)
desired_config={'bandwidth_svgd': 0.2},
control_time_ms=32,
)

0 comments on commit 129a628

Please sign in to comment.