Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor Datapoint dispatch mechanism #7747

Merged
merged 26 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
d9e1379
[PoC] refactor Datapoint dispatch mechanism
pmeier Jul 19, 2023
36b9d36
fix test
pmeier Jul 19, 2023
f36c64c
Merge branch 'main' into kernel-registration
pmeier Jul 26, 2023
bbaa35c
add dispatch to adjust_brightness
pmeier Jul 27, 2023
ca4ad32
enforce no register overwrite
pmeier Jul 27, 2023
d23a80e
[PoC] make wrapping interal kernel more convenient
pmeier Jul 27, 2023
bf47188
[PoC] enforce explicit no-ops
pmeier Jul 27, 2023
74d5054
fix adjust_brightness tests and remove methods
pmeier Jul 27, 2023
e88be5e
Merge branch 'main' into kernel-registration
pmeier Jul 27, 2023
f178373
address minor comments
pmeier Jul 27, 2023
65e80d0
make no-op registration a decorator
pmeier Jul 28, 2023
9614477
Merge branch 'main'
pmeier Aug 1, 2023
6ac08e4
explicit metadata
pmeier Aug 1, 2023
cac079b
implement dispatchers for erase five/ten_crop and temporal_subsample
pmeier Aug 1, 2023
c7256b4
make shape getters proper dispatchers
pmeier Aug 1, 2023
bf78cd6
fix
pmeier Aug 1, 2023
f86f89b
port normalize and to_dtype
pmeier Aug 2, 2023
d90daf6
address comments
pmeier Aug 2, 2023
09eec9a
address comments and cleanup
pmeier Aug 2, 2023
3730811
more cleanup
pmeier Aug 2, 2023
7203453
Merge branch 'main' into kernel-registration
pmeier Aug 2, 2023
31bee5f
port all remaining dispatchers to the new mechanism
pmeier Jul 28, 2023
a924013
put back legacy test_dispatch_datapoint
pmeier Aug 2, 2023
b3c2c88
minor test fixes
pmeier Aug 2, 2023
a1f5ea4
Update torchvision/transforms/v2/functional/_utils.py
pmeier Aug 2, 2023
d29d95b
reinstante antialias tests
pmeier Aug 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,10 @@ def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
return datapoints.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))


def make_video_tensor(*args, **kwargs):
return make_video(*args, **kwargs).as_subclass(torch.Tensor)
pmeier marked this conversation as resolved.
Show resolved Hide resolved


def make_video_loader(
size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
*,
Expand Down
6 changes: 4 additions & 2 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def test_transforms(self, config):

@test_all_configs
def test_transforms_v2_wrapper(self, config):
from torchvision.datapoints._datapoint import Datapoint
from torchvision import datapoints
from torchvision.datasets import wrap_dataset_for_transforms_v2

try:
Expand All @@ -588,7 +588,9 @@ def test_transforms_v2_wrapper(self, config):
assert len(wrapped_dataset) == info["num_examples"]

wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
assert tree_any(
lambda item: isinstance(item, (datapoints.Datapoint, PIL.Image.Image)), wrapped_sample
)
except TypeError as error:
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
if str(error).startswith(msg):
Expand Down
2 changes: 0 additions & 2 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,8 +1346,6 @@ def test_antialias_warning():
with pytest.warns(UserWarning, match=match):
datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20))

with pytest.warns(UserWarning, match=match):
datapoints.Video(tensor_video).resize((20, 20))
pmeier marked this conversation as resolved.
Show resolved Hide resolved
with pytest.warns(UserWarning, match=match):
datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20))

Expand Down
49 changes: 6 additions & 43 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import os
import re

from typing import get_type_hints

import numpy as np
import PIL.Image
import pytest
Expand Down Expand Up @@ -417,22 +415,6 @@ def test_pil_output_type(self, info, args_kwargs):

assert isinstance(output, PIL.Image.Image)

@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
)
def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
pmeier marked this conversation as resolved.
Show resolved Hide resolved
(datapoint, *other_args), kwargs = args_kwargs.load()

method_name = info.id
method = getattr(datapoint, method_name)
datapoint_type = type(datapoint)
spy = spy_on(method, module=datapoint_type.__module__, name=f"{datapoint_type.__name__}.{method_name}")

info.dispatcher(datapoint, *other_args, **kwargs)

spy.assert_called_once()

@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
Expand Down Expand Up @@ -462,9 +444,12 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoi
kernel_params = list(kernel_signature.parameters.values())[1:]

# We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
# explicit passed to the kernel.
datapoint_type_metadata = datapoint_type.__annotations__.keys()
pmeier marked this conversation as resolved.
Show resolved Hide resolved
kernel_params = [param for param in kernel_params if param.name not in datapoint_type_metadata]
# explicitly passed to the kernel.
input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel)
explicit_metadata = {
datapoints.BoundingBoxes: {"format", "canvas_size"},
}
kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]

dispatcher_params = iter(dispatcher_params)
for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
Expand All @@ -481,28 +466,6 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoi

assert dispatcher_param == kernel_param

@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_dispatcher_datapoint_signatures_consistency(self, info):
try:
datapoint_method = getattr(datapoints._datapoint.Datapoint, info.id)
except AttributeError:
pytest.skip("Dispatcher doesn't support arbitrary datapoint dispatch.")

dispatcher_signature = inspect.signature(info.dispatcher)
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

datapoint_signature = inspect.signature(datapoint_method)
datapoint_params = list(datapoint_signature.parameters.values())[1:]

# Because we use `from __future__ import annotations` inside the module where `datapoints._datapoint` is
# defined, the annotations are stored as strings. This makes them concrete again, so they can be compared to the
# natively concrete dispatcher annotations.
datapoint_annotations = get_type_hints(datapoint_method)
for param in datapoint_params:
param._annotation = datapoint_annotations[param.name]

assert dispatcher_params == datapoint_params

@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_unkown_type(self, info):
unkown_input = object()
Expand Down
Loading
Loading