diff --git a/torchvision/transforms/v2/_container.py b/torchvision/transforms/v2/_container.py index 39940e44b34..8f591c49707 100644 --- a/torchvision/transforms/v2/_container.py +++ b/torchvision/transforms/v2/_container.py @@ -43,14 +43,12 @@ def __init__(self, transforms: Sequence[Callable]) -> None: super().__init__() if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence of callables") + elif not transforms: + raise ValueError("Pass at least one transform") self.transforms = transforms def forward(self, *inputs: Any) -> Any: needs_unpacking = len(inputs) > 1 - - if not self.transforms: - return inputs if needs_unpacking else inputs[0] - for transform in self.transforms: outputs = transform(*inputs) inputs = outputs if needs_unpacking else (outputs,)