diff --git a/life-science/00_monai_decathlon.ipynb b/life-science/00_monai_decathlon.ipynb new file mode 100644 index 00000000..18b2cd0d --- /dev/null +++ b/life-science/00_monai_decathlon.ipynb @@ -0,0 +1,195 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "33e2ac00-22be-4f7c-9445-5c3220d0f1bf", + "metadata": {}, + "source": [ + "# Fetching Brain Tumor Segemntation Dataset\n", + "\n", + "In this notebook, we will learn:\n", + "- how we can use [MONAI Core APIs](https://github.com/Project-MONAI/MONAI) to download the brain tumor segmentation data from the [Medical Segmentation Decathlon](http://medicaldecathlon.com) challenge.\n", + "- how we can upload the dataset to Weights & Biases and use it as a dataset artifact." + ] + }, + { + "cell_type": "markdown", + "id": "813a28eb-8d05-412c-b3d4-9e64eb2962dc", + "metadata": {}, + "source": [ + "## 🌴 Setup and Installation\n", + "\n", + "First, let us install the latest version of both MONAI and Weights and Biases." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d8a4eaa-6c15-44f0-81f8-b0c2800b1017", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -q -U monai wandb" + ] + }, + { + "cell_type": "markdown", + "id": "752e1f77-a825-4eb7-afb7-5c2807b29ada", + "metadata": {}, + "source": [ + "## 🌳 Initialize a W&B Run\n", + "\n", + "We will start a new W&B run to start tracking our experiment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2315b79-8c0a-4cfd-aa6d-4fca55d78137", + "metadata": {}, + "outputs": [], + "source": [ + "import wandb\n", + "\n", + "wandb.init(\n", + " project=\"brain-tumor-segmentation\",\n", + " entity=\"lifesciences\",\n", + " job_type=\"fetch_dataset\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "308bd1ff-0999-4b85-b9a7-2a9d5753e69e", + "metadata": {}, + "source": [ + "## 🍁 Fetching the Dataset using MONAI\n", + "\n", + "The [`monai.apps.DecathlonDataset`](https://docs.monai.io/en/stable/apps.html#monai.apps.DecathlonDataset) lets us automatically download the data of [Medical Segmentation Decathlon challenge](http://medicaldecathlon.com/) and generate items for training, validation, or testing. We will use this API in the later notebooks to load and transform our datasets automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42189439-2c3d-403b-915a-98f897d049e4", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# Make the dataset directory\n", + "os.makedirs(\"./dataset/\", exist_ok=True)\n", + "\n", + "\n", + "from monai.apps import DecathlonDataset\n", + "\n", + "# Fetch the training split of the brain tumor segmentation dataset\n", + "train_dataset = DecathlonDataset(\n", + " root_dir=\"./dataset/\",\n", + " task=\"Task01_BrainTumour\",\n", + " section=\"training\",\n", + " download=True,\n", + " cache_rate=0.0,\n", + " num_workers=4,\n", + ")\n", + "\n", + "# Fetch the validation split of the brain tumor segmentation dataset\n", + "val_dataset = DecathlonDataset(\n", + " root_dir=\"./dataset/\",\n", + " task=\"Task01_BrainTumour\",\n", + " section=\"validation\",\n", + " download=False,\n", + " cache_rate=0.0,\n", + " num_workers=4,\n", + ")\n", + "\n", + "# Fetch the test split of the brain tumor segmentation dataset\n", + "test_dataset = DecathlonDataset(\n", + " root_dir=\"./dataset/\",\n", + " task=\"Task01_BrainTumour\",\n", + " section=\"test\",\n", + " download=False,\n", + " cache_rate=0.0,\n", + " num_workers=4,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07461dbc-3056-4f06-bb1a-462246a35791", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Train Set Size:\", len(train_dataset))\n", + "print(\"Validation Set Size:\", len(val_dataset))\n", + "print(\"Test Set Size:\", len(test_dataset))" + ] + }, + { + "cell_type": "markdown", + "id": "93e0609f-3009-4bd0-baf9-e8e10084801c", + "metadata": {}, + "source": [ + "## 💿 Upload the Dataset to W&B as an Artifact\n", + "\n", + "[W&B Artifacts](https://docs.wandb.ai/guides/artifacts) can be used to track and version any serialized data as the inputs and outputs of your W&B Runs. For example, a model training run might take in a dataset as input and a trained model as output.\n", + "\n", + "![](https://docs.wandb.ai/assets/images/artifacts_landing_page2-b6bd49ea5db62eff00f582a95845fed9.png)\n", + "\n", + "Let us now see how we can upload this dataset as a W&B artifact." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f1f35e5-927e-4baf-a351-652e7e99fe76", + "metadata": {}, + "outputs": [], + "source": [ + "artifact = wandb.Artifact(name=\"decathlon_brain_tumor\", type=\"dataset\")\n", + "artifact.add_dir(local_path=\"./dataset/\")\n", + "wandb.log_artifact(artifact)" + ] + }, + { + "cell_type": "markdown", + "id": "e1cbbe47-f83f-4db3-9c81-879121041881", + "metadata": {}, + "source": [ + "Now we end the experiment by calling `wandb.finish()`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25ea852b-04d7-4e94-97c3-45d972b21886", + "metadata": {}, + "outputs": [], + "source": [ + "wandb.finish()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/life-science/01_monai_decathlon_visualize.ipynb b/life-science/01_monai_decathlon_visualize.ipynb new file mode 100644 index 00000000..78f2c0cd --- /dev/null +++ b/life-science/01_monai_decathlon_visualize.ipynb @@ -0,0 +1,482 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "13953901-66c4-437c-b3a5-fadf8136d54c", + "metadata": {}, + "source": [ + "# Visualize Bain Tumor Segmentation Data\n", + "\n", + "In this notebook we will learn:\n", + "- MONAI transform API:\n", + " - MONAI Transforms for dictionary format data.\n", + " - Creating custom transforms using [`monai.transforms`](https://docs.monai.io/en/stable/transforms.html) API.\n", + "- how we can visualize the brain tumor segmentation dataset using W&B image overlays.\n", + "- how we can analyze our data using W&B Tables." + ] + }, + { + "cell_type": "markdown", + "id": "f4023f9a-1e58-468d-8ea3-56a694fa89ec", + "metadata": {}, + "source": [ + "## 🌴 Setup and Installation\n", + "\n", + "First, let us install the latest version of both MONAI and Weights and Biases." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4def9c4c-89b9-4f02-9853-91624690dc4f", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -q -U monai wandb" + ] + }, + { + "cell_type": "markdown", + "id": "48509346-08a2-41e3-bc98-1aea79fe42d3", + "metadata": {}, + "source": [ + "## 🌳 Initialize a W&B Run\n", + "\n", + "We will start a new W&B run to start tracking our experiment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69d72b50-1a24-4a32-97c8-6f859cc203df", + "metadata": {}, + "outputs": [], + "source": [ + "import wandb\n", + "\n", + "wandb.init(\n", + " project=\"brain-tumor-segmentation\",\n", + " entity=\"lifesciences\",\n", + " job_type=\"visualize_dataset\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "245625a8-be24-40f8-8cef-615917611a49", + "metadata": {}, + "source": [ + "## 💿 Loading and Transforming the Data\n", + "\n", + "We will now learn using the [`monai.transforms`](https://docs.monai.io/en/stable/transforms.html) API to create and apply transforms to our data.\n", + "\n", + "### Creating a Custom Transform\n", + "\n", + "First, we demonstrate the creation of a custom transform `ConvertToMultiChannelBasedOnBratsClassesd` using [`monai.transforms.MapTransform`](https://docs.monai.io/en/stable/transforms.html#maptransform) that converts labels to multi-channel tensors based on brats18 classes:\n", + "- label 1 is the necrotic and non-enhancing tumor core\n", + "- label 2 is the peritumoral edema\n", + "- label 3 is the GD-enhancing tumor.\n", + "\n", + "The target classes for the semantic segmentation task after applying this transform on the dataset will be\n", + "- Tumor core\n", + "- Whole tumor\n", + "- Enhancing tumor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ce8d513-cd46-43c1-839e-0ae15f750a5e", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from monai.transforms import MapTransform\n", + "\n", + "\n", + "class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):\n", + " \"\"\"\n", + " Convert labels to multi-channels based on brats classes:\n", + " label 1 is the peritumoral edema\n", + " label 2 is the GD-enhancing tumor\n", + " label 3 is the necrotic and non-enhancing tumor core\n", + " The possible classes are TC (Tumor core), WT (Whole tumor), and ET (Enhancing tumor).\n", + "\n", + " Reference: https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb\n", + "\n", + " \"\"\"\n", + "\n", + " def __call__(self, data):\n", + " data_dict = dict(data)\n", + " for key in self.keys:\n", + " result = []\n", + " # merge label 2 and label 3 to construct Tumor Core\n", + " result.append(torch.logical_or(data_dict[key] == 2, data_dict[key] == 3))\n", + " # merge labels 1, 2 and 3 to construct Whole Tumor\n", + " result.append(\n", + " torch.logical_or(\n", + " torch.logical_or(data_dict[key] == 2, data_dict[key] == 3), data_dict[key] == 1\n", + " )\n", + " )\n", + " # label 2 is Enhancing Tumor\n", + " result.append(data_dict[key] == 2)\n", + " data_dict[key] = torch.stack(result, axis=0).float()\n", + " return data_dict" + ] + }, + { + "cell_type": "markdown", + "id": "d72a65c1-701c-4700-88a9-df30b51eb10e", + "metadata": {}, + "source": [ + "Next, we compose all the necessary transforms for visualizing the data using [`monai.transforms.Compose`](https://docs.monai.io/en/stable/transforms.html#monai.transforms.Compose).\n", + "\n", + "**Note:** During training, we will apply a differnt set of transforms to the data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a542d505-3713-43f1-adf1-c33ade5696b1", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.transforms import (\n", + " Compose,\n", + " LoadImaged,\n", + " NormalizeIntensityd,\n", + " Orientationd,\n", + " Spacingd,\n", + " EnsureTyped,\n", + " EnsureChannelFirstd,\n", + ")\n", + "\n", + "\n", + "transforms = Compose(\n", + " [\n", + " # Load 4 Nifti images and stack them together\n", + " LoadImaged(keys=[\"image\", \"label\"]),\n", + " # Ensure loaded images are in channels-first format\n", + " EnsureChannelFirstd(keys=\"image\"),\n", + " # Ensure the input data to be a PyTorch Tensor or numpy array\n", + " EnsureTyped(keys=[\"image\", \"label\"]),\n", + " # Convert labels to multi-channels based on brats18 classes\n", + " ConvertToMultiChannelBasedOnBratsClassesd(keys=\"label\"),\n", + " # Change the input image’s orientation into the specified based on axis codes\n", + " Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n", + " # Resample the input images to the specified pixel dimension\n", + " Spacingd(\n", + " keys=[\"image\", \"label\"],\n", + " pixdim=(1.0, 1.0, 1.0),\n", + " mode=(\"bilinear\", \"nearest\"),\n", + " ),\n", + " # Normalize input image intensity\n", + " NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "3925b068-31fd-4a66-873d-850e6cd7fd87", + "metadata": {}, + "source": [ + "For loading the dataset, we first fetch it from the W&B dataset artifact that we had created earlier. This enables us to use the dataset as an input artifact to our visualization run, and establish the necessary lineage for our experiment.\n", + "\n", + "![](./assets/artifact_usage.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b6bf7fc-dc39-4247-9001-cd8832045b84", + "metadata": {}, + "outputs": [], + "source": [ + "artifact = wandb.use_artifact(\n", + " \"lifesciences/brain-tumor-segmentation/decathlon_brain_tumor:v0\", type=\"dataset\"\n", + ")\n", + "artifact_dir = artifact.download()" + ] + }, + { + "cell_type": "markdown", + "id": "54920325-570f-4c19-a299-66d4ff21ca15", + "metadata": {}, + "source": [ + "We now use the [`monai.apps.DecathlonDataset`](https://docs.monai.io/en/stable/apps.html#monai.apps.DecathlonDataset) to load our dataset and apply the transforms we defined on the data samples so that we can visualize it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc829ac0-ee0a-4924-889d-72fba089ec7b", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.apps import DecathlonDataset\n", + "\n", + "\n", + "# Create the dataset for the training split\n", + "# of the brain tumor segmentation dataset\n", + "train_dataset = DecathlonDataset(\n", + " root_dir=artifact_dir,\n", + " task=\"Task01_BrainTumour\",\n", + " transform=transforms,\n", + " section=\"training\",\n", + " download=True,\n", + " cache_rate=0.0,\n", + " num_workers=4,\n", + ")\n", + "\n", + "# Create the dataset for the validation split\n", + "# of the brain tumor segmentation dataset\n", + "val_dataset = DecathlonDataset(\n", + " root_dir=artifact_dir,\n", + " task=\"Task01_BrainTumour\",\n", + " transform=transforms,\n", + " section=\"validation\",\n", + " download=False,\n", + " cache_rate=0.0,\n", + " num_workers=4,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "fd917fe9-2fcd-482c-9e00-243d81e26dc7", + "metadata": {}, + "source": [ + "## 📸 Visualizing the Dataset\n", + "\n", + "Weights & Biases supports images, video, audio, and more. You can log rich media to explore your results and visually compare our runs, models, and datasets. Now, you will learn using the [segmentation mask overlay](https://docs.wandb.ai/guides/track/log/media#image-overlays-in-tables) system to visualize our data volumes. To log segmentation masks in [W&B tables](https://docs.wandb.ai/guides/tables), you must provide a [`wandb.Image`](https://docs.wandb.ai/ref/python/data-types/image) object containing the segmentation annotations for each row in the table.\n", + "\n", + "![](https://docs.wandb.ai/assets/images/viz-2-e3652d015abbf1d6d894e8edb1424eac.gif)\n", + "\n", + "An example is provided in the pseudocode below:\n", + "\n", + "```python\n", + "table = wandb.Table(columns=[\"ID\", \"Image\"])\n", + "\n", + "for id, img, label in zip(ids, images, labels):\n", + " mask_img = wandb.Image(\n", + " img,\n", + " masks={\n", + " \"ground-truth\": {\"mask_data\": label, \"class_labels\": class_labels}\n", + " # ...\n", + " },\n", + " )\n", + "\n", + " table.add_data(id, img)\n", + "\n", + "wandb.log({\"Table\": table})\n", + "```\n", + "\n", + "However, in our case, since the volume of the target classes might overlap one another, we will log them as separate overlays on the same image, so that we do not miss the relevant information.\n", + "\n", + "An example is provided in the pseudocode below:\n", + "\n", + "```python\n", + "mask_img = wandb.Image(\n", + " img,\n", + " masks={\n", + " \"ground-truth/Tumor-Core\": {\n", + " \"mask_data\": label_tumor_core,\n", + " \"class_labels\": {0: \"background\", 1: \"Tumor Core\"}\n", + " },\n", + " \"ground-truth/Whole-Tumor\": {\n", + " \"mask_data\": label_tumor_core,\n", + " \"class_labels\": {0: \"background\", 2: \"Whole-Tumor\"}\n", + " },\n", + " \"ground-truth/Enhancing-Tumor\": {\n", + " \"mask_data\": label_tumor_core,\n", + " \"class_labels\": {0: \"background\", 3: \"Enhancing-Tumor\"}\n", + " },\n", + " },\n", + ")\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9acc72e-7091-40d7-a97b-023df48d9e12", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from tqdm.auto import tqdm\n", + "\n", + "\n", + "def get_target_area_percentage(segmentation_map):\n", + " segmentation_map_list = segmentation_map.flatten().tolist()\n", + " return segmentation_map_list.count(1.0) * 100 / len(segmentation_map_list)\n", + "\n", + "\n", + "def log_data_samples_into_tables(\n", + " sample_image: np.array,\n", + " sample_label: np.array,\n", + " split: str = None,\n", + " data_idx: int = None,\n", + " table: wandb.Table = None,\n", + "):\n", + " \"\"\"Utility function for logging a data sample into a W&B Table\"\"\"\n", + " num_channels, _, _, num_slices = sample_image.shape\n", + " with tqdm(total=num_slices, leave=False) as progress_bar:\n", + " for slice_idx in range(num_slices):\n", + " ground_truth_wandb_images, tumor_area_percentages = [], []\n", + " for channel_idx in range(num_channels):\n", + " masks = {\n", + " \"ground-truth/Tumor-Core\": {\n", + " \"mask_data\": sample_label[0, :, :, slice_idx],\n", + " \"class_labels\": {0: \"background\", 1: \"Tumor Core\"},\n", + " },\n", + " \"ground-truth/Whole-Tumor\": {\n", + " \"mask_data\": sample_label[1, :, :, slice_idx] * 2,\n", + " \"class_labels\": {0: \"background\", 2: \"Whole Tumor\"},\n", + " },\n", + " \"ground-truth/Enhancing-Tumor\": {\n", + " \"mask_data\": sample_label[2, :, :, slice_idx] * 3,\n", + " \"class_labels\": {0: \"background\", 3: \"Enhancing Tumor\"},\n", + " },\n", + " }\n", + "\n", + " ground_truth_wandb_images.append(\n", + " wandb.Image(\n", + " sample_image[channel_idx, :, :, slice_idx],\n", + " masks=masks,\n", + " )\n", + " )\n", + " tumor_area_percentages.append(\n", + " {\n", + " \"Tumor-Core-Area-Percentage\": get_target_area_percentage(\n", + " sample_label[0, :, :, slice_idx]\n", + " ),\n", + " \"Whole-Tumor-Area-Percentage\": get_target_area_percentage(\n", + " sample_label[1, :, :, slice_idx]\n", + " ),\n", + " \"Enhancing-Tumor-Area-Percentage\": get_target_area_percentage(\n", + " sample_label[2, :, :, slice_idx]\n", + " ),\n", + " }\n", + " )\n", + " table.add_data(\n", + " split,\n", + " data_idx,\n", + " slice_idx,\n", + " *tumor_area_percentages,\n", + " *ground_truth_wandb_images\n", + " )\n", + " progress_bar.update(1)\n", + " return table" + ] + }, + { + "cell_type": "markdown", + "id": "7919a2fc-7e44-4283-a862-93ff8cdcfa5f", + "metadata": {}, + "source": [ + "Next, we iterate over our respective datasets and populate the table on our W&B dashboard." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85601b14-e693-4cca-b8cf-5e3ce863bb18", + "metadata": {}, + "outputs": [], + "source": [ + "# Define the schema of the table\n", + "table = wandb.Table(\n", + " columns=[\n", + " \"Split\",\n", + " \"Data Index\",\n", + " \"Slice Index\",\n", + " \"Tumor-Area-Pixel-Percentages-Channel-0\",\n", + " \"Tumor-Area-Pixel-Percentages-Channel-1\",\n", + " \"Tumor-Area-Pixel-Percentages-Channel-2\",\n", + " \"Tumor-Area-Pixel-Percentages-Channel-3\",\n", + " \"Image-Channel-0\",\n", + " \"Image-Channel-1\",\n", + " \"Image-Channel-2\",\n", + " \"Image-Channel-3\",\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "386b6d4b-ebf6-4d9d-bc78-7ce489e0a926", + "metadata": {}, + "outputs": [], + "source": [ + "# Generate visualizations for train_dataset\n", + "for data_idx, sample in tqdm(\n", + " enumerate(train_dataset),\n", + " total=len(train_dataset),\n", + " desc=\"Generating Train Dataset Visualizations:\",\n", + "):\n", + " sample_image = sample[\"image\"].detach().cpu().numpy()\n", + " sample_label = sample[\"label\"].detach().cpu().numpy()\n", + " table = log_data_samples_into_tables(\n", + " sample_image,\n", + " sample_label,\n", + " split=\"train\",\n", + " data_idx=data_idx,\n", + " table=table,\n", + " )\n", + "\n", + "# Generate visualizations for val_dataset\n", + "for data_idx, sample in tqdm(\n", + " enumerate(val_dataset),\n", + " total=len(val_dataset),\n", + " desc=\"Generating Validation Dataset Visualizations:\",\n", + "):\n", + " sample_image = sample[\"image\"].detach().cpu().numpy()\n", + " sample_label = sample[\"label\"].detach().cpu().numpy()\n", + " table = log_data_samples_into_tables(\n", + " sample_image,\n", + " sample_label,\n", + " split=\"val\",\n", + " data_idx=data_idx,\n", + " table=table,\n", + " )\n", + "\n", + "# Log the table to your dashboard\n", + "wandb.log({\"tumor_segmentation_data\": table})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc83d9f8-b50f-48b8-a6c3-e8920ac44285", + "metadata": {}, + "outputs": [], + "source": [ + "wandb.finish()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/life-science/02_monai_train_baseline_model.ipynb b/life-science/02_monai_train_baseline_model.ipynb new file mode 100644 index 00000000..df8660cc --- /dev/null +++ b/life-science/02_monai_train_baseline_model.ipynb @@ -0,0 +1,801 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a93e6b75-b1d9-4703-8834-ab3fcd8f934b", + "metadata": {}, + "source": [ + "# Train a Baseline Segmentation Model\n", + "In this notebook we will learn:\n", + "\n", + "- We will learn how to use specific MONAI APIs to write our training workflow, including a SoTA neural network architecture and loss function and metrics for our task.\n", + "- Use Weights & Biases for tracking our experiments and logging and verisioning our model checkpoints." + ] + }, + { + "cell_type": "markdown", + "id": "95118e9e-e6d0-4bde-bd5b-af792ca8153a", + "metadata": {}, + "source": [ + "## 🌴 Setup and Installation\n", + "\n", + "First, let us install the latest version of both MONAI and Weights and Biases." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3ab1bc3-d503-4e4c-afd1-ad77d72ac472", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -q -U monai wandb" + ] + }, + { + "cell_type": "markdown", + "id": "6a4f71b2-ffb6-475c-a563-1cb631e33d84", + "metadata": {}, + "source": [ + "## 🌳 Initialize a W&B Run\n", + "\n", + "We will start a new W&B run to start tracking our experiment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b258cfa2-3795-4502-8878-aff469ba5077", + "metadata": {}, + "outputs": [], + "source": [ + "import wandb\n", + "\n", + "wandb.init(\n", + " project=\"brain-tumor-segmentation\",\n", + " entity=\"lifesciences\",\n", + " job_type=\"train_baseline\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "7ad41880-1613-489b-a18a-80986a74964a", + "metadata": {}, + "source": [ + "## 🌼 Reproducibility and Configuration Management\n", + "\n", + "`wandb.config` allows us to easily define and manage the configurations of our experiments. This includes hyperparameters, model settings, and any other experiment variables that we use in a particular run. By centralizing this information, we can ensure consistency across runs and make your experiments more organized and reproducible." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4b1832f-43fb-41b7-bd4e-ee7d31b44b39", + "metadata": {}, + "outputs": [], + "source": [ + "config = wandb.config" + ] + }, + { + "cell_type": "markdown", + "id": "ab7693aa-5686-463d-aa8f-be92bec82c31", + "metadata": {}, + "source": [ + "Next, we set random seed for modules to enable deterministic training by setting a global seed using `monai.utils.set_determinism`. Setting a random seed (or multiple random seeds) and storing them as a configuration, we can make sure that a particular run is reproducible." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59460776-36a9-4e9e-8955-a93453aaa310", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.utils import set_determinism\n", + "\n", + "config.seed = 0\n", + "set_determinism(seed=config.seed)" + ] + }, + { + "cell_type": "markdown", + "id": "0d2da2f5-b1a8-4215-927c-dc5884b0b41f", + "metadata": {}, + "source": [ + "## 💿 Loading and Transforming the Data\n", + "\n", + "We will now learn using the [`monai.transforms`](https://docs.monai.io/en/stable/transforms.html) API to create and apply transforms to our data.\n", + "\n", + "### Creating a Custom Transform\n", + "\n", + "First, we demonstrate the creation of a custom transform `ConvertToMultiChannelBasedOnBratsClassesd` using [`monai.transforms.MapTransform`](https://docs.monai.io/en/stable/transforms.html#maptransform) that converts labels to multi-channel tensors based on brats18 classes:\n", + "- label 1 is the necrotic and non-enhancing tumor core\n", + "- label 2 is the peritumoral edema\n", + "- label 3 is the GD-enhancing tumor.\n", + "\n", + "The target classes for the semantic segmentation task after applying this transform on the dataset will be\n", + "- Tumor core\n", + "- Whole tumor\n", + "- Enhancing tumor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4cef0b54-8bdc-4bee-9b77-c16745bcec2a", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from monai.transforms import MapTransform\n", + "\n", + "\n", + "class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):\n", + " \"\"\"\n", + " Convert labels to multi-channels based on brats classes:\n", + " label 1 is the peritumoral edema\n", + " label 2 is the GD-enhancing tumor\n", + " label 3 is the necrotic and non-enhancing tumor core\n", + " The possible classes are TC (Tumor core), WT (Whole tumor), and ET (Enhancing tumor).\n", + "\n", + " Reference: https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb\n", + "\n", + " \"\"\"\n", + "\n", + " def __call__(self, data):\n", + " data_dict = dict(data)\n", + " for key in self.keys:\n", + " result = []\n", + " # merge label 2 and label 3 to construct Tumor Core\n", + " result.append(torch.logical_or(data_dict[key] == 2, data_dict[key] == 3))\n", + " # merge labels 1, 2 and 3 to construct Whole Tumor\n", + " result.append(\n", + " torch.logical_or(\n", + " torch.logical_or(data_dict[key] == 2, data_dict[key] == 3),\n", + " data_dict[key] == 1,\n", + " )\n", + " )\n", + " # label 2 is Enhancing Tumor\n", + " result.append(data_dict[key] == 2)\n", + " data_dict[key] = torch.stack(result, axis=0).float()\n", + " return data_dict" + ] + }, + { + "cell_type": "markdown", + "id": "1f38e5d1-e3de-4bda-9b89-67193a6baca2", + "metadata": {}, + "source": [ + "Next, we compose all the necessary transforms for visualizing the data using [`monai.transforms.Compose`](https://docs.monai.io/en/stable/transforms.html#monai.transforms.Compose).\n", + "\n", + "**Note:** During training, we will apply a differnt set of transforms to the data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62a9e73a-aeb1-49b2-b06b-d09857b74966", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.transforms import (\n", + " Activations,\n", + " AsDiscrete,\n", + " Compose,\n", + " LoadImaged,\n", + " NormalizeIntensityd,\n", + " Orientationd,\n", + " RandFlipd,\n", + " RandScaleIntensityd,\n", + " RandShiftIntensityd,\n", + " RandSpatialCropd,\n", + " Spacingd,\n", + " EnsureTyped,\n", + " EnsureChannelFirstd,\n", + ")\n", + "\n", + "\n", + "config.roi_size = [224, 224, 144]\n", + "\n", + "train_transform = Compose(\n", + " [\n", + " # load 4 Nifti images and stack them together\n", + " LoadImaged(keys=[\"image\", \"label\"]),\n", + " # Ensure loaded images are in channels-first format\n", + " EnsureChannelFirstd(keys=\"image\"),\n", + " # Ensure the input data to be a PyTorch Tensor or numpy array\n", + " EnsureTyped(keys=[\"image\", \"label\"]),\n", + " # Convert labels to multi-channels based on brats18 classes\n", + " ConvertToMultiChannelBasedOnBratsClassesd(keys=\"label\"),\n", + " # Change the input image’s orientation into the specified based on axis codes\n", + " Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n", + " # Resample the input images to the specified pixel dimension\n", + " Spacingd(\n", + " keys=[\"image\", \"label\"],\n", + " pixdim=(1.0, 1.0, 1.0),\n", + " mode=(\"bilinear\", \"nearest\"),\n", + " ),\n", + " # Augmentation: Crop image with random size or specific size ROI\n", + " RandSpatialCropd(\n", + " keys=[\"image\", \"label\"], roi_size=config.roi_size, random_size=False\n", + " ),\n", + " \n", + " # Augmentation: Randomly flip the image on the specified axes\n", + " RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=0),\n", + " RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=1),\n", + " RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=2),\n", + " \n", + " # Normalize input image intensity\n", + " NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n", + " \n", + " # Augmentation: Randomly scale the image intensity\n", + " RandScaleIntensityd(keys=\"image\", factors=0.1, prob=1.0),\n", + " RandShiftIntensityd(keys=\"image\", offsets=0.1, prob=1.0),\n", + " ]\n", + ")\n", + "val_transform = Compose(\n", + " [\n", + " # load 4 Nifti images and stack them together\n", + " LoadImaged(keys=[\"image\", \"label\"]),\n", + " # Ensure loaded images are in channels-first format\n", + " EnsureChannelFirstd(keys=\"image\"),\n", + " # Ensure the input data to be a PyTorch Tensor or numpy array\n", + " EnsureTyped(keys=[\"image\", \"label\"]),\n", + " # Convert labels to multi-channels based on brats18 classes\n", + " ConvertToMultiChannelBasedOnBratsClassesd(keys=\"label\"),\n", + " # Change the input image’s orientation into the specified based on axis codes\n", + " Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n", + " # Resample the input images to the specified pixel dimension\n", + " Spacingd(\n", + " keys=[\"image\", \"label\"],\n", + " pixdim=(1.0, 1.0, 1.0),\n", + " mode=(\"bilinear\", \"nearest\"),\n", + " ),\n", + " # Normalize input image intensity\n", + " NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "14df1019-7cf6-4baa-baa1-fa2085da17bb", + "metadata": {}, + "source": [ + "For loading the dataset, we first fetch it from the W&B dataset artifact that we had created earlier. This enables us to use the dataset as an input artifact to our visualization run, and establish the necessary lineage for our experiment.\n", + "\n", + "![](./assets/artifact_usage.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35b2d2c9-8e7f-479b-a822-deb0984d22ae", + "metadata": {}, + "outputs": [], + "source": [ + "artifact = wandb.use_artifact(\n", + " \"lifesciences/brain-tumor-segmentation/decathlon_brain_tumor:v0\", type=\"dataset\"\n", + ")\n", + "artifact_dir = artifact.download()" + ] + }, + { + "cell_type": "markdown", + "id": "8fd9de15-afa4-4ccf-9d4f-78b90ae3db8f", + "metadata": {}, + "source": [ + "We now use the [`monai.apps.DecathlonDataset`](https://docs.monai.io/en/stable/apps.html#monai.apps.DecathlonDataset) to load our dataset and apply the transforms we defined on the data samples so that we use them for training and validation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16faedc0-f50a-4289-9fa4-6948549ea74f", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.apps import DecathlonDataset\n", + "\n", + "config.num_workers = 4\n", + "\n", + "# Create the dataset for the training split\n", + "# of the brain tumor segmentation dataset\n", + "train_dataset = DecathlonDataset(\n", + " root_dir=artifact_dir,\n", + " task=\"Task01_BrainTumour\",\n", + " transform=train_transform,\n", + " section=\"training\",\n", + " download=False,\n", + " cache_rate=0.0,\n", + " num_workers=config.num_workers,\n", + ")\n", + "\n", + "# Create the dataset for the validation split\n", + "# of the brain tumor segmentation dataset\n", + "val_dataset = DecathlonDataset(\n", + " root_dir=artifact_dir,\n", + " task=\"Task01_BrainTumour\",\n", + " transform=val_transform,\n", + " section=\"validation\",\n", + " download=False,\n", + " cache_rate=0.0,\n", + " num_workers=config.num_workers,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e8eb84cf-eea0-431b-a66e-9afd2b7144ba", + "metadata": {}, + "source": [ + "We now create DataLoaders for the train and validation datasets respectively using [`monai.data.DataLoader`](https://docs.monai.io/en/stable/data.html#dataloader) which provides an iterable over the given dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ad33748-060f-4652-87ed-f8b56de02824", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.data import DataLoader\n", + "\n", + "config.batch_size = 2\n", + "\n", + "# create the train_loader\n", + "train_loader = DataLoader(\n", + " train_dataset,\n", + " batch_size=config.batch_size,\n", + " shuffle=True,\n", + " num_workers=config.num_workers,\n", + ")\n", + "\n", + "# create the val_loader\n", + "val_loader = DataLoader(\n", + " val_dataset,\n", + " batch_size=config.batch_size,\n", + " shuffle=False,\n", + " num_workers=config.num_workers,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "ba68b8b8-9527-482d-8f4f-da8c71dcbd87", + "metadata": {}, + "source": [ + "## 🤖 Creating the Model, Loss, and Optimizer\n", + "\n", + "We will be training a **SegResNet** model based on the paper [3D MRI brain tumor segmentation using auto-encoder regularization](https://arxiv.org/pdf/1810.11654.pdf). The [SegResNet](https://docs.monai.io/en/stable/networks.html#segresnet) model that comes implemented as a PyTorch Module as part of the [`monai.networks.nets`](https://docs.monai.io/en/stable/networks.html#nets) API that provides out-of-the-box implementations of SoTA neural network models for different medical imaging tasks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d2a5e5a-84ee-4fbe-9b06-1fcf9639d443", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.networks.nets import SegResNet\n", + "\n", + "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + "config.model_blocks_down = [1, 2, 2, 4]\n", + "config.model_blocks_up = [1, 1, 1]\n", + "config.model_init_filters = 16\n", + "config.model_in_channels = 4\n", + "config.model_out_channels = 3\n", + "config.model_dropout_prob = 0.2\n", + "\n", + "# create model\n", + "model = SegResNet(\n", + " blocks_down=config.model_blocks_down,\n", + " blocks_up=config.model_blocks_up,\n", + " init_filters=config.model_init_filters,\n", + " in_channels=config.model_in_channels,\n", + " out_channels=config.model_out_channels,\n", + " dropout_prob=config.model_dropout_prob,\n", + ").to(device)" + ] + }, + { + "cell_type": "markdown", + "id": "20863f94-f7d0-4b82-8a55-41a1952d1cae", + "metadata": {}, + "source": [ + "We will be using [Adam Optimizer](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html) and the [cosine annealing schedule](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html) to schedule our learning rate. This approach is designed to help in finding global minima in the optimization landscape and to provide a form of reset mechanism during training, which can improve the performance of the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a9a9b1a6-d127-4fe9-8831-a33b71dbe438", + "metadata": {}, + "outputs": [], + "source": [ + "config.initial_learning_rate = 1e-4\n", + "config.weight_decay = 1e-5\n", + "config.max_train_epochs = 25\n", + "\n", + "# create optimizer\n", + "optimizer = torch.optim.Adam(\n", + " model.parameters(),\n", + " config.initial_learning_rate,\n", + " weight_decay=config.weight_decay,\n", + ")\n", + "\n", + "# create learning rate scheduler\n", + "lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n", + " optimizer, T_max=config.max_train_epochs\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "541dc8c6-009e-4115-81c5-d35386a92b6d", + "metadata": {}, + "source": [ + "Next, we would define the loss as multi-label DiceLoss as proposed by the paper [V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation](https://arxiv.org/abs/1606.04797) using the [`monai.losses`](https://docs.monai.io/en/stable/losses.html) API and the corresponding dice metrics using the [`monai.metrics`](https://docs.monai.io/en/stable/metrics.html) API." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2276a62-24a9-4352-8058-eb807e16819e", + "metadata": {}, + "outputs": [], + "source": [ + "config.dice_loss_smoothen_numerator = 0\n", + "config.dice_loss_smoothen_denominator = 1e-5\n", + "config.dice_loss_squared_prediction = True\n", + "config.dice_loss_target_onehot = False\n", + "config.dice_loss_apply_sigmoid = True\n", + "\n", + "from monai.losses import DiceLoss\n", + "\n", + "loss_function = DiceLoss(\n", + " smooth_nr=config.dice_loss_smoothen_numerator,\n", + " smooth_dr=config.dice_loss_smoothen_denominator,\n", + " squared_pred=config.dice_loss_squared_prediction,\n", + " to_onehot_y=config.dice_loss_target_onehot,\n", + " sigmoid=config.dice_loss_apply_sigmoid,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "187d9b11-3df9-4ae8-96c0-e7c16468878d", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.metrics import DiceMetric\n", + "\n", + "dice_metric = DiceMetric(include_background=True, reduction=\"mean\")\n", + "dice_metric_batch = DiceMetric(include_background=True, reduction=\"mean_batch\")\n", + "post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])" + ] + }, + { + "cell_type": "markdown", + "id": "70aa2a71-9ad1-43f8-97e8-00ea458b7c47", + "metadata": {}, + "source": [ + "## 🚀 Automatic Mixed Precision\n", + "\n", + "Mixed precision training is a technique used in training neural networks that utilizes both 16-bit and 32-bit floating-point types for different parts of the computation, rather than using a single precision type throughout the entire process. This method is primarily aimed at accelerating the training process while also reducing the memory usage of the models.\n", + "\n", + "We will be using [`torch.amp provides`](https://pytorch.org/docs/stable/amp.html#module-torch.amp) convenience methods for mixed precision, where some operations use the `torch.float32` datatype and other operations use lower precision floating point datatype such as `torch.float16` or `torch.bfloat16`.\n", + "\n", + "### ⚖️ Gradient and Loss Scaling\n", + "\n", + "If the forward pass for a particular op has float16 inputs, the backward pass for that op will produce float16 gradients. Gradient values with small magnitudes may not be representable in float16. These values will gradient underflow, so the update for the corresponding parameters will be lost.\n", + "\n", + "In order to counteract the gradient underflow issues of FP16, especially in handling small gradient values, gradient and loss scaling is applied. This involves scaling up the loss before the gradient computation and scaling it back down afterwards. We will be using [`torch.cuda.amp.GradScaler`](https://pytorch.org/docs/stable/amp.html#gradient-scaling) to perform the scaling." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9fd9b23e-d40d-462d-860a-b0303621e250", + "metadata": {}, + "outputs": [], + "source": [ + "# use automatic mixed-precision to accelerate training\n", + "scaler = torch.cuda.amp.GradScaler()\n", + "torch.backends.cudnn.benchmark = True" + ] + }, + { + "cell_type": "markdown", + "id": "7ff081e5-de2a-42e6-ba16-bc9048b9aaee", + "metadata": {}, + "source": [ + "Next, we write a utility function to perform sliding window inference using from [`monai.inferers.sliding_window_inference`](https://docs.monai.io/en/stable/inferers.html#sliding-window-inference-function) and AMP autocast. This function would be used durring the validation step in our training and validation loop." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6f51bb2-1c5b-4a16-b078-d6179c2bd083", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.inferers import sliding_window_inference\n", + "\n", + "config.inference_roi_size = (240, 240, 160)\n", + "\n", + "\n", + "def inference(model, input):\n", + " def _compute(input):\n", + " return sliding_window_inference(\n", + " inputs=input,\n", + " roi_size=config.inference_roi_size,\n", + " sw_batch_size=1,\n", + " predictor=model,\n", + " overlap=0.5,\n", + " )\n", + "\n", + " with torch.cuda.amp.autocast():\n", + " return _compute(input)" + ] + }, + { + "cell_type": "markdown", + "id": "55423f02-8686-4a56-9778-09070b7edcc5", + "metadata": {}, + "source": [ + "## 🦾 Training the Model\n", + "Let's finally get to training the model!\n", + "\n", + "### 🐝 Customize Log Axes on W&B\n", + "\n", + "We will use Use [`wandb.define_metric`](https://docs.wandb.ai/guides/track/log/customize-logging-axes) to set a custom x axis for our W&B charts. Custom x-axes are useful in contexts where you need to log to different time steps in the past during training, asynchronously. For example, for training our brain tumor segmentation model, we can log the training loss and metrics every training step but log the validation metrics every epoch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02d4563f-1522-416d-ad3d-fb9b936bb368", + "metadata": {}, + "outputs": [], + "source": [ + "wandb.define_metric(\"epoch/epoch_step\")\n", + "wandb.define_metric(\"epoch/*\", step_metric=\"epoch/epoch_step\")\n", + "wandb.define_metric(\"batch/batch_step\")\n", + "wandb.define_metric(\"batch/*\", step_metric=\"batch/batch_step\")\n", + "wandb.define_metric(\"validation/validation_step\")\n", + "wandb.define_metric(\"validation/*\", step_metric=\"validation/validation_step\")" + ] + }, + { + "cell_type": "markdown", + "id": "094a1fd5-0113-4e22-be81-7a30f754a10b", + "metadata": {}, + "source": [ + "### 🏋️ Training and Validation Loop\n", + "\n", + "Next, we will proceed to writing the training and validation loop for the brain tumor segmentation model. The traininng loop consists of 3 different logical steps:\n", + "\n", + "1. **The training step**: In this step, we actually train the model, by looping over the `train_loader`. Note that we use autocast to speed up the forward pass and loss calculation and during the backpropagation, we use gradient scaler to avoid the vanishing gradient problem. At the end of each batch step, we log the batch step under `batch/batch_step` and the training loss under `batch/train_loss`. This ensure that the training loss is logged under its section against the batch step in the x-axis in the W&B workspace. Here's how the training step is written:\n", + " \n", + " ```python\n", + " for batch_data in train_loader:\n", + " inputs, labels = (\n", + " batch_data[\"image\"].to(device),\n", + " batch_data[\"label\"].to(device),\n", + " )\n", + " optimizer.zero_grad()\n", + " with torch.cuda.amp.autocast():\n", + " outputs = model(inputs)\n", + " loss = loss_function(outputs, labels)\n", + " scaler.scale(loss).backward()\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", + " epoch_loss += loss.item()\n", + " batch_progress_bar.set_description(f\"train_loss: {loss.item():.4f}:\")\n", + " ## Log batch-wise training loss to W&B\n", + " wandb.log({\"batch/batch_step\": batch_step, \"batch/train_loss\": loss.item()})\n", + " batch_step += 1\n", + " \n", + " ```\n", + "\n", + "2. **The epoch-wise logging step:** In this step, we log the learning rate and mean training loss for the epoch under the section `epoch/*`. We also update the learning rate using our learning rate scheduler after logging.\n", + "\n", + " ```python\n", + " wandb.log(\n", + " {\n", + " \"epoch/epoch_step\": epoch,\n", + " \"epoch/mean_train_loss\": total_epoch_loss / total_batch_steps,\n", + " \"epoch/learning_rate\": lr_scheduler.get_last_lr()[0],\n", + " }\n", + " )\n", + " lr_scheduler.step()\n", + " \n", + " ```\n", + "\n", + "3. **The validation step:** This step is executed at the interval of a certain number of epochs. In this step, we use the aforementioned `inference` function to predict the segmentation masks for the images from the validation dataloader and use `dice_metric` to calculate the dice coefficients for each of our target classes and log the dice coefficient values under the `validation/*` section. We also save our model checkpoint to W&B using `wandb.log_model` .\n", + "\n", + " ```python\n", + " for val_data in val_loader:\n", + " val_inputs, val_labels = (\n", + " val_data[\"image\"].to(device),\n", + " val_data[\"label\"].to(device),\n", + " )\n", + " val_outputs = inference(model, val_inputs)\n", + " val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]\n", + " dice_metric(y_pred=val_outputs, y=val_labels)\n", + " dice_metric_batch(y_pred=val_outputs, y=val_labels)\n", + "\n", + " wandb.log(\n", + " {\n", + " \"validation/validation_step\": validation_step,\n", + " \"validation/mean_dice\": metric_values[-1],\n", + " \"validation/mean_dice_tumor_core\": metric_values_tumor_core[-1],\n", + " \"validation/mean_dice_whole_tumor\": metric_values_whole_tumor[-1],\n", + " \"validation/mean_dice_enhanced_tumor\": metric_values_enhanced_tumor[-1],\n", + " }\n", + " )\n", + "\n", + " checkpoint_path = os.path.join(config.checkpoint_dir, \"model.pth\")\n", + " torch.save(model.state_dict(), checkpoint_path)\n", + " wandb.log_model(\n", + " checkpoint_path,\n", + " name=f\"{wandb.run.id}-checkpoint\",\n", + " aliases=[f\"epoch_{epoch}\"],\n", + " )\n", + " ```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7df12631-8e09-48d3-b0b0-655775ec22e4", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from tqdm.auto import tqdm\n", + "from monai.data import decollate_batch\n", + "\n", + "config.validation_intervals = 1\n", + "config.checkpoint_dir = \"./checkpoints\"\n", + "\n", + "# Create checkpoint directory\n", + "os.makedirs(config.checkpoint_dir, exist_ok=True)\n", + "\n", + "batch_step = 0\n", + "validation_step = 0\n", + "metric_values = []\n", + "metric_values_tumor_core = []\n", + "metric_values_whole_tumor = []\n", + "metric_values_enhanced_tumor = []\n", + "\n", + "epoch_progress_bar = tqdm(range(config.max_train_epochs), desc=\"Training:\")\n", + "\n", + "for epoch in epoch_progress_bar:\n", + " model.train()\n", + " epoch_loss = 0\n", + "\n", + " total_batch_steps = len(train_dataset) // train_loader.batch_size\n", + " batch_progress_bar = tqdm(train_loader, total=total_batch_steps, leave=False)\n", + "\n", + " # Training Step\n", + " for batch_data in batch_progress_bar:\n", + " inputs, labels = (\n", + " batch_data[\"image\"].to(device),\n", + " batch_data[\"label\"].to(device),\n", + " )\n", + " optimizer.zero_grad()\n", + " with torch.cuda.amp.autocast():\n", + " outputs = model(inputs)\n", + " loss = loss_function(outputs, labels)\n", + " scaler.scale(loss).backward()\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", + " epoch_loss += loss.item()\n", + " batch_progress_bar.set_description(f\"train_loss: {loss.item():.4f}:\")\n", + " ## Log batch-wise training loss to W&B\n", + " wandb.log({\"batch/batch_step\": batch_step, \"batch/train_loss\": loss.item()})\n", + " batch_step += 1\n", + "\n", + " epoch_loss /= total_batch_steps\n", + " ## Log batch-wise training loss and learning rate to W&B\n", + " wandb.log(\n", + " {\n", + " \"epoch/epoch_step\": epoch,\n", + " \"epoch/mean_train_loss\": epoch_loss,\n", + " \"epoch/learning_rate\": lr_scheduler.get_last_lr()[0],\n", + " }\n", + " )\n", + " lr_scheduler.step()\n", + " epoch_progress_bar.set_description(f\"Training: train_loss: {epoch_loss:.4f}:\")\n", + "\n", + " # Validation and model checkpointing step\n", + " if (epoch + 1) % config.validation_intervals == 0:\n", + " model.eval()\n", + " with torch.no_grad():\n", + " for val_data in val_loader:\n", + " val_inputs, val_labels = (\n", + " val_data[\"image\"].to(device),\n", + " val_data[\"label\"].to(device),\n", + " )\n", + " val_outputs = inference(model, val_inputs)\n", + " val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]\n", + " dice_metric(y_pred=val_outputs, y=val_labels)\n", + " dice_metric_batch(y_pred=val_outputs, y=val_labels)\n", + "\n", + " metric_values.append(dice_metric.aggregate().item())\n", + " metric_batch = dice_metric_batch.aggregate()\n", + " metric_values_tumor_core.append(metric_batch[0].item())\n", + " metric_values_whole_tumor.append(metric_batch[1].item())\n", + " metric_values_enhanced_tumor.append(metric_batch[2].item())\n", + " dice_metric.reset()\n", + " dice_metric_batch.reset()\n", + "\n", + " # Log and versison model checkpoints using W&B artifacts.\n", + " checkpoint_path = os.path.join(config.checkpoint_dir, \"model.pth\")\n", + " torch.save(model.state_dict(), checkpoint_path)\n", + " wandb.log_model(\n", + " checkpoint_path,\n", + " name=f\"{wandb.run.id}-checkpoint\",\n", + " aliases=[f\"epoch_{epoch}\"],\n", + " )\n", + "\n", + " # Log validation metrics to W&B dashboard.\n", + " wandb.log(\n", + " {\n", + " \"validation/validation_step\": validation_step,\n", + " \"validation/mean_dice\": metric_values[-1],\n", + " \"validation/mean_dice_tumor_core\": metric_values_tumor_core[-1],\n", + " \"validation/mean_dice_whole_tumor\": metric_values_whole_tumor[-1],\n", + " \"validation/mean_dice_enhanced_tumor\": metric_values_enhanced_tumor[-1],\n", + " }\n", + " )\n", + " validation_step += 1" + ] + }, + { + "cell_type": "markdown", + "id": "2b161a87-6558-4fa4-8f1b-578f754a95bc", + "metadata": {}, + "source": [ + "Now we end the experiment by calling `wandb.finish()`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3600d23f-76d7-4702-b566-b17b9f412b4b", + "metadata": {}, + "outputs": [], + "source": [ + "wandb.finish()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/life-science/03_monai_evaluate_model.ipynb b/life-science/03_monai_evaluate_model.ipynb new file mode 100644 index 00000000..7c5f29be --- /dev/null +++ b/life-science/03_monai_evaluate_model.ipynb @@ -0,0 +1,465 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Evaluate Bain Tumor Segmentation Data\n", + "\n", + "In this notebook we will learn:\n", + "- how we can evaluate a pre-trained model checkpoint for brain tumor segmentation using MONAI and Weights & Biases.\n", + "- how we can visually compare the ground-truth labels with the predicted labels." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🌴 Setup and Installation\n", + "\n", + "First, let us install the latest version of both MONAI and Weights and Biases." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -q -U monai wandb" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🌳 Initialize a W&B Run\n", + "\n", + "We will start a new W&B run to start tracking our experiment. Note that we set the job type for this run as `evaluate`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import wandb\n", + "from monai.utils import set_determinism\n", + "\n", + "wandb.init(\n", + " project=\"brain-tumor-segmentation\",\n", + " entity=\"lifesciences\",\n", + " job_type=\"evaluate\"\n", + ")\n", + "\n", + "config = wandb.config\n", + "\n", + "# Ensure deterministic behavior and reproducibility\n", + "config.seed = 0\n", + "set_determinism(seed=config.seed)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 💿 Loading and Transforming the Data\n", + "\n", + "We will use the validation transforms from the previous lessons to load and transform the validation dataset using the Decathlon dataset artifact on W&B." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from utils import ConvertToMultiChannelBasedOnBratsClassesd\n", + "from monai.apps import DecathlonDataset\n", + "from monai.transforms import (\n", + " Compose,\n", + " LoadImaged,\n", + " NormalizeIntensityd,\n", + " Orientationd,\n", + " Spacingd,\n", + " EnsureTyped,\n", + " EnsureChannelFirstd,\n", + ")\n", + "\n", + "\n", + "transforms = Compose(\n", + " [\n", + " # load 4 Nifti images and stack them together\n", + " LoadImaged(keys=[\"image\", \"label\"]),\n", + " # Ensure loaded images are in channels-first format\n", + " EnsureChannelFirstd(keys=\"image\"),\n", + " # Ensure the input data to be a PyTorch Tensor or numpy array\n", + " EnsureTyped(keys=[\"image\", \"label\"]),\n", + " # Convert labels to multi-channels based on brats18 classes\n", + " ConvertToMultiChannelBasedOnBratsClassesd(keys=\"label\"),\n", + " # Change the input image’s orientation into the specified based on axis codes\n", + " Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n", + " # Resample the input images to the specified pixel dimension\n", + " Spacingd(\n", + " keys=[\"image\", \"label\"],\n", + " pixdim=(1.0, 1.0, 1.0),\n", + " mode=(\"bilinear\", \"nearest\"),\n", + " ),\n", + " # Normalize input image intensity\n", + " NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n", + " ]\n", + ")\n", + "\n", + "\n", + "# Fetch the brain tumor segmentation dataset artifact from W&B\n", + "artifact = wandb.use_artifact(\n", + " \"lifesciences/brain-tumor-segmentation/decathlon_brain_tumor:latest\",\n", + " type=\"dataset\",\n", + ")\n", + "artifact_dir = artifact.download()\n", + "\n", + "\n", + "# Create the dataset for the test split\n", + "# of the brain tumor segmentation dataset\n", + "val_dataset = DecathlonDataset(\n", + " root_dir=artifact_dir,\n", + " task=\"Task01_BrainTumour\",\n", + " transform=transforms,\n", + " section=\"validation\",\n", + " download=False,\n", + " cache_rate=0.0,\n", + " num_workers=4,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🤖 Loading the Model Checkpoint\n", + "\n", + "We are going to fetch the model checkpoints from the training run and load them." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "from monai.networks.nets import SegResNet\n", + "\n", + "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + "config.model_blocks_down = [1, 2, 2, 4]\n", + "config.model_blocks_up = [1, 1, 1]\n", + "config.model_init_filters = 16\n", + "config.model_in_channels = 4\n", + "config.model_out_channels = 3\n", + "config.model_dropout_prob = 0.2\n", + "\n", + "# create model\n", + "model = SegResNet(\n", + " blocks_down=config.model_blocks_down,\n", + " blocks_up=config.model_blocks_up,\n", + " init_filters=config.model_init_filters,\n", + " in_channels=config.model_in_channels,\n", + " out_channels=config.model_out_channels,\n", + " dropout_prob=config.model_dropout_prob,\n", + ").to(device)\n", + "\n", + "\n", + "# Fetch the latest model checkpoint artifact from the training run\n", + "model_artifact = wandb.use_artifact(\n", + " \"lifesciences/brain-tumor-segmentation/8vmqcqao-checkpoint:latest\",\n", + " type=\"model\",\n", + ")\n", + "model_artifact_dir = model_artifact.download()\n", + "\n", + "\n", + "# Load the model checkpoint\n", + "model.load_state_dict(torch.load(os.path.join(model_artifact_dir, \"model.pth\")))\n", + "model.eval()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 📈 Evaluating the Model\n", + "\n", + "First we define some instances of `monai.metrics.DiceMetric` for all the metrics that we will be evaluating the model against on the validation split of our dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from monai.metrics import DiceMetric\n", + "from monai.transforms import Activations, AsDiscrete\n", + "\n", + "# Dice score for each class\n", + "tumor_core_dice_metric = DiceMetric(include_background=True, reduction=\"mean\")\n", + "enhancing_tumor_dice_metric = DiceMetric(include_background=True, reduction=\"mean\")\n", + "whole_tumor_dice_metric = DiceMetric(include_background=True, reduction=\"mean\")\n", + "\n", + "# Mean dice score across all classes\n", + "dice_metric_batch = DiceMetric(include_background=True, reduction=\"mean_batch\")\n", + "\n", + "# transforms to postprocess the outputs of the model for evaluation and visualization\n", + "postprocessing_transforms = Compose(\n", + " [Activations(sigmoid=True), AsDiscrete(threshold=0.5)]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we write some utility functions for evaluating each data-point from the validation dataset by logging dice score for each target class and the ground-truth and predicted segmentation labels (for granular visual comparison and analysis) to a W&B Table." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm.auto import tqdm\n", + "\n", + "\n", + "def get_target_area_percentage(segmentation_map):\n", + " segmentation_map_list = segmentation_map.flatten().tolist()\n", + " return segmentation_map_list.count(1.0) * 100 / len(segmentation_map_list)\n", + "\n", + "\n", + "def get_class_wise_dice_scores(sample_label, predicted_label, slice_idx):\n", + " sample_label = torch.from_numpy(sample_label).to(device)\n", + " predicted_label = torch.from_numpy(predicted_label).to(device)\n", + " tumor_core_dice_metric(\n", + " y_pred=torch.unsqueeze(predicted_label[1, :, :, slice_idx], dim=0),\n", + " y=torch.unsqueeze(sample_label[0, :, :, slice_idx], dim=0),\n", + " )\n", + " whole_tumor_dice_metric(\n", + " y_pred=torch.unsqueeze(predicted_label[1, :, :, slice_idx], dim=0),\n", + " y=torch.unsqueeze(sample_label[1, :, :, slice_idx], dim=0),\n", + " )\n", + " enhancing_tumor_dice_metric(\n", + " y_pred=torch.unsqueeze(predicted_label[2, :, :, slice_idx], dim=0),\n", + " y=torch.unsqueeze(sample_label[2, :, :, slice_idx], dim=0),\n", + " )\n", + " dice_scores = {\n", + " \"Tumor-Core\": tumor_core_dice_metric.aggregate().item(),\n", + " \"Enhancing-Tumor\": enhancing_tumor_dice_metric.aggregate().item(),\n", + " \"Whole-Tumor\": whole_tumor_dice_metric.aggregate().item(),\n", + " }\n", + " tumor_core_dice_metric.reset()\n", + " whole_tumor_dice_metric.reset()\n", + " enhancing_tumor_dice_metric.reset()\n", + " return dice_scores\n", + "\n", + "\n", + "def log_predictions_into_tables(\n", + " sample_image,\n", + " sample_label,\n", + " predicted_label,\n", + " split: str = None,\n", + " data_idx: int = None,\n", + " table: wandb.Table = None,\n", + "):\n", + " sample_image = sample_image.cpu().numpy()\n", + " sample_label = sample_label.cpu().numpy()\n", + " predicted_label = predicted_label.cpu().numpy()\n", + " _, _, _, num_slices = sample_image.shape\n", + " with tqdm(total=num_slices, leave=False) as progress_bar:\n", + " for slice_idx in range(num_slices):\n", + " tumor_core_dice_metric\n", + " wandb_images = [\n", + " wandb.Image(\n", + " sample_image[0, :, :, slice_idx],\n", + " masks={\n", + " \"ground-truth/Tumor-Core\": {\n", + " \"mask_data\": sample_label[0, :, :, slice_idx],\n", + " \"class_labels\": {0: \"background\", 1: \"Tumor Core\"},\n", + " },\n", + " \"prediction/Tumor-Core\": {\n", + " \"mask_data\": predicted_label[0, :, :, slice_idx] * 2,\n", + " \"class_labels\": {0: \"background\", 2: \"Tumor Core\"},\n", + " },\n", + " },\n", + " ),\n", + " wandb.Image(\n", + " sample_image[0, :, :, slice_idx],\n", + " masks={\n", + " \"ground-truth/Whole-Tumor\": {\n", + " \"mask_data\": sample_label[1, :, :, slice_idx],\n", + " \"class_labels\": {0: \"background\", 1: \"Whole Tumor\"},\n", + " },\n", + " \"prediction/Whole-Tumor\": {\n", + " \"mask_data\": predicted_label[1, :, :, slice_idx] * 2,\n", + " \"class_labels\": {0: \"background\", 2: \"Whole Tumor\"},\n", + " },\n", + " },\n", + " ),\n", + " wandb.Image(\n", + " sample_image[0, :, :, slice_idx],\n", + " masks={\n", + " \"ground-truth/Enhancing-Tumor\": {\n", + " \"mask_data\": sample_label[2, :, :, slice_idx],\n", + " \"class_labels\": {0: \"background\", 1: \"Enhancing Tumor\"},\n", + " },\n", + " \"prediction/Enhancing-Tumor\": {\n", + " \"mask_data\": predicted_label[2, :, :, slice_idx] * 2,\n", + " \"class_labels\": {0: \"background\", 2: \"Enhancing Tumor\"},\n", + " },\n", + " },\n", + " ),\n", + " ]\n", + " tumor_area_percentage = {\n", + " \"Ground-Truth\": {\n", + " \"Tumor-Core\": get_target_area_percentage(\n", + " sample_label[0, :, :, slice_idx]\n", + " ),\n", + " \"Whole-Tumor\": get_target_area_percentage(\n", + " sample_label[1, :, :, slice_idx]\n", + " ),\n", + " \"Enhancing-Tumor\": get_target_area_percentage(\n", + " sample_label[2, :, :, slice_idx]\n", + " ),\n", + " },\n", + " \"Prediction\": {\n", + " \"Tumor-Core\": get_target_area_percentage(\n", + " predicted_label[0, :, :, slice_idx]\n", + " ),\n", + " \"Whole-Tumor\": get_target_area_percentage(\n", + " predicted_label[1, :, :, slice_idx]\n", + " ),\n", + " \"Enhancing-Tumor\": get_target_area_percentage(\n", + " predicted_label[2, :, :, slice_idx]\n", + " ),\n", + " },\n", + " }\n", + " dice_scores = get_class_wise_dice_scores(\n", + " sample_label, predicted_label, slice_idx\n", + " )\n", + " table.add_data(\n", + " split,\n", + " data_idx,\n", + " slice_idx,\n", + " dice_scores,\n", + " tumor_area_percentage,\n", + " *wandb_images\n", + " )\n", + " progress_bar.update(1)\n", + " return table" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we create the prediction table." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "evaluation_table = wandb.Table(\n", + " columns=[\n", + " \"Split\",\n", + " \"Data Index\",\n", + " \"Slice Index\",\n", + " \"Dice-Score\",\n", + " \"Tumor-Area-Pixel-Percentage\",\n", + " \"Prediction/Tumor-Core\",\n", + " \"Prediction/Whole-Tumor\",\n", + " \"Prediction/Enhancing-Tumor\",\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we loop over the validation dataset and log the evaluation table and the mean dice scores for each class across the entore validation set to W&B." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from utils import inference\n", + "\n", + "total_tumor_core_dice_score = 0.0\n", + "total_whole_tumor_dice_score = 0.0\n", + "total_enhancing_tumor_dice_score = 0.0\n", + "\n", + "config.inference_roi_size = (240, 240, 160)\n", + "\n", + "# Perform inference and visualization\n", + "with torch.no_grad():\n", + " for data_idx, sample in tqdm(enumerate(val_dataset), total=len(val_dataset), desc=\"Evaluating:\"):\n", + " test_input, test_labels = (\n", + " torch.unsqueeze(sample[\"image\"], 0).to(device),\n", + " torch.unsqueeze(sample[\"label\"], 0).to(device),\n", + " )\n", + " test_output = inference(model, test_input, config.inference_roi_size)\n", + " test_output = postprocessing_transforms(test_output[0])\n", + " dice_metric_batch(y_pred=torch.unsqueeze(test_output, dim=0), y=test_labels)\n", + " metric_batch = dice_metric_batch.aggregate()\n", + " evaluation_table = log_predictions_into_tables(\n", + " sample_image=torch.squeeze(test_input),\n", + " sample_label=torch.squeeze(test_labels),\n", + " predicted_label=test_output,\n", + " data_idx=data_idx,\n", + " split=\"validation\",\n", + " table=evaluation_table,\n", + " )\n", + " total_tumor_core_dice_score += metric_batch[0].item()\n", + " total_whole_tumor_dice_score += metric_batch[1].item()\n", + " total_enhancing_tumor_dice_score += metric_batch[2].item()\n", + "\n", + " wandb.log({\"Tumor-Segmentation-Evaludation\": evaluation_table})\n", + " wandb.summary[\"Tumor-Score-Dice-Score\"] = total_tumor_core_dice_score / len(val_dataset)\n", + " wandb.summary[\"Whole-Tumor-Dice-Score\"] = total_whole_tumor_dice_score / len(val_dataset)\n", + " wandb.summary[\"Enhancing-Tumor-Dice-Score\"] = total_enhancing_tumor_dice_score / len(val_dataset)\n", + "\n", + "# End the experiment\n", + "wandb.finish()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/life-science/04_monai_train_improved_model.ipynb b/life-science/04_monai_train_improved_model.ipynb new file mode 100644 index 00000000..e0da229a --- /dev/null +++ b/life-science/04_monai_train_improved_model.ipynb @@ -0,0 +1,568 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a93e6b75-b1d9-4703-8834-ab3fcd8f934b", + "metadata": {}, + "source": [ + "# Train a Baseline Segmentation Model\n", + "In this notebook we will learn:\n", + "\n", + "- We will learn how to use specific MONAI APIs to write our training workflow, including a SoTA neural network architecture and loss function and metrics for our task.\n", + "- Use Weights & Biases for tracking our experiments and logging and verisioning our model checkpoints." + ] + }, + { + "cell_type": "markdown", + "id": "95118e9e-e6d0-4bde-bd5b-af792ca8153a", + "metadata": {}, + "source": [ + "## 🌴 Setup and Installation\n", + "\n", + "First, let us install the latest version of both MONAI and Weights and Biases." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3ab1bc3-d503-4e4c-afd1-ad77d72ac472", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -q -U monai wandb" + ] + }, + { + "cell_type": "markdown", + "id": "281b7942", + "metadata": {}, + "source": [ + "## 🦄 Getting the Best Configs\n", + "\n", + "For training a model that is an improvement over the baseline model we would need to get the configs of the best performing run from the sweep." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d004c6bf", + "metadata": {}, + "outputs": [], + "source": [ + "import wandb\n", + "\n", + "\n", + "def get_best_config_from_sweep(entity: str, project: str, sweep_id: str, metric: str):\n", + " api = wandb.Api()\n", + " sweep = api.sweep(f\"{entity}/{project}/{sweep_id}\")\n", + " runs = sorted(\n", + " sweep.runs, key=lambda run: run.summary.get(metric, 0), reverse=True\n", + " )\n", + " best_run = runs[0]\n", + " return best_run.config\n", + "\n", + "\n", + "config = get_best_config_from_sweep(\n", + " entity=\"lifesciences\",\n", + " project=\"brain-tumor-segmentation\",\n", + " sweep_id=\"580gsolt\",\n", + " metric=\"validation/mean_dice\",\n", + ")\n", + "config[\"initial_learning_rate\"] = 1e-4\n", + "config[\"max_train_epochs\"] = 25" + ] + }, + { + "cell_type": "markdown", + "id": "6a4f71b2-ffb6-475c-a563-1cb631e33d84", + "metadata": {}, + "source": [ + "Next, we will start a new W&B run to start tracking our experiment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b258cfa2-3795-4502-8878-aff469ba5077", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.utils import set_determinism\n", + "\n", + "wandb.init(\n", + " project=\"brain-tumor-segmentation\",\n", + " entity=\"lifesciences\",\n", + " job_type=\"train_improved\",\n", + " config=config,\n", + ")\n", + "config = wandb.config\n", + "\n", + "set_determinism(seed=config.seed)" + ] + }, + { + "cell_type": "markdown", + "id": "0d2da2f5-b1a8-4215-927c-dc5884b0b41f", + "metadata": {}, + "source": [ + "## 💿 Loading and Transforming the Data\n", + "\n", + "We will now learn using the [`monai.transforms`](https://docs.monai.io/en/stable/transforms.html) API to create and apply transforms to our data." + ] + }, + { + "cell_type": "markdown", + "id": "1f38e5d1-e3de-4bda-9b89-67193a6baca2", + "metadata": {}, + "source": [ + "Next, we compose all the necessary transforms for visualizing the data using [`monai.transforms.Compose`](https://docs.monai.io/en/stable/transforms.html#monai.transforms.Compose).\n", + "\n", + "**Note:** During training, we will apply a differnt set of transforms to the data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62a9e73a-aeb1-49b2-b06b-d09857b74966", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.transforms import (\n", + " Activations,\n", + " AsDiscrete,\n", + " Compose,\n", + " LoadImaged,\n", + " NormalizeIntensityd,\n", + " Orientationd,\n", + " RandFlipd,\n", + " RandScaleIntensityd,\n", + " RandShiftIntensityd,\n", + " RandSpatialCropd,\n", + " Spacingd,\n", + " EnsureTyped,\n", + " EnsureChannelFirstd,\n", + ")\n", + "from utils import ConvertToMultiChannelBasedOnBratsClassesd\n", + "\n", + "\n", + "config.roi_size = [224, 224, 144]\n", + "\n", + "train_transform = Compose(\n", + " [\n", + " # load 4 Nifti images and stack them together\n", + " LoadImaged(keys=[\"image\", \"label\"]),\n", + " # Ensure loaded images are in channels-first format\n", + " EnsureChannelFirstd(keys=\"image\"),\n", + " # Ensure the input data to be a PyTorch Tensor or numpy array\n", + " EnsureTyped(keys=[\"image\", \"label\"]),\n", + " # Convert labels to multi-channels based on brats18 classes\n", + " ConvertToMultiChannelBasedOnBratsClassesd(keys=\"label\"),\n", + " # Change the input image’s orientation into the specified based on axis codes\n", + " Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n", + " # Resample the input images to the specified pixel dimension\n", + " Spacingd(\n", + " keys=[\"image\", \"label\"],\n", + " pixdim=(1.0, 1.0, 1.0),\n", + " mode=(\"bilinear\", \"nearest\"),\n", + " ),\n", + " # Augmentation: Crop image with random size or specific size ROI\n", + " RandSpatialCropd(\n", + " keys=[\"image\", \"label\"], roi_size=config.roi_size, random_size=False\n", + " ),\n", + " \n", + " # Augmentation: Randomly flip the image on the specified axes\n", + " RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=0),\n", + " RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=1),\n", + " RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=2),\n", + " \n", + " # Normalize input image intensity\n", + " NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n", + " \n", + " # Augmentation: Randomly scale the image intensity\n", + " RandScaleIntensityd(keys=\"image\", factors=0.1, prob=1.0),\n", + " RandShiftIntensityd(keys=\"image\", offsets=0.1, prob=1.0),\n", + " ]\n", + ")\n", + "val_transform = Compose(\n", + " [\n", + " # load 4 Nifti images and stack them together\n", + " LoadImaged(keys=[\"image\", \"label\"]),\n", + " # Ensure loaded images are in channels-first format\n", + " EnsureChannelFirstd(keys=\"image\"),\n", + " # Ensure the input data to be a PyTorch Tensor or numpy array\n", + " EnsureTyped(keys=[\"image\", \"label\"]),\n", + " # Convert labels to multi-channels based on brats18 classes\n", + " ConvertToMultiChannelBasedOnBratsClassesd(keys=\"label\"),\n", + " # Change the input image’s orientation into the specified based on axis codes\n", + " Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n", + " # Resample the input images to the specified pixel dimension\n", + " Spacingd(\n", + " keys=[\"image\", \"label\"],\n", + " pixdim=(1.0, 1.0, 1.0),\n", + " mode=(\"bilinear\", \"nearest\"),\n", + " ),\n", + " # Normalize input image intensity\n", + " NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "14df1019-7cf6-4baa-baa1-fa2085da17bb", + "metadata": {}, + "source": [ + "For loading the dataset, we first fetch it from the W&B dataset artifact that we had created earlier. This enables us to use the dataset as an input artifact to our visualization run, and establish the necessary lineage for our experiment.\n", + "\n", + "![](./assets/artifact_usage.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35b2d2c9-8e7f-479b-a822-deb0984d22ae", + "metadata": {}, + "outputs": [], + "source": [ + "artifact = wandb.use_artifact(\n", + " \"lifesciences/brain-tumor-segmentation/decathlon_brain_tumor:v0\", type=\"dataset\"\n", + ")\n", + "artifact_dir = artifact.download()" + ] + }, + { + "cell_type": "markdown", + "id": "8fd9de15-afa4-4ccf-9d4f-78b90ae3db8f", + "metadata": {}, + "source": [ + "We now use the [`monai.apps.DecathlonDataset`](https://docs.monai.io/en/stable/apps.html#monai.apps.DecathlonDataset) to load our dataset and apply the transforms we defined on the data samples so that we use them for training and validation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16faedc0-f50a-4289-9fa4-6948549ea74f", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.apps import DecathlonDataset\n", + "\n", + "\n", + "# Create the dataset for the training split\n", + "# of the brain tumor segmentation dataset\n", + "train_dataset = DecathlonDataset(\n", + " root_dir=artifact_dir,\n", + " task=\"Task01_BrainTumour\",\n", + " transform=train_transform,\n", + " section=\"training\",\n", + " download=False,\n", + " cache_rate=0.0,\n", + " num_workers=config.num_workers,\n", + ")\n", + "\n", + "# Create the dataset for the validation split\n", + "# of the brain tumor segmentation dataset\n", + "val_dataset = DecathlonDataset(\n", + " root_dir=artifact_dir,\n", + " task=\"Task01_BrainTumour\",\n", + " transform=val_transform,\n", + " section=\"validation\",\n", + " download=False,\n", + " cache_rate=0.0,\n", + " num_workers=config.num_workers,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e8eb84cf-eea0-431b-a66e-9afd2b7144ba", + "metadata": {}, + "source": [ + "We now create DataLoaders for the train and validation datasets respectively using [`monai.data.DataLoader`](https://docs.monai.io/en/stable/data.html#dataloader) which provides an iterable over the given dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ad33748-060f-4652-87ed-f8b56de02824", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.data import DataLoader\n", + "\n", + "\n", + "# create the train_loader\n", + "train_loader = DataLoader(\n", + " train_dataset,\n", + " batch_size=config.batch_size,\n", + " shuffle=True,\n", + " num_workers=config.num_workers,\n", + ")\n", + "\n", + "# create the val_loader\n", + "val_loader = DataLoader(\n", + " val_dataset,\n", + " batch_size=config.batch_size,\n", + " shuffle=False,\n", + " num_workers=config.num_workers,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "ba68b8b8-9527-482d-8f4f-da8c71dcbd87", + "metadata": {}, + "source": [ + "## 🤖 Creating the Model, Loss, and Optimizer\n", + "\n", + "We will be training a **SegResNet** model based on the paper [3D MRI brain tumor segmentation using auto-encoder regularization](https://arxiv.org/pdf/1810.11654.pdf). The [SegResNet](https://docs.monai.io/en/stable/networks.html#segresnet) model that comes implemented as a PyTorch Module as part of the [`monai.networks.nets`](https://docs.monai.io/en/stable/networks.html#nets) API that provides out-of-the-box implementations of SoTA neural network models for different medical imaging tasks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d2a5e5a-84ee-4fbe-9b06-1fcf9639d443", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from monai.networks.nets import SegResNet\n", + "\n", + "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + "# create model\n", + "model = SegResNet(\n", + " blocks_down=config.model_blocks_down,\n", + " blocks_up=config.model_blocks_up,\n", + " init_filters=config.model_init_filters,\n", + " in_channels=config.model_in_channels,\n", + " out_channels=config.model_out_channels,\n", + " dropout_prob=config.model_dropout_prob,\n", + ").to(device)" + ] + }, + { + "cell_type": "markdown", + "id": "20863f94-f7d0-4b82-8a55-41a1952d1cae", + "metadata": {}, + "source": [ + "We will be using [Adam Optimizer](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html) and the [cosine annealing schedule](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html) to schedule our learning rate. This approach is designed to help in finding global minima in the optimization landscape and to provide a form of reset mechanism during training, which can improve the performance of the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a9a9b1a6-d127-4fe9-8831-a33b71dbe438", + "metadata": {}, + "outputs": [], + "source": [ + "# create optimizer\n", + "optimizer = torch.optim.Adam(\n", + " model.parameters(),\n", + " config.initial_learning_rate,\n", + " weight_decay=config.weight_decay,\n", + ")\n", + "\n", + "# create learning rate scheduler\n", + "lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n", + " optimizer, T_max=config.max_train_epochs\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "541dc8c6-009e-4115-81c5-d35386a92b6d", + "metadata": {}, + "source": [ + "Next, we would define the loss as multi-label DiceLoss as proposed by the paper [V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation](https://arxiv.org/abs/1606.04797) using the [`monai.losses`](https://docs.monai.io/en/stable/losses.html) API and the corresponding dice metrics using the [`monai.metrics`](https://docs.monai.io/en/stable/metrics.html) API." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2276a62-24a9-4352-8058-eb807e16819e", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.losses import DiceLoss\n", + "\n", + "loss_function = DiceLoss(\n", + " smooth_nr=config.dice_loss_smoothen_numerator,\n", + " smooth_dr=config.dice_loss_smoothen_denominator,\n", + " squared_pred=config.dice_loss_squared_prediction,\n", + " to_onehot_y=config.dice_loss_target_onehot,\n", + " sigmoid=config.dice_loss_apply_sigmoid,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "187d9b11-3df9-4ae8-96c0-e7c16468878d", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.metrics import DiceMetric\n", + "\n", + "dice_metric = DiceMetric(include_background=True, reduction=\"mean\")\n", + "dice_metric_batch = DiceMetric(include_background=True, reduction=\"mean_batch\")\n", + "post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])" + ] + }, + { + "cell_type": "markdown", + "id": "55423f02-8686-4a56-9778-09070b7edcc5", + "metadata": {}, + "source": [ + "## 🦾 Training the Model\n", + "\n", + "Finally, we proceed to writing the training and validation loop for the brain tumor segmentation model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7df12631-8e09-48d3-b0b0-655775ec22e4", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from tqdm.auto import tqdm\n", + "from monai.data import decollate_batch\n", + "from utils import inference\n", + "\n", + "\n", + "# Define custom x-axes\n", + "wandb.define_metric(\"epoch/epoch_step\")\n", + "wandb.define_metric(\"epoch/*\", step_metric=\"epoch/epoch_step\")\n", + "wandb.define_metric(\"batch/batch_step\")\n", + "wandb.define_metric(\"batch/*\", step_metric=\"batch/batch_step\")\n", + "wandb.define_metric(\"validation/validation_step\")\n", + "wandb.define_metric(\"validation/*\", step_metric=\"validation/validation_step\")\n", + "\n", + "# use automatic mixed-precision to accelerate training\n", + "scaler = torch.cuda.amp.GradScaler()\n", + "torch.backends.cudnn.benchmark = True\n", + "\n", + "# Create checkpoint directory\n", + "checkpoint_dir = \"./checkpoints\"\n", + "os.makedirs(checkpoint_dir, exist_ok=True)\n", + "\n", + "batch_step = 0\n", + "validation_step = 0\n", + "metric_values = []\n", + "metric_values_tumor_core = []\n", + "metric_values_whole_tumor = []\n", + "metric_values_enhanced_tumor = []\n", + "\n", + "epoch_progress_bar = tqdm(range(config.max_train_epochs), desc=\"Training:\")\n", + "\n", + "for epoch in epoch_progress_bar:\n", + " model.train()\n", + " epoch_loss = 0\n", + "\n", + " total_batch_steps = len(train_dataset) // train_loader.batch_size\n", + " batch_progress_bar = tqdm(train_loader, total=total_batch_steps, leave=False)\n", + "\n", + " # Training Step\n", + " for batch_data in batch_progress_bar:\n", + " inputs, labels = (\n", + " batch_data[\"image\"].to(device),\n", + " batch_data[\"label\"].to(device),\n", + " )\n", + " optimizer.zero_grad()\n", + " with torch.cuda.amp.autocast():\n", + " outputs = model(inputs)\n", + " loss = loss_function(outputs, labels)\n", + " scaler.scale(loss).backward()\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", + " epoch_loss += loss.item()\n", + " batch_progress_bar.set_description(f\"train_loss: {loss.item():.4f}:\")\n", + " ## Log batch-wise training loss to W&B\n", + " wandb.log({\"batch/batch_step\": batch_step, \"batch/train_loss\": loss.item()})\n", + " batch_step += 1\n", + "\n", + " epoch_loss /= total_batch_steps\n", + " ## Log batch-wise training loss and learning rate to W&B\n", + " wandb.log(\n", + " {\n", + " \"epoch/epoch_step\": epoch,\n", + " \"epoch/mean_train_loss\": epoch_loss,\n", + " \"epoch/learning_rate\": lr_scheduler.get_last_lr()[0],\n", + " }\n", + " )\n", + " lr_scheduler.step()\n", + " epoch_progress_bar.set_description(f\"Training: train_loss: {epoch_loss:.4f}:\")\n", + "\n", + " # Validation and model checkpointing step\n", + " if (epoch + 1) % config.validation_intervals == 0:\n", + " model.eval()\n", + " with torch.no_grad():\n", + " for val_data in val_loader:\n", + " val_inputs, val_labels = (\n", + " val_data[\"image\"].to(device),\n", + " val_data[\"label\"].to(device),\n", + " )\n", + " val_outputs = inference(model, val_inputs)\n", + " val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]\n", + " dice_metric(y_pred=val_outputs, y=val_labels)\n", + " dice_metric_batch(y_pred=val_outputs, y=val_labels)\n", + "\n", + " metric_values.append(dice_metric.aggregate().item())\n", + " metric_batch = dice_metric_batch.aggregate()\n", + " metric_values_tumor_core.append(metric_batch[0].item())\n", + " metric_values_whole_tumor.append(metric_batch[1].item())\n", + " metric_values_enhanced_tumor.append(metric_batch[2].item())\n", + " dice_metric.reset()\n", + " dice_metric_batch.reset()\n", + "\n", + " # Log and versison model checkpoints using W&B artifacts.\n", + " checkpoint_path = os.path.join(checkpoint_dir, \"model.pth\")\n", + " torch.save(model.state_dict(), checkpoint_path)\n", + " wandb.log_model(\n", + " checkpoint_path,\n", + " name=f\"{wandb.run.id}-checkpoint\",\n", + " aliases=[f\"epoch_{epoch}\"],\n", + " )\n", + "\n", + " # Log validation metrics to W&B dashboard.\n", + " wandb.log(\n", + " {\n", + " \"validation/validation_step\": validation_step,\n", + " \"validation/mean_dice\": metric_values[-1],\n", + " \"validation/mean_dice_tumor_core\": metric_values_tumor_core[-1],\n", + " \"validation/mean_dice_whole_tumor\": metric_values_whole_tumor[-1],\n", + " \"validation/mean_dice_enhanced_tumor\": metric_values_enhanced_tumor[-1],\n", + " }\n", + " )\n", + " validation_step += 1\n", + "\n", + "\n", + "# Finish the experiment\n", + "wandb.finish()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/life-science/assets/artifact_usage.png b/life-science/assets/artifact_usage.png new file mode 100644 index 00000000..949348ac Binary files /dev/null and b/life-science/assets/artifact_usage.png differ diff --git a/life-science/config.yaml b/life-science/config.yaml new file mode 100644 index 00000000..ae1f094e --- /dev/null +++ b/life-science/config.yaml @@ -0,0 +1,29 @@ +program: train.py +name: segmentation_sweep +method: bayes +metric: + goal: maximize + name: validation/mean_dice +parameters: + model_dropout_prob: + distribution: uniform + max: 0.5 + min: 0.1 + model_init_filters: + values: [8, 16, 32] + initial_learning_rate: + distribution: uniform + max: 0.001 + min: 1e-06 + dice_loss_smoothen_denominator: + distribution: uniform + max: 0.0001 + min: 1e-06 + dice_loss_smoothen_numerator: + distribution: uniform + max: 0.0001 + min: 1e-06 + weight_decay: + distribution: uniform + max: 0.001 + min: 1e-06 \ No newline at end of file diff --git a/life-science/train.py b/life-science/train.py new file mode 100644 index 00000000..ef2f7438 --- /dev/null +++ b/life-science/train.py @@ -0,0 +1,294 @@ +import torch +import wandb +from tqdm.auto import tqdm + +from monai.apps import DecathlonDataset +from monai.data import DataLoader, decollate_batch +from monai.losses import DiceLoss +from monai.metrics import DiceMetric +from monai.networks.nets import SegResNet +from monai.transforms import ( + Activations, + AsDiscrete, + Compose, + LoadImaged, + NormalizeIntensityd, + Orientationd, + RandFlipd, + RandScaleIntensityd, + RandShiftIntensityd, + RandSpatialCropd, + Spacingd, + EnsureTyped, + EnsureChannelFirstd, +) +from monai.utils import set_determinism + +from utils import ConvertToMultiChannelBasedOnBratsClassesd, inference + + +def main(): + wandb.init(project="brain-tumor-segmentation", entity="lifesciences") + config = wandb.config + + # Manually setting the values of the configs unaffected by the sweep + config.seed = 0 + config.roi_size = [224, 224, 144] + config.num_workers = 4 + config.batch_size = 2 + config.model_blocks_down = [1, 2, 2, 4] + config.model_blocks_up = [1, 1, 1] + config.model_in_channels = 4 + config.model_out_channels = 3 + config.max_train_epochs = 5 + config.dice_loss_squared_prediction = True + config.dice_loss_target_onehot = False + config.dice_loss_apply_sigmoid = True + config.inference_roi_size = (240, 240, 160) + config.validation_intervals = 1 + + # We are not setting the values of the following configs as their values + # will be determined the sweep + # - config.model_dropout_prob = 0.2 + # - config.model_init_filters = 16 + # - config.initial_learning_rate = 1e-4 + # - config.dice_loss_smoothen_denominator = 1e-5 + # - config.dice_loss_smoothen_numerator = 0 + # - config.weight_decay = 1e-5 + + set_determinism(seed=config.seed) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + train_transform = Compose( + [ + # load 4 Nifti images and stack them together + LoadImaged(keys=["image", "label"]), + # Ensure loaded images are in channels-first format + EnsureChannelFirstd(keys="image"), + # Ensure the input data to be a PyTorch Tensor or numpy array + EnsureTyped(keys=["image", "label"]), + # Convert labels to multi-channels based on brats18 classes + ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), + # Change the input image’s orientation into the specified based on axis codes + Orientationd(keys=["image", "label"], axcodes="RAS"), + # Resample the input images to the specified pixel dimension + Spacingd( + keys=["image", "label"], + pixdim=(1.0, 1.0, 1.0), + mode=("bilinear", "nearest"), + ), + # Augmentation: Crop image with random size or specific size ROI + RandSpatialCropd( + keys=["image", "label"], roi_size=config.roi_size, random_size=False + ), + # Augmentation: Randomly flip the image on the specified axes + RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), + RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1), + RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2), + # Normalize input image intensity + NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), + # Augmentation: Randomly scale the image intensity + RandScaleIntensityd(keys="image", factors=0.1, prob=1.0), + RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0), + ] + ) + val_transform = Compose( + [ + # load 4 Nifti images and stack them together + LoadImaged(keys=["image", "label"]), + # Ensure loaded images are in channels-first format + EnsureChannelFirstd(keys="image"), + # Ensure the input data to be a PyTorch Tensor or numpy array + EnsureTyped(keys=["image", "label"]), + # Convert labels to multi-channels based on brats18 classes + ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), + # Change the input image’s orientation into the specified based on axis codes + Orientationd(keys=["image", "label"], axcodes="RAS"), + # Resample the input images to the specified pixel dimension + Spacingd( + keys=["image", "label"], + pixdim=(1.0, 1.0, 1.0), + mode=("bilinear", "nearest"), + ), + # Normalize input image intensity + NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), + ] + ) + + # Create the dataset for the training split + # of the brain tumor segmentation dataset + train_dataset = DecathlonDataset( + root_dir="./artifacts/decathlon_brain_tumor:v0", + task="Task01_BrainTumour", + transform=train_transform, + section="training", + download=False, + cache_rate=0.0, + num_workers=config.num_workers, + ) + + # Create the dataset for the validation split + # of the brain tumor segmentation dataset + val_dataset = DecathlonDataset( + root_dir="./artifacts/decathlon_brain_tumor:v0", + task="Task01_BrainTumour", + transform=val_transform, + section="validation", + download=False, + cache_rate=0.0, + num_workers=config.num_workers, + ) + + # create the train_loader + train_loader = DataLoader( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + num_workers=config.num_workers, + ) + + # create the val_loader + val_loader = DataLoader( + val_dataset, + batch_size=config.batch_size, + shuffle=False, + num_workers=config.num_workers, + ) + + # create model + model = SegResNet( + blocks_down=config.model_blocks_down, + blocks_up=config.model_blocks_up, + init_filters=config.model_init_filters, + in_channels=config.model_in_channels, + out_channels=config.model_out_channels, + dropout_prob=config.model_dropout_prob, + ).to(device) + + # create optimizer + optimizer = torch.optim.Adam( + model.parameters(), + config.initial_learning_rate, + weight_decay=config.weight_decay, + ) + + # create learning rate scheduler + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=config.max_train_epochs + ) + + loss_function = DiceLoss( + smooth_nr=config.dice_loss_smoothen_numerator, + smooth_dr=config.dice_loss_smoothen_denominator, + squared_pred=config.dice_loss_squared_prediction, + to_onehot_y=config.dice_loss_target_onehot, + sigmoid=config.dice_loss_apply_sigmoid, + ) + + dice_metric = DiceMetric(include_background=True, reduction="mean") + dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch") + postprocessing_transforms = Compose( + [Activations(sigmoid=True), AsDiscrete(threshold=0.5)] + ) + + scaler = torch.cuda.amp.GradScaler() + torch.backends.cudnn.benchmark = True + + wandb.define_metric("epoch/epoch_step") + wandb.define_metric("epoch/*", step_metric="epoch/epoch_step") + wandb.define_metric("batch/batch_step") + wandb.define_metric("batch/*", step_metric="batch/batch_step") + wandb.define_metric("validation/validation_step") + wandb.define_metric("validation/*", step_metric="validation/validation_step") + + batch_step = 0 + validation_step = 0 + metric_values = [] + metric_values_tumor_core = [] + metric_values_whole_tumor = [] + metric_values_enhanced_tumor = [] + + epoch_progress_bar = tqdm(range(config.max_train_epochs), desc="Training:") + + for epoch in epoch_progress_bar: + model.train() + epoch_loss = 0 + + total_batch_steps = len(train_dataset) // train_loader.batch_size + batch_progress_bar = tqdm(train_loader, total=total_batch_steps, leave=False) + + # Training Step + for batch_data in batch_progress_bar: + inputs, labels = ( + batch_data["image"].to(device), + batch_data["label"].to(device), + ) + optimizer.zero_grad() + with torch.cuda.amp.autocast(): + outputs = model(inputs) + loss = loss_function(outputs, labels) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + epoch_loss += loss.item() + batch_progress_bar.set_description(f"train_loss: {loss.item():.4f}:") + ## Log batch-wise training loss to W&B + wandb.log({"batch/batch_step": batch_step, "batch/train_loss": loss.item()}) + batch_step += 1 + + epoch_loss /= total_batch_steps + ## Log batch-wise training loss and learning rate to W&B + wandb.log( + { + "epoch/epoch_step": epoch, + "epoch/mean_train_loss": epoch_loss, + "epoch/learning_rate": lr_scheduler.get_last_lr()[0], + } + ) + lr_scheduler.step() + epoch_progress_bar.set_description(f"Training: train_loss: {epoch_loss:.4f}:") + + # Validation and model checkpointing step + if (epoch + 1) % config.validation_intervals == 0: + model.eval() + with torch.no_grad(): + for val_data in val_loader: + val_inputs, val_labels = ( + val_data["image"].to(device), + val_data["label"].to(device), + ) + val_outputs = inference(model, val_inputs, config.roi_size) + val_outputs = [ + postprocessing_transforms(i) + for i in decollate_batch(val_outputs) + ] + dice_metric(y_pred=val_outputs, y=val_labels) + dice_metric_batch(y_pred=val_outputs, y=val_labels) + + metric_values.append(dice_metric.aggregate().item()) + metric_batch = dice_metric_batch.aggregate() + metric_values_tumor_core.append(metric_batch[0].item()) + metric_values_whole_tumor.append(metric_batch[1].item()) + metric_values_enhanced_tumor.append(metric_batch[2].item()) + dice_metric.reset() + dice_metric_batch.reset() + + # Log validation metrics to W&B dashboard. + wandb.log( + { + "validation/validation_step": validation_step, + "validation/mean_dice": metric_values[-1], + "validation/mean_dice_tumor_core": metric_values_tumor_core[-1], + "validation/mean_dice_whole_tumor": metric_values_whole_tumor[ + -1 + ], + "validation/mean_dice_enhanced_tumor": metric_values_enhanced_tumor[ + -1 + ], + } + ) + validation_step += 1 + + +if __name__ == "__main__": + main() diff --git a/life-science/utils.py b/life-science/utils.py new file mode 100644 index 00000000..5eb5ade0 --- /dev/null +++ b/life-science/utils.py @@ -0,0 +1,60 @@ +from monai.inferers import sliding_window_inference +from monai.transforms import MapTransform + +import torch +import wandb + + +class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): + """ + Convert labels to multi-channels based on brats classes: + label 1 is the peritumoral edema + label 2 is the GD-enhancing tumor + label 3 is the necrotic and non-enhancing tumor core + The possible classes are TC (Tumor core), WT (Whole tumor), and ET (Enhancing tumor). + + Reference: https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb + + """ + + def __call__(self, data): + data_dict = dict(data) + for key in self.keys: + result = [] + # merge label 2 and label 3 to construct Tumor Core + result.append(torch.logical_or(data_dict[key] == 2, data_dict[key] == 3)) + # merge labels 1, 2 and 3 to construct Whole Tumor + result.append( + torch.logical_or( + torch.logical_or(data_dict[key] == 2, data_dict[key] == 3), + data_dict[key] == 1, + ) + ) + # label 2 is Enhancing Tumor + result.append(data_dict[key] == 2) + data_dict[key] = torch.stack(result, axis=0).float() + return data_dict + + +def inference(model, input, roi_size): + def _compute(input): + return sliding_window_inference( + inputs=input, + roi_size=roi_size, + sw_batch_size=1, + predictor=model, + overlap=0.5, + ) + + with torch.cuda.amp.autocast(): + return _compute(input) + + +def get_best_config_from_sweep( + entity: str, project: str, sweep_id: str, metric: str = "validation/mean_dice" +): + api = wandb.Api() + sweep = api.sweep(f"{entity}/{project}/{sweep_id}") + runs = sorted(sweep.runs, key=lambda run: run.summary.get(metric, 0), reverse=True) + best_run = runs[0] + return best_run.config