From dd5f42a4559ef08fa8c3a7c345683736f4c43bcf Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 13 Jun 2023 19:21:54 +0800 Subject: [PATCH] ps_roi_align bw (failed prec) --- .../csrc/ops/mps/ps_roi_align_kernel.mm | 53 +++---- torchvision/csrc/ops/mps/vision_kernels.h | 132 ++++++++++++++++++ 2 files changed, 159 insertions(+), 26 deletions(-) diff --git a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm index 9e63067d2cb..1803ec680f0 100644 --- a/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/ps_roi_align_kernel.mm @@ -111,29 +111,30 @@ at::Tensor ps_roi_align_backward_kernel( const at::Tensor& grad, const at::Tensor& rois, + const at::Tensor& channel_mapping, double spatial_scale, int64_t pooled_height, int64_t pooled_width, + int64_t sampling_ratio, int64_t batch_size, int64_t channels, int64_t height, - int64_t width, - int64_t sampling_ratio, - bool aligned) { + int64_t width) { using namespace at::native::mps; TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor"); TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor"); + TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor"); - at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}; + at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, channel_mapping_t{channel_mapping, "channel_mapping", 3}; at::CheckedFrom c = "ps_roi_align_backward_kernel"; - at::checkAllSameGPU(c, {grad_t, rois_t}); + at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); at::checkAllSameType(c, {grad_t, rois_t}); float spatial_scale_f = static_cast(spatial_scale); - at::Tensor grad_input = at::zeros( + auto grad_input = at::zeros( {batch_size, channels, height, width}, grad.options()); if (grad.numel() == 0) { @@ -146,11 +147,14 @@ int64_t w_stride = grad.stride(3); int64_t output_size = grad.numel(); + int64_t channels_out = channels / (pooled_height * pooled_width); + at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel"); - auto rois_ = rois.contiguous(); + auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); - id inputBuffer = getMTLBufferStorage(grad); + id inputBuffer = getMTLBufferStorage(grad_); id roisBuffer = getMTLBufferStorage(rois_); + id channelMappingBuffer = getMTLBufferStorage(channel_mapping); id outputBuffer = getMTLBufferStorage(grad_input); id device = MPSDevice::getInstance()->device(); MPSStream* mpsStream = getCurrentMPSStream(); @@ -167,23 +171,20 @@ [computeEncoder setComputePipelineState:binaryPSO]; // [N, C, H, W] - [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0]; + [computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0]; [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; - [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2]; + [computeEncoder setBuffer:channelMappingBuffer offset:channel_mapping.storage_offset() * channel_mapping.element_size() atIndex:2]; + [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3]; - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; - [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; - [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; - [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; - [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7]; - [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; - [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; - [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; - [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12]; - [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13]; - [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14]; - [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15]; + [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6]; + [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7]; + [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8]; + [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9]; + [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10]; + [computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12]; // A threadGroup is equivalent to a cuda's block. NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup; @@ -206,9 +207,9 @@ m.impl( TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), TORCH_FN(ps_roi_align_forward_kernel)); - //m.impl( - // TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), - // TORCH_FN(ps_roi_align_backward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), + TORCH_FN(ps_roi_align_backward_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/mps/vision_kernels.h b/torchvision/csrc/ops/mps/vision_kernels.h index c83ced529ea..f8f5033bcd8 100644 --- a/torchvision/csrc/ops/mps/vision_kernels.h +++ b/torchvision/csrc/ops/mps/vision_kernels.h @@ -748,6 +748,138 @@ kernel void ps_roi_align( \ REGISTER_PS_ROI_ALIGN_OP(float); REGISTER_PS_ROI_ALIGN_OP(half); +template +kernel void ps_roi_align_backward( + constant T * grad_output [[buffer(0)]], + constant T * rois [[buffer(1)]], + constant int64_t * channel_mapping [[buffer(2)]], + device T * grad_input [[buffer(3)]], + constant int64_t & output_size [[buffer(4)]], + constant int64_t & channels [[buffer(5)]], + constant int64_t & height [[buffer(6)]], + constant int64_t & width [[buffer(7)]], + constant int64_t & pooled_height [[buffer(8)]], + constant int64_t & pooled_width [[buffer(9)]], + constant int64_t & sampling_ratio [[buffer(10)]], + constant int64_t & channels_out [[buffer(11)]], + constant float & spatial_scale [[buffer(12)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint2 tptg [[threads_per_threadgroup]], + uint2 tid2 [[thread_position_in_threadgroup]]){ + + MPS_1D_KERNEL_LOOP(index, output_size, 1) { + // (n, *, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int n = index / pooled_width / pooled_height / channels_out; + + constant T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + // Force too small ROIs to be 1x1 + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + int c_in = channel_mapping[index]; + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + const T grad_output_this_bin = grad_output[index]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + const int offset = (roi_batch_ind * channels + c_in) * height * width; + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, + width, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + device atomic_uint* xAtomic = (device atomic_uint*)(grad_input + offset + y_low * width + x_low); + device atomic_uint* yAtomic = (device atomic_uint*)(grad_input + offset + y_low * width + x_high); + device atomic_uint* zAtomic = (device atomic_uint*)(grad_input + offset + y_high * width + x_low); + device atomic_uint* wAtomic = (device atomic_uint*)(grad_input + offset + y_high * width + x_high); + + // atomic_float data type is supported on Metal 3 onward. + // TODO: Use native atomic_fetch_add_explicit for Metal 3. + atomic_add_float(xAtomic, static_cast(g1)); + atomic_add_float(yAtomic, static_cast(g2)); + atomic_add_float(zAtomic, static_cast(g3)); + atomic_add_float(wAtomic, static_cast(g4)); + } // if + } // ix + } // iy + } +} + +#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE) \ +template \ +[[host_name("ps_roi_align_backward_" #DTYPE)]] \ +kernel void ps_roi_align_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + constant int64_t * channel_mapping [[buffer(2)]], \ + device DTYPE * grad_input [[buffer(3)]], \ + constant int64_t & output_size [[buffer(4)]], \ + constant int64_t & channels [[buffer(5)]], \ + constant int64_t & height [[buffer(6)]], \ + constant int64_t & width [[buffer(7)]], \ + constant int64_t & pooled_height [[buffer(8)]], \ + constant int64_t & pooled_width [[buffer(9)]], \ + constant int64_t & sampling_ratio [[buffer(10)]], \ + constant int64_t & channels_out [[buffer(11)]], \ + constant float & spatial_scale [[buffer(12)]], \ + uint2 tgid [[threadgroup_position_in_grid]], \ + uint2 tptg [[threads_per_threadgroup]], \ + uint2 tid2 [[thread_position_in_threadgroup]]); + +REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float); +REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half); + )VISION_METAL"; static id compileBinaryOpsLibrary(id device) {