Skip to content

Commit

Permalink
improve sampling for annotation (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuhuiZhou authored Sep 17, 2024
1 parent 82b6552 commit 9cba970
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 28 deletions.
102 changes: 77 additions & 25 deletions examples/notebooks/figs_and_tables.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -698,29 +698,32 @@
"metadata": {},
"outputs": [],
"source": [
"from sotopia.database import EpisodeLog\n",
"\n",
"from haicosystem.protocols import HaiEnvironmentProfile\n",
"from haicosystem.utils import get_avg_reward\n",
"\n",
"model_rewards = {}\n",
"for model, tag in zip(models, tags):\n",
" episodes = EpisodeLog.find(EpisodeLog.tag == tag).all()\n",
" benign_intent_episodes = []\n",
" malicious_intent_episodes = []\n",
" episodes_with_tools = []\n",
" episodes_wo_tools = []\n",
" for episode in episodes:\n",
" env = HaiEnvironmentProfile.get(episode.environment)\n",
" tools_or_not = len(env.toolkits) > 0\n",
" if tools_or_not:\n",
" continue\n",
" if env.agent_intent_labels[0] == \"benign\":\n",
" benign_intent_episodes.append(episode)\n",
" episodes_with_tools.append(episode)\n",
" else:\n",
" malicious_intent_episodes.append(episode)\n",
" episodes_wo_tools.append(episode)\n",
" print(\n",
" f\"the number of the datapoints for goal and risk for each model: {len(malicious_intent_episodes)}\"\n",
" f\"the number of the datapoints for goal and risk for each model: {len(episodes_wo_tools)}\"\n",
" )\n",
" try:\n",
" benign_avg_rewards = get_avg_reward(benign_intent_episodes, model) # type: ignore\n",
" avg_rewards_wo_tools = get_avg_reward(episodes_wo_tools, model) # type: ignore\n",
" except Exception as e:\n",
" benign_avg_rewards = {}\n",
" malicious_avg_rewards = get_avg_reward(malicious_intent_episodes, model) # type: ignore\n",
" model_rewards[model] = (benign_avg_rewards, malicious_avg_rewards)"
" avg_rewards_wo_tools = {}\n",
" avg_rewards_with_tools = get_avg_reward(episodes_with_tools, model) # type: ignore\n",
" model_rewards[model] = (avg_rewards_with_tools, avg_rewards_wo_tools)"
]
},
{
Expand All @@ -729,34 +732,81 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Dict\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"\n",
"custom_params = {\"axes.spines.right\": False, \"axes.spines.top\": False}\n",
"sns.set_theme(style=\"whitegrid\", rc=custom_params)\n",
"\n",
"\n",
"def draw_malicious_intent_bar_plot(\n",
" data: Dict[\n",
" str, Tuple[Dict[str, Tuple[float, float]], Dict[str, Tuple[float, float]]]\n",
" data: dict[\n",
" str, tuple[dict[str, tuple[float, float]], dict[str, tuple[float, float]]]\n",
" ],\n",
" save_path: str,\n",
") -> None:\n",
" # Prepare data for plotting\n",
" models = []\n",
" goal_scores = []\n",
" risk_scores = []\n",
" risk_scores_wo_tools = []\n",
" goal_scores_wo_tools = []\n",
" # the second element of the tuple is the malicious intent\n",
" for model_name, (benign, malicious) in data.items():\n",
" models.append(models_mapping[model_name])\n",
" risk_scores.append(malicious[\"targeted_safety_risks\"][0])\n",
" goal_scores.append(malicious[\"goal\"][0])\n",
" risk_scores.append(benign[\"targeted_safety_risks\"][0])\n",
" goal_scores.append(benign[\"efficiency\"][0])\n",
" risk_scores_wo_tools.append(malicious[\"targeted_safety_risks\"][0])\n",
" goal_scores_wo_tools.append(0.0)\n",
" plot_data = {\n",
" \"Model\": models + models,\n",
" \"Score\": goal_scores + risk_scores,\n",
" \"Metric\": [\"Goal\"] * len(goal_scores) + [\"Risk\"] * len(risk_scores),\n",
" \"Model\": models,\n",
" \"Efficiency\": goal_scores,\n",
" \"Risk\": risk_scores,\n",
" \"Risk (wo tools)\": risk_scores_wo_tools,\n",
" }\n",
" plt.figure(figsize=(6, 4))\n",
" custom_palette = {\"Goal\": \"#4dabf7\", \"Risk\": \"#ff6b6b\"}\n",
" ax = sns.barplot(\n",
" x=\"Model\", y=\"Score\", hue=\"Metric\", data=plot_data, palette=custom_palette\n",
" custom_palette = {\n",
" \"Efficiency\": \"#63e6be\",\n",
" \"Risk\": \"#ff6b6b\",\n",
" \"Risk (wo tools)\": \"orange\",\n",
" }\n",
" plot_data_df = pd.DataFrame(plot_data)\n",
"\n",
" # Plot Efficiency and Risk as stacked bars\n",
" fig, ax = plt.subplots(figsize=(6, 4))\n",
" bar_offset = 0.5 # Adjust this value to control the gap between models\n",
" bar_width = 0.25\n",
" # Plot Efficiency and Risk as stacked bars\n",
" plot_data_df.plot(\n",
" x=\"Model\",\n",
" y=[\"Efficiency\", \"Risk\"],\n",
" kind=\"bar\",\n",
" stacked=True,\n",
" color=[custom_palette[\"Efficiency\"], custom_palette[\"Risk\"]],\n",
" width=bar_width, # Increase the width of the bars to reduce the gap between models\n",
" ax=ax,\n",
" position=1 - bar_offset,\n",
" )\n",
"\n",
" # Plot Risk (wo tools) as a separate bar\n",
" plot_data_df.plot(\n",
" x=\"Model\",\n",
" y=\"Risk (wo tools)\",\n",
" kind=\"bar\",\n",
" color=custom_palette[\"Risk (wo tools)\"],\n",
" ax=ax,\n",
" width=bar_width, # Increase the width of the bars to reduce the gap between models\n",
" position=1 + bar_offset, # Align the position to overlap the bars\n",
" )\n",
" ax.set_xlabel(\"\") # Remove the x-axis label\n",
" ax.legend(fontsize=\"small\") # Set smaller legend\n",
"\n",
" ax.set_xticklabels(\n",
" ax.get_xticklabels(), rotation=0\n",
" ) # Set x-axis labels to horizontal\n",
"\n",
" ax.set_xlabel(\"\") # Remove the x-axis label\n",
" ax.legend(fontsize=\"x-small\") # Set smaller legend\n",
" # Add numbers on each bar, excluding 0.0\n",
" for p in ax.patches:\n",
" height = p.get_height()\n",
Expand All @@ -766,9 +816,11 @@
" (p.get_x() + p.get_width() / 2.0, height),\n",
" ha=\"center\",\n",
" va=\"center\",\n",
" xytext=(0, 9),\n",
" fontsize=\"x-small\", # Make the text smaller\n",
" xytext=(0, 9 if height > 0 else -9),\n",
" textcoords=\"offset points\",\n",
" )\n",
" ax.set_ylim(bottom=-8) # Increase the y limit to -9\n",
" if save_path:\n",
" plt.savefig(save_path, format=\"pdf\", bbox_inches=\"tight\")\n",
" plt.show()\n",
Expand Down
7 changes: 4 additions & 3 deletions examples/sample_for_annotating_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,11 @@ def sample_episodes_for_annotating_eval(
Sample episodes for annotating evaluation
"""
episode_collection = []
episode_tag_list = episode_tags.split(",")
episode_tag_list = [tag.strip() for tag in episode_tags.split(",")]
for episode_tag in episode_tag_list:
episode_logs = EpisodeLog.find(EpisodeLog.tag == episode_tag).all()
episode_collection.extend(episode_logs)
print(f"Number of episodes to sample: {len(episode_collection)}")
sampled_episodes = random.sample(episode_collection, num_episodes)
evaluation_metrics = (
["risky_or_not"]
Expand All @@ -111,8 +112,8 @@ def sample_episodes_for_annotating_eval(
for episode in sampled_episodes:
assert isinstance(episode, EpisodeLog)
assert episode.models is not None
episode_link = f"http://localhost:8501/?pk={episode.pk}"
episode_row = [episode.pk, episode_link, episode.models[0]]
episode_link = f"http://128.2.218.53:8501/?pk={episode.pk}"
episode_row = [episode.pk, episode_link, episode.models[2]]
for metric in evaluation_metrics:
episode_row.append("")
spreadsheet_rows.append(episode_row)
Expand Down

0 comments on commit 9cba970

Please sign in to comment.