diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 109d1a409ba..06d8dc98cf8 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -1669,23 +1669,31 @@ def forward(self, image, label): ) @pytest.mark.parametrize("unpack", [True, False]) def test_packed_unpacked(self, transform_clss, unpack): - if any( - unpack - and issubclass(cls, self.PackedInputTransform) - or not unpack - and issubclass(cls, self.UnpackedInputTransform) - for cls in transform_clss - ): - return + needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss) + needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss) + assert not (needs_packed_inputs and needs_unpacked_inputs) transform = transforms.Compose([cls() for cls in transform_clss]) image = make_image() - label = torch.tensor(3) + label = 3 packed_input = (image, label) - output = transform(*packed_input if unpack else (packed_input,)) + def call_transform(): + if unpack: + return transform(*packed_input) + else: + return transform(packed_input) + + if needs_unpacked_inputs and not unpack: + with pytest.raises(TypeError, match="missing 1 required positional argument"): + call_transform() + elif needs_packed_inputs and unpack: + with pytest.raises(TypeError, match="takes 2 positional arguments but 3 were given"): + call_transform() + else: + output = call_transform() - assert isinstance(output, tuple) and len(output) == 2 - assert output[0] is image - assert output[1] is label + assert isinstance(output, tuple) and len(output) == 2 + assert output[0] is image + assert output[1] is label