Skip to content

Commit

Permalink
[fbsync] improve UX for v2 Compose (#7758)
Browse files Browse the repository at this point in the history
Summary: Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>

Reviewed By: matteobettini

Differential Revision: D48642249

fbshipit-source-id: 48053ff1f6a19d62264500358a4df125751b71e6
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Aug 25, 2023
1 parent ba0044e commit 7beb6a5
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 3 deletions.
63 changes: 63 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
make_video,
set_rng_seed,
)

from torch import nn
from torch.testing import assert_close
from torchvision import datapoints

Expand Down Expand Up @@ -1634,3 +1636,64 @@ def test_transform_negative_degrees_error(self):
def test_transform_unknown_fill_error(self):
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.RandomAffine(degrees=0, fill="fill")


class TestCompose:
class BuiltinTransform(transforms.Transform):
def _transform(self, inpt, params):
return inpt

class PackedInputTransform(nn.Module):
def forward(self, sample):
assert len(sample) == 2
return sample

class UnpackedInputTransform(nn.Module):
def forward(self, image, label):
return image, label

@pytest.mark.parametrize(
"transform_clss",
[
[BuiltinTransform],
[PackedInputTransform],
[UnpackedInputTransform],
[BuiltinTransform, BuiltinTransform],
[PackedInputTransform, PackedInputTransform],
[UnpackedInputTransform, UnpackedInputTransform],
[BuiltinTransform, PackedInputTransform, BuiltinTransform],
[BuiltinTransform, UnpackedInputTransform, BuiltinTransform],
[PackedInputTransform, BuiltinTransform, PackedInputTransform],
[UnpackedInputTransform, BuiltinTransform, UnpackedInputTransform],
],
)
@pytest.mark.parametrize("unpack", [True, False])
def test_packed_unpacked(self, transform_clss, unpack):
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss)
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss)
assert not (needs_packed_inputs and needs_unpacked_inputs)

transform = transforms.Compose([cls() for cls in transform_clss])

image = make_image()
label = 3
packed_input = (image, label)

def call_transform():
if unpack:
return transform(*packed_input)
else:
return transform(packed_input)

if needs_unpacked_inputs and not unpack:
with pytest.raises(TypeError, match="missing 1 required positional argument"):
call_transform()
elif needs_packed_inputs and unpack:
with pytest.raises(TypeError, match="takes 2 positional arguments but 3 were given"):
call_transform()
else:
output = call_transform()

assert isinstance(output, tuple) and len(output) == 2
assert output[0] is image
assert output[1] is label
9 changes: 6 additions & 3 deletions torchvision/transforms/v2/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@ def __init__(self, transforms: Sequence[Callable]) -> None:
super().__init__()
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
elif not transforms:
raise ValueError("Pass at least one transform")
self.transforms = transforms

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
needs_unpacking = len(inputs) > 1
for transform in self.transforms:
sample = transform(sample)
return sample
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs

def extra_repr(self) -> str:
format_string = []
Expand Down

0 comments on commit 7beb6a5

Please sign in to comment.