Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MPS kernels #7643

Merged
merged 32 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
da9c2de
Draft
qqaatw May 30, 2023
7f0d4ce
NMS f32
qqaatw Jun 9, 2023
c7c43dc
roi_align fw
qqaatw Jun 10, 2023
ccde29c
roi_align bw (failed)
qqaatw Jun 13, 2023
c930e54
roi_pool fw
qqaatw Jun 13, 2023
3305cc1
roi_pool bw (failed prec)
qqaatw Jun 13, 2023
0f8d2c3
ps_roi_align fw
qqaatw Jun 13, 2023
e157c7c
ps_roi_align bw (failed prec)
qqaatw Jun 13, 2023
160d5b5
Several improvements
qqaatw Jun 14, 2023
40ea525
ps_roi_pool fw
qqaatw Jun 17, 2023
2c20036
ps_roi_pool bw
qqaatw Jun 17, 2023
a427c2a
Rename kernels header
qqaatw Jun 17, 2023
8036dc2
Add atol to RoI backward tests
qqaatw Jun 17, 2023
0ae9124
mps kernels formatting
qqaatw Jun 17, 2023
1d21cfc
binaryPSO -> visionPSO
qqaatw Jun 18, 2023
3018b25
Formatting
qqaatw Jun 20, 2023
d609da4
Testing
qqaatw Jun 20, 2023
256bd56
Rename cpu_and_gpu to cpu_and_cuda
qqaatw Jun 20, 2023
990685f
formatting
qqaatw Jun 20, 2023
5dce2d7
mps kernel dtype consistency
qqaatw Jun 20, 2023
40ebde5
Kernel improvements
qqaatw Jun 20, 2023
8e4d868
Merge branch 'main' of github.com:pytorch/vision into add_mps_kernels
NicolasHug Jun 21, 2023
24109d4
Merge branch 'main' into add_mps_kernels
NicolasHug Jul 3, 2023
efbb52e
Apply suggestions from code review
qqaatw Jul 4, 2023
b36cafa
Test more dtypes for roi forward functions and assert half inputs in …
qqaatw Jul 4, 2023
fad54f6
Add mps error inputs check
qqaatw Jul 7, 2023
66a00fc
Clean up headers
qqaatw Jul 10, 2023
70f3906
Fix dtype parameters
qqaatw Jul 17, 2023
b1cf619
parameterize nms gpu test
qqaatw Jul 17, 2023
3f82ee4
Merge branch 'main' into add_mps_kernels
NicolasHug Aug 1, 2023
108bc15
Allow to skip MPS tests internally on non-MPS machines
NicolasHug Aug 1, 2023
c825c53
Merge branch 'main' into add_mps_kernels
NicolasHug Aug 1, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/scripts/run-clang-format.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
DEVNULL = open(os.devnull, "wb")


DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu"
DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu,mm"


class ExitStatus:
Expand Down
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ set(CMAKE_CXX_STANDARD 17)
file(STRINGS version.txt TORCHVISION_VERSION)

option(WITH_CUDA "Enable CUDA support" OFF)
option(WITH_MPS "Enable MPS support" OFF)
option(WITH_PNG "Enable features requiring LibPNG." ON)
option(WITH_JPEG "Enable features requiring LibJPEG." ON)
option(USE_PYTHON "Link to Python when building" OFF)
Expand All @@ -15,6 +16,11 @@ if(WITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
endif()

if(WITH_MPS)
enable_language(OBJC OBJCXX)
add_definitions(-DWITH_MPS)
endif()

find_package(Torch REQUIRED)

if (WITH_PNG)
Expand Down Expand Up @@ -79,6 +85,9 @@ list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCP
if(WITH_CUDA)
list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast)
endif()
if(WITH_MPS)
list(APPEND ALLOW_LISTED ${TVCPP}/ops/mps)
endif()

