Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

register tensor and PIL kernel the same way as datapoints #7797

Merged
merged 14 commits into from
Aug 7, 2023
58 changes: 3 additions & 55 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import math
import os
import re
from unittest import mock

import numpy as np
import PIL.Image
Expand All @@ -25,7 +24,6 @@
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY
from torchvision.transforms.v2.utils import is_simple_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
Expand Down Expand Up @@ -359,18 +357,6 @@ def test_scripted_smoke(self, info, args_kwargs, device):
def test_scriptable(self, dispatcher):
script(dispatcher)

@image_sample_inputs
def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = torch.Tensor(image_datapoint)

kernel_info = info.kernel_infos[datapoints.Image]
spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id)

info.dispatcher(image_simple_tensor, *other_args, **kwargs)

spy.assert_called_once()

@image_sample_inputs
def test_simple_tensor_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
Expand All @@ -381,25 +367,6 @@ def test_simple_tensor_output_type(self, info, args_kwargs):
# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
assert type(output) is torch.Tensor

@make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
)
def test_dispatch_pil(self, info, args_kwargs, spy_on):
(image_datapoint, *other_args), kwargs = args_kwargs.load()

if image_datapoint.ndim > 3:
pytest.skip("Input is batched")

image_pil = F.to_image_pil(image_datapoint)

pil_kernel_info = info.pil_kernel_info
spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.id)

info.dispatcher(image_pil, *other_args, **kwargs)

spy.assert_called_once()

@make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
Expand All @@ -416,28 +383,6 @@ def test_pil_output_type(self, info, args_kwargs):

assert isinstance(output, PIL.Image.Image)

@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
)
def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
(datapoint, *other_args), kwargs = args_kwargs.load()

input_type = type(datapoint)

wrapped_kernel = _KERNEL_REGISTRY[info.dispatcher][input_type]

# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
# proper kernel was wrapped
if hasattr(wrapped_kernel, "__wrapped__"):
assert wrapped_kernel.__wrapped__ is info.kernels[input_type]

spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__)
with mock.patch.dict(_KERNEL_REGISTRY[info.dispatcher], values={input_type: spy}):
info.dispatcher(datapoint, *other_args, **kwargs)

spy.assert_called_once()

@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
Expand All @@ -449,6 +394,9 @@ def test_datapoint_output_type(self, info, args_kwargs):

assert isinstance(output, type(datapoint))

if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_format_bounding_boxes:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this test is part of the legacy framework, I was to lazy to handle this more elegantly.

assert output.format == datapoint.format

@pytest.mark.parametrize(
("dispatcher_info", "datapoint_type", "kernel_info"),
[
Expand Down
169 changes: 117 additions & 52 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.functional import pil_modes_mapping
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY
from torchvision.transforms.v2.functional._utils import (
_get_kernel,
_KERNEL_REGISTRY,
_noop,
_register_kernel_internal,
register_kernel,
)


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -173,59 +179,32 @@ def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs):
dispatcher_scripted(input.as_subclass(torch.Tensor), *args, **kwargs)


def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs):
pmeier marked this conversation as resolved.
Show resolved Hide resolved
"""Checks if the dispatcher correctly dispatches the input to the corresponding kernel and that the input type is
preserved in doing so. For bounding boxes also checks that the format is preserved.
"""
input_type = type(input)

if isinstance(input, datapoints.Datapoint):
wrapped_kernel = _KERNEL_REGISTRY[dispatcher][input_type]

# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
# proper kernel was wrapped
if hasattr(wrapped_kernel, "__wrapped__"):
assert wrapped_kernel.__wrapped__ is kernel

spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__)
with mock.patch.dict(_KERNEL_REGISTRY[dispatcher], values={input_type: spy}):
output = dispatcher(input, *args, **kwargs)

spy.assert_called_once()
else:
with mock.patch(f"{dispatcher.__module__}.{kernel.__name__}", wraps=kernel) as spy:
output = dispatcher(input, *args, **kwargs)

spy.assert_called_once()

assert isinstance(output, input_type)

if isinstance(input, datapoints.BoundingBoxes):
assert output.format == input.format


def check_dispatcher(
dispatcher,
# TODO: remove this parameter
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We no longer need this parameter. However, we previously parametrized over it together with the function to create the input. Thus, removing it is chore that I'll deal with after release. It doesn't have any effect on the runtime, because the number of tests stays exactly the same after this parameter is removed.

kernel,
input,
*args,
check_scripted_smoke=True,
check_dispatch=True,
**kwargs,
):
unknown_input = object()
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
dispatcher(unknown_input, *args, **kwargs)

with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy:
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
dispatcher(unknown_input, *args, **kwargs)
output = dispatcher(input, *args, **kwargs)

spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}")

assert isinstance(output, type(input))

if isinstance(input, datapoints.BoundingBoxes):
assert output.format == input.format

if check_scripted_smoke:
_check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs)

if check_dispatch:
_check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs)


def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type):
"""Checks if the signature of the dispatcher matches the kernel signature."""
Expand Down Expand Up @@ -412,18 +391,20 @@ def transform(bbox):


