From 24ab973e8d3d2168a766d7908b339e77972c46cf Mon Sep 17 00:00:00 2001 From: SVJ_Vitor Date: Tue, 27 Aug 2024 14:02:21 +0200 Subject: [PATCH] ENH: add sliders in sympy etapipi notebook (#95) --- docs/eta-pi-p/automated.ipynb | 4 +- docs/eta-pi-p/manual.ipynb | 328 +++++++++++++++++++++++++++++----- 2 files changed, 287 insertions(+), 45 deletions(-) diff --git a/docs/eta-pi-p/automated.ipynb b/docs/eta-pi-p/automated.ipynb index b861d25..7395bb8 100644 --- a/docs/eta-pi-p/automated.ipynb +++ b/docs/eta-pi-p/automated.ipynb @@ -489,7 +489,7 @@ "fig_2d.colorbar(mesh, ax=ax_2d)\n", "\n", "if STATIC_PAGE:\n", - " filename = \"dalitz-plot.png\"\n", + " filename = \"dalitz-plot-auto.png\"\n", " fig_2d.savefig(filename)\n", " plt.close(fig_2d)\n", " display(UI, Image(filename))\n", @@ -580,7 +580,7 @@ "fig.tight_layout()\n", "\n", "if STATIC_PAGE:\n", - " filename = \"1d-histograms.svg\"\n", + " filename = \"1d-histograms-auto.svg\"\n", " fig.savefig(filename)\n", " plt.close(fig)\n", " display(UI, SVG(filename))\n", diff --git a/docs/eta-pi-p/manual.ipynb b/docs/eta-pi-p/manual.ipynb index af29991..cd4b6c0 100644 --- a/docs/eta-pi-p/manual.ipynb +++ b/docs/eta-pi-p/manual.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This section is a follow-up to formulate the amplitude model for the $\\gamma p \\to \\eta\\pi^0 p$ channel example symbolically. See **[TR‑033](https://compwa.github.io/report/033)** for a purely numerical tutorial.\n", + "This section is a follow-up of previous chapter:[Reaction and Models](reaction-model.md), to formulate the amplitude model for the $\\gamma p \\to \\eta\\pi^0 p$ channel example symbolically. See **[TR‑033](https://compwa.github.io/report/033)** for a purely numerical tutorial.\n", "\n", "The model we want to implement is" ] @@ -24,7 +24,7 @@ "\\begin{array}{rcl}\n", "I &=& \\left|A^{12} + A^{23} + A^{31}\\right|^2 \\\\\n", "A^{12} &=& \\frac{\\sum a_m Y_2^m (\\Omega_1)}{s_{12}-m^2_{a_2}+im_{a_2} \\Gamma_{a_2}} \\\\\n", - "A^{23} &=& \\frac{\\sum b_m Y_1^m (\\Omega_2)}{s_{23}-m^2_{\\Delta}+im_{\\Delta} \\Gamma_{\\Delta}} \\\\\n", + "A^{23} &=& \\frac{\\sum b_m Y_1^m (\\Omega_2)}{s_{23}-m^2_{\\Delta^+}+im_{\\Delta^+} \\Gamma_{\\Delta^+}} \\\\\n", "A^{31} &=& \\frac{c_0}{s_{31}-m^2_{N^*}+im_{N^*} \\Gamma_{N^*}} \\,,\n", "\\end{array}\n", "$$" @@ -41,6 +41,9 @@ "cell_type": "code", "execution_count": null, "metadata": { + "jupyter": { + "source_hidden": true + }, "mystnb": { "code_prompt_show": "Import Python libraries" }, @@ -52,6 +55,13 @@ "source": [ "from __future__ import annotations\n", "\n", + "import logging\n", + "import os\n", + "import warnings\n", + "from collections import defaultdict\n", + "\n", + "import ipywidgets as w\n", + "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import sympy as sp\n", @@ -71,13 +81,19 @@ ")\n", "from ampform.sympy import unevaluated\n", "from ampform.sympy._array_expressions import ArraySum\n", - "from IPython.display import Latex\n", + "from IPython.display import SVG, Image, Latex, display\n", "from tensorwaves.data import (\n", " SympyDataTransformer,\n", " TFPhaseSpaceGenerator,\n", " TFUniformRealNumberGenerator,\n", ")\n", - "from tensorwaves.function.sympy import create_parametrized_function" + "from tensorwaves.function.sympy import create_parametrized_function\n", + "\n", + "STATIC_PAGE = \"EXECUTE_NB\" in os.environ\n", + "\n", + "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n", + "logging.disable(logging.WARNING)\n", + "warnings.filterwarnings(\"ignore\")" ] }, { @@ -171,7 +187,7 @@ "metadata": {}, "outputs": [], "source": [ - "s23, m_delta, gamma_delta = sp.symbols(\"s_{23} m_Delta Gamma_Delta\")\n", + "s23, m_delta, gamma_delta = sp.symbols(r\"s_{23} m_{\\Delta^+} \\Gamma_{\\Delta^+}\")\n", "b = sp.IndexedBase(\"b\")\n", "m = sp.symbols(\"m\", cls=sp.Idx)\n", "theta2, phi2 = sp.symbols(\"theta_2 phi_2\")\n", @@ -471,6 +487,15 @@ "Latex(aslatex(parameters_default))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::{note}\n", + "The mass and width of each resonance is customized to make the resonance bands in the Dalitz plot more visible.\n", + ":::" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -601,6 +626,103 @@ "### Dalitz Plot" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input", + "scroll-input" + ] + }, + "outputs": [], + "source": [ + "sliders = {}\n", + "categorized_sliders_m = defaultdict(list)\n", + "categorized_sliders_gamma = defaultdict(list)\n", + "categorized_cphi_pair = defaultdict(list)\n", + "\n", + "for symbol, value in parameters_default.items():\n", + " if symbol.name.startswith(R\"\\Gamma_{\"):\n", + " slider = w.FloatSlider(\n", + " description=Rf\"\\({sp.latex(symbol)}\\)\",\n", + " min=0.0,\n", + " max=1.0,\n", + " step=0.01,\n", + " value=value,\n", + " continuous_update=False,\n", + " )\n", + " sliders[symbol.name] = slider\n", + " if symbol.name.startswith(R\"\\Gamma_{N\"):\n", + " categorized_sliders_gamma[0].append(slider)\n", + " elif symbol.name.startswith(R\"\\Gamma_{\\D\"):\n", + " categorized_sliders_gamma[1].append(slider)\n", + " elif symbol.name.startswith(R\"\\Gamma_{a\"):\n", + " categorized_sliders_gamma[2].append(slider)\n", + "\n", + " elif symbol.name.startswith(\"m_{\"):\n", + " slider = w.FloatSlider(\n", + " description=Rf\"\\({sp.latex(symbol)}\\)\",\n", + " min=0.63,\n", + " max=4,\n", + " step=0.01,\n", + " value=value,\n", + " continuous_update=False,\n", + " )\n", + " sliders[symbol.name] = slider\n", + " if symbol.name.startswith(\"m_{N\"):\n", + " categorized_sliders_m[0].append(slider)\n", + " elif symbol.name.startswith(R\"m_{\\D\"):\n", + " categorized_sliders_m[1].append(slider)\n", + " elif symbol.name.startswith(\"m_{a\"):\n", + " categorized_sliders_m[2].append(slider)\n", + "\n", + " else:\n", + " c_latex = sp.latex(symbol)\n", + " phi_latex = Rf\"\\phi_{{{c_latex}}}\"\n", + " slider_c = w.FloatSlider(\n", + " description=Rf\"\\({c_latex}\\)\",\n", + " min=0,\n", + " max=10,\n", + " step=0.01,\n", + " value=abs(value),\n", + " continuous_update=False,\n", + " )\n", + " slider_phi = w.FloatSlider(\n", + " description=Rf\"\\({phi_latex}\\)\",\n", + " min=-np.pi,\n", + " max=+np.pi,\n", + " step=0.01,\n", + " value=np.angle(value),\n", + " continuous_update=False,\n", + " )\n", + "\n", + " sliders[symbol.name] = slider_c\n", + " sliders[f\"phi_{symbol.name}\"] = slider_phi\n", + " cphi_hbox = w.HBox([slider_c, slider_phi])\n", + " if symbol.base is a:\n", + " categorized_cphi_pair[2].append(cphi_hbox)\n", + " elif symbol.base is b:\n", + " categorized_cphi_pair[1].append(cphi_hbox)\n", + " elif symbol.base is c:\n", + " categorized_cphi_pair[0].append(cphi_hbox)\n", + " else:\n", + " raise NotImplementedError(symbol.name)\n", + "\n", + "tab_contents = []\n", + "resonances_name = [\"N*\", \"Δ⁺\", \"a₂\"]\n", + "for i in range(len(resonances_name)):\n", + " tab_content = w.VBox([\n", + " w.HBox(categorized_sliders_m[i] + categorized_sliders_gamma[i]),\n", + " w.VBox(categorized_cphi_pair[i]),\n", + " ])\n", + " tab_contents.append(tab_content)\n", + "UI = w.Tab(tab_contents, titles=resonances_name)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -624,6 +746,34 @@ "intensities = intensity_func(phsp)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "def insert_phi(parameters: dict) -> dict:\n", + " updated_parameters = {}\n", + " for key, value in parameters.items():\n", + " if key.startswith(\"phi_\"):\n", + " continue\n", + " if key.startswith((\"a\", \"b\", \"c\")):\n", + " phi_key = f\"phi_{key}\"\n", + " if phi_key in parameters:\n", + " phi = parameters[phi_key]\n", + " value *= np.exp(1j * phi) # noqa:PLW2901\n", + " updated_parameters[key] = value\n", + "\n", + " return updated_parameters" + ] + }, { "cell_type": "code", "execution_count": null, @@ -638,25 +788,48 @@ }, "outputs": [], "source": [ + "%matplotlib widget\n", "%config InlineBackend.figure_formats = ['png']\n", + "fig_2d, ax_2d = plt.subplots(dpi=200)\n", + "ax_2d.set_title(\"Model-weighted Phase space Dalitz Plot\")\n", + "ax_2d.set_xlabel(R\"$m^2(\\eta \\pi^0)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", + "ax_2d.set_ylabel(R\"$m^2(\\pi^0 p)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", "\n", - "fig, ax = plt.subplots(dpi=200)\n", - "hist = ax.hist2d(\n", - " phsp[\"s_{12}\"],\n", - " phsp[\"s_{23}\"],\n", - " bins=200,\n", - " cmin=1e-6,\n", - " density=True,\n", - " cmap=\"jet\",\n", - " vmax=0.15,\n", - " weights=intensities,\n", - ")\n", - "ax.set_title(\"Model-weighted Phase space Dalitz Plot\")\n", - "ax.set_xlabel(R\"$m^2(\\eta \\pi^0)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", - "ax.set_ylabel(R\"$m^2(\\pi^0 p)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", - "fig.colorbar(hist[3], ax=ax)\n", - "fig.tight_layout()\n", - "plt.show()" + "mesh = None\n", + "\n", + "\n", + "def update_histogram(**parameters):\n", + " global mesh\n", + " parameters = insert_phi(parameters)\n", + " intensity_func.update_parameters(parameters)\n", + " intensity_weights = intensity_func(phsp)\n", + " bin_values, xedges, yedges = jnp.histogram2d(\n", + " phsp[\"s_{12}\"],\n", + " phsp[\"s_{23}\"],\n", + " bins=200,\n", + " weights=intensity_weights,\n", + " density=True,\n", + " )\n", + " bin_values = jnp.where(bin_values < 1e-6, jnp.nan, bin_values)\n", + " x, y = jnp.meshgrid(xedges[:-1], yedges[:-1])\n", + " if mesh is None:\n", + " mesh = ax_2d.pcolormesh(x, y, bin_values.T, cmap=\"jet\", vmax=0.15)\n", + " else:\n", + " mesh.set_array(bin_values.T)\n", + " fig_2d.canvas.draw_idle()\n", + "\n", + "\n", + "interactive_plot = w.interactive_output(update_histogram, sliders)\n", + "fig_2d.tight_layout()\n", + "fig_2d.colorbar(mesh, ax=ax_2d)\n", + "\n", + "if STATIC_PAGE:\n", + " filename = \"dalitz-plot-man.png\"\n", + " fig_2d.savefig(filename)\n", + " plt.close(fig_2d)\n", + " display(UI, Image(filename))\n", + "else:\n", + " display(UI, interactive_plot)" ] }, { @@ -680,6 +853,7 @@ }, "outputs": [], "source": [ + "%matplotlib widget\n", "%config InlineBackend.figure_formats = ['svg']\n", "\n", "theta_subtitles = [\n", @@ -698,33 +872,101 @@ " R\"$m_{31} \\equiv m_{p \\eta}$\",\n", "]\n", "\n", - "fig, (theta_ax, phi_ax, mass_ax) = plt.subplots(figsize=(12, 8), ncols=3, nrows=3)\n", - "for i, ax1 in enumerate(theta_ax, 1):\n", - " ax1.set_title(theta_subtitles[i - 1])\n", - " ax1.set_xticks([-1, 0, 1])\n", + "fig, (theta_axes, phi_axes, mass_axes) = plt.subplots(figsize=(12, 8), ncols=3, nrows=3)\n", + "fig.canvas.toolbar_visible = False\n", + "fig.canvas.header_visible = False\n", + "fig.canvas.footer_visible = False\n", + "\n", + "for i, ax in enumerate(theta_axes):\n", + " ax.set_title(theta_subtitles[i])\n", + " ax.set_xticks([-1, 0, 1])\n", + "\n", + "for i, ax in enumerate(phi_axes):\n", + " ax.set_title(phi_subtitles[i])\n", + " ax.set_xticks([-np.pi, 0, np.pi])\n", + " ax.set_xticklabels([R\"-$\\pi$\", 0, R\"$\\pi$\"])\n", + "\n", + "for i, ax in enumerate(mass_axes):\n", + " ax.set_title(mass_subtitles[i])\n", + " ax.set_xlabel(\"Mass [GeV]\")\n", + "\n", + "LINES = 3 * [None]\n", + "THETA_LINES = 3 * [None]\n", + "PHI_LINES = 3 * [None]\n", + "RESONANCE_LINES = defaultdict(dict)\n", + "RESONANCES_MASS_NAME = [m_a2, m_delta, m_nstar]\n", + "\n", "\n", - "for i, ax2 in enumerate(phi_ax, 1):\n", - " ax2.set_title(phi_subtitles[i - 1])\n", - " ax2.set_xticks([-np.pi, 0, np.pi])\n", - " ax2.set_xticklabels([R\"-$\\pi$\", 0, R\"$\\pi$\"])\n", + "def update_plot(**parameters): # noqa: C901, PLR0912, PLR0914\n", + " parameters = insert_phi(parameters)\n", + " intensity_func.update_parameters(parameters)\n", + " intensities = intensity_func(phsp)\n", + " max_value_theta = 0.0\n", + " max_value_phi = 0.0\n", + " max_value_mass = 0.0\n", + " theta_keys = [\"theta_1\", \"theta_2\", \"theta_3\"]\n", + " phi_keys = [\"phi_1\", \"phi_2\", \"phi_3\"]\n", + " m2_keys = [\"s_{12}\", \"s_{23}\", \"s_{31}\"]\n", + " line_labels = [R\"$m_{a_2}$\", R\"$m_{\\Delta}$\", R\"$m_{N^*}$\"]\n", + " line_colors = [\"r\", \"g\", \"b\"]\n", + " plot_style = dict(bins=120, weights=intensities, density=True)\n", "\n", - "for i, ax3 in enumerate(mass_ax, 1):\n", - " ax3.set_title(mass_subtitles[i - 1])\n", + " for i, ax in enumerate(mass_axes):\n", + " bin_values, bin_edges = jnp.histogram(np.sqrt(phsp[m2_keys[i]]), **plot_style)\n", + " max_value_mass = max(max_value_mass, bin_values.max())\n", "\n", - "plot_style = dict(bins=100, weights=intensities, density=True)\n", + " if LINES[i] is None:\n", + " LINES[i] = ax.step(bin_edges[:-1], bin_values, alpha=0.5)[0]\n", + " else:\n", + " LINES[i].set_ydata(bin_values)\n", "\n", - "theta_ax[0].hist(np.cos(phsp[\"theta_1\"]), **plot_style)\n", - "theta_ax[1].hist(np.cos(phsp[\"theta_2\"]), **plot_style)\n", - "theta_ax[2].hist(np.cos(phsp[\"theta_3\"]), **plot_style)\n", - "phi_ax[0].hist(phsp[\"phi_1\"], **plot_style)\n", - "phi_ax[1].hist(phsp[\"phi_2\"], **plot_style)\n", - "phi_ax[2].hist(phsp[\"phi_3\"], **plot_style)\n", - "mass_ax[0].hist(np.sqrt(phsp[\"s_{12}\"]), **plot_style)\n", - "mass_ax[1].hist(np.sqrt(phsp[\"s_{23}\"]), **plot_style)\n", - "mass_ax[2].hist(np.sqrt(phsp[\"s_{31}\"]), **plot_style)\n", + " symbol_key = sp.latex([m_a2, m_delta, m_nstar][i])\n", + " val = parameters[symbol_key]\n", + " resonance_line = RESONANCE_LINES[i].get(symbol_key)\n", + " if resonance_line is None:\n", + " RESONANCE_LINES[i][symbol_key] = ax.axvline(\n", + " val, color=line_colors[i], linestyle=\"--\", label=line_labels[i]\n", + " )\n", + " else:\n", + " resonance_line.set_xdata([val, val])\n", "\n", + " for i, ax in enumerate(theta_axes):\n", + " bin_values, bin_edges = jnp.histogram(np.cos(phsp[theta_keys[i]]), **plot_style)\n", + " max_value_theta = max(max_value_theta, bin_values.max())\n", + " if THETA_LINES[i] is None:\n", + " THETA_LINES[i] = ax.step(bin_edges[:-1], bin_values, alpha=0.5)[0]\n", + " else:\n", + " THETA_LINES[i].set_ydata(bin_values)\n", + "\n", + " for i, ax in enumerate(phi_axes):\n", + " bin_values, bin_edges = jnp.histogram(phsp[phi_keys[i]], **plot_style)\n", + " max_value_phi = max(max_value_phi, bin_values.max())\n", + " if PHI_LINES[i] is None:\n", + " PHI_LINES[i] = ax.step(bin_edges[:-1], bin_values, alpha=0.5)[0]\n", + " else:\n", + " PHI_LINES[i].set_ydata(bin_values)\n", + "\n", + " for ax in mass_axes:\n", + " ax.set_ylim(0, max_value_mass * 1.1)\n", + " ax.legend()\n", + "\n", + " for ax in theta_axes:\n", + " ax.set_ylim(0, max_value_theta * 1.1)\n", + "\n", + " for ax in phi_axes:\n", + " ax.set_ylim(0, max_value_phi * 1.1)\n", + "\n", + "\n", + "interactive_plot = w.interactive_output(update_plot, sliders)\n", "fig.tight_layout()\n", - "plt.show()" + "\n", + "if STATIC_PAGE:\n", + " filename = \"1d-histograms-man.svg\"\n", + " fig.savefig(filename)\n", + " plt.close(fig)\n", + " display(UI, SVG(filename))\n", + "else:\n", + " display(UI, interactive_plot)" ] } ],