diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 69180b99dbc..64a79262f3e 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -26,6 +26,8 @@ make_video, set_rng_seed, ) + +from torch import nn from torch.testing import assert_close from torchvision import datapoints @@ -1634,3 +1636,64 @@ def test_transform_negative_degrees_error(self): def test_transform_unknown_fill_error(self): with pytest.raises(TypeError, match="Got inappropriate fill arg"): transforms.RandomAffine(degrees=0, fill="fill") + + +class TestCompose: + class BuiltinTransform(transforms.Transform): + def _transform(self, inpt, params): + return inpt + + class PackedInputTransform(nn.Module): + def forward(self, sample): + assert len(sample) == 2 + return sample + + class UnpackedInputTransform(nn.Module): + def forward(self, image, label): + return image, label + + @pytest.mark.parametrize( + "transform_clss", + [ + [BuiltinTransform], + [PackedInputTransform], + [UnpackedInputTransform], + [BuiltinTransform, BuiltinTransform], + [PackedInputTransform, PackedInputTransform], + [UnpackedInputTransform, UnpackedInputTransform], + [BuiltinTransform, PackedInputTransform, BuiltinTransform], + [BuiltinTransform, UnpackedInputTransform, BuiltinTransform], + [PackedInputTransform, BuiltinTransform, PackedInputTransform], + [UnpackedInputTransform, BuiltinTransform, UnpackedInputTransform], + ], + ) + @pytest.mark.parametrize("unpack", [True, False]) + def test_packed_unpacked(self, transform_clss, unpack): + needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss) + needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss) + assert not (needs_packed_inputs and needs_unpacked_inputs) + + transform = transforms.Compose([cls() for cls in transform_clss]) + + image = make_image() + label = 3 + packed_input = (image, label) + + def call_transform(): + if unpack: + return transform(*packed_input) + else: + return transform(packed_input) + + if needs_unpacked_inputs and not unpack: + with pytest.raises(TypeError, match="missing 1 required positional argument"): + call_transform() + elif needs_packed_inputs and unpack: + with pytest.raises(TypeError, match="takes 2 positional arguments but 3 were given"): + call_transform() + else: + output = call_transform() + + assert isinstance(output, tuple) and len(output) == 2 + assert output[0] is image + assert output[1] is label diff --git a/torchvision/transforms/v2/_container.py b/torchvision/transforms/v2/_container.py index fffef4157bd..8f591c49707 100644 --- a/torchvision/transforms/v2/_container.py +++ b/torchvision/transforms/v2/_container.py @@ -43,13 +43,16 @@ def __init__(self, transforms: Sequence[Callable]) -> None: super().__init__() if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence of callables") + elif not transforms: + raise ValueError("Pass at least one transform") self.transforms = transforms def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] + needs_unpacking = len(inputs) > 1 for transform in self.transforms: - sample = transform(sample) - return sample + outputs = transform(*inputs) + inputs = outputs if needs_unpacking else (outputs,) + return outputs def extra_repr(self) -> str: format_string = []