Skip to content

Commit

Permalink
improve test
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jul 25, 2023
1 parent 835f40c commit 4bcb488
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,23 +1669,31 @@ def forward(self, image, label):
)
@pytest.mark.parametrize("unpack", [True, False])
def test_packed_unpacked(self, transform_clss, unpack):
if any(
unpack
and issubclass(cls, self.PackedInputTransform)
or not unpack
and issubclass(cls, self.UnpackedInputTransform)
for cls in transform_clss
):
return
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss)
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss)
assert not (needs_packed_inputs and needs_unpacked_inputs)

transform = transforms.Compose([cls() for cls in transform_clss])

image = make_image()
label = torch.tensor(3)
label = 3
packed_input = (image, label)

output = transform(*packed_input if unpack else (packed_input,))
def call_transform():
if unpack:
return transform(*packed_input)
else:
return transform(packed_input)

if needs_unpacked_inputs and not unpack:
with pytest.raises(TypeError, match="missing 1 required positional argument"):
call_transform()
elif needs_packed_inputs and unpack:
with pytest.raises(TypeError, match="takes 2 positional arguments but 3 were given"):
call_transform()
else:
output = call_transform()

assert isinstance(output, tuple) and len(output) == 2
assert output[0] is image
assert output[1] is label
assert isinstance(output, tuple) and len(output) == 2
assert output[0] is image
assert output[1] is label

0 comments on commit 4bcb488

Please sign in to comment.