From b99647b9d21cc2574d671e826f9c7b235e052caf Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 25 Jul 2023 16:00:46 +0200 Subject: [PATCH] enforce at least one transform --- torchvision/transforms/v2/_container.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchvision/transforms/v2/_container.py b/torchvision/transforms/v2/_container.py index 39940e44b34..8f591c49707 100644 --- a/torchvision/transforms/v2/_container.py +++ b/torchvision/transforms/v2/_container.py @@ -43,14 +43,12 @@ def __init__(self, transforms: Sequence[Callable]) -> None: super().__init__() if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence of callables") + elif not transforms: + raise ValueError("Pass at least one transform") self.transforms = transforms def forward(self, *inputs: Any) -> Any: needs_unpacking = len(inputs) > 1 - - if not self.transforms: - return inputs if needs_unpacking else inputs[0] - for transform in self.transforms: outputs = transform(*inputs) inputs = outputs if needs_unpacking else (outputs,)