From 84db2ac4572dd23b67d93d08660426e44f97ba75 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 4 Aug 2023 17:13:15 +0100 Subject: [PATCH 1/8] Add tuto for custom transforms and custom datapoints in gallery example (#7795) Co-authored-by: Philip Meier --- docs/source/conf.py | 2 +- docs/source/datapoints.rst | 1 + docs/source/transforms.rst | 11 ++ gallery/plot_custom_datapoints.py | 125 ++++++++++++++++ gallery/plot_custom_transforms.py | 123 ++++++++++++++++ gallery/plot_datapoints.py | 137 ++++++++++++++---- torchvision/datapoints/_bounding_box.py | 2 +- torchvision/datapoints/_datapoint.py | 13 +- torchvision/datapoints/_image.py | 9 -- torchvision/datapoints/_mask.py | 12 -- torchvision/datapoints/_video.py | 9 -- torchvision/prototype/datapoints/_label.py | 2 +- .../transforms/v2/functional/_utils.py | 5 + 13 files changed, 385 insertions(+), 66 deletions(-) create mode 100644 gallery/plot_custom_datapoints.py create mode 100644 gallery/plot_custom_transforms.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 7b3e9e8a7f3..fed3884ea27 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -320,7 +320,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines): used within the autoclass directive. """ - if obj.__name__.endswith(("_Weights", "_QuantizedWeights")): + if getattr(obj, ".__name__", "").endswith(("_Weights", "_QuantizedWeights")): if len(obj) == 0: lines[:] = ["There are no available pre-trained weights."] diff --git a/docs/source/datapoints.rst b/docs/source/datapoints.rst index 55d3cda4a8c..ea23a7ff7a6 100644 --- a/docs/source/datapoints.rst +++ b/docs/source/datapoints.rst @@ -17,3 +17,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`. BoundingBoxFormat BoundingBoxes Mask + Datapoint diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 73adb3cf3b5..a1858c6b514 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -375,3 +375,14 @@ you can use a functional transform to build transform classes with custom behavi to_pil_image to_tensor vflip + +Developer tools +--------------- + +.. currentmodule:: torchvision.transforms.v2.functional + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + register_kernel diff --git a/gallery/plot_custom_datapoints.py b/gallery/plot_custom_datapoints.py new file mode 100644 index 00000000000..ea757283e86 --- /dev/null +++ b/gallery/plot_custom_datapoints.py @@ -0,0 +1,125 @@ +""" +===================================== +How to write your own Datapoint class +===================================== + +This guide is intended for downstream library maintainers. We explain how to +write your own datapoint class, and how to make it compatible with the built-in +Torchvision v2 transforms. Before continuing, make sure you have read +:ref:`sphx_glr_auto_examples_plot_datapoints.py`. +""" + +# %% +import torch +import torchvision + +# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that +# some APIs may slightly change in the future +torchvision.disable_beta_transforms_warning() + +from torchvision import datapoints +from torchvision.transforms import v2 + +# %% +# We will create a very simple class that just inherits from the base +# :class:`~torchvision.datapoints.Datapoint` class. It will be enough to cover +# what you need to know to implement your more elaborate uses-cases. If you need +# to create a class that carries meta-data, take a look at how the +# :class:`~torchvision.datapoints.BoundingBoxes` class is `implemented +# `_. + + +class MyDatapoint(datapoints.Datapoint): + pass + + +my_dp = MyDatapoint([1, 2, 3]) +my_dp + +# %% +# Now that we have defined our custom Datapoint class, we want it to be +# compatible with the built-in torchvision transforms, and the functional API. +# For that, we need to implement a kernel which performs the core of the +# transformation, and then "hook" it to the functional that we want to support +# via :func:`~torchvision.transforms.v2.functional.register_kernel`. +# +# We illustrate this process below: we create a kernel for the "horizontal flip" +# operation of our MyDatapoint class, and register it to the functional API. + +from torchvision.transforms.v2 import functional as F + + +@F.register_kernel(dispatcher="hflip", datapoint_cls=MyDatapoint) +def hflip_my_datapoint(my_dp, *args, **kwargs): + print("Flipping!") + out = my_dp.flip(-1) + return MyDatapoint.wrap_like(my_dp, out) + + +# %% +# To understand why ``wrap_like`` is used, see +# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now, +# we will explain it below in :ref:`param_forwarding`. +# +# .. note:: +# +# In our call to ``register_kernel`` above we used a string +# ``dispatcher="hflip"`` to refer to the functional we want to hook into. We +# could also have used the functional *itself*, i.e. +# ``@register_kernel(dispatcher=F.hflip, ...)``. +# +# The functionals that you can be hooked into are the ones in +# ``torchvision.transforms.v2.functional`` and they are documented in +# :ref:`functional_transforms`. +# +# Now that we have registered our kernel, we can call the functional API on a +# ``MyDatapoint`` instance: + +my_dp = MyDatapoint(torch.rand(3, 256, 256)) +_ = F.hflip(my_dp) + +# %% +# And we can also use the +# :class:`~torchvision.transforms.v2.RandomHorizontalFlip` transform, since it relies on :func:`~torchvision.transforms.v2.functional.hflip` internally: +t = v2.RandomHorizontalFlip(p=1) +_ = t(my_dp) + +# %% +# .. note:: +# +# We cannot register a kernel for a transform class, we can only register a +# kernel for a **functional**. The reason we can't register a transform +# class is because one transform may internally rely on more than one +# functional, so in general we can't register a single kernel for a given +# class. +# +# .. _param_forwarding: +# +# Parameter forwarding, and ensuring future compatibility of your kernels +# ----------------------------------------------------------------------- +# +# The functional API that you're hooking into is public and therefore +# **backward** compatible: we guarantee that the parameters of these functionals +# won't be removed or renamed without a proper deprecation cycle. However, we +# don't guarantee **forward** compatibility, and we may add new parameters in +# the future. +# +# Imagine that in a future version, Torchvision adds a new ``inplace`` parameter +# to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you +# already defined and registered your own kernel as + +def hflip_my_datapoint(my_dp): # noqa + print("Flipping!") + out = my_dp.flip(-1) + return MyDatapoint.wrap_like(my_dp, out) + + +# %% +# then calling ``F.hflip(my_dp)`` will **fail**, because ``hflip`` will try to +# pass the new ``inplace`` parameter to your kernel, but your kernel doesn't +# accept it. +# +# For this reason, we recommend to always define your kernels with +# ``*args, **kwargs`` in their signature, as done above. This way, your kernel +# will be able to accept any new parameter that we may add in the future. +# (Technically, adding `**kwargs` only should be enough). diff --git a/gallery/plot_custom_transforms.py b/gallery/plot_custom_transforms.py new file mode 100644 index 00000000000..eba8e91faf4 --- /dev/null +++ b/gallery/plot_custom_transforms.py @@ -0,0 +1,123 @@ +""" +=================================== +How to write your own v2 transforms +=================================== + +This guide explains how to write transforms that are compatible with the +torchvision transforms V2 API. +""" + +# %% +import torch +import torchvision + +# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that +# some APIs may slightly change in the future +torchvision.disable_beta_transforms_warning() + +from torchvision import datapoints +from torchvision.transforms import v2 + + +# %% +# Just create a ``nn.Module`` and override the ``forward`` method +# =============================================================== +# +# In most cases, this is all you're going to need, as long as you already know +# the structure of the input that your transform will expect. For example if +# you're just doing image classification, your transform will typically accept a +# single image as input, or a ``(img, label)`` input. So you can just hard-code +# your ``forward`` method to accept just that, e.g. +# +# .. code:: python +# +# class MyCustomTransform(torch.nn.Module): +# def forward(self, img, label): +# # Do some transformations +# return new_img, new_label +# +# .. note:: +# +# This means that if you have a custom transform that is already compatible +# with the V1 transforms (those in ``torchvision.transforms``), it will +# still work with the V2 transforms without any change! +# +# We will illustrate this more completely below with a typical detection case, +# where our samples are just images, bounding boxes and labels: + +class MyCustomTransform(torch.nn.Module): + def forward(self, img, bboxes, label): # we assume inputs are always structured like this + print( + f"I'm transforming an image of shape {img.shape} " + f"with bboxes = {bboxes}\n{label = }" + ) + # Do some transformations. Here, we're just passing though the input + return img, bboxes, label + + +transforms = v2.Compose([ + MyCustomTransform(), + v2.RandomResizedCrop((224, 224), antialias=True), + v2.RandomHorizontalFlip(p=1), + v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1]) +]) + +H, W = 256, 256 +img = torch.rand(3, H, W) +bboxes = datapoints.BoundingBoxes( + torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]), + format="XYXY", + canvas_size=(H, W) +) +label = 3 + +out_img, out_bboxes, out_label = transforms(img, bboxes, label) +# %% +print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }") +# %% +# .. note:: +# While working with datapoint classes in your code, make sure to +# familiarize yourself with this section: +# :ref:`datapoint_unwrapping_behaviour` +# +# Supporting arbitrary input structures +# ===================================== +# +# In the section above, we have assumed that you already know the structure of +# your inputs and that you're OK with hard-coding this expected structure in +# your code. If you want your custom transforms to be as flexible as possible, +# this can be a bit limitting. +# +# A key feature of the builtin Torchvision V2 transforms is that they can accept +# arbitrary input structure and return the same structure as output (with +# transformed entries). For example, transforms can accept a single image, or a +# tuple of ``(img, label)``, or an arbitrary nested dictionary as input: + +structured_input = { + "img": img, + "annotations": (bboxes, label), + "something_that_will_be_ignored": (1, "hello") +} +structured_output = v2.RandomHorizontalFlip(p=1)(structured_input) + +assert isinstance(structured_output, dict) +assert structured_output["something_that_will_be_ignored"] == (1, "hello") +print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}") + +# %% +# If you want to reproduce this behavior in your own transform, we invite you to +# look at our `code +# `_ +# and adapt it to your needs. +# +# In brief, the core logic is to unpack the input into a flat list using `pytree +# `_, and +# then transform only the entries that can be transformed (the decision is made +# based on the **class** of the entries, as all datapoints are +# tensor-subclasses) plus some custom logic that is out of score here - check the +# code for details. The (potentially transformed) entries are then repacked and +# returned, in the same structure as the input. +# +# We do not provide public dev-facing tools to achieve that at this time, but if +# this is something that would be valuable to you, please let us know by opening +# an issue on our `GitHub repo `_. diff --git a/gallery/plot_datapoints.py b/gallery/plot_datapoints.py index 57e29bd86eb..d87575cdb8e 100644 --- a/gallery/plot_datapoints.py +++ b/gallery/plot_datapoints.py @@ -3,13 +3,22 @@ Datapoints FAQ ============== -The :mod:`torchvision.datapoints` namespace was introduced together with ``torchvision.transforms.v2``. This example -showcases what these datapoints are and how they behave. This is a fairly low-level topic that most users will not need -to worry about: you do not need to understand the internals of datapoints to efficiently rely on -``torchvision.transforms.v2``. It may however be useful for advanced users trying to implement their own datasets, -transforms, or work directly with the datapoints. +Datapoints are Tensor subclasses introduced together with +``torchvision.transforms.v2``. This example showcases what these datapoints are +and how they behave. + +.. warning:: + + **Intended Audience** Unless you're writing your own transforms or your own datapoints, you + probably do not need to read this guide. This is a fairly low-level topic + that most users will not need to worry about: you do not need to understand + the internals of datapoints to efficiently rely on + ``torchvision.transforms.v2``. It may however be useful for advanced users + trying to implement their own datasets, transforms, or work directly with + the datapoints. """ +# %% import PIL.Image import torch @@ -35,11 +44,20 @@ assert isinstance(image, torch.Tensor) assert image.data_ptr() == tensor.data_ptr() - # %% # Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function # for the input data. # +# What can I do with a datapoint? +# ------------------------------- +# +# Datapoints look and feel just like regular tensors - they **are** tensors. +# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or +# any ``torch.*`` operator will also works on datapoints. See +# :ref:`datapoint_unwrapping_behaviour` for a few gotchas. + +# %% +# # What datapoints are supported? # ------------------------------ # @@ -50,9 +68,14 @@ # * :class:`~torchvision.datapoints.BoundingBoxes` # * :class:`~torchvision.datapoints.Mask` # +# .. _datapoint_creation: +# # How do I construct a datapoint? # ------------------------------- # +# Using the constructor +# ^^^^^^^^^^^^^^^^^^^^^ +# # Each datapoint class takes any tensor-like data that can be turned into a :class:`~torch.Tensor` image = datapoints.Image([[[[0, 1], [1, 0]]]]) @@ -68,27 +91,52 @@ # %% -# In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` also take a +# In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` can also take a # :class:`PIL.Image.Image` directly: image = datapoints.Image(PIL.Image.open("assets/astronaut.jpg")) print(image.shape, image.dtype) # %% -# In general, the datapoints can also store additional metadata that complements the underlying tensor. For example, -# :class:`~torchvision.datapoints.BoundingBoxes` stores the coordinate format as well as the spatial size of the -# corresponding image alongside the actual values: - -bounding_box = datapoints.BoundingBoxes( - [17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:] +# Some datapoints require additional metadata to be passed in ordered to be constructed. For example, +# :class:`~torchvision.datapoints.BoundingBoxes` requires the coordinate format as well as the size of the +# corresponding image (``canvas_size``) alongside the actual values. These +# metadata are required to properly transform the bounding boxes. + +bboxes = datapoints.BoundingBoxes( + [[17, 16, 344, 495], [0, 10, 0, 10]], + format=datapoints.BoundingBoxFormat.XYXY, + canvas_size=image.shape[-2:] ) -print(bounding_box) +print(bboxes) + +# %% +# Using the ``wrap_like()`` class method +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# You can also use the ``wrap_like()`` class method to wrap a tensor object +# into a datapoint. This is useful when you already have an object of the +# desired type, which typically happens when writing transforms: you just want +# to wrap the output like the input. This API is inspired by utils like +# :func:`torch.zeros_like`: + +new_bboxes = torch.tensor([0, 20, 30, 40]) +new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes) +assert isinstance(new_bboxes, datapoints.BoundingBoxes) +assert new_bboxes.canvas_size == bboxes.canvas_size # %% +# The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass +# it as a parameter to override it. Check the +# :meth:`~torchvision.datapoints.BoundingBoxes.wrap_like` documentation for +# more details. +# # Do I have to wrap the output of the datasets myself? # ---------------------------------------------------- # +# TODO: Move this in another guide - this is user-facing, not dev-facing. +# # Only if you are using custom datasets. For the built-in ones, you can use # :func:`torchvision.datasets.wrap_dataset_for_transforms_v2`. Note that the function also supports subclasses of the # built-in datasets. Meaning, if your custom dataset subclasses from a built-in one and the output type is the same, you @@ -105,8 +153,8 @@ class PennFudanDataset(torch.utils.data.Dataset): def __getitem__(self, item): ... - target["boxes"] = datapoints.BoundingBoxes( - boxes, + target["bboxes"] = datapoints.BoundingBoxes( + bboxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=F.get_size(img), ) @@ -147,7 +195,7 @@ def get_transform(train): # %% # .. note:: # -# If both :class:`~torchvision.datapoints.BoundingBoxes`'es and :class:`~torchvision.datapoints.Mask`'s are included in +# If both :class:`~torchvision.datapoints.BoundingBoxes` and :class:`~torchvision.datapoints.Mask`'s are included in # the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or # at least not wrapping the obsolete parts, can lead to a significant performance boost. # @@ -156,41 +204,66 @@ def get_transform(train): # even better to not load the masks at all, but this is not possible in this example, since the bounding boxes are # generated from the masks. # -# How do the datapoints behave inside a computation? -# -------------------------------------------------- +# .. _datapoint_unwrapping_behaviour: # -# Datapoints look and feel just like regular tensors. Everything that is supported on a plain :class:`torch.Tensor` -# also works on datapoints. -# Since for most operations involving datapoints, it cannot be safely inferred whether the result should retain the -# datapoint type, we choose to return a plain tensor instead of a datapoint (this might change, see note below): +# I had a Datapoint but now I have a Tensor. Help! +# ------------------------------------------------ +# +# For a lot of operations involving datapoints, we cannot safely infer whether +# the result should retain the datapoint type, so we choose to return a plain +# tensor instead of a datapoint (this might change, see note below): -assert isinstance(image, datapoints.Image) +assert isinstance(bboxes, datapoints.BoundingBoxes) -new_image = image + 0 +# Shift bboxes by 3 pixels in both H and W +new_bboxes = bboxes + 3 -assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image) +assert isinstance(new_bboxes, torch.Tensor) and not isinstance(new_bboxes, datapoints.BoundingBoxes) + +# %% +# If you're writing your own custom transforms or code involving datapoints, you +# can re-wrap the output into a datapoint by just calling their constructor, or +# by using the ``.wrap_like()`` class method: + +new_bboxes = bboxes + 3 +new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes) +assert isinstance(new_bboxes, datapoints.BoundingBoxes) # %% +# See more details above in :ref:`datapoint_creation`. +# +# .. note:: +# +# You never need to re-wrap manually if you're using the built-in transforms +# or their functional equivalents: this is automatically taken care of for +# you. +# # .. note:: # # This "unwrapping" behaviour is something we're actively seeking feedback on. If you find this surprising or if you # have any suggestions on how to better support your use-cases, please reach out to us via this issue: # https://github.com/pytorch/vision/issues/7319 # -# There are two exceptions to this rule: +# There are a few exceptions to this "unwrapping" rule: # -# 1. The operations :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, and :meth:`~torch.Tensor.requires_grad_` -# retain the datapoint type. -# 2. Inplace operations on datapoints cannot change the type of the datapoint they are called on. However, if you use -# the flow style, the returned value will be unwrapped: +# 1. Operations like :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, +# :meth:`torch.Tensor.detach` and :meth:`~torch.Tensor.requires_grad_` retain +# the datapoint type. +# 2. Inplace operations on datapoints like ``.add_()`` preserve they type. However, +# the **returned** value of inplace operations will be unwrapped into a pure +# tensor: image = datapoints.Image([[[0, 1], [1, 0]]]) new_image = image.add_(1).mul_(2) -assert isinstance(image, torch.Tensor) +# image got transformed in-place and is still an Image datapoint, but new_image +# is a Tensor. They share the same underlying data and they're equal, just +# different classes. +assert isinstance(image, datapoints.Image) print(image) assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image) assert (new_image == image).all() +assert new_image.data_ptr() == image.data_ptr() diff --git a/torchvision/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py index 912cc3bca08..7477b3652dc 100644 --- a/torchvision/datapoints/_bounding_box.py +++ b/torchvision/datapoints/_bounding_box.py @@ -42,7 +42,7 @@ class BoundingBoxes(Datapoint): canvas_size: Tuple[int, int] @classmethod - def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes: + def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override] bounding_boxes = tensor.as_subclass(cls) bounding_boxes.format = format bounding_boxes.canvas_size = canvas_size diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index 384273301de..fae3c18656b 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -14,6 +14,13 @@ class Datapoint(torch.Tensor): + """[Beta] Base class for all datapoints. + + You probably don't want to use this class unless you're defining your own + custom Datapoints. See + :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for details. + """ + @staticmethod def _to_tensor( data: Any, @@ -25,9 +32,13 @@ def _to_tensor( requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) + @classmethod + def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: + return tensor.as_subclass(cls) + @classmethod def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: - raise NotImplementedError + return cls._wrap(tensor) _NO_WRAPPING_EXCEPTIONS = { torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), diff --git a/torchvision/datapoints/_image.py b/torchvision/datapoints/_image.py index dccfc81a605..9b635e8e034 100644 --- a/torchvision/datapoints/_image.py +++ b/torchvision/datapoints/_image.py @@ -22,11 +22,6 @@ class Image(Datapoint): ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``. """ - @classmethod - def _wrap(cls, tensor: torch.Tensor) -> Image: - image = tensor.as_subclass(cls) - return image - def __new__( cls, data: Any, @@ -48,10 +43,6 @@ def __new__( return cls._wrap(tensor) - @classmethod - def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image: - return cls._wrap(tensor) - def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr() diff --git a/torchvision/datapoints/_mask.py b/torchvision/datapoints/_mask.py index 2b95eca72e2..95eda077929 100644 --- a/torchvision/datapoints/_mask.py +++ b/torchvision/datapoints/_mask.py @@ -22,10 +22,6 @@ class Mask(Datapoint): ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``. """ - @classmethod - def _wrap(cls, tensor: torch.Tensor) -> Mask: - return tensor.as_subclass(cls) - def __new__( cls, data: Any, @@ -41,11 +37,3 @@ def __new__( tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) return cls._wrap(tensor) - - @classmethod - def wrap_like( - cls, - other: Mask, - tensor: torch.Tensor, - ) -> Mask: - return cls._wrap(tensor) diff --git a/torchvision/datapoints/_video.py b/torchvision/datapoints/_video.py index 11d6e2a854d..842c05bf7e9 100644 --- a/torchvision/datapoints/_video.py +++ b/torchvision/datapoints/_video.py @@ -20,11 +20,6 @@ class Video(Datapoint): ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``. """ - @classmethod - def _wrap(cls, tensor: torch.Tensor) -> Video: - video = tensor.as_subclass(cls) - return video - def __new__( cls, data: Any, @@ -38,10 +33,6 @@ def __new__( raise ValueError return cls._wrap(tensor) - @classmethod - def wrap_like(cls, other: Video, tensor: torch.Tensor) -> Video: - return cls._wrap(tensor) - def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr() diff --git a/torchvision/prototype/datapoints/_label.py b/torchvision/prototype/datapoints/_label.py index 7ed2f7522b0..ac9b2d8912a 100644 --- a/torchvision/prototype/datapoints/_label.py +++ b/torchvision/prototype/datapoints/_label.py @@ -15,7 +15,7 @@ class _LabelBase(Datapoint): categories: Optional[Sequence[str]] @classmethod - def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: + def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: # type: ignore[override] label_base = tensor.as_subclass(cls) label_base.categories = categories return label_base diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 1eaa54102a4..bb3d59b551a 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -47,6 +47,11 @@ def _name_to_dispatcher(name): def register_kernel(dispatcher, datapoint_cls): + """Decorate a kernel to register it for a dispatcher and a (custom) datapoint type. + + See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage + details. + """ if isinstance(dispatcher, str): dispatcher = _name_to_dispatcher(name=dispatcher) return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) From 2030d208ba1044b97b8ceab91852858672a56cc8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 7 Aug 2023 09:46:09 +0200 Subject: [PATCH 2/8] register tensor and PIL kernel the same way as datapoints (#7797) Co-authored-by: Nicolas Hug --- test/test_transforms_v2_functional.py | 58 +-- test/test_transforms_v2_refactored.py | 168 ++++--- .../transforms/v2/functional/_augment.py | 25 +- .../transforms/v2/functional/_color.py | 290 +++++------- .../transforms/v2/functional/_geometry.py | 423 ++++++++---------- torchvision/transforms/v2/functional/_meta.py | 88 ++-- torchvision/transforms/v2/functional/_misc.py | 69 ++- .../transforms/v2/functional/_temporal.py | 20 +- .../transforms/v2/functional/_utils.py | 98 ++-- 9 files changed, 552 insertions(+), 687 deletions(-) diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index a05f1a3c3da..713737abbff 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -2,7 +2,6 @@ import math import os import re -from unittest import mock import numpy as np import PIL.Image @@ -25,7 +24,6 @@ from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes -from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY from torchvision.transforms.v2.utils import is_simple_tensor from transforms_v2_dispatcher_infos import DISPATCHER_INFOS from transforms_v2_kernel_infos import KERNEL_INFOS @@ -359,18 +357,6 @@ def test_scripted_smoke(self, info, args_kwargs, device): def test_scriptable(self, dispatcher): script(dispatcher) - @image_sample_inputs - def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on): - (image_datapoint, *other_args), kwargs = args_kwargs.load() - image_simple_tensor = torch.Tensor(image_datapoint) - - kernel_info = info.kernel_infos[datapoints.Image] - spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id) - - info.dispatcher(image_simple_tensor, *other_args, **kwargs) - - spy.assert_called_once() - @image_sample_inputs def test_simple_tensor_output_type(self, info, args_kwargs): (image_datapoint, *other_args), kwargs = args_kwargs.load() @@ -381,25 +367,6 @@ def test_simple_tensor_output_type(self, info, args_kwargs): # We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well assert type(output) is torch.Tensor - @make_info_args_kwargs_parametrization( - [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None], - args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image), - ) - def test_dispatch_pil(self, info, args_kwargs, spy_on): - (image_datapoint, *other_args), kwargs = args_kwargs.load() - - if image_datapoint.ndim > 3: - pytest.skip("Input is batched") - - image_pil = F.to_image_pil(image_datapoint) - - pil_kernel_info = info.pil_kernel_info - spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.id) - - info.dispatcher(image_pil, *other_args, **kwargs) - - spy.assert_called_once() - @make_info_args_kwargs_parametrization( [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None], args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image), @@ -416,28 +383,6 @@ def test_pil_output_type(self, info, args_kwargs): assert isinstance(output, PIL.Image.Image) - @make_info_args_kwargs_parametrization( - DISPATCHER_INFOS, - args_kwargs_fn=lambda info: info.sample_inputs(), - ) - def test_dispatch_datapoint(self, info, args_kwargs, spy_on): - (datapoint, *other_args), kwargs = args_kwargs.load() - - input_type = type(datapoint) - - wrapped_kernel = _KERNEL_REGISTRY[info.dispatcher][input_type] - - # In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the - # proper kernel was wrapped - if hasattr(wrapped_kernel, "__wrapped__"): - assert wrapped_kernel.__wrapped__ is info.kernels[input_type] - - spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__) - with mock.patch.dict(_KERNEL_REGISTRY[info.dispatcher], values={input_type: spy}): - info.dispatcher(datapoint, *other_args, **kwargs) - - spy.assert_called_once() - @make_info_args_kwargs_parametrization( DISPATCHER_INFOS, args_kwargs_fn=lambda info: info.sample_inputs(), @@ -449,6 +394,9 @@ def test_datapoint_output_type(self, info, args_kwargs): assert isinstance(output, type(datapoint)) + if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_format_bounding_boxes: + assert output.format == datapoint.format + @pytest.mark.parametrize( ("dispatcher_info", "datapoint_type", "kernel_info"), [ diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 8a858bf58c2..c910882f9fd 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -39,7 +39,7 @@ from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.functional import pil_modes_mapping from torchvision.transforms.v2 import functional as F -from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY +from torchvision.transforms.v2.functional._utils import _get_kernel, _KERNEL_REGISTRY, _noop, _register_kernel_internal @pytest.fixture(autouse=True) @@ -173,59 +173,32 @@ def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs): dispatcher_scripted(input.as_subclass(torch.Tensor), *args, **kwargs) -def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs): - """Checks if the dispatcher correctly dispatches the input to the corresponding kernel and that the input type is - preserved in doing so. For bounding boxes also checks that the format is preserved. - """ - input_type = type(input) - - if isinstance(input, datapoints.Datapoint): - wrapped_kernel = _KERNEL_REGISTRY[dispatcher][input_type] - - # In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the - # proper kernel was wrapped - if hasattr(wrapped_kernel, "__wrapped__"): - assert wrapped_kernel.__wrapped__ is kernel - - spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__) - with mock.patch.dict(_KERNEL_REGISTRY[dispatcher], values={input_type: spy}): - output = dispatcher(input, *args, **kwargs) - - spy.assert_called_once() - else: - with mock.patch(f"{dispatcher.__module__}.{kernel.__name__}", wraps=kernel) as spy: - output = dispatcher(input, *args, **kwargs) - - spy.assert_called_once() - - assert isinstance(output, input_type) - - if isinstance(input, datapoints.BoundingBoxes): - assert output.format == input.format - - def check_dispatcher( dispatcher, + # TODO: remove this parameter kernel, input, *args, check_scripted_smoke=True, - check_dispatch=True, **kwargs, ): unknown_input = object() + with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): + dispatcher(unknown_input, *args, **kwargs) + with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy: - with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): - dispatcher(unknown_input, *args, **kwargs) + output = dispatcher(input, *args, **kwargs) spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}") + assert isinstance(output, type(input)) + + if isinstance(input, datapoints.BoundingBoxes): + assert output.format == input.format + if check_scripted_smoke: _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs) - if check_dispatch: - _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs) - def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type): """Checks if the signature of the dispatcher matches the kernel signature.""" @@ -412,18 +385,20 @@ def transform(bbox): @pytest.mark.parametrize( - ("dispatcher", "registered_datapoint_clss"), + ("dispatcher", "registered_input_types"), [(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()], ) -def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss): +def test_exhaustive_kernel_registration(dispatcher, registered_input_types): missing = { + torch.Tensor, + PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, datapoints.Video, - } - registered_datapoint_clss + } - registered_input_types if missing: - names = sorted(f"datapoints.{cls.__name__}" for cls in missing) + names = sorted(str(t) for t in missing) raise AssertionError( "\n".join( [ @@ -1753,11 +1728,6 @@ def test_dispatcher(self, kernel, make_input, input_dtype, output_dtype, device, F.to_dtype, kernel, make_input(dtype=input_dtype, device=device), - # TODO: we could leave check_dispatch to True but it currently fails - # in _check_dispatcher_dispatch because there is no to_dtype() method on the datapoints. - # We should be able to put this back if we change the dispatch - # mechanism e.g. via https://github.com/pytorch/vision/pull/7733 - check_dispatch=False, dtype=output_dtype, scale=scale, ) @@ -2208,9 +2178,105 @@ def new_resize(dp, *args, **kwargs): t(torch.rand(3, 10, 10)).shape == (3, 224, 224) t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224) - def test_bad_disaptcher_name(self): - class CustomDatapoint(datapoints.Datapoint): + def test_errors(self): + with pytest.raises(ValueError, match="Could not find dispatcher with name"): + F.register_kernel("bad_name", datapoints.Image) + + with pytest.raises(ValueError, match="Kernels can only be registered on dispatchers"): + F.register_kernel(datapoints.Image, F.resize) + + with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"): + F.register_kernel(F.resize, object) + + with pytest.raises(ValueError, match="already has a kernel registered for type"): + F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) + + +class TestGetKernel: + # We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination + # would also be fine + KERNELS = { + torch.Tensor: F.resize_image_tensor, + PIL.Image.Image: F.resize_image_pil, + datapoints.Image: F.resize_image_tensor, + datapoints.BoundingBoxes: F.resize_bounding_boxes, + datapoints.Mask: F.resize_mask, + datapoints.Video: F.resize_video, + } + + def test_unsupported_types(self): + class MyTensor(torch.Tensor): pass - with pytest.raises(ValueError, match="Could not find dispatcher with name"): - F.register_kernel("bad_name", CustomDatapoint) + class MyPILImage(PIL.Image.Image): + pass + + for input_type in [str, int, object, MyTensor, MyPILImage]: + with pytest.raises( + TypeError, + match=( + "supports inputs of type torch.Tensor, PIL.Image.Image, " + "and subclasses of torchvision.datapoints.Datapoint" + ), + ): + _get_kernel(F.resize, input_type) + + def test_exact_match(self): + # We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the + # ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher + # here, register the kernels without wrapper, and check the exact matching afterwards. + def resize_with_pure_kernels(): + pass + + for input_type, kernel in self.KERNELS.items(): + _register_kernel_internal(resize_with_pure_kernels, input_type, datapoint_wrapper=False)(kernel) + + assert _get_kernel(resize_with_pure_kernels, input_type) is kernel + + def test_builtin_datapoint_subclass(self): + # We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the + # ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher + # here, register the kernels without wrapper, and check if subclasses of our builtin datapoints get dispatched + # to the kernel of the corresponding superclass + def resize_with_pure_kernels(): + pass + + class MyImage(datapoints.Image): + pass + + class MyBoundingBoxes(datapoints.BoundingBoxes): + pass + + class MyMask(datapoints.Mask): + pass + + class MyVideo(datapoints.Video): + pass + + for custom_datapoint_subclass in [ + MyImage, + MyBoundingBoxes, + MyMask, + MyVideo, + ]: + builtin_datapoint_class = custom_datapoint_subclass.__mro__[1] + builtin_datapoint_kernel = self.KERNELS[builtin_datapoint_class] + _register_kernel_internal(resize_with_pure_kernels, builtin_datapoint_class, datapoint_wrapper=False)( + builtin_datapoint_kernel + ) + + assert _get_kernel(resize_with_pure_kernels, custom_datapoint_subclass) is builtin_datapoint_kernel + + def test_datapoint_subclass(self): + class MyDatapoint(datapoints.Datapoint): + pass + + # Note that this will be an error in the future + assert _get_kernel(F.resize, MyDatapoint) is _noop + + def resize_my_datapoint(): + pass + + _register_kernel_internal(F.resize, MyDatapoint, datapoint_wrapper=False)(resize_my_datapoint) + + assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 95b4ed93786..89fa254374d 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -7,7 +7,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor +from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal @_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True) @@ -20,23 +20,16 @@ def erase( v: torch.Tensor, inplace: bool = False, ) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]: - if not torch.jit.is_scripting(): - _log_api_usage_once(erase) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(erase, type(inpt)) - return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) - elif isinstance(inpt, PIL.Image.Image): - return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(erase) + + kernel = _get_kernel(erase, type(inpt)) + return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) +@_register_kernel_internal(erase, torch.Tensor) @_register_kernel_internal(erase, datapoints.Image) def erase_image_tensor( image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False @@ -48,7 +41,7 @@ def erase_image_tensor( return image -@torch.jit.unused +@_register_kernel_internal(erase, PIL.Image.Image) def erase_image_pil( image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False ) -> PIL.Image.Image: diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 99dc1936259..71797fd2500 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -10,29 +10,20 @@ from torchvision.utils import _log_api_usage_once from ._misc import _num_value_bits, to_dtype_image_tensor -from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor +from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, datapoints.Video) def rgb_to_grayscale( inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], num_output_channels: int = 1 ) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]: - if not torch.jit.is_scripting(): - _log_api_usage_once(rgb_to_grayscale) - if num_output_channels not in (1, 3): - raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(rgb_to_grayscale, type(inpt)) - return kernel(inpt, num_output_channels=num_output_channels) - elif isinstance(inpt, PIL.Image.Image): - return rgb_to_grayscale_image_pil(inpt, num_output_channels=num_output_channels) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(rgb_to_grayscale) + + kernel = _get_kernel(rgb_to_grayscale, type(inpt)) + return kernel(inpt, num_output_channels=num_output_channels) # `to_grayscale` actually predates `rgb_to_grayscale` in v1, but only handles PIL images. Since `rgb_to_grayscale` is a @@ -56,12 +47,19 @@ def _rgb_to_grayscale_image_tensor( return l_img +@_register_kernel_internal(rgb_to_grayscale, torch.Tensor) @_register_kernel_internal(rgb_to_grayscale, datapoints.Image) def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: + if num_output_channels not in (1, 3): + raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True) -rgb_to_grayscale_image_pil = _FP.to_grayscale +@_register_kernel_internal(rgb_to_grayscale, PIL.Image.Image) +def rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: + if num_output_channels not in (1, 3): + raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") + return _FP.to_grayscale(image, num_output_channels=num_output_channels) def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: @@ -74,23 +72,16 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_brightness) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(adjust_brightness, type(inpt)) - return kernel(inpt, brightness_factor=brightness_factor) - elif isinstance(inpt, PIL.Image.Image): - return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + _log_api_usage_once(adjust_brightness) + kernel = _get_kernel(adjust_brightness, type(inpt)) + return kernel(inpt, brightness_factor=brightness_factor) + + +@_register_kernel_internal(adjust_brightness, torch.Tensor) @_register_kernel_internal(adjust_brightness, datapoints.Image) def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: if brightness_factor < 0: @@ -106,6 +97,7 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float return output if fp else output.to(image.dtype) +@_register_kernel_internal(adjust_brightness, PIL.Image.Image) def adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: float) -> PIL.Image.Image: return _FP.adjust_brightness(image, brightness_factor=brightness_factor) @@ -117,23 +109,16 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_saturation) - - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if torch.jit.is_scripting(): return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(adjust_saturation, type(inpt)) - return kernel(inpt, saturation_factor=saturation_factor) - elif isinstance(inpt, PIL.Image.Image): - return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + _log_api_usage_once(adjust_saturation) + + kernel = _get_kernel(adjust_saturation, type(inpt)) + return kernel(inpt, saturation_factor=saturation_factor) + +@_register_kernel_internal(adjust_saturation, torch.Tensor) @_register_kernel_internal(adjust_saturation, datapoints.Image) def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor: if saturation_factor < 0: @@ -153,7 +138,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float return _blend(image, grayscale_image, saturation_factor) -adjust_saturation_image_pil = _FP.adjust_saturation +adjust_saturation_image_pil = _register_kernel_internal(adjust_saturation, PIL.Image.Image)(_FP.adjust_saturation) @_register_kernel_internal(adjust_saturation, datapoints.Video) @@ -163,23 +148,16 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_contrast) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(adjust_contrast, type(inpt)) - return kernel(inpt, contrast_factor=contrast_factor) - elif isinstance(inpt, PIL.Image.Image): - return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(adjust_contrast) + + kernel = _get_kernel(adjust_contrast, type(inpt)) + return kernel(inpt, contrast_factor=contrast_factor) +@_register_kernel_internal(adjust_contrast, torch.Tensor) @_register_kernel_internal(adjust_contrast, datapoints.Image) def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor: if contrast_factor < 0: @@ -199,7 +177,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> return _blend(image, mean, contrast_factor) -adjust_contrast_image_pil = _FP.adjust_contrast +adjust_contrast_image_pil = _register_kernel_internal(adjust_contrast, PIL.Image.Image)(_FP.adjust_contrast) @_register_kernel_internal(adjust_contrast, datapoints.Video) @@ -209,23 +187,16 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_sharpness(inpt: datapoints._InputTypeJIT, sharpness_factor: float) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_sharpness) - - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if torch.jit.is_scripting(): return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(adjust_sharpness, type(inpt)) - return kernel(inpt, sharpness_factor=sharpness_factor) - elif isinstance(inpt, PIL.Image.Image): - return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + _log_api_usage_once(adjust_sharpness) + + kernel = _get_kernel(adjust_sharpness, type(inpt)) + return kernel(inpt, sharpness_factor=sharpness_factor) + +@_register_kernel_internal(adjust_sharpness, torch.Tensor) @_register_kernel_internal(adjust_sharpness, datapoints.Image) def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor: num_channels, height, width = image.shape[-3:] @@ -279,7 +250,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) return output -adjust_sharpness_image_pil = _FP.adjust_sharpness +adjust_sharpness_image_pil = _register_kernel_internal(adjust_sharpness, PIL.Image.Image)(_FP.adjust_sharpness) @_register_kernel_internal(adjust_sharpness, datapoints.Video) @@ -289,21 +260,13 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_hue) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(adjust_hue, type(inpt)) - return kernel(inpt, hue_factor=hue_factor) - elif isinstance(inpt, PIL.Image.Image): - return adjust_hue_image_pil(inpt, hue_factor=hue_factor) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(adjust_hue) + + kernel = _get_kernel(adjust_hue, type(inpt)) + return kernel(inpt, hue_factor=hue_factor) def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor: @@ -370,6 +333,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor: return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3) +@_register_kernel_internal(adjust_hue, torch.Tensor) @_register_kernel_internal(adjust_hue, datapoints.Image) def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor: if not (-0.5 <= hue_factor <= 0.5): @@ -398,7 +362,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten return to_dtype_image_tensor(image_hue_adj, orig_dtype, scale=True) -adjust_hue_image_pil = _FP.adjust_hue +adjust_hue_image_pil = _register_kernel_internal(adjust_hue, PIL.Image.Image)(_FP.adjust_hue) @_register_kernel_internal(adjust_hue, datapoints.Video) @@ -408,23 +372,16 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_gamma) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(adjust_gamma, type(inpt)) - return kernel(inpt, gamma=gamma, gain=gain) - elif isinstance(inpt, PIL.Image.Image): - return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + _log_api_usage_once(adjust_gamma) + kernel = _get_kernel(adjust_gamma, type(inpt)) + return kernel(inpt, gamma=gamma, gain=gain) + + +@_register_kernel_internal(adjust_gamma, torch.Tensor) @_register_kernel_internal(adjust_gamma, datapoints.Image) def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor: if gamma < 0: @@ -445,7 +402,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1 return to_dtype_image_tensor(output, image.dtype, scale=True) -adjust_gamma_image_pil = _FP.adjust_gamma +adjust_gamma_image_pil = _register_kernel_internal(adjust_gamma, PIL.Image.Image)(_FP.adjust_gamma) @_register_kernel_internal(adjust_gamma, datapoints.Video) @@ -455,23 +412,16 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(posterize) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return posterize_image_tensor(inpt, bits=bits) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(posterize, type(inpt)) - return kernel(inpt, bits=bits) - elif isinstance(inpt, PIL.Image.Image): - return posterize_image_pil(inpt, bits=bits) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(posterize) + + kernel = _get_kernel(posterize, type(inpt)) + return kernel(inpt, bits=bits) +@_register_kernel_internal(posterize, torch.Tensor) @_register_kernel_internal(posterize, datapoints.Image) def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: if image.is_floating_point(): @@ -486,7 +436,7 @@ def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: return image & mask -posterize_image_pil = _FP.posterize +posterize_image_pil = _register_kernel_internal(posterize, PIL.Image.Image)(_FP.posterize) @_register_kernel_internal(posterize, datapoints.Video) @@ -496,23 +446,16 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(solarize) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return solarize_image_tensor(inpt, threshold=threshold) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(solarize, type(inpt)) - return kernel(inpt, threshold=threshold) - elif isinstance(inpt, PIL.Image.Image): - return solarize_image_pil(inpt, threshold=threshold) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + _log_api_usage_once(solarize) + + kernel = _get_kernel(solarize, type(inpt)) + return kernel(inpt, threshold=threshold) + +@_register_kernel_internal(solarize, torch.Tensor) @_register_kernel_internal(solarize, datapoints.Image) def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor: if threshold > _max_value(image.dtype): @@ -521,7 +464,7 @@ def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor return torch.where(image >= threshold, invert_image_tensor(image), image) -solarize_image_pil = _FP.solarize +solarize_image_pil = _register_kernel_internal(solarize, PIL.Image.Image)(_FP.solarize) @_register_kernel_internal(solarize, datapoints.Video) @@ -531,25 +474,16 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(autocontrast) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return autocontrast_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(autocontrast, type(inpt)) - return kernel( - inpt, - ) - elif isinstance(inpt, PIL.Image.Image): - return autocontrast_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + _log_api_usage_once(autocontrast) + + kernel = _get_kernel(autocontrast, type(inpt)) + return kernel(inpt) + +@_register_kernel_internal(autocontrast, torch.Tensor) @_register_kernel_internal(autocontrast, datapoints.Image) def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: c = image.shape[-3] @@ -580,7 +514,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: return diff.div_(inv_scale).clamp_(0, bound).to(image.dtype) -autocontrast_image_pil = _FP.autocontrast +autocontrast_image_pil = _register_kernel_internal(autocontrast, PIL.Image.Image)(_FP.autocontrast) @_register_kernel_internal(autocontrast, datapoints.Video) @@ -590,25 +524,16 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(equalize) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return equalize_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(equalize, type(inpt)) - return kernel( - inpt, - ) - elif isinstance(inpt, PIL.Image.Image): - return equalize_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + _log_api_usage_once(equalize) + kernel = _get_kernel(equalize, type(inpt)) + return kernel(inpt) + + +@_register_kernel_internal(equalize, torch.Tensor) @_register_kernel_internal(equalize, datapoints.Image) def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.numel() == 0: @@ -679,7 +604,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: return to_dtype_image_tensor(output, output_dtype, scale=True) -equalize_image_pil = _FP.equalize +equalize_image_pil = _register_kernel_internal(equalize, PIL.Image.Image)(_FP.equalize) @_register_kernel_internal(equalize, datapoints.Video) @@ -689,25 +614,16 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def invert(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(invert) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return invert_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(invert, type(inpt)) - return kernel( - inpt, - ) - elif isinstance(inpt, PIL.Image.Image): - return invert_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(invert) + + kernel = _get_kernel(invert, type(inpt)) + return kernel(inpt) +@_register_kernel_internal(invert, torch.Tensor) @_register_kernel_internal(invert, datapoints.Image) def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.is_floating_point(): @@ -719,7 +635,7 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: return image.bitwise_xor((1 << _num_value_bits(image.dtype)) - 1) -invert_image_pil = _FP.invert +invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert) @_register_kernel_internal(invert, datapoints.Video) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 21f2aa8df0a..bb19def2c93 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -25,13 +25,7 @@ from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil -from ._utils import ( - _get_kernel, - _register_explicit_noop, - _register_five_ten_crop_kernel, - _register_kernel_internal, - is_simple_tensor, -) +from ._utils import _get_kernel, _register_explicit_noop, _register_five_ten_crop_kernel, _register_kernel_internal def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -46,30 +40,22 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(horizontal_flip) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return horizontal_flip_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(horizontal_flip, type(inpt)) - return kernel( - inpt, - ) - elif isinstance(inpt, PIL.Image.Image): - return horizontal_flip_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(horizontal_flip) + + kernel = _get_kernel(horizontal_flip, type(inpt)) + return kernel(inpt) +@_register_kernel_internal(horizontal_flip, torch.Tensor) @_register_kernel_internal(horizontal_flip, datapoints.Image) def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: return image.flip(-1) +@_register_kernel_internal(horizontal_flip, PIL.Image.Image) def horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: return _FP.hflip(image) @@ -110,30 +96,22 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: def vertical_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(vertical_flip) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return vertical_flip_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(vertical_flip, type(inpt)) - return kernel( - inpt, - ) - elif isinstance(inpt, PIL.Image.Image): - return vertical_flip_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(vertical_flip) + + kernel = _get_kernel(vertical_flip, type(inpt)) + return kernel(inpt) +@_register_kernel_internal(vertical_flip, torch.Tensor) @_register_kernel_internal(vertical_flip, datapoints.Image) def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: return image.flip(-2) +@_register_kernel_internal(vertical_flip, PIL.Image.Image) def vertical_flip_image_pil(image: PIL.Image) -> PIL.Image: return _FP.vflip(image) @@ -199,24 +177,16 @@ def resize( max_size: Optional[int] = None, antialias: Optional[Union[str, bool]] = "warn", ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(resize) - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(resize, type(inpt)) - return kernel(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) - elif isinstance(inpt, PIL.Image.Image): - if antialias is False: - warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") - return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + if torch.jit.is_scripting(): + return resize_image_tensor(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) + + _log_api_usage_once(resize) + + kernel = _get_kernel(resize, type(inpt)) + return kernel(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) +@_register_kernel_internal(resize, torch.Tensor) @_register_kernel_internal(resize, datapoints.Image) def resize_image_tensor( image: torch.Tensor, @@ -297,7 +267,6 @@ def resize_image_tensor( return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) -@torch.jit.unused def resize_image_pil( image: PIL.Image.Image, size: Union[Sequence[int], int], @@ -319,6 +288,19 @@ def resize_image_pil( return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation]) +@_register_kernel_internal(resize, PIL.Image.Image) +def _resize_image_pil_dispatch( + image: PIL.Image.Image, + size: Union[Sequence[int], int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[Union[str, bool]] = "warn", +) -> PIL.Image.Image: + if antialias is False: + warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") + return resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size) + + def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -391,26 +373,10 @@ def affine( fill: datapoints._FillTypeJIT = None, center: Optional[List[float]] = None, ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(affine) - - # TODO: consider deprecating integers from angle and shear on the future - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return affine_image_tensor( inpt, - angle, - translate=translate, - scale=scale, - shear=shear, - interpolation=interpolation, - fill=fill, - center=center, - ) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(affine, type(inpt)) - return kernel( - inpt, - angle, + angle=angle, translate=translate, scale=scale, shear=shear, @@ -418,22 +384,20 @@ def affine( fill=fill, center=center, ) - elif isinstance(inpt, PIL.Image.Image): - return affine_image_pil( - inpt, - angle, - translate=translate, - scale=scale, - shear=shear, - interpolation=interpolation, - fill=fill, - center=center, - ) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(affine) + + kernel = _get_kernel(affine, type(inpt)) + return kernel( + inpt, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) def _affine_parse_args( @@ -684,6 +648,7 @@ def _affine_grid( return output_grid.view(1, oh, ow, 2) +@_register_kernel_internal(affine, torch.Tensor) @_register_kernel_internal(affine, datapoints.Image) def affine_image_tensor( image: torch.Tensor, @@ -736,7 +701,7 @@ def affine_image_tensor( return output -@torch.jit.unused +@_register_kernel_internal(affine, PIL.Image.Image) def affine_image_pil( image: PIL.Image.Image, angle: Union[int, float], @@ -983,23 +948,18 @@ def rotate( center: Optional[List[float]] = None, fill: datapoints._FillTypeJIT = None, ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(rotate) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(rotate, type(inpt)) - return kernel(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) - elif isinstance(inpt, PIL.Image.Image): - return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." + if torch.jit.is_scripting(): + return rotate_image_tensor( + inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center ) + _log_api_usage_once(rotate) + kernel = _get_kernel(rotate, type(inpt)) + return kernel(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + + +@_register_kernel_internal(rotate, torch.Tensor) @_register_kernel_internal(rotate, datapoints.Image) def rotate_image_tensor( image: torch.Tensor, @@ -1045,7 +1005,7 @@ def rotate_image_tensor( return output.reshape(shape[:-3] + (num_channels, new_height, new_width)) -@torch.jit.unused +@_register_kernel_internal(rotate, PIL.Image.Image) def rotate_image_pil( image: PIL.Image.Image, angle: float, @@ -1162,22 +1122,13 @@ def pad( fill: Optional[Union[int, float, List[float]]] = None, padding_mode: str = "constant", ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(pad) + if torch.jit.is_scripting(): + return pad_image_tensor(inpt, padding=padding, fill=fill, padding_mode=padding_mode) - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) + _log_api_usage_once(pad) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(pad, type(inpt)) - return kernel(inpt, padding, fill=fill, padding_mode=padding_mode) - elif isinstance(inpt, PIL.Image.Image): - return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + kernel = _get_kernel(pad, type(inpt)) + return kernel(inpt, padding=padding, fill=fill, padding_mode=padding_mode) def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: @@ -1204,6 +1155,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: return [pad_left, pad_right, pad_top, pad_bottom] +@_register_kernel_internal(pad, torch.Tensor) @_register_kernel_internal(pad, datapoints.Image) def pad_image_tensor( image: torch.Tensor, @@ -1303,7 +1255,7 @@ def _pad_with_vector_fill( return output -pad_image_pil = _FP.pad +pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad) @_register_kernel_internal(pad, datapoints.Mask) @@ -1385,23 +1337,16 @@ def pad_video( def crop(inpt: datapoints._InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(crop) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return crop_image_tensor(inpt, top, left, height, width) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(crop, type(inpt)) - return kernel(inpt, top, left, height, width) - elif isinstance(inpt, PIL.Image.Image): - return crop_image_pil(inpt, top, left, height, width) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + if torch.jit.is_scripting(): + return crop_image_tensor(inpt, top=top, left=left, height=height, width=width) + + _log_api_usage_once(crop) + kernel = _get_kernel(crop, type(inpt)) + return kernel(inpt, top=top, left=left, height=height, width=width) + +@_register_kernel_internal(crop, torch.Tensor) @_register_kernel_internal(crop, datapoints.Image) def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: h, w = image.shape[-2:] @@ -1422,6 +1367,7 @@ def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, wid crop_image_pil = _FP.crop +_register_kernel_internal(crop, PIL.Image.Image)(crop_image_pil) def crop_bounding_boxes( @@ -1484,25 +1430,28 @@ def perspective( fill: datapoints._FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(perspective) - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return perspective_image_tensor( - inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients - ) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(perspective, type(inpt)) - return kernel(inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients) - elif isinstance(inpt, PIL.Image.Image): - return perspective_image_pil( - inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients - ) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." + inpt, + startpoints=startpoints, + endpoints=endpoints, + interpolation=interpolation, + fill=fill, + coefficients=coefficients, ) + _log_api_usage_once(perspective) + + kernel = _get_kernel(perspective, type(inpt)) + return kernel( + inpt, + startpoints=startpoints, + endpoints=endpoints, + interpolation=interpolation, + fill=fill, + coefficients=coefficients, + ) + def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/ @@ -1551,6 +1500,7 @@ def _perspective_coefficients( raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.") +@_register_kernel_internal(perspective, torch.Tensor) @_register_kernel_internal(perspective, datapoints.Image) def perspective_image_tensor( image: torch.Tensor, @@ -1598,7 +1548,7 @@ def perspective_image_tensor( return output -@torch.jit.unused +@_register_kernel_internal(perspective, PIL.Image.Image) def perspective_image_pil( image: PIL.Image.Image, startpoints: Optional[List[List[int]]], @@ -1787,29 +1737,19 @@ def elastic( interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: datapoints._FillTypeJIT = None, ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(elastic) - - if not isinstance(displacement, torch.Tensor): - raise TypeError("Argument displacement should be a Tensor") - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(elastic, type(inpt)) - return kernel(inpt, displacement, interpolation=interpolation, fill=fill) - elif isinstance(inpt, PIL.Image.Image): - return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + if torch.jit.is_scripting(): + return elastic_image_tensor(inpt, displacement=displacement, interpolation=interpolation, fill=fill) + + _log_api_usage_once(elastic) + + kernel = _get_kernel(elastic, type(inpt)) + return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill) elastic_transform = elastic +@_register_kernel_internal(elastic, torch.Tensor) @_register_kernel_internal(elastic, datapoints.Image) def elastic_image_tensor( image: torch.Tensor, @@ -1867,7 +1807,7 @@ def elastic_image_tensor( return output -@torch.jit.unused +@_register_kernel_internal(elastic, PIL.Image.Image) def elastic_image_pil( image: PIL.Image.Image, displacement: torch.Tensor, @@ -1990,21 +1930,13 @@ def elastic_video( def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(center_crop) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return center_crop_image_tensor(inpt, output_size) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(center_crop, type(inpt)) - return kernel(inpt, output_size) - elif isinstance(inpt, PIL.Image.Image): - return center_crop_image_pil(inpt, output_size) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + if torch.jit.is_scripting(): + return center_crop_image_tensor(inpt, output_size=output_size) + + _log_api_usage_once(center_crop) + + kernel = _get_kernel(center_crop, type(inpt)) + return kernel(inpt, output_size=output_size) def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: @@ -2034,6 +1966,7 @@ def _center_crop_compute_crop_anchor( return crop_top, crop_left +@_register_kernel_internal(center_crop, torch.Tensor) @_register_kernel_internal(center_crop, datapoints.Image) def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor: crop_height, crop_width = _center_crop_parse_output_size(output_size) @@ -2054,7 +1987,7 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)] -@torch.jit.unused +@_register_kernel_internal(center_crop, PIL.Image.Image) def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: crop_height, crop_width = _center_crop_parse_output_size(output_size) image_height, image_width = get_size_image_pil(image) @@ -2125,25 +2058,34 @@ def resized_crop( interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(resized_crop) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return resized_crop_image_tensor( - inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation - ) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(resized_crop, type(inpt)) - return kernel(inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation) - elif isinstance(inpt, PIL.Image.Image): - return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." + inpt, + top=top, + left=left, + height=height, + width=width, + size=size, + interpolation=interpolation, + antialias=antialias, ) + _log_api_usage_once(resized_crop) + + kernel = _get_kernel(resized_crop, type(inpt)) + return kernel( + inpt, + top=top, + left=left, + height=height, + width=width, + size=size, + interpolation=interpolation, + antialias=antialias, + ) + +@_register_kernel_internal(resized_crop, torch.Tensor) @_register_kernel_internal(resized_crop, datapoints.Image) def resized_crop_image_tensor( image: torch.Tensor, @@ -2159,7 +2101,6 @@ def resized_crop_image_tensor( return resize_image_tensor(image, size, interpolation=interpolation, antialias=antialias) -@torch.jit.unused def resized_crop_image_pil( image: PIL.Image.Image, top: int, @@ -2173,6 +2114,30 @@ def resized_crop_image_pil( return resize_image_pil(image, size, interpolation=interpolation) +@_register_kernel_internal(resized_crop, PIL.Image.Image) +def resized_crop_image_pil_dispatch( + image: PIL.Image.Image, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", +) -> PIL.Image.Image: + if antialias is False: + warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") + return resized_crop_image_pil( + image, + top=top, + left=left, + height=height, + width=width, + size=size, + interpolation=interpolation, + ) + + def resized_crop_bounding_boxes( bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, @@ -2244,21 +2209,13 @@ def five_crop( datapoints._InputTypeJIT, datapoints._InputTypeJIT, ]: - if not torch.jit.is_scripting(): - _log_api_usage_once(five_crop) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return five_crop_image_tensor(inpt, size) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(five_crop, type(inpt)) - return kernel(inpt, size) - elif isinstance(inpt, PIL.Image.Image): - return five_crop_image_pil(inpt, size) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + if torch.jit.is_scripting(): + return five_crop_image_tensor(inpt, size=size) + + _log_api_usage_once(five_crop) + + kernel = _get_kernel(five_crop, type(inpt)) + return kernel(inpt, size=size) def _parse_five_crop_size(size: List[int]) -> List[int]: @@ -2275,6 +2232,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]: return size +@_register_five_ten_crop_kernel(five_crop, torch.Tensor) @_register_five_ten_crop_kernel(five_crop, datapoints.Image) def five_crop_image_tensor( image: torch.Tensor, size: List[int] @@ -2294,7 +2252,7 @@ def five_crop_image_tensor( return tl, tr, bl, br, center -@torch.jit.unused +@_register_five_ten_crop_kernel(five_crop, PIL.Image.Image) def five_crop_image_pil( image: PIL.Image.Image, size: List[int] ) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: @@ -2335,23 +2293,16 @@ def ten_crop( datapoints._InputTypeJIT, datapoints._InputTypeJIT, ]: - if not torch.jit.is_scripting(): - _log_api_usage_once(ten_crop) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(ten_crop, type(inpt)) - return kernel(inpt, size, vertical_flip=vertical_flip) - elif isinstance(inpt, PIL.Image.Image): - return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + if torch.jit.is_scripting(): + return ten_crop_image_tensor(inpt, size=size, vertical_flip=vertical_flip) + + _log_api_usage_once(ten_crop) + + kernel = _get_kernel(ten_crop, type(inpt)) + return kernel(inpt, size=size, vertical_flip=vertical_flip) +@_register_five_ten_crop_kernel(ten_crop, torch.Tensor) @_register_five_ten_crop_kernel(ten_crop, datapoints.Image) def ten_crop_image_tensor( image: torch.Tensor, size: List[int], vertical_flip: bool = False @@ -2379,7 +2330,7 @@ def ten_crop_image_tensor( return non_flipped + flipped -@torch.jit.unused +@_register_five_ten_crop_kernel(ten_crop, PIL.Image.Image) def ten_crop_image_pil( image: PIL.Image.Image, size: List[int], vertical_flip: bool = False ) -> Tuple[ diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index a4bfe7df8e4..fc1aa05f319 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -13,23 +13,16 @@ @_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]: - if not torch.jit.is_scripting(): - _log_api_usage_once(get_dimensions) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return get_dimensions_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(get_dimensions, type(inpt)) - return kernel(inpt) - elif isinstance(inpt, PIL.Image.Image): - return get_dimensions_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(get_dimensions) + + kernel = _get_kernel(get_dimensions, type(inpt)) + return kernel(inpt) +@_register_kernel_internal(get_dimensions, torch.Tensor) @_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False) def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: chw = list(image.shape[-3:]) @@ -43,7 +36,7 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") -get_dimensions_image_pil = _FP.get_dimensions +get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions) @_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False) @@ -53,23 +46,16 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]: @_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int: - if not torch.jit.is_scripting(): - _log_api_usage_once(get_num_channels) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return get_num_channels_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(get_num_channels, type(inpt)) - return kernel(inpt) - elif isinstance(inpt, PIL.Image.Image): - return get_num_channels_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(get_num_channels) + + kernel = _get_kernel(get_num_channels, type(inpt)) + return kernel(inpt) +@_register_kernel_internal(get_num_channels, torch.Tensor) @_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False) def get_num_channels_image_tensor(image: torch.Tensor) -> int: chw = image.shape[-3:] @@ -82,7 +68,7 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int: raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") -get_num_channels_image_pil = _FP.get_image_num_channels +get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels) @_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False) @@ -96,23 +82,16 @@ def get_num_channels_video(video: torch.Tensor) -> int: def get_size(inpt: datapoints._InputTypeJIT) -> List[int]: - if not torch.jit.is_scripting(): - _log_api_usage_once(get_size) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return get_size_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(get_size, type(inpt)) - return kernel(inpt) - elif isinstance(inpt, PIL.Image.Image): - return get_size_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(get_size) + + kernel = _get_kernel(get_size, type(inpt)) + return kernel(inpt) +@_register_kernel_internal(get_size, torch.Tensor) @_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False) def get_size_image_tensor(image: torch.Tensor) -> List[int]: hw = list(image.shape[-2:]) @@ -123,7 +102,7 @@ def get_size_image_tensor(image: torch.Tensor) -> List[int]: raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") -@torch.jit.unused +@_register_kernel_internal(get_size, PIL.Image.Image) def get_size_image_pil(image: PIL.Image.Image) -> List[int]: width, height = _FP.get_image_size(image) return [height, width] @@ -146,21 +125,16 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int] @_register_unsupported_type(PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask) def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int: - if not torch.jit.is_scripting(): - _log_api_usage_once(get_num_frames) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return get_num_frames_video(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(get_num_frames, type(inpt)) - return kernel(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(get_num_frames) + + kernel = _get_kernel(get_num_frames, type(inpt)) + return kernel(inpt) +@_register_kernel_internal(get_num_frames, torch.Tensor) @_register_kernel_internal(get_num_frames, datapoints.Video, datapoint_wrapper=False) def get_num_frames_video(video: torch.Tensor) -> int: return video.shape[-4] diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 90a3e44e9d3..e3a800ea79c 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -11,13 +11,7 @@ from torchvision.utils import _log_api_usage_once -from ._utils import ( - _get_kernel, - _register_explicit_noop, - _register_kernel_internal, - _register_unsupported_type, - is_simple_tensor, -) +from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, _register_unsupported_type @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @@ -28,19 +22,16 @@ def normalize( std: List[float], inplace: bool = False, ) -> torch.Tensor: - if not torch.jit.is_scripting(): - _log_api_usage_once(normalize) - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(normalize, type(inpt)) - return kernel(inpt, mean=mean, std=std, inplace=inplace) - else: - raise TypeError( - f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead." - ) + + _log_api_usage_once(normalize) + + kernel = _get_kernel(normalize, type(inpt)) + return kernel(inpt, mean=mean, std=std, inplace=inplace) +@_register_kernel_internal(normalize, torch.Tensor) @_register_kernel_internal(normalize, datapoints.Image) def normalize_image_tensor( image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False @@ -86,21 +77,13 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in def gaussian_blur( inpt: datapoints._InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(gaussian_blur) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(gaussian_blur, type(inpt)) - return kernel(inpt, kernel_size=kernel_size, sigma=sigma) - elif isinstance(inpt, PIL.Image.Image): - return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(gaussian_blur) + + kernel = _get_kernel(gaussian_blur, type(inpt)) + return kernel(inpt, kernel_size=kernel_size, sigma=sigma) def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor: @@ -119,6 +102,7 @@ def _get_gaussian_kernel2d( return kernel2d +@_register_kernel_internal(gaussian_blur, torch.Tensor) @_register_kernel_internal(gaussian_blur, datapoints.Image) def gaussian_blur_image_tensor( image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None @@ -184,7 +168,7 @@ def gaussian_blur_image_tensor( return output -@torch.jit.unused +@_register_kernel_internal(gaussian_blur, PIL.Image.Image) def gaussian_blur_image_pil( image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> PIL.Image.Image: @@ -200,21 +184,17 @@ def gaussian_blur_video( return gaussian_blur_image_tensor(video, kernel_size, sigma) +@_register_unsupported_type(PIL.Image.Image) def to_dtype( inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(to_dtype) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return to_dtype_image_tensor(inpt, dtype, scale=scale) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(to_dtype, type(inpt)) - return kernel(inpt, dtype, scale=scale) - else: - raise TypeError( - f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead." - ) + if torch.jit.is_scripting(): + return to_dtype_image_tensor(inpt, dtype=dtype, scale=scale) + + _log_api_usage_once(to_dtype) + + kernel = _get_kernel(to_dtype, type(inpt)) + return kernel(inpt, dtype=dtype, scale=scale) def _num_value_bits(dtype: torch.dtype) -> int: @@ -232,6 +212,7 @@ def _num_value_bits(dtype: torch.dtype) -> int: raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.") +@_register_kernel_internal(to_dtype, torch.Tensor) @_register_kernel_internal(to_dtype, datapoints.Image) def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: diff --git a/torchvision/transforms/v2/functional/_temporal.py b/torchvision/transforms/v2/functional/_temporal.py index 52c745f9901..62d12cb4b4e 100644 --- a/torchvision/transforms/v2/functional/_temporal.py +++ b/torchvision/transforms/v2/functional/_temporal.py @@ -5,27 +5,23 @@ from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor +from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal @_register_explicit_noop( PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True ) def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) -> datapoints._VideoTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(uniform_temporal_subsample) + if torch.jit.is_scripting(): + return uniform_temporal_subsample_video(inpt, num_samples=num_samples) - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return uniform_temporal_subsample_video(inpt, num_samples) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(uniform_temporal_subsample, type(inpt)) - return kernel(inpt, num_samples) - else: - raise TypeError( - f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead." - ) + _log_api_usage_once(uniform_temporal_subsample) + kernel = _get_kernel(uniform_temporal_subsample, type(inpt)) + return kernel(inpt, num_samples=num_samples) + +@_register_kernel_internal(uniform_temporal_subsample, torch.Tensor) @_register_kernel_internal(uniform_temporal_subsample, datapoints.Video) def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor: # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index bb3d59b551a..576a2b99dbf 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -23,15 +23,17 @@ def wrapper(inpt, *args, **kwargs): return wrapper -def _register_kernel_internal(dispatcher, datapoint_cls, *, datapoint_wrapper=True): +def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True): registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) - if datapoint_cls in registry: - raise TypeError( - f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'." - ) + if input_type in registry: + raise ValueError(f"Dispatcher {dispatcher} already has a kernel registered for type {input_type}.") def decorator(kernel): - registry[datapoint_cls] = _kernel_datapoint_wrapper(kernel) if datapoint_wrapper else kernel + registry[input_type] = ( + _kernel_datapoint_wrapper(kernel) + if issubclass(input_type, datapoints.Datapoint) and datapoint_wrapper + else kernel + ) return kernel return decorator @@ -43,7 +45,9 @@ def _name_to_dispatcher(name): try: return getattr(torchvision.transforms.v2.functional, name) except AttributeError: - raise ValueError(f"Could not find dispatcher with name '{name}'.") from None + raise ValueError( + f"Could not find dispatcher with name '{name}' in torchvision.transforms.v2.functional." + ) from None def register_kernel(dispatcher, datapoint_cls): @@ -54,22 +58,57 @@ def register_kernel(dispatcher, datapoint_cls): """ if isinstance(dispatcher, str): dispatcher = _name_to_dispatcher(name=dispatcher) + elif not ( + callable(dispatcher) + and getattr(dispatcher, "__module__", "").startswith("torchvision.transforms.v2.functional") + ): + raise ValueError( + f"Kernels can only be registered on dispatchers from the torchvision.transforms.v2.functional namespace, " + f"but got {dispatcher}." + ) + + if not ( + isinstance(datapoint_cls, type) + and issubclass(datapoint_cls, datapoints.Datapoint) + and datapoint_cls is not datapoints.Datapoint + ): + raise ValueError( + f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, " + f"but got {datapoint_cls}." + ) + return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) -def _get_kernel(dispatcher, datapoint_cls): +def _get_kernel(dispatcher, input_type): registry = _KERNEL_REGISTRY.get(dispatcher) if not registry: - raise ValueError(f"No kernel registered for dispatcher '{dispatcher.__name__}'.") - - if datapoint_cls in registry: - return registry[datapoint_cls] - - for registered_cls, kernel in registry.items(): - if issubclass(datapoint_cls, registered_cls): - return kernel - - return _noop + raise ValueError(f"No kernel registered for dispatcher {dispatcher.__name__}.") + + # In case we have an exact type match, we take a shortcut. + if input_type in registry: + return registry[input_type] + + # In case of datapoints, we check if we have a kernel for a superclass registered + if issubclass(input_type, datapoints.Datapoint): + # Since we have already checked for an exact match above, we can start the traversal at the superclass. + for cls in input_type.__mro__[1:]: + if cls is datapoints.Datapoint: + # We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the + # MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't + # allow kernels to be registered for datapoints.Datapoint anyway. + break + elif cls in registry: + return registry[cls] + + # Note that in the future we are not going to return a noop here, but rather raise the error below + return _noop + + raise TypeError( + f"Dispatcher {dispatcher} supports inputs of type torch.Tensor, PIL.Image.Image, " + f"and subclasses of torchvision.datapoints.Datapoint, " + f"but got {input_type} instead." + ) # Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate @@ -101,7 +140,9 @@ def decorator(dispatcher): f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. " f"This will likely change in the future." ) - register_kernel(dispatcher, cls)(functools.partial(_noop, __msg__=msg if warn_passthrough else None)) + _register_kernel_internal(dispatcher, cls, datapoint_wrapper=False)( + functools.partial(_noop, __msg__=msg if warn_passthrough else None) + ) return dispatcher return decorator @@ -115,13 +156,15 @@ def _noop(inpt, *args, __msg__=None, **kwargs): # TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that # to error later, this decorator can be removed, since the error will be raised by _get_kernel -def _register_unsupported_type(*datapoints_classes): +def _register_unsupported_type(*input_types): def kernel(inpt, *args, __dispatcher_name__, **kwargs): raise TypeError(f"F.{__dispatcher_name__} does not support inputs of type {type(inpt)}.") def decorator(dispatcher): - for cls in datapoints_classes: - register_kernel(dispatcher, cls)(functools.partial(kernel, __dispatcher_name__=dispatcher.__name__)) + for input_type in input_types: + _register_kernel_internal(dispatcher, input_type, datapoint_wrapper=False)( + functools.partial(kernel, __dispatcher_name__=dispatcher.__name__) + ) return dispatcher return decorator @@ -129,13 +172,10 @@ def decorator(dispatcher): # This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop # We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool -# TODO: decide if we want that -def _register_five_ten_crop_kernel(dispatcher, datapoint_cls): +def _register_five_ten_crop_kernel(dispatcher, input_type): registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) - if datapoint_cls in registry: - raise TypeError( - f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'." - ) + if input_type in registry: + raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.") def wrap(kernel): @functools.wraps(kernel) @@ -147,7 +187,7 @@ def wrapper(inpt, *args, **kwargs): return wrapper def decorator(kernel): - registry[datapoint_cls] = wrap(kernel) + registry[input_type] = wrap(kernel) if issubclass(input_type, datapoints.Datapoint) else kernel return kernel return decorator From 9b82df43341a6891f652be1803abd1d1d05bfbb2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Aug 2023 14:16:12 +0100 Subject: [PATCH 3/8] Remove `_wrap()` class method from base class Datapoint (#7805) --- test/test_datapoints.py | 20 ++++++++++++++++++++ torchvision/datapoints/_bounding_box.py | 13 ++++--------- torchvision/datapoints/_datapoint.py | 6 +----- torchvision/datapoints/_image.py | 2 +- torchvision/datapoints/_mask.py | 2 +- torchvision/datapoints/_video.py | 2 +- torchvision/prototype/datapoints/_label.py | 2 +- 7 files changed, 29 insertions(+), 18 deletions(-) diff --git a/test/test_datapoints.py b/test/test_datapoints.py index f0a44ec1720..25a2182e050 100644 --- a/test/test_datapoints.py +++ b/test/test_datapoints.py @@ -113,6 +113,26 @@ def test_detach_wrapping(): assert type(image_detached) is datapoints.Image +def test_no_wrapping_exceptions_with_metadata(): + # Sanity checks for the ops in _NO_WRAPPING_EXCEPTIONS and datapoints with metadata + format, canvas_size = "XYXY", (32, 32) + bbox = datapoints.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size) + + bbox = bbox.clone() + assert bbox.format, bbox.canvas_size == (format, canvas_size) + + bbox = bbox.to(torch.float64) + assert bbox.format, bbox.canvas_size == (format, canvas_size) + + bbox = bbox.detach() + assert bbox.format, bbox.canvas_size == (format, canvas_size) + + assert not bbox.requires_grad + bbox.requires_grad_(True) + assert bbox.format, bbox.canvas_size == (format, canvas_size) + assert bbox.requires_grad + + def test_other_op_no_wrapping(): image = datapoints.Image(torch.rand(3, 16, 16)) diff --git a/torchvision/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py index 7477b3652dc..9677cef21e6 100644 --- a/torchvision/datapoints/_bounding_box.py +++ b/torchvision/datapoints/_bounding_box.py @@ -42,7 +42,9 @@ class BoundingBoxes(Datapoint): canvas_size: Tuple[int, int] @classmethod - def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override] + def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override] + if isinstance(format, str): + format = BoundingBoxFormat[format.upper()] bounding_boxes = tensor.as_subclass(cls) bounding_boxes.format = format bounding_boxes.canvas_size = canvas_size @@ -59,10 +61,6 @@ def __new__( requires_grad: Optional[bool] = None, ) -> BoundingBoxes: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) - - if isinstance(format, str): - format = BoundingBoxFormat[format.upper()] - return cls._wrap(tensor, format=format, canvas_size=canvas_size) @classmethod @@ -71,7 +69,7 @@ def wrap_like( other: BoundingBoxes, tensor: torch.Tensor, *, - format: Optional[BoundingBoxFormat] = None, + format: Optional[Union[BoundingBoxFormat, str]] = None, canvas_size: Optional[Tuple[int, int]] = None, ) -> BoundingBoxes: """Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference. @@ -85,9 +83,6 @@ def wrap_like( omitted, it is taken from the reference. """ - if isinstance(format, str): - format = BoundingBoxFormat[format.upper()] - return cls._wrap( tensor, format=format if format is not None else other.format, diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index fae3c18656b..9b1c648648d 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -32,13 +32,9 @@ def _to_tensor( requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) - @classmethod - def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: - return tensor.as_subclass(cls) - @classmethod def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: - return cls._wrap(tensor) + return tensor.as_subclass(cls) _NO_WRAPPING_EXCEPTIONS = { torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), diff --git a/torchvision/datapoints/_image.py b/torchvision/datapoints/_image.py index 9b635e8e034..cf7b8b1fccd 100644 --- a/torchvision/datapoints/_image.py +++ b/torchvision/datapoints/_image.py @@ -41,7 +41,7 @@ def __new__( elif tensor.ndim == 2: tensor = tensor.unsqueeze(0) - return cls._wrap(tensor) + return tensor.as_subclass(cls) def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr() diff --git a/torchvision/datapoints/_mask.py b/torchvision/datapoints/_mask.py index 95eda077929..e2bafcd6883 100644 --- a/torchvision/datapoints/_mask.py +++ b/torchvision/datapoints/_mask.py @@ -36,4 +36,4 @@ def __new__( data = F.pil_to_tensor(data) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) - return cls._wrap(tensor) + return tensor.as_subclass(cls) diff --git a/torchvision/datapoints/_video.py b/torchvision/datapoints/_video.py index 842c05bf7e9..19ab0aa8de7 100644 --- a/torchvision/datapoints/_video.py +++ b/torchvision/datapoints/_video.py @@ -31,7 +31,7 @@ def __new__( tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) if data.ndim < 4: raise ValueError - return cls._wrap(tensor) + return tensor.as_subclass(cls) def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr() diff --git a/torchvision/prototype/datapoints/_label.py b/torchvision/prototype/datapoints/_label.py index ac9b2d8912a..7ed2f7522b0 100644 --- a/torchvision/prototype/datapoints/_label.py +++ b/torchvision/prototype/datapoints/_label.py @@ -15,7 +15,7 @@ class _LabelBase(Datapoint): categories: Optional[Sequence[str]] @classmethod - def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: # type: ignore[override] + def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: label_base = tensor.as_subclass(cls) label_base.categories = categories return label_base From 8faa1b14d383129877c7d233e8be848330980875 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 8 Aug 2023 10:17:12 +0100 Subject: [PATCH 4/8] Simplify query_bounding_boxes logic (#7786) Co-authored-by: Philip Meier --- test/common_utils.py | 6 +- test/test_datapoints.py | 8 +- test/test_prototype_transforms.py | 6 +- test/test_transforms_v2.py | 12 --- test/test_transforms_v2_functional.py | 78 +++++++------------ test/transforms_v2_kernel_infos.py | 13 +--- torchvision/datapoints/_bounding_box.py | 10 +++ torchvision/prototype/transforms/_geometry.py | 4 +- torchvision/transforms/v2/_geometry.py | 4 +- torchvision/transforms/v2/_misc.py | 10 +-- torchvision/transforms/v2/utils.py | 13 ++-- 11 files changed, 67 insertions(+), 97 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index ec694cc8178..8d5eb047534 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -691,7 +691,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT if isinstance(format, str): format = datapoints.BoundingBoxFormat[format] - spatial_size = _parse_size(spatial_size, name="canvas_size") + spatial_size = _parse_size(spatial_size, name="spatial_size") def fn(shape, dtype, device): *batch_dims, num_coordinates = shape @@ -702,12 +702,12 @@ def fn(shape, dtype, device): format=format, canvas_size=spatial_size, batch_dims=batch_dims, dtype=dtype, device=device ) - return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size) + return BoundingBoxesLoader(fn, shape=(*extra_dims[-1:], 4), dtype=dtype, format=format, spatial_size=spatial_size) def make_bounding_box_loaders( *, - extra_dims=DEFAULT_EXTRA_DIMS, + extra_dims=tuple(d for d in DEFAULT_EXTRA_DIMS if len(d) < 2), formats=tuple(datapoints.BoundingBoxFormat), spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtypes=(torch.float32, torch.float64, torch.int64), diff --git a/test/test_datapoints.py b/test/test_datapoints.py index 25a2182e050..984caa2c345 100644 --- a/test/test_datapoints.py +++ b/test/test_datapoints.py @@ -22,7 +22,7 @@ def test_mask_instance(data): assert mask.ndim == 3 and mask.shape[0] == 1 -@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]]]) +@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]], [1, 2, 3, 4]]) @pytest.mark.parametrize( "format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH] ) @@ -35,6 +35,12 @@ def test_bbox_instance(data, format): assert bboxes.format == format +def test_bbox_dim_error(): + data_3d = [[[1, 2, 3, 4]]] + with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"): + datapoints.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32)) + + @pytest.mark.parametrize( ("data", "input_requires_grad", "expected_requires_grad"), [ diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index b1760f6f965..d395c224785 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -20,7 +20,7 @@ from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video from torchvision.prototype import datapoints, transforms from torchvision.transforms.v2._utils import _convert_fill_arg -from torchvision.transforms.v2.functional import InterpolationMode, pil_to_tensor, to_image_pil +from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_image_pil from torchvision.transforms.v2.utils import check_type, is_simple_tensor BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] @@ -306,7 +306,9 @@ def test__transform_bounding_boxes_clamping(self, mocker): bounding_boxes = make_bounding_box( format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,) ) - mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes") + mock = mocker.patch( + "torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes", wraps=clamp_bounding_boxes + ) transform = transforms.FixedSizeCrop((-1, -1)) mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 49455b05dc5..353cc846bed 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1654,18 +1654,6 @@ def test_sanitize_bounding_boxes_errors(): different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)} transforms.SanitizeBoundingBoxes()(different_sizes) - with pytest.raises(ValueError, match="boxes must be of shape"): - bad_bbox = datapoints.BoundingBoxes( # batch with 2 elements - [ - [[0, 0, 10, 10]], - [[0, 0, 10, 10]], - ], - format=datapoints.BoundingBoxFormat.XYXY, - canvas_size=(20, 20), - ) - different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])} - transforms.SanitizeBoundingBoxes()(different_sizes) - @pytest.mark.parametrize( "import_statement", diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index 713737abbff..bf447c8ce71 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -711,21 +711,20 @@ def _parse_padding(padding): @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]]) def test_correctness_pad_bounding_boxes(device, padding): - def _compute_expected_bbox(bbox, padding_): + def _compute_expected_bbox(bbox, format, padding_): pad_left, pad_up, _, _ = _parse_padding(padding_) dtype = bbox.dtype - format = bbox.format bbox = ( bbox.clone() if format == datapoints.BoundingBoxFormat.XYXY - else convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY) + else convert_format_bounding_boxes(bbox, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) ) bbox[0::2] += pad_left bbox[1::2] += pad_up - bbox = convert_format_bounding_boxes(bbox, new_format=format) + bbox = convert_format_bounding_boxes(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format) if bbox.dtype != dtype: # Temporary cast to original dtype # e.g. float32 -> int @@ -737,7 +736,7 @@ def _compute_expected_canvas_size(bbox, padding_): height, width = bbox.canvas_size return height + pad_up + pad_down, width + pad_left + pad_right - for bboxes in make_bounding_boxes(): + for bboxes in make_bounding_boxes(extra_dims=((4,),)): bboxes = bboxes.to(device) bboxes_format = bboxes.format bboxes_canvas_size = bboxes.canvas_size @@ -748,18 +747,10 @@ def _compute_expected_canvas_size(bbox, padding_): torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding)) - if bboxes.ndim < 2 or bboxes.shape[0] == 0: - bboxes = [bboxes] - - expected_bboxes = [] - for bbox in bboxes: - bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size) - expected_bboxes.append(_compute_expected_bbox(bbox, padding)) + expected_bboxes = torch.stack( + [_compute_expected_bbox(b, bboxes_format, padding) for b in bboxes.reshape(-1, 4).unbind()] + ).reshape(bboxes.shape) - if len(expected_bboxes) > 1: - expected_bboxes = torch.stack(expected_bboxes) - else: - expected_bboxes = expected_bboxes[0] torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0) @@ -784,7 +775,7 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device): ], ) def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): - def _compute_expected_bbox(bbox, pcoeffs_): + def _compute_expected_bbox(bbox, format_, canvas_size_, pcoeffs_): m1 = np.array( [ [pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]], @@ -798,7 +789,9 @@ def _compute_expected_bbox(bbox, pcoeffs_): ] ) - bbox_xyxy = convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY) + bbox_xyxy = convert_format_bounding_boxes( + bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY + ) points = np.array( [ [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], @@ -818,14 +811,11 @@ def _compute_expected_bbox(bbox, pcoeffs_): np.max(transformed_points[:, 1]), ] ) - out_bbox = datapoints.BoundingBoxes( - out_bbox, - format=datapoints.BoundingBoxFormat.XYXY, - canvas_size=bbox.canvas_size, - dtype=bbox.dtype, - device=bbox.device, + out_bbox = torch.from_numpy(out_bbox) + out_bbox = convert_format_bounding_boxes( + out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_ ) - return clamp_bounding_boxes(convert_format_bounding_boxes(out_bbox, new_format=bbox.format)) + return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox) canvas_size = (32, 38) @@ -844,17 +834,13 @@ def _compute_expected_bbox(bbox, pcoeffs_): coefficients=pcoeffs, ) - if bboxes.ndim < 2: - bboxes = [bboxes] + expected_bboxes = torch.stack( + [ + _compute_expected_bbox(b, bboxes.format, bboxes.canvas_size, inv_pcoeffs) + for b in bboxes.reshape(-1, 4).unbind() + ] + ).reshape(bboxes.shape) - expected_bboxes = [] - for bbox in bboxes: - bbox = datapoints.BoundingBoxes(bbox, format=bboxes.format, canvas_size=bboxes.canvas_size) - expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs)) - if len(expected_bboxes) > 1: - expected_bboxes = torch.stack(expected_bboxes) - else: - expected_bboxes = expected_bboxes[0] torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1) @@ -864,9 +850,7 @@ def _compute_expected_bbox(bbox, pcoeffs_): [(18, 18), [18, 15], (16, 19), [12], [46, 48]], ) def test_correctness_center_crop_bounding_boxes(device, output_size): - def _compute_expected_bbox(bbox, output_size_): - format_ = bbox.format - canvas_size_ = bbox.canvas_size + def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_): dtype = bbox.dtype bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH) @@ -895,18 +879,12 @@ def _compute_expected_bbox(bbox, output_size_): bboxes, bboxes_format, bboxes_canvas_size, output_size ) - if bboxes.ndim < 2: - bboxes = [bboxes] - - expected_bboxes = [] - for bbox in bboxes: - bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size) - expected_bboxes.append(_compute_expected_bbox(bbox, output_size)) - - if len(expected_bboxes) > 1: - expected_bboxes = torch.stack(expected_bboxes) - else: - expected_bboxes = expected_bboxes[0] + expected_bboxes = torch.stack( + [ + _compute_expected_bbox(b, bboxes_format, bboxes_canvas_size, output_size) + for b in bboxes.reshape(-1, 4).unbind() + ] + ).reshape(bboxes.shape) torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0) torch.testing.assert_close(output_canvas_size, output_size) diff --git a/test/transforms_v2_kernel_infos.py b/test/transforms_v2_kernel_infos.py index 01605f696b4..ac5651d3217 100644 --- a/test/transforms_v2_kernel_infos.py +++ b/test/transforms_v2_kernel_infos.py @@ -222,16 +222,9 @@ def transform(bbox, affine_matrix_, format_, canvas_size_): out_bbox = out_bbox.to(dtype=in_dtype) return out_bbox - if bounding_boxes.ndim < 2: - bounding_boxes = [bounding_boxes] - - expected_bboxes = [transform(bbox, affine_matrix, format, canvas_size) for bbox in bounding_boxes] - if len(expected_bboxes) > 1: - expected_bboxes = torch.stack(expected_bboxes) - else: - expected_bboxes = expected_bboxes[0] - - return expected_bboxes + return torch.stack( + [transform(b, affine_matrix, format, canvas_size) for b in bounding_boxes.reshape(-1, 4).unbind()] + ).reshape(bounding_boxes.shape) def sample_inputs_convert_format_bounding_boxes(): diff --git a/torchvision/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py index 9677cef21e6..d459a55448a 100644 --- a/torchvision/datapoints/_bounding_box.py +++ b/torchvision/datapoints/_bounding_box.py @@ -26,6 +26,12 @@ class BoundingBoxFormat(Enum): class BoundingBoxes(Datapoint): """[BETA] :class:`torch.Tensor` subclass for bounding boxes. + .. note:: + There should be only one :class:`~torchvision.datapoints.BoundingBoxes` + instance per sample e.g. ``{"img": img, "bbox": BoundingBoxes(...)}``, + although one :class:`~torchvision.datapoints.BoundingBoxes` object can + contain multiple bounding boxes. + Args: data: Any data that can be turned into a tensor with :func:`torch.as_tensor`. format (BoundingBoxFormat, str): Format of the bounding box. @@ -43,6 +49,10 @@ class BoundingBoxes(Datapoint): @classmethod def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override] + if tensor.ndim == 1: + tensor = tensor.unsqueeze(0) + elif tensor.ndim != 2: + raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D") if isinstance(format, str): format = BoundingBoxFormat[format.upper()] bounding_boxes = tensor.as_subclass(cls) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index a4023ca2108..e3819554d0b 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -7,7 +7,7 @@ from torchvision.prototype.datapoints import Label, OneHotLabel from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size -from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_boxes, query_size +from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size class FixedSizeCrop(Transform): @@ -61,7 +61,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: bounding_boxes: Optional[torch.Tensor] try: - bounding_boxes = query_bounding_boxes(flat_inputs) + bounding_boxes = get_bounding_boxes(flat_inputs) except ValueError: bounding_boxes = None diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index e43aa868a34..23d4c971af0 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -23,7 +23,7 @@ _setup_float_or_seq, _setup_size, ) -from .utils import has_all, has_any, is_simple_tensor, query_bounding_boxes, query_size +from .utils import get_bounding_boxes, has_all, has_any, is_simple_tensor, query_size class RandomHorizontalFlip(_RandomApplyTransform): @@ -1137,7 +1137,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_h, orig_w = query_size(flat_inputs) - bboxes = query_bounding_boxes(flat_inputs) + bboxes = get_bounding_boxes(flat_inputs) while True: # sample an option diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index a799070ee1e..d2dddd96d5c 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -10,7 +10,7 @@ from torchvision.transforms.v2 import functional as F, Transform from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size -from .utils import has_any, is_simple_tensor, query_bounding_boxes +from .utils import get_bounding_boxes, has_any, is_simple_tensor # TODO: do we want/need to expose this? @@ -384,13 +384,7 @@ def forward(self, *inputs: Any) -> Any: ) flat_inputs, spec = tree_flatten(inputs) - # TODO: this enforces one single BoundingBoxes entry. - # Assuming this transform needs to be called at the end of *any* pipeline that has bboxes... - # should we just enforce it for all transforms?? What are the benefits of *not* enforcing this? - boxes = query_bounding_boxes(flat_inputs) - - if boxes.ndim != 2: - raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}") + boxes = get_bounding_boxes(flat_inputs) if labels is not None and boxes.shape[0] != labels.shape[0]: raise ValueError( diff --git a/torchvision/transforms/v2/utils.py b/torchvision/transforms/v2/utils.py index dd9f4489dee..1d9219fb4f5 100644 --- a/torchvision/transforms/v2/utils.py +++ b/torchvision/transforms/v2/utils.py @@ -9,13 +9,12 @@ from torchvision.transforms.v2.functional import get_dimensions, get_size, is_simple_tensor -def query_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes: - bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes)] - if not bounding_boxes: - raise TypeError("No bounding boxes were found in the sample") - elif len(bounding_boxes) > 1: - raise ValueError("Found multiple bounding boxes instances in the sample") - return bounding_boxes.pop() +def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes: + # This assumes there is only one bbox per sample as per the general convention + try: + return next(inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes)) + except StopIteration: + raise ValueError("No bounding boxes were found in the sample") def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: From 2ab937a07d6a3d2486edef945ec8a2de16439e95 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Aug 2023 09:38:55 +0100 Subject: [PATCH 5/8] Change default pytest traceback from native to short (#7810) --- pytest.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytest.ini b/pytest.ini index a2f59ecec46..8d52b55d5a6 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,7 +3,7 @@ addopts = # show tests that (f)ailed, (E)rror, or (X)passed in the summary -rfEX # Make tracebacks shorter - --tb=native + --tb=short # enable all warnings -Wd --ignore=test/test_datasets_download.py From 5d8d61acc907a0e80bfee6dd35a22e4c75f16e82 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Aug 2023 10:48:04 +0200 Subject: [PATCH 6/8] add PermuteChannels transform (#7624) --- docs/source/transforms.rst | 1 + test/test_transforms_v2.py | 1 + test/test_transforms_v2_refactored.py | 58 +++++++++++++++++ torchvision/transforms/v2/__init__.py | 1 + torchvision/transforms/v2/_color.py | 39 ++++++----- .../transforms/v2/functional/__init__.py | 4 ++ .../transforms/v2/functional/_color.py | 65 ++++++++++++++++++- 7 files changed, 151 insertions(+), 18 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index a1858c6b514..0df46c92530 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -155,6 +155,7 @@ Color ColorJitter v2.ColorJitter + v2.RandomChannelPermutation v2.RandomPhotometricDistort Grayscale v2.Grayscale diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 353cc846bed..5f4a9b62898 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -124,6 +124,7 @@ class TestSmoke: (transforms.RandomEqualize(p=1.0), None), (transforms.RandomGrayscale(p=1.0), None), (transforms.RandomInvert(p=1.0), None), + (transforms.RandomChannelPermutation(), None), (transforms.RandomPhotometricDistort(p=1.0), None), (transforms.RandomPosterize(bits=4, p=1.0), None), (transforms.RandomSolarize(threshold=0.5, p=1.0), None), diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index c910882f9fd..fa04d5deb0c 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2280,3 +2280,61 @@ def resize_my_datapoint(): _register_kernel_internal(F.resize, MyDatapoint, datapoint_wrapper=False)(resize_my_datapoint) assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint + + +class TestPermuteChannels: + _DEFAULT_PERMUTATION = [2, 0, 1] + + @pytest.mark.parametrize( + ("kernel", "make_input"), + [ + (F.permute_channels_image_tensor, make_image_tensor), + # FIXME + # check_kernel does not support PIL kernel, but it should + (F.permute_channels_image_tensor, make_image), + (F.permute_channels_video, make_video), + ], + ) + @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel(self, kernel, make_input, dtype, device): + check_kernel(kernel, make_input(dtype=dtype, device=device), permutation=self._DEFAULT_PERMUTATION) + + @pytest.mark.parametrize( + ("kernel", "make_input"), + [ + (F.permute_channels_image_tensor, make_image_tensor), + (F.permute_channels_image_pil, make_image_pil), + (F.permute_channels_image_tensor, make_image), + (F.permute_channels_video, make_video), + ], + ) + def test_dispatcher(self, kernel, make_input): + check_dispatcher(F.permute_channels, kernel, make_input(), permutation=self._DEFAULT_PERMUTATION) + + @pytest.mark.parametrize( + ("kernel", "input_type"), + [ + (F.permute_channels_image_tensor, torch.Tensor), + (F.permute_channels_image_pil, PIL.Image.Image), + (F.permute_channels_image_tensor, datapoints.Image), + (F.permute_channels_video, datapoints.Video), + ], + ) + def test_dispatcher_signature(self, kernel, input_type): + check_dispatcher_kernel_signature_match(F.permute_channels, kernel=kernel, input_type=input_type) + + def reference_image_correctness(self, image, permutation): + channel_images = image.split(1, dim=-3) + permuted_channel_images = [channel_images[channel_idx] for channel_idx in permutation] + return datapoints.Image(torch.concat(permuted_channel_images, dim=-3)) + + @pytest.mark.parametrize("permutation", [[2, 0, 1], [1, 2, 0], [2, 0, 1], [0, 1, 2]]) + @pytest.mark.parametrize("batch_dims", [(), (2,), (2, 1)]) + def test_image_correctness(self, permutation, batch_dims): + image = make_image(batch_dims=batch_dims) + + actual = F.permute_channels(image, permutation=permutation) + expected = self.reference_image_correctness(image, permutation=permutation) + + torch.testing.assert_close(actual, expected) diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 8ce9bee9b4d..4451cb7a1a2 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -11,6 +11,7 @@ Grayscale, RandomAdjustSharpness, RandomAutocontrast, + RandomChannelPermutation, RandomEqualize, RandomGrayscale, RandomInvert, diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 7dd8eeae236..8315e2f36b4 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -177,7 +177,27 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return output -# TODO: This class seems to be untested +class RandomChannelPermutation(Transform): + """[BETA] Randomly permute the channels of an image or video + + .. v2betastatus:: RandomChannelPermutation transform + """ + + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + num_channels, *_ = query_chw(flat_inputs) + return dict(permutation=torch.randperm(num_channels)) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.permute_channels(inpt, params["permutation"]) + + class RandomPhotometricDistort(Transform): """[BETA] Randomly distorts the image or video as used in `SSD: Single Shot MultiBox Detector `_. @@ -241,21 +261,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None return params - def _permute_channels( - self, inpt: Union[datapoints._ImageType, datapoints._VideoType], permutation: torch.Tensor - ) -> Union[datapoints._ImageType, datapoints._VideoType]: - orig_inpt = inpt - if isinstance(orig_inpt, PIL.Image.Image): - inpt = F.pil_to_tensor(inpt) - - # TODO: Find a better fix than as_subclass??? - output = inpt[..., permutation, :, :].as_subclass(type(inpt)) - - if isinstance(orig_inpt, PIL.Image.Image): - output = F.to_image_pil(output) - - return output - def _transform( self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] ) -> Union[datapoints._ImageType, datapoints._VideoType]: @@ -270,7 +275,7 @@ def _transform( if params["contrast_factor"] is not None and not params["contrast_before"]: inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"]) if params["channel_permutation"] is not None: - inpt = self._permute_channels(inpt, permutation=params["channel_permutation"]) + inpt = F.permute_channels(inpt, permutation=params["channel_permutation"]) return inpt diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 163a55fad38..f3295860155 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -62,6 +62,10 @@ invert_image_pil, invert_image_tensor, invert_video, + permute_channels, + permute_channels_image_pil, + permute_channels_image_tensor, + permute_channels_video, posterize, posterize_image_pil, posterize_image_tensor, diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 71797fd2500..9b6bf3886fa 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import List, Union import PIL.Image import torch @@ -10,6 +10,8 @@ from torchvision.utils import _log_api_usage_once from ._misc import _num_value_bits, to_dtype_image_tensor + +from ._type_conversion import pil_to_tensor, to_image_pil from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal @@ -641,3 +643,64 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: @_register_kernel_internal(invert, datapoints.Video) def invert_video(video: torch.Tensor) -> torch.Tensor: return invert_image_tensor(video) + + +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) +def permute_channels(inpt: datapoints._InputTypeJIT, permutation: List[int]) -> datapoints._InputTypeJIT: + """Permute the channels of the input according to the given permutation. + + This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and + :class:`torchvision.datapoints.Image` and :class:`torchvision.datapoints.Video`. + + Example: + >>> rgb_image = torch.rand(3, 256, 256) + >>> bgr_image = F.permutate_channels(rgb_image, permutation=[2, 1, 0]) + + Args: + permutation (List[int]): Valid permutation of the input channel indices. The index of the element determines the + channel index in the input and the value determines the channel index in the output. For example, + ``permutation=[2, 0 , 1]`` + + - takes ``ìnpt[..., 0, :, :]`` and puts it at ``output[..., 2, :, :]``, + - takes ``ìnpt[..., 1, :, :]`` and puts it at ``output[..., 0, :, :]``, and + - takes ``ìnpt[..., 2, :, :]`` and puts it at ``output[..., 1, :, :]``. + + Raises: + ValueError: If ``len(permutation)`` doesn't match the number of channels in the input. + """ + if torch.jit.is_scripting(): + return permute_channels_image_tensor(inpt, permutation=permutation) + + _log_api_usage_once(permute_channels) + + kernel = _get_kernel(permute_channels, type(inpt)) + return kernel(inpt, permutation=permutation) + + +@_register_kernel_internal(permute_channels, torch.Tensor) +@_register_kernel_internal(permute_channels, datapoints.Image) +def permute_channels_image_tensor(image: torch.Tensor, permutation: List[int]) -> torch.Tensor: + shape = image.shape + num_channels, height, width = shape[-3:] + + if len(permutation) != num_channels: + raise ValueError( + f"Length of permutation does not match number of channels: " f"{len(permutation)} != {num_channels}" + ) + + if image.numel() == 0: + return image + + image = image.reshape(-1, num_channels, height, width) + image = image[:, permutation, :, :] + return image.reshape(shape) + + +@_register_kernel_internal(permute_channels, PIL.Image.Image) +def permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) -> PIL.Image: + return to_image_pil(permute_channels_image_tensor(pil_to_tensor(image), permutation=permutation)) + + +@_register_kernel_internal(permute_channels, datapoints.Video) +def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor: + return permute_channels_image_tensor(video, permutation=permutation) From 6b020798524f538b6e5ffe61648d63682695ab91 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Aug 2023 11:34:30 +0200 Subject: [PATCH 7/8] cleanup v2 tests (#7812) --- test/test_transforms_v2_refactored.py | 110 ++++++-------------------- 1 file changed, 25 insertions(+), 85 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index fa04d5deb0c..9028b304c1b 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -173,15 +173,7 @@ def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs): dispatcher_scripted(input.as_subclass(torch.Tensor), *args, **kwargs) -def check_dispatcher( - dispatcher, - # TODO: remove this parameter - kernel, - input, - *args, - check_scripted_smoke=True, - **kwargs, -): +def check_dispatcher(dispatcher, input, *args, check_scripted_smoke=True, **kwargs): unknown_input = object() with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): dispatcher(unknown_input, *args, **kwargs) @@ -516,20 +508,12 @@ def test_kernel_video(self): @pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize( - ("kernel", "make_input"), - [ - (F.resize_image_tensor, make_image_tensor), - (F.resize_image_pil, make_image_pil), - (F.resize_image_tensor, make_image), - (F.resize_bounding_boxes, make_bounding_box), - (F.resize_mask, make_segmentation_mask), - (F.resize_video, make_video), - ], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) - def test_dispatcher(self, size, kernel, make_input): + def test_dispatcher(self, size, make_input): check_dispatcher( F.resize, - kernel, make_input(self.INPUT_SIZE), size=size, antialias=True, @@ -805,18 +789,11 @@ def test_kernel_video(self): check_kernel(F.horizontal_flip_video, make_video()) @pytest.mark.parametrize( - ("kernel", "make_input"), - [ - (F.horizontal_flip_image_tensor, make_image_tensor), - (F.horizontal_flip_image_pil, make_image_pil), - (F.horizontal_flip_image_tensor, make_image), - (F.horizontal_flip_bounding_boxes, make_bounding_box), - (F.horizontal_flip_mask, make_segmentation_mask), - (F.horizontal_flip_video, make_video), - ], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) - def test_dispatcher(self, kernel, make_input): - check_dispatcher(F.horizontal_flip, kernel, make_input()) + def test_dispatcher(self, make_input): + check_dispatcher(F.horizontal_flip, make_input()) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -988,18 +965,11 @@ def test_kernel_video(self): self._check_kernel(F.affine_video, make_video()) @pytest.mark.parametrize( - ("kernel", "make_input"), - [ - (F.affine_image_tensor, make_image_tensor), - (F.affine_image_pil, make_image_pil), - (F.affine_image_tensor, make_image), - (F.affine_bounding_boxes, make_bounding_box), - (F.affine_mask, make_segmentation_mask), - (F.affine_video, make_video), - ], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) - def test_dispatcher(self, kernel, make_input): - check_dispatcher(F.affine, kernel, make_input(), **self._MINIMAL_AFFINE_KWARGS) + def test_dispatcher(self, make_input): + check_dispatcher(F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -1284,18 +1254,11 @@ def test_kernel_video(self): check_kernel(F.vertical_flip_video, make_video()) @pytest.mark.parametrize( - ("kernel", "make_input"), - [ - (F.vertical_flip_image_tensor, make_image_tensor), - (F.vertical_flip_image_pil, make_image_pil), - (F.vertical_flip_image_tensor, make_image), - (F.vertical_flip_bounding_boxes, make_bounding_box), - (F.vertical_flip_mask, make_segmentation_mask), - (F.vertical_flip_video, make_video), - ], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) - def test_dispatcher(self, kernel, make_input): - check_dispatcher(F.vertical_flip, kernel, make_input()) + def test_dispatcher(self, make_input): + check_dispatcher(F.vertical_flip, make_input()) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -1441,18 +1404,11 @@ def test_kernel_video(self): check_kernel(F.rotate_video, make_video(), **self._MINIMAL_AFFINE_KWARGS) @pytest.mark.parametrize( - ("kernel", "make_input"), - [ - (F.rotate_image_tensor, make_image_tensor), - (F.rotate_image_pil, make_image_pil), - (F.rotate_image_tensor, make_image), - (F.rotate_bounding_boxes, make_bounding_box), - (F.rotate_mask, make_segmentation_mask), - (F.rotate_video, make_video), - ], + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_box, make_segmentation_mask, make_video], ) - def test_dispatcher(self, kernel, make_input): - check_dispatcher(F.rotate, kernel, make_input(), **self._MINIMAL_AFFINE_KWARGS) + def test_dispatcher(self, make_input): + check_dispatcher(F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -1711,22 +1667,14 @@ def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, sca scale=scale, ) - @pytest.mark.parametrize( - ("kernel", "make_input"), - [ - (F.to_dtype_image_tensor, make_image_tensor), - (F.to_dtype_image_tensor, make_image), - (F.to_dtype_video, make_video), - ], - ) + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video]) @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("scale", (True, False)) - def test_dispatcher(self, kernel, make_input, input_dtype, output_dtype, device, scale): + def test_dispatcher(self, make_input, input_dtype, output_dtype, device, scale): check_dispatcher( F.to_dtype, - kernel, make_input(dtype=input_dtype, device=device), dtype=output_dtype, scale=scale, @@ -1890,17 +1838,9 @@ class TestAdjustBrightness: def test_kernel(self, kernel, make_input, dtype, device): check_kernel(kernel, make_input(dtype=dtype, device=device), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR) - @pytest.mark.parametrize( - ("kernel", "make_input"), - [ - (F.adjust_brightness_image_tensor, make_image_tensor), - (F.adjust_brightness_image_pil, make_image_pil), - (F.adjust_brightness_image_tensor, make_image), - (F.adjust_brightness_video, make_video), - ], - ) - def test_dispatcher(self, kernel, make_input): - check_dispatcher(F.adjust_brightness, kernel, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR) + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) + def test_dispatcher(self, make_input): + check_dispatcher(F.adjust_brightness, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR) @pytest.mark.parametrize( ("kernel", "input_type"), From 641fdd9f71a17f6088269efdf7b9e311a14ee548 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Aug 2023 12:18:19 +0200 Subject: [PATCH 8/8] remove custom types defintions from datapoints module (#7814) --- torchvision/datapoints/__init__.py | 6 +- torchvision/datapoints/_datapoint.py | 9 +- torchvision/datapoints/_image.py | 6 - torchvision/datapoints/_video.py | 6 - torchvision/prototype/transforms/_augment.py | 10 +- torchvision/prototype/transforms/_geometry.py | 4 +- torchvision/prototype/transforms/_misc.py | 8 +- torchvision/transforms/v2/_auto_augment.py | 24 ++-- torchvision/transforms/v2/_color.py | 4 +- torchvision/transforms/v2/_geometry.py | 18 ++- torchvision/transforms/v2/_misc.py | 4 +- torchvision/transforms/v2/_temporal.py | 3 +- torchvision/transforms/v2/_utils.py | 5 +- .../transforms/v2/functional/_augment.py | 6 +- .../transforms/v2/functional/_color.py | 30 ++-- .../transforms/v2/functional/_deprecated.py | 5 +- .../transforms/v2/functional/_geometry.py | 130 +++++++++--------- torchvision/transforms/v2/functional/_meta.py | 18 +-- torchvision/transforms/v2/functional/_misc.py | 16 +-- .../transforms/v2/functional/_temporal.py | 2 +- .../transforms/v2/functional/_utils.py | 5 +- 21 files changed, 141 insertions(+), 178 deletions(-) diff --git a/torchvision/datapoints/__init__.py b/torchvision/datapoints/__init__.py index 03469ca0cde..de6f975e42d 100644 --- a/torchvision/datapoints/__init__.py +++ b/torchvision/datapoints/__init__.py @@ -1,10 +1,10 @@ from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS from ._bounding_box import BoundingBoxes, BoundingBoxFormat -from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT, Datapoint -from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image +from ._datapoint import Datapoint +from ._image import Image from ._mask import Mask -from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video +from ._video import Video if _WARN_ABOUT_BETA_TRANSFORMS: import warnings diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index 9b1c648648d..af6d5929d10 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -1,16 +1,13 @@ from __future__ import annotations -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union -import PIL.Image import torch from torch._C import DisableTorchFunctionSubclass from torch.types import _device, _dtype, _size D = TypeVar("D", bound="Datapoint") -_FillType = Union[int, float, Sequence[int], Sequence[float], None] -_FillTypeJIT = Optional[List[float]] class Datapoint(torch.Tensor): @@ -132,7 +129,3 @@ def __deepcopy__(self: D, memo: Dict[int, Any]) -> D: # `BoundingBoxes.format` and `BoundingBoxes.canvas_size`, which are immutable and thus implicitly deep-copied by # `BoundingBoxes.clone()`. return self.detach().clone().requires_grad_(self.requires_grad) # type: ignore[return-value] - - -_InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint] -_InputTypeJIT = torch.Tensor diff --git a/torchvision/datapoints/_image.py b/torchvision/datapoints/_image.py index cf7b8b1fccd..609ace90d21 100644 --- a/torchvision/datapoints/_image.py +++ b/torchvision/datapoints/_image.py @@ -45,9 +45,3 @@ def __new__( def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr() - - -_ImageType = Union[torch.Tensor, PIL.Image.Image, Image] -_ImageTypeJIT = torch.Tensor -_TensorImageType = Union[torch.Tensor, Image] -_TensorImageTypeJIT = torch.Tensor diff --git a/torchvision/datapoints/_video.py b/torchvision/datapoints/_video.py index 19ab0aa8de7..f6cc80fabcb 100644 --- a/torchvision/datapoints/_video.py +++ b/torchvision/datapoints/_video.py @@ -35,9 +35,3 @@ def __new__( def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr() - - -_VideoType = Union[torch.Tensor, Video] -_VideoTypeJIT = torch.Tensor -_TensorVideoType = Union[torch.Tensor, Video] -_TensorVideoTypeJIT = torch.Tensor diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 95585fe287c..53f3f801303 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -26,15 +26,15 @@ def __init__( def _copy_paste( self, - image: datapoints._TensorImageType, + image: Union[torch.Tensor, datapoints.Image], target: Dict[str, Any], - paste_image: datapoints._TensorImageType, + paste_image: Union[torch.Tensor, datapoints.Image], paste_target: Dict[str, Any], random_selection: torch.Tensor, blending: bool, resize_interpolation: F.InterpolationMode, antialias: Optional[bool], - ) -> Tuple[datapoints._TensorImageType, Dict[str, Any]]: + ) -> Tuple[torch.Tensor, Dict[str, Any]]: paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection]) paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection]) @@ -106,7 +106,7 @@ def _copy_paste( def _extract_image_targets( self, flat_sample: List[Any] - ) -> Tuple[List[datapoints._TensorImageType], List[Dict[str, Any]]]: + ) -> Tuple[List[Union[torch.Tensor, datapoints.Image]], List[Dict[str, Any]]]: # fetch all images, bboxes, masks and labels from unstructured input # with List[image], List[BoundingBoxes], List[Mask], List[Label] images, bboxes, masks, labels = [], [], [], [] @@ -137,7 +137,7 @@ def _extract_image_targets( def _insert_outputs( self, flat_sample: List[Any], - output_images: List[datapoints._TensorImageType], + output_images: List[torch.Tensor], output_targets: List[Dict[str, Any]], ) -> None: c0, c1, c2, c3 = 0, 0, 0, 0 diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index e3819554d0b..1a2802db0ac 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -6,7 +6,7 @@ from torchvision import datapoints from torchvision.prototype.datapoints import Label, OneHotLabel from torchvision.transforms.v2 import functional as F, Transform -from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size +from torchvision.transforms.v2._utils import _FillType, _get_fill, _setup_fill_arg, _setup_size from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size @@ -14,7 +14,7 @@ class FixedSizeCrop(Transform): def __init__( self, size: Union[int, Sequence[int]], - fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, padding_mode: str = "constant", ) -> None: super().__init__() diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 51a2ea9074a..f1b859aac03 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -39,9 +39,7 @@ def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]] ) self.dims = dims - def _transform( - self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any] - ) -> torch.Tensor: + def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: dims = self.dims[type(inpt)] if dims is None: return inpt.as_subclass(torch.Tensor) @@ -63,9 +61,7 @@ def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, i ) self.dims = dims - def _transform( - self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any] - ) -> torch.Tensor: + def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: dims = self.dims[type(inpt)] if dims is None: return inpt.as_subclass(torch.Tensor) diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index 146c8c236ef..26eb3abbcf9 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -10,17 +10,21 @@ from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.transforms.v2.functional._geometry import _check_interpolation from torchvision.transforms.v2.functional._meta import get_size +from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT from ._utils import _get_fill, _setup_fill_arg from .utils import check_type, is_simple_tensor +ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video] + + class _AutoAugmentBase(Transform): def __init__( self, *, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, ) -> None: super().__init__() self.interpolation = _check_interpolation(interpolation) @@ -35,7 +39,7 @@ def _flatten_and_extract_image_or_video( self, inputs: Any, unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBoxes, datapoints.Mask), - ) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints._ImageType, datapoints._VideoType]]: + ) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]: flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) needs_transform_list = self._needs_transform_list(flat_inputs) @@ -68,7 +72,7 @@ def _flatten_and_extract_image_or_video( def _unflatten_and_insert_image_or_video( self, flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int], - image_or_video: Union[datapoints._ImageType, datapoints._VideoType], + image_or_video: ImageOrVideo, ) -> Any: flat_inputs, spec, idx = flat_inputs_with_spec flat_inputs[idx] = image_or_video @@ -76,12 +80,12 @@ def _unflatten_and_insert_image_or_video( def _apply_image_or_video_transform( self, - image: Union[datapoints._ImageType, datapoints._VideoType], + image: ImageOrVideo, transform_id: str, magnitude: float, interpolation: Union[InterpolationMode, int], - fill: Dict[Union[Type, str], datapoints._FillTypeJIT], - ) -> Union[datapoints._ImageType, datapoints._VideoType]: + fill: Dict[Union[Type, str], _FillTypeJIT], + ) -> ImageOrVideo: fill_ = _get_fill(fill, type(image)) if transform_id == "Identity": @@ -214,7 +218,7 @@ def __init__( self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.policy = policy @@ -394,7 +398,7 @@ def __init__( magnitude: int = 9, num_magnitude_bins: int = 31, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.num_ops = num_ops @@ -467,7 +471,7 @@ def __init__( self, num_magnitude_bins: int = 31, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, ): super().__init__(interpolation=interpolation, fill=fill) self.num_magnitude_bins = num_magnitude_bins @@ -550,7 +554,7 @@ def __init__( alpha: float = 1.0, all_ops: bool = True, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self._PARAMETER_MAX = 10 diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 8315e2f36b4..90e3ce2ff2c 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -261,9 +261,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None return params - def _transform( - self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] - ) -> Union[datapoints._ImageType, datapoints._VideoType]: + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["brightness_factor"] is not None: inpt = F.adjust_brightness(inpt, brightness_factor=params["brightness_factor"]) if params["contrast_factor"] is not None and params["contrast_before"]: diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 23d4c971af0..5c285056928 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -11,6 +11,7 @@ from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform from torchvision.transforms.v2.functional._geometry import _check_interpolation +from torchvision.transforms.v2.functional._utils import _FillType from ._transform import _RandomApplyTransform from ._utils import ( @@ -311,9 +312,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: ) -ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT] - - class FiveCrop(Transform): """[BETA] Crop the image or video into four corners and the central crop. @@ -459,7 +457,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]: def __init__( self, padding: Union[int, Sequence[int]], - fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() @@ -514,7 +512,7 @@ class RandomZoomOut(_RandomApplyTransform): def __init__( self, - fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, side_range: Sequence[float] = (1.0, 4.0), p: float = 0.5, ) -> None: @@ -592,7 +590,7 @@ def __init__( interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, - fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, ) -> None: super().__init__() self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) @@ -674,7 +672,7 @@ def __init__( scale: Optional[Sequence[float]] = None, shear: Optional[Union[int, float, Sequence[float]]] = None, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, center: Optional[List[float]] = None, ) -> None: super().__init__() @@ -812,7 +810,7 @@ def __init__( size: Union[int, Sequence[int]], padding: Optional[Union[int, Sequence[int]]] = None, pad_if_needed: bool = False, - fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() @@ -931,7 +929,7 @@ def __init__( distortion_scale: float = 0.5, p: float = 0.5, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, ) -> None: super().__init__(p=p) @@ -1033,7 +1031,7 @@ def __init__( alpha: Union[float, Sequence[float]] = 50.0, sigma: Union[float, Sequence[float]] = 5.0, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, ) -> None: super().__init__() self.alpha = _setup_float_or_seq(alpha, "alpha", 2) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index d2dddd96d5c..da71cebb416 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -169,9 +169,7 @@ def _check_inputs(self, sample: Any) -> Any: if has_any(sample, PIL.Image.Image): raise TypeError(f"{type(self).__name__}() does not support PIL images.") - def _transform( - self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any] - ) -> Any: + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) diff --git a/torchvision/transforms/v2/_temporal.py b/torchvision/transforms/v2/_temporal.py index 868314e9e33..591341e7cc7 100644 --- a/torchvision/transforms/v2/_temporal.py +++ b/torchvision/transforms/v2/_temporal.py @@ -1,7 +1,6 @@ from typing import Any, Dict import torch -from torchvision import datapoints from torchvision.transforms.v2 import functional as F, Transform @@ -25,5 +24,5 @@ def __init__(self, num_samples: int): super().__init__() self.num_samples = num_samples - def _transform(self, inpt: datapoints._VideoType, params: Dict[str, Any]) -> datapoints._VideoType: + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.uniform_temporal_subsample(inpt, self.num_samples) diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index a7826a6645f..f9d9bae49e9 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -5,9 +5,8 @@ import torch -from torchvision import datapoints -from torchvision.datapoints._datapoint import _FillType, _FillTypeJIT from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 +from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]: @@ -36,7 +35,7 @@ def _check_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) - raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.") -def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT: +def _convert_fill_arg(fill: _FillType) -> _FillTypeJIT: # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # So, we can't reassign fill to 0 # if fill is None: diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 89fa254374d..1497638f6b3 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,5 +1,3 @@ -from typing import Union - import PIL.Image import torch @@ -12,14 +10,14 @@ @_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True) def erase( - inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], + inpt: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False, -) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]: +) -> torch.Tensor: if torch.jit.is_scripting(): return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 9b6bf3886fa..9ba88d31b94 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List import PIL.Image import torch @@ -16,9 +16,7 @@ @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, datapoints.Video) -def rgb_to_grayscale( - inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], num_output_channels: int = 1 -) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]: +def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: if torch.jit.is_scripting(): return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels) @@ -73,7 +71,7 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) -def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT: +def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Tensor: if torch.jit.is_scripting(): return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) @@ -110,7 +108,7 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) -def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float) -> datapoints._InputTypeJIT: +def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Tensor: if torch.jit.is_scripting(): return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) @@ -149,7 +147,7 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) -def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> datapoints._InputTypeJIT: +def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor: if torch.jit.is_scripting(): return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) @@ -188,7 +186,7 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) -def adjust_sharpness(inpt: datapoints._InputTypeJIT, sharpness_factor: float) -> datapoints._InputTypeJIT: +def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tensor: if torch.jit.is_scripting(): return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) @@ -261,7 +259,7 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) -def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints._InputTypeJIT: +def adjust_hue(inpt: torch.Tensor, hue_factor: float) -> torch.Tensor: if torch.jit.is_scripting(): return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) @@ -373,7 +371,7 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) -def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1) -> datapoints._InputTypeJIT: +def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: if torch.jit.is_scripting(): return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) @@ -413,7 +411,7 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) -def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTypeJIT: +def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor: if torch.jit.is_scripting(): return posterize_image_tensor(inpt, bits=bits) @@ -447,7 +445,7 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) -def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._InputTypeJIT: +def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor: if torch.jit.is_scripting(): return solarize_image_tensor(inpt, threshold=threshold) @@ -475,7 +473,7 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) -def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: +def autocontrast(inpt: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting(): return autocontrast_image_tensor(inpt) @@ -525,7 +523,7 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) -def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: +def equalize(inpt: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting(): return equalize_image_tensor(inpt) @@ -615,7 +613,7 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) -def invert(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: +def invert(inpt: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting(): return invert_image_tensor(inpt) @@ -646,7 +644,7 @@ def invert_video(video: torch.Tensor) -> torch.Tensor: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) -def permute_channels(inpt: datapoints._InputTypeJIT, permutation: List[int]) -> datapoints._InputTypeJIT: +def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor: """Permute the channels of the input according to the given permutation. This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and diff --git a/torchvision/transforms/v2/functional/_deprecated.py b/torchvision/transforms/v2/functional/_deprecated.py index f27d0b29deb..1cb7f50e5c7 100644 --- a/torchvision/transforms/v2/functional/_deprecated.py +++ b/torchvision/transforms/v2/functional/_deprecated.py @@ -1,9 +1,8 @@ import warnings -from typing import Any, List, Union +from typing import Any, List import torch -from torchvision import datapoints from torchvision.transforms import functional as _F @@ -16,7 +15,7 @@ def to_tensor(inpt: Any) -> torch.Tensor: return _F.to_tensor(inpt) -def get_image_size(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]: +def get_image_size(inpt: torch.Tensor) -> List[int]: warnings.warn( "The function `get_image_size(...)` is deprecated and will be removed in a future release. " "Instead, please use `get_size(...)` which returns `[h, w]` instead of `[w, h]`." diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index bb19def2c93..6416a143c03 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -25,7 +25,13 @@ from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil -from ._utils import _get_kernel, _register_explicit_noop, _register_five_ten_crop_kernel, _register_kernel_internal +from ._utils import ( + _FillTypeJIT, + _get_kernel, + _register_explicit_noop, + _register_five_ten_crop_kernel, + _register_kernel_internal, +) def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -39,7 +45,7 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp return interpolation -def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: +def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting(): return horizontal_flip_image_tensor(inpt) @@ -95,7 +101,7 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: return horizontal_flip_image_tensor(video) -def vertical_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: +def vertical_flip(inpt: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting(): return vertical_flip_image_tensor(inpt) @@ -171,12 +177,12 @@ def _compute_resized_output_size( def resize( - inpt: datapoints._InputTypeJIT, + inpt: torch.Tensor, size: List[int], interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[Union[str, bool]] = "warn", -) -> datapoints._InputTypeJIT: +) -> torch.Tensor: if torch.jit.is_scripting(): return resize_image_tensor(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) @@ -364,15 +370,15 @@ def resize_video( def affine( - inpt: datapoints._InputTypeJIT, + inpt: torch.Tensor, angle: Union[int, float], translate: List[float], scale: float, shear: List[float], interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, center: Optional[List[float]] = None, -) -> datapoints._InputTypeJIT: +) -> torch.Tensor: if torch.jit.is_scripting(): return affine_image_tensor( inpt, @@ -549,9 +555,7 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in return int(size[0]), int(size[1]) # w, h -def _apply_grid_transform( - img: torch.Tensor, grid: torch.Tensor, mode: str, fill: datapoints._FillTypeJIT -) -> torch.Tensor: +def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill: _FillTypeJIT) -> torch.Tensor: # We are using context knowledge that grid should have float dtype fp = img.dtype == grid.dtype @@ -592,7 +596,7 @@ def _assert_grid_transform_inputs( image: torch.Tensor, matrix: Optional[List[float]], interpolation: str, - fill: datapoints._FillTypeJIT, + fill: _FillTypeJIT, supported_interpolation_modes: List[str], coeffs: Optional[List[float]] = None, ) -> None: @@ -657,7 +661,7 @@ def affine_image_tensor( scale: float, shear: List[float], interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: interpolation = _check_interpolation(interpolation) @@ -709,7 +713,7 @@ def affine_image_pil( scale: float, shear: List[float], interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, center: Optional[List[float]] = None, ) -> PIL.Image.Image: interpolation = _check_interpolation(interpolation) @@ -868,7 +872,7 @@ def affine_mask( translate: List[float], scale: float, shear: List[float], - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: if mask.ndim < 3: @@ -901,7 +905,7 @@ def _affine_mask_dispatch( translate: List[float], scale: float, shear: List[float], - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, center: Optional[List[float]] = None, **kwargs, ) -> datapoints.Mask: @@ -925,7 +929,7 @@ def affine_video( scale: float, shear: List[float], interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: return affine_image_tensor( @@ -941,13 +945,13 @@ def affine_video( def rotate( - inpt: datapoints._InputTypeJIT, + inpt: torch.Tensor, angle: float, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, - fill: datapoints._FillTypeJIT = None, -) -> datapoints._InputTypeJIT: + fill: _FillTypeJIT = None, +) -> torch.Tensor: if torch.jit.is_scripting(): return rotate_image_tensor( inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center @@ -967,7 +971,7 @@ def rotate_image_tensor( interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, ) -> torch.Tensor: interpolation = _check_interpolation(interpolation) @@ -1012,7 +1016,7 @@ def rotate_image_pil( interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, ) -> PIL.Image.Image: interpolation = _check_interpolation(interpolation) @@ -1068,7 +1072,7 @@ def rotate_mask( angle: float, expand: bool = False, center: Optional[List[float]] = None, - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, ) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -1097,7 +1101,7 @@ def _rotate_mask_dispatch( angle: float, expand: bool = False, center: Optional[List[float]] = None, - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, **kwargs, ) -> datapoints.Mask: output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center) @@ -1111,17 +1115,17 @@ def rotate_video( interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, ) -> torch.Tensor: return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) def pad( - inpt: datapoints._InputTypeJIT, + inpt: torch.Tensor, padding: List[int], fill: Optional[Union[int, float, List[float]]] = None, padding_mode: str = "constant", -) -> datapoints._InputTypeJIT: +) -> torch.Tensor: if torch.jit.is_scripting(): return pad_image_tensor(inpt, padding=padding, fill=fill, padding_mode=padding_mode) @@ -1336,7 +1340,7 @@ def pad_video( return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode) -def crop(inpt: datapoints._InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints._InputTypeJIT: +def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: if torch.jit.is_scripting(): return crop_image_tensor(inpt, top=top, left=left, height=height, width=width) @@ -1423,13 +1427,13 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int def perspective( - inpt: datapoints._InputTypeJIT, + inpt: torch.Tensor, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, coefficients: Optional[List[float]] = None, -) -> datapoints._InputTypeJIT: +) -> torch.Tensor: if torch.jit.is_scripting(): return perspective_image_tensor( inpt, @@ -1507,7 +1511,7 @@ def perspective_image_tensor( startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> torch.Tensor: perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) @@ -1554,7 +1558,7 @@ def perspective_image_pil( startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], interpolation: Union[InterpolationMode, int] = InterpolationMode.BICUBIC, - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> PIL.Image.Image: perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) @@ -1679,7 +1683,7 @@ def perspective_mask( mask: torch.Tensor, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> torch.Tensor: if mask.ndim < 3: @@ -1703,7 +1707,7 @@ def _perspective_mask_dispatch( inpt: datapoints.Mask, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, coefficients: Optional[List[float]] = None, **kwargs, ) -> datapoints.Mask: @@ -1723,7 +1727,7 @@ def perspective_video( startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> torch.Tensor: return perspective_image_tensor( @@ -1732,11 +1736,11 @@ def perspective_video( def elastic( - inpt: datapoints._InputTypeJIT, + inpt: torch.Tensor, displacement: torch.Tensor, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: datapoints._FillTypeJIT = None, -) -> datapoints._InputTypeJIT: + fill: _FillTypeJIT = None, +) -> torch.Tensor: if torch.jit.is_scripting(): return elastic_image_tensor(inpt, displacement=displacement, interpolation=interpolation, fill=fill) @@ -1755,7 +1759,7 @@ def elastic_image_tensor( image: torch.Tensor, displacement: torch.Tensor, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, ) -> torch.Tensor: interpolation = _check_interpolation(interpolation) @@ -1812,7 +1816,7 @@ def elastic_image_pil( image: PIL.Image.Image, displacement: torch.Tensor, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, ) -> PIL.Image.Image: t_img = pil_to_tensor(image) output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill) @@ -1895,7 +1899,7 @@ def _elastic_bounding_boxes_dispatch( def elastic_mask( mask: torch.Tensor, displacement: torch.Tensor, - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, ) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -1913,7 +1917,7 @@ def elastic_mask( @_register_kernel_internal(elastic, datapoints.Mask, datapoint_wrapper=False) def _elastic_mask_dispatch( - inpt: datapoints.Mask, displacement: torch.Tensor, fill: datapoints._FillTypeJIT = None, **kwargs + inpt: datapoints.Mask, displacement: torch.Tensor, fill: _FillTypeJIT = None, **kwargs ) -> datapoints.Mask: output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill) return datapoints.Mask.wrap_like(inpt, output) @@ -1924,12 +1928,12 @@ def elastic_video( video: torch.Tensor, displacement: torch.Tensor, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: datapoints._FillTypeJIT = None, + fill: _FillTypeJIT = None, ) -> torch.Tensor: return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill) -def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datapoints._InputTypeJIT: +def center_crop(inpt: torch.Tensor, output_size: List[int]) -> torch.Tensor: if torch.jit.is_scripting(): return center_crop_image_tensor(inpt, output_size=output_size) @@ -2049,7 +2053,7 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens def resized_crop( - inpt: datapoints._InputTypeJIT, + inpt: torch.Tensor, top: int, left: int, height: int, @@ -2057,7 +2061,7 @@ def resized_crop( size: List[int], interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", -) -> datapoints._InputTypeJIT: +) -> torch.Tensor: if torch.jit.is_scripting(): return resized_crop_image_tensor( inpt, @@ -2201,14 +2205,8 @@ def resized_crop_video( @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True) def five_crop( - inpt: datapoints._InputTypeJIT, size: List[int] -) -> Tuple[ - datapoints._InputTypeJIT, - datapoints._InputTypeJIT, - datapoints._InputTypeJIT, - datapoints._InputTypeJIT, - datapoints._InputTypeJIT, -]: + inpt: torch.Tensor, size: List[int] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if torch.jit.is_scripting(): return five_crop_image_tensor(inpt, size=size) @@ -2280,18 +2278,18 @@ def five_crop_video( @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True) def ten_crop( - inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], size: List[int], vertical_flip: bool = False + inpt: torch.Tensor, size: List[int], vertical_flip: bool = False ) -> Tuple[ - datapoints._InputTypeJIT, - datapoints._InputTypeJIT, - datapoints._InputTypeJIT, - datapoints._InputTypeJIT, - datapoints._InputTypeJIT, - datapoints._InputTypeJIT, - datapoints._InputTypeJIT, - datapoints._InputTypeJIT, - datapoints._InputTypeJIT, - datapoints._InputTypeJIT, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, ]: if torch.jit.is_scripting(): return ten_crop_image_tensor(inpt, size=size, vertical_flip=vertical_flip) diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index fc1aa05f319..a7177ab04e9 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple import PIL.Image import torch @@ -12,7 +12,7 @@ @_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) -def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]: +def get_dimensions(inpt: torch.Tensor) -> List[int]: if torch.jit.is_scripting(): return get_dimensions_image_tensor(inpt) @@ -45,7 +45,7 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]: @_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) -def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int: +def get_num_channels(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_channels_image_tensor(inpt) @@ -81,7 +81,7 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels -def get_size(inpt: datapoints._InputTypeJIT) -> List[int]: +def get_size(inpt: torch.Tensor) -> List[int]: if torch.jit.is_scripting(): return get_size_image_tensor(inpt) @@ -124,7 +124,7 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int] @_register_unsupported_type(PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask) -def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int: +def get_num_frames(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_frames_video(inpt) @@ -201,11 +201,11 @@ def _convert_format_bounding_boxes( def convert_format_bounding_boxes( - inpt: datapoints._InputTypeJIT, + inpt: torch.Tensor, old_format: Optional[BoundingBoxFormat] = None, new_format: Optional[BoundingBoxFormat] = None, inplace: bool = False, -) -> datapoints._InputTypeJIT: +) -> torch.Tensor: # This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor # inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on # `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the @@ -252,10 +252,10 @@ def _clamp_bounding_boxes( def clamp_bounding_boxes( - inpt: datapoints._InputTypeJIT, + inpt: torch.Tensor, format: Optional[BoundingBoxFormat] = None, canvas_size: Optional[Tuple[int, int]] = None, -) -> datapoints._InputTypeJIT: +) -> torch.Tensor: if not torch.jit.is_scripting(): _log_api_usage_once(clamp_bounding_boxes) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index e3a800ea79c..ec9c194d51d 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import List, Optional, Union +from typing import List, Optional import PIL.Image import torch @@ -17,7 +17,7 @@ @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @_register_unsupported_type(PIL.Image.Image) def normalize( - inpt: Union[datapoints._TensorImageTypeJIT, datapoints._TensorVideoTypeJIT], + inpt: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False, @@ -74,9 +74,7 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) -def gaussian_blur( - inpt: datapoints._InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None -) -> datapoints._InputTypeJIT: +def gaussian_blur(inpt: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> torch.Tensor: if torch.jit.is_scripting(): return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) @@ -185,9 +183,7 @@ def gaussian_blur_video( @_register_unsupported_type(PIL.Image.Image) -def to_dtype( - inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False -) -> datapoints._InputTypeJIT: +def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: if torch.jit.is_scripting(): return to_dtype_image_tensor(inpt, dtype=dtype, scale=scale) @@ -278,8 +274,6 @@ def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: @_register_kernel_internal(to_dtype, datapoints.BoundingBoxes, datapoint_wrapper=False) @_register_kernel_internal(to_dtype, datapoints.Mask, datapoint_wrapper=False) -def _to_dtype_tensor_dispatch( - inpt: datapoints._InputTypeJIT, dtype: torch.dtype, scale: bool = False -) -> datapoints._InputTypeJIT: +def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: bool = False) -> torch.Tensor: # We don't need to unwrap and rewrap here, since Datapoint.to() preserves the type return inpt.to(dtype) diff --git a/torchvision/transforms/v2/functional/_temporal.py b/torchvision/transforms/v2/functional/_temporal.py index 62d12cb4b4e..78dcfc1ef92 100644 --- a/torchvision/transforms/v2/functional/_temporal.py +++ b/torchvision/transforms/v2/functional/_temporal.py @@ -11,7 +11,7 @@ @_register_explicit_noop( PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True ) -def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) -> datapoints._VideoTypeJIT: +def uniform_temporal_subsample(inpt: torch.Tensor, num_samples: int) -> torch.Tensor: if torch.jit.is_scripting(): return uniform_temporal_subsample_video(inpt, num_samples=num_samples) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 576a2b99dbf..ce1c320a745 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -1,10 +1,13 @@ import functools import warnings -from typing import Any, Callable, Dict, Type +from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union import torch from torchvision import datapoints +_FillType = Union[int, float, Sequence[int], Sequence[float], None] +_FillTypeJIT = Optional[List[float]] + def is_simple_tensor(inpt: Any) -> bool: return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint)