From 2571c69c0bd514d9ae401b7530cb77810e754292 Mon Sep 17 00:00:00 2001 From: samdow Date: Tue, 14 Jun 2022 11:38:04 -0400 Subject: [PATCH] simple functorch with modes --- simple_functorch_modes.ipynb | 2198 ++++++++++++++++++++++++++++++++++ simple_functorch_modes.py | 1302 ++++++++++++++++++++ 2 files changed, 3500 insertions(+) create mode 100644 simple_functorch_modes.ipynb create mode 100644 simple_functorch_modes.py diff --git a/simple_functorch_modes.ipynb b/simple_functorch_modes.ipynb new file mode 100644 index 0000000..67d8a82 --- /dev/null +++ b/simple_functorch_modes.ipynb @@ -0,0 +1,2198 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "## Simple Functorch but Make it Modes\n", + "\n", + "This notebook is a rewrite of the simple functorch notebook that uses torch dispatch modes instead of the Dispatcher object" + ], + "metadata": { + "id": "BXUThCq6-x9r" + }, + "id": "BXUThCq6-x9r" + }, + { + "cell_type": "code", + "source": [ + "!pip uninstall -y torch\n", + "!pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --upgrade" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 347 + }, + "id": "HIonKOQXUoyK", + "outputId": "21ad9757-1587-4a1c-d196-1ace81b96e50" + }, + "id": "HIonKOQXUoyK", + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Found existing installation: torch 1.13.0.dev20220613+cpu\n", + "Uninstalling torch-1.13.0.dev20220613+cpu:\n", + " Successfully uninstalled torch-1.13.0.dev20220613+cpu\n", + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Looking in links: https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n", + "Collecting torch\n", + " Using cached https://download.pytorch.org/whl/nightly/cpu/torch-1.13.0.dev20220613%2Bcpu-cp37-cp37m-linux_x86_64.whl (190.8 MB)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch) (4.2.0)\n", + "Installing collected packages: torch\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "torchvision 0.12.0+cu113 requires torch==1.11.0, but you have torch 1.13.0.dev20220613+cpu which is incompatible.\n", + "torchtext 0.12.0 requires torch==1.11.0, but you have torch 1.13.0.dev20220613+cpu which is incompatible.\n", + "torchaudio 0.11.0+cu113 requires torch==1.11.0, but you have torch 1.13.0.dev20220613+cpu which is incompatible.\u001b[0m\n", + "Successfully installed torch-1.13.0.dev20220613+cpu\n" + ] + }, + { + "output_type": "display_data", + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "torch" + ] + } + } + }, + "metadata": {} + } + ] + }, + { + "cell_type": "markdown", + "id": "1cb2ffc9", + "metadata": { + "id": "1cb2ffc9" + }, + "source": [ + "This notebook walks through a self-contained implementation of\n", + "functorch, including support for both vjp and vmap combinators (using\n", + "PyTorch only to implement primitive tensor operations). It follows\n", + "the tradition of\n", + "[Autodidax](https://jax.readthedocs.io/en/latest/autodidax.html) (a\n", + "pedagogical reimplementation of JAX, the library functorch is inspired\n", + "by) and [Simple\n", + "Autograd](https://colab.research.google.com/drive/1VpeE6UvEPRz9HmsHh1KS0XxXjYu533EC?usp=sharing)\n", + "(Zachary Devito's pedagogical reimplementation of autograd, which the\n", + "autograd system in this notebook is based off of.) You can [open this\n", + "file in\n", + "Colab](https://colab.research.google.com/github/albanD/subclass_zoo/blob/main/simple_functorch.ipynb)\n", + "and play around with the examples.\n", + "\n", + "As a simplified implementation of functorch, this notebook also makes\n", + "it easier to investigate some more subtle aspects of how PyTorch's\n", + "native autograd system interacts with composable transforms. In\n", + "particular, we will see that PyTorch's native implementation of double\n", + "backwards (which shares the same tape through multiple levels of\n", + "differentiation) differs from functorch's nested grad implementation\n", + "(which maintains a separate tape per level)." + ] + }, + { + "cell_type": "markdown", + "id": "6ed150b0", + "metadata": { + "id": "6ed150b0" + }, + "source": [ + "To get started, we replicate some of the data structures and helper functions\n", + "from Simple Autograd." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba464220", + "metadata": { + "lines_to_end_of_cell_marker": 2, + "id": "ba464220" + }, + "outputs": [], + "source": [ + "import contextlib\n", + "import functools\n", + "from dataclasses import dataclass\n", + "from typing import Callable, Dict, List, NamedTuple, Optional\n", + "\n", + "import torch\n", + "from torch import Tensor\n", + "from torch.utils._python_dispatch import TorchDispatchMode\n", + "\n", + "\n", + "class TapeEntry(NamedTuple):\n", + " # names of the inputs to the original computation\n", + " inputs: List[str]\n", + " # names of the outputs of the original computation\n", + " outputs: List[str]\n", + " # apply chain rule\n", + " propagate: Callable[[List[Tensor]], List[Tensor]]\n", + "\n", + "\n", + "_name = 0\n", + "\n", + "\n", + "def fresh_name() -> str:\n", + " \"\"\"create a new unique name for a variable: v0, v1, v2\"\"\"\n", + " global _name\n", + " r = f\"v{_name}\"\n", + " _name += 1\n", + " return r" + ] + }, + { + "cell_type": "markdown", + "id": "0af3ca0b", + "metadata": { + "lines_to_next_cell": 2, + "id": "0af3ca0b" + }, + "source": [ + "This is a little helper function for converting the dim argument in\n", + "sum into an explicit list of dimensions that will be reduced over.\n", + "It takes the dim of the tensor we are summing over and the dim\n", + "argument itself." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "836a1635", + "metadata": { + "lines_to_next_cell": 2, + "id": "836a1635" + }, + "outputs": [], + "source": [ + "def sum_dims(*, input_dim, dim):\n", + " if dim is None:\n", + " return tuple(range(0, input_dim))\n", + " elif isinstance(dim, int):\n", + " return (dim,)\n", + " else:\n", + " return tuple(sorted(dim))" + ] + }, + { + "cell_type": "markdown", + "source": [ + "This is another little helper function that we might want to incorporate into the default behavior of restore. But for now we need this in order to not error when restoring a mode" + ], + "metadata": { + "id": "wMJ7m3UuExhV" + }, + "id": "wMJ7m3UuExhV" + }, + { + "cell_type": "code", + "source": [ + "from torch._C import _get_torch_dispatch_mode\n", + "\n", + "def restore_dispatcher_or_nop(dispatcher):\n", + " # another argument that maybe we should let .restore() keep the current mode\n", + " if dispatcher == _get_torch_dispatch_mode():\n", + " # we'll just no-op here since restoring the current mode will error\n", + " return contextlib.nullcontext\n", + " return dispatcher.restore" + ], + "metadata": { + "id": "5VM_8NkcEvS0" + }, + "id": "5VM_8NkcEvS0", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "5f74fb48", + "metadata": { + "lines_to_next_cell": 2, + "id": "5f74fb48" + }, + "source": [ + "In Simple Autograd, we provided a Variable wrapper class which\n", + "provided a traditional Tensor style interface for our objects; in\n", + "functorch proper, objects are repeatedly wrapped in this way to\n", + "implement multipler layers of transformations.\n", + "\n", + "In my opinion, this sort of wrapper makes it more difficult to\n", + "understand the flow of logic. So in Simple Functorch, we take a\n", + "different approach: we won't make use of a wrapper class at all,\n", + "instead showing how to add it in the end as syntax sugar on top of our\n", + "system.\n", + "\n", + "For debuggability purposes, however, it is nice to have a way to\n", + "identify variables by a human readable name. We'll do this by setting\n", + "a t_name attribute on PyTorch tensors whenever we allocate a new\n", + "tensor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9d17aef", + "metadata": { + "lines_to_next_cell": 2, + "id": "f9d17aef" + }, + "outputs": [], + "source": [ + "def label(t: Tensor, name: str = None):\n", + " if not hasattr(t, \"t_name\"):\n", + " t.t_name = name or fresh_name()\n", + " return t" + ] + }, + { + "cell_type": "markdown", + "id": "cd046576", + "metadata": { + "lines_to_next_cell": 2, + "id": "cd046576" + }, + "source": [ + "So if we aren't going to have a wrapper around each tensor, how will\n", + "we actually implement our logic? We will organize our various layers\n", + "of transformations as separate Dispatcher objects, which inherit from mode and define methods for performing operations on tensors, but are not Tensors\n", + "themselves. For example, instead of defining Tensor.add(Tensor), the mode will catch the add(Tensor, Tensor) call when it hits the Pytorch dispatcher. In order to avoid the same boilerplate in every dispatcher object, we define a parent object that catches all functions and redispatches it to the correct rule based on the child Dispatcher's implementation\n", + "\n", + "Notice that unlike with the original simple functorch, we don't have to set an\n", + "inner parameter. This logic is handeled by the underlying mode implementation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59a69d9d", + "metadata": { + "lines_to_next_cell": 2, + "id": "59a69d9d" + }, + "outputs": [], + "source": [ + "class Dispatcher(TorchDispatchMode):\n", + " def apply(self, func):\n", + " if func.__name__ == \"add.Tensor\":\n", + " return self.add\n", + " if func.__name__ == \"mul.Tensor\":\n", + " return self.mul\n", + " if func.__name__ in [\"sum.default\", \"sum.dim_IntList\"]:\n", + " return self.sum\n", + " if func.__name__ == \"expand.default\":\n", + " return self.expand\n", + " if func.__name__ == \"unsqueeze.default\":\n", + " return self.unsqueeze\n", + " if func.__name__ == \"squeeze.dim\":\n", + " return self.squeeze\n", + " if func.__name__ == \"size\":\n", + " return self.size\n", + " if func.__name__ == \"ones.default\":\n", + " return self.ones\n", + " else:\n", + " raise RuntimeError(f\"Simple functorch doesn't support {func.__name__}\")\n", + "\n", + " def mul(self, lhs, rhs):\n", + " raise NotImplementedError\n", + "\n", + " def add(self, lhs, rhs):\n", + " raise NotImplementedError\n", + "\n", + " # Sum has been generalized to take an optional dim argument, which we\n", + " # will need for Batched tensors\n", + " def sum(self, input, dim=None):\n", + " raise NotImplementedError\n", + "\n", + " def expand(self, input, sizes):\n", + " raise NotImplementedError\n", + "\n", + " # For closure under Batched tensors, we need these operations...\n", + " def unsqueeze(self, input, dim):\n", + " raise NotImplementedError\n", + "\n", + " def squeeze(self, input, dim):\n", + " raise NotImplementedError\n", + "\n", + " # ...and we also need to overload the meaning of size/ones to\n", + " # hide/reinsert batch dimensions. We also introduce a concept\n", + " # of \"lifting\" a tensor to be batched by broadcasting it on\n", + " # a dimension\n", + " def size(self, input):\n", + " raise NotImplementedError\n", + "\n", + " def ones(self, size, **kwargs):\n", + " raise NotImplementedError\n", + "\n", + " def lift(self, input, d):\n", + " raise NotImplementedError\n", + "\n", + " # For convenience, we provide dim, which just returns the length of\n", + " # the sizes\n", + " def dim(self, input):\n", + " return len(self.size(input))\n", + " \n", + " def custom_vjp(self, fwd_fn, bwd_fn, *args):\n", + " # really gross but because we don't have the torch_dispatch for this, we\n", + " # need something to mimic what the mode does in torch_dispatch\n", + " old = torch._C._get_torch_dispatch_mode()\n", + " try:\n", + " torch._C._set_torch_dispatch_mode(None) # BUG: should be able to be done with enable\n", + " with self.inner.restore():\n", + " return self.inner.custom_vjp(fwd_fn, bwd_fn, *args)\n", + " finally:\n", + " torch._C._set_torch_dispatch_mode(old)\n", + "\n", + " def __torch_dispatch__(self, func, types, args=(), kwargs=None):\n", + " kwargs = kwargs if kwargs else {}\n", + " return self.apply(func)(*args, **kwargs)" + ] + }, + { + "cell_type": "markdown", + "id": "d0e21984", + "metadata": { + "lines_to_next_cell": 2, + "id": "d0e21984" + }, + "source": [ + "To start with, we can implement a labeler layer, which just labels all inputs\n", + "and outputs. This will be necessary for autograd so it should be the bottom\n", + "most layer to everything. Specifically, we'll set it's inner to be None so that\n", + "if it's used as not the innermost layer, it will error" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27101073", + "metadata": { + "lines_to_next_cell": 2, + "id": "27101073" + }, + "outputs": [], + "source": [ + "class Labeler(Dispatcher):\n", + " def mul(self, lhs, rhs):\n", + " return label(torch.mul(lhs, rhs))\n", + "\n", + " def add(self, lhs, rhs):\n", + " return label(torch.add(lhs, rhs))\n", + "\n", + " def sum(self, input, dim=None):\n", + " if dim is None:\n", + " return label(torch.sum(input))\n", + " else:\n", + " return label(torch.sum(input, dim))\n", + "\n", + " def expand(self, input, sizes):\n", + " return label(input.expand(sizes))\n", + "\n", + " def unsqueeze(self, input, dim):\n", + " return label(torch.unsqueeze(input, dim))\n", + "\n", + " def squeeze(self, input, dim):\n", + " return label(torch.squeeze(input, dim))\n", + "\n", + " def size(self, input, **kwargs):\n", + " # Return size a tuple for marginally more compact printing\n", + " assert isinstance(input, torch.Tensor)\n", + " return input.size()\n", + "\n", + " def ones(self, size, **kwargs):\n", + " return label(torch.ones(size))\n", + " \n", + " def lift(self, input, d):\n", + " assert self == d\n", + " return input\n", + "\n", + " def custom_vjp(self, fwd_fn, bwd_fn, *args):\n", + " # The backend layer for custom_vjp just calls fwd_fn.\n", + " # Why doesn't it create an autograd.Function? We're assuming the backend\n", + " # layer doesn't need to handle Autograd.\n", + " assert self.inner == None\n", + " a, b = fwd_fn(*args)\n", + " result = label(a), label(b)\n", + " return result" + ] + }, + { + "cell_type": "markdown", + "id": "e428f835", + "metadata": { + "lines_to_next_cell": 2, + "id": "e428f835" + }, + "source": [ + "Dispatcher layers are composable via object composition: we can\n", + "imagine a stack of dispatchers, each one calling into the next.\n", + "For example, the Logger dispatcher simply prints out what operation\n", + "was called on it, and then forwards on the operation to the inner\n", + "dispatcher. Unlike with simple functorch, we're able to rely on the modes to forward the call to the inner dispatcher by just calling the function again" + ] + }, + { + "cell_type": "code", + "source": [ + "def custom_vjp_str(r, fwd_fn, bwd_fn, args):\n", + " arg_names = \", \".join([a.t_name for a in args])\n", + " r_is_tensor = isinstance(r, torch.Tensor)\n", + " if r_is_tensor:\n", + " result_names = r.t_name\n", + " else:\n", + " result_names = [r.t_name for r in r]\n", + " if len(result_names) == 1:\n", + " result_names = f\"{result_names[0]},\"\n", + " else:\n", + " result_names = \", \".join(result_names)\n", + "\n", + " print(\n", + " f\"{result_names} = custom_vjp({fwd_fn.__name__}, {bwd_fn.__name__}, {arg_names})\"\n", + " )" + ], + "metadata": { + "id": "KdSPXN8GvyKw" + }, + "id": "KdSPXN8GvyKw", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1500269c", + "metadata": { + "id": "1500269c" + }, + "outputs": [], + "source": [ + "class Logger(Dispatcher):\n", + " def __init__(self, *, name):\n", + " self.name = f\" {name}\"\n", + "\n", + " def size(self, input):\n", + " # don't log size calls\n", + " return self.inner.size(input)\n", + "\n", + " def ones(self, size, **kwargs):\n", + " r = torch.ones(size)\n", + " print(f\"{self.name} {r.t_name}: {self.size(r)} = ones({size})\")\n", + " return r\n", + "\n", + " def mul(self, lhs, rhs):\n", + " r = lhs.mul(rhs)\n", + " if isinstance(rhs, float):\n", + " print(f\"{self.name} {r.t_name}: {self.size(r)} = {lhs.t_name} * {rhs}\")\n", + " else:\n", + " print(\n", + " f\"{self.name} {r.t_name}: {self.size(r)} = {lhs.t_name} * {rhs.t_name}\"\n", + " )\n", + " return r\n", + "\n", + " def add(self, lhs, rhs):\n", + " r = lhs.add(rhs)\n", + " print(f\"{self.name} {r.t_name}: {self.size(r)} = {lhs.t_name} + {rhs.t_name}\")\n", + " return r\n", + "\n", + " def sum(self, input, dim=None):\n", + " if dim is None:\n", + " r = input.sum()\n", + " else:\n", + " r = input.sum(dim)\n", + " print(f\"{self.name} {r.t_name}: {self.size(r)} = {input.t_name}.sum(dim={dim})\")\n", + " return r\n", + "\n", + " def unsqueeze(self, input, dim):\n", + " r = input.unsqueeze(dim)\n", + " print(\n", + " f\"{self.name} {r.t_name}: {self.size(r)} = {input.t_name}.unsqueeze({dim})\"\n", + " )\n", + " return r\n", + "\n", + " def squeeze(self, input, dim):\n", + " r = input.squeeze(dim)\n", + " print(f\"{self.name} {r.t_name}: {self.size(r)} = {input.t_name}.squeeze({dim})\")\n", + " return r\n", + "\n", + " def expand(self, input, sizes):\n", + " r = input.expand(sizes)\n", + " print(\n", + " f\"{self.name} {r.t_name}: {self.size(r)} = {input.t_name}.expand({sizes})\"\n", + " )\n", + " return r\n", + "\n", + " def custom_vjp(self, fwd_fn, bwd_fn, *args):\n", + " # because custom_vjp is not an aten function, we have to explicitly send\n", + " # it to its inner \n", + " r = super().custom_vjp(fwd_fn, bwd_fn, *args)\n", + " print(custom_vjp_str(r, fwd_fn, bwd_fn, args))\n", + " return r\n", + " \n", + " def lift(self, input, d):\n", + " if self == d:\n", + " return input\n", + " else:\n", + " return self.inner.lift(input, d)" + ] + }, + { + "cell_type": "markdown", + "id": "e323b552", + "metadata": { + "id": "e323b552" + }, + "source": [ + "Here is a simple example of using Logger and Torch together. Whenever\n", + "we make calls to operations, we must do so via the Dispatcher object.\n", + "We will explicitly write out all of these calls before we add wrapper\n", + "class sugaring." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26b4c5e5", + "metadata": { + "lines_to_next_cell": 2, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "26b4c5e5", + "outputId": "50d1f97d-dc8a-4ee4-9144-4046ea113350" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " Torch v338: torch.Size([2]) = ones([2])\n", + " Torch v339: torch.Size([2]) = ones([2])\n", + " Torch v340: torch.Size([2]) = v338 + v339\n", + "tensor([2., 2.])\n" + ] + } + ], + "source": [ + "with Labeler():\n", + " with Logger(name=\"Torch\"):\n", + " z = torch.ones(2) + torch.ones(2)\n", + "print(z)\n", + "assert(isinstance(z, torch.Tensor))" + ] + }, + { + "cell_type": "markdown", + "id": "1dcd422b", + "metadata": { + "lines_to_next_cell": 2, + "id": "1dcd422b" + }, + "source": [ + "With the Dispatcher structure in hand, we are now in a good place to\n", + "port the autograd implementation from Simple Autograd into our new\n", + "framework." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54293b9b", + "metadata": { + "id": "54293b9b" + }, + "outputs": [], + "source": [ + "from torch.utils._mode_utils import no_dispatch\n", + "\n", + "class Autograd(Dispatcher):\n", + " # create_graph here corresponds to the create_graph kwarg in traditional\n", + " # PyTorch, which controls whether or not the graph of the derivative\n", + " # will be constructed, allowing computing higher order derivatives.\n", + " # We will see that although create_graph=True allows Autograd to directly\n", + " # support higher order derivatives, layering an Autograd to another\n", + " # Autograd will also allow higher order derivatives.\n", + " def __init__(self, *, name=\"Autograd\", create_graph: bool = False):\n", + " self.gradient_tape = []\n", + " self.name = name\n", + " self.create_graph = create_graph\n", + "\n", + " # create_graph controls where add/mul/etc calls from the backwards\n", + " # propagators go: if you create_graph, they we're going to have you\n", + " # the current Autograd dispatcher; otherwise they're going to\n", + " # move on to the inner layer. This restores the right mode to reset (and\n", + " # the proper context manager to use)\n", + " def backward_inner(self):\n", + " if self.create_graph:\n", + " mode = self\n", + " else:\n", + " mode = self.inner\n", + " return restore_dispatcher_or_nop(mode)\n", + "\n", + " def mul(self, lhs, rhs):\n", + " if isinstance(rhs, float) and rhs == 1.0:\n", + " # peephole optimization\n", + " return lhs\n", + "\n", + " # define forward\n", + " # first, run the operation in the inner layer to get the initial\n", + " # result\n", + " r = lhs.mul(rhs)\n", + " # We directly implement printing here as it indicates whether or not\n", + " # this operation was saved to the tape or not\n", + " if isinstance(rhs, float):\n", + " print(f\"{self.name} {r.t_name}: {self.size(r)} = {lhs.t_name} * {rhs}\")\n", + " else:\n", + " print(f\"{self.name} {r.t_name}: {self.size(r)} = {lhs.t_name} * {rhs.t_name}\")\n", + "\n", + " # record what the inputs and outputs of the op were\n", + " inputs = [lhs.t_name, rhs] if isinstance(rhs, float) else [lhs.t_name, rhs.t_name]\n", + " outputs = [r.t_name]\n", + "\n", + " # define backprop\n", + " def propagate(dL_doutputs: List[Tensor]):\n", + " (dL_dr,) = dL_doutputs\n", + "\n", + " dr_dlhs = rhs # partial derivative of r = lhs*rhs\n", + " dr_drhs = lhs # partial derivative of r = lhs*rhs\n", + "\n", + " # chain rule propagation from outputs to inputs of multiply.\n", + " # Notice that the propagation rule may itself call\n", + " # other operations; depending on create_graph, they may\n", + " # either be dispatched with self or self.inner; self.backward_inner()\n", + " # controls which one we go to.\n", + " with self.backward_inner()():\n", + " dL_dlhs = dL_dr.mul(dr_dlhs)\n", + " dL_drhs = dL_dr.mul(dr_drhs)\n", + " dL_dinputs = [dL_dlhs, dL_drhs]\n", + " return dL_dinputs\n", + "\n", + " # finally, we record the compute we did on the tape\n", + " self.gradient_tape.append(\n", + " TapeEntry(inputs=inputs, outputs=outputs, propagate=propagate)\n", + " )\n", + " return r\n", + "\n", + " # The rest of the implementations follow in the same way and can\n", + " # be skipped\n", + "\n", + " def add(self, lhs, rhs):\n", + " # Add follows a similar pattern to Mul, but it doesn't end up\n", + " # capturing any variables.\n", + " r = lhs.add(rhs)\n", + " print(f\"{self.name} {r.t_name}: {self.size(r)} = {lhs.t_name} + {rhs.t_name}\")\n", + "\n", + " def propagate(dL_doutputs: List[Tensor]):\n", + " (dL_dr,) = dL_doutputs\n", + " dr_dlhs = 1.0\n", + " dr_drhs = 1.0\n", + " with self.backward_inner()():\n", + " dL_dlhs = dL_dr.mul(dr_dlhs)\n", + " dL_drhs = dL_dr.mul(dr_drhs)\n", + " return [dL_dlhs, dL_drhs]\n", + "\n", + " self.gradient_tape.append(\n", + " TapeEntry(\n", + " inputs=[lhs.t_name, rhs.t_name], outputs=[r.t_name], propagate=propagate\n", + " )\n", + " )\n", + " return r\n", + "\n", + " # Extended to handle dim argument for Batched (later)\n", + " def sum(self, input: Tensor, dim=None):\n", + " if dim is None:\n", + " r = input.sum()\n", + " else:\n", + " r = input.sum(dim)\n", + " print(f\"{self.name} {r.t_name}: {self.size(r)} = {input.t_name}.sum(dim={dim})\")\n", + "\n", + " def propagate(dL_doutputs: List[Tensor]):\n", + " (dL_dr,) = dL_doutputs\n", + " size = self.size(input)\n", + " res = dL_dr\n", + " # Broadcast over all dimensions that were reduced over\n", + " input_dim = self.inner.dim(input) # this needs to be done in inner\n", + " print(input_dim)\n", + " print(dim)\n", + " with self.backward_inner()():\n", + " for i in sum_dims(input_dim=input_dim, dim=dim):\n", + " res = res.unsqueeze(i)\n", + " out = res.expand(size)\n", + " return [out]\n", + "\n", + " self.gradient_tape.append(\n", + " TapeEntry(inputs=[input.t_name], outputs=[r.t_name], propagate=propagate)\n", + " )\n", + " return r\n", + "\n", + " # Unlike Simple Autograd, this expand requires the input to have\n", + " # been unsqueezed before hand. This lets us avoid having to do\n", + " # at::sum_to for the nontrivial case (which is more complicated)\n", + " def expand(self, input: Tensor, sizes: List[int]):\n", + " print(self.inner.dim(input))\n", + " print(len(sizes))\n", + " assert self.inner.dim(input) == len(sizes) # only works if dims match\n", + " r = input.expand(sizes)\n", + " print(\n", + " f\"{self.name} {r.t_name}: {self.size(r)} = {input.t_name}.expand({sizes})\"\n", + " )\n", + "\n", + " def propagate(dL_doutputs: List[Tensor]):\n", + " (dL_dr,) = dL_doutputs\n", + " input_size = input.size()\n", + " dims = tuple(\n", + " i for i in range(input.dim()) if input_size[i] != sizes[i]\n", + " )\n", + " # We wanted a sum keepdim=True, but I didn't want to force\n", + " # everyone to support it so manually unsqueeze\n", + " with self.backward_inner()():\n", + " res = dL_dr.sum(dims)\n", + " for d in dims:\n", + " res = res.unsqueeze(d)\n", + " return [res]\n", + "\n", + " self.gradient_tape.append(\n", + " TapeEntry(inputs=[input.t_name], outputs=[r.t_name], propagate=propagate)\n", + " )\n", + " return r\n", + "\n", + " def squeeze(self, input: Tensor, dim):\n", + " r = input.squeeze(dim)\n", + " print(\n", + " f\"{self.name} {r.t_name}: {self.size(r)} = {input.t_name}.squeeze(dim={dim})\"\n", + " )\n", + "\n", + " def propagate(dL_doutputs: List[Tensor]):\n", + " (dL_dr,) = dL_outputs\n", + " with self.backward_inner()():\n", + " res = dL_dr.unsqueeze(dim)\n", + " return [res]\n", + "\n", + " self.gradient_tape.append(\n", + " TapeEntry(inputs=[input.t_name], outputs=[r.t_name], propagate=propagate)\n", + " )\n", + " return r\n", + "\n", + " def unsqueeze(self, input: Tensor, dim):\n", + " r = input.unsqueeze(dim)\n", + " print(\n", + " f\"{self.name} {r.t_name}: {self.size(r)} = {input.t_name}.unsqueeze(dim={dim})\"\n", + " )\n", + "\n", + " def propagate(dL_doutputs: List[Tensor]):\n", + " (dL_dr,) = dL_doutputs\n", + " with self.backward_inner()():\n", + " out = dL_dr.squeeze(dim)\n", + " return [out]\n", + "\n", + " self.gradient_tape.append(\n", + " TapeEntry(inputs=[input.t_name], outputs=[r.t_name], propagate=propagate)\n", + " )\n", + " return r\n", + "\n", + " def ones(self, size, **kwargs):\n", + " return torch.ones(size)\n", + " \n", + " def custom_vjp(self, fwd_fn, bwd_fn, *args):\n", + " # To support Autograd(Autograd(Torch()), custom_vjp MUST call custom_vjp\n", + " # on the inner dispatcher. If it instead called fwd_fn(*args), then\n", + " # the inner Autograd dispatcher would not use bwd_fn in its backward pass.\n", + "\n", + " r, saved = super().custom_vjp(fwd_fn, bwd_fn, *args)\n", + " print(custom_vjp_str(r, fwd_fn, bwd_fn, args))\n", + "\n", + " # To preserve custom backward semantics, we create a lambda that calls\n", + " # bwd_fn. This lambda is then saved on the gradient tape.\n", + " def propagate(dL_doutputs: List[Tensor]):\n", + " with self.backward_inner()():\n", + " return bwd_fn(dL_doutputs, saved)\n", + "\n", + " self.gradient_tape.append(\n", + " TapeEntry(\n", + " inputs=[arg.t_name for arg in args], outputs=[r.t_name], propagate=propagate\n", + " )\n", + " )\n", + " return r, saved\n", + "\n", + " def size(self, input):\n", + " return self.inner.size(input)\n", + " \n", + " def lift(self, input, d):\n", + " if self == d:\n", + " return input\n", + " else:\n", + " return self.inner.lift(input, d)\n", + "\n", + " def grad(self, L, desired_results: List[Tensor]) -> List[Tensor]:\n", + " # this map holds dL/dX for all values X\n", + " dL_d: Dict[str, Tensor] = {}\n", + " # It starts by initializing the 'seed' dL/dL, which is 1\n", + " # TODO: indirect this via the backend\n", + " with restore_dispatcher_or_nop(self.inner)():\n", + " dL_d[L.t_name] = torch.ones(self.inner.size(L)) \n", + " print(f\"-- {self.name} d{L.t_name} -------\")\n", + "\n", + " # look up dL_dentries. If a variable is never used to compute the loss,\n", + " # we consider its gradient None, see the note below about zeros for more information.\n", + " def gather_grad(entries: List[str]):\n", + " return [dL_d[entry] if entry in dL_d else None for entry in entries]\n", + "\n", + " # propagate the gradient information backward\n", + " for entry in reversed(self.gradient_tape):\n", + " dL_doutputs = gather_grad(entry.outputs)\n", + " if all(dL_doutput is None for dL_doutput in dL_doutputs):\n", + " # optimize for the case where some gradient pathways are zero. See\n", + " # The note below for more details.\n", + " continue\n", + "\n", + " # perform chain rule propagation specific to each compute\n", + " dL_dinputs = entry.propagate(dL_doutputs)\n", + "\n", + " # Accululate the gradient produced for each input.\n", + " # Each use of a variable produces some gradient dL_dinput for that\n", + " # use. The multivariate chain rule tells us it is safe to sum\n", + " # all the contributions together.\n", + " for input, dL_dinput in zip(entry.inputs, dL_dinputs):\n", + " if input not in dL_d:\n", + " dL_d[input] = dL_dinput\n", + " else:\n", + " with self.backward_inner()():\n", + " dL_d[input] = dL_d[input].add(dL_dinput)\n", + "\n", + " # print some information to understand the values of each intermediate\n", + " # for name, value in dL_d.items():\n", + " # print(f'{self.name} d{L.t_name}_d{name} = {value.t_name}')\n", + " print(f\"------------------------\")\n", + "\n", + " return gather_grad(desired.t_name for desired in desired_results)" + ] + }, + { + "cell_type": "markdown", + "id": "901c6b4a", + "metadata": { + "id": "901c6b4a" + }, + "source": [ + "To calculate some simple gradients, we can compose Autograd with\n", + "Torch and get the result we expect." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b725c318", + "metadata": { + "id": "b725c318", + "outputId": "ab223bba-6b1a-4b52-af23-2b5cd7989e13", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Autograd v5: torch.Size([4]) = v3 + v4\n", + "Autograd v6: torch.Size([4]) = v5 * v4\n", + "a tensor([0.4963, 0.7682, 0.0885, 0.1320])\n", + "b tensor([0.3074, 0.6341, 0.4901, 0.8964])\n", + "-- Autograd dv6 -------\n", + "------------------------\n", + "da tensor([0.3074, 0.6341, 0.4901, 0.8964])\n", + "db tensor([1.1111, 2.0364, 1.0687, 1.9249])\n" + ] + } + ], + "source": [ + "torch.manual_seed(0)\n", + "a, b = label(torch.rand(4)), label(torch.rand(4))\n", + "\n", + "def simple(a, b):\n", + " t = a + b\n", + " return t.mul(b)\n", + "\n", + "\n", + "grad_dispatcher = Autograd()\n", + "\n", + "with Labeler():\n", + " with grad_dispatcher:\n", + " loss = simple(a, b)\n", + "\n", + "print(\"a\", a)\n", + "print(\"b\", b)\n", + "da, db = grad_dispatcher.grad(loss, [a, b])\n", + "print(\"da\", da)\n", + "print(\"db\", db)" + ] + }, + { + "cell_type": "markdown", + "id": "fdd64990", + "metadata": { + "id": "fdd64990" + }, + "source": [ + "To compute higher order gradients, we have two options. First,\n", + "we can do traditional PyTorch style higher order differentiation\n", + "with `create_graph=True`, writing the backpropagation computations directly\n", + "into the tape so they can be further differentiated over. This is also\n", + "what the original Simple Autograd implementation does." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e1a5342", + "metadata": { + "id": "7e1a5342", + "outputId": "95abe5d7-1ac0-4f82-9d0a-d8a97e06d055", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Autograd v13: torch.Size([4]) = v3 + v4\n", + "Autograd v14: torch.Size([4]) = v13 * v4\n", + "Autograd v15: torch.Size([]) = v14.sum(dim=None)\n", + "-- Autograd dv15 -------\n", + "Autograd v17: torch.Size([1]) = v16.unsqueeze(dim=0)\n", + "torch.Size([1])\n", + "torch.Size([4])\n", + "1\n", + "1\n", + "Autograd v18: torch.Size([4]) = v17.expand([4])\n", + "Autograd v19: torch.Size([4]) = v18 * v4\n", + "Autograd v20: torch.Size([4]) = v18 * v13\n", + "Autograd v21: torch.Size([4]) = v20 + v19\n", + "------------------------\n", + "tensor([0.3074, 0.6341, 0.4901, 0.8964])\n", + "Autograd v22: torch.Size([4]) = v19 * v19\n", + "Autograd v23: torch.Size([4]) = v21 * v21\n", + "Autograd v24: torch.Size([4]) = v22 + v23\n", + "Autograd v25: torch.Size([]) = v24.sum(dim=None)\n", + "-- Autograd dv25 -------\n", + "Autograd v27: torch.Size([1]) = v26.unsqueeze(dim=0)\n", + "torch.Size([1])\n", + "torch.Size([4])\n", + "1\n", + "1\n", + "Autograd v28: torch.Size([4]) = v27.expand([4])\n", + "Autograd v29: torch.Size([4]) = v28 * v21\n", + "Autograd v30: torch.Size([4]) = v28 * v21\n", + "Autograd v31: torch.Size([4]) = v29 + v30\n", + "Autograd v32: torch.Size([4]) = v28 * v19\n", + "Autograd v33: torch.Size([4]) = v28 * v19\n", + "Autograd v34: torch.Size([4]) = v32 + v33\n", + "Autograd v35: torch.Size([4]) = v34 + v31\n", + "Autograd v36: torch.Size([4]) = v31 * v13\n", + "Autograd v37: torch.Size([4]) = v31 * v18\n", + "Autograd v38: torch.Size([4]) = v35 * v4\n", + "Autograd v39: torch.Size([4]) = v35 * v18\n", + "Autograd v40: torch.Size([4]) = v36 + v38\n", + "Autograd v41: torch.Size([]) = v40.sum(dim=[0])\n", + "Autograd v42: torch.Size([1]) = v41.unsqueeze(dim=0)\n", + "Autograd v43: torch.Size([]) = v42.squeeze(dim=0)\n", + "Autograd v44: torch.Size([4]) = v39 + v37\n", + "------------------------\n", + "da tensor([2.2222, 4.0728, 2.1373, 3.8498])\n", + "db tensor([5.0593, 9.4137, 5.2548, 9.4926])\n" + ] + } + ], + "source": [ + "def run_gradients(d1, d2):\n", + " with Labeler():\n", + " with d1:\n", + " with d2 if d1 != d2 else contextlib.nullcontext():\n", + " # our first loss\n", + " L0 = simple(a, b).sum()\n", + "\n", + " # compute derivatives of our inputs\n", + " dL0_da, dL0_db = d2.grad(L0, [a, b])\n", + " print(dL0_da)\n", + "\n", + " # In real code, how would we switch from executing from d2 to d1?\n", + " # In functorch, the d2 dispatch calls would happen in the inside of\n", + " # a higher-order grad() or vjp() call; when we exit from this call, all\n", + " # of the involved tensors are unwrapped.\n", + "\n", + " # now lets compute the L2 norm of our derivatives\n", + " with d1.restore():\n", + " L1 = torch.sum(torch.add(dL0_da.mul(dL0_da), dL0_db.mul(dL0_db)))\n", + "\n", + " # and take the gradient of that.\n", + " # notice there are two losses involved1.\n", + " return d1.grad(L1, [a, b])\n", + "\n", + "grad_dispatcher = Autograd(create_graph=True)\n", + "da, db = run_gradients(grad_dispatcher, grad_dispatcher)\n", + "\n", + "print(\"da\", da)\n", + "print(\"db\", db)" + ] + }, + { + "cell_type": "markdown", + "id": "5b17d02c", + "metadata": { + "id": "5b17d02c" + }, + "source": [ + "Our second option is to follow functorch's implementation strategy, which\n", + "is to stack two Autograd dispatchers on top of each other. Here, it is\n", + "not necessary to `create_graph=True`, because when the backpropagator forwards\n", + "to the inner dispatcher, it will record those operations on the tape too.\n", + "But if you look at the output, you will notice something very interesting:\n", + "the first portion of the tape is exactly replicated between Autograd1 and\n", + "Autograd2: we're duplicating the tape in this case! So PyTorch's default\n", + "implementation of backwards is more efficient, because it avoids having to\n", + "record the tape twice (although this doesn't matter too much, because the\n", + "saved tensors themselves can be shared between the two tapes, so it is just\n", + "the operator graph that is duplicated).\n", + "\n", + "This is our first example of using two dispatchers. While we are\n", + "performing the inner grad, we perform our operations on the outer\n", + "dispatcher `d2`; after we are done with the inner grad we switch to\n", + "`d1` by restoring `d1` without `d2`. Intuitively, this corresponds from\n", + "passing out of the inner `grad` call to the outer `grad` call." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a7c47f5e", + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1, + "id": "a7c47f5e", + "outputId": "8ac686d8-4307-44d5-c13b-886ca7221ec1", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Autograd v45: torch.Size([4]) = v3 + v4\n", + "Autograd v45: torch.Size([4]) = v3 + v4\n", + "Autograd v46: torch.Size([4]) = v45 * v4\n", + "Autograd v46: torch.Size([4]) = v45 * v4\n", + "Autograd v47: torch.Size([]) = v46.sum(dim=None)\n", + "Autograd v47: torch.Size([]) = v46.sum(dim=None)\n", + "-- Autograd dv47 -------\n", + "Autograd v49: torch.Size([1]) = v48.unsqueeze(dim=0)\n", + "torch.Size([1])\n", + "torch.Size([4])\n", + "1\n", + "1\n", + "Autograd v50: torch.Size([4]) = v49.expand([4])\n", + "Autograd v51: torch.Size([4]) = v50 * v4\n", + "Autograd v52: torch.Size([4]) = v50 * v45\n", + "Autograd v53: torch.Size([4]) = v52 + v51\n", + "------------------------\n", + "tensor([0.3074, 0.6341, 0.4901, 0.8964])\n", + "Autograd v54: torch.Size([4]) = v51 * v51\n", + "Autograd v55: torch.Size([4]) = v53 * v53\n", + "Autograd v56: torch.Size([4]) = v54 + v55\n", + "Autograd v57: torch.Size([]) = v56.sum(dim=None)\n", + "-- Autograd dv57 -------\n", + "torch.Size([1])\n", + "torch.Size([4])\n", + "------------------------\n", + "da tensor([2.2222, 4.0728, 2.1373, 3.8498])\n", + "db tensor([5.0593, 9.4137, 5.2548, 9.4926])\n" + ] + } + ], + "source": [ + "d1 = Autograd(create_graph=False)\n", + "d2 = Autograd(create_graph=False)\n", + "\n", + "da, db = run_gradients(d2, d1)\n", + "print(\"da\", da)\n", + "print(\"db\", db)" + ] + }, + { + "cell_type": "markdown", + "id": "f343af81", + "metadata": { + "id": "f343af81" + }, + "source": [ + "Under what situations might it be profitable to keep the two tapes separate?\n", + "One guess we might have is if there is another functional transformation\n", + "wedged between the two autograd transformations. We would then expect the\n", + "backwards formula we save to be different between the two tapes. To do this, I\n", + "first need to implement batched tensors.\n", + "\n", + "One unusual thing about this implementation is that we do not need to wrap\n", + "tensors to change their sizes; instead, we just override the meaning of\n", + "size() on the dispatcher to hide batch dimensions. These calls are not\n", + "sent to the Pytorch dispatcher, so we need to explicitly call Dispatcher.size.\n", + "\n", + "One case we do not\n", + "exercise in this example is implicit broadcasting when you combine a tensor\n", + "that is not batched with a tensor that is batched: without wrappers, a user\n", + "must explicitly lift (e.g., unsqueeze and expand) tensors they wish to\n", + "replicate across the batch dimension. The code below will blindly attempt to\n", + "reinterpret a tensor as a batched tensor, even when it may not make sense (if\n", + "there is a size mismatch, however, you will get an assert failure). Similarly,\n", + "once you exit a vmap region, all previously vmap'ed tensors \"magically\" become\n", + "unbatched. functorch did not pursue this implementation because at the time\n", + "Tensor.size() was not virtual and thus it was not possible to override (this\n", + "will be changing soon)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd1d0688", + "metadata": { + "id": "dd1d0688" + }, + "outputs": [], + "source": [ + "# This implementation of Batched only supports inserting a dimension\n", + "# at the very front\n", + "class Batched(Dispatcher):\n", + " def __init__(self, *, length, name=\"Batched\"):\n", + " self.name = name\n", + " self.length = length\n", + "\n", + " def _pad_to_size(self, lhs, rhs):\n", + " lhs_size, rhs_size = self.inner.size(lhs), self.inner.size(rhs)\n", + " lhs_dim, rhs_dim = len(lhs_size), len(rhs_size)\n", + " if lhs_dim == rhs_dim:\n", + " return (lhs, rhs)\n", + " diff = rhs_dim - lhs_dim\n", + " assert diff != 0 # sanity check\n", + " new_final = rhs if diff < 0 else lhs\n", + " for _ in range(abs(diff)): # could be done as a reshape if we added that\n", + " new_final = self.unsqueeze(new_final, 0)\n", + " return (lhs, new_final) if diff < 0 else (new_final, rhs)\n", + "\n", + " def size(self, input):\n", + " sizes = self.inner.size(input)\n", + " print(sizes)\n", + " assert sizes[0] == self.length\n", + " return sizes[1:]\n", + "\n", + " def ones(self, size, **kwargs):\n", + " return torch.ones([self.length,] + size)\n", + "\n", + " def mul(self, lhs, rhs):\n", + " assert self.inner.size(lhs)[0] == self.length\n", + " if not isinstance(rhs, float):\n", + " assert self.inner.size(rhs)[0] == self.length\n", + " lhs, rhs = self._pad_to_size(lhs, rhs)\n", + " return self.inner.mul(lhs, rhs)\n", + "\n", + " def add(self, lhs, rhs):\n", + " assert self.inner.size(lhs)[0] == self.length\n", + " assert self.inner.size(rhs)[0] == self.length\n", + " lhs, rhs = self._pad_to_size(lhs, rhs)\n", + " return torch.add(lhs, rhs)\n", + "\n", + " def sum(self, input, dim=None):\n", + " # offset all the summed over dimensions by one\n", + " assert self.inner.size(input)[0] == self.length\n", + " dim = tuple(\n", + " i + 1 for i in sum_dims(input_dim=self.inner.dim(input) - 1, dim=dim)\n", + " )\n", + " return torch.sum(input, dim)\n", + "\n", + " def expand(self, input, sizes):\n", + " # offset sizes by one\n", + " assert self.inner.size(input)[0] == self.length\n", + " new_sizes = [self.inner.size(input)[0]] + sizes\n", + " return input.expand(new_sizes)\n", + "\n", + " def squeeze(self, input, dim):\n", + " # offset dim by one\n", + " assert self.inner.size(input)[0] == self.length\n", + " return torch.squeeze(input, dim + 1)\n", + "\n", + " def unsqueeze(self, input, dim):\n", + " # offset dim by one\n", + " assert self.inner.size(input)[0] == self.length\n", + " return torch.unsqueeze(input, dim + 1)\n", + "\n", + " def custom_vjp(self, fwd_fn, bwd_fn, *args):\n", + " def batchify(fn):\n", + " def new_fn(*args):\n", + " with Batched(length=self.length, name='GeneratedBatched'):\n", + " return fn(*args)\n", + " return new_fn\n", + "\n", + " # If we have Batched(Autograd(Torch()), then we would like the inner\n", + " # dispatcher to receive a call to custom_vjp so that it preserves the\n", + " # backward semantics. However, since this is the Batched dispatcher,\n", + " # we want the innermost Torch dispatcher to run a batched version of fwd_fn\n", + " # function! The way we get this to work is to create a new fwd_fn, that,\n", + " # when executed, executes a batched version of fwd_fn.\n", + " #\n", + " # Same thing for the bwd_fn.\n", + " # NB: currently simple_functorch assumes that all Tensors are batched at\n", + " # dimension 0. I'm not sure how this logic would look like without\n", + " # this assumption (in functorch tensors may not be batched).\n", + " r, saved = super().custom_vjp(batchify(fwd_fn), batchify(bwd_fn), *args)\n", + " return r, saved\n", + "\n", + " # The lift operation takes a tensor associated with some inner\n", + " # dispatcher, and \"lifts\" it so that it is interpreted neutrally\n", + " # for the outer dispatcher. For most dispatchers this is trivial,\n", + " # but for batched tensor it is not: given a tensor x, to interpret\n", + " # it as x under the Batching dispatcher, we have to expand it so\n", + " # that it is broadcasted along its first dimension.\n", + " def lift(self, input, d):\n", + " if d is self:\n", + " return input\n", + " b_input = torch.unsqueeze(input, 0)\n", + " b_input = b_input.expand((self.length,) + self.inner.size(input))\n", + " return self.inner.lift(b_input, d)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dffb9a7e", + "metadata": { + "id": "dffb9a7e", + "outputId": "7f83b456-4bdf-41ff-c2f6-997d5313e3c8", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Autograd v449: torch.Size([2, 4]) = v447 + v448\n", + "torch.Size([2, 4])\n", + "Autograd v449: torch.Size([4]) = v447 + v448\n", + "Autograd v450: torch.Size([2, 4]) = v449 * v448\n", + "Autograd v450: torch.Size([2, 4]) = v449 * v448\n", + "torch.Size([2, 4])\n", + "Autograd v450: torch.Size([4]) = v449 * v448\n", + "Autograd v451: torch.Size([2]) = v450.sum(dim=[1])\n", + "torch.Size([2])\n", + "Autograd v451: torch.Size([]) = v450.sum(dim=[0])\n", + "-- Autograd dv451 -------\n", + "torch.Size([2, 4])\n", + "torch.Size([2, 4])\n", + "1\n", + "[0]\n", + "Autograd v453: torch.Size([2, 1]) = v452.unsqueeze(dim=1)\n", + "2\n", + "2\n", + "Autograd v454: torch.Size([2, 4]) = v453.expand([2, 4])\n", + "Autograd v455: torch.Size([2, 4]) = v454 * v448\n", + "Autograd v455: torch.Size([2, 4]) = v454 * v448\n", + "Autograd v456: torch.Size([2, 4]) = v454 * v449\n", + "Autograd v456: torch.Size([2, 4]) = v454 * v449\n", + "Autograd v457: torch.Size([2, 4]) = v456 + v455\n", + "------------------------\n", + "Autograd v458: torch.Size([2, 4]) = v455 * v455\n", + "Autograd v459: torch.Size([2, 4]) = v457 * v457\n", + "Autograd v460: torch.Size([2, 4]) = v458 + v459\n", + "Autograd v461: torch.Size([]) = v460.sum(dim=None)\n", + "-- Autograd dv461 -------\n", + "2\n", + "None\n", + "------------------------\n", + "va tensor([[0.4423, 0.2768, 0.8998, 0.0960],\n", + " [0.5537, 0.3953, 0.8571, 0.6396]])\n", + "vb tensor([[0.7403, 0.6766, 0.3798, 0.3948],\n", + " [0.0880, 0.7709, 0.8970, 0.8421]])\n", + "dva tensor([[ 7.6912, 6.5198, 6.6374, 3.5426],\n", + " [ 2.9183, 7.7486, 10.6041, 9.2952]])\n", + "dvb tensor([[18.3435, 15.7460, 14.7939, 8.6646],\n", + " [ 6.1884, 18.5810, 24.7962, 21.9588]])\n" + ] + } + ], + "source": [ + "# Our inputs are batched this time!\n", + "va, vb = label(torch.rand(2, 4)), label(torch.rand(2, 4))\n", + "\n", + "def run_batched_gradients():\n", + " # our first loss\n", + " # we write the dimension we reduce on explicitly for clarity\n", + " d1 = Autograd(create_graph=False)\n", + " d3 = Autograd(create_graph=False)\n", + " with Labeler():\n", + " with d1:\n", + " with Batched(length=2):\n", + " with d3:\n", + " L0 = torch.sum(simple(va, vb), dim=0)\n", + "\n", + " # compute derivatives of our inputs\n", + " dL0_da, dL0_db = d3.grad(L0, [va, vb])\n", + "\n", + " # now lets compute the L2 norm of our derivatives\n", + " with d1.restore():\n", + " L1 = torch.sum(torch.add(dL0_da.mul(dL0_da), dL0_db.mul(dL0_db)))\n", + "\n", + " # and take the gradient of that.\n", + " # notice there are two losses involved1.\n", + " dL1_da, dL1_db = d1.grad(L1, [va, vb])\n", + " return dL1_da, dL1_db\n", + "\n", + "\n", + "dva, dvb = run_batched_gradients()\n", + "print(\"va\", va)\n", + "print(\"vb\", vb)\n", + "print(\"dva\", dva)\n", + "print(\"dvb\", dvb)" + ] + }, + { + "cell_type": "markdown", + "id": "c66d013e", + "metadata": { + "lines_to_next_cell": 2, + "id": "c66d013e" + }, + "source": [ + "To see that we have done this correctly, we could run the corresponding JAX:\n", + "\n", + "```\n", + "from jax import grad, vmap\n", + "import jax.numpy as np\n", + "\n", + "def simple(a, b):\n", + " t = a + b\n", + " return t * b\n", + "\n", + "def L0(a, b):\n", + " return np.sum(simple(a, b))\n", + "\n", + "def L1(a, b):\n", + " dL0_da, dL0_db = vmap(grad(L0, argnums=(0,1)), in_axes=0)(a, b)\n", + " return (dL0_da * dL0_da + dL0_db * dL0_db).sum()\n", + "\n", + "va = np.asarray([[0.4556, 0.6323, 0.3489, 0.4017],\n", + " [0.0223, 0.1689, 0.2939, 0.5185]])\n", + "vb = np.asarray([[0.6977, 0.8000, 0.1610, 0.2823],\n", + " [0.6816, 0.9152, 0.3971, 0.8742]])\n", + "dva, dvb = grad(L1, argnums=(0,1))(va, vb)\n", + "print(\"dva\", dva)\n", + "print(\"dvb\", dvb)\n", + "```\n", + "\n", + "Looking over the output, the tapes look similar, but we can see that the sizes\n", + "and the arguments of the operations in question differ (after all, Autograd3 is\n", + "on the inside of the vmap, while Autograd1 is outside). But it is still very\n", + "similar: we could imagine simply varying the dispatcher we use to process backwards\n", + "depending on when we are executing the tape. In fact, this is exactly what an\n", + "initial, non-functorch implementation of PyTorch did to support per-sample\n", + "gradients.\n", + "\n", + "Exercise: modify Autograd.grad to accept a dispatcher, and use that dispatcher\n", + "instead of self.backward_inner() when running propagator functions. Then, rewrite\n", + "the above example so that it only has one level of Autograd:\n", + "Batched(Autograd(Torch(), create_graph=True)) and show you still get the same\n", + "result." + ] + }, + { + "cell_type": "markdown", + "id": "2706fca7", + "metadata": { + "id": "2706fca7" + }, + "source": [ + "OK, so all of this dispatcher business is all nice and explicit, but\n", + "that's not what JAX/functorch's interface looks like. How do we\n", + "bridge the gap? Unlike with simple functorch 1.0, we don't have to\n", + "set the global mode since we're using the context managers and modes\n", + "to set that" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0d9f0bb", + "metadata": { + "id": "c0d9f0bb" + }, + "outputs": [], + "source": [ + "# now unnecessary :)\n", + "# DISPATCHER = Labeler()\n", + "\n", + "\n", + "# @contextlib.contextmanager\n", + "# def dispatcher(d):\n", + "# global DISPATCHER\n", + "# old_d = DISPATCHER\n", + "# DISPATCHER = d\n", + "# try:\n", + "# yield\n", + "# finally:\n", + "# DISPATCHER = old_d" + ] + }, + { + "cell_type": "markdown", + "id": "76f9d0e9", + "metadata": { + "id": "76f9d0e9" + }, + "source": [ + "A dispatcher mode, however, is not enough. Remember that in our\n", + "implementation of Batched, we blindly assumed that all tensors were\n", + "batched, even if this did not necessarily make sense. If I have\n", + "`vmap(lambda bx: bx + y)(x)`, with `x: (B,X)` and `y: (X,)`, the\n", + "underlying operation should broadcast y to `(B,X)` and then do the\n", + "addition with x (bx advertises that it has size `(X,)` inside of the\n", + "vmap'd lambda). To know this should happen, it is necessary for\n", + "us to know that y is not a batched tensor, but x is a batched tensor.\n", + "We'll resolve this with a wrapper class called FuncTensor, which\n", + "records both the underlying Tensor, as well as the Dispatcher which\n", + "this tensor is associated with. In the above example, `bx.dispatcher`\n", + "might be `Batched(Torch())`, whereas `x.dispatcher` is `Torch()`.\n", + "\n", + "So our general strategy is as follows:\n", + " 1. Every tensor is associated with a dispatcher\n", + " 2. You can lift tensors to dispatchers which wrap them (which can\n", + " trigger some operations, like expand for Batched); this is\n", + " implemented by `dispatcher_wraps`\n", + " 3. To perform an operation between to tensors, lift them so that\n", + " they all have the same dispatcher, then do the operation on\n", + " that dispatcher." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c80652b4", + "metadata": { + "lines_to_end_of_cell_marker": 2, + "lines_to_next_cell": 2, + "id": "c80652b4" + }, + "outputs": [], + "source": [ + "# A dispatcher d1 wraps another dispatcher d2 if d2 is an ancestor of\n", + "# d1 in the tree structure. We've defined this relation to be\n", + "# reflexive, in the same way issubclass(A, A) == True.\n", + "def dispatcher_wraps(d1, d2):\n", + " # Treat this as a reflexive relation\n", + " if d1 is d2 or (d1 is not None and d2 in d1.ancestors) or d2 is None:\n", + " return True\n", + " return False\n", + "\n", + "\n", + "# Given a list of arguments, lift them all up to a common dispatcher\n", + "# level, returning that dispatcher as well as the lifted arguments.\n", + "# Note that the current dispatcher is also accounted for by getting the current\n", + "# mode! In autodidax, this is `find_top_trace`.\n", + "def lift_and_unwrap_args(*args):\n", + " outermost = _get_torch_dispatch_mode()\n", + " for a in args:\n", + " if dispatcher_wraps(outermost, a.dispatcher):\n", + " pass\n", + " elif dispatcher_wraps(a.dispatcher, outermost):\n", + " # You can make this case an error as well if you don't\n", + " # want to support non-lexical functorch tensors\n", + " outermost = a.dispatcher\n", + " else:\n", + " raise TypeError(\"incompatible dispatcher trees\")\n", + " return (outermost,) + tuple(a.lift(outermost).tensor for a in args)" + ] + }, + { + "cell_type": "markdown", + "id": "d9ad2cfe", + "metadata": { + "lines_to_next_cell": 2, + "id": "d9ad2cfe" + }, + "source": [ + "The actual implementation of the wrapper tensor which tracks the\n", + "Dispatcher for a tensor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "454659ce", + "metadata": { + "id": "454659ce" + }, + "outputs": [], + "source": [ + "@dataclass\n", + "class FuncTensor:\n", + " tensor: Tensor\n", + " dispatcher: Dispatcher\n", + "\n", + " # Lift a FuncTensor to an outer dispatcher\n", + " def lift(self, d):\n", + " # You can only lift to a dispatcher which wraps the dispatcher\n", + " # this FuncTensor is associated with (not vice versa, or between\n", + " # unrelated FuncTensors).\n", + " assert dispatcher_wraps(d, self.dispatcher)\n", + " return FuncTensor(d.lift(self.tensor, self.dispatcher), d)\n", + "\n", + " # The general strategy for any operation performed on a tensor, we\n", + " # lift all the arguments so that they live on the same dispatcher\n", + " # level, and then perform the operation on that dispatcher. The\n", + " # resulting tensor is tagged at whatever dispatcher we had run the\n", + " # tensor on.\n", + " def __mul__(self, other):\n", + " d, self, other = lift_and_unwrap_args(self, other)\n", + " with restore_dispatcher_or_nop(d)():\n", + " return FuncTensor(self.mul(other), d)\n", + "\n", + " def __add__(self, other):\n", + " d, self, other = lift_and_unwrap_args(self, other)\n", + " with restore_dispatcher_or_nop(d)():\n", + " return FuncTensor(self.add(other), d)\n", + "\n", + " def sum(self, dim=None):\n", + " d, self = lift_and_unwrap_args(self)\n", + " with restore_dispatcher_or_nop(d)():\n", + " if dim is None:\n", + " res = self.sum()\n", + " else:\n", + " res = self.sum(dim)\n", + " return FuncTensor(res, d)\n", + "\n", + " def expand(self, sizes):\n", + " d, self = lift_and_unwrap_args(self)\n", + " with restore_dispatcher_or_nop(d)():\n", + " return FuncTensor(self.expand(sizes), d)\n", + "\n", + " def unsqueeze(self, dim):\n", + " d, self = lift_and_unwrap_args(self)\n", + " with restore_dispatcher_or_nop(d)():\n", + " return FuncTensor(self.unsqueeze(dim), d)\n", + "\n", + " def squeeze(self, dim):\n", + " d, self = lift_and_unwrap_args(self)\n", + " with restore_dispatcher_or_nop(d)():\n", + " return FuncTensor(self.squeeze(dim), d)\n", + "\n", + " def size(self):\n", + " d, self = lift_and_unwrap_args(self)\n", + " return d.size(self)\n", + "\n", + " def dim(self):\n", + " d, self = lift_and_unwrap_args(self)\n", + " return d.size(self)\n", + "\n", + " # Factory functions like ones do not have any Tensor arguments,\n", + " # so they rely solely on the current mode\n", + " @staticmethod\n", + " def ones(size):\n", + " return torch.ones(size)" + ] + }, + { + "cell_type": "markdown", + "id": "6063992f", + "metadata": { + "id": "6063992f" + }, + "source": [ + "Now we are ready to implement grad. First, we need some helper\n", + "functions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0926b942", + "metadata": { + "id": "0926b942" + }, + "outputs": [], + "source": [ + "# When we are done doing a vmap/grad, we need to take the results and\n", + "# lower them back to a lower dispatcher on the stack (this is always\n", + "# a no-op, in particular, in the vmap case, when we exit vmap the user\n", + "# gets to see the batched dimension again.)\n", + "def unlift(t, d):\n", + " if isinstance(t, list):\n", + " return [unlift(x, d) for x in t]\n", + " elif isinstance(t, tuple):\n", + " return tuple(unlift(x, d) for x in t)\n", + " else:\n", + " if t.dispatcher is d:\n", + " return t\n", + " return unlift(FuncTensor(t.tensor, t.dispatcher.inner), d)\n", + "\n", + "\n", + "# This lets us easily pick out arguments as specified by argnums\n", + "def filter_argnums(args, argnums):\n", + " if isinstance(argnums, int):\n", + " return (args[argnums],)\n", + " else:\n", + " return tuple(args[i] for i in argnums)" + ] + }, + { + "cell_type": "markdown", + "id": "3cc16dc6", + "metadata": { + "id": "3cc16dc6" + }, + "source": [ + "Now grad and vmap!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8778b72f", + "metadata": { + "id": "8778b72f" + }, + "outputs": [], + "source": [ + "# For simplicity, these functions only take tuples, not pytrees\n", + "def grad(f, argnums=0):\n", + " @functools.wraps(f)\n", + " def wrapped_f(*args):\n", + " # We first lift and unwrap all of the arguments which we want\n", + " # to pass into the function\n", + " old_d, *args = lift_and_unwrap_args(*args)\n", + " assert old_d == _get_torch_dispatch_mode()\n", + " d = Autograd()\n", + " with d:\n", + " # We pass in the functions at the new Autograd level (they\n", + " # were lifted to old_d, and lifting to d is a noop)\n", + " L = f(*(FuncTensor(a, d) for a in args))\n", + " assert L.dispatcher is d\n", + " # Run the autograd pass, getting the grads for the inputs\n", + " # as specified by argnums\n", + " grads = d.grad(L.tensor, filter_argnums(args, argnums))\n", + " # Finally, construct the grads at the lower level and return\n", + " # them\n", + " return [FuncTensor(r, old_d) for r in grads]\n", + "\n", + " return wrapped_f" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0134b6c8", + "metadata": { + "id": "0134b6c8" + }, + "outputs": [], + "source": [ + "def vmap(f):\n", + " @functools.wraps(f)\n", + " def wrapped_f(*args):\n", + " # cannot vmap over no arguments as this function uses the\n", + " # arguments to determine how large the batch dimension is\n", + " # (hypothetically, you could explicitly pass in the batch\n", + " # size, and then use this to control factory functions;\n", + " # JAX doesn't seem to have a knob to do this)\n", + " assert args\n", + " old_d, *args = lift_and_unwrap_args(*args)\n", + " d = Batched(length=args[0].size()[0])\n", + " for a in args:\n", + " assert a.size()[0] == d.length\n", + " with d:\n", + " # Rewrap all the arguments as batched tensors, then\n", + " # unwrap any batched tensors that escape\n", + " return unlift(f(*(FuncTensor(a, d) for a in args)), old_d)\n", + "\n", + " return wrapped_f" + ] + }, + { + "cell_type": "markdown", + "id": "159d3937", + "metadata": { + "id": "159d3937" + }, + "source": [ + "Now we can rerun our example using the high level grad/vmap functions!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a24e765", + "metadata": { + "id": "2a24e765", + "outputId": "93bdeda6-d19a-4748-95fa-4c8cae61f93c", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "tensor([[0.4194, 0.5529, 0.9527, 0.0362],\n", + " [0.1852, 0.3734, 0.3051, 0.9320]])\n", + "tensor([[0.1759, 0.2698, 0.1507, 0.0317],\n", + " [0.2081, 0.9298, 0.7231, 0.7423]])\n", + "Autograd v165: torch.Size([2, 4]) = v124 + v125\n", + "torch.Size([2, 4])\n", + "Autograd v165: torch.Size([4]) = v124 + v125\n", + "Autograd v166: torch.Size([2, 4]) = v165 * v125\n", + "torch.Size([2, 4])\n", + "Autograd v166: torch.Size([4]) = v165 * v125\n", + "Autograd v167: torch.Size([2]) = v166.sum(dim=[1])\n", + "torch.Size([2])\n", + "Autograd v167: torch.Size([]) = v166.sum(dim=None)\n", + "-- Autograd dv167 -------\n", + "torch.Size([2, 4])\n", + "torch.Size([2, 4])\n", + "Autograd v169: torch.Size([2, 1]) = v168.unsqueeze(dim=1)\n", + "torch.Size([2, 1])\n", + "torch.Size([1])\n", + "torch.Size([4])\n", + "2\n", + "2\n", + "Autograd v170: torch.Size([2, 4]) = v169.expand([2, 4])\n", + "Autograd v171: torch.Size([2, 4]) = v170 * v125\n", + "Autograd v172: torch.Size([2, 4]) = v170 * v165\n", + "Autograd v173: torch.Size([2, 4]) = v172 + v171\n", + "------------------------\n", + "Autograd v174: torch.Size([2, 4]) = v171 * v171\n", + "Autograd v175: torch.Size([2, 4]) = v173 * v173\n", + "Autograd v176: torch.Size([2, 4]) = v174 + v175\n", + "Autograd v177: torch.Size([]) = v176.sum(dim=None)\n", + "-- Autograd dv177 -------\n", + "torch.Size([1, 1])\n", + "torch.Size([2, 4])\n", + "------------------------\n", + "dva FuncTensor(tensor=tensor([[1.5425, 2.1851, 2.5082, 0.1992],\n", + " [1.2030, 4.4660, 3.5026, 4.8333]]), dispatcher=<__main__.Labeler object at 0x7f6f87dfe790>)\n", + "dvb FuncTensor(tensor=tensor([[ 3.4367, 4.9100, 5.3178, 0.4619],\n", + " [ 2.8222, 10.7917, 8.4515, 11.1514]]), dispatcher=<__main__.Labeler object at 0x7f6f87dfe790>)\n" + ] + } + ], + "source": [ + "def simple(a, b):\n", + " t = a + b\n", + " return t * b\n", + "\n", + "\n", + "def L0(a, b):\n", + " return simple(a, b).sum()\n", + "\n", + "\n", + "def L1(a, b):\n", + " dL0_da, dL0_db = vmap(grad(L0, argnums=(0, 1)))(a, b)\n", + " return (dL0_da * dL0_da + dL0_db * dL0_db).sum()\n", + "\n", + "l = Labeler()\n", + "print(va)\n", + "print(vb)\n", + "with l:\n", + " fva = FuncTensor(va, l)\n", + " fvb = FuncTensor(vb, l)\n", + " dva, dvb = grad(L1, argnums=(0, 1))(fva, fvb)\n", + "print(\"dva\", dva)\n", + "print(\"dvb\", dvb)" + ] + }, + { + "cell_type": "markdown", + "id": "268470a8", + "metadata": { + "id": "268470a8" + }, + "source": [ + "Because FuncTensors are associated with the ambient dispatcher they\n", + "were created from, they are also allowed to escape from the context in\n", + "which they were defined, allowing for non-lexical, imperative\n", + "transform API. For example, batching over module parameters is\n", + "problematic today, but all we need to do is tweak the FuncTensor's\n", + "dispatchers appropriately and everything works out." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29d937c8", + "metadata": { + "id": "29d937c8", + "outputId": "927985e2-baca-46fe-83db-3bfe493b60fd", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "expect tensor([[-0.6632, 0.5145, 0.7165],\n", + " [-0.6439, 0.5264, -0.8546]])\n", + "output tensor([[-0.6632, 0.5145, 0.7165],\n", + " [-0.6439, 0.5264, -0.8546]])\n" + ] + } + ], + "source": [ + "B = 2\n", + "\n", + "# this is a bug, we should be able to set inner in the constructor and have that set the ancestors correctly\n", + "base_dispatcher = Labeler()\n", + "batched_dispatcher = Batched(length=B)\n", + "with base_dispatcher:\n", + " with batched_dispatcher:\n", + " pass\n", + "\n", + "PlainTensor = lambda t: FuncTensor(torch.randn(N), base_dispatcher)\n", + "BatchedTensor = lambda t: FuncTensor(t, batched_dispatcher)\n", + "\n", + "class ScaleBiasModule:\n", + " weight: FuncTensor\n", + " bias: FuncTensor\n", + "\n", + " def __init__(self, N):\n", + " self.weight = PlainTensor(torch.randn(N))\n", + " self.bias = PlainTensor(torch.randn(N))\n", + "\n", + " def forward(self, input):\n", + " return self.weight * input + self.bias\n", + "\n", + "\n", + "B = 2\n", + "N = 3\n", + "m = ScaleBiasModule(N)\n", + "# Ensemble weights only; input is not batched\n", + "m.weight = BatchedTensor(torch.randn(B, N))\n", + "input = PlainTensor(torch.randn(N))\n", + "output = m.forward(input)\n", + "print(\n", + " \"expect\", input.tensor.unsqueeze(0) * m.weight.tensor + m.bias.tensor.unsqueeze(0)\n", + ")\n", + "print(\"output\", output.tensor)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Higher-order operations in simple functorch!\n", + "\n", + "Problem: users want to define functions with custom forward and backward\n", + "passes. These functions call PyTorch operations. When we vmap over such a\n", + "function, we would like for the backward pass to be preserved.\n", + "\n", + "Why is this difficult? In PyTorch today, vmap over an autograd.Function\n", + "effectively runs vmap on the forward pass of the autograd.Function.\n", + "Meanwhile, autograd records the transformed operations for backward, instead\n", + "of the custom backward pass we specified in the autograd.Function!\n", + "\n", + "Solution: We're going to introduce a `custom_vjp` primitive that accepts\n", + "functions and varargs Tensor arguments and demonstrate that it resolves\n", + "the problem.\n", + "\n", + "custom_vjp(fwd_fn, bwd_fn, *args) takes in two functions as arguments.\n", + "We add a little helper function so that the user is not explicitly calling\n", + "this function on the active dispatcher" + ], + "metadata": { + "id": "PwQ4ob6w0cmS" + }, + "id": "PwQ4ob6w0cmS" + }, + { + "cell_type": "code", + "source": [ + "def custom_vjp(fwd_fn, bwd_fn, *args):\n", + " d = _get_torch_dispatch_mode()\n", + " return d.custom_vjp(fwd_fn, bwd_fn, *args)" + ], + "metadata": { + "id": "K64aCQby1300" + }, + "id": "K64aCQby1300", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "For our custom function, we want f(x) = x * x, but we install a custom\n", + "backwards pass that computes 32 * x (instead of 2 * x) so we can tell\n", + "if custom_vjp is working." + ], + "metadata": { + "id": "0k3nmKdL2KNt" + }, + "id": "0k3nmKdL2KNt" + }, + { + "cell_type": "code", + "source": [ + "a = label(torch.rand(4))\n", + "va = label(torch.rand(2, 4))\n", + "\n", + "def f_fwd(x):\n", + " # Our convention is that f_fwd returns (outputs, \"saved\")\n", + " return x.mul(x), x\n", + "\n", + "# Our convention is that f_bwd accepts (dispatcher, gradOutputs, \"saved\")\n", + "def f_bwd(gradOutputs, x):\n", + " gO, = gradOutputs\n", + " # Should be gO * 2 * x, but we're gonna do gO * 32 * x to demonstrate things\n", + " with no_dispatch():\n", + " thirty_two = torch.tensor(32.) # a hack so I don't have to override lift\n", + "\n", + " return [torch.mul(gO.mul(x), label(thirty_two, 'thirty_two'))]" + ], + "metadata": { + "id": "lFmnq0oP1aTz" + }, + "id": "lFmnq0oP1aTz", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def run_grad():\n", + " grad_dispatcher = Autograd()\n", + " with Labeler():\n", + " with grad_dispatcher:\n", + " # Here's how to invoke custom_vjp.\n", + " r, _ = custom_vjp(f_fwd, f_bwd, a)\n", + " L3 = torch.sum(r)\n", + " print(grad_dispatcher.gradient_tape)\n", + " dL3_a, = grad_dispatcher.grad(L3, [a])\n", + "\n", + " print(dL3_a)\n", + " print(a)\n", + " # Check that the gradients are indeed 32 * a\n", + " assert torch.allclose(dL3_a, 32 * a)\n", + "\n", + "run_grad()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7PJ1yolV1UW_", + "outputId": "6983dc93-13e6-4227-92b7-675fe917b677" + }, + "id": "7PJ1yolV1UW_", + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "v500 = custom_vjp(f_fwd, f_bwd, v498)\n", + "None\n", + "Autograd v501: torch.Size([]) = v500.sum(dim=None)\n", + "[TapeEntry(inputs=['v498'], outputs=['v500'], propagate=.propagate at 0x7f6f7f223cb0>), TapeEntry(inputs=['v500'], outputs=['v501'], propagate=.propagate at 0x7f6f7f223c20>)]\n", + "-- Autograd dv501 -------\n", + "1\n", + "None\n", + "------------------------\n", + "tensor([ 6.7094, 27.1989, 10.2486, 29.4958])\n", + "tensor([0.2097, 0.8500, 0.3203, 0.9217])\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Multiple layers of autograd\n", + "def run_gradgrad():\n", + " grad_dispatcher_1 = Autograd()\n", + " grad_dispatcher_2 = Autograd()\n", + " with Labeler():\n", + " with Logger(name=\"Torch\"):\n", + " with grad_dispatcher_1:\n", + " with grad_dispatcher_2:\n", + " r, _ = custom_vjp(f_fwd, f_bwd, a)\n", + " L4 = r.sum()\n", + " dL4_a, = grad_dispatcher_2.grad(L4, [a])\n", + "\n", + " # Evidence that d2 respected the custom_vjp's f_bwd\n", + " assert torch.allclose(dL4_a, 32 * a)\n", + "\n", + " assert hasattr(dL4_a, 't_name')\n", + " with grad_dispatcher_1.restore():\n", + " dL4_a_sum = dL4_a.sum()\n", + " ddL4_a_a, = grad_dispatcher_1.grad(dL4_a_sum, [a])\n", + "\n", + " # Evidence that d1 respected the custom_vjp's f_bwd\n", + " assert torch.allclose(ddL4_a_a, torch.ones_like(a) * 32)\n", + "\n", + "run_gradgrad()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "u36V5Du7Dbs3", + "outputId": "95cff424-c51e-4efe-8d6a-1479af4183f4" + }, + "id": "u36V5Du7Dbs3", + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "v507, v498 = custom_vjp(f_fwd, f_bwd, v498)\n", + "None\n", + "v507 = custom_vjp(f_fwd, f_bwd, v498)\n", + "None\n", + "v507 = custom_vjp(f_fwd, f_bwd, v498)\n", + "None\n", + " Torch v508: torch.Size([]) = v507.sum(dim=None)\n", + "Autograd v508: torch.Size([]) = v507.sum(dim=None)\n", + "Autograd v508: torch.Size([]) = v507.sum(dim=None)\n", + " Torch v509: torch.Size([]) = ones([])\n", + "-- Autograd dv508 -------\n", + "1\n", + "None\n", + " Torch v510: torch.Size([1]) = v509.unsqueeze(0)\n", + "Autograd v510: torch.Size([1]) = v509.unsqueeze(dim=0)\n", + "1\n", + "1\n", + " Torch v511: torch.Size([4]) = v510.expand([4])\n", + "Autograd v511: torch.Size([4]) = v510.expand([4])\n", + " Torch v512: torch.Size([4]) = v511 * v498\n", + "Autograd v512: torch.Size([4]) = v511 * v498\n", + " Torch v513: torch.Size([4]) = v512 * thirty_two\n", + "Autograd v513: torch.Size([4]) = v512 * thirty_two\n", + "------------------------\n", + " Torch v514: torch.Size([]) = v513.sum(dim=None)\n", + "Autograd v514: torch.Size([]) = v513.sum(dim=None)\n", + " Torch v515: torch.Size([]) = ones([])\n", + "-- Autograd dv514 -------\n", + "1\n", + "None\n", + " Torch v516: torch.Size([1]) = v515.unsqueeze(0)\n", + " Torch v517: torch.Size([4]) = v516.expand([4])\n", + " Torch v518: torch.Size([4]) = v517 * thirty_two\n", + " Torch v519: torch.Size([4]) = v517 * v512\n", + " Torch v520: torch.Size([4]) = v518 * v498\n", + " Torch v521: torch.Size([4]) = v518 * v511\n", + " Torch v522: torch.Size([]) = v520.sum(dim=[0])\n", + " Torch v523: torch.Size([1]) = v522.unsqueeze(0)\n", + " Torch v524: torch.Size([]) = v523.squeeze(0)\n", + "------------------------\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "And now, let's try that again, with grad(lambda x: vmap(f)(x).sum()).\n", + "The goal of custom_vjp is to make it so that vmap(custom_vjp) still\n", + "preserves the backward semantics." + ], + "metadata": { + "id": "BSipEfS1E7jc" + }, + "id": "BSipEfS1E7jc" + }, + { + "cell_type": "code", + "source": [ + "def f_fwd(x):\n", + " return x.mul(x), x\n", + "\n", + "def f_bwd(gradOutputs, x):\n", + " gO, = gradOutputs\n", + " # Should be gO * 2 * x, but we're gonna do gO * 32 * x to prove a point\n", + " return [torch.mul(gO.mul(x), torch.ones(()) * 32.)]\n", + "\n", + "def run_gradvmap():\n", + " grad_dispatcher = Autograd()\n", + " with Labeler():\n", + " with grad_dispatcher:\n", + " with Batched(length=2):\n", + " r, _ = custom_vjp(f_fwd, f_bwd, va)\n", + " L99 = r.sum()\n", + " dL99_a, = grad_dispatcher.grad(L99, [va])\n", + "\n", + " # As you can see, d1.grad still calls f_bwd.\n", + " # The way we got this to work is that Batched.custom_vjp\n", + " # calls custom_vjp on its inner dispatcher.\n", + " # Scroll up to the implementation of Batched for more details.\n", + " assert torch.allclose(dL99_a, 32 * va)\n", + "\n", + "run_gradvmap()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "_-BPlF2qFBZ8", + "outputId": "2a56dab0-0de0-4fb0-e098-19e4707d52b0" + }, + "id": "_-BPlF2qFBZ8", + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "v561 = custom_vjp(new_fn, new_fn, v499)\n", + "None\n", + "Autograd v562: torch.Size([]) = v561.sum(dim=None)\n", + "-- Autograd dv562 -------\n", + "2\n", + "None\n", + "------------------------\n" + ] + } + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,py:light" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.11" + }, + "colab": { + "name": "simple_functorch_modes.ipynb", + "provenance": [], + "collapsed_sections": [] + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/simple_functorch_modes.py b/simple_functorch_modes.py new file mode 100644 index 0000000..5fcb131 --- /dev/null +++ b/simple_functorch_modes.py @@ -0,0 +1,1302 @@ +# -*- coding: utf-8 -*- +"""simple_functorch_modes.ipynb + +Automatically generated by Colaboratory. + +Original file is located at + https://colab.research.google.com/drive/13zHv0UwdzAPeW07QinIiit27AHbzXJ8L + +## Simple Functorch but Make it Modes + +This notebook is a rewrite of the simple functorch notebook that uses torch dispatch modes instead of the Dispatcher object +""" + +!pip uninstall -y torch +!pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --upgrade + +"""This notebook walks through a self-contained implementation of +functorch, including support for both vjp and vmap combinators (using +PyTorch only to implement primitive tensor operations). It follows +the tradition of +[Autodidax](https://jax.readthedocs.io/en/latest/autodidax.html) (a +pedagogical reimplementation of JAX, the library functorch is inspired +by) and [Simple +Autograd](https://colab.research.google.com/drive/1VpeE6UvEPRz9HmsHh1KS0XxXjYu533EC?usp=sharing) +(Zachary Devito's pedagogical reimplementation of autograd, which the +autograd system in this notebook is based off of.) You can [open this +file in +Colab](https://colab.research.google.com/github/albanD/subclass_zoo/blob/main/simple_functorch.ipynb) +and play around with the examples. + +As a simplified implementation of functorch, this notebook also makes +it easier to investigate some more subtle aspects of how PyTorch's +native autograd system interacts with composable transforms. In +particular, we will see that PyTorch's native implementation of double +backwards (which shares the same tape through multiple levels of +differentiation) differs from functorch's nested grad implementation +(which maintains a separate tape per level). + +To get started, we replicate some of the data structures and helper functions +from Simple Autograd. +""" + +import contextlib +import functools +from dataclasses import dataclass +from typing import Callable, Dict, List, NamedTuple, Optional + +import torch +from torch import Tensor +from torch.utils._python_dispatch import TorchDispatchMode + + +class TapeEntry(NamedTuple): + # names of the inputs to the original computation + inputs: List[str] + # names of the outputs of the original computation + outputs: List[str] + # apply chain rule + propagate: Callable[[List[Tensor]], List[Tensor]] + + +_name = 0 + + +def fresh_name() -> str: + """create a new unique name for a variable: v0, v1, v2""" + global _name + r = f"v{_name}" + _name += 1 + return r + +"""This is a little helper function for converting the dim argument in +sum into an explicit list of dimensions that will be reduced over. +It takes the dim of the tensor we are summing over and the dim +argument itself. +""" + +def sum_dims(*, input_dim, dim): + if dim is None: + return tuple(range(0, input_dim)) + elif isinstance(dim, int): + return (dim,) + else: + return tuple(sorted(dim)) + +"""This is another little helper function that we might want to incorporate into the default behavior of restore. But for now we need this in order to not error when restoring a mode""" + +from torch._C import _get_torch_dispatch_mode + +def restore_dispatcher_or_nop(dispatcher): + # another argument that maybe we should let .restore() keep the current mode + if dispatcher == _get_torch_dispatch_mode(): + # we'll just no-op here since restoring the current mode will error + return contextlib.nullcontext + return dispatcher.restore + +"""In Simple Autograd, we provided a Variable wrapper class which +provided a traditional Tensor style interface for our objects; in +functorch proper, objects are repeatedly wrapped in this way to +implement multipler layers of transformations. + +In my opinion, this sort of wrapper makes it more difficult to +understand the flow of logic. So in Simple Functorch, we take a +different approach: we won't make use of a wrapper class at all, +instead showing how to add it in the end as syntax sugar on top of our +system. + +For debuggability purposes, however, it is nice to have a way to +identify variables by a human readable name. We'll do this by setting +a t_name attribute on PyTorch tensors whenever we allocate a new +tensor. +""" + +def label(t: Tensor, name: str = None): + if not hasattr(t, "t_name"): + t.t_name = name or fresh_name() + return t + +"""So if we aren't going to have a wrapper around each tensor, how will +we actually implement our logic? We will organize our various layers +of transformations as separate Dispatcher objects, which inherit from mode and define methods for performing operations on tensors, but are not Tensors +themselves. For example, instead of defining Tensor.add(Tensor), the mode will catch the add(Tensor, Tensor) call when it hits the Pytorch dispatcher. In order to avoid the same boilerplate in every dispatcher object, we define a parent object that catches all functions and redispatches it to the correct rule based on the child Dispatcher's implementation + +Notice that unlike with the original simple functorch, we don't have to set an +inner parameter. This logic is handeled by the underlying mode implementation +""" + +class Dispatcher(TorchDispatchMode): + def apply(self, func): + if func.__name__ == "add.Tensor": + return self.add + if func.__name__ == "mul.Tensor": + return self.mul + if func.__name__ in ["sum.default", "sum.dim_IntList"]: + return self.sum + if func.__name__ == "expand.default": + return self.expand + if func.__name__ == "unsqueeze.default": + return self.unsqueeze + if func.__name__ == "squeeze.dim": + return self.squeeze + if func.__name__ == "size": + return self.size + if func.__name__ == "ones.default": + return self.ones + else: + raise RuntimeError(f"Simple functorch doesn't support {func.__name__}") + + def mul(self, lhs, rhs): + raise NotImplementedError + + def add(self, lhs, rhs): + raise NotImplementedError + + # Sum has been generalized to take an optional dim argument, which we + # will need for Batched tensors + def sum(self, input, dim=None): + raise NotImplementedError + + def expand(self, input, sizes): + raise NotImplementedError + + # For closure under Batched tensors, we need these operations... + def unsqueeze(self, input, dim): + raise NotImplementedError + + def squeeze(self, input, dim): + raise NotImplementedError + + # ...and we also need to overload the meaning of size/ones to + # hide/reinsert batch dimensions. We also introduce a concept + # of "lifting" a tensor to be batched by broadcasting it on + # a dimension + def size(self, input): + raise NotImplementedError + + def ones(self, size, **kwargs): + raise NotImplementedError + + def lift(self, input, d): + raise NotImplementedError + + # For convenience, we provide dim, which just returns the length of + # the sizes + def dim(self, input): + return len(self.size(input)) + + def custom_vjp(self, fwd_fn, bwd_fn, *args): + # really gross but because we don't have the torch_dispatch for this, we + # need something to mimic what the mode does in torch_dispatch + old = torch._C._get_torch_dispatch_mode() + try: + torch._C._set_torch_dispatch_mode(None) # BUG: should be able to be done with enable + with self.inner.restore(): + return self.inner.custom_vjp(fwd_fn, bwd_fn, *args) + finally: + torch._C._set_torch_dispatch_mode(old) + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs if kwargs else {} + return self.apply(func)(*args, **kwargs) + +"""To start with, we can implement a labeler layer, which just labels all inputs +and outputs. This will be necessary for autograd so it should be the bottom +most layer to everything. Specifically, we'll set it's inner to be None so that +if it's used as not the innermost layer, it will error +""" + +class Labeler(Dispatcher): + def mul(self, lhs, rhs): + return label(torch.mul(lhs, rhs)) + + def add(self, lhs, rhs): + return label(torch.add(lhs, rhs)) + + def sum(self, input, dim=None): + if dim is None: + return label(torch.sum(input)) + else: + return label(torch.sum(input, dim)) + + def expand(self, input, sizes): + return label(input.expand(sizes)) + + def unsqueeze(self, input, dim): + return label(torch.unsqueeze(input, dim)) + + def squeeze(self, input, dim): + return label(torch.squeeze(input, dim)) + + def size(self, input, **kwargs): + # Return size a tuple for marginally more compact printing + assert isinstance(input, torch.Tensor) + return input.size() + + def ones(self, size, **kwargs): + return label(torch.ones(size)) + + def lift(self, input, d): + assert self == d + return input + + def custom_vjp(self, fwd_fn, bwd_fn, *args): + # The backend layer for custom_vjp just calls fwd_fn. + # Why doesn't it create an autograd.Function? We're assuming the backend + # layer doesn't need to handle Autograd. + assert self.inner == None + a, b = fwd_fn(*args) + result = label(a), label(b) + return result + +"""Dispatcher layers are composable via object composition: we can +imagine a stack of dispatchers, each one calling into the next. +For example, the Logger dispatcher simply prints out what operation +was called on it, and then forwards on the operation to the inner +dispatcher. Unlike with simple functorch, we're able to rely on the modes to forward the call to the inner dispatcher by just calling the function again +""" + +def custom_vjp_str(r, fwd_fn, bwd_fn, args): + arg_names = ", ".join([a.t_name for a in args]) + r_is_tensor = isinstance(r, torch.Tensor) + if r_is_tensor: + result_names = r.t_name + else: + result_names = [r.t_name for r in r] + if len(result_names) == 1: + result_names = f"{result_names[0]}," + else: + result_names = ", ".join(result_names) + + print( + f"{result_names} = custom_vjp({fwd_fn.__name__}, {bwd_fn.__name__}, {arg_names})" + ) + +class Logger(Dispatcher): + def __init__(self, *, name): + self.name = f" {name}" + + def size(self, input): + # don't log size calls + return self.inner.size(input) + + def ones(self, size, **kwargs): + r = torch.ones(size) + print(f"{self.name} {r.t_name}: {self.size(r)} = ones({size})") + return r + + def mul(self, lhs, rhs): + r = lhs.mul(rhs) + if isinstance(rhs, float): + print(f"{self.name} {r.t_name}: {self.size(r)} = {lhs.t_name} * {rhs}") + else: + print( + f"{self.name} {r.t_name}: {self.size(r)} = {lhs.t_name} * {rhs.t_name}" + ) + return r + + def add(self, lhs, rhs): + r = lhs.add(rhs) + print(f"{self.name} {r.t_name}: {self.size(r)} = {lhs.t_name} + {rhs.t_name}") + return r + + def sum(self, input, dim=None): + if dim is None: + r = input.sum() + else: + r = input.sum(dim) + print(f"{self.name} {r.t_name}: {self.size(r)} = {input.t_name}.sum(dim={dim})") + return r + + def unsqueeze(self, input, dim): + r = input.unsqueeze(dim) + print( + f"{self.name} {r.t_name}: {self.size(r)} = {input.t_name}.unsqueeze({dim})" + ) + return r + + def squeeze(self, input, dim): + r = input.squeeze(dim) + print(f"{self.name} {r.t_name}: {self.size(r)} = {input.t_name}.squeeze({dim})") + return r + + def expand(self, input, sizes): + r = input.expand(sizes) + print( + f"{self.name} {r.t_name}: {self.size(r)} = {input.t_name}.expand({sizes})" + ) + return r + + def custom_vjp(self, fwd_fn, bwd_fn, *args): + # because custom_vjp is not an aten function, we have to explicitly send + # it to its inner + r = super().custom_vjp(fwd_fn, bwd_fn, *args) + print(custom_vjp_str(r, fwd_fn, bwd_fn, args)) + return r + + def lift(self, input, d): + if self == d: + return input + else: + return self.inner.lift(input, d) + +"""Here is a simple example of using Logger and Torch together. Whenever +we make calls to operations, we must do so via the Dispatcher object. +We will explicitly write out all of these calls before we add wrapper +class sugaring. +""" + +with Labeler(): + with Logger(name="Torch"): + z = torch.ones(2) + torch.ones(2) +print(z) +assert(isinstance(z, torch.Tensor)) + +"""With the Dispatcher structure in hand, we are now in a good place to +port the autograd implementation from Simple Autograd into our new +framework. +""" + +from torch.utils._mode_utils import no_dispatch + +class Autograd(Dispatcher): + # create_graph here corresponds to the create_graph kwarg in traditional + # PyTorch, which controls whether or not the graph of the derivative + # will be constructed, allowing computing higher order derivatives. + # We will see that although create_graph=True allows Autograd to directly + # support higher order derivatives, layering an Autograd to another + # Autograd will also allow higher order derivatives. + def __init__(self, *, name="Autograd", create_graph: bool = False): + self.gradient_tape = [] + self.name = name + self.create_graph = create_graph + + # create_graph controls where add/mul/etc calls from the backwards + # propagators go: if you create_graph, they we're going to have you + # the current Autograd dispatcher; otherwise they're going to + # move on to the inner layer. This restores the right mode to reset (and + # the proper context manager to use) + def backward_inner(self): + if self.create_graph: + mode = self + else: + mode = self.inner + return restore_dispatcher_or_nop(mode) + + def mul(self, lhs, rhs): + if isinstance(rhs, float) and rhs == 1.0: + # peephole optimization + return lhs + + # define forward + # first, run the operation in the inner layer to get the initial + # result + r = lhs.mul(rhs) + # We directly implement printing here as it indicates whether or not + # this operation was saved to the tape or not + if isinstance(rhs, float): + print(f"{self.name} {r.t_name}: {self.size(r)} = {lhs.t_name} * {rhs}") + else: + print(f"{self.name} {r.t_name}: {self.size(r)} = {lhs.t_name} * {rhs.t_name}") + + # record what the inputs and outputs of the op were + inputs = [lhs.t_name, rhs] if isinstance(rhs, float) else [lhs.t_name, rhs.t_name] + outputs = [r.t_name] + + # define backprop + def propagate(dL_doutputs: List[Tensor]): + (dL_dr,) = dL_doutputs + + dr_dlhs = rhs # partial derivative of r = lhs*rhs + dr_drhs = lhs # partial derivative of r = lhs*rhs + + # chain rule propagation from outputs to inputs of multiply. + # Notice that the propagation rule may itself call + # other operations; depending on create_graph, they may + # either be dispatched with self or self.inner; self.backward_inner() + # controls which one we go to. + with self.backward_inner()(): + dL_dlhs = dL_dr.mul(dr_dlhs) + dL_drhs = dL_dr.mul(dr_drhs) + dL_dinputs = [dL_dlhs, dL_drhs] + return dL_dinputs + + # finally, we record the compute we did on the tape + self.gradient_tape.append( + TapeEntry(inputs=inputs, outputs=outputs, propagate=propagate) + ) + return r + + # The rest of the implementations follow in the same way and can + # be skipped + + def add(self, lhs, rhs): + # Add follows a similar pattern to Mul, but it doesn't end up + # capturing any variables. + r = lhs.add(rhs) + print(f"{self.name} {r.t_name}: {self.size(r)} = {lhs.t_name} + {rhs.t_name}") + + def propagate(dL_doutputs: List[Tensor]): + (dL_dr,) = dL_doutputs + dr_dlhs = 1.0 + dr_drhs = 1.0 + with self.backward_inner()(): + dL_dlhs = dL_dr.mul(dr_dlhs) + dL_drhs = dL_dr.mul(dr_drhs) + return [dL_dlhs, dL_drhs] + + self.gradient_tape.append( + TapeEntry( + inputs=[lhs.t_name, rhs.t_name], outputs=[r.t_name], propagate=propagate + ) + ) + return r + + # Extended to handle dim argument for Batched (later) + def sum(self, input: Tensor, dim=None): + if dim is None: + r = input.sum() + else: + r = input.sum(dim) + print(f"{self.name} {r.t_name}: {self.size(r)} = {input.t_name}.sum(dim={dim})") + + def propagate(dL_doutputs: List[Tensor]): + (dL_dr,) = dL_doutputs + size = self.size(input) + res = dL_dr + # Broadcast over all dimensions that were reduced over + input_dim = self.inner.dim(input) # this needs to be done in inner + print(input_dim) + print(dim) + with self.backward_inner()(): + for i in sum_dims(input_dim=input_dim, dim=dim): + res = res.unsqueeze(i) + out = res.expand(size) + return [out] + + self.gradient_tape.append( + TapeEntry(inputs=[input.t_name], outputs=[r.t_name], propagate=propagate) + ) + return r + + # Unlike Simple Autograd, this expand requires the input to have + # been unsqueezed before hand. This lets us avoid having to do + # at::sum_to for the nontrivial case (which is more complicated) + def expand(self, input: Tensor, sizes: List[int]): + print(self.inner.dim(input)) + print(len(sizes)) + assert self.inner.dim(input) == len(sizes) # only works if dims match + r = input.expand(sizes) + print( + f"{self.name} {r.t_name}: {self.size(r)} = {input.t_name}.expand({sizes})" + ) + + def propagate(dL_doutputs: List[Tensor]): + (dL_dr,) = dL_doutputs + input_size = input.size() + dims = tuple( + i for i in range(input.dim()) if input_size[i] != sizes[i] + ) + # We wanted a sum keepdim=True, but I didn't want to force + # everyone to support it so manually unsqueeze + with self.backward_inner()(): + res = dL_dr.sum(dims) + for d in dims: + res = res.unsqueeze(d) + return [res] + + self.gradient_tape.append( + TapeEntry(inputs=[input.t_name], outputs=[r.t_name], propagate=propagate) + ) + return r + + def squeeze(self, input: Tensor, dim): + r = input.squeeze(dim) + print( + f"{self.name} {r.t_name}: {self.size(r)} = {input.t_name}.squeeze(dim={dim})" + ) + + def propagate(dL_doutputs: List[Tensor]): + (dL_dr,) = dL_outputs + with self.backward_inner()(): + res = dL_dr.unsqueeze(dim) + return [res] + + self.gradient_tape.append( + TapeEntry(inputs=[input.t_name], outputs=[r.t_name], propagate=propagate) + ) + return r + + def unsqueeze(self, input: Tensor, dim): + r = input.unsqueeze(dim) + print( + f"{self.name} {r.t_name}: {self.size(r)} = {input.t_name}.unsqueeze(dim={dim})" + ) + + def propagate(dL_doutputs: List[Tensor]): + (dL_dr,) = dL_doutputs + with self.backward_inner()(): + out = dL_dr.squeeze(dim) + return [out] + + self.gradient_tape.append( + TapeEntry(inputs=[input.t_name], outputs=[r.t_name], propagate=propagate) + ) + return r + + def ones(self, size, **kwargs): + return torch.ones(size) + + def custom_vjp(self, fwd_fn, bwd_fn, *args): + # To support Autograd(Autograd(Torch()), custom_vjp MUST call custom_vjp + # on the inner dispatcher. If it instead called fwd_fn(*args), then + # the inner Autograd dispatcher would not use bwd_fn in its backward pass. + + r, saved = super().custom_vjp(fwd_fn, bwd_fn, *args) + print(custom_vjp_str(r, fwd_fn, bwd_fn, args)) + + # To preserve custom backward semantics, we create a lambda that calls + # bwd_fn. This lambda is then saved on the gradient tape. + def propagate(dL_doutputs: List[Tensor]): + with self.backward_inner()(): + return bwd_fn(dL_doutputs, saved) + + self.gradient_tape.append( + TapeEntry( + inputs=[arg.t_name for arg in args], outputs=[r.t_name], propagate=propagate + ) + ) + return r, saved + + def size(self, input): + return self.inner.size(input) + + def lift(self, input, d): + if self == d: + return input + else: + return self.inner.lift(input, d) + + def grad(self, L, desired_results: List[Tensor]) -> List[Tensor]: + # this map holds dL/dX for all values X + dL_d: Dict[str, Tensor] = {} + # It starts by initializing the 'seed' dL/dL, which is 1 + # TODO: indirect this via the backend + with restore_dispatcher_or_nop(self.inner)(): + dL_d[L.t_name] = torch.ones(self.inner.size(L)) + print(f"-- {self.name} d{L.t_name} -------") + + # look up dL_dentries. If a variable is never used to compute the loss, + # we consider its gradient None, see the note below about zeros for more information. + def gather_grad(entries: List[str]): + return [dL_d[entry] if entry in dL_d else None for entry in entries] + + # propagate the gradient information backward + for entry in reversed(self.gradient_tape): + dL_doutputs = gather_grad(entry.outputs) + if all(dL_doutput is None for dL_doutput in dL_doutputs): + # optimize for the case where some gradient pathways are zero. See + # The note below for more details. + continue + + # perform chain rule propagation specific to each compute + dL_dinputs = entry.propagate(dL_doutputs) + + # Accululate the gradient produced for each input. + # Each use of a variable produces some gradient dL_dinput for that + # use. The multivariate chain rule tells us it is safe to sum + # all the contributions together. + for input, dL_dinput in zip(entry.inputs, dL_dinputs): + if input not in dL_d: + dL_d[input] = dL_dinput + else: + with self.backward_inner()(): + dL_d[input] = dL_d[input].add(dL_dinput) + + # print some information to understand the values of each intermediate + # for name, value in dL_d.items(): + # print(f'{self.name} d{L.t_name}_d{name} = {value.t_name}') + print(f"------------------------") + + return gather_grad(desired.t_name for desired in desired_results) + +"""To calculate some simple gradients, we can compose Autograd with +Torch and get the result we expect. +""" + +torch.manual_seed(0) +a, b = label(torch.rand(4)), label(torch.rand(4)) + +def simple(a, b): + t = a + b + return t.mul(b) + + +grad_dispatcher = Autograd() + +with Labeler(): + with grad_dispatcher: + loss = simple(a, b) + +print("a", a) +print("b", b) +da, db = grad_dispatcher.grad(loss, [a, b]) +print("da", da) +print("db", db) + +"""To compute higher order gradients, we have two options. First, +we can do traditional PyTorch style higher order differentiation +with `create_graph=True`, writing the backpropagation computations directly +into the tape so they can be further differentiated over. This is also +what the original Simple Autograd implementation does. +""" + +def run_gradients(d1, d2): + with Labeler(): + with d1: + with d2 if d1 != d2 else contextlib.nullcontext(): + # our first loss + L0 = simple(a, b).sum() + + # compute derivatives of our inputs + dL0_da, dL0_db = d2.grad(L0, [a, b]) + print(dL0_da) + + # In real code, how would we switch from executing from d2 to d1? + # In functorch, the d2 dispatch calls would happen in the inside of + # a higher-order grad() or vjp() call; when we exit from this call, all + # of the involved tensors are unwrapped. + + # now lets compute the L2 norm of our derivatives + with d1.restore(): + L1 = torch.sum(torch.add(dL0_da.mul(dL0_da), dL0_db.mul(dL0_db))) + + # and take the gradient of that. + # notice there are two losses involved1. + return d1.grad(L1, [a, b]) + +grad_dispatcher = Autograd(create_graph=True) +da, db = run_gradients(grad_dispatcher, grad_dispatcher) + +print("da", da) +print("db", db) + +"""Our second option is to follow functorch's implementation strategy, which +is to stack two Autograd dispatchers on top of each other. Here, it is +not necessary to `create_graph=True`, because when the backpropagator forwards +to the inner dispatcher, it will record those operations on the tape too. +But if you look at the output, you will notice something very interesting: +the first portion of the tape is exactly replicated between Autograd1 and +Autograd2: we're duplicating the tape in this case! So PyTorch's default +implementation of backwards is more efficient, because it avoids having to +record the tape twice (although this doesn't matter too much, because the +saved tensors themselves can be shared between the two tapes, so it is just +the operator graph that is duplicated). + +This is our first example of using two dispatchers. While we are +performing the inner grad, we perform our operations on the outer +dispatcher `d2`; after we are done with the inner grad we switch to +`d1` by restoring `d1` without `d2`. Intuitively, this corresponds from +passing out of the inner `grad` call to the outer `grad` call. +""" + +d1 = Autograd(create_graph=False) +d2 = Autograd(create_graph=False) + +da, db = run_gradients(d2, d1) +print("da", da) +print("db", db) + +"""Under what situations might it be profitable to keep the two tapes separate? +One guess we might have is if there is another functional transformation +wedged between the two autograd transformations. We would then expect the +backwards formula we save to be different between the two tapes. To do this, I +first need to implement batched tensors. + +One unusual thing about this implementation is that we do not need to wrap +tensors to change their sizes; instead, we just override the meaning of +size() on the dispatcher to hide batch dimensions. These calls are not +sent to the Pytorch dispatcher, so we need to explicitly call Dispatcher.size. + +One case we do not +exercise in this example is implicit broadcasting when you combine a tensor +that is not batched with a tensor that is batched: without wrappers, a user +must explicitly lift (e.g., unsqueeze and expand) tensors they wish to +replicate across the batch dimension. The code below will blindly attempt to +reinterpret a tensor as a batched tensor, even when it may not make sense (if +there is a size mismatch, however, you will get an assert failure). Similarly, +once you exit a vmap region, all previously vmap'ed tensors "magically" become +unbatched. functorch did not pursue this implementation because at the time +Tensor.size() was not virtual and thus it was not possible to override (this +will be changing soon). +""" + +# This implementation of Batched only supports inserting a dimension +# at the very front +class Batched(Dispatcher): + def __init__(self, *, length, name="Batched"): + self.name = name + self.length = length + + def _pad_to_size(self, lhs, rhs): + lhs_size, rhs_size = self.inner.size(lhs), self.inner.size(rhs) + lhs_dim, rhs_dim = len(lhs_size), len(rhs_size) + if lhs_dim == rhs_dim: + return (lhs, rhs) + diff = rhs_dim - lhs_dim + assert diff != 0 # sanity check + new_final = rhs if diff < 0 else lhs + for _ in range(abs(diff)): # could be done as a reshape if we added that + new_final = self.unsqueeze(new_final, 0) + return (lhs, new_final) if diff < 0 else (new_final, rhs) + + def size(self, input): + sizes = self.inner.size(input) + print(sizes) + assert sizes[0] == self.length + return sizes[1:] + + def ones(self, size, **kwargs): + return torch.ones([self.length,] + size) + + def mul(self, lhs, rhs): + assert self.inner.size(lhs)[0] == self.length + if not isinstance(rhs, float): + assert self.inner.size(rhs)[0] == self.length + lhs, rhs = self._pad_to_size(lhs, rhs) + return self.inner.mul(lhs, rhs) + + def add(self, lhs, rhs): + assert self.inner.size(lhs)[0] == self.length + assert self.inner.size(rhs)[0] == self.length + lhs, rhs = self._pad_to_size(lhs, rhs) + return torch.add(lhs, rhs) + + def sum(self, input, dim=None): + # offset all the summed over dimensions by one + assert self.inner.size(input)[0] == self.length + dim = tuple( + i + 1 for i in sum_dims(input_dim=self.inner.dim(input) - 1, dim=dim) + ) + return torch.sum(input, dim) + + def expand(self, input, sizes): + # offset sizes by one + assert self.inner.size(input)[0] == self.length + new_sizes = [self.inner.size(input)[0]] + sizes + return input.expand(new_sizes) + + def squeeze(self, input, dim): + # offset dim by one + assert self.inner.size(input)[0] == self.length + return torch.squeeze(input, dim + 1) + + def unsqueeze(self, input, dim): + # offset dim by one + assert self.inner.size(input)[0] == self.length + return torch.unsqueeze(input, dim + 1) + + def custom_vjp(self, fwd_fn, bwd_fn, *args): + def batchify(fn): + def new_fn(*args): + with Batched(length=self.length, name='GeneratedBatched'): + return fn(*args) + return new_fn + + # If we have Batched(Autograd(Torch()), then we would like the inner + # dispatcher to receive a call to custom_vjp so that it preserves the + # backward semantics. However, since this is the Batched dispatcher, + # we want the innermost Torch dispatcher to run a batched version of fwd_fn + # function! The way we get this to work is to create a new fwd_fn, that, + # when executed, executes a batched version of fwd_fn. + # + # Same thing for the bwd_fn. + # NB: currently simple_functorch assumes that all Tensors are batched at + # dimension 0. I'm not sure how this logic would look like without + # this assumption (in functorch tensors may not be batched). + r, saved = super().custom_vjp(batchify(fwd_fn), batchify(bwd_fn), *args) + return r, saved + + # The lift operation takes a tensor associated with some inner + # dispatcher, and "lifts" it so that it is interpreted neutrally + # for the outer dispatcher. For most dispatchers this is trivial, + # but for batched tensor it is not: given a tensor x, to interpret + # it as x under the Batching dispatcher, we have to expand it so + # that it is broadcasted along its first dimension. + def lift(self, input, d): + if d is self: + return input + b_input = torch.unsqueeze(input, 0) + b_input = b_input.expand((self.length,) + self.inner.size(input)) + return self.inner.lift(b_input, d) + +# Our inputs are batched this time! +va, vb = label(torch.rand(2, 4)), label(torch.rand(2, 4)) + +def run_batched_gradients(): + # our first loss + # we write the dimension we reduce on explicitly for clarity + d1 = Autograd(create_graph=False) + d3 = Autograd(create_graph=False) + with Labeler(): + with d1: + with Batched(length=2): + with d3: + L0 = torch.sum(simple(va, vb), dim=0) + + # compute derivatives of our inputs + dL0_da, dL0_db = d3.grad(L0, [va, vb]) + + # now lets compute the L2 norm of our derivatives + with d1.restore(): + L1 = torch.sum(torch.add(dL0_da.mul(dL0_da), dL0_db.mul(dL0_db))) + + # and take the gradient of that. + # notice there are two losses involved1. + dL1_da, dL1_db = d1.grad(L1, [va, vb]) + return dL1_da, dL1_db + + +dva, dvb = run_batched_gradients() +print("va", va) +print("vb", vb) +print("dva", dva) +print("dvb", dvb) + +"""To see that we have done this correctly, we could run the corresponding JAX: + +``` +from jax import grad, vmap +import jax.numpy as np + +def simple(a, b): + t = a + b + return t * b + +def L0(a, b): + return np.sum(simple(a, b)) + +def L1(a, b): + dL0_da, dL0_db = vmap(grad(L0, argnums=(0,1)), in_axes=0)(a, b) + return (dL0_da * dL0_da + dL0_db * dL0_db).sum() + +va = np.asarray([[0.4556, 0.6323, 0.3489, 0.4017], + [0.0223, 0.1689, 0.2939, 0.5185]]) +vb = np.asarray([[0.6977, 0.8000, 0.1610, 0.2823], + [0.6816, 0.9152, 0.3971, 0.8742]]) +dva, dvb = grad(L1, argnums=(0,1))(va, vb) +print("dva", dva) +print("dvb", dvb) +``` + +Looking over the output, the tapes look similar, but we can see that the sizes +and the arguments of the operations in question differ (after all, Autograd3 is +on the inside of the vmap, while Autograd1 is outside). But it is still very +similar: we could imagine simply varying the dispatcher we use to process backwards +depending on when we are executing the tape. In fact, this is exactly what an +initial, non-functorch implementation of PyTorch did to support per-sample +gradients. + +Exercise: modify Autograd.grad to accept a dispatcher, and use that dispatcher +instead of self.backward_inner() when running propagator functions. Then, rewrite +the above example so that it only has one level of Autograd: +Batched(Autograd(Torch(), create_graph=True)) and show you still get the same +result. + +OK, so all of this dispatcher business is all nice and explicit, but +that's not what JAX/functorch's interface looks like. How do we +bridge the gap? Unlike with simple functorch 1.0, we don't have to +set the global mode since we're using the context managers and modes +to set that +""" + +# now unnecessary :) +# DISPATCHER = Labeler() + + +# @contextlib.contextmanager +# def dispatcher(d): +# global DISPATCHER +# old_d = DISPATCHER +# DISPATCHER = d +# try: +# yield +# finally: +# DISPATCHER = old_d + +"""A dispatcher mode, however, is not enough. Remember that in our +implementation of Batched, we blindly assumed that all tensors were +batched, even if this did not necessarily make sense. If I have +`vmap(lambda bx: bx + y)(x)`, with `x: (B,X)` and `y: (X,)`, the +underlying operation should broadcast y to `(B,X)` and then do the +addition with x (bx advertises that it has size `(X,)` inside of the +vmap'd lambda). To know this should happen, it is necessary for +us to know that y is not a batched tensor, but x is a batched tensor. +We'll resolve this with a wrapper class called FuncTensor, which +records both the underlying Tensor, as well as the Dispatcher which +this tensor is associated with. In the above example, `bx.dispatcher` +might be `Batched(Torch())`, whereas `x.dispatcher` is `Torch()`. + +So our general strategy is as follows: + 1. Every tensor is associated with a dispatcher + 2. You can lift tensors to dispatchers which wrap them (which can + trigger some operations, like expand for Batched); this is + implemented by `dispatcher_wraps` + 3. To perform an operation between to tensors, lift them so that + they all have the same dispatcher, then do the operation on + that dispatcher. +""" + +# A dispatcher d1 wraps another dispatcher d2 if d2 is an ancestor of +# d1 in the tree structure. We've defined this relation to be +# reflexive, in the same way issubclass(A, A) == True. +def dispatcher_wraps(d1, d2): + # Treat this as a reflexive relation + if d1 is d2 or (d1 is not None and d2 in d1.ancestors) or d2 is None: + return True + return False + + +# Given a list of arguments, lift them all up to a common dispatcher +# level, returning that dispatcher as well as the lifted arguments. +# Note that the current dispatcher is also accounted for by getting the current +# mode! In autodidax, this is `find_top_trace`. +def lift_and_unwrap_args(*args): + outermost = _get_torch_dispatch_mode() + for a in args: + if dispatcher_wraps(outermost, a.dispatcher): + pass + elif dispatcher_wraps(a.dispatcher, outermost): + # You can make this case an error as well if you don't + # want to support non-lexical functorch tensors + outermost = a.dispatcher + else: + raise TypeError("incompatible dispatcher trees") + return (outermost,) + tuple(a.lift(outermost).tensor for a in args) + +"""The actual implementation of the wrapper tensor which tracks the +Dispatcher for a tensor +""" + +@dataclass +class FuncTensor: + tensor: Tensor + dispatcher: Dispatcher + + # Lift a FuncTensor to an outer dispatcher + def lift(self, d): + # You can only lift to a dispatcher which wraps the dispatcher + # this FuncTensor is associated with (not vice versa, or between + # unrelated FuncTensors). + assert dispatcher_wraps(d, self.dispatcher) + return FuncTensor(d.lift(self.tensor, self.dispatcher), d) + + # The general strategy for any operation performed on a tensor, we + # lift all the arguments so that they live on the same dispatcher + # level, and then perform the operation on that dispatcher. The + # resulting tensor is tagged at whatever dispatcher we had run the + # tensor on. + def __mul__(self, other): + d, self, other = lift_and_unwrap_args(self, other) + with restore_dispatcher_or_nop(d)(): + return FuncTensor(self.mul(other), d) + + def __add__(self, other): + d, self, other = lift_and_unwrap_args(self, other) + with restore_dispatcher_or_nop(d)(): + return FuncTensor(self.add(other), d) + + def sum(self, dim=None): + d, self = lift_and_unwrap_args(self) + with restore_dispatcher_or_nop(d)(): + if dim is None: + res = self.sum() + else: + res = self.sum(dim) + return FuncTensor(res, d) + + def expand(self, sizes): + d, self = lift_and_unwrap_args(self) + with restore_dispatcher_or_nop(d)(): + return FuncTensor(self.expand(sizes), d) + + def unsqueeze(self, dim): + d, self = lift_and_unwrap_args(self) + with restore_dispatcher_or_nop(d)(): + return FuncTensor(self.unsqueeze(dim), d) + + def squeeze(self, dim): + d, self = lift_and_unwrap_args(self) + with restore_dispatcher_or_nop(d)(): + return FuncTensor(self.squeeze(dim), d) + + def size(self): + d, self = lift_and_unwrap_args(self) + return d.size(self) + + def dim(self): + d, self = lift_and_unwrap_args(self) + return d.size(self) + + # Factory functions like ones do not have any Tensor arguments, + # so they rely solely on the current mode + @staticmethod + def ones(size): + return torch.ones(size) + +"""Now we are ready to implement grad. First, we need some helper +functions. +""" + +# When we are done doing a vmap/grad, we need to take the results and +# lower them back to a lower dispatcher on the stack (this is always +# a no-op, in particular, in the vmap case, when we exit vmap the user +# gets to see the batched dimension again.) +def unlift(t, d): + if isinstance(t, list): + return [unlift(x, d) for x in t] + elif isinstance(t, tuple): + return tuple(unlift(x, d) for x in t) + else: + if t.dispatcher is d: + return t + return unlift(FuncTensor(t.tensor, t.dispatcher.inner), d) + + +# This lets us easily pick out arguments as specified by argnums +def filter_argnums(args, argnums): + if isinstance(argnums, int): + return (args[argnums],) + else: + return tuple(args[i] for i in argnums) + +"""Now grad and vmap!""" + +# For simplicity, these functions only take tuples, not pytrees +def grad(f, argnums=0): + @functools.wraps(f) + def wrapped_f(*args): + # We first lift and unwrap all of the arguments which we want + # to pass into the function + old_d, *args = lift_and_unwrap_args(*args) + assert old_d == _get_torch_dispatch_mode() + d = Autograd() + with d: + # We pass in the functions at the new Autograd level (they + # were lifted to old_d, and lifting to d is a noop) + L = f(*(FuncTensor(a, d) for a in args)) + assert L.dispatcher is d + # Run the autograd pass, getting the grads for the inputs + # as specified by argnums + grads = d.grad(L.tensor, filter_argnums(args, argnums)) + # Finally, construct the grads at the lower level and return + # them + return [FuncTensor(r, old_d) for r in grads] + + return wrapped_f + +def vmap(f): + @functools.wraps(f) + def wrapped_f(*args): + # cannot vmap over no arguments as this function uses the + # arguments to determine how large the batch dimension is + # (hypothetically, you could explicitly pass in the batch + # size, and then use this to control factory functions; + # JAX doesn't seem to have a knob to do this) + assert args + old_d, *args = lift_and_unwrap_args(*args) + d = Batched(length=args[0].size()[0]) + for a in args: + assert a.size()[0] == d.length + with d: + # Rewrap all the arguments as batched tensors, then + # unwrap any batched tensors that escape + return unlift(f(*(FuncTensor(a, d) for a in args)), old_d) + + return wrapped_f + +"""Now we can rerun our example using the high level grad/vmap functions!""" + +def simple(a, b): + t = a + b + return t * b + + +def L0(a, b): + return simple(a, b).sum() + + +def L1(a, b): + dL0_da, dL0_db = vmap(grad(L0, argnums=(0, 1)))(a, b) + return (dL0_da * dL0_da + dL0_db * dL0_db).sum() + +l = Labeler() +print(va) +print(vb) +with l: + fva = FuncTensor(va, l) + fvb = FuncTensor(vb, l) + dva, dvb = grad(L1, argnums=(0, 1))(fva, fvb) +print("dva", dva) +print("dvb", dvb) + +"""Because FuncTensors are associated with the ambient dispatcher they +were created from, they are also allowed to escape from the context in +which they were defined, allowing for non-lexical, imperative +transform API. For example, batching over module parameters is +problematic today, but all we need to do is tweak the FuncTensor's +dispatchers appropriately and everything works out. +""" + +B = 2 + +# this is a bug, we should be able to set inner in the constructor and have that set the ancestors correctly +base_dispatcher = Labeler() +batched_dispatcher = Batched(length=B) +with base_dispatcher: + with batched_dispatcher: + pass + +PlainTensor = lambda t: FuncTensor(torch.randn(N), base_dispatcher) +BatchedTensor = lambda t: FuncTensor(t, batched_dispatcher) + +class ScaleBiasModule: + weight: FuncTensor + bias: FuncTensor + + def __init__(self, N): + self.weight = PlainTensor(torch.randn(N)) + self.bias = PlainTensor(torch.randn(N)) + + def forward(self, input): + return self.weight * input + self.bias + + +B = 2 +N = 3 +m = ScaleBiasModule(N) +# Ensemble weights only; input is not batched +m.weight = BatchedTensor(torch.randn(B, N)) +input = PlainTensor(torch.randn(N)) +output = m.forward(input) +print( + "expect", input.tensor.unsqueeze(0) * m.weight.tensor + m.bias.tensor.unsqueeze(0) +) +print("output", output.tensor) + +"""Higher-order operations in simple functorch! + +Problem: users want to define functions with custom forward and backward +passes. These functions call PyTorch operations. When we vmap over such a +function, we would like for the backward pass to be preserved. + +Why is this difficult? In PyTorch today, vmap over an autograd.Function +effectively runs vmap on the forward pass of the autograd.Function. +Meanwhile, autograd records the transformed operations for backward, instead +of the custom backward pass we specified in the autograd.Function! + +Solution: We're going to introduce a `custom_vjp` primitive that accepts +functions and varargs Tensor arguments and demonstrate that it resolves +the problem. + +custom_vjp(fwd_fn, bwd_fn, *args) takes in two functions as arguments. +We add a little helper function so that the user is not explicitly calling +this function on the active dispatcher +""" + +def custom_vjp(fwd_fn, bwd_fn, *args): + d = _get_torch_dispatch_mode() + return d.custom_vjp(fwd_fn, bwd_fn, *args) + +"""For our custom function, we want f(x) = x * x, but we install a custom +backwards pass that computes 32 * x (instead of 2 * x) so we can tell +if custom_vjp is working. +""" + +a = label(torch.rand(4)) +va = label(torch.rand(2, 4)) + +def f_fwd(x): + # Our convention is that f_fwd returns (outputs, "saved") + return x.mul(x), x + +# Our convention is that f_bwd accepts (dispatcher, gradOutputs, "saved") +def f_bwd(gradOutputs, x): + gO, = gradOutputs + # Should be gO * 2 * x, but we're gonna do gO * 32 * x to demonstrate things + with no_dispatch(): + thirty_two = torch.tensor(32.) # a hack so I don't have to override lift + + return [torch.mul(gO.mul(x), label(thirty_two, 'thirty_two'))] + +def run_grad(): + grad_dispatcher = Autograd() + with Labeler(): + with grad_dispatcher: + # Here's how to invoke custom_vjp. + r, _ = custom_vjp(f_fwd, f_bwd, a) + L3 = torch.sum(r) + print(grad_dispatcher.gradient_tape) + dL3_a, = grad_dispatcher.grad(L3, [a]) + + print(dL3_a) + print(a) + # Check that the gradients are indeed 32 * a + assert torch.allclose(dL3_a, 32 * a) + +run_grad() + +# Multiple layers of autograd +def run_gradgrad(): + grad_dispatcher_1 = Autograd() + grad_dispatcher_2 = Autograd() + with Labeler(): + with Logger(name="Torch"): + with grad_dispatcher_1: + with grad_dispatcher_2: + r, _ = custom_vjp(f_fwd, f_bwd, a) + L4 = r.sum() + dL4_a, = grad_dispatcher_2.grad(L4, [a]) + + # Evidence that d2 respected the custom_vjp's f_bwd + assert torch.allclose(dL4_a, 32 * a) + + assert hasattr(dL4_a, 't_name') + with grad_dispatcher_1.restore(): + dL4_a_sum = dL4_a.sum() + ddL4_a_a, = grad_dispatcher_1.grad(dL4_a_sum, [a]) + + # Evidence that d1 respected the custom_vjp's f_bwd + assert torch.allclose(ddL4_a_a, torch.ones_like(a) * 32) + +run_gradgrad() + +"""And now, let's try that again, with grad(lambda x: vmap(f)(x).sum()). +The goal of custom_vjp is to make it so that vmap(custom_vjp) still +preserves the backward semantics. +""" + +def f_fwd(x): + return x.mul(x), x + +def f_bwd(gradOutputs, x): + gO, = gradOutputs + # Should be gO * 2 * x, but we're gonna do gO * 32 * x to prove a point + return [torch.mul(gO.mul(x), torch.ones(()) * 32.)] + +def run_gradvmap(): + grad_dispatcher = Autograd() + with Labeler(): + with grad_dispatcher: + with Batched(length=2): + r, _ = custom_vjp(f_fwd, f_bwd, va) + L99 = r.sum() + dL99_a, = grad_dispatcher.grad(L99, [va]) + + # As you can see, d1.grad still calls f_bwd. + # The way we got this to work is that Batched.custom_vjp + # calls custom_vjp on its inner dispatcher. + # Scroll up to the implementation of Batched for more details. + assert torch.allclose(dL99_a, 32 * va) + +run_gradvmap() \ No newline at end of file