From f8fcabd7de30e3e7970863e0a87e6ab093b1908a Mon Sep 17 00:00:00 2001 From: lenarttreven Date: Fri, 16 Feb 2024 13:09:47 +0100 Subject: [PATCH] update plotting_constants.py --- .../offline_rl_from_recorded_data/launcher.py | 2 +- lenart_internal/analysis_simulator_data.py | 118 ++++++------------ plotting_hyperdata/plotting_constants.py | 7 ++ 3 files changed, 43 insertions(+), 84 deletions(-) diff --git a/experiments/offline_rl_from_recorded_data/launcher.py b/experiments/offline_rl_from_recorded_data/launcher.py index 8f2bb8a..b4cc791 100644 --- a/experiments/offline_rl_from_recorded_data/launcher.py +++ b/experiments/offline_rl_from_recorded_data/launcher.py @@ -4,7 +4,7 @@ PROJECT_NAME = 'OfflineRLSim_Jan12' _applicable_configs = { - 'horizon_len': [100], + 'horizon_len': [200], 'model_seed': list(range(5)), 'data_seed': list(range(5)), 'project_name': [PROJECT_NAME], diff --git a/lenart_internal/analysis_simulator_data.py b/lenart_internal/analysis_simulator_data.py index 9e4e046..c1edd3f 100644 --- a/lenart_internal/analysis_simulator_data.py +++ b/lenart_internal/analysis_simulator_data.py @@ -3,9 +3,10 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd +from plotting_hyperdata import plotting_constants +LINEWIDTH = 3 SIMULATION = True -BEST_MODEL_FREE = 164.53915 if SIMULATION: reward_column = 'reward_mean_on_simulator' @@ -15,41 +16,37 @@ folder = 'real_data' + def to_pretty_group_name(group_name: str): if 'use_grey_box=1' in group_name: return 'Grey box' elif 'use_sim_prior=0' in group_name: - return 'No sim prior' + return 'No sim prior model' elif 'use_sim_prior=1' in group_name and 'high_fidelity=0' in group_name: - return 'Low fidelity' + return 'Low fidelity model' elif 'use_sim_prior=1' in group_name and 'high_fidelity=1' in group_name: - return 'High fidelity' + return 'High fidelity model' -def join_group_names_and_prepare_statistics_mean(data: pd.DataFrame): +def join_group_names_and_prepare_statistics_mean_SIM(data: pd.DataFrame, ): # Summary of rewards - summary_rewards = data.groupby('Group')[reward_column].agg(['mean', 'std']) - - summary = pd.concat([summary_rewards], axis=1) - summary.reset_index(inplace=True) - summary.columns = ['group_name', 'mean', 'std', ] - return summary + summary_rewards = data.groupby('Group')[reward_column].agg([ + lambda x: smooth_curve(x, type='mean'), + lambda x: smooth_curve(x, type='std_err'), + 'max', - -def join_group_names_and_prepare_statistics(data: pd.DataFrame): - summary_rewards = data.groupby('Group')[reward_column].agg(['median', - lambda x: x.quantile(0.2), # 0.2 quantile - lambda x: x.quantile(0.8) # 0.8 quantile - ]) + ]) summary = pd.concat([summary_rewards], axis=1) summary.reset_index(inplace=True) - summary.columns = ['group_name', - 'median_rewards', '0.2_quantile_rewards', '0.8_quantile_rewards'] + summary.columns = ['group_name', 'mean', 'std', 'max'] return summary -def plot_rewards_mean(data: pd.DataFrame, max_offline_data: int | None = None, sup_title: str = None): +def plot_rewards_mean_SIM(data: pd.DataFrame, + max_offline_data: int | None = None, + sup_title: str = None, + width: float = 2.0): offline_transitions = data['num_offline_collected_transitions'].unique() offline_transitions.sort() @@ -60,7 +57,7 @@ def plot_rewards_mean(data: pd.DataFrame, max_offline_data: int | None = None, s for index, offline_transition in enumerate(offline_transitions): cur_data = data.loc[data['num_offline_collected_transitions'] == offline_transition] - summary = join_group_names_and_prepare_statistics_mean(cur_data) + summary = join_group_names_and_prepare_statistics_mean_SIM(cur_data) means.append(summary['mean'].to_numpy()) stds.append(summary['std'].to_numpy()) @@ -76,15 +73,21 @@ def plot_rewards_mean(data: pd.DataFrame, max_offline_data: int | None = None, s stds = stds[idx] for index, group in enumerate(group_names): - ax.plot(offline_transitions, means[:, index], label=group) - ax.fill_between(offline_transitions, - means[:, index] - (stds[:, index]) / (1), - means[:, index] + stds[:, index] / (1), alpha=0.2) - if max_offline_data is None: - max_offline_data = offline_transitions[-1] - ax.hlines(y=BEST_MODEL_FREE, xmin=0, xmax=max_offline_data, linewidth=2, linestyle='--', label='Best model free') + if group in plotting_constants.offline_rl_names_transfer.keys(): + method_name = plotting_constants.offline_rl_names_transfer[group] + ax.plot(offline_transitions, means[:, index], + label=method_name, + color=plotting_constants.COLORS[method_name], + linestyle=plotting_constants.LINE_STYLES[method_name], + linewidth=LINEWIDTH) + ax.fill_between(offline_transitions, + means[:, index] - width * stds[:, index], + means[:, index] + width * stds[:, index], + alpha=0.2, + color=plotting_constants.COLORS[method_name], ) ax.set_title(r'Mean $\pm$ std error') ax.set_xlabel('Number of offline transitions') + ax.set_ylim(0) ax.set_ylabel('Reward') fig.suptitle(sup_title) dir = folder @@ -96,65 +99,14 @@ def plot_rewards_mean(data: pd.DataFrame, max_offline_data: int | None = None, s plt.show() -def plot_rewards(data: pd.DataFrame, max_offline_data: int | None = None, sup_title: str = None): - offline_transitions = data['num_offline_collected_transitions'].unique() - offline_transitions.sort() - - medians = [] - quantile_20 = [] - quantile_80 = [] - group_names = [] - - fig, ax = plt.subplots(1, 1, figsize=(10, 4)) - - for index, offline_transition in enumerate(offline_transitions): - cur_data = data.loc[data['num_offline_collected_transitions'] == offline_transition] - summary = join_group_names_and_prepare_statistics(cur_data) - - medians.append(summary['median_rewards'].to_numpy()) - quantile_20.append(summary['0.2_quantile_rewards'].to_numpy()) - quantile_80.append(summary['0.8_quantile_rewards'].to_numpy()) - if index == 0: - group_names = list(map(lambda x: to_pretty_group_name(x), summary['group_name'])) - - medians = np.stack(medians, axis=0) - quantile_20 = np.stack(quantile_20, axis=0) - quantile_80 = np.stack(quantile_80, axis=0) - if max_offline_data: - idx = offline_transitions <= max_offline_data - offline_transitions = offline_transitions[idx] - medians = medians[idx] - quantile_20 = quantile_20[idx] - quantile_80 = quantile_80[idx] - - for index, group in enumerate(group_names): - ax.plot(offline_transitions, medians[:, index], label=group) - ax.fill_between(offline_transitions, quantile_20[:, index], quantile_80[:, index], alpha=0.2) - - if max_offline_data is None: - max_offline_data = offline_transitions[-1] - - ax.hlines(y=BEST_MODEL_FREE, xmin=0, xmax=max_offline_data, linewidth=2, linestyle='--', label='Best model free') - ax.set_title(f'Median and 0.2-0.8 confidence interval') - ax.set_xlabel('Number of offline transitions') - ax.set_ylabel('Reward') - fig.suptitle(sup_title) - dir = folder - if not os.path.exists(dir): - os.makedirs(dir) - title = sup_title + 'offline_rl_simulated_data.pdf' - plt.legend() - plt.savefig(os.path.join(dir, title)) - plt.show() - if __name__ == '__main__': - max_offline_data = 3000 - data = pd.read_csv('wandb_runs.csv') + max_offline_data = 5_000 + data = pd.read_csv('wandb_runs_sim_final.csv') bandwidth_svgd = data.bandwidth_svgd.unique() length_scale_aditive_sim_gp = data.length_scale_aditive_sim_gp.unique() for i in bandwidth_svgd: for j in length_scale_aditive_sim_gp: filtered_data = data[(data.bandwidth_svgd == i) & (data.length_scale_aditive_sim_gp == j)] - plot_rewards_mean(filtered_data, max_offline_data=max_offline_data, - sup_title=f'bandwidth_svgd={i}, length_scale_aditive_sim_gp={j}') + plot_rewards_mean_SIM(filtered_data, max_offline_data=max_offline_data, + sup_title=f'bandwidth_svgd={i}, length_scale_aditive_sim_gp={j}') diff --git a/plotting_hyperdata/plotting_constants.py b/plotting_hyperdata/plotting_constants.py index 2a41675..af6241d 100644 --- a/plotting_hyperdata/plotting_constants.py +++ b/plotting_hyperdata/plotting_constants.py @@ -126,6 +126,13 @@ 'nll': 'NLL', } +offline_rl_names_transfer = { + 'No sim prior model': METHODS[1], + 'Low fidelity model': METHODS[11], + 'High fidelity model': METHODS[12], + 'Parameteric model': METHODS[16] +} + online_rl_name_transfer = { 'No sim prior': METHODS[1], 'Low fidelity prior': METHODS[11]