FOREACH(DIR ${ALLOW_LISTED})
file(GLOB ALL_SOURCES ${ALL_SOURCES} ${DIR}/*.*)
Expand Down
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,13 @@ def get_extensions():
+ glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
)
source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))

print("Compiling extensions with following flags:")
force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
print(f" FORCE_CUDA: {force_cuda}")
force_mps = os.getenv("FORCE_MPS", "0") == "1"
print(f" FORCE_MPS: {force_mps}")
debug_mode = os.getenv("DEBUG", "0") == "1"
print(f" DEBUG: {debug_mode}")
use_png = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
Expand Down Expand Up @@ -202,6 +205,8 @@ def get_extensions():
define_macros += [("WITH_HIP", None)]
nvcc_flags = []
extra_compile_args["nvcc"] = nvcc_flags
elif torch.backends.mps.is_available() or force_mps:
sources += source_mps

if sys.platform == "win32":
define_macros += [("torchvision_EXPORTS", None)]
Expand Down
11 changes: 11 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."


Expand Down Expand Up @@ -130,12 +131,22 @@ def cpu_and_cuda():
return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))


def cpu_and_cuda_and_mps():
return cpu_and_cuda() + (pytest.param("mps", marks=pytest.mark.needs_mps),)


def needs_cuda(test_func):
import pytest # noqa

return pytest.mark.needs_cuda(test_func)


def needs_mps(test_func):
import pytest # noqa

return pytest.mark.needs_mps(test_func)


def _create_data(height=3, width=3, channels=3, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
Expand Down
14 changes: 13 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,20 @@

torchvision.disable_beta_transforms_warning()

from common_utils import CUDA_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG
from common_utils import (
CUDA_NOT_AVAILABLE_MSG,
IN_FBCODE,
IN_OSS_CI,
IN_RE_WORKER,
MPS_NOT_AVAILABLE_MSG,
OSS_CI_GPU_NO_CUDA_MSG,
)


def pytest_configure(config):
# register an additional marker (see pytest_collection_modifyitems)
config.addinivalue_line("markers", "needs_cuda: mark for tests that rely on a CUDA device")
config.addinivalue_line("markers", "needs_mps: mark for tests that rely on a MPS device")
config.addinivalue_line("markers", "dont_collect: mark for tests that should not be collected")


Expand All @@ -37,12 +45,16 @@ def pytest_collection_modifyitems(items):
# the "instances" of the tests where device == 'cuda' will have the 'needs_cuda' mark,
# and the ones with device == 'cpu' won't have the mark.
needs_cuda = item.get_closest_marker("needs_cuda") is not None
needs_mps = item.get_closest_marker("needs_mps") is not None

if needs_cuda and not torch.cuda.is_available():
# In general, we skip cuda tests on machines without a GPU
# There are special cases though, see below
item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG))

if needs_mps and not torch.backends.mps.is_available():
item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG))

if IN_FBCODE:
# fbcode doesn't like skipping tests, so instead we just don't collect the test
# so that they don't even "exist", hence the continue statements.
Expand Down
97 changes: 74 additions & 23 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
import torch.fx
import torch.nn.functional as F
from common_utils import assert_equal, cpu_and_cuda, needs_cuda
from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
from PIL import Image
from torch import nn, Tensor
from torch.autograd import gradcheck
Expand Down Expand Up @@ -96,21 +96,40 @@ def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor:

class RoIOpTester(ABC):
dtype = torch.float64
mps_dtype = torch.float32
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
mps_backward_atol = 2e-2

@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False))
def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, deterministic=False, **kwargs):
x_dtype = self.dtype if x_dtype is None else x_dtype
rois_dtype = self.dtype if rois_dtype is None else rois_dtype
@pytest.mark.parametrize(
"dtype",
(
torch.float16,
torch.float32,
torch.float64,
),
ids=str,
)
def test_forward(self, device, contiguous, dtype, deterministic=False, **kwargs):
if device == "mps" and dtype is torch.float64:
pytest.skip("MPS does not support float64")

tol = 1e-5
if dtype is torch.half:
if device == "mps":
tol = 5e-3
else:
tol = 4e-3

pool_size = 5
# n_channels % (pool_size ** 2) == 0 required for PS operations.
n_channels = 2 * (pool_size**2)
x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device)
x = torch.rand(2, n_channels, 10, 10, dtype=dtype, device=device)
if not contiguous:
x = x.permute(0, 1, 3, 2)
rois = torch.tensor(
[[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], # format is (xyxy)
dtype=rois_dtype,
dtype=dtype,
device=device,
)

Expand All @@ -120,10 +139,9 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, determ
# the following should be true whether we're running an autocast test or not.
assert y.dtype == x.dtype
gt_y = self.expected_fn(
x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs
x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=dtype, **kwargs
)

tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)

@pytest.mark.parametrize("device", cpu_and_cuda())
Expand Down Expand Up @@ -155,16 +173,19 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa
torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol)

@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False))
def test_backward(self, seed, device, contiguous, deterministic=False):
atol = self.mps_backward_atol if device == "mps" else 1e-05
dtype = self.mps_dtype if device == "mps" else self.dtype

torch.random.manual_seed(seed)
pool_size = 2
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True)
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=dtype, device=device, requires_grad=True)
if not contiguous:
x = x.permute(0, 1, 3, 2)
rois = torch.tensor(
[[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=self.dtype, device=device # format is (xyxy)
[[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=dtype, device=device # format is (xyxy)
)

def func(z):
Expand All @@ -173,9 +194,9 @@ def func(z):
script_func = self.get_script_fn(rois, pool_size)

with DeterministicGuard(deterministic):
gradcheck(func, (x,))
gradcheck(func, (x,), atol=atol)

gradcheck(script_func, (x,))
gradcheck(script_func, (x,), atol=atol)

@needs_cuda
@pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
Expand Down Expand Up @@ -271,6 +292,8 @@ def test_jit_boxes_list(self):


class TestPSRoIPool(RoIOpTester):
mps_backward_atol = 5e-2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD , any thought regarding this atol value for gradcheck()?

For ref we typically use 1e-5 for CPU/CUDA, although we seem to be testing on float64 while the MPS tests are currently running on float32.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gradcheck is a bit tricky here as we usually only run it in fp64 precision to get accurate results.
Unfortunately, MPS doesn't support fp64 so we can only resolve to comparing with CPU results or increasing the tolerance significantly.


def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)

Expand Down Expand Up @@ -352,6 +375,8 @@ def bilinear_interpolate(data, y, x, snap_border=False):


class TestRoIAlign(RoIOpTester):
mps_backward_atol = 6e-2

def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs):
return ops.RoIAlign(
(pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
Expand Down Expand Up @@ -418,18 +443,18 @@ def test_boxes_shape(self):
self._helper_boxes_shape(ops.roi_align)

@pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64), ids=str)
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("deterministic", (True, False))
def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, rois_dtype=None):
def test_forward(self, device, contiguous, deterministic, aligned, dtype):
if deterministic and device == "cpu":
pytest.skip("cpu is always deterministic, don't retest")
super().test_forward(
device=device,
contiguous=contiguous,
deterministic=deterministic,
x_dtype=x_dtype,
rois_dtype=rois_dtype,
dtype=dtype,
aligned=aligned,
)

Expand All @@ -450,7 +475,7 @@ def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
)

@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("deterministic", (True, False))
def test_backward(self, seed, device, contiguous, deterministic):
Expand Down Expand Up @@ -537,6 +562,8 @@ def test_jit_boxes_list(self):


class TestPSRoIAlign(RoIOpTester):
mps_backward_atol = 5e-2

def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois)

Expand Down Expand Up @@ -722,23 +749,47 @@ def test_nms_cuda(self, iou, dtype=torch.float64):
is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol)
assert is_eq, err_msg.format(iou)

@needs_mps
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
qqaatw marked this conversation as resolved.
Show resolved Hide resolved
def test_nms_mps(self, iou, dtype=torch.float32):
tol = 1e-3 if dtype is torch.half else 1e-5
err_msg = "NMS incompatible between CPU and MPS for IoU={}"

boxes, scores = self._create_tensors_with_iou(1000, iou)
r_cpu = ops.nms(boxes, scores, iou)
r_mps = ops.nms(boxes.to("mps"), scores.to("mps"), iou)

print(r_cpu.size(), r_mps.size())
is_eq = torch.allclose(r_cpu, r_mps.cpu())
if not is_eq:
# if the indices are not the same, ensure that it's because the scores
# are duplicate
is_eq = torch.allclose(scores[r_cpu], scores[r_mps.cpu()], rtol=tol, atol=tol)
assert is_eq, err_msg.format(iou)

@needs_cuda
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
@pytest.mark.parametrize("dtype", (torch.float, torch.half))
def test_autocast(self, iou, dtype):
with torch.cuda.amp.autocast():
self.test_nms_cuda(iou=iou, dtype=dtype)

@needs_cuda
def test_nms_cuda_float16(self):
@pytest.mark.parametrize(
"device",
(
pytest.param("cuda", marks=pytest.mark.needs_cuda),
pytest.param("mps", marks=pytest.mark.needs_mps),
),
)
def test_nms_float16(self, device):
boxes = torch.tensor(
[
[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
]
).cuda()
scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda()
).to(device)
scores = torch.tensor([0.6370, 0.7569, 0.3966]).to(device)

iou_thres = 0.2
keep32 = ops.nms(boxes, scores, iou_thres)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/csrc/ops/cpu/nms_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ at::Tensor nms_kernel_impl(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold) {
TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor");
TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor");
TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor");
TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor");
TORCH_CHECK(
dets.scalar_type() == scores.scalar_type(),
"dets should have the same type as scores");
Expand Down
6 changes: 6 additions & 0 deletions torchvision/csrc/ops/mps/mps_helpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
constexpr int threadsPerBlock = 512;

template <typename T>
constexpr inline T ceil_div(T n, T m) {
return (n + m - 1) / m;
}
Loading