Skip to content

Commit

Permalink
Hopefully fix cuda test
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jul 27, 2023
1 parent acc7a98 commit 5e02675
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
assert_no_warnings,
cache,
cpu_and_cuda,
freeze_rng_state,
ignore_jit_no_profile_information_warning,
make_bounding_box,
make_detection_mask,
Expand Down Expand Up @@ -61,8 +62,10 @@ def _check_kernel_cuda_vs_cpu(kernel, input, *args, rtol, atol, **kwargs):
input_cuda = input.as_subclass(torch.Tensor)
input_cpu = input_cuda.to("cpu")

actual = kernel(input_cuda, *args, **kwargs)
expected = kernel(input_cpu, *args, **kwargs)
with freeze_rng_state():
actual = kernel(input_cuda, *args, **kwargs)
with freeze_rng_state():
expected = kernel(input_cpu, *args, **kwargs)

assert_close(actual, expected, check_device=False, rtol=rtol, atol=atol)

Expand Down Expand Up @@ -1772,11 +1775,11 @@ def test_cpu_vs_gpu(self, T):
batch_size = 3
H, W = 12, 12

imgs = torch.rand(batch_size, 3, H, W).to("cuda")
labels = torch.randint(0, num_classes, (batch_size,)).to("cuda")
imgs = torch.rand(batch_size, 3, H, W)
labels = torch.randint(0, num_classes, (batch_size,))
cutmix_mixup = T(alpha=0.5, num_classes=num_classes)

_check_kernel_cuda_vs_cpu(cutmix_mixup, input=(imgs, labels), rtol=None, atol=None)
_check_kernel_cuda_vs_cpu(cutmix_mixup, imgs, labels, rtol=None, atol=None)

@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
def test_error(self, T):
Expand Down

0 comments on commit 5e02675

Please sign in to comment.