Skip to content

Commit

Permalink
Merge branch 'main' into turboooo
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Aug 9, 2023
2 parents a0680da + 641fdd9 commit 69e4261
Show file tree
Hide file tree
Showing 40 changed files with 1,330 additions and 1,129 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
used within the autoclass directive.
"""

if obj.__name__.endswith(("_Weights", "_QuantizedWeights")):
if getattr(obj, ".__name__", "").endswith(("_Weights", "_QuantizedWeights")):

if len(obj) == 0:
lines[:] = ["There are no available pre-trained weights."]
Expand Down
1 change: 1 addition & 0 deletions docs/source/datapoints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`.
BoundingBoxFormat
BoundingBoxes
Mask
Datapoint
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ Color

ColorJitter
v2.ColorJitter
v2.RandomChannelPermutation
v2.RandomPhotometricDistort
Grayscale
v2.Grayscale
Expand Down Expand Up @@ -375,3 +376,14 @@ you can use a functional transform to build transform classes with custom behavi
to_pil_image
to_tensor
vflip

Developer tools
---------------

.. currentmodule:: torchvision.transforms.v2.functional

.. autosummary::
:toctree: generated/
:template: function.rst

register_kernel
125 changes: 125 additions & 0 deletions gallery/plot_custom_datapoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""
=====================================
How to write your own Datapoint class
=====================================
This guide is intended for downstream library maintainers. We explain how to
write your own datapoint class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read
:ref:`sphx_glr_auto_examples_plot_datapoints.py`.
"""

# %%
import torch
import torchvision

# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision.disable_beta_transforms_warning()

from torchvision import datapoints
from torchvision.transforms import v2

# %%
# We will create a very simple class that just inherits from the base
# :class:`~torchvision.datapoints.Datapoint` class. It will be enough to cover
# what you need to know to implement your more elaborate uses-cases. If you need
# to create a class that carries meta-data, take a look at how the
# :class:`~torchvision.datapoints.BoundingBoxes` class is `implemented
# <https://github.com/pytorch/vision/blob/main/torchvision/datapoints/_bounding_box.py>`_.


class MyDatapoint(datapoints.Datapoint):
pass


my_dp = MyDatapoint([1, 2, 3])
my_dp

# %%
# Now that we have defined our custom Datapoint class, we want it to be
# compatible with the built-in torchvision transforms, and the functional API.
# For that, we need to implement a kernel which performs the core of the
# transformation, and then "hook" it to the functional that we want to support
# via :func:`~torchvision.transforms.v2.functional.register_kernel`.
#
# We illustrate this process below: we create a kernel for the "horizontal flip"
# operation of our MyDatapoint class, and register it to the functional API.

from torchvision.transforms.v2 import functional as F


@F.register_kernel(dispatcher="hflip", datapoint_cls=MyDatapoint)
def hflip_my_datapoint(my_dp, *args, **kwargs):
print("Flipping!")
out = my_dp.flip(-1)
return MyDatapoint.wrap_like(my_dp, out)


# %%
# To understand why ``wrap_like`` is used, see
# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
# we will explain it below in :ref:`param_forwarding`.
#
# .. note::
#
# In our call to ``register_kernel`` above we used a string
# ``dispatcher="hflip"`` to refer to the functional we want to hook into. We
# could also have used the functional *itself*, i.e.
# ``@register_kernel(dispatcher=F.hflip, ...)``.
#
# The functionals that you can be hooked into are the ones in
# ``torchvision.transforms.v2.functional`` and they are documented in
# :ref:`functional_transforms`.
#
# Now that we have registered our kernel, we can call the functional API on a
# ``MyDatapoint`` instance:

my_dp = MyDatapoint(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)

# %%
# And we can also use the
# :class:`~torchvision.transforms.v2.RandomHorizontalFlip` transform, since it relies on :func:`~torchvision.transforms.v2.functional.hflip` internally:
t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)

# %%
# .. note::
#
# We cannot register a kernel for a transform class, we can only register a
# kernel for a **functional**. The reason we can't register a transform
# class is because one transform may internally rely on more than one
# functional, so in general we can't register a single kernel for a given
# class.
#
# .. _param_forwarding:
#
# Parameter forwarding, and ensuring future compatibility of your kernels
# -----------------------------------------------------------------------
#
# The functional API that you're hooking into is public and therefore
# **backward** compatible: we guarantee that the parameters of these functionals
# won't be removed or renamed without a proper deprecation cycle. However, we
# don't guarantee **forward** compatibility, and we may add new parameters in
# the future.
#
# Imagine that in a future version, Torchvision adds a new ``inplace`` parameter
# to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you
# already defined and registered your own kernel as

