diff --git a/mdp/measure-rtdp.ipynb b/mdp/measure-rtdp.ipynb new file mode 100644 index 00000000..c468e40f --- /dev/null +++ b/mdp/measure-rtdp.ipynb @@ -0,0 +1,147 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "22ce969b-6f42-4379-b338-d7555cf06d28", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas\n", + "import pickle\n", + "import seaborn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32dc8e26-cc01-4caa-9438-b345fd9b6ded", + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"measure-rtdp.pkl\", \"rb\") as f:\n", + " results = pickle.load(f)\n", + "d = results[\"data\"]\n", + "list(d.columns)\n", + "d[\"fullmodel\"] = d.apply(lambda x: f\"{x.protocol}-{x.model}-trunc{x.trunc}\", axis=1)\n", + "d.columns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ebad0902-19e2-4787-be2d-fc694d45890f", + "metadata": {}, + "outputs": [], + "source": [ + "log = []\n", + "for _, row in d.query('algo == \"rtdp\"').iterrows():\n", + " for entry in row.log:\n", + " params = dict()\n", + " for p in [\"fullmodel\", \"attacker\", \"hyperparams\"]:\n", + " params[p] = row[p]\n", + " log.append(dict(algo=\"rtdp\") | params | entry)\n", + "\n", + " ref_row = d.query(\n", + " f'fullmodel == \"{row.fullmodel}\" and attacker == \"{row.attacker}\" and algo == \"aft20\"'\n", + " )\n", + " assert ref_row.shape[0] == 1\n", + " ref_row = ref_row.iloc[0]\n", + " ref_entry = dict(\n", + " step=entry[\"step\"],\n", + " start_value=ref_row.start_value,\n", + " start_progress=ref_row.start_progress,\n", + " n_states=ref_row.pimc_n_states,\n", + " )\n", + " log.append(dict(algo=\"aft20\") | params | ref_entry)\n", + "\n", + "log = pandas.DataFrame(log)\n", + "log" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba762257-9446-404d-b354-cd99ca428ad8", + "metadata": {}, + "outputs": [], + "source": [ + "seaborn.relplot(\n", + " data=log,\n", + " kind=\"line\",\n", + " x=\"step\",\n", + " y=\"start_value\",\n", + " col=\"attacker\",\n", + " hue=\"algo\",\n", + " row=\"fullmodel\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d62b4a1d-9a1c-491f-9af1-aa1684bb2620", + "metadata": {}, + "outputs": [], + "source": [ + "seaborn.relplot(\n", + " data=log.assign(rpp=lambda x: x.start_value / x.start_progress),\n", + " kind=\"line\",\n", + " x=\"step\",\n", + " y=\"rpp\",\n", + " col=\"attacker\",\n", + " hue=\"algo\",\n", + " row=\"fullmodel\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d8e098c-a54d-42d9-9c85-0d22f8206460", + "metadata": {}, + "outputs": [], + "source": [ + "seaborn.relplot(\n", + " # data = log.query('algo == \"rtdp\"'),\n", + " data=log,\n", + " kind=\"line\",\n", + " x=\"step\",\n", + " y=\"n_states\",\n", + " col=\"attacker\",\n", + " hue=\"algo\",\n", + " row=\"fullmodel\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05232b9e-754b-4b1a-a02c-27ec48a5a760", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}