From ec5215ba79a8f4e14e7e89eafa041e7323d437df Mon Sep 17 00:00:00 2001 From: Georg Schramm <40211162+gschramm@users.noreply.github.com> Date: Thu, 16 May 2024 19:08:41 +0200 Subject: [PATCH] listmode DL notebooks for PSMR 2024 (#225) * init commit of first LM DL PET notebooks * rename notebook * ignore all notebooks (auto generated from jupytext) * update README * add notebook file and learning objectives * add sirf wsl build help * update help * add first figure * work on intro * work on intro * work on intro * move solution into separate file * add array notebook * add LM recon mini example * notebooks/Deep_Learning_listmode_PET/01_SIRF_listmode_recon.py * work on LM recon notebook * format notebook * move solutions to snippets * use inverse sens images in updates * correct LM OSEM solution * add test script for hessian in sinogram mode * cosmetics * add manual Hessian test * add LM Hessian test * use acq storage memory * add skeleton for 3rd notebook * work on 3rd notebook * work on 3rd notebook * add custom layer notebook * structure notebooks * structure notebooks * work on notebook 4 * add 04/05 notebooks * WIP * add TODO * add new figures * add new figure * wip * update plots * black reformat * update figure * clean up * clean up * clean up * update TODO * WIP to get 60min recons * use 60min ref recon * clean up * clean up * reformat * add ipynb versions * fix typos --- .../Deep_Learning_listmode_PET/.gitignore | 5 + .../00_introduction.ipynb | 226 ++++++ .../00_introduction.py | 156 ++++ .../01_SIRF_listmode_recon.ipynb | 679 ++++++++++++++++++ .../01_SIRF_listmode_recon.py | 409 +++++++++++ .../02_SIRF_vs_torch_arrays.ipynb | 257 +++++++ .../02_SIRF_vs_torch_arrays.py | 160 +++++ .../03_custom_torch_layers.ipynb | 418 +++++++++++ .../03_custom_torch_layers.py | 263 +++++++ .../04_custom_sirf_Poisson_logL_layer.ipynb | 496 +++++++++++++ .../04_custom_sirf_Poisson_logL_layer.py | 363 ++++++++++ .../05_custrom_unrolled_varnet.ipynb | 609 ++++++++++++++++ .../05_custrom_unrolled_varnet.py | 484 +++++++++++++ .../06_outlook.ipynb | 43 ++ .../Deep_Learning_listmode_PET/06_outlook.py | 24 + .../Deep_Learning_listmode_PET/README.md | 31 + notebooks/Deep_Learning_listmode_PET/TODO.txt | 3 + .../figs/.gitignore | 1 + .../figs/osem_layer.drawio | 109 +++ .../figs/osem_layer.drawio.svg | 4 + .../figs/osem_varnet.drawio | 156 ++++ .../figs/osem_varnet.drawio.svg | 4 + .../figs/poisson_logL_grad_layer.drawio | 40 ++ .../figs/poisson_logL_grad_layer.drawio.svg | 4 + .../figs/varnet.drawio | 127 ++++ .../figs/varnet.drawio.svg | 4 + .../snippets/solution_0_1.md | 17 + .../snippets/solution_1_1.py | 1 + .../snippets/solution_1_2.py | 10 + .../snippets/solution_1_3.py | 23 + .../snippets/solution_1_4.py | 32 + .../snippets/solution_2_1.py | 12 + .../snippets/solution_3_1.py | 0 .../snippets/solution_4_1.py | 97 +++ .../snippets/solution_4_2.py | 60 ++ .../snippets/solution_5_1.py | 68 ++ .../test/lm_data_fid.py | 122 ++++ .../test/stir_torch_lm_em_layer.py | 102 +++ .../test/test_grad_layer.py | 211 ++++++ .../test/test_hessian.py | 318 ++++++++ .../test/torch_em_layers.py | 202 ++++++ .../test/train_varnet.py | 421 +++++++++++ 42 files changed, 6771 insertions(+) create mode 100644 notebooks/Deep_Learning_listmode_PET/.gitignore create mode 100644 notebooks/Deep_Learning_listmode_PET/00_introduction.ipynb create mode 100644 notebooks/Deep_Learning_listmode_PET/00_introduction.py create mode 100644 notebooks/Deep_Learning_listmode_PET/01_SIRF_listmode_recon.ipynb create mode 100644 notebooks/Deep_Learning_listmode_PET/01_SIRF_listmode_recon.py create mode 100644 notebooks/Deep_Learning_listmode_PET/02_SIRF_vs_torch_arrays.ipynb create mode 100644 notebooks/Deep_Learning_listmode_PET/02_SIRF_vs_torch_arrays.py create mode 100644 notebooks/Deep_Learning_listmode_PET/03_custom_torch_layers.ipynb create mode 100644 notebooks/Deep_Learning_listmode_PET/03_custom_torch_layers.py create mode 100644 notebooks/Deep_Learning_listmode_PET/04_custom_sirf_Poisson_logL_layer.ipynb create mode 100644 notebooks/Deep_Learning_listmode_PET/04_custom_sirf_Poisson_logL_layer.py create mode 100644 notebooks/Deep_Learning_listmode_PET/05_custrom_unrolled_varnet.ipynb create mode 100644 notebooks/Deep_Learning_listmode_PET/05_custrom_unrolled_varnet.py create mode 100644 notebooks/Deep_Learning_listmode_PET/06_outlook.ipynb create mode 100644 notebooks/Deep_Learning_listmode_PET/06_outlook.py create mode 100644 notebooks/Deep_Learning_listmode_PET/README.md create mode 100644 notebooks/Deep_Learning_listmode_PET/TODO.txt create mode 100644 notebooks/Deep_Learning_listmode_PET/figs/.gitignore create mode 100644 notebooks/Deep_Learning_listmode_PET/figs/osem_layer.drawio create mode 100644 notebooks/Deep_Learning_listmode_PET/figs/osem_layer.drawio.svg create mode 100644 notebooks/Deep_Learning_listmode_PET/figs/osem_varnet.drawio create mode 100644 notebooks/Deep_Learning_listmode_PET/figs/osem_varnet.drawio.svg create mode 100644 notebooks/Deep_Learning_listmode_PET/figs/poisson_logL_grad_layer.drawio create mode 100644 notebooks/Deep_Learning_listmode_PET/figs/poisson_logL_grad_layer.drawio.svg create mode 100644 notebooks/Deep_Learning_listmode_PET/figs/varnet.drawio create mode 100644 notebooks/Deep_Learning_listmode_PET/figs/varnet.drawio.svg create mode 100644 notebooks/Deep_Learning_listmode_PET/snippets/solution_0_1.md create mode 100644 notebooks/Deep_Learning_listmode_PET/snippets/solution_1_1.py create mode 100644 notebooks/Deep_Learning_listmode_PET/snippets/solution_1_2.py create mode 100644 notebooks/Deep_Learning_listmode_PET/snippets/solution_1_3.py create mode 100644 notebooks/Deep_Learning_listmode_PET/snippets/solution_1_4.py create mode 100644 notebooks/Deep_Learning_listmode_PET/snippets/solution_2_1.py create mode 100644 notebooks/Deep_Learning_listmode_PET/snippets/solution_3_1.py create mode 100644 notebooks/Deep_Learning_listmode_PET/snippets/solution_4_1.py create mode 100644 notebooks/Deep_Learning_listmode_PET/snippets/solution_4_2.py create mode 100644 notebooks/Deep_Learning_listmode_PET/snippets/solution_5_1.py create mode 100644 notebooks/Deep_Learning_listmode_PET/test/lm_data_fid.py create mode 100644 notebooks/Deep_Learning_listmode_PET/test/stir_torch_lm_em_layer.py create mode 100644 notebooks/Deep_Learning_listmode_PET/test/test_grad_layer.py create mode 100644 notebooks/Deep_Learning_listmode_PET/test/test_hessian.py create mode 100644 notebooks/Deep_Learning_listmode_PET/test/torch_em_layers.py create mode 100644 notebooks/Deep_Learning_listmode_PET/test/train_varnet.py diff --git a/notebooks/Deep_Learning_listmode_PET/.gitignore b/notebooks/Deep_Learning_listmode_PET/.gitignore new file mode 100644 index 00000000..aedeafce --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/.gitignore @@ -0,0 +1,5 @@ +*.ipynb +lm_recons/* +recons/* +recons_1min/* +recons_60min/* diff --git a/notebooks/Deep_Learning_listmode_PET/00_introduction.ipynb b/notebooks/Deep_Learning_listmode_PET/00_introduction.ipynb new file mode 100644 index 00000000..cb13c546 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/00_introduction.ipynb @@ -0,0 +1,226 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "62b3400b", + "metadata": {}, + "source": [ + "Introduction & Motivation\n", + "=========================\n", + "\n", + "In this series of SIRF exercises, we will learn how to build and train a deep neural\n", + "network for listmode PET reconstruction. As a concrete example,\n", + "we will focus on unrolled Variational networks that can be trained in a supervised manner.\n", + "The general architecture of such network is shown below.\n", + "\n", + "![](figs/osem_varnet.drawio.svg)\n", + "\n", + "The aim of an unrolled variational PET listmode network is to create \"high quality\" PET reconstructions\n", + "from \"low-quality\" input listmode data using supervised training." + ] + }, + { + "cell_type": "markdown", + "id": "9338c20c", + "metadata": {}, + "source": [ + "Question\n", + "--------\n", + "\n", + "Which (realistic) circumstances can lead to \"low-quality\" PET listmode data?\n", + "How can we obtain paired \"high-quality\" PET reconstructions needed for supervised training?" + ] + }, + { + "cell_type": "markdown", + "id": "dccdc765", + "metadata": {}, + "source": [ + "Learning objectives of this notebook\n", + "------------------------------------\n", + "\n", + "1. What is listmode PET reconstruction and why is it attractive for combining DL and reconstruction.\n", + "2. Understanding architectures of unrolled reconstruction networks.\n", + "3. Understanding the essential blocks of training a neural network in pytorch (model setup, data loading, gradient backpropagation)\n", + " and what we are missing from pytorch to build an unrolled PET reconstruction network." + ] + }, + { + "cell_type": "markdown", + "id": "3ad534e9", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "What is listmode PET reconstruction? Why is it attractive for combining DL and reconstruction?\n", + "----------------------------------------------------------------------------------------------\n", + "\n", + "In listmode PET reconstruction, the emission data is stored in a list of events. Each event contains\n", + "the detector numbers of the two detectors that detected the photon pair, and eventually also the\n", + "arrival time difference between the two photons (time-of-flight or TOF).\n", + "\n", + "In contrast to histogrammed emission data (singoram mode), reconstruction of listmode data has the following advantages:\n", + "1. For low and normal count acquisitions with modern TOF scanners, forward and back projections in listmode are usually faster\n", + " compared to projections in sinogram mode. **Question: Why?**\n", + "2. Storage of (low count) listmode data requires less memory compared to storing full TOF sinograms. **Question: Why?**\n", + "3. Listmode data also preserves the timing information of the detected photon pairs." + ] + }, + { + "cell_type": "markdown", + "id": "bdcd2c5e", + "metadata": {}, + "source": [ + "Architecture of unrolled reconstruction networks\n", + "------------------------------------------------\n", + "\n", + "Unrolled variational networks are a class of deep neural networks that are designed to solve inverse problems.\n", + "The consist of a series of layers that are repeated multiple times.\n", + "Each contains an update with respect to the data fidelity term (blue boxes in the figure above)\n", + "and a regularization term (red boxes in the figure above).\n", + "The latter can be represented by a neural network (e.g. a CNN) containing learnable parameters which are optimized\n", + "during (supervised) training.\n", + "\n", + "There are many way of implementing the data fidelity update block.\n", + "One simple possibility is to implement a gradient ascent step with respect to the Poisson log-likelihood.\n", + "$$ x^+ = x_k + \\alpha \\nabla_x \\log L(y|x) ,$$\n", + "where the Poisson log-likelihood is given by\n", + "$$ \\log L(y|x) = \\sum_{i} y_i \\log(\\bar{y}_i(x)) - \\bar{y}_i(x) ,$$\n", + "where $y$ is the measured emission sinogram, and $\\bar{y}(x) = Ax + s$ the expectation of the measured data given the current\n", + "estimate of the image $x$ and a linear (affine) forward model $A$ including the mean of known additive contaminations (randoms and scatter) $s$.\n" + ] + }, + { + "cell_type": "markdown", + "id": "c069ae51", + "metadata": {}, + "source": [ + "Exercise 0.1\n", + "------------\n", + "\n", + "Given the equations above, derive the update formula for the gradient of the Poisson log-likelihood (using sinogram data)\n", + "\n", + "(bonus question) How does the update formula change if we use listmode data instead of sinogram data?\n", + "\n", + "YOUR SOLUTION GOES IN HERE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc9dff44", + "metadata": {}, + "outputs": [], + "source": [ + "# #TO SHOW THE SOLUTION, UNCOMMENT THE NEXT TO LINES AND RUN THE CELL\n", + "# from IPython.display import Markdown, display\n", + "# display(Markdown(\"snippets/solution_0_1.md\"))" + ] + }, + { + "cell_type": "markdown", + "id": "061b92da", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "Training a neural network in pytorch\n", + "------------------------------------\n", + "\n", + "Pytorch is a popular deep learning framework that provides a flexible and efficient way to build and train neural networks.\n", + "The essential steps to train a neural network in pytorch are summarized in the train loop, see\n", + "[here](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html#optimizing-the-model-parameters) for more details." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84f00b91", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "# DO NOT RUN THIS CELL - CODE SNIPPET ONLY\n", + "import torch\n", + "\n", + "\n", + "def train(\n", + " dataloader: torch.utils.data.DataLoader,\n", + " model: torch.nn.Module,\n", + " loss_fn: torch.nn.Module,\n", + " optimizer: torch.optim.Optimizer,\n", + " device: torch.device,\n", + "):\n", + " model.train()\n", + " # loop over the dataset and sample mini-batches\n", + " for batch_num, (input_data_batch, target_image_batch) in enumerate(dataloader):\n", + " # move input and target data to device\n", + " input_data_batch = input_data_batch.to(device)\n", + " target_image_batch = target_image_batch.to(device)\n", + "\n", + " # Compute prediction error\n", + " predicted_image_batch = model(input_data_batch)\n", + " loss = loss_fn(predicted_image_batch, target_image_batch)\n", + "\n", + " # calculate gradients using backpropagation\n", + " loss.backward()\n", + " # update model parameters\n", + " optimizer.step()\n", + " # reset gradients\n", + " optimizer.zero_grad()\n", + "\n", + "\n", + "# model and data loader to be defined\n", + "my_model = myModel()\n", + "my_data_loader = myDataLoader()\n", + "\n", + "# compute device - use cuda GPU if available\n", + "dev = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "# the loss function we optimize during training\n", + "my_loss_fn = torch.nn.MSELoss()\n", + "# the optimizer we use to update the model parameters\n", + "my_optimizer = torch.optim.Adam(my_model.parameters(), lr=1e-3)\n", + "\n", + "# run a single epoch of training\n", + "train(my_data_loader, my_model, my_loss_fn, my_optimizer, dev)" + ] + }, + { + "cell_type": "markdown", + "id": "c125902e", + "metadata": {}, + "source": [ + "**The essential blocks for supervised training a neural network in pytorch are:**\n", + "1. Sampling of mini-batches of input and target (label) images from the training dataset.\n", + "2. Forward pass: Compute the prediction of the model given the input data.\n", + "3. Compute the loss (error) between the prediction and the target images.\n", + "4. Backward pass: Compute the gradient of the loss with respect to the model parameters using backpropagation.\n", + "5. Update the model parameters using an optimizer.\n", + "\n", + "Fortunately, pytorch provides many high-level functions that simplify the implementation of all these steps.\n", + "(e.g. pytorch's data loader classes, pytorch's convolutional layers and non-linear activation function, pytorch's\n", + "autograd functionality for backpropagation of gradients, and optimizers like Adam)\n", + "To train a listmode PET unrolled variational network, the only thing we need to implement ourselves\n", + "is the forward pass of our model, including the data fidelity update blocks which are not directly available pytorch.\n", + "\n", + "**The aim of the remaining exercises is:**\n", + "- to learn how to couple SIRF/STIR's PET listmode classes into a pytorch feedforward model\n", + "- learn how to backpropagate gradients through our custom model\n", + "\n", + "**The following is beyond the scope of the exercises:**\n", + "- training a real world unrolled variational listmode PET reconstruction network on a\n", + " big amount of data" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/Deep_Learning_listmode_PET/00_introduction.py b/notebooks/Deep_Learning_listmode_PET/00_introduction.py new file mode 100644 index 00000000..fcb9c00c --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/00_introduction.py @@ -0,0 +1,156 @@ +# %% [markdown] +# Introduction & Motivation +# ========================= +# +# In this series of SIRF exercises, we will learn how to build and train a deep neural +# network for listmode PET reconstruction. As a concrete example, +# we will focus on unrolled Variational networks that can be trained in a supervised manner. +# The general architecture of such network is shown below. +# +# ![](figs/osem_varnet.drawio.svg) +# +# The aim of an unrolled variational PET listmode network is to create "high quality" PET reconstructions +# from "low-quality" input listmode data using supervised training. + +# %% [markdown] +# Question +# -------- +# +# Which (realistic) circumstances can lead to "low-quality" PET listmode data? +# How can we obtain paired "high-quality" PET reconstructions needed for supervised training? + +# %% [markdown] +# Learning objectives of this notebook +# ------------------------------------ +# +# 1. What is listmode PET reconstruction and why is it attractive for combining DL and reconstruction. +# 2. Understanding architectures of unrolled reconstruction networks. +# 3. Understanding the essential blocks of training a neural network in pytorch (model setup, data loading, gradient backpropagation) +# and what we are missing from pytorch to build an unrolled PET reconstruction network. + +# %% [markdown] +# What is listmode PET reconstruction? Why is it attractive for combining DL and reconstruction? +# ---------------------------------------------------------------------------------------------- +# +# In listmode PET reconstruction, the emission data is stored in a list of events. Each event contains +# the detector numbers of the two detectors that detected the photon pair, and eventually also the +# arrival time difference between the two photons (time-of-flight or TOF). +# +# In contrast to histogrammed emission data (singoram mode), reconstruction of listmode data has the following advantages: +# 1. For low and normal count acquisitions with modern TOF scanners, forward and back projections in listmode are usually faster +# compared to projections in sinogram mode. **Question: Why?** +# 2. Storage of (low count) listmode data requires less memory compared to storing full TOF sinograms. **Question: Why?** +# 3. Listmode data also preserves the timing information of the detected photon pairs. + + +# %% [markdown] +# Architecture of unrolled reconstruction networks +# ------------------------------------------------ +# +# Unrolled variational networks are a class of deep neural networks that are designed to solve inverse problems. +# The consist of a series of layers that are repeated multiple times. +# Each contains an update with respect to the data fidelity term (blue boxes in the figure above) +# and a regularization term (red boxes in the figure above). +# The latter can be represented by a neural network (e.g. a CNN) containing learnable parameters which are optimized +# during (supervised) training. +# +# There are many way of implementing the data fidelity update block. +# One simple possibility is to implement a gradient ascent step with respect to the Poisson log-likelihood. +# $$ x^+ = x_k + \alpha \nabla_x \log L(y|x) ,$$ +# where the Poisson log-likelihood is given by +# $$ \log L(y|x) = \sum_{i} y_i \log(\bar{y}_i(x)) - \bar{y}_i(x) ,$$ +# where $y$ is the measured emission sinogram, and $\bar{y}(x) = Ax + s$ the expectation of the measured data given the current +# estimate of the image $x$ and a linear (affine) forward model $A$ including the mean of known additive contaminations (randoms and scatter) $s$. +# + +# %% [markdown] +# Exercise 0.1 +# ------------ +# +# Given the equations above, derive the update formula for the gradient of the Poisson log-likelihood (using sinogram data) +# +# (bonus question) How does the update formula change if we use listmode data instead of sinogram data? + +# YOUR SOLUTION GOES IN HERE + +# %% +# #TO SHOW THE SOLUTION, UNCOMMENT THE NEXT TO LINES AND RUN THE CELL +# from IPython.display import Markdown, display +# display(Markdown("snippets/solution_0_1.md")) + +# %% [markdown] +# Training a neural network in pytorch +# ------------------------------------ +# +# Pytorch is a popular deep learning framework that provides a flexible and efficient way to build and train neural networks. +# The essential steps to train a neural network in pytorch are summarized in the train loop, see +# [here](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html#optimizing-the-model-parameters) for more details. + + +# %% +# DO NOT RUN THIS CELL - CODE SNIPPET ONLY +import torch + + +def train( + dataloader: torch.utils.data.DataLoader, + model: torch.nn.Module, + loss_fn: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: torch.device, +): + model.train() + # loop over the dataset and sample mini-batches + for batch_num, (input_data_batch, target_image_batch) in enumerate(dataloader): + # move input and target data to device + input_data_batch = input_data_batch.to(device) + target_image_batch = target_image_batch.to(device) + + # Compute prediction error + predicted_image_batch = model(input_data_batch) + loss = loss_fn(predicted_image_batch, target_image_batch) + + # calculate gradients using backpropagation + loss.backward() + # update model parameters + optimizer.step() + # reset gradients + optimizer.zero_grad() + + +# model and data loader to be defined +my_model = myModel() +my_data_loader = myDataLoader() + +# compute device - use cuda GPU if available +dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# the loss function we optimize during training +my_loss_fn = torch.nn.MSELoss() +# the optimizer we use to update the model parameters +my_optimizer = torch.optim.Adam(my_model.parameters(), lr=1e-3) + +# run a single epoch of training +train(my_data_loader, my_model, my_loss_fn, my_optimizer, dev) + + +# %% [markdown] +# **The essential blocks for supervised training a neural network in pytorch are:** +# 1. Sampling of mini-batches of input and target (label) images from the training dataset. +# 2. Forward pass: Compute the prediction of the model given the input data. +# 3. Compute the loss (error) between the prediction and the target images. +# 4. Backward pass: Compute the gradient of the loss with respect to the model parameters using backpropagation. +# 5. Update the model parameters using an optimizer. +# +# Fortunately, pytorch provides many high-level functions that simplify the implementation of all these steps. +# (e.g. pytorch's data loader classes, pytorch's convolutional layers and non-linear activation function, pytorch's +# autograd functionality for backpropagation of gradients, and optimizers like Adam) +# To train a listmode PET unrolled variational network, the only thing we need to implement ourselves +# is the forward pass of our model, including the data fidelity update blocks which are not directly available pytorch. +# +# **The aim of the remaining exercises is:** +# - to learn how to couple SIRF/STIR's PET listmode classes into a pytorch feedforward model +# - learn how to backpropagate gradients through our custom model +# +# **The following is beyond the scope of the exercises:** +# - training a real world unrolled variational listmode PET reconstruction network on a +# big amount of data diff --git a/notebooks/Deep_Learning_listmode_PET/01_SIRF_listmode_recon.ipynb b/notebooks/Deep_Learning_listmode_PET/01_SIRF_listmode_recon.ipynb new file mode 100644 index 00000000..f6bf5fff --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/01_SIRF_listmode_recon.ipynb @@ -0,0 +1,679 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "683c2283", + "metadata": {}, + "source": [ + "Sinogram and Listmode OSEM using sirf.STIR\n", + "==========================================\n", + "\n", + "Using the learnings from the previous \"theory\" notebook, we will now learn how to perform\n", + "PET reconstruction of emission data in listmode and sinogram format using (sinogram and listmode)\n", + "objective function objects of the sirf.STIR library.\n", + "\n", + "We will see that standard OSEM reconstruction can be seen as a sequence of image update \"blocks\",\n", + "where the update in each block is related to the gradient of the Poisson loglikelihood objective function.\n", + "\n", + "Understanding these OSEM update blocks is the first key step for implementing a pytorch-based feed-forward\n", + "neural network for PET image reconstruction also containing OSEM-like update blocks.\n", + "\n", + "Learning objectives of this notebook\n", + "------------------------------------\n", + "1. Understanding how to setup a Poisson loglikelihood objective functions in sinogram and listmode mode.\n", + "2. Understanding how to perform sinogram / listmode OSEM reconstruction using sirf.STIR high-level API.\n", + "3. Implementing a simple DIY OSEM reconstruction using the gradient of the Poisson loglikelihood." + ] + }, + { + "cell_type": "markdown", + "id": "d4edb034", + "metadata": {}, + "source": [ + "Import modules\n", + "--------------" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "819ec782", + "metadata": {}, + "outputs": [], + "source": [ + "import sirf.STIR\n", + "import numpy as np\n", + "import subprocess\n", + "import matplotlib.pyplot as plt\n", + "from pathlib import Path\n", + "from sirf.Utilities import examples_data_path" + ] + }, + { + "cell_type": "markdown", + "id": "f5192289", + "metadata": {}, + "source": [ + "Download the 60min mMR NEMA data, if not present\n", + "------------------------------------------------" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "deee0164", + "metadata": {}, + "outputs": [], + "source": [ + "if not (\n", + " Path(\"..\")\n", + " / \"..\"\n", + " / \"data\"\n", + " / \"PET\"\n", + " / \"mMR\"\n", + " / \"NEMA_IQ\"\n", + " / \"20170809_NEMA_60min_UCL.l.hdr\"\n", + ").exists():\n", + " retval = subprocess.call(\"../../scripts/download_PET_data.sh\", shell=True)" + ] + }, + { + "cell_type": "markdown", + "id": "ed0b73e1", + "metadata": {}, + "source": [ + "Define variables and file names\n", + "-------------------------------" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44c4c932", + "metadata": {}, + "outputs": [], + "source": [ + "# we have a 1min and 60min acquisition of the NEMA IQ phantom acquired on a Siemens mMR\n", + "# choose the acquisition time \"1min\" or \"60min\" - start with \"1min\"\n", + "acq_time: str = \"1min\"\n", + "\n", + "data_path: Path = Path(examples_data_path(\"PET\")) / \"mMR\"\n", + "\n", + "if acq_time == \"1min\":\n", + " list_file: str = str(data_path / \"list.l.hdr\")\n", + "elif acq_time == \"60min\":\n", + " # you need to run the \"download_data.sh\" script to get the data of the long 60min acq.\n", + " list_file: str = str(\n", + " Path(\"..\")\n", + " / \"..\"\n", + " / \"data\"\n", + " / \"PET\"\n", + " / \"mMR\"\n", + " / \"NEMA_IQ\"\n", + " / \"20170809_NEMA_60min_UCL.l.hdr\"\n", + " )\n", + "else:\n", + " raise ValueError(\"Please choose acq_time to be either '1min' or '60min'\")\n", + "\n", + "attn_file: str = str(data_path / \"mu_map.hv\")\n", + "norm_file: str = str(data_path / \"norm.n.hdr\")\n", + "output_path: Path = Path(f\"recons_{acq_time}\")\n", + "emission_sinogram_output_prefix: str = str(output_path / \"emission_sinogram\")\n", + "scatter_sinogram_output_prefix: str = str(output_path / \"scatter_sinogram\")\n", + "randoms_sinogram_output_prefix: str = str(output_path / \"randoms_sinogram\")\n", + "attenuation_sinogram_output_prefix: str = str(output_path / \"acf_sinogram\")\n", + "recon_output_file: str = str(output_path / \"recon\")\n", + "lm_recon_output_file: str = str(output_path / \"lm_recon\")\n", + "nxny: tuple[int, int] = (127, 127)\n", + "num_subsets: int = 21\n", + "num_iter: int = 1\n", + "num_scatter_iter: int = 3\n", + "\n", + "# create the output directory\n", + "output_path.mkdir(exist_ok=True)\n", + "\n", + "# engine's messages go to files, except error messages, which go to stdout\n", + "_ = sirf.STIR.MessageRedirector(\"info.txt\", \"warn.txt\")" + ] + }, + { + "cell_type": "markdown", + "id": "2f5967c5", + "metadata": {}, + "source": [ + "Read the listmode data and create a sinogram template\n", + "-----------------------------------------------------" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d08a555", + "metadata": {}, + "outputs": [], + "source": [ + "sirf.STIR.AcquisitionData.set_storage_scheme(\"memory\")\n", + "listmode_data = sirf.STIR.ListmodeData(list_file)\n", + "acq_data_template = listmode_data.acquisition_data_template()\n", + "print(acq_data_template.get_info())" + ] + }, + { + "cell_type": "markdown", + "id": "4c75f3bd", + "metadata": {}, + "source": [ + "Conversion of listmode to sinogram data (needed for scatter estimation)\n", + "-----------------------------------------------------------------------" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "faf9cc0e", + "metadata": {}, + "outputs": [], + "source": [ + "# create listmode-to-sinograms converter object\n", + "lm2sino = sirf.STIR.ListmodeToSinograms()\n", + "\n", + "# set input, output and template files\n", + "lm2sino.set_input(listmode_data)\n", + "lm2sino.set_output_prefix(emission_sinogram_output_prefix)\n", + "lm2sino.set_template(acq_data_template)\n", + "\n", + "# get the start and end time of the listmode data\n", + "frame_start = float(\n", + " [\n", + " x\n", + " for x in listmode_data.get_info().split(\"\\n\")\n", + " if x.startswith(\"Time frame start\")\n", + " ][0]\n", + " .split(\": \")[1]\n", + " .split(\"-\")[0]\n", + ")\n", + "frame_end = float(\n", + " [\n", + " x\n", + " for x in listmode_data.get_info().split(\"\\n\")\n", + " if x.startswith(\"Time frame start\")\n", + " ][0]\n", + " .split(\": \")[1]\n", + " .split(\"-\")[1]\n", + " .split(\"(\")[0]\n", + ")\n", + "# set interval\n", + "lm2sino.set_time_interval(frame_start, frame_end)\n", + "# set up the converter\n", + "lm2sino.set_up()\n", + "\n", + "# convert (need it for the scatter estimate)\n", + "lm2sino.process()\n", + "acq_data = lm2sino.get_output()" + ] + }, + { + "cell_type": "markdown", + "id": "b101d3cc", + "metadata": {}, + "source": [ + "Estimation of random coincidences\n", + "---------------------------------" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38ad8c17", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "randoms_filepath = Path(f\"{randoms_sinogram_output_prefix}.hs\")\n", + "\n", + "if not randoms_filepath.exists():\n", + " print(\"estimting randoms\")\n", + " randoms = lm2sino.estimate_randoms()\n", + " randoms.write(randoms_sinogram_output_prefix)\n", + "else:\n", + " print(\"reading randoms from {randoms_filepath}\")\n", + " randoms = sirf.STIR.AcquisitionData(str(randoms_filepath))" + ] + }, + { + "cell_type": "markdown", + "id": "973d4184", + "metadata": {}, + "source": [ + "Setup of the acquisition model\n", + "------------------------------" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5038649b", + "metadata": {}, + "outputs": [], + "source": [ + "# select acquisition model that implements the geometric\n", + "# forward projection by a ray tracing matrix multiplication\n", + "acq_model = sirf.STIR.AcquisitionModelUsingRayTracingMatrix()\n", + "# acq_model.set_num_tangential_LORs(10)\n", + "acq_model.set_num_tangential_LORs(1)" + ] + }, + { + "cell_type": "markdown", + "id": "f23c18e4", + "metadata": {}, + "source": [ + "Calculation of the attenuation sinogram\n", + "---------------------------------------" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95fc6371", + "metadata": {}, + "outputs": [], + "source": [ + "# read attenuation image and display a single slice\n", + "attn_image = sirf.STIR.ImageData(attn_file)\n", + "\n", + "# create attenuation factors\n", + "asm_attn = sirf.STIR.AcquisitionSensitivityModel(attn_image, acq_model)\n", + "# converting attenuation image into attenuation factors (one for every bin)\n", + "asm_attn.set_up(acq_data)\n", + "\n", + "acf_filepath = Path(f\"{attenuation_sinogram_output_prefix}.hs\")\n", + "\n", + "if not acf_filepath.exists():\n", + " ac_factors = acq_data.get_uniform_copy(value=1)\n", + " print(\"applying attenuation (please wait, may take a while)...\")\n", + " asm_attn.unnormalise(ac_factors)\n", + " ac_factors.write(attenuation_sinogram_output_prefix)\n", + "else:\n", + " print(f\"reading attenuation factors from {acf_filepath}\")\n", + " ac_factors = sirf.STIR.AcquisitionData(str(acf_filepath))\n", + "\n", + "asm_attn = sirf.STIR.AcquisitionSensitivityModel(ac_factors)" + ] + }, + { + "cell_type": "markdown", + "id": "22067c43", + "metadata": {}, + "source": [ + "Creation of the normalization factors (sensitivity sinogram)\n", + "------------------------------------------------------------" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9fe9f642", + "metadata": {}, + "outputs": [], + "source": [ + "# create acquisition sensitivity model from normalisation data\n", + "asm_norm = sirf.STIR.AcquisitionSensitivityModel(norm_file)\n", + "\n", + "asm = sirf.STIR.AcquisitionSensitivityModel(asm_norm, asm_attn)\n", + "asm.set_up(acq_data)\n", + "acq_model.set_acquisition_sensitivity(asm)" + ] + }, + { + "cell_type": "markdown", + "id": "a1c3b169", + "metadata": {}, + "source": [ + "Estimation of scattered coincidences\n", + "------------------------------------" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cce2d650", + "metadata": {}, + "outputs": [], + "source": [ + "scatter_filepath: Path = Path(f\"{scatter_sinogram_output_prefix}_{num_scatter_iter}.hs\")\n", + "\n", + "if not scatter_filepath.exists():\n", + " print(\"estimating scatter (this will take a while!)\")\n", + " scatter_estimator = sirf.STIR.ScatterEstimator()\n", + " scatter_estimator.set_input(acq_data)\n", + " scatter_estimator.set_attenuation_image(attn_image)\n", + " scatter_estimator.set_randoms(randoms)\n", + " scatter_estimator.set_asm(asm_norm)\n", + " # invert attenuation factors to get the correction factors,\n", + " # as this is unfortunately what a ScatterEstimator needs\n", + " acf_factors = acq_data.get_uniform_copy()\n", + " acf_factors.fill(1 / ac_factors.as_array())\n", + " scatter_estimator.set_attenuation_correction_factors(acf_factors)\n", + " scatter_estimator.set_output_prefix(scatter_sinogram_output_prefix)\n", + " scatter_estimator.set_num_iterations(num_scatter_iter)\n", + " scatter_estimator.set_up()\n", + " scatter_estimator.process()\n", + " scatter_estimate = scatter_estimator.get_output()\n", + "else:\n", + " print(f\"reading scatter from file {scatter_filepath}\")\n", + " scatter_estimate = sirf.STIR.AcquisitionData(str(scatter_filepath))\n", + "\n", + "# add scatter plus randoms estimated to the background term of the acquisition model\n", + "acq_model.set_background_term(randoms + scatter_estimate)" + ] + }, + { + "cell_type": "markdown", + "id": "097ecd59", + "metadata": {}, + "source": [ + "Setup of the Poisson loglikelihood objective function in sinogram mode\n", + "----------------------------------------------------------------------" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02a60875", + "metadata": {}, + "outputs": [], + "source": [ + "initial_image = acq_data.create_uniform_image(value=1, xy=nxny)\n", + "\n", + "# create objective function\n", + "obj_fun = sirf.STIR.make_Poisson_loglikelihood(acq_data)\n", + "obj_fun.set_acquisition_model(acq_model)\n", + "obj_fun.set_num_subsets(num_subsets)\n", + "obj_fun.set_up(initial_image)" + ] + }, + { + "cell_type": "markdown", + "id": "7ca398f1", + "metadata": {}, + "source": [ + "Image reconstruction (optimization of the Poisson logL objective function) using sinogram OSEM\n", + "----------------------------------------------------------------------------------------------" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb2442af", + "metadata": {}, + "outputs": [], + "source": [ + "if not Path(f\"{recon_output_file}.hv\").exists():\n", + " reconstructor = sirf.STIR.OSMAPOSLReconstructor()\n", + " reconstructor.set_objective_function(obj_fun)\n", + " reconstructor.set_num_subsets(num_subsets)\n", + " reconstructor.set_num_subiterations(num_iter * num_subsets)\n", + " reconstructor.set_input(acq_data)\n", + " reconstructor.set_up(initial_image)\n", + " reconstructor.set_current_estimate(initial_image)\n", + " reconstructor.process()\n", + " ref_recon = reconstructor.get_output()\n", + " ref_recon.write(recon_output_file)\n", + "else:\n", + " ref_recon = sirf.STIR.ImageData(f\"{recon_output_file}.hv\")\n", + "\n", + "vmax = np.percentile(ref_recon.as_array(), 99.999)\n", + "\n", + "fig, ax = plt.subplots(1, 1, figsize=(4, 4), tight_layout=True)\n", + "ax.imshow(ref_recon.as_array()[71, :, :], cmap=\"Greys\", vmin=0, vmax=vmax)\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "id": "4ca63eb9", + "metadata": {}, + "source": [ + "Exercise 1.1\n", + "------------\n", + "\n", + "Perform the gradient ascent step\n", + "$$ x^+ = x + \\alpha \\nabla_x logL(y|x) $$\n", + "on the initial image x using a constant scalar step size $\\alpha=0.001$ by calling\n", + "the `gradient()` method of the objective function.\n", + "Use the first (0th) subset of the data for the gradient calculation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "abb57d02", + "metadata": {}, + "outputs": [], + "source": [ + "#\n", + "# ==============\n", + "# YOUR CODE HERE\n", + "# ==============\n", + "#" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d016524b", + "metadata": {}, + "outputs": [], + "source": [ + "# to view the solution, execute the this cell\n", + "%load snippets/solution_1_1.py" + ] + }, + { + "cell_type": "markdown", + "id": "b433390d", + "metadata": {}, + "source": [ + "Exercise 1.2\n", + "------------\n", + "\n", + "Given the fact that the OSEM update can be written as\n", + "$$ x^+ = x + t \\nabla_x logL(y|x) $$\n", + "with the non-scalar step size\n", + "$$ t = \\frac{x}{s} $$\n", + "where $s$ is the (subset) \"sensitivity image\", perform an OSEM update on the initial image\n", + "by using the `get_subset_sensitivity()` method of the objective function and the first subset.\n", + "Print the maximum value of the updated image. What do you observe?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a2c2b56", + "metadata": {}, + "outputs": [], + "source": [ + "#\n", + "# ==============\n", + "# YOUR CODE HERE\n", + "# ==============\n", + "#" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e710c35f", + "metadata": {}, + "outputs": [], + "source": [ + "# to view the solution, execute the this cell\n", + "%load snippets/solution_1_2.py" + ] + }, + { + "cell_type": "markdown", + "id": "34b0df04", + "metadata": {}, + "source": [ + "Exercise 1.3\n", + "------------\n", + "\n", + "Implement your own OSEM reconstruction by looping over the subsets and performing the\n", + "OSEM update for each subset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "941893c7", + "metadata": {}, + "outputs": [], + "source": [ + "# initialize the reconstruction with ones where the sensitivity image is greater than 0\n", + "# all other values are set to zero and are not updated during reconstruction\n", + "recon = initial_image.copy()\n", + "recon.fill(obj_fun.get_subset_sensitivity(0).as_array() > 0)\n", + "#\n", + "# ==============\n", + "# YOUR CODE HERE\n", + "# ==============\n", + "#" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e13940c4", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "# to view the solution, execute the this cell\n", + "%load snippets/solution_1_3.py" + ] + }, + { + "cell_type": "markdown", + "id": "5b573d43", + "metadata": {}, + "source": [ + "Setup of the Poisson loglikelihood objective function logL(y|x) in listmode\n", + "---------------------------------------------------------------------------" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c098f286", + "metadata": {}, + "outputs": [], + "source": [ + "# define the listmode objective function\n", + "lm_obj_fun = (\n", + " sirf.STIR.PoissonLogLikelihoodWithLinearModelForMeanAndListModeDataWithProjMatrixByBin()\n", + ")\n", + "lm_obj_fun.set_acquisition_model(acq_model)\n", + "lm_obj_fun.set_acquisition_data(listmode_data)\n", + "lm_obj_fun.set_num_subsets(num_subsets)\n", + "lm_obj_fun.set_cache_max_size(1000000000)\n", + "lm_obj_fun.set_cache_path(str(output_path))" + ] + }, + { + "cell_type": "markdown", + "id": "633568f9", + "metadata": {}, + "source": [ + "Reconstruction (optimization of the Poisson logL objective function) using listmode OSEM\n", + "----------------------------------------------------------------------------------------" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cdad4c41", + "metadata": {}, + "outputs": [], + "source": [ + "if not Path(f\"{lm_recon_output_file}.hv\").exists():\n", + " lm_reconstructor = sirf.STIR.OSMAPOSLReconstructor()\n", + " lm_reconstructor.set_objective_function(lm_obj_fun)\n", + " lm_reconstructor.set_num_subsets(num_subsets)\n", + " lm_reconstructor.set_num_subiterations(num_iter * num_subsets)\n", + " lm_reconstructor.set_up(initial_image)\n", + " lm_reconstructor.set_current_estimate(initial_image)\n", + " lm_reconstructor.process()\n", + " lm_ref_recon = lm_reconstructor.get_output()\n", + " lm_ref_recon.write(lm_recon_output_file)\n", + "else:\n", + " lm_ref_recon = sirf.STIR.ImageData(f\"{lm_recon_output_file}.hv\")\n", + "\n", + "fig3, ax3 = plt.subplots(1, 1, figsize=(4, 4), tight_layout=True)\n", + "ax3.imshow(lm_ref_recon.as_array()[71, :, :], cmap=\"Greys\", vmin=0, vmax=vmax)\n", + "fig3.show()" + ] + }, + { + "cell_type": "markdown", + "id": "1e891ec5", + "metadata": {}, + "source": [ + "Exercise 1.4\n", + "------------\n", + "Repeat exercise 1.3 (OSEM reconstruction) using the listmode objective function to\n", + "learn how to do a listmode OSEM update step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6eac8f6", + "metadata": {}, + "outputs": [], + "source": [ + "#\n", + "# ==============\n", + "# YOUR CODE HERE\n", + "# ==============" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8136729d", + "metadata": {}, + "outputs": [], + "source": [ + "# to view the solution, execute the cell below\n", + "%load snippets/solution_1_4.py" + ] + }, + { + "cell_type": "markdown", + "id": "a99979ab", + "metadata": {}, + "source": [ + "Exercise 1.5\n", + "------------\n", + "Rerun the sinogram and listmode reconstruction (first cells of the notebook)\n", + "using the 60min acquisition data by adapting the `acq_time` variable.\n", + "Make sure that you restart the kernel before running the cells and to rerun\n", + "the all cells (including scatter and random estimation).\n", + "We wil use the 60min reconstruction in our last notebook." + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/Deep_Learning_listmode_PET/01_SIRF_listmode_recon.py b/notebooks/Deep_Learning_listmode_PET/01_SIRF_listmode_recon.py new file mode 100644 index 00000000..12944934 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/01_SIRF_listmode_recon.py @@ -0,0 +1,409 @@ +# %% [markdown] +# Sinogram and Listmode OSEM using sirf.STIR +# ========================================== +# +# Using the learnings from the previous "theory" notebook, we will now learn how to perform +# PET reconstruction of emission data in listmode and sinogram format using (sinogram and listmode) +# objective function objects of the sirf.STIR library. +# +# We will see that standard OSEM reconstruction can be seen as a sequence of image update "blocks", +# where the update in each block is related to the gradient of the Poisson loglikelihood objective function. +# +# Understanding these OSEM update blocks is the first key step for implementing a pytorch-based feed-forward +# neural network for PET image reconstruction also containing OSEM-like update blocks. +# +# Learning objectives of this notebook +# ------------------------------------ +# 1. Understanding how to setup a Poisson loglikelihood objective functions in sinogram and listmode mode. +# 2. Understanding how to perform sinogram / listmode OSEM reconstruction using sirf.STIR high-level API. +# 3. Implementing a simple DIY OSEM reconstruction using the gradient of the Poisson loglikelihood. + +# %% [markdown] +# Import modules +# -------------- + +# %% +import sirf.STIR +import numpy as np +import subprocess +import matplotlib.pyplot as plt +from pathlib import Path +from sirf.Utilities import examples_data_path + +# %% [markdown] +# Download the 60min mMR NEMA data, if not present +# ------------------------------------------------ + +# %% +if not ( + Path("..") + / ".." + / "data" + / "PET" + / "mMR" + / "NEMA_IQ" + / "20170809_NEMA_60min_UCL.l.hdr" +).exists(): + retval = subprocess.call("../../scripts/download_PET_data.sh", shell=True) + +# %% [markdown] +# Define variables and file names +# ------------------------------- + +# %% +# we have a 1min and 60min acquisition of the NEMA IQ phantom acquired on a Siemens mMR +# choose the acquisition time "1min" or "60min" - start with "1min" +acq_time: str = "1min" + +data_path: Path = Path(examples_data_path("PET")) / "mMR" + +if acq_time == "1min": + list_file: str = str(data_path / "list.l.hdr") +elif acq_time == "60min": + # you need to run the "download_data.sh" script to get the data of the long 60min acq. + list_file: str = str( + Path("..") + / ".." + / "data" + / "PET" + / "mMR" + / "NEMA_IQ" + / "20170809_NEMA_60min_UCL.l.hdr" + ) +else: + raise ValueError("Please choose acq_time to be either '1min' or '60min'") + +attn_file: str = str(data_path / "mu_map.hv") +norm_file: str = str(data_path / "norm.n.hdr") +output_path: Path = Path(f"recons_{acq_time}") +emission_sinogram_output_prefix: str = str(output_path / "emission_sinogram") +scatter_sinogram_output_prefix: str = str(output_path / "scatter_sinogram") +randoms_sinogram_output_prefix: str = str(output_path / "randoms_sinogram") +attenuation_sinogram_output_prefix: str = str(output_path / "acf_sinogram") +recon_output_file: str = str(output_path / "recon") +lm_recon_output_file: str = str(output_path / "lm_recon") +nxny: tuple[int, int] = (127, 127) +num_subsets: int = 21 +num_iter: int = 1 +num_scatter_iter: int = 3 + +# create the output directory +output_path.mkdir(exist_ok=True) + +# engine's messages go to files, except error messages, which go to stdout +_ = sirf.STIR.MessageRedirector("info.txt", "warn.txt") + +# %% [markdown] +# Read the listmode data and create a sinogram template +# ----------------------------------------------------- + +# %% +sirf.STIR.AcquisitionData.set_storage_scheme("memory") +listmode_data = sirf.STIR.ListmodeData(list_file) +acq_data_template = listmode_data.acquisition_data_template() +print(acq_data_template.get_info()) + +# %% [markdown] +# Conversion of listmode to sinogram data (needed for scatter estimation) +# ----------------------------------------------------------------------- + +# %% +# create listmode-to-sinograms converter object +lm2sino = sirf.STIR.ListmodeToSinograms() + +# set input, output and template files +lm2sino.set_input(listmode_data) +lm2sino.set_output_prefix(emission_sinogram_output_prefix) +lm2sino.set_template(acq_data_template) + +# get the start and end time of the listmode data +frame_start = float( + [ + x + for x in listmode_data.get_info().split("\n") + if x.startswith("Time frame start") + ][0] + .split(": ")[1] + .split("-")[0] +) +frame_end = float( + [ + x + for x in listmode_data.get_info().split("\n") + if x.startswith("Time frame start") + ][0] + .split(": ")[1] + .split("-")[1] + .split("(")[0] +) +# set interval +lm2sino.set_time_interval(frame_start, frame_end) +# set up the converter +lm2sino.set_up() + +# convert (need it for the scatter estimate) +lm2sino.process() +acq_data = lm2sino.get_output() + +# %% [markdown] +# Estimation of random coincidences +# --------------------------------- + +# %% +randoms_filepath = Path(f"{randoms_sinogram_output_prefix}.hs") + +if not randoms_filepath.exists(): + print("estimting randoms") + randoms = lm2sino.estimate_randoms() + randoms.write(randoms_sinogram_output_prefix) +else: + print("reading randoms from {randoms_filepath}") + randoms = sirf.STIR.AcquisitionData(str(randoms_filepath)) + + +# %% [markdown] +# Setup of the acquisition model +# ------------------------------ + +# %% +# select acquisition model that implements the geometric +# forward projection by a ray tracing matrix multiplication +acq_model = sirf.STIR.AcquisitionModelUsingRayTracingMatrix() +# acq_model.set_num_tangential_LORs(10) +acq_model.set_num_tangential_LORs(1) + +# %% [markdown] +# Calculation of the attenuation sinogram +# --------------------------------------- + +# %% +# read attenuation image and display a single slice +attn_image = sirf.STIR.ImageData(attn_file) + +# create attenuation factors +asm_attn = sirf.STIR.AcquisitionSensitivityModel(attn_image, acq_model) +# converting attenuation image into attenuation factors (one for every bin) +asm_attn.set_up(acq_data) + +acf_filepath = Path(f"{attenuation_sinogram_output_prefix}.hs") + +if not acf_filepath.exists(): + ac_factors = acq_data.get_uniform_copy(value=1) + print("applying attenuation (please wait, may take a while)...") + asm_attn.unnormalise(ac_factors) + ac_factors.write(attenuation_sinogram_output_prefix) +else: + print(f"reading attenuation factors from {acf_filepath}") + ac_factors = sirf.STIR.AcquisitionData(str(acf_filepath)) + +asm_attn = sirf.STIR.AcquisitionSensitivityModel(ac_factors) + +# %% [markdown] +# Creation of the normalization factors (sensitivity sinogram) +# ------------------------------------------------------------ + +# %% +# create acquisition sensitivity model from normalisation data +asm_norm = sirf.STIR.AcquisitionSensitivityModel(norm_file) + +asm = sirf.STIR.AcquisitionSensitivityModel(asm_norm, asm_attn) +asm.set_up(acq_data) +acq_model.set_acquisition_sensitivity(asm) + +# %% [markdown] +# Estimation of scattered coincidences +# ------------------------------------ + +# %% +scatter_filepath: Path = Path(f"{scatter_sinogram_output_prefix}_{num_scatter_iter}.hs") + +if not scatter_filepath.exists(): + print("estimating scatter (this will take a while!)") + scatter_estimator = sirf.STIR.ScatterEstimator() + scatter_estimator.set_input(acq_data) + scatter_estimator.set_attenuation_image(attn_image) + scatter_estimator.set_randoms(randoms) + scatter_estimator.set_asm(asm_norm) + # invert attenuation factors to get the correction factors, + # as this is unfortunately what a ScatterEstimator needs + acf_factors = acq_data.get_uniform_copy() + acf_factors.fill(1 / ac_factors.as_array()) + scatter_estimator.set_attenuation_correction_factors(acf_factors) + scatter_estimator.set_output_prefix(scatter_sinogram_output_prefix) + scatter_estimator.set_num_iterations(num_scatter_iter) + scatter_estimator.set_up() + scatter_estimator.process() + scatter_estimate = scatter_estimator.get_output() +else: + print(f"reading scatter from file {scatter_filepath}") + scatter_estimate = sirf.STIR.AcquisitionData(str(scatter_filepath)) + +# add scatter plus randoms estimated to the background term of the acquisition model +acq_model.set_background_term(randoms + scatter_estimate) + +# %% [markdown] +# Setup of the Poisson loglikelihood objective function in sinogram mode +# ---------------------------------------------------------------------- + +# %% +initial_image = acq_data.create_uniform_image(value=1, xy=nxny) + +# create objective function +obj_fun = sirf.STIR.make_Poisson_loglikelihood(acq_data) +obj_fun.set_acquisition_model(acq_model) +obj_fun.set_num_subsets(num_subsets) +obj_fun.set_up(initial_image) + +# %% [markdown] +# Image reconstruction (optimization of the Poisson logL objective function) using sinogram OSEM +# ---------------------------------------------------------------------------------------------- + +# %% +if not Path(f"{recon_output_file}.hv").exists(): + reconstructor = sirf.STIR.OSMAPOSLReconstructor() + reconstructor.set_objective_function(obj_fun) + reconstructor.set_num_subsets(num_subsets) + reconstructor.set_num_subiterations(num_iter * num_subsets) + reconstructor.set_input(acq_data) + reconstructor.set_up(initial_image) + reconstructor.set_current_estimate(initial_image) + reconstructor.process() + ref_recon = reconstructor.get_output() + ref_recon.write(recon_output_file) +else: + ref_recon = sirf.STIR.ImageData(f"{recon_output_file}.hv") + +vmax = np.percentile(ref_recon.as_array(), 99.999) + +fig, ax = plt.subplots(1, 1, figsize=(4, 4), tight_layout=True) +ax.imshow(ref_recon.as_array()[71, :, :], cmap="Greys", vmin=0, vmax=vmax) +fig.show() + +# %% [markdown] +# Exercise 1.1 +# ------------ +# +# Perform the gradient ascent step +# $$ x^+ = x + \alpha \nabla_x logL(y|x) $$ +# on the initial image x using a constant scalar step size $\alpha=0.001$ by calling +# the `gradient()` method of the objective function. +# Use the first (0th) subset of the data for the gradient calculation. + +# %% +# +# ============== +# YOUR CODE HERE +# ============== +# + +# %% +# to view the solution, execute the this cell +# %load snippets/solution_1_1.py + +# %% [markdown] +# Exercise 1.2 +# ------------ +# +# Given the fact that the OSEM update can be written as +# $$ x^+ = x + t \nabla_x logL(y|x) $$ +# with the non-scalar step size +# $$ t = \frac{x}{s} $$ +# where $s$ is the (subset) "sensitivity image", perform an OSEM update on the initial image +# by using the `get_subset_sensitivity()` method of the objective function and the first subset. +# Print the maximum value of the updated image. What do you observe? + +# %% +# +# ============== +# YOUR CODE HERE +# ============== +# + +# %% +# to view the solution, execute the this cell +# %load snippets/solution_1_2.py + +# %% [markdown] +# Exercise 1.3 +# ------------ +# +# Implement your own OSEM reconstruction by looping over the subsets and performing the +# OSEM update for each subset. + +# %% +# initialize the reconstruction with ones where the sensitivity image is greater than 0 +# all other values are set to zero and are not updated during reconstruction +recon = initial_image.copy() +recon.fill(obj_fun.get_subset_sensitivity(0).as_array() > 0) +# +# ============== +# YOUR CODE HERE +# ============== +# + +# %% +# to view the solution, execute the this cell +# %load snippets/solution_1_3.py + + +# %% [markdown] +# Setup of the Poisson loglikelihood objective function logL(y|x) in listmode +# --------------------------------------------------------------------------- + +# %% +# define the listmode objective function +lm_obj_fun = ( + sirf.STIR.PoissonLogLikelihoodWithLinearModelForMeanAndListModeDataWithProjMatrixByBin() +) +lm_obj_fun.set_acquisition_model(acq_model) +lm_obj_fun.set_acquisition_data(listmode_data) +lm_obj_fun.set_num_subsets(num_subsets) +lm_obj_fun.set_cache_max_size(1000000000) +lm_obj_fun.set_cache_path(str(output_path)) + +# %% [markdown] +# Reconstruction (optimization of the Poisson logL objective function) using listmode OSEM +# ---------------------------------------------------------------------------------------- + +# %% +if not Path(f"{lm_recon_output_file}.hv").exists(): + lm_reconstructor = sirf.STIR.OSMAPOSLReconstructor() + lm_reconstructor.set_objective_function(lm_obj_fun) + lm_reconstructor.set_num_subsets(num_subsets) + lm_reconstructor.set_num_subiterations(num_iter * num_subsets) + lm_reconstructor.set_up(initial_image) + lm_reconstructor.set_current_estimate(initial_image) + lm_reconstructor.process() + lm_ref_recon = lm_reconstructor.get_output() + lm_ref_recon.write(lm_recon_output_file) +else: + lm_ref_recon = sirf.STIR.ImageData(f"{lm_recon_output_file}.hv") + +fig3, ax3 = plt.subplots(1, 1, figsize=(4, 4), tight_layout=True) +ax3.imshow(lm_ref_recon.as_array()[71, :, :], cmap="Greys", vmin=0, vmax=vmax) +fig3.show() + +# %% [markdown] +# Exercise 1.4 +# ------------ +# Repeat exercise 1.3 (OSEM reconstruction) using the listmode objective function to +# learn how to do a listmode OSEM update step. + +# %% +# +# ============== +# YOUR CODE HERE +# ============== + +# %% +# to view the solution, execute the cell below +# %load snippets/solution_1_4.py + +# %% [markdown] +# Exercise 1.5 +# ------------ +# Rerun the sinogram and listmode reconstruction (first cells of the notebook) +# using the 60min acquisition data by adapting the `acq_time` variable. +# Make sure that you restart the kernel before running the cells and to rerun +# the all cells (including scatter and random estimation). +# We wil use the 60min reconstruction in our last notebook. diff --git a/notebooks/Deep_Learning_listmode_PET/02_SIRF_vs_torch_arrays.ipynb b/notebooks/Deep_Learning_listmode_PET/02_SIRF_vs_torch_arrays.ipynb new file mode 100644 index 00000000..5c1416d6 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/02_SIRF_vs_torch_arrays.ipynb @@ -0,0 +1,257 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9e1f726e", + "metadata": {}, + "source": [ + "SIRF.STIR ImageData objects vs numpy arrays vs torch tensors\n", + "============================================================" + ] + }, + { + "cell_type": "markdown", + "id": "5bca5247", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "Learning objectives of this notebook\n", + "------------------------------------\n", + "\n", + "1. Understanding the differences between SIRF ImageData, numpy arrays and torch tensors.\n", + "2. Learn how to convert between these different data types." + ] + }, + { + "cell_type": "markdown", + "id": "530a8fc0", + "metadata": {}, + "source": [ + "SIRF.STIR ImageData objects vs numpy arrays\n", + "-------------------------------------------" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9dc7b071", + "metadata": {}, + "outputs": [], + "source": [ + "# create a SIRF image template\n", + "\n", + "import sirf.STIR\n", + "from sirf.Utilities import examples_data_path\n", + "\n", + "# read an example PET acquisition data set that we can use\n", + "# to set up a compatible image data set\n", + "acq_data: sirf.STIR.AcquisitionData = sirf.STIR.AcquisitionData(\n", + " examples_data_path(\"PET\") + \"/brain/template_sinogram.hs\"\n", + ")\n", + "\n", + "# create a SIRF image compatible with the acquisition data\n", + "# uses default voxel sizes and dimensions\n", + "sirf_image_1: sirf.STIR.ImageData = acq_data.create_uniform_image(1.0)\n", + "sirf_image_2: sirf.STIR.ImageData = acq_data.create_uniform_image(2.0)\n", + "\n", + "image_shape: tuple[int, int, int] = sirf_image_1.shape\n", + "\n", + "print()\n", + "print(f\"sirf_image_1 shape .: {sirf_image_1.shape}\")\n", + "print(f\"sirf_image_1 spacing .: {sirf_image_1.spacing}\")\n", + "print(f\"sirf_image_1 max .: {sirf_image_1.max()}\")\n", + "print()\n", + "print(f\"sirf_image_2 shape .: {sirf_image_2.shape}\")\n", + "print(f\"sirf_image_2 spacing .: {sirf_image_2.spacing}\")\n", + "print(f\"sirf_image_2 max .: {sirf_image_2.max()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e81cef5", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "# you retrieve the data behind a SIRF.STIR image as numpy array using the as_array() method\n", + "import numpy as np\n", + "\n", + "numpy_image_1: np.ndarray = sirf_image_1.as_array()\n", + "numpy_image_2: np.ndarray = sirf_image_2.as_array()\n", + "\n", + "numpy_image_2_modified = numpy_image_2.copy()\n", + "numpy_image_2_modified[0, 0, 0] = 5.0\n", + "numpy_image_2_modified[-1, -1, -1] = -4.0\n", + "\n", + "print()\n", + "print(f\"numpy_image_1 shape .: {numpy_image_1.shape}\")\n", + "print(f\"numpy_image_1 max .: {numpy_image_1.max()}\")\n", + "print()\n", + "print(f\"numpy_image_2 shape .: {numpy_image_2.shape}\")\n", + "print(f\"numpy_image_2 max .: {numpy_image_2.max()}\")\n", + "print()\n", + "print(f\"numpy_image_2_modified shape .: {numpy_image_2_modified.shape}\")\n", + "print(f\"numpy_image_2_modified max .: {numpy_image_2_modified.max()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "099104c6", + "metadata": {}, + "outputs": [], + "source": [ + "# you can convert a numpy array into a SIRF.STIR image using the fill() method\n", + "\n", + "# create a copy of sirf_image_2\n", + "sirf_image_2_modified = sirf_image_2.get_uniform_copy()\n", + "sirf_image_2_modified.fill(numpy_image_2_modified)\n", + "\n", + "print()\n", + "print(f\"sirf_image_2 shape .: {sirf_image_2.shape}\")\n", + "print(f\"sirf_image_2 spacing .: {sirf_image_2.spacing}\")\n", + "print(f\"sirf_image_2 max .: {sirf_image_2.max()}\")" + ] + }, + { + "cell_type": "markdown", + "id": "fcd8d734", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "Exercise 2.1\n", + "------------\n", + "\n", + "Create a SIRF.STIR image that is compatible with the acquisition data\n", + "where every image \"plane\" contains the \"plane number squared\"." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4873bdd4", + "metadata": {}, + "outputs": [], + "source": [ + "# uncomment the next line and run this cell\n", + "%load snippets/solution_2_1.py" + ] + }, + { + "cell_type": "markdown", + "id": "00d4c5f4", + "metadata": {}, + "source": [ + "torch tensors vs numpy arrays\n", + "-----------------------------" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0be9bf93", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "# torch tensors can live on different devices\n", + "if torch.cuda.is_available():\n", + " # if cuda is availalbe, we want our torch tensor on the first CUDA device\n", + " dev = torch.device(\"cuda:0\")\n", + "else:\n", + " # otherwise we select the CPU as device\n", + " dev = torch.device(\"cpu\")\n", + "\n", + "torch_image_1: torch.Tensor = torch.ones(image_shape, dtype=torch.float32, device=dev)\n", + "\n", + "print()\n", + "print(f\"torch_image_1 shape .: {torch_image_1.shape}\")\n", + "print(f\"torch_image_1 max .: {torch_image_1.max()}\")\n", + "print(f\"torch_image_1 dtype .: {torch_image_1.dtype}\")\n", + "print(f\"torch_image_1 devive .: {torch_image_1.device}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dfc6d1e4", + "metadata": {}, + "outputs": [], + "source": [ + "# you can convert torch (GPU or CPU) tensors to numpy arrays using numpy() method\n", + "numpy_image_from_torch_1: np.ndarray = torch_image_1.cpu().numpy()\n", + "# see here: https://pytorch.org/docs/stable/generated/torch.Tensor.numpy.html\n", + "\n", + "# Attention: If the torch tensor lives on the CPU, the underlying array is not copied\n", + "# and shared between the numpy and torch object!\n", + "print()\n", + "print(f\"numpy data pointer {numpy_image_from_torch_1.ctypes.data}\")\n", + "print(f\"torch data pointer {torch_image_1.data_ptr()}\")\n", + "\n", + "if torch_image_1.data_ptr() == numpy_image_from_torch_1.ctypes.data:\n", + " print(\"numpy array and torch tensor share same data\")\n", + "else:\n", + " print(\"numpy array and torch tensor don't share same data\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36bbc8f4", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "# You can create torch tensors from numpy array using torch.from_numpy()\n", + "torch_image_from_numpy_1: torch.Tensor = torch.from_numpy(numpy_image_2)\n", + "print()\n", + "print(f\"torch_image_from_numpy_1 shape .: {torch_image_from_numpy_1.shape}\")\n", + "print(f\"torch_image_from_numpy_1 max .: {torch_image_from_numpy_1.max()}\")\n", + "\n", + "# torch.from_numpy() will create a Tensor living on the CPU\n", + "print()\n", + "print(f\"device of torch tensor from numpy {torch_image_from_numpy_1.device}\")\n", + "\n", + "# we can send the tensor to our prefered device using the .to() method\n", + "print(f\"sending tensor to device {dev.type}\")\n", + "torch_image_from_numpy_1.to(dev)\n", + "print(f\"device of torch tensor from numpy {torch_image_from_numpy_1.device}\")" + ] + }, + { + "cell_type": "markdown", + "id": "46894319", + "metadata": {}, + "source": [ + "Exercise 2.2\n", + "------------\n", + "\n", + "Now that we know how to convert between SIRF.STIR images and numpy arrays,\n", + "and between numpy arrays and torch tensors do the following:\n", + "1. convert a torch tensor full of \"3s\" into SIRF.STIR ImageData object compatible\n", + " with the acquisition data\n", + "2. convert a SIRF.STIR ImageData object \"sirf_image_1\" into a torch tensor on the\n", + " device \"dev\"\n", + "3. Predict whether the different image objects should share data and test your\n", + " hypothesis\n", + "4. Try to convert the torch tensor `torch.ones(image_shape, dtype=torch.float32, device=dev, requires_grad=True)`\n", + " into a numpy array. What do you observe?" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/Deep_Learning_listmode_PET/02_SIRF_vs_torch_arrays.py b/notebooks/Deep_Learning_listmode_PET/02_SIRF_vs_torch_arrays.py new file mode 100644 index 00000000..bf74c435 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/02_SIRF_vs_torch_arrays.py @@ -0,0 +1,160 @@ +# %% [markdown] +# SIRF.STIR ImageData objects vs numpy arrays vs torch tensors +# ============================================================ + +# %% [markdown] +# Learning objectives of this notebook +# ------------------------------------ +# +# 1. Understanding the differences between SIRF ImageData, numpy arrays and torch tensors. +# 2. Learn how to convert between these different data types. + + +# %% [markdown] +# SIRF.STIR ImageData objects vs numpy arrays +# ------------------------------------------- + +# %% +# create a SIRF image template + +import sirf.STIR +from sirf.Utilities import examples_data_path + +# read an example PET acquisition data set that we can use +# to set up a compatible image data set +acq_data: sirf.STIR.AcquisitionData = sirf.STIR.AcquisitionData( + examples_data_path("PET") + "/brain/template_sinogram.hs" +) + +# create a SIRF image compatible with the acquisition data +# uses default voxel sizes and dimensions +sirf_image_1: sirf.STIR.ImageData = acq_data.create_uniform_image(1.0) +sirf_image_2: sirf.STIR.ImageData = acq_data.create_uniform_image(2.0) + +image_shape: tuple[int, int, int] = sirf_image_1.shape + +print() +print(f"sirf_image_1 shape .: {sirf_image_1.shape}") +print(f"sirf_image_1 spacing .: {sirf_image_1.spacing}") +print(f"sirf_image_1 max .: {sirf_image_1.max()}") +print() +print(f"sirf_image_2 shape .: {sirf_image_2.shape}") +print(f"sirf_image_2 spacing .: {sirf_image_2.spacing}") +print(f"sirf_image_2 max .: {sirf_image_2.max()}") + +# %% +# you retrieve the data behind a SIRF.STIR image as numpy array using the as_array() method +import numpy as np + +numpy_image_1: np.ndarray = sirf_image_1.as_array() +numpy_image_2: np.ndarray = sirf_image_2.as_array() + +numpy_image_2_modified = numpy_image_2.copy() +numpy_image_2_modified[0, 0, 0] = 5.0 +numpy_image_2_modified[-1, -1, -1] = -4.0 + +print() +print(f"numpy_image_1 shape .: {numpy_image_1.shape}") +print(f"numpy_image_1 max .: {numpy_image_1.max()}") +print() +print(f"numpy_image_2 shape .: {numpy_image_2.shape}") +print(f"numpy_image_2 max .: {numpy_image_2.max()}") +print() +print(f"numpy_image_2_modified shape .: {numpy_image_2_modified.shape}") +print(f"numpy_image_2_modified max .: {numpy_image_2_modified.max()}") + + +# %% +# you can convert a numpy array into a SIRF.STIR image using the fill() method + +# create a copy of sirf_image_2 +sirf_image_2_modified = sirf_image_2.get_uniform_copy() +sirf_image_2_modified.fill(numpy_image_2_modified) + +print() +print(f"sirf_image_2 shape .: {sirf_image_2.shape}") +print(f"sirf_image_2 spacing .: {sirf_image_2.spacing}") +print(f"sirf_image_2 max .: {sirf_image_2.max()}") + +# %% [markdown] +# Exercise 2.1 +# ------------ +# +# Create a SIRF.STIR image that is compatible with the acquisition data +# where every image "plane" contains the "plane number squared". + + +# %% +# uncomment the next line and run this cell +# %load snippets/solution_2_1.py + +# %% [markdown] +# torch tensors vs numpy arrays +# ----------------------------- + +# %% +import torch + +# torch tensors can live on different devices +if torch.cuda.is_available(): + # if cuda is availalbe, we want our torch tensor on the first CUDA device + dev = torch.device("cuda:0") +else: + # otherwise we select the CPU as device + dev = torch.device("cpu") + +torch_image_1: torch.Tensor = torch.ones(image_shape, dtype=torch.float32, device=dev) + +print() +print(f"torch_image_1 shape .: {torch_image_1.shape}") +print(f"torch_image_1 max .: {torch_image_1.max()}") +print(f"torch_image_1 dtype .: {torch_image_1.dtype}") +print(f"torch_image_1 devive .: {torch_image_1.device}") + +# %% +# you can convert torch (GPU or CPU) tensors to numpy arrays using numpy() method +numpy_image_from_torch_1: np.ndarray = torch_image_1.cpu().numpy() +# see here: https://pytorch.org/docs/stable/generated/torch.Tensor.numpy.html + +# Attention: If the torch tensor lives on the CPU, the underlying array is not copied +# and shared between the numpy and torch object! +print() +print(f"numpy data pointer {numpy_image_from_torch_1.ctypes.data}") +print(f"torch data pointer {torch_image_1.data_ptr()}") + +if torch_image_1.data_ptr() == numpy_image_from_torch_1.ctypes.data: + print("numpy array and torch tensor share same data") +else: + print("numpy array and torch tensor don't share same data") + +# %% +# You can create torch tensors from numpy array using torch.from_numpy() +torch_image_from_numpy_1: torch.Tensor = torch.from_numpy(numpy_image_2) +print() +print(f"torch_image_from_numpy_1 shape .: {torch_image_from_numpy_1.shape}") +print(f"torch_image_from_numpy_1 max .: {torch_image_from_numpy_1.max()}") + +# torch.from_numpy() will create a Tensor living on the CPU +print() +print(f"device of torch tensor from numpy {torch_image_from_numpy_1.device}") + +# we can send the tensor to our prefered device using the .to() method +print(f"sending tensor to device {dev.type}") +torch_image_from_numpy_1.to(dev) +print(f"device of torch tensor from numpy {torch_image_from_numpy_1.device}") + + +# %% [markdown] +# Exercise 2.2 +# ------------ +# +# Now that we know how to convert between SIRF.STIR images and numpy arrays, +# and between numpy arrays and torch tensors do the following: +# 1. convert a torch tensor full of "3s" into SIRF.STIR ImageData object compatible +# with the acquisition data +# 2. convert a SIRF.STIR ImageData object "sirf_image_1" into a torch tensor on the +# device "dev" +# 3. Predict whether the different image objects should share data and test your +# hypothesis +# 4. Try to convert the torch tensor `torch.ones(image_shape, dtype=torch.float32, device=dev, requires_grad=True)` +# into a numpy array. What do you observe? diff --git a/notebooks/Deep_Learning_listmode_PET/03_custom_torch_layers.ipynb b/notebooks/Deep_Learning_listmode_PET/03_custom_torch_layers.ipynb new file mode 100644 index 00000000..628d206f --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/03_custom_torch_layers.ipynb @@ -0,0 +1,418 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "863d394a", + "metadata": {}, + "source": [ + "Creating custom layers in pytorch\n", + "=================================\n", + "\n", + "In this notebook, we will learn how to create custom layers in pytorch that use functions outside the pytorch framework.\n", + "We will create a custom layer that multiplies the input tensor with a square matrix.\n", + "For demonostration purposes, we will create a simple layer that multiplies a 1D torch input vector with a square matrix,\n", + "where the matrix multiplication is done using numpy functions.\n", + "\n", + "Learning objectives of this notebook\n", + "------------------------------------\n", + "\n", + "1. Learn how to create custom layers in pytorch that are compatible with the autograd framework.\n", + "2. Understand the importance of implementing the backward pass of the custom layer correctly.\n", + "3. Learn how to test the gradient backpropagation through the custom layer using the `torch.autograd.gradcheck` function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dea6bbea", + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "# import modules\n", + "import torch\n", + "import numpy as np\n", + "\n", + "# seed all torch random generators\n", + "torch.manual_seed(0)\n", + "\n", + "# choose the torch device\n", + "if torch.cuda.is_available():\n", + " dev = \"cuda:0\"\n", + "else:\n", + " dev = \"cpu\"\n", + "\n", + "# length of the input vector\n", + "n = 7\n", + "\n", + "# define our square matrix\n", + "A: np.ndarray = np.arange(n ** 2).reshape(n, n).astype(np.float64) / (n ** 2)\n", + "# define the 1D pytorch tensor: not that the shape is (1,1,n) including the batch and channel dimensions\n", + "x_t = torch.tensor(np.arange(n).reshape(1, 1, n).astype(np.float64), device=dev) / n" + ] + }, + { + "cell_type": "markdown", + "id": "1b9094b8", + "metadata": {}, + "source": [ + "Approach 1: The naive approach\n", + "------------------------------\n", + "\n", + "We will first try a naive approach where we create a custom layer by subclassing torch.nn.Module\n", + "and implementing the forward pass by conversion between numpy and torch tensors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb86f287", + "metadata": {}, + "outputs": [], + "source": [ + "class SquareMatrixMultiplicationLayer(torch.nn.Module):\n", + " def __init__(self, mat: np.ndarray) -> None:\n", + " super().__init__()\n", + " self._mat: np.ndarray = mat\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " # convert the input tensor to numpy\n", + " x_np = x.detach().cpu().numpy()\n", + " # nympy matrix multiplication\n", + " y_np = self._mat @ x_np[0, 0, ...]\n", + " # convert back to torch tensor\n", + " y = torch.tensor(y_np, device=x.device).unsqueeze(0).unsqueeze(0)\n", + "\n", + " return y" + ] + }, + { + "cell_type": "markdown", + "id": "ceb070a2", + "metadata": {}, + "source": [ + "We setup a simple feedforward network interlacing the 3 minimals convolutional layers and 3 square matrix multiplication layers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "803723d9", + "metadata": {}, + "outputs": [], + "source": [ + "class Net1(torch.nn.Module):\n", + " def __init__(self, mat, cnn) -> None:\n", + " super().__init__()\n", + " self._matrix_layer = SquareMatrixMultiplicationLayer(mat)\n", + " self._cnn = cnn\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " x1 = self._cnn(x)\n", + " x2 = self._matrix_layer(x1)\n", + " x3 = self._cnn(x2)\n", + " x4 = self._matrix_layer(x3)\n", + " x5 = self._cnn(x4)\n", + " x6 = self._matrix_layer(x5)\n", + "\n", + " return x6" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "900bf860", + "metadata": {}, + "outputs": [], + "source": [ + "# setup a simple CNN consisting of 2 convolutional layers and 1 ReLU activation\n", + "cnn1 = torch.nn.Sequential(\n", + " torch.nn.Conv1d(1, 3, (3,), padding=\"same\", bias=False, dtype=torch.float64),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Conv1d(3, 1, (3,), padding=\"same\", bias=False, dtype=torch.float64),\n", + ").to(dev)\n", + "\n", + "# setup the network\n", + "net1 = Net1(A, cnn1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5473fe3c", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "# forward pass of our input vector through the network\n", + "pred1 = net1(x_t)\n", + "print(f\"pred1: {pred1}\\n\")" + ] + }, + { + "cell_type": "markdown", + "id": "e56d51ca", + "metadata": {}, + "source": [ + "We see that the forward pass works as expected. Now we will setup a dummy loss and try backpropagate the gradients\n", + "using the naive approach for our custom matrix multiplication layer.\n", + "Baclpropagation of the gradients is the central step in training neural networks. It involves calculating the gradients of\n", + "the loss function with respect to the weights of the network." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b96d045", + "metadata": {}, + "outputs": [], + "source": [ + "# setup a dummy target (label / high quality reference image) tensor\n", + "target = 2 * x_t\n", + "# define an MSE loss\n", + "loss_fct = torch.nn.MSELoss()\n", + "# calculate the loss between the prediction and the target\n", + "loss1 = loss_fct(pred1, target)\n", + "print(f\"loss1: {loss1.item()}\\n\")" + ] + }, + { + "cell_type": "markdown", + "id": "f6458bf9", + "metadata": {}, + "source": [ + "Calculation of the loss still runs fine. Now let's try to backpropagate the gradients." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47c51d54", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " loss1.backward()\n", + "except RuntimeError:\n", + " print(\"Error in gradient backpropagation using naive approach\\n\")" + ] + }, + { + "cell_type": "markdown", + "id": "7d290dd2", + "metadata": {}, + "source": [ + "Exercise 3.1\n", + "------------\n", + "We see that the backpropagation of the gradients fails with the naive approach.\n", + "Why is that?" + ] + }, + { + "cell_type": "markdown", + "id": "4f8810a3", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "Approach 2: Subclassing torch.autograd.Function\n", + "-----------------------------------------------\n", + "\n", + "The correct way to create custom layers in pytorch is to subclass torch.autograd.Function\n", + "which involves implementing the forward and backward pass of the layer.\n", + "In the backward pass we have to implement the Jacobian transpose vector product of the layer.\n", + "For details, see [here](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#optional-reading-vector-calculus-using-autograd)\n", + "and [here](https://pytorch.org/docs/stable/notes/extending.func.html)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3e3cd2d", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "# define the custom layer by subclassing torch.autograd.Function and implementing the forward and backward pass\n", + "\n", + "\n", + "class NPSquareMatrixMultiplicationLayer(torch.autograd.Function):\n", + " @staticmethod\n", + " def forward(ctx, x: torch.Tensor, mat: np.ndarray) -> torch.Tensor:\n", + "\n", + " # we use the context object ctx to store the matrix and other variables that we need in the backward pass\n", + " ctx.mat = mat\n", + " ctx.device = x.device\n", + " ctx.shape = x.shape\n", + " ctx.dtype = x.dtype\n", + "\n", + " # convert to numpy\n", + " x_np = x.cpu().numpy()\n", + " # numpy matrix multiplication\n", + " y_np = mat @ x_np[0, 0, ...]\n", + " # convert back to torch tensor\n", + " y = torch.tensor(y_np, device=ctx.device).unsqueeze(0).unsqueeze(0)\n", + "\n", + " return y\n", + "\n", + " @staticmethod\n", + " def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, None]:\n", + " if grad_output is None:\n", + " return None, None\n", + " else:\n", + " # convert to numpy\n", + " grad_output_np = grad_output.cpu().numpy()\n", + " # calculate the Jacobian transpose vector product in numpy and convert back to torch tensor\n", + " back = (\n", + " torch.tensor(\n", + " ctx.mat.T @ grad_output_np[0, 0, ...],\n", + " device=ctx.device,\n", + " dtype=ctx.dtype,\n", + " )\n", + " .unsqueeze(0)\n", + " .unsqueeze(0)\n", + " )\n", + "\n", + " return back, None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b9cdd45", + "metadata": {}, + "outputs": [], + "source": [ + "# define a new network incl. the custom matrix multiplication layer using the \"correct\" approach\n", + "# To use our custom layer in the network, we have to use the apply method of the custom layer class.\n", + "\n", + "\n", + "class Net2(torch.nn.Module):\n", + " def __init__(self, mat, cnn) -> None:\n", + " super().__init__()\n", + " self._matrix_layer = NPSquareMatrixMultiplicationLayer.apply\n", + " self._mat = mat\n", + " self._cnn = cnn\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " x1 = self._cnn(x)\n", + " x2 = self._matrix_layer(x1, self._mat)\n", + " x3 = self._cnn(x2)\n", + " x4 = self._matrix_layer(x3, self._mat)\n", + " x5 = self._cnn(x4)\n", + " x6 = self._matrix_layer(x5, self._mat)\n", + "\n", + " return x6" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cef93f9d", + "metadata": {}, + "outputs": [], + "source": [ + "# setup the same CNN as above\n", + "cnn2 = torch.nn.Sequential(\n", + " torch.nn.Conv1d(1, 3, (3,), padding=\"same\", bias=False, dtype=torch.float64),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Conv1d(3, 1, (3,), padding=\"same\", bias=False, dtype=torch.float64),\n", + ").to(dev)\n", + "cnn2.load_state_dict(cnn1.state_dict())\n", + "\n", + "# setup the network - only difference is the custom layer\n", + "net2 = Net2(A, cnn2)\n", + "\n", + "# predict again\n", + "pred2 = net2(x_t)\n", + "print(f\"pred2: {pred2}\\n\")\n", + "\n", + "loss2 = loss_fct(pred2, target)\n", + "print(f\"loss2: {loss2.item()}\\n\")" + ] + }, + { + "cell_type": "markdown", + "id": "28d1c4d5", + "metadata": {}, + "source": [ + "Note that the prediction still works and gives the same result as before. Also the loss calculation yield the same results as before." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f90f98a4", + "metadata": {}, + "outputs": [], + "source": [ + "loss2.backward()\n", + "\n", + "# print backpropagated gradients that of all parameters of CNN layers of our network\n", + "print(\"backpropagated gradients using correct approach\")\n", + "print([p.grad for p in net2._cnn.parameters()])" + ] + }, + { + "cell_type": "markdown", + "id": "1b7ff1c4", + "metadata": {}, + "source": [ + "In contrast to the naive approach, the backpropagation of the gradients works fine now, meaning that this network is ready for training." + ] + }, + { + "cell_type": "markdown", + "id": "e5b383db", + "metadata": {}, + "source": [ + "Testing gradient backpropagation through the layer\n", + "--------------------------------------------------\n", + "\n", + "When defining new custom layers, it is crucial to test whether the backward pass is implemented correctly.\n", + "Otherwise the gradient backpropagation though the layer will be incorrect, and optimizing the model parameters will not work.\n", + "To test the gradient backpropagation, we can use the `torch.autograd.gradcheck` function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3def037", + "metadata": {}, + "outputs": [], + "source": [ + "# setup a test input tensor - requires grad must be True!\n", + "t_t = torch.rand(x_t.shape, device=dev, dtype=torch.float64, requires_grad=True)\n", + "\n", + "# test the gradient backpropagation through the custom numpy matrix multiplication layer\n", + "matrix_layer = NPSquareMatrixMultiplicationLayer.apply\n", + "gradcheck = torch.autograd.gradcheck(matrix_layer, (t_t, A), fast_mode=True)\n", + "\n", + "print(f\"gradient check of NPSquareMatrixMultiplicationLayer: {gradcheck}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c33cb3ba", + "metadata": {}, + "source": [ + "Exercise 3.2\n", + "------------\n", + "Temporarily change the backward pass of the custom layer such that is is not correct anymore\n", + "(e.g. by multiplying the output with 0.95) and rerun the gradient check. What do you observe?" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/Deep_Learning_listmode_PET/03_custom_torch_layers.py b/notebooks/Deep_Learning_listmode_PET/03_custom_torch_layers.py new file mode 100644 index 00000000..1b2b7fdd --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/03_custom_torch_layers.py @@ -0,0 +1,263 @@ +# %% [markdown] +# Creating custom layers in pytorch +# ================================= +# +# In this notebook, we will learn how to create custom layers in pytorch that use functions outside the pytorch framework. +# We will create a custom layer that multiplies the input tensor with a square matrix. +# For demonostration purposes, we will create a simple layer that multiplies a 1D torch input vector with a square matrix, +# where the matrix multiplication is done using numpy functions. +# +# Learning objectives of this notebook +# ------------------------------------ +# +# 1. Learn how to create custom layers in pytorch that are compatible with the autograd framework. +# 2. Understand the importance of implementing the backward pass of the custom layer correctly. +# 3. Learn how to test the gradient backpropagation through the custom layer using the `torch.autograd.gradcheck` function. + +# %% +# import modules +import torch +import numpy as np + +# seed all torch random generators +torch.manual_seed(0) + +# choose the torch device +if torch.cuda.is_available(): + dev = "cuda:0" +else: + dev = "cpu" + +# length of the input vector +n = 7 + +# define our square matrix +A: np.ndarray = np.arange(n ** 2).reshape(n, n).astype(np.float64) / (n ** 2) +# define the 1D pytorch tensor: not that the shape is (1,1,n) including the batch and channel dimensions +x_t = torch.tensor(np.arange(n).reshape(1, 1, n).astype(np.float64), device=dev) / n + +# %% [markdown] +# Approach 1: The naive approach +# ------------------------------ +# +# We will first try a naive approach where we create a custom layer by subclassing torch.nn.Module +# and implementing the forward pass by conversion between numpy and torch tensors. + +# %% +class SquareMatrixMultiplicationLayer(torch.nn.Module): + def __init__(self, mat: np.ndarray) -> None: + super().__init__() + self._mat: np.ndarray = mat + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # convert the input tensor to numpy + x_np = x.detach().cpu().numpy() + # nympy matrix multiplication + y_np = self._mat @ x_np[0, 0, ...] + # convert back to torch tensor + y = torch.tensor(y_np, device=x.device).unsqueeze(0).unsqueeze(0) + + return y + + +# %% [markdown] +# We setup a simple feedforward network interlacing the 3 minimals convolutional layers and 3 square matrix multiplication layers. + +# %% +class Net1(torch.nn.Module): + def __init__(self, mat, cnn) -> None: + super().__init__() + self._matrix_layer = SquareMatrixMultiplicationLayer(mat) + self._cnn = cnn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1 = self._cnn(x) + x2 = self._matrix_layer(x1) + x3 = self._cnn(x2) + x4 = self._matrix_layer(x3) + x5 = self._cnn(x4) + x6 = self._matrix_layer(x5) + + return x6 + + +# %% +# setup a simple CNN consisting of 2 convolutional layers and 1 ReLU activation +cnn1 = torch.nn.Sequential( + torch.nn.Conv1d(1, 3, (3,), padding="same", bias=False, dtype=torch.float64), + torch.nn.ReLU(), + torch.nn.Conv1d(3, 1, (3,), padding="same", bias=False, dtype=torch.float64), +).to(dev) + +# setup the network +net1 = Net1(A, cnn1) + +# %% +# forward pass of our input vector through the network +pred1 = net1(x_t) +print(f"pred1: {pred1}\n") + + +# %% [markdown] +# We see that the forward pass works as expected. Now we will setup a dummy loss and try backpropagate the gradients +# using the naive approach for our custom matrix multiplication layer. +# Baclpropagation of the gradients is the central step in training neural networks. It involves calculating the gradients of +# the loss function with respect to the weights of the network. + +# %% +# setup a dummy target (label / high quality reference image) tensor +target = 2 * x_t +# define an MSE loss +loss_fct = torch.nn.MSELoss() +# calculate the loss between the prediction and the target +loss1 = loss_fct(pred1, target) +print(f"loss1: {loss1.item()}\n") + +# %% [markdown] +# Calculation of the loss still runs fine. Now let's try to backpropagate the gradients. + +# %% +try: + loss1.backward() +except RuntimeError: + print("Error in gradient backpropagation using naive approach\n") + +# %% [markdown] +# Exercise 3.1 +# ------------ +# We see that the backpropagation of the gradients fails with the naive approach. +# Why is that? + +# %% [markdown] +# Approach 2: Subclassing torch.autograd.Function +# ----------------------------------------------- +# +# The correct way to create custom layers in pytorch is to subclass torch.autograd.Function +# which involves implementing the forward and backward pass of the layer. +# In the backward pass we have to implement the Jacobian transpose vector product of the layer. +# For details, see [here](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#optional-reading-vector-calculus-using-autograd) +# and [here](https://pytorch.org/docs/stable/notes/extending.func.html). + + +# %% +# define the custom layer by subclassing torch.autograd.Function and implementing the forward and backward pass + + +class NPSquareMatrixMultiplicationLayer(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, mat: np.ndarray) -> torch.Tensor: + + # we use the context object ctx to store the matrix and other variables that we need in the backward pass + ctx.mat = mat + ctx.device = x.device + ctx.shape = x.shape + ctx.dtype = x.dtype + + # convert to numpy + x_np = x.cpu().numpy() + # numpy matrix multiplication + y_np = mat @ x_np[0, 0, ...] + # convert back to torch tensor + y = torch.tensor(y_np, device=ctx.device).unsqueeze(0).unsqueeze(0) + + return y + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, None]: + if grad_output is None: + return None, None + else: + # convert to numpy + grad_output_np = grad_output.cpu().numpy() + # calculate the Jacobian transpose vector product in numpy and convert back to torch tensor + back = ( + torch.tensor( + ctx.mat.T @ grad_output_np[0, 0, ...], + device=ctx.device, + dtype=ctx.dtype, + ) + .unsqueeze(0) + .unsqueeze(0) + ) + + return back, None + + +# %% +# define a new network incl. the custom matrix multiplication layer using the "correct" approach +# To use our custom layer in the network, we have to use the apply method of the custom layer class. + + +class Net2(torch.nn.Module): + def __init__(self, mat, cnn) -> None: + super().__init__() + self._matrix_layer = NPSquareMatrixMultiplicationLayer.apply + self._mat = mat + self._cnn = cnn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1 = self._cnn(x) + x2 = self._matrix_layer(x1, self._mat) + x3 = self._cnn(x2) + x4 = self._matrix_layer(x3, self._mat) + x5 = self._cnn(x4) + x6 = self._matrix_layer(x5, self._mat) + + return x6 + + +# %% +# setup the same CNN as above +cnn2 = torch.nn.Sequential( + torch.nn.Conv1d(1, 3, (3,), padding="same", bias=False, dtype=torch.float64), + torch.nn.ReLU(), + torch.nn.Conv1d(3, 1, (3,), padding="same", bias=False, dtype=torch.float64), +).to(dev) +cnn2.load_state_dict(cnn1.state_dict()) + +# setup the network - only difference is the custom layer +net2 = Net2(A, cnn2) + +# predict again +pred2 = net2(x_t) +print(f"pred2: {pred2}\n") + +loss2 = loss_fct(pred2, target) +print(f"loss2: {loss2.item()}\n") + +# %% [markdown] +# Note that the prediction still works and gives the same result as before. Also the loss calculation yield the same results as before. + +# %% +loss2.backward() + +# print backpropagated gradients that of all parameters of CNN layers of our network +print("backpropagated gradients using correct approach") +print([p.grad for p in net2._cnn.parameters()]) + +# %% [markdown] +# In contrast to the naive approach, the backpropagation of the gradients works fine now, meaning that this network is ready for training. + +# %% [markdown] +# Testing gradient backpropagation through the layer +# -------------------------------------------------- +# +# When defining new custom layers, it is crucial to test whether the backward pass is implemented correctly. +# Otherwise the gradient backpropagation though the layer will be incorrect, and optimizing the model parameters will not work. +# To test the gradient backpropagation, we can use the `torch.autograd.gradcheck` function. + +# %% +# setup a test input tensor - requires grad must be True! +t_t = torch.rand(x_t.shape, device=dev, dtype=torch.float64, requires_grad=True) + +# test the gradient backpropagation through the custom numpy matrix multiplication layer +matrix_layer = NPSquareMatrixMultiplicationLayer.apply +gradcheck = torch.autograd.gradcheck(matrix_layer, (t_t, A), fast_mode=True) + +print(f"gradient check of NPSquareMatrixMultiplicationLayer: {gradcheck}") + +# %% [markdown] +# Exercise 3.2 +# ------------ +# Temporarily change the backward pass of the custom layer such that is is not correct anymore +# (e.g. by multiplying the output with 0.95) and rerun the gradient check. What do you observe? diff --git a/notebooks/Deep_Learning_listmode_PET/04_custom_sirf_Poisson_logL_layer.ipynb b/notebooks/Deep_Learning_listmode_PET/04_custom_sirf_Poisson_logL_layer.ipynb new file mode 100644 index 00000000..4883c2d8 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/04_custom_sirf_Poisson_logL_layer.ipynb @@ -0,0 +1,496 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e927b2e5", + "metadata": {}, + "source": [ + "Creating custom Poisson log likelihood gradient step and OSEM update layers\n", + "===========================================================================\n", + "\n", + "Learning objectives\n", + "-------------------\n", + "\n", + "1. Implement the forward and backward pass of a custom (pytorch autograd compatible) layer that\n", + " calculates the gradient Poisson log-likelihood.\n", + "2. Understand how to test whether the (backward pass) of the custom layer is implemented correctly,\n", + " such that gradient backpropagation works as expected." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "642a9eb8", + "metadata": {}, + "outputs": [], + "source": [ + "import sirf.STIR\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "from pathlib import Path\n", + "from sirf.Utilities import examples_data_path\n", + "\n", + "# acq_time must be 1min\n", + "acq_time: str = \"1min\"\n", + "\n", + "data_path: Path = Path(examples_data_path(\"PET\")) / \"mMR\"\n", + "list_file: str = str(data_path / \"list.l.hdr\")\n", + "norm_file: str = str(data_path / \"norm.n.hdr\")\n", + "attn_file: str = str(data_path / \"mu_map.hv\")\n", + "\n", + "output_path: Path = Path(f\"recons_{acq_time}\")\n", + "emission_sinogram_output_prefix: str = str(output_path / \"emission_sinogram\")\n", + "scatter_sinogram_output_prefix: str = str(output_path / \"scatter_sinogram\")\n", + "randoms_sinogram_output_prefix: str = str(output_path / \"randoms_sinogram\")\n", + "attenuation_sinogram_output_prefix: str = str(output_path / \"acf_sinogram\")\n", + "num_scatter_iter: int = 3\n", + "\n", + "lm_recon_output_file: str = str(output_path / \"lm_recon\")\n", + "nxny: tuple[int, int] = (127, 127)\n", + "num_subsets: int = 21\n", + "\n", + "if torch.cuda.is_available():\n", + " dev = \"cuda:0\"\n", + "else:\n", + " dev = \"cpu\"\n", + "\n", + "# engine's messages go to files, except error messages, which go to stdout\n", + "_ = sirf.STIR.MessageRedirector(\"info.txt\", \"warn.txt\")" + ] + }, + { + "cell_type": "markdown", + "id": "f33f071a", + "metadata": {}, + "source": [ + "Load listmode data and create the acquisition model\n", + "---------------------------------------------------\n", + "\n", + "In this demo example, we use a simplified acquisition model that only implements the geometric forward projection.\n", + "The effects of normalization, attenuation, scatter, randoms, are ignored but can be added as shown in the last\n", + "example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad999928", + "metadata": {}, + "outputs": [], + "source": [ + "sirf.STIR.AcquisitionData.set_storage_scheme(\"memory\")\n", + "listmode_data = sirf.STIR.ListmodeData(list_file)\n", + "acq_data_template = listmode_data.acquisition_data_template()\n", + "\n", + "acq_data = sirf.STIR.AcquisitionData(\n", + " str(Path(f\"{emission_sinogram_output_prefix}_f1g1d0b0.hs\"))\n", + ")\n", + "\n", + "# select acquisition model that implements the geometric\n", + "# forward projection by a ray tracing matrix multiplication\n", + "acq_model = sirf.STIR.AcquisitionModelUsingRayTracingMatrix()\n", + "acq_model.set_num_tangential_LORs(1)\n", + "\n", + "randoms = sirf.STIR.AcquisitionData(str(Path(f\"{randoms_sinogram_output_prefix}.hs\")))\n", + "\n", + "ac_factors = sirf.STIR.AcquisitionData(\n", + " str(Path(f\"{attenuation_sinogram_output_prefix}.hs\"))\n", + ")\n", + "asm_attn = sirf.STIR.AcquisitionSensitivityModel(ac_factors)\n", + "\n", + "asm_norm = sirf.STIR.AcquisitionSensitivityModel(norm_file)\n", + "asm = sirf.STIR.AcquisitionSensitivityModel(asm_norm, asm_attn)\n", + "\n", + "asm.set_up(acq_data)\n", + "acq_model.set_acquisition_sensitivity(asm)\n", + "\n", + "scatter_estimate = sirf.STIR.AcquisitionData(\n", + " str(Path(f\"{scatter_sinogram_output_prefix}_{num_scatter_iter}.hs\"))\n", + ")\n", + "acq_model.set_background_term(randoms + scatter_estimate)\n", + "\n", + "# setup an initial (template) image based on the acquisition data template\n", + "initial_image = acq_data_template.create_uniform_image(value=1, xy=nxny)\n", + "\n", + "# load the reconstructed image from notebook 01\n", + "lm_ref_recon = sirf.STIR.ImageData(f\"{lm_recon_output_file}.hv\")" + ] + }, + { + "cell_type": "markdown", + "id": "b17e4472", + "metadata": {}, + "source": [ + "Setup of the Poisson log likelihood listmode objective function\n", + "---------------------------------------------------------------\n", + "\n", + "Using the listmode data and the acquisition model, we can now setup the Poisson log likelihood objective function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9d20277", + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "lm_obj_fun = (\n", + " sirf.STIR.PoissonLogLikelihoodWithLinearModelForMeanAndListModeDataWithProjMatrixByBin()\n", + ")\n", + "lm_obj_fun.set_acquisition_model(acq_model)\n", + "lm_obj_fun.set_acquisition_data(listmode_data)\n", + "lm_obj_fun.set_num_subsets(num_subsets)\n", + "lm_obj_fun.set_cache_max_size(1000000000)\n", + "lm_obj_fun.set_cache_path(str(output_path))\n", + "print(\"setting up listmode objective function ...\")\n", + "lm_obj_fun.set_up(initial_image)" + ] + }, + { + "cell_type": "markdown", + "id": "773f9910", + "metadata": {}, + "source": [ + "Setup of a pytorch layer that computes the gradient of the Poisson log likelihood objective function\n", + "----------------------------------------------------------------------------------------------------\n", + "\n", + "Using our listmode objective function, we can now implement a custom pytorch layer that computes the gradient\n", + "of the Poisson log likelihood using the `gradient()` method using a subset of the listmode data.\n", + "\n", + "This layer maps a torch minibatch tensor to another torch tensor of the same shape.\n", + "The shape of the minibatch tensor is [batch_size=1, channel_size=1, spatial dimensions].\n", + "For the implementation we subclass `torch.autograd.Function` and implement the `forward()` and\n", + "`backward()` methods." + ] + }, + { + "cell_type": "markdown", + "id": "7e08c001", + "metadata": {}, + "source": [ + "Exercise 4.1\n", + "------------\n", + "\n", + "Using your knowledge of the Poisson log likelihood gradient (exercise 0.1) and the content of the notebook 03\n", + "on custom layers, implement the forward and backward pass of a custom layer that calculates the gradient of the\n", + "Poisson log likelihood using a SIRF objective function as shown in the figure below.\n", + "\n", + "# ![](figs/poisson_logL_grad_layer.drawio.svg)\n", + "\n", + "The next cell contains the skeleton of the custom layer. You need to fill in the missing parts in the forward and\n", + "backward pass." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2049496", + "metadata": {}, + "outputs": [], + "source": [ + "class SIRFPoissonlogLGradLayer(torch.autograd.Function):\n", + " @staticmethod\n", + " def forward(\n", + " ctx,\n", + " x: torch.Tensor,\n", + " objective_function,\n", + " sirf_template_image: sirf.STIR.ImageData,\n", + " subset: int,\n", + " ) -> torch.Tensor:\n", + " \"\"\"(listmode) Poisson loglikelihood gradient layer forward pass\n", + "\n", + " Parameters\n", + " ----------\n", + " ctx : context object\n", + " used to store objects that we need in the backward pass\n", + " x : torch.Tensor\n", + " minibatch tensor of shape [1,1,spatial_dimensions] containing the image\n", + " objective_function : sirf (listmode) objective function\n", + " the objective function that we use to calculate the gradient\n", + " sirf_template_image : sirf.STIR.ImageData\n", + " image template that we use to convert between torch tensors and sirf images\n", + " subset : int\n", + " subset number used for the gradient calculation\n", + "\n", + " Returns\n", + " -------\n", + " torch.Tensor\n", + " minibatch tensor of shape [1,1,spatial_dimensions] containing the image\n", + " containing the gradient of the (listmode) Poisson log likelihood at x\n", + " \"\"\"\n", + " # we use the context object ctx to store objects that we need in the backward pass\n", + " ctx.device = x.device\n", + " ctx.objective_function = objective_function\n", + " ctx.dtype = x.dtype\n", + " ctx.subset = subset\n", + " ctx.sirf_template_image = sirf_template_image\n", + "\n", + " # ==============================================================\n", + " # ==============================================================\n", + " # YOUR CODE HERE\n", + " # ==============================================================\n", + " # ==============================================================\n", + "\n", + " @staticmethod\n", + " def backward(\n", + " ctx, grad_output: torch.Tensor | None\n", + " ) -> tuple[torch.Tensor | None, None, None, None]:\n", + " \"\"\"(listmode) Poisson loglikelihood gradient layer backward pass\n", + "\n", + " Parameters\n", + " ----------\n", + " ctx : context object\n", + " used to store objects that we need in the backward pass\n", + " grad_output : torch.Tensor | None\n", + " minibatch tensor of shape [1,1,spatial_dimensions] containing the gradient (called v in the autograd tutorial)\n", + " https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#optional-reading-vector-calculus-using-autograd\n", + "\n", + " Returns\n", + " -------\n", + " tuple[torch.Tensor | None, None, None, None]\n", + " the Jacobian-vector product of the Poisson log likelihood gradient layer\n", + " \"\"\"\n", + "\n", + " if grad_output is None:\n", + " return None, None, None, None\n", + " else:\n", + " ctx.sirf_template_image.fill(grad_output.cpu().numpy()[0, 0, ...])\n", + "\n", + " # ==============================================================\n", + " # ==============================================================\n", + " # YOUR CODE HERE\n", + " # --------------\n", + " #\n", + " # calculate the Jacobian-vector product of the Poisson log likelihood gradient layer\n", + " # Hints: (1) try to derive the Jacobian of the gradient of the Poisson log likelihood gradient first\n", + " # (2) the sirf.STIR objective function has a method called `multiply_with_Hessian`\n", + " #\n", + " # ==============================================================\n", + " # ==============================================================" + ] + }, + { + "cell_type": "markdown", + "id": "95f9dae6", + "metadata": {}, + "source": [ + "To view the solution to the exercise, execute the next cell." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e947f4f0", + "metadata": {}, + "outputs": [], + "source": [ + "%load snippets/solution_4_1.py" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff42b23b", + "metadata": {}, + "outputs": [], + "source": [ + "# convert to torch tensor and add the minibatch and channel dimensions\n", + "x_t = (\n", + " torch.tensor(\n", + " lm_ref_recon.as_array(), device=dev, dtype=torch.float32, requires_grad=False\n", + " )\n", + " .unsqueeze(0)\n", + " .unsqueeze(0)\n", + ")\n", + "\n", + "# setup our custom Poisson log likelihood gradient layer\n", + "poisson_logL_grad_layer = SIRFPoissonlogLGradLayer.apply\n", + "# perform the forward pass (calcuate the gradient of the Poisson log likelihood at x_t)\n", + "grad_x = poisson_logL_grad_layer(x_t, lm_obj_fun, initial_image, 0)" + ] + }, + { + "cell_type": "markdown", + "id": "5c71a89e", + "metadata": {}, + "source": [ + "Implementing a OSEM update layer using our custom Poisson log likelihood gradient layer\n", + "=======================================================================================\n", + "\n", + "Using our custom Poisson log likelihood gradient layer, we can now implement a custom OSEM update layer.\n", + "Note that the OSEM update can be decomposed into a simple feedforward network consisting of basic arithmetic\n", + "operations that are implemented in pytorch (pointwise multiplication and addition) as shown in the figure below.\n", + "\n", + "# ![](figs/osem_layer.drawio.svg)" + ] + }, + { + "cell_type": "markdown", + "id": "42b8eb52", + "metadata": {}, + "source": [ + "Exercise 4.2\n", + "------------\n", + "Implement the forward pass of a OSEM update layer using the Poisson log likelihood gradient layer that we implemented\n", + "above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a4db256", + "metadata": {}, + "outputs": [], + "source": [ + "class OSEMUpdateLayer(torch.nn.Module):\n", + " def __init__(\n", + " self,\n", + " objective_function,\n", + " sirf_template_image: sirf.STIR.ImageData,\n", + " subset: int,\n", + " device: str,\n", + " ) -> None:\n", + " \"\"\"OSEM update layer\n", + "\n", + " Parameters\n", + " ----------\n", + " objective_function : sirf (listmode) objective function\n", + " the objective function that we use to calculate the gradient\n", + " sirf_template_image : sirf.STIR.ImageData\n", + " image template that we use to convert between torch tensors and sirf images\n", + " subset : int\n", + " subset number used for the gradient calculation\n", + " device : str\n", + " device used for the calculations\n", + "\n", + " Returns\n", + " -------\n", + " torch.Tensor\n", + " minibatch tensor of shape [1,1,spatial_dimensions] containing the OSEM\n", + " update of the input image using the Poisson log likelihood objective function\n", + " \"\"\"\n", + " super().__init__()\n", + " self._objective_function = objective_function\n", + " self._sirf_template_image: sirf.STIR.ImageData = sirf_template_image\n", + " self._subset: int = subset\n", + "\n", + " self._poisson_logL_grad_layer = SIRFPoissonlogLGradLayer.apply\n", + "\n", + " # setup a tensor containng the inverse of the subset sensitivity image adding the minibatch and channel dimensions\n", + " self._inv_sens_image: torch.Tensor = 1.0 / torch.tensor(\n", + " objective_function.get_subset_sensitivity(subset).as_array(),\n", + " dtype=torch.float32,\n", + " device=device,\n", + " ).unsqueeze(0).unsqueeze(0)\n", + " # replace positive infinity values with 0 (voxels with 0 sensitivity)\n", + " torch.nan_to_num(self._inv_sens_image, posinf=0, out=self._inv_sens_image)\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"forward pass of the OSEM update layer\n", + "\n", + " Parameters\n", + " ----------\n", + " x : torch.Tensor\n", + " minibatch tensor of shape [1,1,spatial_dimensions] containing the image\n", + "\n", + " Returns\n", + " -------\n", + " torch.Tensor\n", + " OSEM update image\n", + " \"\"\"\n", + "\n", + " # =======================================================================\n", + " # =======================================================================\n", + " # YOUR CODE HERE\n", + " # USE ONLY BASIC ARITHMETIC OPERATIONS BETWEEN TORCH TENSORS!\n", + " # =======================================================================\n", + " # =======================================================================" + ] + }, + { + "cell_type": "markdown", + "id": "cafe0398", + "metadata": {}, + "source": [ + "To view the solution to the exercise, execute the next cell." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3fe1d29", + "metadata": {}, + "outputs": [], + "source": [ + "%load snippets/solution_4_2.py" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "98031064", + "metadata": {}, + "outputs": [], + "source": [ + "# define the OSEM update layer for subset 0\n", + "osem_layer0 = OSEMUpdateLayer(lm_obj_fun, initial_image, 0, dev)\n", + "# perform the forward pass\n", + "osem_updated_x_t = osem_layer0(x_t)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d8639e1", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# show the input and output of the OSEM update layer\n", + "fig, ax = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True)\n", + "ax[0].imshow(x_t.cpu().numpy()[0, 0, 71, ...], cmap=\"Greys\")\n", + "ax[1].imshow(osem_updated_x_t.cpu().numpy()[0, 0, 71, ...], cmap=\"Greys\")\n", + "ax[2].imshow(\n", + " osem_updated_x_t.cpu().numpy()[0, 0, 71, ...] - x_t.cpu().numpy()[0, 0, 71, ...],\n", + " cmap=\"seismic\",\n", + " vmin=-0.01,\n", + " vmax=0.01,\n", + ")\n", + "ax[0].set_title(\"input image\")\n", + "ax[1].set_title(\"OSEM updated image\")\n", + "ax[2].set_title(\"diffence image\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "id": "566bd485", + "metadata": {}, + "source": [ + "Testing the backward pass of the custom layers\n", + "----------------------------------------------\n", + "\n", + "As mentioned in the previous notebook, it is important to test whether the backward pass\n", + "of the custom layer is implemented correctly using the `torch.autograd.gradcheck` function.\n", + "**However, we won't do this here** - but rather disuss the implementation - because:\n", + "- it can take long time\n", + "- because we are using float32, we have to adapt the tolerances\n", + "- the sirf.STIR gradient calculation is not exactly deterministic, due to parallelization and numerical precision\n", + " which also requires to adapt the tolerances for non-deterministic functions\n", + "\n", + "**If you implement a new layer, and you are not 100% sure that the backward pass is correct, you should always test it!**" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/Deep_Learning_listmode_PET/04_custom_sirf_Poisson_logL_layer.py b/notebooks/Deep_Learning_listmode_PET/04_custom_sirf_Poisson_logL_layer.py new file mode 100644 index 00000000..5f70e033 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/04_custom_sirf_Poisson_logL_layer.py @@ -0,0 +1,363 @@ +# %% [markdown] +# Creating custom Poisson log likelihood gradient step and OSEM update layers +# =========================================================================== +# +# Learning objectives +# ------------------- +# +# 1. Implement the forward and backward pass of a custom (pytorch autograd compatible) layer that +# calculates the gradient Poisson log-likelihood. +# 2. Understand how to test whether the (backward pass) of the custom layer is implemented correctly, +# such that gradient backpropagation works as expected. + +# %% +import sirf.STIR +import torch +import matplotlib.pyplot as plt +from pathlib import Path +from sirf.Utilities import examples_data_path + +# acq_time must be 1min +acq_time: str = "1min" + +data_path: Path = Path(examples_data_path("PET")) / "mMR" +list_file: str = str(data_path / "list.l.hdr") +norm_file: str = str(data_path / "norm.n.hdr") +attn_file: str = str(data_path / "mu_map.hv") + +output_path: Path = Path(f"recons_{acq_time}") +emission_sinogram_output_prefix: str = str(output_path / "emission_sinogram") +scatter_sinogram_output_prefix: str = str(output_path / "scatter_sinogram") +randoms_sinogram_output_prefix: str = str(output_path / "randoms_sinogram") +attenuation_sinogram_output_prefix: str = str(output_path / "acf_sinogram") +num_scatter_iter: int = 3 + +lm_recon_output_file: str = str(output_path / "lm_recon") +nxny: tuple[int, int] = (127, 127) +num_subsets: int = 21 + +if torch.cuda.is_available(): + dev = "cuda:0" +else: + dev = "cpu" + +# engine's messages go to files, except error messages, which go to stdout +_ = sirf.STIR.MessageRedirector("info.txt", "warn.txt") + +# %% [markdown] +# Load listmode data and create the acquisition model +# --------------------------------------------------- +# +# In this demo example, we use a simplified acquisition model that only implements the geometric forward projection. +# The effects of normalization, attenuation, scatter, randoms, are ignored but can be added as shown in the last +# example. + +# %% +sirf.STIR.AcquisitionData.set_storage_scheme("memory") +listmode_data = sirf.STIR.ListmodeData(list_file) +acq_data_template = listmode_data.acquisition_data_template() + +acq_data = sirf.STIR.AcquisitionData( + str(Path(f"{emission_sinogram_output_prefix}_f1g1d0b0.hs")) +) + +# select acquisition model that implements the geometric +# forward projection by a ray tracing matrix multiplication +acq_model = sirf.STIR.AcquisitionModelUsingRayTracingMatrix() +acq_model.set_num_tangential_LORs(1) + +randoms = sirf.STIR.AcquisitionData(str(Path(f"{randoms_sinogram_output_prefix}.hs"))) + +ac_factors = sirf.STIR.AcquisitionData( + str(Path(f"{attenuation_sinogram_output_prefix}.hs")) +) +asm_attn = sirf.STIR.AcquisitionSensitivityModel(ac_factors) + +asm_norm = sirf.STIR.AcquisitionSensitivityModel(norm_file) +asm = sirf.STIR.AcquisitionSensitivityModel(asm_norm, asm_attn) + +asm.set_up(acq_data) +acq_model.set_acquisition_sensitivity(asm) + +scatter_estimate = sirf.STIR.AcquisitionData( + str(Path(f"{scatter_sinogram_output_prefix}_{num_scatter_iter}.hs")) +) +acq_model.set_background_term(randoms + scatter_estimate) + +# setup an initial (template) image based on the acquisition data template +initial_image = acq_data_template.create_uniform_image(value=1, xy=nxny) + +# load the reconstructed image from notebook 01 +lm_ref_recon = sirf.STIR.ImageData(f"{lm_recon_output_file}.hv") + +# %% [markdown] +# Setup of the Poisson log likelihood listmode objective function +# --------------------------------------------------------------- +# +# Using the listmode data and the acquisition model, we can now setup the Poisson log likelihood objective function. + +# %% +lm_obj_fun = ( + sirf.STIR.PoissonLogLikelihoodWithLinearModelForMeanAndListModeDataWithProjMatrixByBin() +) +lm_obj_fun.set_acquisition_model(acq_model) +lm_obj_fun.set_acquisition_data(listmode_data) +lm_obj_fun.set_num_subsets(num_subsets) +lm_obj_fun.set_cache_max_size(1000000000) +lm_obj_fun.set_cache_path(str(output_path)) +print("setting up listmode objective function ...") +lm_obj_fun.set_up(initial_image) + +# %% [markdown] +# Setup of a pytorch layer that computes the gradient of the Poisson log likelihood objective function +# ---------------------------------------------------------------------------------------------------- +# +# Using our listmode objective function, we can now implement a custom pytorch layer that computes the gradient +# of the Poisson log likelihood using the `gradient()` method using a subset of the listmode data. +# +# This layer maps a torch minibatch tensor to another torch tensor of the same shape. +# The shape of the minibatch tensor is [batch_size=1, channel_size=1, spatial dimensions]. +# For the implementation we subclass `torch.autograd.Function` and implement the `forward()` and +# `backward()` methods. + +# %% [markdown] +# Exercise 4.1 +# ------------ +# +# Using your knowledge of the Poisson log likelihood gradient (exercise 0.1) and the content of the notebook 03 +# on custom layers, implement the forward and backward pass of a custom layer that calculates the gradient of the +# Poisson log likelihood using a SIRF objective function as shown in the figure below. +# +# # ![](figs/poisson_logL_grad_layer.drawio.svg) +# +# The next cell contains the skeleton of the custom layer. You need to fill in the missing parts in the forward and +# backward pass. + +# %% +class SIRFPoissonlogLGradLayer(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + objective_function, + sirf_template_image: sirf.STIR.ImageData, + subset: int, + ) -> torch.Tensor: + """(listmode) Poisson loglikelihood gradient layer forward pass + + Parameters + ---------- + ctx : context object + used to store objects that we need in the backward pass + x : torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the image + objective_function : sirf (listmode) objective function + the objective function that we use to calculate the gradient + sirf_template_image : sirf.STIR.ImageData + image template that we use to convert between torch tensors and sirf images + subset : int + subset number used for the gradient calculation + + Returns + ------- + torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the image + containing the gradient of the (listmode) Poisson log likelihood at x + """ + # we use the context object ctx to store objects that we need in the backward pass + ctx.device = x.device + ctx.objective_function = objective_function + ctx.dtype = x.dtype + ctx.subset = subset + ctx.sirf_template_image = sirf_template_image + + # ============================================================== + # ============================================================== + # YOUR CODE HERE + # ============================================================== + # ============================================================== + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor | None + ) -> tuple[torch.Tensor | None, None, None, None]: + """(listmode) Poisson loglikelihood gradient layer backward pass + + Parameters + ---------- + ctx : context object + used to store objects that we need in the backward pass + grad_output : torch.Tensor | None + minibatch tensor of shape [1,1,spatial_dimensions] containing the gradient (called v in the autograd tutorial) + https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#optional-reading-vector-calculus-using-autograd + + Returns + ------- + tuple[torch.Tensor | None, None, None, None] + the Jacobian-vector product of the Poisson log likelihood gradient layer + """ + + if grad_output is None: + return None, None, None, None + else: + ctx.sirf_template_image.fill(grad_output.cpu().numpy()[0, 0, ...]) + + # ============================================================== + # ============================================================== + # YOUR CODE HERE + # -------------- + # + # calculate the Jacobian-vector product of the Poisson log likelihood gradient layer + # Hints: (1) try to derive the Jacobian of the gradient of the Poisson log likelihood gradient first + # (2) the sirf.STIR objective function has a method called `multiply_with_Hessian` + # + # ============================================================== + # ============================================================== + + +# %% [markdown] +# To view the solution to the exercise, execute the next cell. + +# %% +# %load snippets/solution_4_1.py + +# %% +# convert to torch tensor and add the minibatch and channel dimensions +x_t = ( + torch.tensor( + lm_ref_recon.as_array(), device=dev, dtype=torch.float32, requires_grad=False + ) + .unsqueeze(0) + .unsqueeze(0) +) + +# setup our custom Poisson log likelihood gradient layer +poisson_logL_grad_layer = SIRFPoissonlogLGradLayer.apply +# perform the forward pass (calcuate the gradient of the Poisson log likelihood at x_t) +grad_x = poisson_logL_grad_layer(x_t, lm_obj_fun, initial_image, 0) + + +# %% [markdown] +# Implementing a OSEM update layer using our custom Poisson log likelihood gradient layer +# ======================================================================================= +# +# Using our custom Poisson log likelihood gradient layer, we can now implement a custom OSEM update layer. +# Note that the OSEM update can be decomposed into a simple feedforward network consisting of basic arithmetic +# operations that are implemented in pytorch (pointwise multiplication and addition) as shown in the figure below. +# +# # ![](figs/osem_layer.drawio.svg) + +# %% [markdown] +# Exercise 4.2 +# ------------ +# Implement the forward pass of a OSEM update layer using the Poisson log likelihood gradient layer that we implemented +# above. + +# %% +class OSEMUpdateLayer(torch.nn.Module): + def __init__( + self, + objective_function, + sirf_template_image: sirf.STIR.ImageData, + subset: int, + device: str, + ) -> None: + """OSEM update layer + + Parameters + ---------- + objective_function : sirf (listmode) objective function + the objective function that we use to calculate the gradient + sirf_template_image : sirf.STIR.ImageData + image template that we use to convert between torch tensors and sirf images + subset : int + subset number used for the gradient calculation + device : str + device used for the calculations + + Returns + ------- + torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the OSEM + update of the input image using the Poisson log likelihood objective function + """ + super().__init__() + self._objective_function = objective_function + self._sirf_template_image: sirf.STIR.ImageData = sirf_template_image + self._subset: int = subset + + self._poisson_logL_grad_layer = SIRFPoissonlogLGradLayer.apply + + # setup a tensor containng the inverse of the subset sensitivity image adding the minibatch and channel dimensions + self._inv_sens_image: torch.Tensor = 1.0 / torch.tensor( + objective_function.get_subset_sensitivity(subset).as_array(), + dtype=torch.float32, + device=device, + ).unsqueeze(0).unsqueeze(0) + # replace positive infinity values with 0 (voxels with 0 sensitivity) + torch.nan_to_num(self._inv_sens_image, posinf=0, out=self._inv_sens_image) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """forward pass of the OSEM update layer + + Parameters + ---------- + x : torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the image + + Returns + ------- + torch.Tensor + OSEM update image + """ + + # ======================================================================= + # ======================================================================= + # YOUR CODE HERE + # USE ONLY BASIC ARITHMETIC OPERATIONS BETWEEN TORCH TENSORS! + # ======================================================================= + # ======================================================================= + + +# %% [markdown] +# To view the solution to the exercise, execute the next cell. + +# %% +# %load snippets/solution_4_2.py + +# %% +# define the OSEM update layer for subset 0 +osem_layer0 = OSEMUpdateLayer(lm_obj_fun, initial_image, 0, dev) +# perform the forward pass +osem_updated_x_t = osem_layer0(x_t) + +# %% + +# show the input and output of the OSEM update layer +fig, ax = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True) +ax[0].imshow(x_t.cpu().numpy()[0, 0, 71, ...], cmap="Greys") +ax[1].imshow(osem_updated_x_t.cpu().numpy()[0, 0, 71, ...], cmap="Greys") +ax[2].imshow( + osem_updated_x_t.cpu().numpy()[0, 0, 71, ...] - x_t.cpu().numpy()[0, 0, 71, ...], + cmap="seismic", + vmin=-0.01, + vmax=0.01, +) +ax[0].set_title("input image") +ax[1].set_title("OSEM updated image") +ax[2].set_title("diffence image") +fig.show() + +# %% [markdown] +# Testing the backward pass of the custom layers +# ---------------------------------------------- +# +# As mentioned in the previous notebook, it is important to test whether the backward pass +# of the custom layer is implemented correctly using the `torch.autograd.gradcheck` function. +# **However, we won't do this here** - but rather disuss the implementation - because: +# - it can take long time +# - because we are using float32, we have to adapt the tolerances +# - the sirf.STIR gradient calculation is not exactly deterministic, due to parallelization and numerical precision +# which also requires to adapt the tolerances for non-deterministic functions +# +# **If you implement a new layer, and you are not 100% sure that the backward pass is correct, you should always test it!** diff --git a/notebooks/Deep_Learning_listmode_PET/05_custrom_unrolled_varnet.ipynb b/notebooks/Deep_Learning_listmode_PET/05_custrom_unrolled_varnet.ipynb new file mode 100644 index 00000000..dc8fae3c --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/05_custrom_unrolled_varnet.ipynb @@ -0,0 +1,609 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "63066d5f", + "metadata": {}, + "source": [ + "Creating a custom unrolled variational network for listmode PET data\n", + "====================================================================\n", + "\n", + "Learning objectives\n", + "-------------------\n", + "\n", + "1. Learn how to implement and train a custom unrolled variational network fusing updates\n", + " from listmode OSEM blocks and CNN blocks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42b19ea3", + "metadata": {}, + "outputs": [], + "source": [ + "import sirf.STIR\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "from pathlib import Path\n", + "from sirf.Utilities import examples_data_path\n", + "\n", + "# acq_time must be 1min\n", + "acq_time: str = \"1min\"\n", + "\n", + "data_path: Path = Path(examples_data_path(\"PET\")) / \"mMR\"\n", + "list_file: str = str(data_path / \"list.l.hdr\")\n", + "norm_file: str = str(data_path / \"norm.n.hdr\")\n", + "attn_file: str = str(data_path / \"mu_map.hv\")\n", + "\n", + "output_path: Path = Path(f\"recons_{acq_time}\")\n", + "emission_sinogram_output_prefix: str = str(output_path / \"emission_sinogram\")\n", + "scatter_sinogram_output_prefix: str = str(output_path / \"scatter_sinogram\")\n", + "randoms_sinogram_output_prefix: str = str(output_path / \"randoms_sinogram\")\n", + "attenuation_sinogram_output_prefix: str = str(output_path / \"acf_sinogram\")\n", + "\n", + "num_scatter_iter: int = 3\n", + "\n", + "lm_recon_output_file: str = str(output_path / \"lm_recon\")\n", + "lm_60min_recon_output_file: str = str(Path(f\"recons_60min\") / \"lm_recon\")\n", + "nxny: tuple[int, int] = (127, 127)\n", + "num_subsets: int = 21\n", + "\n", + "if torch.cuda.is_available():\n", + " dev = \"cuda:0\"\n", + "else:\n", + " dev = \"cpu\"\n", + "\n", + "# engine's messages go to files, except error messages, which go to stdout\n", + "_ = sirf.STIR.MessageRedirector(\"info.txt\", \"warn.txt\")" + ] + }, + { + "cell_type": "markdown", + "id": "2d3bd078", + "metadata": {}, + "source": [ + "Load listmode data and create the acquisition model\n", + "---------------------------------------------------\n", + "\n", + "In this demo example, we use a simplified acquisition model that only implements the geometric forward projection.\n", + "The effects of normalization, attenuation, scatter, randoms, are ignored but can be added as shown in the last\n", + "example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70df707f", + "metadata": {}, + "outputs": [], + "source": [ + "sirf.STIR.AcquisitionData.set_storage_scheme(\"memory\")\n", + "listmode_data = sirf.STIR.ListmodeData(list_file)\n", + "acq_data_template = listmode_data.acquisition_data_template()\n", + "\n", + "acq_data = sirf.STIR.AcquisitionData(\n", + " str(Path(f\"{emission_sinogram_output_prefix}_f1g1d0b0.hs\"))\n", + ")\n", + "\n", + "# select acquisition model that implements the geometric\n", + "# forward projection by a ray tracing matrix multiplication\n", + "acq_model = sirf.STIR.AcquisitionModelUsingRayTracingMatrix()\n", + "acq_model.set_num_tangential_LORs(1)\n", + "\n", + "randoms = sirf.STIR.AcquisitionData(str(Path(f\"{randoms_sinogram_output_prefix}.hs\")))\n", + "\n", + "ac_factors = sirf.STIR.AcquisitionData(\n", + " str(Path(f\"{attenuation_sinogram_output_prefix}.hs\"))\n", + ")\n", + "asm_attn = sirf.STIR.AcquisitionSensitivityModel(ac_factors)\n", + "\n", + "asm_norm = sirf.STIR.AcquisitionSensitivityModel(norm_file)\n", + "asm = sirf.STIR.AcquisitionSensitivityModel(asm_norm, asm_attn)\n", + "\n", + "asm.set_up(acq_data)\n", + "acq_model.set_acquisition_sensitivity(asm)\n", + "\n", + "scatter_estimate = sirf.STIR.AcquisitionData(\n", + " str(Path(f\"{scatter_sinogram_output_prefix}_{num_scatter_iter}.hs\"))\n", + ")\n", + "acq_model.set_background_term(randoms + scatter_estimate)\n", + "\n", + "# setup an initial (template) image based on the acquisition data template\n", + "initial_image = acq_data_template.create_uniform_image(value=1, xy=nxny)" + ] + }, + { + "cell_type": "markdown", + "id": "612f5c06", + "metadata": {}, + "source": [ + "Setup of the Poisson log likelihood listmode objective function\n", + "---------------------------------------------------------------\n", + "\n", + "Using the listmode data and the acquisition model, we can now setup the Poisson log likelihood objective function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85f6bc7a", + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "lm_obj_fun = (\n", + " sirf.STIR.PoissonLogLikelihoodWithLinearModelForMeanAndListModeDataWithProjMatrixByBin()\n", + ")\n", + "lm_obj_fun.set_acquisition_model(acq_model)\n", + "lm_obj_fun.set_acquisition_data(listmode_data)\n", + "lm_obj_fun.set_num_subsets(num_subsets)\n", + "lm_obj_fun.set_cache_max_size(1000000000)\n", + "lm_obj_fun.set_cache_path(str(output_path))\n", + "print(\"setting up listmode objective function ...\")\n", + "lm_obj_fun.set_up(initial_image)" + ] + }, + { + "cell_type": "markdown", + "id": "405a16c8", + "metadata": {}, + "source": [ + "Setup of OSEM update layer\n", + "--------------------------\n", + "\n", + "See notebook 04." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "388324fa", + "metadata": {}, + "outputs": [], + "source": [ + "class SIRFPoissonlogLGradLayer(torch.autograd.Function):\n", + " @staticmethod\n", + " def forward(\n", + " ctx,\n", + " x: torch.Tensor,\n", + " objective_function,\n", + " sirf_template_image: sirf.STIR.ImageData,\n", + " subset: int,\n", + " ) -> torch.Tensor:\n", + " \"\"\"(listmode) Poisson loglikelihood gradient layer forward pass\n", + "\n", + " Parameters\n", + " ----------\n", + " ctx : context object\n", + " used to store objects that we need in the backward pass\n", + " x : torch.Tensor\n", + " minibatch tensor of shape [1,1,spatial_dimensions] containing the image\n", + " objective_function : sirf (listmode) objective function\n", + " the objective function that we use to calculate the gradient\n", + " sirf_template_image : sirf.STIR.ImageData\n", + " image template that we use to convert between torch tensors and sirf images\n", + " subset : int\n", + " subset number used for the gradient calculation\n", + "\n", + " Returns\n", + " -------\n", + " torch.Tensor\n", + " minibatch tensor of shape [1,1,spatial_dimensions] containing the image\n", + " containing the gradient of the (listmode) Poisson log likelihood at x\n", + " \"\"\"\n", + "\n", + " # we use the context object ctx to store the matrix and other variables that we need in the backward pass\n", + " ctx.device = x.device\n", + " ctx.objective_function = objective_function\n", + " ctx.dtype = x.dtype\n", + " ctx.subset = subset\n", + " ctx.sirf_template_image = sirf_template_image\n", + "\n", + " # setup a new sirf.STIR ImageData object\n", + " x_sirf = sirf_template_image.clone()\n", + " # convert torch tensor to sirf image via numpy\n", + " x_sirf.fill(x.cpu().numpy()[0, 0, ...])\n", + "\n", + " # save the input sirf.STIR ImageData for the backward pass\n", + " ctx.x_sirf = x_sirf\n", + "\n", + " # calculate the gradient of the Poisson log likelihood using SIRF\n", + " g_np = objective_function.gradient(x_sirf, subset).as_array()\n", + "\n", + " # convert back to torch tensor\n", + " y = (\n", + " torch.tensor(g_np, device=ctx.device, dtype=ctx.dtype)\n", + " .unsqueeze(0)\n", + " .unsqueeze(0)\n", + " )\n", + "\n", + " return y\n", + "\n", + " @staticmethod\n", + " def backward(\n", + " ctx, grad_output: torch.Tensor | None\n", + " ) -> tuple[torch.Tensor | None, None, None, None]:\n", + " \"\"\"(listmode) Poisson loglikelihood gradient layer backward pass\n", + "\n", + " Parameters\n", + " ----------\n", + " ctx : context object\n", + " used to store objects that we need in the backward pass\n", + " grad_output : torch.Tensor | None\n", + " minibatch tensor of shape [1,1,spatial_dimensions] containing the gradient (called v in the autograd tutorial)\n", + " https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#optional-reading-vector-calculus-using-autograd\n", + "\n", + " Returns\n", + " -------\n", + " tuple[torch.Tensor | None, None, None, None]\n", + " the Jacobian-vector product of the Poisson log likelihood gradient layer\n", + " \"\"\"\n", + "\n", + " if grad_output is None:\n", + " return None, None, None, None\n", + " else:\n", + " # convert torch tensor to sirf image via numpy\n", + " ctx.sirf_template_image.fill(grad_output.cpu().numpy()[0, 0, ...])\n", + "\n", + " # calculate the Jacobian vector product (the Hessian applied to an image) using SIRF\n", + " back_sirf = ctx.objective_function.multiply_with_Hessian(\n", + " ctx.x_sirf, ctx.sirf_template_image, ctx.subset\n", + " )\n", + "\n", + " # convert back to torch tensor via numpy\n", + " back = (\n", + " torch.tensor(back_sirf.as_array(), device=ctx.device, dtype=ctx.dtype)\n", + " .unsqueeze(0)\n", + " .unsqueeze(0)\n", + " )\n", + "\n", + " return back, None, None, None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f0b0950", + "metadata": {}, + "outputs": [], + "source": [ + "class OSEMUpdateLayer(torch.nn.Module):\n", + " def __init__(\n", + " self,\n", + " objective_function,\n", + " sirf_template_image: sirf.STIR.ImageData,\n", + " subset: int,\n", + " device: str,\n", + " ) -> None:\n", + " \"\"\"OSEM update layer\n", + "\n", + " Parameters\n", + " ----------\n", + " objective_function : sirf (listmode) objective function\n", + " the objective function that we use to calculate the gradient\n", + " sirf_template_image : sirf.STIR.ImageData\n", + " image template that we use to convert between torch tensors and sirf images\n", + " subset : int\n", + " subset number used for the gradient calculation\n", + " device : str\n", + " device used for the calculations\n", + "\n", + " Returns\n", + " -------\n", + " torch.Tensor\n", + " minibatch tensor of shape [1,1,spatial_dimensions] containing the OSEM\n", + " update of the input image using the Poisson log likelihood objective function\n", + " \"\"\"\n", + " super().__init__()\n", + " self._objective_function = objective_function\n", + " self._sirf_template_image: sirf.STIR.ImageData = sirf_template_image\n", + " self._subset: int = subset\n", + "\n", + " self._poisson_logL_grad_layer = SIRFPoissonlogLGradLayer.apply\n", + "\n", + " # setup a tensor containng the inverse of the subset sensitivity image adding the minibatch and channel dimensions\n", + " self._inv_sens_image: torch.Tensor = 1.0 / torch.tensor(\n", + " objective_function.get_subset_sensitivity(subset).as_array(),\n", + " dtype=torch.float32,\n", + " device=device,\n", + " ).unsqueeze(0).unsqueeze(0)\n", + " # replace positive infinity values with 0 (voxels with 0 sensitivity)\n", + " torch.nan_to_num(self._inv_sens_image, posinf=0, out=self._inv_sens_image)\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"forward pass of the OSEM update layer\n", + "\n", + " Parameters\n", + " ----------\n", + " x : torch.Tensor\n", + " minibatch tensor of shape [1,1,spatial_dimensions] containing the image\n", + "\n", + " Returns\n", + " -------\n", + " torch.Tensor\n", + " OSEM update image\n", + " \"\"\"\n", + " grad_x: torch.Tensor = self._poisson_logL_grad_layer(\n", + " x, self._objective_function, self._sirf_template_image, self._subset\n", + " )\n", + " return x + x * self._inv_sens_image * grad_x" + ] + }, + { + "cell_type": "markdown", + "id": "4b57ce96", + "metadata": {}, + "source": [ + "Exercise 5.1\n", + "------------\n", + "\n", + "Implement the forward pass of the unrolled OSEM Variational Network with 2 blocks shown below.\n", + "Start from the code below and fill in the missing parts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce158c2e", + "metadata": {}, + "outputs": [], + "source": [ + "class UnrolledOSEMVarNet(torch.nn.Module):\n", + " def __init__(\n", + " self,\n", + " objective_function,\n", + " sirf_template_image: sirf.STIR.ImageData,\n", + " convnet: torch.nn.Module,\n", + " device: str,\n", + " ) -> None:\n", + " \"\"\"Unrolled OSEM Variational Network with 2 blocks\n", + "\n", + " Parameters\n", + " ----------\n", + " objective_function : sirf.STIR objetive function\n", + " (listmode) Poisson logL objective function\n", + " that we use for the OSEM updates\n", + " sirf_template_image : sirf.STIR.ImageData\n", + " used for the conversion between torch tensors and sirf images\n", + " convnet : torch.nn.Module\n", + " a (convolutional) neural network that maps a minibatch tensor \n", + " of shape [1,1,spatial_dimensions] onto a minibatch tensor of the same shape\n", + " device : str\n", + " device used for the calculations\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " # OSEM update layer using the 1st subset of the listmode data\n", + " self._osem_step_layer0 = OSEMUpdateLayer(\n", + " objective_function, sirf_template_image, 0, device\n", + " )\n", + "\n", + " # OSEM update layer using the 2nd subset of the listmode data\n", + " self._osem_step_layer1 = OSEMUpdateLayer(\n", + " objective_function, sirf_template_image, 1, device\n", + " )\n", + " self._convnet = convnet\n", + " self._relu = torch.nn.ReLU()\n", + "\n", + " # trainable parameters for the fusion of the OSEM update and the CNN output in the two blocks\n", + " self._fusion_weight0 = torch.nn.Parameter(\n", + " torch.ones(1, device=device, dtype=torch.float32)\n", + " )\n", + " self._fusion_weight1 = torch.nn.Parameter(\n", + " torch.ones(1, device=device, dtype=torch.float32)\n", + " )\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"forward pass of the Unrolled OSEM Variational Network\n", + "\n", + " Parameters\n", + " ----------\n", + " x : torch.Tensor\n", + " minibatch tensor of shape [1,1,spatial_dimensions] containing the image\n", + "\n", + " Returns\n", + " -------\n", + " torch.Tensor\n", + " minibatch tensor of shape [1,1,spatial_dimensions] containing the output of the network\n", + " \"\"\"\n", + "\n", + " # =============================================================\n", + " # =============================================================\n", + " # YOUR CODE HERE\n", + " #\n", + " # forward pass contains two blocks where each block\n", + " # consists of a fusion of the OSEM update and the CNN output\n", + " #\n", + " # the fusion is a weighted sum of the OSEM update and the CNN output\n", + " # using the respective fusion weights\n", + " #\n", + " # =============================================================\n", + " # =============================================================" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8575de9b", + "metadata": {}, + "outputs": [], + "source": [ + "# load the reference OSEM reconstruction that we use a input our network\n", + "lm_ref_recon = sirf.STIR.ImageData(f\"{lm_recon_output_file}.hv\")\n", + "x_t = (\n", + " torch.tensor(\n", + " lm_ref_recon.as_array(), device=dev, dtype=torch.float32, requires_grad=False\n", + " )\n", + " .unsqueeze(0)\n", + " .unsqueeze(0)\n", + ")\n", + "\n", + "# define a minimal CNN that we use within the unrolled OSEM Variational Network\n", + "cnn = torch.nn.Sequential(\n", + " torch.nn.Conv3d(1, 5, 5, padding=\"same\", bias=False),\n", + " torch.nn.Conv3d(5, 5, 5, padding=\"same\", bias=False),\n", + " torch.nn.PReLU(device=dev),\n", + " torch.nn.Conv3d(5, 5, 5, padding=\"same\", bias=False),\n", + " torch.nn.Conv3d(5, 5, 5, padding=\"same\", bias=False),\n", + " torch.nn.PReLU(device=dev),\n", + " torch.nn.Conv3d(5, 1, 1, padding=\"same\", bias=False),\n", + ").to(dev)\n", + "\n", + "\n", + "# setup the unrolled OSEM Variational Network using the sirf.STIR listmode objective function\n", + "# and the CNN\n", + "varnet = UnrolledOSEMVarNet(lm_obj_fun, initial_image, cnn, dev)" + ] + }, + { + "cell_type": "markdown", + "id": "f375a774", + "metadata": {}, + "source": [ + "\n", + "Supervised optimization the network parameters\n", + "----------------------------------------------\n", + "\n", + "The following cells demonstrate how to optimize the network parameters\n", + "using a high quality target image (supervised learning).\n", + "Here, we use the reconstruction of the 60min listmode data as the target image.\n", + "\n", + "**The purpose of the following cells is to demonstrate how the training of a network,\n", + "works in principle. The aim is not to train a network that is actually useful!**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5453c3aa", + "metadata": {}, + "outputs": [], + "source": [ + "# define the high quality target image (mini-batch)\n", + "lm_60min_ref_recon = sirf.STIR.ImageData(f\"{lm_60min_recon_output_file}.hv\")\n", + "\n", + "# we have to scale the 60min reconstruction, since it is not reconcstructed in kBq/ml\n", + "scale_factor = lm_ref_recon.as_array().mean() / lm_60min_ref_recon.as_array().mean()\n", + "lm_60min_ref_recon *= scale_factor\n", + "\n", + "target = (\n", + " torch.tensor(\n", + " lm_60min_ref_recon.as_array(),\n", + " device=dev,\n", + " dtype=torch.float32,\n", + " requires_grad=False,\n", + " )\n", + " .unsqueeze(0)\n", + " .unsqueeze(0)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e1b15626", + "metadata": {}, + "source": [ + "To train the network weights, we need to define an optimizer and a loss function.\n", + "Here we use the Adam optimizer with a learning rate of 1e-3 and the Mean Squared Error (MSE) loss function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b73afa7a", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(varnet._convnet.parameters(), lr=1e-3)\n", + "# define the loss function\n", + "loss_fct = torch.nn.MSELoss()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f1e99f4", + "metadata": {}, + "outputs": [], + "source": [ + "# run 10 updates of the model parameters using backpropagation of the\n", + "# gradient of the loss function and the Adam optimizer\n", + "\n", + "num_epochs = 50\n", + "training_loss = torch.zeros(num_epochs)\n", + "\n", + "for i in range(num_epochs):\n", + " # pass the input mini-batch through the network\n", + " prediction = varnet(x_t)\n", + " # calculate the MSE loss between the prediction and the target\n", + " loss = loss_fct(prediction, target)\n", + " # backpropagate the gradient of the loss through the network\n", + " # (needed to update the trainable parameters of the network with an optimizer)\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " # update the trainable parameters of the network with the optimizer\n", + " optimizer.step()\n", + " print(i, loss.item())\n", + " # save the training loss\n", + " training_loss[i] = loss.item()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c542fe44", + "metadata": {}, + "outputs": [], + "source": [ + "# visualize the results\n", + "vmax = float(target.max())\n", + "sl = 71\n", + "\n", + "fig1, ax1 = plt.subplots(2, 3, figsize=(9, 6), tight_layout=True)\n", + "ax1[0, 0].imshow(x_t.cpu().numpy()[0, 0, sl, :, :], cmap=\"Greys\", vmin=0, vmax=vmax)\n", + "ax1[0, 1].imshow(\n", + " prediction.detach().cpu().numpy()[0, 0, sl, :, :], cmap=\"Greys\", vmin=0, vmax=vmax\n", + ")\n", + "ax1[0, 2].imshow(target.cpu().numpy()[0, 0, sl, :, :], cmap=\"Greys\", vmin=0, vmax=vmax)\n", + "ax1[1, 0].imshow(\n", + " x_t.cpu().numpy()[0, 0, sl, :, :] - target.cpu().numpy()[0, 0, sl, :, :],\n", + " cmap=\"seismic\",\n", + " vmin=-0.01,\n", + " vmax=0.01,\n", + ")\n", + "ax1[1, 1].imshow(\n", + " prediction.detach().cpu().numpy()[0, 0, sl, :, :]\n", + " - target.cpu().numpy()[0, 0, sl, :, :],\n", + " cmap=\"seismic\",\n", + " vmin=-0.01,\n", + " vmax=0.01,\n", + ")\n", + "\n", + "ax1[0, 0].set_title(\"network input\")\n", + "ax1[0, 1].set_title(\"network output\")\n", + "ax1[0, 2].set_title(\"target\")\n", + "ax1[1, 0].set_title(\"network input - target\")\n", + "ax1[1, 1].set_title(\"network output - target\")\n", + "fig1.show()\n", + "\n", + "# plot the training loss\n", + "fig2, ax2 = plt.subplots()\n", + "ax2.plot(training_loss.cpu().numpy())\n", + "ax2.set_xlabel(\"epoch\")\n", + "ax2.set_ylabel(\"training loss\")\n", + "fig2.show()" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/Deep_Learning_listmode_PET/05_custrom_unrolled_varnet.py b/notebooks/Deep_Learning_listmode_PET/05_custrom_unrolled_varnet.py new file mode 100644 index 00000000..aff6e501 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/05_custrom_unrolled_varnet.py @@ -0,0 +1,484 @@ +# %% [markdown] +# Creating a custom unrolled variational network for listmode PET data +# ==================================================================== +# +# Learning objectives +# ------------------- +# +# 1. Learn how to implement and train a custom unrolled variational network fusing updates +# from listmode OSEM blocks and CNN blocks + +# %% +import sirf.STIR +import torch +import matplotlib.pyplot as plt +from pathlib import Path +from sirf.Utilities import examples_data_path + +# acq_time must be 1min +acq_time: str = "1min" + +data_path: Path = Path(examples_data_path("PET")) / "mMR" +list_file: str = str(data_path / "list.l.hdr") +norm_file: str = str(data_path / "norm.n.hdr") +attn_file: str = str(data_path / "mu_map.hv") + +output_path: Path = Path(f"recons_{acq_time}") +emission_sinogram_output_prefix: str = str(output_path / "emission_sinogram") +scatter_sinogram_output_prefix: str = str(output_path / "scatter_sinogram") +randoms_sinogram_output_prefix: str = str(output_path / "randoms_sinogram") +attenuation_sinogram_output_prefix: str = str(output_path / "acf_sinogram") + +num_scatter_iter: int = 3 + +lm_recon_output_file: str = str(output_path / "lm_recon") +lm_60min_recon_output_file: str = str(Path(f"recons_60min") / "lm_recon") +nxny: tuple[int, int] = (127, 127) +num_subsets: int = 21 + +if torch.cuda.is_available(): + dev = "cuda:0" +else: + dev = "cpu" + +# engine's messages go to files, except error messages, which go to stdout +_ = sirf.STIR.MessageRedirector("info.txt", "warn.txt") + +# %% [markdown] +# Load listmode data and create the acquisition model +# --------------------------------------------------- +# +# In this demo example, we use a simplified acquisition model that only implements the geometric forward projection. +# The effects of normalization, attenuation, scatter, randoms, are ignored but can be added as shown in the last +# example. + +# %% +sirf.STIR.AcquisitionData.set_storage_scheme("memory") +listmode_data = sirf.STIR.ListmodeData(list_file) +acq_data_template = listmode_data.acquisition_data_template() + +acq_data = sirf.STIR.AcquisitionData( + str(Path(f"{emission_sinogram_output_prefix}_f1g1d0b0.hs")) +) + +# select acquisition model that implements the geometric +# forward projection by a ray tracing matrix multiplication +acq_model = sirf.STIR.AcquisitionModelUsingRayTracingMatrix() +acq_model.set_num_tangential_LORs(1) + +randoms = sirf.STIR.AcquisitionData(str(Path(f"{randoms_sinogram_output_prefix}.hs"))) + +ac_factors = sirf.STIR.AcquisitionData( + str(Path(f"{attenuation_sinogram_output_prefix}.hs")) +) +asm_attn = sirf.STIR.AcquisitionSensitivityModel(ac_factors) + +asm_norm = sirf.STIR.AcquisitionSensitivityModel(norm_file) +asm = sirf.STIR.AcquisitionSensitivityModel(asm_norm, asm_attn) + +asm.set_up(acq_data) +acq_model.set_acquisition_sensitivity(asm) + +scatter_estimate = sirf.STIR.AcquisitionData( + str(Path(f"{scatter_sinogram_output_prefix}_{num_scatter_iter}.hs")) +) +acq_model.set_background_term(randoms + scatter_estimate) + +# setup an initial (template) image based on the acquisition data template +initial_image = acq_data_template.create_uniform_image(value=1, xy=nxny) + +# %% [markdown] +# Setup of the Poisson log likelihood listmode objective function +# --------------------------------------------------------------- +# +# Using the listmode data and the acquisition model, we can now setup the Poisson log likelihood objective function. + +# %% +lm_obj_fun = ( + sirf.STIR.PoissonLogLikelihoodWithLinearModelForMeanAndListModeDataWithProjMatrixByBin() +) +lm_obj_fun.set_acquisition_model(acq_model) +lm_obj_fun.set_acquisition_data(listmode_data) +lm_obj_fun.set_num_subsets(num_subsets) +lm_obj_fun.set_cache_max_size(1000000000) +lm_obj_fun.set_cache_path(str(output_path)) +print("setting up listmode objective function ...") +lm_obj_fun.set_up(initial_image) + +# %% [markdown] +# Setup of OSEM update layer +# -------------------------- +# +# See notebook 04. + +# %% +class SIRFPoissonlogLGradLayer(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + objective_function, + sirf_template_image: sirf.STIR.ImageData, + subset: int, + ) -> torch.Tensor: + """(listmode) Poisson loglikelihood gradient layer forward pass + + Parameters + ---------- + ctx : context object + used to store objects that we need in the backward pass + x : torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the image + objective_function : sirf (listmode) objective function + the objective function that we use to calculate the gradient + sirf_template_image : sirf.STIR.ImageData + image template that we use to convert between torch tensors and sirf images + subset : int + subset number used for the gradient calculation + + Returns + ------- + torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the image + containing the gradient of the (listmode) Poisson log likelihood at x + """ + + # we use the context object ctx to store the matrix and other variables that we need in the backward pass + ctx.device = x.device + ctx.objective_function = objective_function + ctx.dtype = x.dtype + ctx.subset = subset + ctx.sirf_template_image = sirf_template_image + + # setup a new sirf.STIR ImageData object + x_sirf = sirf_template_image.clone() + # convert torch tensor to sirf image via numpy + x_sirf.fill(x.cpu().numpy()[0, 0, ...]) + + # save the input sirf.STIR ImageData for the backward pass + ctx.x_sirf = x_sirf + + # calculate the gradient of the Poisson log likelihood using SIRF + g_np = objective_function.gradient(x_sirf, subset).as_array() + + # convert back to torch tensor + y = ( + torch.tensor(g_np, device=ctx.device, dtype=ctx.dtype) + .unsqueeze(0) + .unsqueeze(0) + ) + + return y + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor | None + ) -> tuple[torch.Tensor | None, None, None, None]: + """(listmode) Poisson loglikelihood gradient layer backward pass + + Parameters + ---------- + ctx : context object + used to store objects that we need in the backward pass + grad_output : torch.Tensor | None + minibatch tensor of shape [1,1,spatial_dimensions] containing the gradient (called v in the autograd tutorial) + https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#optional-reading-vector-calculus-using-autograd + + Returns + ------- + tuple[torch.Tensor | None, None, None, None] + the Jacobian-vector product of the Poisson log likelihood gradient layer + """ + + if grad_output is None: + return None, None, None, None + else: + # convert torch tensor to sirf image via numpy + ctx.sirf_template_image.fill(grad_output.cpu().numpy()[0, 0, ...]) + + # calculate the Jacobian vector product (the Hessian applied to an image) using SIRF + back_sirf = ctx.objective_function.multiply_with_Hessian( + ctx.x_sirf, ctx.sirf_template_image, ctx.subset + ) + + # convert back to torch tensor via numpy + back = ( + torch.tensor(back_sirf.as_array(), device=ctx.device, dtype=ctx.dtype) + .unsqueeze(0) + .unsqueeze(0) + ) + + return back, None, None, None + + +# %% +class OSEMUpdateLayer(torch.nn.Module): + def __init__( + self, + objective_function, + sirf_template_image: sirf.STIR.ImageData, + subset: int, + device: str, + ) -> None: + """OSEM update layer + + Parameters + ---------- + objective_function : sirf (listmode) objective function + the objective function that we use to calculate the gradient + sirf_template_image : sirf.STIR.ImageData + image template that we use to convert between torch tensors and sirf images + subset : int + subset number used for the gradient calculation + device : str + device used for the calculations + + Returns + ------- + torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the OSEM + update of the input image using the Poisson log likelihood objective function + """ + super().__init__() + self._objective_function = objective_function + self._sirf_template_image: sirf.STIR.ImageData = sirf_template_image + self._subset: int = subset + + self._poisson_logL_grad_layer = SIRFPoissonlogLGradLayer.apply + + # setup a tensor containng the inverse of the subset sensitivity image adding the minibatch and channel dimensions + self._inv_sens_image: torch.Tensor = 1.0 / torch.tensor( + objective_function.get_subset_sensitivity(subset).as_array(), + dtype=torch.float32, + device=device, + ).unsqueeze(0).unsqueeze(0) + # replace positive infinity values with 0 (voxels with 0 sensitivity) + torch.nan_to_num(self._inv_sens_image, posinf=0, out=self._inv_sens_image) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """forward pass of the OSEM update layer + + Parameters + ---------- + x : torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the image + + Returns + ------- + torch.Tensor + OSEM update image + """ + grad_x: torch.Tensor = self._poisson_logL_grad_layer( + x, self._objective_function, self._sirf_template_image, self._subset + ) + return x + x * self._inv_sens_image * grad_x + + +# %% [markdown] +# Exercise 5.1 +# ------------ +# +# Implement the forward pass of the unrolled OSEM Variational Network with 2 blocks shown below. +# Start from the code below and fill in the missing parts. + +# %% +class UnrolledOSEMVarNet(torch.nn.Module): + def __init__( + self, + objective_function, + sirf_template_image: sirf.STIR.ImageData, + convnet: torch.nn.Module, + device: str, + ) -> None: + """Unrolled OSEM Variational Network with 2 blocks + + Parameters + ---------- + objective_function : sirf.STIR objetive function + (listmode) Poisson logL objective function + that we use for the OSEM updates + sirf_template_image : sirf.STIR.ImageData + used for the conversion between torch tensors and sirf images + convnet : torch.nn.Module + a (convolutional) neural network that maps a minibatch tensor + of shape [1,1,spatial_dimensions] onto a minibatch tensor of the same shape + device : str + device used for the calculations + """ + super().__init__() + + # OSEM update layer using the 1st subset of the listmode data + self._osem_step_layer0 = OSEMUpdateLayer( + objective_function, sirf_template_image, 0, device + ) + + # OSEM update layer using the 2nd subset of the listmode data + self._osem_step_layer1 = OSEMUpdateLayer( + objective_function, sirf_template_image, 1, device + ) + self._convnet = convnet + self._relu = torch.nn.ReLU() + + # trainable parameters for the fusion of the OSEM update and the CNN output in the two blocks + self._fusion_weight0 = torch.nn.Parameter( + torch.ones(1, device=device, dtype=torch.float32) + ) + self._fusion_weight1 = torch.nn.Parameter( + torch.ones(1, device=device, dtype=torch.float32) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """forward pass of the Unrolled OSEM Variational Network + + Parameters + ---------- + x : torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the image + + Returns + ------- + torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the output of the network + """ + + # ============================================================= + # ============================================================= + # YOUR CODE HERE + # + # forward pass contains two blocks where each block + # consists of a fusion of the OSEM update and the CNN output + # + # the fusion is a weighted sum of the OSEM update and the CNN output + # using the respective fusion weights + # + # ============================================================= + # ============================================================= + + +# %% +# load the reference OSEM reconstruction that we use a input our network +lm_ref_recon = sirf.STIR.ImageData(f"{lm_recon_output_file}.hv") +x_t = ( + torch.tensor( + lm_ref_recon.as_array(), device=dev, dtype=torch.float32, requires_grad=False + ) + .unsqueeze(0) + .unsqueeze(0) +) + +# define a minimal CNN that we use within the unrolled OSEM Variational Network +cnn = torch.nn.Sequential( + torch.nn.Conv3d(1, 5, 5, padding="same", bias=False), + torch.nn.Conv3d(5, 5, 5, padding="same", bias=False), + torch.nn.PReLU(device=dev), + torch.nn.Conv3d(5, 5, 5, padding="same", bias=False), + torch.nn.Conv3d(5, 5, 5, padding="same", bias=False), + torch.nn.PReLU(device=dev), + torch.nn.Conv3d(5, 1, 1, padding="same", bias=False), +).to(dev) + + +# setup the unrolled OSEM Variational Network using the sirf.STIR listmode objective function +# and the CNN +varnet = UnrolledOSEMVarNet(lm_obj_fun, initial_image, cnn, dev) + +# %% [markdown] +# +# Supervised optimization the network parameters +# ---------------------------------------------- +# +# The following cells demonstrate how to optimize the network parameters +# using a high quality target image (supervised learning). +# Here, we use the reconstruction of the 60min listmode data as the target image. +# +# **The purpose of the following cells is to demonstrate how the training of a network, +# works in principle. The aim is not to train a network that is actually useful!** + +# %% +# define the high quality target image (mini-batch) +lm_60min_ref_recon = sirf.STIR.ImageData(f"{lm_60min_recon_output_file}.hv") + +# we have to scale the 60min reconstruction, since it is not reconcstructed in kBq/ml +scale_factor = lm_ref_recon.as_array().mean() / lm_60min_ref_recon.as_array().mean() +lm_60min_ref_recon *= scale_factor + +target = ( + torch.tensor( + lm_60min_ref_recon.as_array(), + device=dev, + dtype=torch.float32, + requires_grad=False, + ) + .unsqueeze(0) + .unsqueeze(0) +) + +# %% [markdown] +# To train the network weights, we need to define an optimizer and a loss function. +# Here we use the Adam optimizer with a learning rate of 1e-3 and the Mean Squared Error (MSE) loss function. + +# %% +optimizer = torch.optim.Adam(varnet._convnet.parameters(), lr=1e-3) +# define the loss function +loss_fct = torch.nn.MSELoss() + +# %% +# run 10 updates of the model parameters using backpropagation of the +# gradient of the loss function and the Adam optimizer + +num_epochs = 50 +training_loss = torch.zeros(num_epochs) + +for i in range(num_epochs): + # pass the input mini-batch through the network + prediction = varnet(x_t) + # calculate the MSE loss between the prediction and the target + loss = loss_fct(prediction, target) + # backpropagate the gradient of the loss through the network + # (needed to update the trainable parameters of the network with an optimizer) + optimizer.zero_grad() + loss.backward() + # update the trainable parameters of the network with the optimizer + optimizer.step() + print(i, loss.item()) + # save the training loss + training_loss[i] = loss.item() + +# %% +# visualize the results +vmax = float(target.max()) +sl = 71 + +fig1, ax1 = plt.subplots(2, 3, figsize=(9, 6), tight_layout=True) +ax1[0, 0].imshow(x_t.cpu().numpy()[0, 0, sl, :, :], cmap="Greys", vmin=0, vmax=vmax) +ax1[0, 1].imshow( + prediction.detach().cpu().numpy()[0, 0, sl, :, :], cmap="Greys", vmin=0, vmax=vmax +) +ax1[0, 2].imshow(target.cpu().numpy()[0, 0, sl, :, :], cmap="Greys", vmin=0, vmax=vmax) +ax1[1, 0].imshow( + x_t.cpu().numpy()[0, 0, sl, :, :] - target.cpu().numpy()[0, 0, sl, :, :], + cmap="seismic", + vmin=-0.01, + vmax=0.01, +) +ax1[1, 1].imshow( + prediction.detach().cpu().numpy()[0, 0, sl, :, :] + - target.cpu().numpy()[0, 0, sl, :, :], + cmap="seismic", + vmin=-0.01, + vmax=0.01, +) + +ax1[0, 0].set_title("network input") +ax1[0, 1].set_title("network output") +ax1[0, 2].set_title("target") +ax1[1, 0].set_title("network input - target") +ax1[1, 1].set_title("network output - target") +fig1.show() + +# plot the training loss +fig2, ax2 = plt.subplots() +ax2.plot(training_loss.cpu().numpy()) +ax2.set_xlabel("epoch") +ax2.set_ylabel("training loss") +fig2.show() diff --git a/notebooks/Deep_Learning_listmode_PET/06_outlook.ipynb b/notebooks/Deep_Learning_listmode_PET/06_outlook.ipynb new file mode 100644 index 00000000..39e1a2ab --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/06_outlook.ipynb @@ -0,0 +1,43 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "046e9138", + "metadata": {}, + "source": [ + "Outlook\n", + "=======\n", + "\n", + "Congratulations, you made through all the notebooks and should\n", + "now be able to understand the basics of deep learning for (listmode) PET reconstruction.\n", + "\n", + "Follow-up questions / proposals to think about\n", + "----------------------------------------------\n", + "\n", + "1. How could we incorporate training on mini-batches with a batch size greater than 1?\n", + "2. What happens if we increase the number of unrolled blocks in the variational network?\n", + "3. What is the impact on the CNN size / architecture and the loss function in the variational network?\n", + "4. Most neural networks work best if the image intensities are normalized. How could we include this in the training process?\n", + "\n", + "Training speed\n", + "--------------\n", + "\n", + "Currently, the training is quite slow due to several reasons:\n", + "- the listmode projections are currently performed on the CPU\n", + "- there is a lot of memory transfer between CPU and GPU during training (OSEM block on CPU, CNN on GPU)\n", + "\n", + "**However,** SIRF and STIR are community projects that are constantly being improved.\n", + "So, it is likely that the training speed will increase in the future." + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/Deep_Learning_listmode_PET/06_outlook.py b/notebooks/Deep_Learning_listmode_PET/06_outlook.py new file mode 100644 index 00000000..61b8dec3 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/06_outlook.py @@ -0,0 +1,24 @@ +# %% [markdown] +# Outlook +# ======= +# +# Congratulations, you made through all the notebooks and should +# now be able to understand the basics of deep learning for (listmode) PET reconstruction. +# +# Follow-up questions / proposals to think about +# ---------------------------------------------- +# +# 1. How could we incorporate training on mini-batches with a batch size greater than 1? +# 2. What happens if we increase the number of unrolled blocks in the variational network? +# 3. What is the impact on the CNN size / architecture and the loss function in the variational network? +# 4. Most neural networks work best if the image intensities are normalized. How could we include this in the training process? +# +# Training speed +# -------------- +# +# Currently, the training is quite slow due to several reasons: +# - the listmode projections are currently performed on the CPU +# - there is a lot of memory transfer between CPU and GPU during training (OSEM block on CPU, CNN on GPU) +# +# **However,** SIRF and STIR are community projects that are constantly being improved. +# So, it is likely that the training speed will increase in the future. diff --git a/notebooks/Deep_Learning_listmode_PET/README.md b/notebooks/Deep_Learning_listmode_PET/README.md new file mode 100644 index 00000000..486bfa2d --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/README.md @@ -0,0 +1,31 @@ +# STIR listmode (LM) Deep learning (DL) reconstruction notebooks + +## Structure of the exercises + +- Intro / motivation: `00_introduction.ipynb` + - problem setting (DL recon network to maps from "low quality" to "higher quality" images) + - Why listmode and not sinograms? + +- Running sinogram and listmode OSEM reconstruction in sirf.STIR: `01_SIRF_listmode_recon.ipynb` + - learn how to run (listmode) OSEM in sirf.STIR + - understand the relation between the OSEM update and the gradient of the Poisson logL + +- A deep dive into differenet (image) array classes (SIRF images vs numpy arrays vs torch tensors): `02_SIRF_vs_torch_arrays.ipynb` + - differences between sirf.STIR.ImageData, numpy arrays and pytorch tensors + - how to convert from SIRF images to torch tensors and back + +- Defining custom (non-pytorch) layers that are compatible with pytorch's autograd functionality: `03_custom_torch_layers.ipynb` + - basic of gradient backpropagation + - understand what needs to be implemented in the backward pass based on a simply numpy matrix + multiplication layer + +- Defining a custom (listmode) Poisson logL gradient step layer using sirf.STIR and pytorch: `04_custom_sirf_Poisson_logL_layer.ipynb` + - use of sirf.STIR for a step in the direction of the gradient of the Poisson logL + - understanding of the respective Jacobian vector product for the backward pass + - combining the Poisson logL gradient step layer into a OSEM update layer + +- Demo training of a minimal unrolled variational network: `05_custrom_unrolled_varnet.ipynb` + - combination of OSEM update layers and a CNN into an unrolled variational network + - demo supervised training based on a single low count data set + high count reference image + +- Outlook: `06_outlook.ipynb` diff --git a/notebooks/Deep_Learning_listmode_PET/TODO.txt b/notebooks/Deep_Learning_listmode_PET/TODO.txt new file mode 100644 index 00000000..5dbd3ee2 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/TODO.txt @@ -0,0 +1,3 @@ +- outlook +- use single norm / mu file for 1/60min +- rescale 60min ref \ No newline at end of file diff --git a/notebooks/Deep_Learning_listmode_PET/figs/.gitignore b/notebooks/Deep_Learning_listmode_PET/figs/.gitignore new file mode 100644 index 00000000..8d71bf9c --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/figs/.gitignore @@ -0,0 +1 @@ +*.bkp diff --git a/notebooks/Deep_Learning_listmode_PET/figs/osem_layer.drawio b/notebooks/Deep_Learning_listmode_PET/figs/osem_layer.drawio new file mode 100644 index 00000000..6cc394c7 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/figs/osem_layer.drawio @@ -0,0 +1,109 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/notebooks/Deep_Learning_listmode_PET/figs/osem_layer.drawio.svg b/notebooks/Deep_Learning_listmode_PET/figs/osem_layer.drawio.svg new file mode 100644 index 00000000..2a03da73 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/figs/osem_layer.drawio.svg @@ -0,0 +1,4 @@ + + + +
custom pytorch
OSEM update
 layer
pytorch input image tensor
(mini batch)
xin
shape (1,1,nz,ny,nx)
sirf.STIR (listmode) objective function
including
sirf.STIR acquired (listmode) PET data 
sirf.STIR acquisition model
pytorch output image tensor 
(mini batch)
xout = xin + (xin/s) ( logL)(xin)
shape (1,1,nz,ny,nx)
custom pytorch
Poisson logL
 layer
pointwise multiplication
 with inverse sensitivity image
pointwise multiplication
pointwise addition
pytorch input image tensor
(mini batch)
xin
shape (1,1,nz,ny,nx)
pytorch output image tensor 
(mini batch)
xout = xin + (xin/s) ( logL)(xin)
shape (1,1,nz,ny,nx)
sirf.STIR (listmode) objective function
Decomposition of the OSEM update layer into feed forward network using custom Poisson gradient logL layer
and basic pytorch tensor operations
\ No newline at end of file diff --git a/notebooks/Deep_Learning_listmode_PET/figs/osem_varnet.drawio b/notebooks/Deep_Learning_listmode_PET/figs/osem_varnet.drawio new file mode 100644 index 00000000..33b7c137 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/figs/osem_varnet.drawio @@ -0,0 +1,156 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/notebooks/Deep_Learning_listmode_PET/figs/osem_varnet.drawio.svg b/notebooks/Deep_Learning_listmode_PET/figs/osem_varnet.drawio.svg new file mode 100644 index 00000000..97afcf27 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/figs/osem_varnet.drawio.svg @@ -0,0 +1,4 @@ + + + +
pytorch input image tensor
(mini batch)
xin
shape (1,1,nz,ny,nx)
containing "low quality"
images
prediction output image tensor 
(mini batch)
xout
shape (1,1,nz,ny,nx)
containing "high quality" images
custom pytorch
OSEM update
 layer
sirf.STIR (listmode) objective function
including
sirf.STIR acquired (listmode) PET data 
sirf.STIR acquisition model
NN with trainable 
weights
fusion (e.g. weighted sum)
ReLU
xOSEM
xNN
UNROLLED UPDATE 1
custom pytorch
OSEM update
 layer
sirf.STIR (listmode) objective function
including
sirf.STIR acquired (listmode) PET data 
sirf.STIR acquisition model
NN with trainable 
weights
fusion (e.g. weighted sum)
ReLU
xOSEM
xNN
UNROLLED UPDATE N
Unrolled variational network for reconstruction
\ No newline at end of file diff --git a/notebooks/Deep_Learning_listmode_PET/figs/poisson_logL_grad_layer.drawio b/notebooks/Deep_Learning_listmode_PET/figs/poisson_logL_grad_layer.drawio new file mode 100644 index 00000000..894b207b --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/figs/poisson_logL_grad_layer.drawio @@ -0,0 +1,40 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/notebooks/Deep_Learning_listmode_PET/figs/poisson_logL_grad_layer.drawio.svg b/notebooks/Deep_Learning_listmode_PET/figs/poisson_logL_grad_layer.drawio.svg new file mode 100644 index 00000000..8ba5d909 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/figs/poisson_logL_grad_layer.drawio.svg @@ -0,0 +1,4 @@ + + + +
custom pytorch
Poisson logL
gradient layer
pytorch input image tensor
(mini batch)
xin
shape (1,1,nz,ny,nx)
sirf.STIR (listmode) objective function
including
sirf.STIR acquired (listmode) PET data 
sirf.STIR acquisition model
pytorch output image tensor 
(mini batch)
xout = ( logL)(xin)
shape (1,1,nz,ny,nx)
\ No newline at end of file diff --git a/notebooks/Deep_Learning_listmode_PET/figs/varnet.drawio b/notebooks/Deep_Learning_listmode_PET/figs/varnet.drawio new file mode 100644 index 00000000..ed1b00f3 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/figs/varnet.drawio @@ -0,0 +1,127 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/notebooks/Deep_Learning_listmode_PET/figs/varnet.drawio.svg b/notebooks/Deep_Learning_listmode_PET/figs/varnet.drawio.svg new file mode 100644 index 00000000..23835362 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/figs/varnet.drawio.svg @@ -0,0 +1,4 @@ + + + +
(listmode) PET data fidelity
update layer
(listmode) PET data fidelity...
CNN
with trainable parameters
(learned regularization)
CNN...
(listmode) PET data fidelity
update layer
(listmode) PET data fidelity...
CNN
with trainable parameters
(learned regularization)
CNN...
(listmode) PET data fidelity
update layer
(listmode) PET data fidelity...
CNN
with trainable parameters

(learned regularization)
CNN...
Layer 1
Layer 1
Layer 2
Layer 2
Layer n
Layer n
INPUT
initial image
estimate
x0
INPUTinitial im...
OUTPUT:
"high quality" 
reconstruction
(from "low quality" data)
xn
OUTPUT:"high quality"...
INPUT: acquired (listmode) PET data + quantitative corrections (attenuation, scatter, randoms, normalization)
INPUT: acquired (listmode) PET data + quantitative corrections (attenuation, scatter, randoms, normalization)
x1
x1
x2
x2
xd1
xd1
xr1
xr1
xd2
xd2
xr2
xr2
xrn
xrn
xdn
xdn
Unrolled variational PET listmode network
Unrolled variational PET listmode network
\ No newline at end of file diff --git a/notebooks/Deep_Learning_listmode_PET/snippets/solution_0_1.md b/notebooks/Deep_Learning_listmode_PET/snippets/solution_0_1.md new file mode 100644 index 00000000..313cd31f --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/snippets/solution_0_1.md @@ -0,0 +1,17 @@ +In matrix notation, the gradient of the Poisson log-likelihood is given by: +$$ \nabla_x \log L(y|x) = A^T \left( \frac{y}{\bar{y}(x)} - 1 \right) = A^T \left( \frac{y}{Ax + s} - 1 \right) .$$ + +For a given image voxel $j$, the corresponding expression reads: +$$ \frac{\partial \log L(y|x)} {\partial x_j} = \sum_{i=1}^m a_{ij} \left( \frac{y_i}{\sum_{k=1}^n a_{ik} x_k + s_i} - 1 \right) .$$ + +Using a list of detected coincidence events $e$ instead of a sinogram, the gradient of the Poisson log-likelihood becomes: +$$ \frac{\partial \log L(y|x)} {\partial x_j} = \sum_{\text{events} \ e} a_{i_ej} \frac{1}{\sum_{k=1}^n a_{i_ek} x_k + s_{i_e}} - \sum_{i=1}^m a_{ij} 1, $$ +where $i_e$ is the (TOF) sinogram bin corresponding to event $e$. + +**Note:** +- SIRF (using the STIR backend) already provides implementations of the (TOF) PET acquisition forward model and + the gradient of the Poisson log-likelihood such that we do not have to re-implement these. +- using SIRF with STIR, this gradient can be evaluated in listmode +- if the number of listmode events is much smaller compared to the number of (TOF) sinogram bins, evaluating the gradient + in listmode can be more efficient. + diff --git a/notebooks/Deep_Learning_listmode_PET/snippets/solution_1_1.py b/notebooks/Deep_Learning_listmode_PET/snippets/solution_1_1.py new file mode 100644 index 00000000..af7fdef1 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/snippets/solution_1_1.py @@ -0,0 +1 @@ +new_image = initial_image + 0.001 * obj_fun.gradient(initial_image, 0) diff --git a/notebooks/Deep_Learning_listmode_PET/snippets/solution_1_2.py b/notebooks/Deep_Learning_listmode_PET/snippets/solution_1_2.py new file mode 100644 index 00000000..bde74e77 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/snippets/solution_1_2.py @@ -0,0 +1,10 @@ +subset = 0 +subset_grad = obj_fun.gradient(initial_image, subset) +# this is only correct, if the sensitivity image is greater than 0 everywhere +# (see next exercise for more details) +step = initial_image / obj_fun.get_subset_sensitivity(subset) +osem_update = initial_image + step * subset_grad + +# maximum value of the updated image is nan, because the sensitivity image is 0 in some places +# which needs special attention +print(osem_update.max()) diff --git a/notebooks/Deep_Learning_listmode_PET/snippets/solution_1_3.py b/notebooks/Deep_Learning_listmode_PET/snippets/solution_1_3.py new file mode 100644 index 00000000..68e9b231 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/snippets/solution_1_3.py @@ -0,0 +1,23 @@ +# calculate the inverse of the subset sensitivity images, correctly accounting for voxels +# where the sensitivity images are zero + +inverse_sens_images = [] + +for i in range(num_subsets): + inverse_sens_image = acq_data.create_uniform_image(value=0, xy=nxny) + inverse_sens_image_np = np.zeros( + inverse_sens_image.shape, dtype=inverse_sens_image.as_array().dtype + ) + sens_image_np = obj_fun.get_subset_sensitivity(i).as_array() + np.divide(1, sens_image_np, out=inverse_sens_image_np, where=sens_image_np > 0) + inverse_sens_image.fill(inverse_sens_image_np) + inverse_sens_images.append(inverse_sens_image) + +for it in range(num_iter): + for i in range(num_subsets): + subset_grad = obj_fun.gradient(recon, i) + recon = recon + recon * inverse_sens_images[i] * subset_grad + +fig2, ax2 = plt.subplots(1, 1, figsize=(4, 4), tight_layout=True) +ax2.imshow(recon.as_array()[71, :, :], cmap="Greys", vmin=0, vmax=vmax) +fig2.show() diff --git a/notebooks/Deep_Learning_listmode_PET/snippets/solution_1_4.py b/notebooks/Deep_Learning_listmode_PET/snippets/solution_1_4.py new file mode 100644 index 00000000..151149b8 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/snippets/solution_1_4.py @@ -0,0 +1,32 @@ +lm_obj_fun.set_up(initial_image) + +# initialize the reconstruction with ones where the sensitivity image is greater than 0 +# all other values are set to zero and are not updated during reconstruction +lm_recon = initial_image.copy() +lm_recon.fill(lm_obj_fun.get_subset_sensitivity(0).as_array() > 0) + +lm_inverse_sens_images = [] + +for i in range(num_subsets): + lm_inverse_sens_image = acq_data.create_uniform_image(value=0, xy=nxny) + lm_inverse_sens_image_np = np.zeros( + lm_inverse_sens_image.shape, dtype=lm_inverse_sens_image.as_array().dtype + ) + + lm_sens_image_np = lm_obj_fun.get_subset_sensitivity(i).as_array() + + np.divide( + 1, lm_sens_image_np, out=lm_inverse_sens_image_np, where=lm_sens_image_np > 0 + ) + lm_inverse_sens_image.fill(lm_inverse_sens_image_np) + lm_inverse_sens_images.append(lm_inverse_sens_image) + +for it in range(num_iter): + for i in range(num_subsets): + subset_grad = lm_obj_fun.gradient(lm_recon, i) + lm_recon = lm_recon + lm_recon * lm_inverse_sens_images[i] * subset_grad + + +fig4, ax4 = plt.subplots(1, 1, figsize=(4, 4), tight_layout=True) +ax4.imshow(lm_recon.as_array()[71, :, :], cmap="Greys", vmin=0, vmax=vmax) +fig4.show() diff --git a/notebooks/Deep_Learning_listmode_PET/snippets/solution_2_1.py b/notebooks/Deep_Learning_listmode_PET/snippets/solution_2_1.py new file mode 100644 index 00000000..123587d3 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/snippets/solution_2_1.py @@ -0,0 +1,12 @@ +# we first create an numpy array that we fill in a loop +numpy_image_3: np.ndarray = np.zeros(sirf_image_1.shape, dtype=sirf_image_1.dtype) +for i in range(numpy_image_3.shape[0]): + numpy_image_3[i, :, :] = i ** 2 + +sirf_image_3: sirf.STIR.ImageData = acq_data.create_uniform_image(0.0) +sirf_image_3.fill(numpy_image_3) + +print() +print(f"sirf_image_3 shape .: {sirf_image_3.shape}") +print(f"sirf_image_3 spacing .: {sirf_image_3.spacing}") +print(f"sirf_image_3 max .: {sirf_image_3.max()}") diff --git a/notebooks/Deep_Learning_listmode_PET/snippets/solution_3_1.py b/notebooks/Deep_Learning_listmode_PET/snippets/solution_3_1.py new file mode 100644 index 00000000..e69de29b diff --git a/notebooks/Deep_Learning_listmode_PET/snippets/solution_4_1.py b/notebooks/Deep_Learning_listmode_PET/snippets/solution_4_1.py new file mode 100644 index 00000000..aa7ea434 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/snippets/solution_4_1.py @@ -0,0 +1,97 @@ +class SIRFPoissonlogLGradLayer(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + objective_function, + sirf_template_image: sirf.STIR.ImageData, + subset: int, + ) -> torch.Tensor: + """(listmode) Poisson loglikelihood gradient layer forward pass + + Parameters + ---------- + ctx : context object + used to store objects that we need in the backward pass + x : torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the image + objective_function : sirf (listmode) objective function + the objective function that we use to calculate the gradient + sirf_template_image : sirf.STIR.ImageData + image template that we use to convert between torch tensors and sirf images + subset : int + subset number used for the gradient calculation + + Returns + ------- + torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the image + containing the gradient of the (listmode) Poisson log likelihood at x + """ + + # we use the context object ctx to store the matrix and other variables that we need in the backward pass + ctx.device = x.device + ctx.objective_function = objective_function + ctx.dtype = x.dtype + ctx.subset = subset + ctx.sirf_template_image = sirf_template_image + + # setup a new sirf.STIR ImageData object + x_sirf = sirf_template_image.clone() + # convert torch tensor to sirf image via numpy + x_sirf.fill(x.cpu().numpy()[0, 0, ...]) + + # save the input sirf.STIR ImageData for the backward pass + ctx.x_sirf = x_sirf + + # calculate the gradient of the Poisson log likelihood using SIRF + g_np = objective_function.gradient(x_sirf, subset).as_array() + + # convert back to torch tensor + y = ( + torch.tensor(g_np, device=ctx.device, dtype=ctx.dtype) + .unsqueeze(0) + .unsqueeze(0) + ) + + return y + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor | None + ) -> tuple[torch.Tensor | None, None, None, None]: + """(listmode) Poisson loglikelihood gradient layer backward pass + + Parameters + ---------- + ctx : context object + used to store objects that we need in the backward pass + grad_output : torch.Tensor | None + minibatch tensor of shape [1,1,spatial_dimensions] containing the gradient (called v in the autograd tutorial) + https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#optional-reading-vector-calculus-using-autograd + + Returns + ------- + tuple[torch.Tensor | None, None, None, None] + the Jacobian-vector product of the Poisson log likelihood gradient layer + """ + + if grad_output is None: + return None, None, None, None + else: + # convert torch tensor to sirf image via numpy + ctx.sirf_template_image.fill(grad_output.cpu().numpy()[0, 0, ...]) + + # calculate the Jacobian vector product (the Hessian applied to an image) using SIRF + back_sirf = ctx.objective_function.multiply_with_Hessian( + ctx.x_sirf, ctx.sirf_template_image, ctx.subset + ) + + # convert back to torch tensor via numpy + back = ( + torch.tensor(back_sirf.as_array(), device=ctx.device, dtype=ctx.dtype) + .unsqueeze(0) + .unsqueeze(0) + ) + + return back, None, None, None diff --git a/notebooks/Deep_Learning_listmode_PET/snippets/solution_4_2.py b/notebooks/Deep_Learning_listmode_PET/snippets/solution_4_2.py new file mode 100644 index 00000000..dfcdc9dc --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/snippets/solution_4_2.py @@ -0,0 +1,60 @@ +class OSEMUpdateLayer(torch.nn.Module): + def __init__( + self, + objective_function, + sirf_template_image: sirf.STIR.ImageData, + subset: int, + device: str, + ) -> None: + """OSEM update layer + + Parameters + ---------- + objective_function : sirf (listmode) objective function + the objective function that we use to calculate the gradient + sirf_template_image : sirf.STIR.ImageData + image template that we use to convert between torch tensors and sirf images + subset : int + subset number used for the gradient calculation + device : str + device used for the calculations + + Returns + ------- + torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the OSEM + update of the input image using the Poisson log likelihood objective function + """ + super().__init__() + self._objective_function = objective_function + self._sirf_template_image: sirf.STIR.ImageData = sirf_template_image + self._subset: int = subset + + self._poisson_logL_grad_layer = SIRFPoissonlogLGradLayer.apply + + # setup a tensor containng the inverse of the subset sensitivity image adding the minibatch and channel dimensions + self._inv_sens_image: torch.Tensor = 1.0 / torch.tensor( + objective_function.get_subset_sensitivity(subset).as_array(), + dtype=torch.float32, + device=device, + ).unsqueeze(0).unsqueeze(0) + # replace positive infinity values with 0 (voxels with 0 sensitivity) + torch.nan_to_num(self._inv_sens_image, posinf=0, out=self._inv_sens_image) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """forward pass of the OSEM update layer + + Parameters + ---------- + x : torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the image + + Returns + ------- + torch.Tensor + OSEM update image + """ + grad_x: torch.Tensor = self._poisson_logL_grad_layer( + x, self._objective_function, self._sirf_template_image, self._subset + ) + return x + x * self._inv_sens_image * grad_x diff --git a/notebooks/Deep_Learning_listmode_PET/snippets/solution_5_1.py b/notebooks/Deep_Learning_listmode_PET/snippets/solution_5_1.py new file mode 100644 index 00000000..d77ac9e6 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/snippets/solution_5_1.py @@ -0,0 +1,68 @@ +class UnrolledOSEMVarNet(torch.nn.Module): + def __init__( + self, + objective_function, + sirf_template_image: sirf.STIR.ImageData, + convnet: torch.nn.Module, + device: str, + ) -> None: + """Unrolled OSEM Variational Network with 2 blocks + + Parameters + ---------- + objective_function : sirf.STIR objetive function + (listmode) Poisson logL objective function + that we use for the OSEM updates + sirf_template_image : sirf.STIR.ImageData + used for the conversion between torch tensors and sirf images + convnet : torch.nn.Module + a (convolutional) neural network that maps a minibatch tensor + of shape [1,1,spatial_dimensions] onto a minibatch tensor of the same shape + device : str + device used for the calculations + """ + super().__init__() + + # OSEM update layer using the 1st subset of the listmode data + self._osem_step_layer0 = OSEMUpdateLayer( + objective_function, sirf_template_image, 0, device + ) + + # OSEM update layer using the 2nd subset of the listmode data + self._osem_step_layer1 = OSEMUpdateLayer( + objective_function, sirf_template_image, 1, device + ) + self._convnet = convnet + self._relu = torch.nn.ReLU() + + # trainable parameters for the fusion of the OSEM update and the CNN output in the two blocks + # we start with a weight of 10 for the fusion + # a good starting value depends on the scale of the input image + self._fusion_weight0 = torch.nn.Parameter( + 10 * torch.ones(1, device=device, dtype=torch.float32) + ) + self._fusion_weight1 = torch.nn.Parameter( + 10 * torch.ones(1, device=device, dtype=torch.float32) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """forward pass of the Unrolled OSEM Variational Network + + Parameters + ---------- + x : torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the image + + Returns + ------- + torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the output of the network + """ + x1 = self._relu( + self._fusion_weight0 * self._convnet(x) + self._osem_step_layer0(x) + ) + x2 = self._relu( + self._fusion_weight1 * self._convnet(x1) + self._osem_step_layer1(x1) + ) + + return x2 diff --git a/notebooks/Deep_Learning_listmode_PET/test/lm_data_fid.py b/notebooks/Deep_Learning_listmode_PET/test/lm_data_fid.py new file mode 100644 index 00000000..80546251 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/test/lm_data_fid.py @@ -0,0 +1,122 @@ +import torch + + +class LMLinearOperator: + def __init__(self, A: torch.tensor) -> None: + self._A = A + self._lm_data = None + self._adjoint_ones = None + + @property + def in_shape(self) -> tuple[int]: + return (self._A.shape[1],) + + @property + def out_shape(self) -> tuple[int]: + return (self._A.shape[0],) + + @property + def A(self) -> torch.tensor: + return self._A + + @property + def lm_data(self) -> torch.tensor: + return self._lm_data + + @lm_data.setter + def lm_data(self, value: torch.tensor) -> None: + self._lm_data = value + + @property + def adjoint_ones(self) -> torch.tensor: + if self._adjoint_ones is None: + self._adjoint_ones = self._A.T @ torch.ones( + self._A.shape[0], device=self._A.device, dtype=torch.float64 + ) + return self._adjoint_ones + + @property + def data(self) -> torch.tensor: + data = torch.zeros(self._A.shape[0], device=self._A.device, dtype=torch.int) + for i in range(self._A.shape[0]): + data[i] = (self._lm_data == i).sum().item() + + return data + + def __call__(self, x: torch.tensor) -> torch.tensor: + return self.fwd(x) + + def fwd(self, x: torch.tensor) -> torch.tensor: + return self._A @ x + + def adjoint(self, y: torch.tensor) -> torch.tensor: + return self._A.T @ y.double() + + def fwd_lm(self, x: torch.tensor) -> torch.tensor: + if self._lm_data is None: + raise ValueError("must set lm data first") + return self._A[self._lm_data, :] @ x + + def adjoint_lm(self, lst: torch.tensor) -> torch.tensor: + return self._A[self._lm_data, :].T @ lst.double() + + +class PoissonLogL: + def __init__(self, data: torch.tensor, op: LMLinearOperator) -> None: + self._data = data + self._op = op + + def __call__(self, x: torch.tensor) -> float: + exp = self._op.fwd(x) + return float((self._data * torch.log(exp) - exp).sum()) + + def gradient(self, x: torch.tensor) -> torch.tensor: + exp = self._op.fwd(x) + return self._op.adjoint((self._data / exp) - 1) + + def hessian_applied(self, x: torch.tensor, x_prime: torch.tensor) -> torch.tensor: + exp = self._op.fwd(x) + exp_prime = self._op.fwd(x_prime) + return -self._op.adjoint(self._data * exp_prime / (exp ** 2)) + + +class LMPoissonLogL: + def __init__(self, op: LMLinearOperator) -> None: + self._op = op + + @property + def op(self) -> LMLinearOperator: + return self._op + + def gradient(self, x: torch.tensor) -> torch.tensor: + exp = self._op.fwd_lm(x) + return self._op.adjoint_lm(1 / exp) - self._op.adjoint_ones + pass + + def hessian_applied(self, x: torch.tensor, x_prime: torch.tensor) -> torch.tensor: + exp = self._op.fwd_lm(x) + exp_prime = self._op.fwd_lm(x_prime) + return -self._op.adjoint_lm(exp_prime / (exp ** 2)) + + +def test_lmlogl(dev: str, nx: int = 2, ny: int = 2): + + x = torch.rand(nx, device=dev, dtype=torch.float64) + op = LMLinearOperator( + 6 * torch.rand(ny, nx, device=dev, dtype=torch.float64) + + 24 * torch.eye(ny, nx, device=dev, dtype=torch.float64) + ) + + y_noisefree = op(x) + y = torch.poisson(y_noisefree).int() + + lm_data = torch.repeat_interleave(torch.arange(ny, device=dev), y) + + # shuffle the LM data using a random permuation + shuffled_inds = torch.randperm(lm_data.shape[0]) + lm_data = lm_data[shuffled_inds] + + op.lm_data = lm_data + lmdata_fid = LMPoissonLogL(op) + + return lmdata_fid diff --git a/notebooks/Deep_Learning_listmode_PET/test/stir_torch_lm_em_layer.py b/notebooks/Deep_Learning_listmode_PET/test/stir_torch_lm_em_layer.py new file mode 100644 index 00000000..db2c2610 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/test/stir_torch_lm_em_layer.py @@ -0,0 +1,102 @@ +# %% [markdown] +# Skeleton for STIR-based listmode Poisson logL gradient data fidelity layer +# ========================================================================== +# +# Hello 3 +# $$\lambda^n = 1$$ +# + +# %% +import torch +from lm_data_fid import test_lmlogl + + +class LMPoissonLogLGradLayer(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, lm_objective_list) -> torch.Tensor: + + ctx.set_materialize_grads(False) + ctx.x = x.detach() + ctx.lm_objective_list = lm_objective_list + + y = torch.zeros_like(x) + + for i in range(x.shape[0]): + y[i, 0, ...] = lm_objective_list[i].gradient(ctx.x[i, 0, ...]) + + return y + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, None]: + if grad_output is None: + return None, None + else: + back = torch.zeros_like(grad_output) + + for i in range(grad_output.shape[0]): + back[i, 0, ...] = ctx.lm_objective_list[i].hessian_applied( + ctx.x[i, 0, ...], grad_output[i, 0, ...].detach() + ) + + return back, None + + +# %% +class LMEMNet(torch.nn.Module): + def __init__(self, lm_obj_list, num_blocks: int = 20) -> None: + super().__init__() + self._data_fid_gradient_layer = LMPoissonLogLGradLayer.apply + self._lm_obj_list = lm_obj_list + self._sens_imgs = torch.stack( + [l.op.adjoint_ones.unsqueeze(0) for l in lm_obj_list] + ) + self._num_blocks = num_blocks + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + for _ in range(self._num_blocks): + x = x + (x / self._sens_imgs) * self._data_fid_gradient_layer( + x, lm_obj_list + ) + + return x + + +# %% +torch.manual_seed(0) +if torch.cuda.is_available(): + # dev = "cuda:0" + dev = "cpu" +else: + dev = "cpu" + + +# %% + +# setup a list of dummy LM objective function corresponding to a mini-batch of LM acquisitions +lm_obj1 = test_lmlogl(dev, nx=4, ny=4) +lm_obj2 = test_lmlogl(dev, nx=4, ny=4) +lm_obj_list = [lm_obj1, lm_obj2] +batch_size = len(lm_obj_list) + +# setup a test input mini-batch of images that we use to test our network +x_t = torch.rand( + (batch_size, 1) + lm_obj_list[0].op.in_shape, + device=dev, + dtype=torch.float64, + requires_grad=True, +) + +# setup the LM Network using 20 blocks (20 LM EM updates) +lmem_net = LMEMNet(lm_obj_list, num_blocks=20) + +# feed out test image through the network +x_fwd = lmem_net(x_t) +# calculate the analytic ML solution (possible since we have invertible 2x2 forward operators) +x_ml = torch.stack( + [(torch.linalg.inv(l.op.A) @ l.op.data.double()).unsqueeze(0) for l in lm_obj_list], + dim=0, +) + +# test the gradient back propagation through the network +test_lmemnet = torch.autograd.gradcheck(lmem_net, (x_t,)) diff --git a/notebooks/Deep_Learning_listmode_PET/test/test_grad_layer.py b/notebooks/Deep_Learning_listmode_PET/test/test_grad_layer.py new file mode 100644 index 00000000..de9d9d28 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/test/test_grad_layer.py @@ -0,0 +1,211 @@ +# %% [markdown] +# Learning objectives +# =================== +# +# 1. Exercise 1: Implement a custom layer that calculates the Poisson log-likelihood. +# How to define the backward pass? +# 2. Exercise 2: Using the custom layer gradient logL layer, define EM step layer. + +# %% +import sirf.STIR +import torch +import numpy as np +from pathlib import Path +from sirf.Utilities import examples_data_path + +data_path: Path = Path(examples_data_path("PET")) / "mMR" +output_path: Path = Path("recons") +list_file: str = str(data_path / "list.l.hdr") +norm_file: str = str(data_path / "norm.n.hdr") +attn_file: str = str(data_path / "mu_map.hv") +emission_sinogram_output_prefix: str = str(output_path / "emission_sinogram") +scatter_sinogram_output_prefix: str = str(output_path / "scatter_sinogram") +randoms_sinogram_output_prefix: str = str(output_path / "randoms_sinogram") +attenuation_sinogram_output_prefix: str = str(output_path / "acf_sinogram") +num_scatter_iter: int = 5 + +lm_recon_output_file: str = str(output_path / "lm_recon") +n: int = 7 +nxny: tuple[int, int] = (n, n) +num_subsets: int = 21 + +# should run on CPU using OMP_NUM_THREADS=1, to get deterministic behaviour +dev = "cpu" + +# engine's messages go to files, except error messages, which go to stdout +_ = sirf.STIR.MessageRedirector("info.txt", "warn.txt") + +# %% [markdown] +# Load listmode data and create the acquisition model +# --------------------------------------------------- +# +# In this demo example, we use a simplified acquisition model that only implements the geometric forward projection. +# The effects of normalization, attenuation, scatter, randoms, are ignored but can be added as shown in the last +# example. + +# %% +sirf.STIR.AcquisitionData.set_storage_scheme("memory") +listmode_data = sirf.STIR.ListmodeData(list_file) +acq_data_template = listmode_data.acquisition_data_template() + +acq_data = sirf.STIR.AcquisitionData( + str(Path(f"{emission_sinogram_output_prefix}_f1g1d0b0.hs")) +) + +# select acquisition model that implements the geometric +# forward projection by a ray tracing matrix multiplication +acq_model = sirf.STIR.AcquisitionModelUsingRayTracingMatrix() +acq_model.set_num_tangential_LORs(1) + +randoms = sirf.STIR.AcquisitionData(str(Path(f"{randoms_sinogram_output_prefix}.hs"))) + +ac_factors = sirf.STIR.AcquisitionData( + str(Path(f"{attenuation_sinogram_output_prefix}.hs")) +) +asm_attn = sirf.STIR.AcquisitionSensitivityModel(ac_factors) + +asm_norm = sirf.STIR.AcquisitionSensitivityModel(norm_file) +asm = sirf.STIR.AcquisitionSensitivityModel(asm_norm, asm_attn) + +asm.set_up(acq_data) +acq_model.set_acquisition_sensitivity(asm) + +scatter_estimate = sirf.STIR.AcquisitionData( + str(Path(f"{scatter_sinogram_output_prefix}_{num_scatter_iter}.hs")) +) +acq_model.set_background_term(randoms + scatter_estimate) + +# setup an initial (template) image based on the acquisition data template +initial_image = acq_data_template.create_uniform_image(value=0, xy=nxny) + +# %% [markdown] +# Setup of the Poisson log likelihood listmode objective function +# --------------------------------------------------------------- +# +# Using the listmode data and the acquisition model, we can now setup the Poisson log likelihood objective function. + +# %% +lm_obj_fun = ( + sirf.STIR.PoissonLogLikelihoodWithLinearModelForMeanAndListModeDataWithProjMatrixByBin() +) +lm_obj_fun.set_acquisition_model(acq_model) +lm_obj_fun.set_acquisition_data(listmode_data) +lm_obj_fun.set_num_subsets(num_subsets) +print("setting up listmode objective function ...") +lm_obj_fun.set_up(initial_image) + +# %% [markdown] +# Setup of a pytorch layer that computes the gradient of the Poisson log likelihood objective function +# ---------------------------------------------------------------------------------------------------- +# +# Using our listmode objective function, we can now implement a custom pytorch layer that computes the gradient +# of the Poisson log likelihood using the `gradient()` method using a subset of the listmode data. +# +# This layer maps a torch minibatch tensor to another torch tensor of the same shape. +# The shape of the minibatch tensor is [batch_size=1, channel_size=1, spatial dimensions]. +# For the implementation we subclass `torch.autograd.Function` and implement the `forward()` and +# `backward()` methods. + + +class SIRFPoissonlogLGradLayer(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + objective_function, + sirf_template_image: sirf.STIR.ImageData, + subset: int, + ) -> torch.Tensor: + + # we use the context object ctx to store the matrix and other variables that we need in the backward pass + ctx.device = x.device + ctx.objective_function = objective_function + ctx.dtype = x.dtype + ctx.subset = subset + ctx.sirf_template_image = sirf_template_image.clone() + + # setup a new sirf.STIR ImageData object + x_sirf = sirf_template_image.clone() + # convert torch tensor to sirf image via numpy + x_sirf.fill(x.cpu().numpy()[0, 0, ...]) + + # save the input sirf.STIR ImageData for the backward pass + ctx.x_sirf = x_sirf + + # calculate the gradient of the Poisson log likelihood using SIRF + g_np = objective_function.gradient(x_sirf, subset).as_array() + + # convert back to torch tensor + y = ( + torch.tensor(g_np, device=ctx.device, dtype=ctx.dtype) + .unsqueeze(0) + .unsqueeze(0) + ) + + return y + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor | None + ) -> tuple[torch.Tensor | None, None, None, None]: + if grad_output is None: + return None, None, None, None + else: + # convert torch tensor to sirf image via numpy + ctx.sirf_template_image.fill(grad_output.cpu().numpy()[0, 0, ...]) + + # calculate the Jacobian vector product (the Hessian applied to an image) using SIRF + back_sirf = ctx.objective_function.multiply_with_Hessian( + ctx.x_sirf, ctx.sirf_template_image, ctx.subset + ) + + # convert back to torch tensor via numpy + back = ( + torch.tensor(back_sirf.as_array(), device=ctx.device, dtype=ctx.dtype) + .unsqueeze(0) + .unsqueeze(0) + ) + + return back, None, None, None + + +# %% +tmp = np.zeros(initial_image.shape, dtype=np.float32) +tmp[:, n // 2, n // 2] = 0.1 +tmp[:, n // 2 - 1, n // 2] = 0.1 +tmp[:, n // 2, n // 2 - 1] = 0.1 +tmp[:, n // 2 + 1, n // 2] = 0.1 +tmp[:, n // 2, n // 2 + 1] = 0.1 + +lm_ref_recon = initial_image.clone() +lm_ref_recon.fill(tmp) +x_t = ( + torch.tensor( + lm_ref_recon.as_array(), device=dev, dtype=torch.float32, requires_grad=True + ) + .unsqueeze(0) + .unsqueeze(0) +) + +poisson_logL_grad_layer = SIRFPoissonlogLGradLayer.apply +g = poisson_logL_grad_layer(x_t, lm_obj_fun, initial_image, 0) +g2 = poisson_logL_grad_layer(x_t, lm_obj_fun, initial_image, 0) + +r = torch.abs(g2 - g) / (torch.abs(g) + 1e-6) +print(r.max()) + +# %% +# +# Currently raises +# GradcheckError: Backward is not reentrant, +# i.e., running backward with same input and grad_output multiple times gives different values, +# although analytical gradient matches numerical gradient.The tolerance for nondeterminism was 0.0. +# +res = torch.autograd.gradcheck( + poisson_logL_grad_layer, + (x_t, lm_obj_fun, initial_image, 0), + eps=1e-3, + atol=1e-4, + rtol=1e-3, + fast_mode=True, +) diff --git a/notebooks/Deep_Learning_listmode_PET/test/test_hessian.py b/notebooks/Deep_Learning_listmode_PET/test/test_hessian.py new file mode 100644 index 00000000..7aebf6b4 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/test/test_hessian.py @@ -0,0 +1,318 @@ +# %% [markdown] +# Sinogram and Listmode OSEM using sirf.STIR +# ========================================== +# +# Using the theory learnings from the previous "theory" notebook, we will now learn how to perform +# PET reconstruction of emission data in listmode and sinogram format using (sinogram and listmode) +# objective function objects of the sirf.STIR library. +# +# We will see that standard OSEM reconstruction can be seen as a sequence of image update block, +# where the update in each block is related to the gradient of the Poisson loglikelihood objective function. +# +# Understanding these OSEM update blocks is the first key step for implementing a pytorch-based feed-forward +# neural network for PET image reconstruction also containing OSEM-like update blocks. + +# %% [markdown] +# Import modules and define file names +# ------------------------------------ + +# %% +import sirf.STIR +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path +from sirf.Utilities import examples_data_path + +data_path: Path = Path(examples_data_path("PET")) / "mMR" +output_path: Path = Path("recons") +list_file: str = str(data_path / "list.l.hdr") +norm_file: str = str(data_path / "norm.n.hdr") +attn_file: str = str(data_path / "mu_map.hv") +emission_sinogram_output_prefix: str = str(output_path / "emission_sinogram") +scatter_sinogram_output_prefix: str = str(output_path / "scatter_sinogram") +randoms_sinogram_output_prefix: str = str(output_path / "randoms_sinogram") +attenuation_sinogram_output_prefix: str = str(output_path / "acf_sinogram") +recon_output_file: str = str(output_path / "recon") +lm_recon_output_file: str = str(output_path / "lm_recon") +nxny: tuple[int, int] = (127, 127) +num_subsets: int = 21 +num_iter: int = 1 +num_scatter_iter: int = 3 + +# create the output directory +output_path.mkdir(exist_ok=True) + +# engine's messages go to files, except error messages, which go to stdout +_ = sirf.STIR.MessageRedirector("info.txt", "warn.txt") + +# %% [markdown] +# Read the listmode data and create a sinogram template +# ----------------------------------------------------- + +# %% +sirf.STIR.AcquisitionData.set_storage_scheme("memory") +listmode_data = sirf.STIR.ListmodeData(list_file) +acq_data_template = listmode_data.acquisition_data_template() +print(acq_data_template.get_info()) + +# %% [markdown] +# Conversion of listmode to sinogram data (needed for scatter estimation) +# ----------------------------------------------------------------------- + +# %% +# create listmode-to-sinograms converter object +lm2sino = sirf.STIR.ListmodeToSinograms() + +# set input, output and template files +lm2sino.set_input(listmode_data) +lm2sino.set_output_prefix(emission_sinogram_output_prefix) +lm2sino.set_template(acq_data_template) + +# get the start and end time of the listmode data +frame_start = float( + [ + x + for x in listmode_data.get_info().split("\n") + if x.startswith("Time frame start") + ][0] + .split(": ")[1] + .split("-")[0] +) +frame_end = float( + [ + x + for x in listmode_data.get_info().split("\n") + if x.startswith("Time frame start") + ][0] + .split(": ")[1] + .split("-")[1] + .split("(")[0] +) +# set interval +lm2sino.set_time_interval(frame_start, frame_end) +# set up the converter +lm2sino.set_up() + +# convert (need it for the scatter estimate) +lm2sino.process() +acq_data = lm2sino.get_output() + +# %% [markdown] +# Estimation of random coincidences +# --------------------------------- + +# %% +randoms_filepath = Path(f"{randoms_sinogram_output_prefix}.hs") + +if not randoms_filepath.exists(): + print("estimting randoms") + randoms = lm2sino.estimate_randoms() + randoms.write(randoms_sinogram_output_prefix) +else: + print("reading randoms from {randoms_filepath}") + randoms = sirf.STIR.AcquisitionData(str(randoms_filepath)) + + +# %% [markdown] +# Setup of the acquisition model +# ------------------------------ + +# %% +# select acquisition model that implements the geometric +# forward projection by a ray tracing matrix multiplication +acq_model = sirf.STIR.AcquisitionModelUsingRayTracingMatrix() +acq_model.set_num_tangential_LORs(1) + + +# %% [markdown] +# Calculation the attenuation sinogram +# ------------------------------------ + +# %% +# read attenuation image and display a single slice +attn_image = sirf.STIR.ImageData(attn_file) + +# create attenuation factors +asm_attn = sirf.STIR.AcquisitionSensitivityModel(attn_image, acq_model) +# converting attenuation image into attenuation factors (one for every bin) +asm_attn.set_up(acq_data) + +acf_filepath = Path(f"{attenuation_sinogram_output_prefix}.hs") + +if not acf_filepath.exists(): + ac_factors = acq_data.get_uniform_copy(value=1) + print("applying attenuation (please wait, may take a while)...") + asm_attn.unnormalise(ac_factors) + ac_factors.write(attenuation_sinogram_output_prefix) +else: + print(f"reading attenuation factors from {acf_filepath}") + ac_factors = sirf.STIR.AcquisitionData(str(acf_filepath)) + +asm_attn = sirf.STIR.AcquisitionSensitivityModel(ac_factors) + +# %% [markdown] +# Creation of the normalization factors (sensitivity sinogram) +# ------------------------------------------------------------ + +# %% +# create acquisition sensitivity model from normalisation data +asm_norm = sirf.STIR.AcquisitionSensitivityModel(norm_file) + +asm = sirf.STIR.AcquisitionSensitivityModel(asm_norm, asm_attn) +asm.set_up(acq_data) +acq_model.set_acquisition_sensitivity(asm) + +# %% [markdown] +# Estimation of scattered coincidences +# ------------------------------------ + +# %% +scatter_filepath: Path = Path(f"{scatter_sinogram_output_prefix}_{num_scatter_iter}.hs") + +if not scatter_filepath.exists(): + print("estimating scatter (this will take a while!)") + scatter_estimator = sirf.STIR.ScatterEstimator() + scatter_estimator.set_input(acq_data) + scatter_estimator.set_attenuation_image(attn_image) + scatter_estimator.set_randoms(randoms) + scatter_estimator.set_asm(asm_norm) + # invert attenuation factors to get the correction factors, + # as this is unfortunately what a ScatterEstimator needs + acf_factors = acq_data.get_uniform_copy() + acf_factors.fill(1 / ac_factors.as_array()) + scatter_estimator.set_attenuation_correction_factors(acf_factors) + scatter_estimator.set_output_prefix(scatter_sinogram_output_prefix) + scatter_estimator.set_num_iterations(num_scatter_iter) + scatter_estimator.set_up() + scatter_estimator.process() + scatter_estimate = scatter_estimator.get_output() +else: + print(f"reading scatter from file {scatter_filepath}") + scatter_estimate = sirf.STIR.AcquisitionData(str(scatter_filepath)) + +# chain attenuation and ECAT8 normalisation +acq_model.set_background_term(randoms + scatter_estimate) + +# %% [markdown] +# Setup of the Poisson loglikelihood objective function ($logL(y,x)$) in sinogram mode +# ------------------------------------------------------------------------------------ + +# %% +initial_image = acq_data.create_uniform_image(value=1, xy=nxny) + +# create objective function +obj_fun = sirf.STIR.make_Poisson_loglikelihood(acq_data) +obj_fun.set_acquisition_model(acq_model) +obj_fun.set_num_subsets(num_subsets) +obj_fun.set_up(initial_image) + +# %% [markdown] +# Image reconstruction (optimization of the Poisson logL objective function) using sinogram OSEM +# ---------------------------------------------------------------------------------------------- + +# %% +if not Path(f"{recon_output_file}.hv").exists(): + reconstructor = sirf.STIR.OSMAPOSLReconstructor() + reconstructor.set_objective_function(obj_fun) + reconstructor.set_num_subsets(num_subsets) + reconstructor.set_num_subiterations(num_iter * num_subsets) + reconstructor.set_input(acq_data) + reconstructor.set_up(initial_image) + reconstructor.set_current_estimate(initial_image) + reconstructor.process() + ref_recon = reconstructor.get_output() + ref_recon.write(recon_output_file) +else: + ref_recon = sirf.STIR.ImageData(f"{recon_output_file}.hv") + +vmax = np.percentile(ref_recon.as_array(), 99.999) + +# %% +current_estimate = ref_recon.copy() +input_img = acq_data.create_uniform_image(value=1, xy=nxny) +np.random.seed(0) +input_img.fill( + np.random.rand(*input_img.shape) + * (obj_fun.get_subset_sensitivity(0).as_array() > 0) + * current_estimate.max() +) + +hess_out_img = obj_fun.multiply_with_Hessian(current_estimate, input_img, subset=0) + +# %% +# repeat the calculation using the LM objective function +# define objective function to be maximized as +# Poisson logarithmic likelihood (with linear model for mean) +lm_obj_fun = ( + sirf.STIR.PoissonLogLikelihoodWithLinearModelForMeanAndListModeDataWithProjMatrixByBin() +) +lm_obj_fun.set_acquisition_model(acq_model) +lm_obj_fun.set_acquisition_data(listmode_data) +lm_obj_fun.set_num_subsets(num_subsets) +lm_obj_fun.set_up(initial_image) + +hess_out_img_lm = lm_obj_fun.multiply_with_Hessian( + current_estimate, input_img, subset=0 +) + +# %% +# verify hessian calculation + +acq_model.set_up(acq_data, initial_image) +acq_model.num_subsets = num_subsets +acq_model.subset_num = 0 + +# get the linear (Ax) part of the Ax + b affine acq. model +lin_acq_model = acq_model.get_linear_acquisition_model() +lin_acq_model.num_subsets = num_subsets +lin_acq_model.subset_num = 0 + +# for the Hessian "multiply" we need the linear part of the acquisition model applied to the input image +input_img_fwd = lin_acq_model.forward(input_img) +current_estimate_fwd = acq_model.forward(current_estimate) +h = -acq_model.backward( + acq_data * input_img_fwd / (current_estimate_fwd * current_estimate_fwd) +) +h2 = -acq_model.backward( + acq_data * input_img_fwd / (current_estimate_fwd * current_estimate_fwd + 1e-8) +) + + +# %% + +fig, ax = plt.subplots(2, 6, figsize=(18, 6), tight_layout=True) +ax[0, 0].imshow(current_estimate.as_array()[71, :, :], cmap="Greys") +ax[0, 1].imshow(input_img.as_array()[71, :, :], cmap="Greys") +ax[0, 2].imshow(hess_out_img.as_array()[71, :, :], cmap="Greys", vmin=-5000, vmax=-1000) +ax[0, 3].imshow( + hess_out_img_lm.as_array()[71, :, :], cmap="Greys", vmin=-5000, vmax=-1000 +) +ax[0, 4].imshow(h.as_array()[71, :, :], cmap="Greys", vmin=-5000, vmax=-1000) +ax[0, 5].imshow(h2.as_array()[71, :, :], cmap="Greys", vmin=-5000, vmax=-1000) +ax[1, 2].imshow( + hess_out_img.as_array()[71, :, :], + cmap="Greys", + vmin=-100000, + vmax=hess_out_img.max(), +) +ax[1, 3].imshow( + hess_out_img_lm.as_array()[71, :, :], + cmap="Greys", + vmin=-100000, + vmax=hess_out_img.max(), +) +ax[1, 4].imshow( + h.as_array()[71, :, :], cmap="Greys", vmin=-100000, vmax=hess_out_img.max() +) +ax[1, 5].imshow( + h2.as_array()[71, :, :], cmap="Greys", vmin=-100000, vmax=hess_out_img.max() +) +ax[0, 0].set_title("current estimate", fontsize="medium") +ax[0, 1].set_title("input", fontsize="medium") +ax[0, 2].set_title("sino Hessian multiply", fontsize="medium") +ax[0, 3].set_title("neg. LM Hessian multiply", fontsize="medium") +ax[0, 4].set_title("manual Hessian multiply", fontsize="medium") +ax[0, 5].set_title("manual Hessian multiply + eps", fontsize="medium") +ax[1, 0].set_axis_off() +ax[1, 1].set_axis_off() +fig.show() diff --git a/notebooks/Deep_Learning_listmode_PET/test/torch_em_layers.py b/notebooks/Deep_Learning_listmode_PET/test/torch_em_layers.py new file mode 100644 index 00000000..387280e2 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/test/torch_em_layers.py @@ -0,0 +1,202 @@ +# %% +from __future__ import annotations + +import torch +import parallelproj + + +# device variable (cpu or cuda) that determines whether calculations +# are performed on the cpu or cuda gpu +if parallelproj.cuda_present: + dev = "cuda" +else: + dev = "cpu" + +# %% + + +class PoissonLogLGradLayer(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + operator: parallelproj.LinearOperator, + data: torch.Tensor, + contam: torch.Tensor, + ) -> torch.Tensor: + + ctx.set_materialize_grads(False) + ctx.operator = operator + + y = torch.zeros_like(x) + ratio2 = torch.zeros_like(data) + + for i in range(x.shape[0]): + exp = operator(x[i, 0, ...].detach()) + contam[i, ...] + ratio = data[i, ...] / exp + ratio2[i, ...] = data[i, ...] / (exp ** 2) + y[i, 0, ...] = operator.adjoint(ratio - 1) + + ctx.ratio2 = ratio2 + + return y + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> tuple[torch.Tensor | None, None, None, None]: + if grad_output is None: + return None, None, None, None + else: + operator = ctx.operator + ratio2 = ctx.ratio2 + + x = torch.zeros_like(grad_output) + + for i in range(grad_output.shape[0]): + exp = operator(grad_output[i, 0, ...].detach()) + x[i, 0, ...] = -operator.adjoint(ratio2[i, ...] * exp) + + return x, None, None, None + + +# %% + + +class PoissonEMOperator(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + operator: parallelproj.LinearOperator, + data: torch.Tensor, + contam: torch.Tensor, + sens_img: torch.Tensor, + ) -> torch.Tensor: + + ctx.set_materialize_grads(False) + ctx.operator = operator + + x_sens_ratio = torch.zeros_like(x) + ratio2 = torch.zeros_like(data) + mult_update = torch.zeros_like(x) + + for i in range(x.shape[0]): + exp = operator(x[i, 0, ...].detach()) + contam[i, ...] + ratio = data[i, ...] / exp + ratio2[i, ...] = data[i, ...] / (exp ** 2) + x_sens_ratio[i, 0, ...] = x[i, 0, ...] / sens_img[i, 0, ...] + mult_update[i, 0, ...] = operator.adjoint(ratio) / sens_img[i, 0, ...] + + ctx.ratio2 = ratio2 + ctx.x_sens_ratio = x_sens_ratio + ctx.mult_update = mult_update + + return mult_update * x + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> tuple[torch.Tensor | None, None, None, None, None]: + if grad_output is None: + return None, None, None, None, None + else: + operator = ctx.operator + ratio2 = ctx.ratio2 + mult_update = ctx.mult_update + x_sens_ratio = ctx.x_sens_ratio + + x = torch.zeros_like(grad_output) + + for i in range(grad_output.shape[0]): + exp = operator( + x_sens_ratio[i, 0, ...] * grad_output[i, 0, ...].detach() + ) + x[i, 0, ...] = grad_output[i, 0, ...] * mult_update[ + i, 0, ... + ] - operator.adjoint(ratio2[i, ...] * exp) + + return x, None, None, None, None + + +# %% + + +class EMNet(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self._data_fid_layer = PoissonLogLGradLayer.apply + + def forward( + self, + x: torch.Tensor, + operator: parallelproj.LinearOperator, + data: torch.Tensor, + contam: torch.Tensor, + sens_img: torch.Tensor, + ) -> torch.tensor: + + return x + (x / sens_img) * self._data_fid_layer(x, operator, data, contam) + + +# %% +torch.manual_seed(0) + +A = torch.tensor( + [[1.5, 0.5, 0.1], [0.3, 2.1, 0.2], [0.9, 1.2, 2.1], [1.0, 2.0, 0.5]], + dtype=torch.float64, + device=dev, +) +proj = parallelproj.MatrixOperator(A) + +# %% +# Define a mini batch of input and output tensors +# ----------------------------------------------- + +batch_size = 2 + +xt = torch.rand( + (batch_size, 1) + proj.in_shape, device=dev, dtype=torch.float64, requires_grad=True +) + +data_t = torch.rand( + (batch_size,) + proj.out_shape, device=dev, dtype=torch.float64, requires_grad=False +) + +contam_t = torch.rand( + (batch_size,) + proj.out_shape, device=dev, dtype=torch.float64, requires_grad=False +) + +sens_t = torch.zeros_like(xt) +ones_data = torch.ones(proj.out_shape, device=dev, dtype=torch.float64) +for i in range(batch_size): + sens_t[i, 0, ...] = proj.adjoint(ones_data) + +# %% +# Define the forward and backward projection layers +# ------------------------------------------------- + +logLgrad_layer = PoissonLogLGradLayer.apply +f2 = logLgrad_layer(xt, proj, data_t, contam_t) + +em_layer = PoissonEMOperator.apply +em_update_1 = em_layer(xt, proj, data_t, contam_t, sens_t) + +em_net = EMNet() +em_update_2 = em_net(xt, proj, data_t, contam_t, sens_t) + +manual_em_update = torch.zeros_like(xt) +for i in range(batch_size): + manual_em_update[i, 0, ...] = (xt[i, 0, ...] / sens_t[i, 0, ...]) * ( + A.T @ (data_t[i, ...] / (A @ xt[i, 0, ...] + contam_t[i, ...])) + ) + +# %% +# Check whether the gradients are calculated correctly +# ---------------------------------------------------- + +test_logLgrad = torch.autograd.gradcheck(logLgrad_layer, (xt, proj, data_t, contam_t)) + +test_em = torch.autograd.gradcheck(em_layer, (xt, proj, data_t, contam_t, sens_t)) + +test_em2 = torch.autograd.gradcheck(em_net, (xt, proj, data_t, contam_t, sens_t)) diff --git a/notebooks/Deep_Learning_listmode_PET/test/train_varnet.py b/notebooks/Deep_Learning_listmode_PET/test/train_varnet.py new file mode 100644 index 00000000..54578123 --- /dev/null +++ b/notebooks/Deep_Learning_listmode_PET/test/train_varnet.py @@ -0,0 +1,421 @@ +# %% [markdown] +# Learning objectives +# =================== +# +# 1. Exercise 1: Implement a custom layer that calculates the Poisson log-likelihood. +# How to define the backward pass? +# 2. Exercise 2: Using the custom layer gradient logL layer, define EM step layer. + +# %% +import sirf.STIR +import torch +import matplotlib.pyplot as plt +from pathlib import Path +from sirf.Utilities import examples_data_path +from scipy.ndimage import gaussian_filter + +acq_time: str = "1min" + +if acq_time == "1min": + data_path: Path = Path(examples_data_path("PET")) / "mMR" + list_file: str = str(data_path / "list.l.hdr") + norm_file: str = str(data_path / "norm.n.hdr") + attn_file: str = str(data_path / "mu_map.hv") +elif acq_time == "60min": + data_path: Path = Path("..") / ".." / "data" / "PET" / "mMR" / "NEMA_IQ" + list_file: str = str(data_path / "20170809_NEMA_60min_UCL.l.hdr") + norm_file: str = str(data_path / "20170809_NEMA_UCL.n.hdr") + attn_file: str = str(data_path / "20170809_NEMA_MUMAP_UCL.v.hdr") +else: + raise ValueError("Please choose acq_time to be either '1min' or '60min'") + +output_path: Path = Path(f"recons_{acq_time}") +emission_sinogram_output_prefix: str = str(output_path / "emission_sinogram") +scatter_sinogram_output_prefix: str = str(output_path / "scatter_sinogram") +randoms_sinogram_output_prefix: str = str(output_path / "randoms_sinogram") +attenuation_sinogram_output_prefix: str = str(output_path / "acf_sinogram") +num_scatter_iter: int = 3 + +lm_recon_output_file: str = str(output_path / "lm_recon") +nxny: tuple[int, int] = (127, 127) +num_subsets: int = 21 + +if torch.cuda.is_available(): + dev = "cuda:0" +else: + dev = "cpu" + +# engine's messages go to files, except error messages, which go to stdout +_ = sirf.STIR.MessageRedirector("info.txt", "warn.txt") + +# %% [markdown] +# Load listmode data and create the acquisition model +# --------------------------------------------------- +# +# In this demo example, we use a simplified acquisition model that only implements the geometric forward projection. +# The effects of normalization, attenuation, scatter, randoms, are ignored but can be added as shown in the last +# example. + +# %% +sirf.STIR.AcquisitionData.set_storage_scheme("memory") +listmode_data = sirf.STIR.ListmodeData(list_file) +acq_data_template = listmode_data.acquisition_data_template() + +acq_data = sirf.STIR.AcquisitionData( + str(Path(f"{emission_sinogram_output_prefix}_f1g1d0b0.hs")) +) + +# select acquisition model that implements the geometric +# forward projection by a ray tracing matrix multiplication +acq_model = sirf.STIR.AcquisitionModelUsingRayTracingMatrix() +acq_model.set_num_tangential_LORs(1) + +randoms = sirf.STIR.AcquisitionData(str(Path(f"{randoms_sinogram_output_prefix}.hs"))) + +ac_factors = sirf.STIR.AcquisitionData( + str(Path(f"{attenuation_sinogram_output_prefix}.hs")) +) +asm_attn = sirf.STIR.AcquisitionSensitivityModel(ac_factors) + +asm_norm = sirf.STIR.AcquisitionSensitivityModel(norm_file) +asm = sirf.STIR.AcquisitionSensitivityModel(asm_norm, asm_attn) + +asm.set_up(acq_data) +acq_model.set_acquisition_sensitivity(asm) + +scatter_estimate = sirf.STIR.AcquisitionData( + str(Path(f"{scatter_sinogram_output_prefix}_{num_scatter_iter}.hs")) +) +acq_model.set_background_term(randoms + scatter_estimate) + +# setup an initial (template) image based on the acquisition data template +initial_image = acq_data_template.create_uniform_image(value=1, xy=nxny) + +# %% [markdown] +# Setup of the Poisson log likelihood listmode objective function +# --------------------------------------------------------------- +# +# Using the listmode data and the acquisition model, we can now setup the Poisson log likelihood objective function. + +# %% +lm_obj_fun = ( + sirf.STIR.PoissonLogLikelihoodWithLinearModelForMeanAndListModeDataWithProjMatrixByBin() +) +lm_obj_fun.set_acquisition_model(acq_model) +lm_obj_fun.set_acquisition_data(listmode_data) +lm_obj_fun.set_num_subsets(num_subsets) +lm_obj_fun.set_cache_max_size(1000000000) +lm_obj_fun.set_cache_path(str(output_path)) +print("setting up listmode objective function ...") +lm_obj_fun.set_up(initial_image) + +# %% [markdown] +# Setup of OSEM update layer +# -------------------------- +# +# See notebook 04. + +# %% +class SIRFPoissonlogLGradLayer(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + objective_function, + sirf_template_image: sirf.STIR.ImageData, + subset: int, + ) -> torch.Tensor: + """(listmode) Poisson loglikelihood gradient layer forward pass + + Parameters + ---------- + ctx : context object + used to store objects that we need in the backward pass + x : torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the image + objective_function : sirf (listmode) objective function + the objective function that we use to calculate the gradient + sirf_template_image : sirf.STIR.ImageData + image template that we use to convert between torch tensors and sirf images + subset : int + subset number used for the gradient calculation + + Returns + ------- + torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the image + containing the gradient of the (listmode) Poisson log likelihood at x + """ + + # we use the context object ctx to store the matrix and other variables that we need in the backward pass + ctx.device = x.device + ctx.objective_function = objective_function + ctx.dtype = x.dtype + ctx.subset = subset + ctx.sirf_template_image = sirf_template_image + + # setup a new sirf.STIR ImageData object + x_sirf = sirf_template_image.clone() + # convert torch tensor to sirf image via numpy + x_sirf.fill(x.cpu().numpy()[0, 0, ...]) + + # save the input sirf.STIR ImageData for the backward pass + ctx.x_sirf = x_sirf + + # calculate the gradient of the Poisson log likelihood using SIRF + g_np = objective_function.gradient(x_sirf, subset).as_array() + + # convert back to torch tensor + y = ( + torch.tensor(g_np, device=ctx.device, dtype=ctx.dtype) + .unsqueeze(0) + .unsqueeze(0) + ) + + return y + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor | None + ) -> tuple[torch.Tensor | None, None, None, None]: + """(listmode) Poisson loglikelihood gradient layer backward pass + + Parameters + ---------- + ctx : context object + used to store objects that we need in the backward pass + grad_output : torch.Tensor | None + minibatch tensor of shape [1,1,spatial_dimensions] containing the gradient (called v in the autograd tutorial) + https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#optional-reading-vector-calculus-using-autograd + + Returns + ------- + tuple[torch.Tensor | None, None, None, None] + the Jacobian-vector product of the Poisson log likelihood gradient layer + """ + + if grad_output is None: + return None, None, None, None + else: + # convert torch tensor to sirf image via numpy + ctx.sirf_template_image.fill(grad_output.cpu().numpy()[0, 0, ...]) + + # calculate the Jacobian vector product (the Hessian applied to an image) using SIRF + back_sirf = ctx.objective_function.multiply_with_Hessian( + ctx.x_sirf, ctx.sirf_template_image, ctx.subset + ) + + # convert back to torch tensor via numpy + back = ( + torch.tensor(back_sirf.as_array(), device=ctx.device, dtype=ctx.dtype) + .unsqueeze(0) + .unsqueeze(0) + ) + + return back, None, None, None + + +# %% +class OSEMUpdateLayer(torch.nn.Module): + def __init__( + self, + objective_function, + sirf_template_image: sirf.STIR.ImageData, + subset: int, + device: str, + ) -> None: + """OSEM update layer + + Parameters + ---------- + objective_function : sirf (listmode) objective function + the objective function that we use to calculate the gradient + sirf_template_image : sirf.STIR.ImageData + image template that we use to convert between torch tensors and sirf images + subset : int + subset number used for the gradient calculation + device : str + device used for the calculations + + Returns + ------- + torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the OSEM + update of the input image using the Poisson log likelihood objective function + """ + super().__init__() + self._objective_function = objective_function + self._sirf_template_image: sirf.STIR.ImageData = sirf_template_image + self._subset: int = subset + + self._poisson_logL_grad_layer = SIRFPoissonlogLGradLayer.apply + + # setup a tensor containng the inverse of the subset sensitivity image adding the minibatch and channel dimensions + self._inv_sens_image: torch.Tensor = 1.0 / torch.tensor( + objective_function.get_subset_sensitivity(subset).as_array(), + dtype=torch.float32, + device=device, + ).unsqueeze(0).unsqueeze(0) + # replace positive infinity values with 0 (voxels with 0 sensitivity) + torch.nan_to_num(self._inv_sens_image, posinf=0, out=self._inv_sens_image) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """forward pass of the OSEM update layer + + Parameters + ---------- + x : torch.Tensor + minibatch tensor of shape [1,1,spatial_dimensions] containing the image + + Returns + ------- + torch.Tensor + OSEM update image + """ + grad_x: torch.Tensor = self._poisson_logL_grad_layer( + x, self._objective_function, self._sirf_template_image, self._subset + ) + return x + x * self._inv_sens_image * grad_x + + +# %% +class UnrolledOSEMVarNet(torch.nn.Module): + def __init__( + self, + objective_function, + sirf_template_image: sirf.STIR.ImageData, + convnet: torch.nn.Module, + device: str, + ) -> None: + super().__init__() + self._osem_step_layer0 = OSEMUpdateLayer( + objective_function, sirf_template_image, 0, device + ) + self._osem_step_layer1 = OSEMUpdateLayer( + objective_function, sirf_template_image, 1, device + ) + self._convnet = convnet + self._relu = torch.nn.ReLU() + + self._fusion_weight0 = torch.nn.Parameter( + 10*torch.ones(1, device=device, dtype=torch.float32) + ) + self._fusion_weight1 = torch.nn.Parameter( + 10*torch.ones(1, device=device, dtype=torch.float32) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1 = self._relu( + self._fusion_weight0 * self._convnet(x) + self._osem_step_layer0(x) + ) + x2 = self._relu( + self._fusion_weight1 * self._convnet(x1) + self._osem_step_layer1(x1) + ) + + return x2 + + +# %% + +lm_ref_recon = sirf.STIR.ImageData(f"{lm_recon_output_file}.hv") +x_t = ( + torch.tensor( + lm_ref_recon.as_array(), device=dev, dtype=torch.float32, requires_grad=False + ) + .unsqueeze(0) + .unsqueeze(0) +) + +cnn = torch.nn.Sequential( + torch.nn.Conv3d(1, 5, 5, padding="same", bias=False), + torch.nn.Conv3d(5, 5, 5, padding="same", bias=False), + torch.nn.PReLU(device=dev), + torch.nn.Conv3d(5, 5, 5, padding="same", bias=False), + torch.nn.Conv3d(5, 5, 5, padding="same", bias=False), + torch.nn.PReLU(device=dev), + torch.nn.Conv3d(5, 1, 1, padding="same", bias=False), +).to(dev) + + +varnet = UnrolledOSEMVarNet(lm_obj_fun, initial_image, cnn, dev) + +# define the high quality target image (mini-batch) +lm_60min_recon_output_file: str = str(Path(f"recons_60min") / "lm_recon") +lm_60min_ref_recon = sirf.STIR.ImageData(f"{lm_60min_recon_output_file}.hv") + +# we have to scale the 60min reconstruction, since it is not reconcstructed in kBq/ml +scale_factor = lm_ref_recon.as_array().mean() / lm_60min_ref_recon.as_array().mean() +lm_60min_ref_recon *= scale_factor + +target = ( + torch.tensor( + lm_60min_ref_recon.as_array(), device=dev, dtype=torch.float32, requires_grad=False + ) + .unsqueeze(0) + .unsqueeze(0) +) + +optimizer = torch.optim.Adam(varnet._convnet.parameters(), lr=1e-3) +# define the loss function +loss_fct = torch.nn.MSELoss() + +# %% +# run 10 updates of the model parameters using backpropagation of the +# gradient of the loss function and the Adam optimizer + +num_epochs = 50 +training_loss = torch.zeros(num_epochs) + +for i in range(num_epochs): + # pass the input mini-batch through the network + prediction = varnet(x_t) + # calculate the MSE loss between the prediction and the target + loss = loss_fct(prediction, target) + # backpropagate the gradient of the loss through the network + # (needed to update the trainable parameters of the network with an optimizer) + optimizer.zero_grad() + loss.backward() + # update the trainable parameters of the network with the optimizer + optimizer.step() + print(i, loss.item()) + # save the training loss + training_loss[i] = loss.item() + + +# %% +# visualize the results +vmax = float(target.max()) +sl = 71 + +fig1, ax1 = plt.subplots(2, 3, figsize=(9, 6), tight_layout=True) +ax1[0, 0].imshow(x_t.cpu().numpy()[0, 0, sl, :, :], cmap="Greys", vmin=0, vmax=vmax) +ax1[0, 1].imshow( + prediction.detach().cpu().numpy()[0, 0, sl, :, :], cmap="Greys", vmin=0, vmax=vmax +) +ax1[0, 2].imshow(target.cpu().numpy()[0, 0, sl, :, :], cmap="Greys", vmin=0, vmax=vmax) +ax1[1, 0].imshow( + x_t.cpu().numpy()[0, 0, sl, :, :] - target.cpu().numpy()[0, 0, sl, :, :], + cmap="seismic", + vmin=-0.01, + vmax=0.01, +) +ax1[1, 1].imshow( + prediction.detach().cpu().numpy()[0, 0, sl, :, :] + - target.cpu().numpy()[0, 0, sl, :, :], + cmap="seismic", + vmin=-0.01, + vmax=0.01, +) + +ax1[0, 0].set_title("network input") +ax1[0, 1].set_title("network output") +ax1[0, 2].set_title("target") +ax1[1, 0].set_title("network input - target") +ax1[1, 1].set_title("network output - target") +fig1.show() + +fig2, ax2 = plt.subplots() +ax2.plot(training_loss.cpu().numpy()) +ax2.set_xlabel("epoch") +ax2.set_ylabel("training loss") +fig2.show()