Skip to content

Commit

Permalink
ps_roi_align bw (failed prec)
Browse files Browse the repository at this point in the history
  • Loading branch information
qqaatw committed Jun 13, 2023
1 parent 195d03a commit 6f32285
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 26 deletions.
53 changes: 27 additions & 26 deletions torchvision/csrc/ops/mps/ps_roi_align_kernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -111,29 +111,30 @@
at::Tensor ps_roi_align_backward_kernel(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned) {
int64_t width) {

using namespace at::native::mps;
TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor");

at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2};
at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, channel_mapping_t{channel_mapping, "channel_mapping", 3};

at::CheckedFrom c = "ps_roi_align_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t});
at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t});
at::checkAllSameType(c, {grad_t, rois_t});

float spatial_scale_f = static_cast<float>(spatial_scale);

at::Tensor grad_input = at::zeros(
auto grad_input = at::zeros(
{batch_size, channels, height, width}, grad.options());

if (grad.numel() == 0) {
Expand All @@ -146,11 +147,14 @@
int64_t w_stride = grad.stride(3);
int64_t output_size = grad.numel();

int64_t channels_out = channels / (pooled_height * pooled_width);

at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel");
auto rois_ = rois.contiguous();
auto grad_ = grad.contiguous(), rois_ = rois.contiguous();

id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad);
id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> channelMappingBuffer = getMTLBufferStorage(channel_mapping);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
Expand All @@ -167,23 +171,20 @@

[computeEncoder setComputePipelineState:binaryPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0];
[computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2];
[computeEncoder setBuffer:channelMappingBuffer offset:channel_mapping.storage_offset() * channel_mapping.element_size() atIndex:2];
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3];

[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
[computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12];
[computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13];
[computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14];
[computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10];
[computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12];

// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup;
Expand All @@ -206,9 +207,9 @@
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_forward_kernel));
//m.impl(
// TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
// TORCH_FN(ps_roi_align_backward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
TORCH_FN(ps_roi_align_backward_kernel));
}

} // namespace ops
Expand Down
132 changes: 132 additions & 0 deletions torchvision/csrc/ops/mps/vision_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,138 @@ kernel void ps_roi_align<DTYPE>( \
REGISTER_PS_ROI_ALIGN_OP(float);
REGISTER_PS_ROI_ALIGN_OP(half);
template<typename T>
kernel void ps_roi_align_backward(
constant T * grad_output [[buffer(0)]],
constant T * rois [[buffer(1)]],
constant int64_t * channel_mapping [[buffer(2)]],
device T * grad_input [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant int64_t & sampling_ratio [[buffer(10)]],
constant int64_t & channels_out [[buffer(11)]],
constant float & spatial_scale [[buffer(12)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, *, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int n = index / pooled_width / pooled_height / channels_out;
constant T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
// Force too small ROIs to be 1x1
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
int c_in = channel_mapping[index];
// Do not using floor/ceil; this implementation detail is critical
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
const T grad_output_this_bin = grad_output[index];
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
const int offset = (roi_batch_ind * channels + c_in) * height * width;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = hstart +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = wstart +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(
height,
width,
y,
x,
w1,
w2,
w3,
w4,
x_low,
x_high,
y_low,
y_high,
index);
T g1 = grad_output_this_bin * w1 / count;
T g2 = grad_output_this_bin * w2 / count;
T g3 = grad_output_this_bin * w3 / count;
T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
device atomic_uint* xAtomic = (device atomic_uint*)(grad_input + offset + y_low * width + x_low);
device atomic_uint* yAtomic = (device atomic_uint*)(grad_input + offset + y_low * width + x_high);
device atomic_uint* zAtomic = (device atomic_uint*)(grad_input + offset + y_high * width + x_low);
device atomic_uint* wAtomic = (device atomic_uint*)(grad_input + offset + y_high * width + x_high);
// atomic_float data type is supported on Metal 3 onward.
// TODO: Use native atomic_fetch_add_explicit for Metal 3.
atomic_add_float(xAtomic, static_cast<T>(g1));
atomic_add_float(yAtomic, static_cast<T>(g2));
atomic_add_float(zAtomic, static_cast<T>(g3));
atomic_add_float(wAtomic, static_cast<T>(g4));
} // if
} // ix
} // iy
}
}
#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE) \
template \
[[host_name("ps_roi_align_backward_" #DTYPE)]] \
kernel void ps_roi_align_backward<DTYPE>( \
constant DTYPE * grad_output [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
constant int64_t * channel_mapping [[buffer(2)]], \
device DTYPE * grad_input [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant int64_t & sampling_ratio [[buffer(10)]], \
constant int64_t & channels_out [[buffer(11)]], \
constant float & spatial_scale [[buffer(12)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float);
REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half);
)VISION_METAL";

static id<MTLLibrary> compileBinaryOpsLibrary(id<MTLDevice> device) {
Expand Down

0 comments on commit 6f32285

Please sign in to comment.