Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MPS kernels #7643

Merged
merged 32 commits into from
Aug 1, 2023
Merged

Add MPS kernels #7643

merged 32 commits into from
Aug 1, 2023

Conversation

qqaatw
Copy link
Contributor

@qqaatw qqaatw commented May 30, 2023

Summary:

  1. Prerequisite PR in PyTorch repository: [MPS] Prerequisite for MPS C++ extension pytorch#102483.
  2. This PR adds nms, roi_align, roi_pool, ps_roi_align, ps_roi_pool and their corresponding backward kernels if any. Most implementations are inspired by the CUDA implementations.
  3. All the kernel code is placed in mps_kernels.h for the ease of sharing helper functions and macros, as well as caching PSOs.
  4. Atomic operations are used in the backward kernels of RoI functions. Since atomic_float is supported in Metal 3 (macOS Ventura, MSL specs, section 2.6) and later only, for systems with Metal 2.x, we implement a custom atomic addition function.
  5. Apple GPUs natively support 64 bit signed and unsigned integer types, so we unify the integer types in the kernels into 64 bits. It might have performance implications when running kernels on AMD or Intel GPUs. (relevant discussion).
  6. MPS does not support float64. Thus, the absolute tolerances of gradcheck in RoI backward tests are adjusted accordingly.

cc @NicolasHug @pmeier @albanD @kulinseth

@qqaatw qqaatw closed this May 30, 2023
@qqaatw qqaatw reopened this May 30, 2023
@qqaatw qqaatw changed the title Add MPS kernels [Draft] Add MPS kernels May 30, 2023
@qqaatw qqaatw marked this pull request as draft May 30, 2023 16:33
@qqaatw qqaatw force-pushed the add_mps_kernels branch 2 times, most recently from dd5f42a to 6f32285 Compare June 13, 2023 14:42
@qqaatw qqaatw mentioned this pull request Jun 19, 2023
@NicolasHug
Copy link
Member

Hi @qqaatw , I saw this PR isn't in draft state anymore. Is this ready for review?

@qqaatw
Copy link
Contributor Author

qqaatw commented Jun 26, 2023

Hi @qqaatw , I saw this PR isn't in draft state anymore. Is this ready for review?

Hi @NicolasHug, yes, please.

There is an issue with f16 inputs for RoI ops, which doesn't have test coverage. Otherwise the added ops are tested.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @qqaatw. I gave a quick first glance at the tests and made some minor comments / suggestions, but this looks great overall.

As discussed offline with @albanD , we're OK to introduce these new MPS kernels in torchvision, with the shared understanding that the MPS-related support (typically bug reports and fixes) will be on the responsibility of the MPS team.

There is an issue with f16 inputs for RoI ops, which doesn't have test coverage. Otherwise the added ops are tested.

What's the issue? If float16 isn't supported for MPS that's OK, but maybe we should write a small test asserting the error message?

setup.py Outdated Show resolved Hide resolved
test/conftest.py Outdated Show resolved Hide resolved
int64_t w_stride = grad.stride(3);
int64_t output_size = grad.numel();

at::globalContext().alertNotDeterministic("roi_align_backward_kernel");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious, what makes this kernel and the other roi align / pool kernels non-deterministic?

For the CUDA kernels, it's the calls to atomicAdd, but I'm curious what the reason is here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MPS kernels also make use of atomic addition which is either provided by the Metal library or the custom implementation depending on the Metal version (See the atomic_add_float function in mps_kernels.h).

I've added a note in the PR description. Hope it properly explains the non-determinism.

torchvision/ops/roi_align.py Show resolved Hide resolved
test/test_ops.py Show resolved Hide resolved
@@ -271,6 +277,8 @@ def test_jit_boxes_list(self):


class TestPSRoIPool(RoIOpTester):
mps_backward_atol = 5e-2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD , any thought regarding this atol value for gradcheck()?

For ref we typically use 1e-5 for CPU/CUDA, although we seem to be testing on float64 while the MPS tests are currently running on float32.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gradcheck is a bit tricky here as we usually only run it in fp64 precision to get accurate results.
Unfortunately, MPS doesn't support fp64 so we can only resolve to comparing with CPU results or increasing the tolerance significantly.

test/test_ops.py Outdated Show resolved Hide resolved
test/test_ops.py Outdated Show resolved Hide resolved
Copy link
Contributor Author

@qqaatw qqaatw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for reviewing @NicolasHug!

What's the issue? If float16 isn't supported for MPS that's OK, but maybe we should write a small test asserting the error message?

The issue is that the atomic operations on MPS do not support half, and the RoI backward kernels make use of atomic addition. Added checks to the RoI backward kernels. The forward kernels work fine!

test/test_ops.py Outdated Show resolved Hide resolved
int64_t w_stride = grad.stride(3);
int64_t output_size = grad.numel();

at::globalContext().alertNotDeterministic("roi_align_backward_kernel");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MPS kernels also make use of atomic addition which is either provided by the Metal library or the custom implementation depending on the Metal version (See the atomic_add_float function in mps_kernels.h).

I've added a note in the PR description. Hope it properly explains the non-determinism.

@@ -158,12 +158,12 @@ def from_K(t):
y = (
from_K(roi_start_h)
+ ph[None, :, None] * from_K(bin_size_h)
+ (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h)
+ (iy[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_h / roi_bin_grid_h)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.5 is by default f32, casting to the input dtype.

) # [K, PH, IY]
x = (
from_K(roi_start_w)
+ pw[None, :, None] * from_K(bin_size_w)
+ (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w)
+ (ix[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_w / roi_bin_grid_w)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

@qqaatw
Copy link
Contributor Author

qqaatw commented Jul 17, 2023

Hi @NicolasHug, sorry for the delayed update. I've applied all the suggestions.

@qqaatw
Copy link
Contributor Author

qqaatw commented Jul 31, 2023

Gently pinging @NicolasHug.

@NicolasHug
Copy link
Member

Sorry for the delay @qqaatw . I'll provide another round tomorrow

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @qqaatw , I took a last look at the tests and this LGTM.

@albanD was there anything you wanted to check before merge this?

Copy link

@kulinseth kulinseth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great , thanks @qqaatw

@NicolasHug NicolasHug merged commit 16d62e3 into pytorch:main Aug 1, 2023
49 of 60 checks passed
@NicolasHug
Copy link
Member

Thanks @qqaatw !!

@github-actions
Copy link

github-actions bot commented Aug 1, 2023

Hey @NicolasHug!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

facebook-github-bot pushed a commit that referenced this pull request Aug 25, 2023
Reviewed By: matteobettini

Differential Revision: D48642285

fbshipit-source-id: 00534d4080565eb66ed6b2dbb8416f8d7526687e

Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants