Skip to content

Commit

Permalink
Rename cpu_and_gpu to cpu_and_cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
qqaatw committed Jun 20, 2023
1 parent d609da4 commit 256bd56
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 140 deletions.
6 changes: 3 additions & 3 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,14 @@ def disable_console_output():
yield


def cpu_and_gpu():
def cpu_and_cuda():
import pytest # noqa

return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))


def cpu_and_gpu_and_mps():
return cpu_and_gpu() + (pytest.param("mps", marks=pytest.mark.needs_mps),)
def cpu_and_cuda_and_mps():
return cpu_and_cuda() + (pytest.param("mps", marks=pytest.mark.needs_mps),)


def needs_cuda(test_func):
Expand Down
15 changes: 11 additions & 4 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@

torchvision.disable_beta_transforms_warning()

from common_utils import CUDA_NOT_AVAILABLE_MSG, MPS_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG, OSS_CI_GPU_NO_MPS_MSG
from common_utils import (
CUDA_NOT_AVAILABLE_MSG,
IN_FBCODE,
IN_OSS_CI,
IN_RE_WORKER,
MPS_NOT_AVAILABLE_MSG,
OSS_CI_GPU_NO_CUDA_MSG,
OSS_CI_GPU_NO_MPS_MSG,
)


def pytest_configure(config):
Expand All @@ -34,18 +42,17 @@ def pytest_collection_modifyitems(items):
# The needs_cuda mark will exist if the test was explicitly decorated with
# the @needs_cuda decorator. It will also exist if it was parametrized with a
# parameter that has the mark: for example if a test is parametrized with
# @pytest.mark.parametrize('device', cpu_and_gpu())
# @pytest.mark.parametrize('device', cpu_and_cuda())
# the "instances" of the tests where device == 'cuda' will have the 'needs_cuda' mark,
# and the ones with device == 'cpu' won't have the mark.
needs_cuda = item.get_closest_marker("needs_cuda") is not None
needs_mps = item.get_closest_marker("needs_mps") is not None


if needs_cuda and not torch.cuda.is_available():
# In general, we skip cuda tests on machines without a GPU
# There are special cases though, see below
item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG))

if needs_mps and not torch.backends.mps.is_available():
item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG))

Expand Down
Loading

0 comments on commit 256bd56

Please sign in to comment.