Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve UX for v2 Compose #7758

Merged
merged 5 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
make_video,
set_rng_seed,
)

from torch import nn
from torch.testing import assert_close
from torchvision import datapoints

Expand Down Expand Up @@ -1634,3 +1636,56 @@ def test_transform_negative_degrees_error(self):
def test_transform_unknown_fill_error(self):
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.RandomAffine(degrees=0, fill="fill")


class TestCompose:
class BuiltinTransform(transforms.Transform):
def _transform(self, inpt, params):
return inpt

class PackedInputTransform(nn.Module):
def forward(self, sample):
image, label = sample
return image, label
pmeier marked this conversation as resolved.
Show resolved Hide resolved

class UnpackedInputTransform(nn.Module):
def forward(self, image, label):
return image, label

@pytest.mark.parametrize(
"transform_clss",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not call this transform_class ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would have to be transform_classes, since it is a list of classes. And since we use cls for singular, I'm usually just append an s to it. I'll leave it up to you.

[
[BuiltinTransform],
[PackedInputTransform],
[UnpackedInputTransform],
[BuiltinTransform, BuiltinTransform],
[PackedInputTransform, PackedInputTransform],
[UnpackedInputTransform, UnpackedInputTransform],
[BuiltinTransform, PackedInputTransform, BuiltinTransform],
[BuiltinTransform, UnpackedInputTransform, BuiltinTransform],
[PackedInputTransform, BuiltinTransform, PackedInputTransform],
[UnpackedInputTransform, BuiltinTransform, UnpackedInputTransform],
],
)
@pytest.mark.parametrize("unpack", [True, False])
def test_packed_unpacked(self, transform_clss, unpack):
if any(
unpack
and issubclass(cls, self.PackedInputTransform)
or not unpack
and issubclass(cls, self.UnpackedInputTransform)
for cls in transform_clss
pmeier marked this conversation as resolved.
Show resolved Hide resolved
):
return
pmeier marked this conversation as resolved.
Show resolved Hide resolved

transform = transforms.Compose([cls() for cls in transform_clss])

image = make_image()
label = torch.tensor(3)
pmeier marked this conversation as resolved.
Show resolved Hide resolved
packed_input = (image, label)

output = transform(*packed_input if unpack else (packed_input,))
pmeier marked this conversation as resolved.
Show resolved Hide resolved

assert isinstance(output, tuple) and len(output) == 2
assert output[0] is image
assert output[1] is label
11 changes: 8 additions & 3 deletions torchvision/transforms/v2/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,15 @@ def __init__(self, transforms: Sequence[Callable]) -> None:
self.transforms = transforms

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
needs_unpacking = len(inputs) > 1

if not self.transforms:
return inputs if needs_unpacking else inputs[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just error in init? I dont know if it's super valuable to special-case this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends. The old Compose works as a no-op in case you don't put any transforms in. However, this doesn't sound like a valid use case. So I'm ok in putting this into the constructor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with the error. LMK if you want the no-op behavior back.


for transform in self.transforms:
sample = transform(sample)
return sample
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs

def extra_repr(self) -> str:
format_string = []
Expand Down
Loading