Skip to content

Commit

Permalink
Use bigger images
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jul 25, 2023
1 parent e91e879 commit 45bf28c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,8 +1647,8 @@ def __init__(self, size, one_hot, num_categories):
assert size < num_categories

def __getitem__(self, idx):
img = torch.rand(3, 12, 12)
label = idx
img = torch.rand(3, 100, 100)
label = idx # This ensures all labels in a batch are unique and makes testing easier
if self.one_hot:
label = torch.nn.functional.one_hot(torch.tensor(label), num_classes=self.num_categories)
return img, label
Expand All @@ -1672,7 +1672,6 @@ def test_supported_input_structure(self, T, one_hot):
cutmix_mixup = transforms.Compose([cutmix, mixup])
else:
cutmix_mixup = transforms.Compose([mixup, cutmix])
# When both CutMix and MixUp
expected_num_non_zero_labels = 3
else:
cutmix_mixup = T(alpha=0.5, num_categories=num_categories)
Expand All @@ -1682,11 +1681,12 @@ def test_supported_input_structure(self, T, one_hot):

# Input sanity checks
img, target = next(iter(dl))
input_img_size = img.shape[-3:]
assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor)
assert target.shape == (batch_size, num_categories) if one_hot else (batch_size,)

def check_output(img, target):
assert img.shape == (batch_size, 3, 12, 12)
assert img.shape == (batch_size, *input_img_size)
assert target.shape == (batch_size, num_categories)
torch.testing.assert_close(target.sum(axis=-1), torch.ones(batch_size))
num_non_zero_labels = (target != 0).sum(axis=-1)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:

H, W = query_spatial_size(flat_inputs)

r_x = torch.randint(W, ())
r_y = torch.randint(H, ())
r_x = torch.randint(W, size=(1,))
r_y = torch.randint(H, size=(1,))

r = 0.5 * math.sqrt(1.0 - lam)
r_w_half = int(r * W)
Expand Down

0 comments on commit 45bf28c

Please sign in to comment.