@pytest.mark.parametrize(
("dispatcher", "registered_datapoint_clss"),
("dispatcher", "registered_input_types"),
[(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()],
)
def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss):
def test_exhaustive_kernel_registration(dispatcher, registered_input_types):
missing = {
torch.Tensor,
PIL.Image.Image,
datapoints.Image,
datapoints.BoundingBoxes,
datapoints.Mask,
datapoints.Video,
} - registered_datapoint_clss
} - registered_input_types
if missing:
names = sorted(f"datapoints.{cls.__name__}" for cls in missing)
names = sorted(str(t) for t in missing)
raise AssertionError(
"\n".join(
[
Expand Down Expand Up @@ -1753,11 +1734,6 @@ def test_dispatcher(self, kernel, make_input, input_dtype, output_dtype, device,
F.to_dtype,
kernel,
make_input(dtype=input_dtype, device=device),
# TODO: we could leave check_dispatch to True but it currently fails
# in _check_dispatcher_dispatch because there is no to_dtype() method on the datapoints.
# We should be able to put this back if we change the dispatch
# mechanism e.g. via https://github.com/pytorch/vision/pull/7733
check_dispatch=False,
dtype=output_dtype,
scale=scale,
)
Expand Down Expand Up @@ -2185,7 +2161,9 @@ def test_unsupported_types(self, dispatcher, make_input):

class TestRegisterKernel:
@pytest.mark.parametrize("dispatcher", (F.resize, "resize"))
def test_register_kernel(self, dispatcher):
def test_register_kernel(self, mocker, dispatcher):
mocker.patch.dict(_KERNEL_REGISTRY, values={F.resize: _KERNEL_REGISTRY[F.resize]}, clear=True)
pmeier marked this conversation as resolved.
Show resolved Hide resolved

class CustomDatapoint(datapoints.Datapoint):
pass

Expand All @@ -2208,9 +2186,96 @@ def new_resize(dp, *args, **kwargs):
t(torch.rand(3, 10, 10)).shape == (3, 224, 224)
t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224)

def test_bad_disaptcher_name(self):
class CustomDatapoint(datapoints.Datapoint):
pass
def test_errors(self, mocker):
mocker.patch.dict(_KERNEL_REGISTRY, clear=True)

with pytest.raises(ValueError, match="Could not find dispatcher with name"):
F.register_kernel("bad_name", CustomDatapoint)
F.register_kernel("bad_name", datapoints.Image)

with pytest.raises(ValueError, match="Kernels can only be registered on dispatchers"):
register_kernel(datapoints.Image, F.resize)

with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"):
register_kernel(F.resize, object)

register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor)

with pytest.raises(ValueError, match="already has a kernel registered for type"):
register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor)


class TestGetKernel:
def make_and_register_kernel(self, dispatcher, input_type):
return _register_kernel_internal(dispatcher, input_type, datapoint_wrapper=False)(object())

@pytest.fixture
pmeier marked this conversation as resolved.
Show resolved Hide resolved
def dispatcher_and_kernels(self, mocker):
mocker.patch.dict(_KERNEL_REGISTRY, clear=True)
pmeier marked this conversation as resolved.
Show resolved Hide resolved

dispatcher = object()
pmeier marked this conversation as resolved.
Show resolved Hide resolved

kernels = {
cls: self.make_and_register_kernel(dispatcher, cls)
for cls in [
torch.Tensor,
PIL.Image.Image,
datapoints.Image,
datapoints.BoundingBoxes,
datapoints.Mask,
datapoints.Video,
]
}

yield dispatcher, kernels

def test_unsupported_types(self, dispatcher_and_kernels):
dispatcher, _ = dispatcher_and_kernels

class MyTensor(torch.Tensor):
pass

class MyPILImage(PIL.Image.Image):
pass

for input_type in [str, int, object, MyTensor, MyPILImage]:
with pytest.raises(TypeError, match=re.escape(str(input_type))):
pmeier marked this conversation as resolved.
Show resolved Hide resolved
_get_kernel(dispatcher, input_type)

def test_exact_match(self, dispatcher_and_kernels):
dispatcher, kernels = dispatcher_and_kernels

for input_type, kernel in kernels.items():
assert _get_kernel(dispatcher, input_type) is kernel

def test_builtin_datapoint_subclass(self, dispatcher_and_kernels):
dispatcher, kernels = dispatcher_and_kernels

class MyImage(datapoints.Image):
pass

class MyBoundingBoxes(datapoints.BoundingBoxes):
pass

class MyMask(datapoints.Mask):
pass

class MyVideo(datapoints.Video):
pass

assert _get_kernel(dispatcher, MyImage) is kernels[datapoints.Image]
assert _get_kernel(dispatcher, MyBoundingBoxes) is kernels[datapoints.BoundingBoxes]
assert _get_kernel(dispatcher, MyMask) is kernels[datapoints.Mask]
assert _get_kernel(dispatcher, MyVideo) is kernels[datapoints.Video]

def test_datapoint_subclass(self, dispatcher_and_kernels):
dispatcher, _ = dispatcher_and_kernels

class MyDatapoint(datapoints.Datapoint):
pass

# Note that this will be an error in the future
assert _get_kernel(dispatcher, MyDatapoint) is _noop
pmeier marked this conversation as resolved.
Show resolved Hide resolved

kernel = self.make_and_register_kernel(dispatcher, MyDatapoint)

assert _get_kernel(dispatcher, MyDatapoint) is kernel
25 changes: 9 additions & 16 deletions torchvision/transforms/v2/functional/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once

from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal


@_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True)
Expand All @@ -20,23 +20,16 @@ def erase(
v: torch.Tensor,
inplace: bool = False,
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(erase)

if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting():
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(erase, type(inpt))
return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)

_log_api_usage_once(erase)

kernel = _get_kernel(erase, type(inpt))
return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)


@_register_kernel_internal(erase, torch.Tensor)
@_register_kernel_internal(erase, datapoints.Image)
def erase_image_tensor(
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
Expand All @@ -48,7 +41,7 @@ def erase_image_tensor(
return image


@torch.jit.unused
@_register_kernel_internal(erase, PIL.Image.Image)
def erase_image_pil(
image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> PIL.Image.Image:
Expand Down
Loading
Loading