def hflip_my_datapoint(my_dp): # noqa
print("Flipping!")
out = my_dp.flip(-1)
return MyDatapoint.wrap_like(my_dp, out)


# %%
# then calling ``F.hflip(my_dp)`` will **fail**, because ``hflip`` will try to
# pass the new ``inplace`` parameter to your kernel, but your kernel doesn't
# accept it.
#
# For this reason, we recommend to always define your kernels with
# ``*args, **kwargs`` in their signature, as done above. This way, your kernel
# will be able to accept any new parameter that we may add in the future.
# (Technically, adding `**kwargs` only should be enough).
123 changes: 123 additions & 0 deletions gallery/plot_custom_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
===================================
How to write your own v2 transforms
===================================
This guide explains how to write transforms that are compatible with the
torchvision transforms V2 API.
"""

# %%
import torch
import torchvision

# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision.disable_beta_transforms_warning()

from torchvision import datapoints
from torchvision.transforms import v2


# %%
# Just create a ``nn.Module`` and override the ``forward`` method
# ===============================================================
#
# In most cases, this is all you're going to need, as long as you already know
# the structure of the input that your transform will expect. For example if
# you're just doing image classification, your transform will typically accept a
# single image as input, or a ``(img, label)`` input. So you can just hard-code
# your ``forward`` method to accept just that, e.g.
#
# .. code:: python
#
# class MyCustomTransform(torch.nn.Module):
# def forward(self, img, label):
# # Do some transformations
# return new_img, new_label
#
# .. note::
#
# This means that if you have a custom transform that is already compatible
# with the V1 transforms (those in ``torchvision.transforms``), it will
# still work with the V2 transforms without any change!
#
# We will illustrate this more completely below with a typical detection case,
# where our samples are just images, bounding boxes and labels:

class MyCustomTransform(torch.nn.Module):
def forward(self, img, bboxes, label): # we assume inputs are always structured like this
print(
f"I'm transforming an image of shape {img.shape} "
f"with bboxes = {bboxes}\n{label = }"
)
# Do some transformations. Here, we're just passing though the input
return img, bboxes, label


transforms = v2.Compose([
MyCustomTransform(),
v2.RandomResizedCrop((224, 224), antialias=True),
v2.RandomHorizontalFlip(p=1),
v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
])

H, W = 256, 256
img = torch.rand(3, H, W)
bboxes = datapoints.BoundingBoxes(
torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]),
format="XYXY",
canvas_size=(H, W)
)
label = 3

out_img, out_bboxes, out_label = transforms(img, bboxes, label)
# %%
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
# %%
# .. note::
# While working with datapoint classes in your code, make sure to
# familiarize yourself with this section:
# :ref:`datapoint_unwrapping_behaviour`
#
# Supporting arbitrary input structures
# =====================================
#
# In the section above, we have assumed that you already know the structure of
# your inputs and that you're OK with hard-coding this expected structure in
# your code. If you want your custom transforms to be as flexible as possible,
# this can be a bit limitting.
#
# A key feature of the builtin Torchvision V2 transforms is that they can accept
# arbitrary input structure and return the same structure as output (with
# transformed entries). For example, transforms can accept a single image, or a
# tuple of ``(img, label)``, or an arbitrary nested dictionary as input:

structured_input = {
"img": img,
"annotations": (bboxes, label),
"something_that_will_be_ignored": (1, "hello")
}
structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)

assert isinstance(structured_output, dict)
assert structured_output["something_that_will_be_ignored"] == (1, "hello")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")

# %%
# If you want to reproduce this behavior in your own transform, we invite you to
# look at our `code
# <https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/_transform.py>`_
# and adapt it to your needs.
#
# In brief, the core logic is to unpack the input into a flat list using `pytree
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
# then transform only the entries that can be transformed (the decision is made
# based on the **class** of the entries, as all datapoints are
# tensor-subclasses) plus some custom logic that is out of score here - check the
# code for details. The (potentially transformed) entries are then repacked and
# returned, in the same structure as the input.
#
# We do not provide public dev-facing tools to achieve that at this time, but if
# this is something that would be valuable to you, please let us know by opening
# an issue on our `GitHub repo <https://github.com/pytorch/vision/issues>`_.
Loading

0 comments on commit 69e4261

Please sign in to comment.