diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index aa11b83b61a..c5cc81a02ca 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -545,38 +545,3 @@ def test_sanitize_bounding_boxes_errors(): with pytest.raises(ValueError, match="Number of boxes"): different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)} transforms.SanitizeBoundingBoxes()(different_sizes) - - -class TestLambda: - inputs = pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0]) - - @inputs - def test_default(self, input): - was_applied = False - - def was_applied_fn(input): - nonlocal was_applied - was_applied = True - return input - - transform = transforms.Lambda(was_applied_fn) - - transform(input) - - assert was_applied - - @inputs - def test_with_types(self, input): - was_applied = False - - def was_applied_fn(input): - nonlocal was_applied - was_applied = True - return input - - types = (torch.Tensor, np.ndarray) - transform = transforms.Lambda(was_applied_fn, *types) - - transform(input) - - assert was_applied is isinstance(input, types) diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index 12e76e89f43..fe3dc264e78 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -12,7 +12,6 @@ import torchvision.transforms.v2 as v2_transforms from common_utils import assert_close, assert_equal, set_rng_seed from torchvision import transforms as legacy_transforms, tv_tensors -from torchvision._utils import sequence_to_str from torchvision.transforms import functional as legacy_F from torchvision.transforms.v2 import functional as prototype_F @@ -70,57 +69,7 @@ def __init__( LINEAR_TRANSFORMATION_MEAN = torch.rand(36) LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2) -CONSISTENCY_CONFIGS = [ - ConsistencyConfig( - v2_transforms.Lambda, - legacy_transforms.Lambda, - [ - NotScriptableArgsKwargs(lambda image: image / 2), - ], - # Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL - # images given that the transform does nothing but call it anyway. - supports_pil=False, - ), -] - - -@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__) -def test_signature_consistency(config): - legacy_params = dict(inspect.signature(config.legacy_cls).parameters) - prototype_params = dict(inspect.signature(config.prototype_cls).parameters) - - for param in config.removed_params: - legacy_params.pop(param, None) - - missing = legacy_params.keys() - prototype_params.keys() - if missing: - raise AssertionError( - f"The prototype transform does not support the parameters " - f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. " - f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on " - f"the `ConsistencyConfig`." - ) - - extra = prototype_params.keys() - legacy_params.keys() - extra_without_default = { - param - for param in extra - if prototype_params[param].default is inspect.Parameter.empty - and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD} - } - if extra_without_default: - raise AssertionError( - f"The prototype transform requires the parameters " - f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does " - f"not. Please add a default value." - ) - - legacy_signature = list(legacy_params.keys()) - # Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature - # to the same number of parameters as the legacy one - prototype_signature = list(prototype_params.keys())[: len(legacy_signature)] - - assert prototype_signature == legacy_signature +CONSISTENCY_CONFIGS = [] def check_call_consistency( diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index d9e271ce7b9..fd55083ff55 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -1906,8 +1906,9 @@ def test_random_order(self): input = make_image() actual = check_transform(transform, input) - # horizontal and vertical flip are commutative. Meaning, although the order in the transform is indeed random, - # we don't need to care here. + # We can't really check whether the transforms are actually applied in random order. However, horizontal and + # vertical flip are commutative. Meaning, even under the assumption that the transform applies them in random + # order, we can use a fixed order to compute the expected value. expected = F.vertical_flip(F.horizontal_flip(input)) assert_equal(actual, expected) @@ -5221,3 +5222,21 @@ def test_functional_and_transform(self, color_space, fn): def test_functional_error(self): with pytest.raises(TypeError, match="pic should be PIL Image"): F.pil_to_tensor(object()) + + +class TestLambda: + @pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0]) + @pytest.mark.parametrize("types", [(), (torch.Tensor, np.ndarray)]) + def test_transform(self, input, types): + was_applied = False + + def was_applied_fn(input): + nonlocal was_applied + was_applied = True + return input + + transform = transforms.Lambda(was_applied_fn, *types) + output = transform(input) + + assert output is input + assert was_applied is (not types or isinstance(input, types)) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index bdf4ccc5912..d176e00a8da 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1270,7 +1270,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: Note: Please, note that this method supports only RGB images as input. For inputs in other color spaces, - please, consider using meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image. + please, consider using :meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image. Args: img (PIL Image or Tensor): RGB Image to be converted to grayscale.