Skip to content

Commit

Permalink
Got rid of FakeData, doesn't work tho
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jul 25, 2023
1 parent 26f55de commit e91e879
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 35 deletions.
55 changes: 25 additions & 30 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from torch.testing import assert_close
from torch.utils.data import DataLoader, default_collate
from torchvision import datapoints
from torchvision.datasets import FakeData

from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.functional import pil_modes_mapping
Expand Down Expand Up @@ -1640,36 +1639,44 @@ def test_transform_unknown_fill_error(self):


class TestCutMixMixUp:
class DummyDataset:
def __init__(self, size, one_hot, num_categories):
self.one_hot = one_hot
self.size = size
self.num_categories = num_categories
assert size < num_categories

def __getitem__(self, idx):
img = torch.rand(3, 12, 12)
label = idx
if self.one_hot:
label = torch.nn.functional.one_hot(torch.tensor(label), num_classes=self.num_categories)
return img, label

def __len__(self):
return self.size

@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup, "CutMixMixUp", "MixUpCutMix"])
@pytest.mark.parametrize("one_hot", [True, False])
def test_supported_input_structure(self, T, one_hot):

num_categories = 10
batch_size = 32
H, W = 12, 12

preproc = transforms.Compose([transforms.PILToTensor(), transforms.ToDtype(torch.float32)])
if one_hot:

class ToOneHot(torch.nn.Module):
def forward(self, inpt):
img, label = inpt
return img, torch.nn.functional.one_hot(label, num_classes=num_categories)
num_categories = 100

preproc = transforms.Compose([preproc, ToOneHot()])
dataset = self.DummyDataset(size=batch_size, one_hot=one_hot, num_categories=num_categories)

dataset = FakeData(size=batch_size, image_size=(3, H, W), num_classes=num_categories, transforms=preproc)
if isinstance(T, str):
expected_num_non_zero_labels = 3 # see common_checks
cutmix = transforms.Cutmix(alpha=0.5, num_categories=num_categories)
mixup = transforms.Mixup(alpha=0.5, num_categories=num_categories)
if T == "CutMixMixUp":
cutmix_mixup = transforms.Compose([cutmix, mixup])
else:
cutmix_mixup = transforms.Compose([mixup, cutmix])
# When both CutMix and MixUp
expected_num_non_zero_labels = 3
else:
cutmix_mixup = T(alpha=0.5, num_categories=num_categories)
expected_num_non_zero_labels = 2 # see common_checks
expected_num_non_zero_labels = 2

dl = DataLoader(dataset, batch_size=batch_size)

Expand All @@ -1679,23 +1686,11 @@ def forward(self, inpt):
assert target.shape == (batch_size, num_categories) if one_hot else (batch_size,)

def check_output(img, target):
assert img.shape == (batch_size, 3, H, W)
assert img.shape == (batch_size, 3, 12, 12)
assert target.shape == (batch_size, num_categories)
torch.testing.assert_close(target.sum(axis=-1), torch.ones(batch_size))
# Below we check the number of non-zero values in the target tensor.
# When just CutMix() (or just MixUp()) is called, we should expect 2
# non-zero label values per sample. Although, it may happen that
# only 1 non-zero value is present, basically if the transform had
# no effect. Here we make sure that:
# - there is at least one sample with 2 non-zero values
# - there is no sample with more than 2 non-zero values
# When CutMix() and MixUp() are called in sequence together, we
# should expect 3 instead of 2. That's the
# expected_num_non_zero_labels threshold.
num_non_zero_values = (target != 0).sum(axis=-1)
assert (num_non_zero_values == expected_num_non_zero_labels).any()
assert (num_non_zero_values <= expected_num_non_zero_labels).all()
assert (num_non_zero_values > 0).all() # Note: we already know that from target.sum(axis=-1) check above
num_non_zero_labels = (target != 0).sum(axis=-1)
assert (num_non_zero_labels == expected_num_non_zero_labels).all()

# After Dataloader, as unpacked input
img, target = next(iter(dl))
Expand Down
6 changes: 1 addition & 5 deletions torchvision/datasets/fakedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ def __init__(
image_size: Tuple[int, int, int] = (3, 224, 224),
num_classes: int = 10,
transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
random_offset: int = 0,
) -> None:
super().__init__(None, transform=transform, transforms=transforms, target_transform=target_transform) # type: ignore[arg-type]
super().__init__(None, transform=transform, target_transform=target_transform) # type: ignore[arg-type]
self.size = size
self.num_classes = num_classes
self.image_size = image_size
Expand Down Expand Up @@ -61,9 +60,6 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target # We don't want to call item() on arbitrarily transformed targets

return img, target.item()

Expand Down

0 comments on commit e91e879

Please sign in to comment.