From e91e8792ef0263acfb8a92b337fa6168919ffe82 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 25 Jul 2023 10:49:53 +0100 Subject: [PATCH] Got rid of FakeData, doesn't work tho --- test/test_transforms_v2_refactored.py | 55 ++++++++++++--------------- torchvision/datasets/fakedata.py | 6 +-- 2 files changed, 26 insertions(+), 35 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index bc2292b7c9c..bc1cfa471ee 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -30,7 +30,6 @@ from torch.testing import assert_close from torch.utils.data import DataLoader, default_collate from torchvision import datapoints -from torchvision.datasets import FakeData from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.functional import pil_modes_mapping @@ -1640,36 +1639,44 @@ def test_transform_unknown_fill_error(self): class TestCutMixMixUp: + class DummyDataset: + def __init__(self, size, one_hot, num_categories): + self.one_hot = one_hot + self.size = size + self.num_categories = num_categories + assert size < num_categories + + def __getitem__(self, idx): + img = torch.rand(3, 12, 12) + label = idx + if self.one_hot: + label = torch.nn.functional.one_hot(torch.tensor(label), num_classes=self.num_categories) + return img, label + + def __len__(self): + return self.size + @pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup, "CutMixMixUp", "MixUpCutMix"]) @pytest.mark.parametrize("one_hot", [True, False]) def test_supported_input_structure(self, T, one_hot): - num_categories = 10 batch_size = 32 - H, W = 12, 12 - - preproc = transforms.Compose([transforms.PILToTensor(), transforms.ToDtype(torch.float32)]) - if one_hot: - - class ToOneHot(torch.nn.Module): - def forward(self, inpt): - img, label = inpt - return img, torch.nn.functional.one_hot(label, num_classes=num_categories) + num_categories = 100 - preproc = transforms.Compose([preproc, ToOneHot()]) + dataset = self.DummyDataset(size=batch_size, one_hot=one_hot, num_categories=num_categories) - dataset = FakeData(size=batch_size, image_size=(3, H, W), num_classes=num_categories, transforms=preproc) if isinstance(T, str): - expected_num_non_zero_labels = 3 # see common_checks cutmix = transforms.Cutmix(alpha=0.5, num_categories=num_categories) mixup = transforms.Mixup(alpha=0.5, num_categories=num_categories) if T == "CutMixMixUp": cutmix_mixup = transforms.Compose([cutmix, mixup]) else: cutmix_mixup = transforms.Compose([mixup, cutmix]) + # When both CutMix and MixUp + expected_num_non_zero_labels = 3 else: cutmix_mixup = T(alpha=0.5, num_categories=num_categories) - expected_num_non_zero_labels = 2 # see common_checks + expected_num_non_zero_labels = 2 dl = DataLoader(dataset, batch_size=batch_size) @@ -1679,23 +1686,11 @@ def forward(self, inpt): assert target.shape == (batch_size, num_categories) if one_hot else (batch_size,) def check_output(img, target): - assert img.shape == (batch_size, 3, H, W) + assert img.shape == (batch_size, 3, 12, 12) assert target.shape == (batch_size, num_categories) torch.testing.assert_close(target.sum(axis=-1), torch.ones(batch_size)) - # Below we check the number of non-zero values in the target tensor. - # When just CutMix() (or just MixUp()) is called, we should expect 2 - # non-zero label values per sample. Although, it may happen that - # only 1 non-zero value is present, basically if the transform had - # no effect. Here we make sure that: - # - there is at least one sample with 2 non-zero values - # - there is no sample with more than 2 non-zero values - # When CutMix() and MixUp() are called in sequence together, we - # should expect 3 instead of 2. That's the - # expected_num_non_zero_labels threshold. - num_non_zero_values = (target != 0).sum(axis=-1) - assert (num_non_zero_values == expected_num_non_zero_labels).any() - assert (num_non_zero_values <= expected_num_non_zero_labels).all() - assert (num_non_zero_values > 0).all() # Note: we already know that from target.sum(axis=-1) check above + num_non_zero_labels = (target != 0).sum(axis=-1) + assert (num_non_zero_labels == expected_num_non_zero_labels).all() # After Dataloader, as unpacked input img, target = next(iter(dl)) diff --git a/torchvision/datasets/fakedata.py b/torchvision/datasets/fakedata.py index 94e7b28b892..244af634989 100644 --- a/torchvision/datasets/fakedata.py +++ b/torchvision/datasets/fakedata.py @@ -28,11 +28,10 @@ def __init__( image_size: Tuple[int, int, int] = (3, 224, 224), num_classes: int = 10, transform: Optional[Callable] = None, - transforms: Optional[Callable] = None, target_transform: Optional[Callable] = None, random_offset: int = 0, ) -> None: - super().__init__(None, transform=transform, transforms=transforms, target_transform=target_transform) # type: ignore[arg-type] + super().__init__(None, transform=transform, target_transform=target_transform) # type: ignore[arg-type] self.size = size self.num_classes = num_classes self.image_size = image_size @@ -61,9 +60,6 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) - if self.transforms is not None: - img, target = self.transforms(img, target) - return img, target # We don't want to call item() on arbitrarily transformed targets return img, target.item()