Skip to content

Commit

Permalink
Add MPS kernels for nms and roi ops (#7643)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
  • Loading branch information
3 people authored Aug 1, 2023
1 parent f524cd3 commit 16d62e3
Show file tree
Hide file tree
Showing 15 changed files with 2,146 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .github/scripts/run-clang-format.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
DEVNULL = open(os.devnull, "wb")


DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu"
DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu,mm"


class ExitStatus:
Expand Down
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ set(CMAKE_CXX_STANDARD 17)
file(STRINGS version.txt TORCHVISION_VERSION)

option(WITH_CUDA "Enable CUDA support" OFF)
option(WITH_MPS "Enable MPS support" OFF)
option(WITH_PNG "Enable features requiring LibPNG." ON)
option(WITH_JPEG "Enable features requiring LibJPEG." ON)
option(USE_PYTHON "Link to Python when building" OFF)
Expand All @@ -15,6 +16,11 @@ if(WITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
endif()

if(WITH_MPS)
enable_language(OBJC OBJCXX)
add_definitions(-DWITH_MPS)
endif()

find_package(Torch REQUIRED)

if (WITH_PNG)
Expand Down Expand Up @@ -79,6 +85,9 @@ list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCP
if(WITH_CUDA)
list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast)
endif()
if(WITH_MPS)
list(APPEND ALLOW_LISTED ${TVCPP}/ops/mps)
endif()

FOREACH(DIR ${ALLOW_LISTED})
file(GLOB ALL_SOURCES ${ALL_SOURCES} ${DIR}/*.*)
Expand Down
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,13 @@ def get_extensions():
+ glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
)
source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))

print("Compiling extensions with following flags:")
force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
print(f" FORCE_CUDA: {force_cuda}")
force_mps = os.getenv("FORCE_MPS", "0") == "1"
print(f" FORCE_MPS: {force_mps}")
debug_mode = os.getenv("DEBUG", "0") == "1"
print(f" DEBUG: {debug_mode}")
use_png = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
Expand Down Expand Up @@ -202,6 +205,8 @@ def get_extensions():
define_macros += [("WITH_HIP", None)]
nvcc_flags = []
extra_compile_args["nvcc"] = nvcc_flags
elif torch.backends.mps.is_available() or force_mps:
sources += source_mps

if sys.platform == "win32":
define_macros += [("torchvision_EXPORTS", None)]
Expand Down
11 changes: 11 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."


Expand Down Expand Up @@ -130,12 +131,22 @@ def cpu_and_cuda():
return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))


def cpu_and_cuda_and_mps():
return cpu_and_cuda() + (pytest.param("mps", marks=pytest.mark.needs_mps),)


def needs_cuda(test_func):
import pytest # noqa

return pytest.mark.needs_cuda(test_func)


def needs_mps(test_func):
import pytest # noqa

return pytest.mark.needs_mps(test_func)


def _create_data(height=3, width=3, channels=3, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
Expand Down
17 changes: 16 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,20 @@

torchvision.disable_beta_transforms_warning()

from common_utils import CUDA_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG
from common_utils import (
CUDA_NOT_AVAILABLE_MSG,
IN_FBCODE,
IN_OSS_CI,
IN_RE_WORKER,
MPS_NOT_AVAILABLE_MSG,
OSS_CI_GPU_NO_CUDA_MSG,
)


def pytest_configure(config):
# register an additional marker (see pytest_collection_modifyitems)
config.addinivalue_line("markers", "needs_cuda: mark for tests that rely on a CUDA device")
config.addinivalue_line("markers", "needs_mps: mark for tests that rely on a MPS device")
config.addinivalue_line("markers", "dont_collect: mark for tests that should not be collected")


Expand All @@ -37,12 +45,16 @@ def pytest_collection_modifyitems(items):
# the "instances" of the tests where device == 'cuda' will have the 'needs_cuda' mark,
# and the ones with device == 'cpu' won't have the mark.
needs_cuda = item.get_closest_marker("needs_cuda") is not None
needs_mps = item.get_closest_marker("needs_mps") is not None

if needs_cuda and not torch.cuda.is_available():
# In general, we skip cuda tests on machines without a GPU
# There are special cases though, see below
item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG))

if needs_mps and not torch.backends.mps.is_available():
item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG))

if IN_FBCODE:
# fbcode doesn't like skipping tests, so instead we just don't collect the test
# so that they don't even "exist", hence the continue statements.
Expand All @@ -54,6 +66,9 @@ def pytest_collection_modifyitems(items):
# TODO: something more robust would be to do that only in a sandcastle instance,
# so that we can still see the test being skipped when testing locally from a devvm
continue
if needs_mps and not torch.backends.mps.is_available():
# Same as above, but for MPS
continue
elif IN_OSS_CI:
# Here we're not in fbcode, so we can safely collect and skip tests.
if not needs_cuda and torch.cuda.is_available():
Expand Down
109 changes: 84 additions & 25 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
import torch.fx
import torch.nn.functional as F
from common_utils import assert_equal, cpu_and_cuda, needs_cuda
from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
from PIL import Image
from torch import nn, Tensor
from torch.autograd import gradcheck
Expand Down Expand Up @@ -96,12 +96,33 @@ def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor:

class RoIOpTester(ABC):
dtype = torch.float64
mps_dtype = torch.float32
mps_backward_atol = 2e-2

@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False))
def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, deterministic=False, **kwargs):
x_dtype = self.dtype if x_dtype is None else x_dtype
rois_dtype = self.dtype if rois_dtype is None else rois_dtype
@pytest.mark.parametrize(
"x_dtype",
(
torch.float16,
torch.float32,
torch.float64,
),
ids=str,
)
def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs):
if device == "mps" and x_dtype is torch.float64:
pytest.skip("MPS does not support float64")

rois_dtype = x_dtype if rois_dtype is None else rois_dtype

tol = 1e-5
if x_dtype is torch.half:
if device == "mps":
tol = 5e-3
else:
tol = 4e-3

pool_size = 5
# n_channels % (pool_size ** 2) == 0 required for PS operations.
n_channels = 2 * (pool_size**2)
Expand All @@ -120,10 +141,9 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, determ
# the following should be true whether we're running an autocast test or not.
assert y.dtype == x.dtype
gt_y = self.expected_fn(
x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs
x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=x_dtype, **kwargs
)

tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)

@pytest.mark.parametrize("device", cpu_and_cuda())
Expand Down Expand Up @@ -155,16 +175,19 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa
torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol)

@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False))
def test_backward(self, seed, device, contiguous, deterministic=False):
atol = self.mps_backward_atol if device == "mps" else 1e-05
dtype = self.mps_dtype if device == "mps" else self.dtype

torch.random.manual_seed(seed)
pool_size = 2
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True)
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=dtype, device=device, requires_grad=True)
if not contiguous:
x = x.permute(0, 1, 3, 2)
rois = torch.tensor(
[[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=self.dtype, device=device # format is (xyxy)
[[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=dtype, device=device # format is (xyxy)
)

def func(z):
Expand All @@ -173,9 +196,25 @@ def func(z):
script_func = self.get_script_fn(rois, pool_size)

with DeterministicGuard(deterministic):
gradcheck(func, (x,))
gradcheck(func, (x,), atol=atol)

gradcheck(script_func, (x,), atol=atol)

gradcheck(script_func, (x,))
@needs_mps
def test_mps_error_inputs(self):
pool_size = 2
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=torch.float16, device="mps", requires_grad=True)
rois = torch.tensor(
[[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=torch.float16, device="mps" # format is (xyxy)
)

def func(z):
return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)

with pytest.raises(
RuntimeError, match="MPS does not support (?:ps_)?roi_(?:align|pool)? backward with float16 inputs."
):
gradcheck(func, (x,))

@needs_cuda
@pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
Expand Down Expand Up @@ -271,6 +310,8 @@ def test_jit_boxes_list(self):


class TestPSRoIPool(RoIOpTester):
mps_backward_atol = 5e-2

def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)

Expand Down Expand Up @@ -352,6 +393,8 @@ def bilinear_interpolate(data, y, x, snap_border=False):


class TestRoIAlign(RoIOpTester):
mps_backward_atol = 6e-2

def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs):
return ops.RoIAlign(
(pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
Expand Down Expand Up @@ -418,10 +461,11 @@ def test_boxes_shape(self):
self._helper_boxes_shape(ops.roi_align)

@pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64), ids=str)
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("deterministic", (True, False))
def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, rois_dtype=None):
def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois_dtype=None):
if deterministic and device == "cpu":
pytest.skip("cpu is always deterministic, don't retest")
super().test_forward(
Expand Down Expand Up @@ -450,7 +494,7 @@ def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
)

@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("deterministic", (True, False))
def test_backward(self, seed, device, contiguous, deterministic):
Expand Down Expand Up @@ -537,6 +581,8 @@ def test_jit_boxes_list(self):


class TestPSRoIAlign(RoIOpTester):
mps_backward_atol = 5e-2

def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois)

Expand Down Expand Up @@ -705,40 +751,53 @@ def test_qnms(self, iou, scale, zero_point):

torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))

@needs_cuda
@pytest.mark.parametrize(
"device",
(
pytest.param("cuda", marks=pytest.mark.needs_cuda),
pytest.param("mps", marks=pytest.mark.needs_mps),
),
)
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
def test_nms_cuda(self, iou, dtype=torch.float64):
def test_nms_gpu(self, iou, device, dtype=torch.float64):
dtype = torch.float32 if device == "mps" else dtype
tol = 1e-3 if dtype is torch.half else 1e-5
err_msg = "NMS incompatible between CPU and CUDA for IoU={}"

boxes, scores = self._create_tensors_with_iou(1000, iou)
r_cpu = ops.nms(boxes, scores, iou)
r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou)
r_gpu = ops.nms(boxes.to(device), scores.to(device), iou)

is_eq = torch.allclose(r_cpu, r_cuda.cpu())
is_eq = torch.allclose(r_cpu, r_gpu.cpu())
if not is_eq:
# if the indices are not the same, ensure that it's because the scores
# are duplicate
is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol)
is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol)
assert is_eq, err_msg.format(iou)

@needs_cuda
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
@pytest.mark.parametrize("dtype", (torch.float, torch.half))
def test_autocast(self, iou, dtype):
with torch.cuda.amp.autocast():
self.test_nms_cuda(iou=iou, dtype=dtype)
self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")

@needs_cuda
def test_nms_cuda_float16(self):
@pytest.mark.parametrize(
"device",
(
pytest.param("cuda", marks=pytest.mark.needs_cuda),
pytest.param("mps", marks=pytest.mark.needs_mps),
),
)
def test_nms_float16(self, device):
boxes = torch.tensor(
[
[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
]
).cuda()
scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda()
).to(device)
scores = torch.tensor([0.6370, 0.7569, 0.3966]).to(device)

iou_thres = 0.2
keep32 = ops.nms(boxes, scores, iou_thres)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/csrc/ops/cpu/nms_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ at::Tensor nms_kernel_impl(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold) {
TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor");
TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor");
TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor");
TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor");
TORCH_CHECK(
dets.scalar_type() == scores.scalar_type(),
"dets should have the same type as scores");
Expand Down
6 changes: 6 additions & 0 deletions torchvision/csrc/ops/mps/mps_helpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
constexpr int threadsPerBlock = 512;

template <typename T>
constexpr inline T ceil_div(T n, T m) {
return (n + m - 1) / m;
}
Loading

0 comments on commit 16d62e3

Please sign in to comment.