Skip to content

Commit

Permalink
ps_roi_align fw
Browse files Browse the repository at this point in the history
  • Loading branch information
qqaatw committed Jun 13, 2023
1 parent 15b0108 commit aa48cc2
Show file tree
Hide file tree
Showing 2 changed files with 321 additions and 12 deletions.
215 changes: 215 additions & 0 deletions torchvision/csrc/ops/mps/ps_roi_align_kernel.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "vision_kernels.h"
#include "mps_helpers.h"

#include <iostream>
#include <cmath>

namespace vision {
namespace ops {

namespace {

// This should be in sync with the one in metal kernel.
int const threadsPerBlock = 512;

std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_kernel(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio) {

using namespace at::native::mps;
TORCH_CHECK(input.is_mps(), "input must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");

at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};

at::CheckedFrom c = "ps_roi_align_forward_kernel";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});

int64_t num_rois = rois.size(0);
int64_t channels = input.size(1);
int64_t height = input.size(2);
int64_t width = input.size(3);
float spatial_scale_f = static_cast<float>(spatial_scale);

TORCH_CHECK(
channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width");

int64_t channels_out = channels / (pooled_height * pooled_width);

auto output = at::zeros(
{num_rois, channels_out, pooled_height, pooled_width}, input.options());
auto channel_mapping =
at::zeros(output.sizes(), input.options().dtype(at::kLong));

int64_t output_size = output.numel();

if (output_size == 0) {
return std::make_tuple(output, channel_mapping);
}

auto input_ = input.contiguous();
auto rois_ = rois.contiguous();

id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
id<MTLBuffer> channelMappingBuffer = getMTLBufferStorage(channel_mapping);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)), 1, 1);

const std::string kernel = "ps_roi_align_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> binaryPSO = mps::binaryPipelineState(device, kernel);

// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input_, rois_});

[computeEncoder setComputePipelineState:binaryPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];
[computeEncoder setBuffer:channelMappingBuffer offset:channel_mapping.storage_offset() * channel_mapping.element_size() atIndex:3];

[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10];
[computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12];

// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}

MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];

getMPSProfiler().endProfileKernel(binaryPSO);
}
});
return std::make_tuple(output, channel_mapping);
}

at::Tensor ps_roi_align_backward_kernel(
const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned) {

using namespace at::native::mps;
TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");

at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2};

at::CheckedFrom c = "ps_roi_align_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t});
at::checkAllSameType(c, {grad_t, rois_t});

float spatial_scale_f = static_cast<float>(spatial_scale);

at::Tensor grad_input = at::zeros(
{batch_size, channels, height, width}, grad.options());

if (grad.numel() == 0) {
return grad_input;
}

int64_t n_stride = grad.stride(0);
int64_t c_stride = grad.stride(1);
int64_t h_stride = grad.stride(2);
int64_t w_stride = grad.stride(3);
int64_t output_size = grad.numel();

at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel");
auto rois_ = rois.contiguous();

id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(std::min(ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)), 1, 1);

const std::string kernel = "ps_roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type());
id<MTLComputePipelineState> binaryPSO = mps::binaryPipelineState(device, kernel);

// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {grad, rois_});

[computeEncoder setComputePipelineState:binaryPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2];

[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
[computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12];
[computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13];
[computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14];
[computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15];

// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = binaryPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}

MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];

getMPSProfiler().endProfileKernel(binaryPSO);
}
});
return grad_input;
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
TORCH_FN(ps_roi_align_forward_kernel));
//m.impl(
// TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
// TORCH_FN(ps_roi_align_backward_kernel));
}

} // namespace ops
} // namespace vision
118 changes: 106 additions & 12 deletions torchvision/csrc/ops/mps/vision_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,17 @@ void atomic_add_float( device atomic_uint* atom_var, const float val )
float fetched_float, assigning_float;
fetched_uint = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed );
fetched_float = *( (thread float*) &fetched_uint );
assigning_float = fetched_float + val;
assigning_uint = *( (thread uint*) &assigning_float );
while ( (fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint, memory_order_relaxed ) ) != 0 ) {
uint fetched_uint_again = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed );
float fetched_float_again = *( (thread float*) &fetched_uint_again );
fetched_float = *( (thread float*) &(fetched_uint) );
assigning_float = fetched_float_again + fetched_float;
assigning_uint = *( (thread uint*) &assigning_float );
uint fetched_uint_again = atomic_exchange_explicit( atom_var, 0, memory_order_relaxed );
float fetched_float_again = *( (thread float*) &fetched_uint_again );
fetched_float = *( (thread float*) &(fetched_uint) );
assigning_float = fetched_float_again + fetched_float;
assigning_uint = *( (thread uint*) &assigning_float );
}
}
Expand Down Expand Up @@ -654,6 +647,107 @@ kernel void roi_pool_backward<DTYPE>( \
REGISTER_ROI_POOL_BACKWARD_OP(float);
REGISTER_ROI_POOL_BACKWARD_OP(half);
template<typename T>
kernel void ps_roi_align(
constant T * input [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * output [[buffer(2)]],
device int64_t * channel_mapping [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant int64_t & sampling_ratio [[buffer(10)]],
constant int64_t & channels_out [[buffer(11)]],
constant float & spatial_scale [[buffer(12)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c_out, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c_out = (index / pooled_width / pooled_height) % channels_out;
int n = index / pooled_width / pooled_height / channels_out;
// (n, c_in, ph, pw) is the associated element in the input
int c_in = (c_out * pooled_height + ph) * pooled_width + pw;
// [start, end) interval for spatial sampling
constant T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
// Do not using floor/ceil; this implementation detail is critical
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height);
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
constant T* offset_input =
input + (roi_batch_ind * channels + c_in) * height * width;
T out_sum = 0;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = hstart +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = wstart +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val = bilinear_interpolate(offset_input, height, width, y, x, index);
out_sum += val;
}
}
out_sum /= count;
output[index] = out_sum;
channel_mapping[index] = c_in;
}
}
#define REGISTER_PS_ROI_ALIGN_OP(DTYPE) \
template \
[[host_name("ps_roi_align_" #DTYPE)]] \
kernel void ps_roi_align<DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
device int64_t * channel_mapping [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant int64_t & sampling_ratio [[buffer(10)]], \
constant int64_t & channels_out [[buffer(11)]], \
constant float & spatial_scale [[buffer(12)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
REGISTER_PS_ROI_ALIGN_OP(float);
REGISTER_PS_ROI_ALIGN_OP(half);
)VISION_METAL";

static id<MTLLibrary> compileBinaryOpsLibrary(id<MTLDevice> device) {
Expand Down

0 comments on commit aa48cc2

Please sign in to comment.