From 45bf28c4b1045d248217f7efe46827274213a4e0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 25 Jul 2023 11:15:12 +0100 Subject: [PATCH] Use bigger images --- test/test_transforms_v2_refactored.py | 8 ++++---- torchvision/transforms/v2/_augment.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index bc1cfa471ee..671d1bbe91a 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -1647,8 +1647,8 @@ def __init__(self, size, one_hot, num_categories): assert size < num_categories def __getitem__(self, idx): - img = torch.rand(3, 12, 12) - label = idx + img = torch.rand(3, 100, 100) + label = idx # This ensures all labels in a batch are unique and makes testing easier if self.one_hot: label = torch.nn.functional.one_hot(torch.tensor(label), num_classes=self.num_categories) return img, label @@ -1672,7 +1672,6 @@ def test_supported_input_structure(self, T, one_hot): cutmix_mixup = transforms.Compose([cutmix, mixup]) else: cutmix_mixup = transforms.Compose([mixup, cutmix]) - # When both CutMix and MixUp expected_num_non_zero_labels = 3 else: cutmix_mixup = T(alpha=0.5, num_categories=num_categories) @@ -1682,11 +1681,12 @@ def test_supported_input_structure(self, T, one_hot): # Input sanity checks img, target = next(iter(dl)) + input_img_size = img.shape[-3:] assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor) assert target.shape == (batch_size, num_categories) if one_hot else (batch_size,) def check_output(img, target): - assert img.shape == (batch_size, 3, 12, 12) + assert img.shape == (batch_size, *input_img_size) assert target.shape == (batch_size, num_categories) torch.testing.assert_close(target.sum(axis=-1), torch.ones(batch_size)) num_non_zero_labels = (target != 0).sum(axis=-1) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index c38deef24a1..faaab263c88 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -243,8 +243,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: H, W = query_spatial_size(flat_inputs) - r_x = torch.randint(W, ()) - r_y = torch.randint(H, ()) + r_x = torch.randint(W, size=(1,)) + r_y = torch.randint(H, size=(1,)) r = 0.5 * math.sqrt(1.0 - lam) r_w_half = int(r * W)