Skip to content

Commit

Permalink
Remove intermediate TensorLists. Improve performance
Browse files Browse the repository at this point in the history
Signed-off-by: Rafal Banas <rbanas@nvidia.com>
  • Loading branch information
banasraf committed Oct 29, 2024
1 parent 0ff79a2 commit 44b9ca2
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 60 deletions.
83 changes: 46 additions & 37 deletions dali/operators/image/resize/experimental/resize_op_impl_cvcuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "dali/kernels/imgproc/resample/params.h"
#include "dali/operators/image/resize/resize_op_impl.h"
#include "dali/operators/nvcvop/nvcvop.h"
#include "dali/core/nvtx.h"

namespace dali {

Expand All @@ -33,12 +34,13 @@ class ResizeOpImplCvCuda : public ResizeBase<GPUBackend>::Impl {

static_assert(spatial_ndim == 2 || spatial_ndim == 3, "Only 2D and 3D resizing is supported");


/// Dimensionality of each separate frame. If input contains no channel dimension, one is added
static constexpr int frame_ndim = spatial_ndim + 1;

void Setup(TensorListShape<> &out_shape, const TensorListShape<> &in_shape, int first_spatial_dim,
span<const kernels::ResamplingParams> params) override {
first_spatial_dim_ = first_spatial_dim;

// Calculate output shape of the input, as supplied (sequences, planar images, etc)
GetResizedShape(out_shape, in_shape, params, spatial_ndim, first_spatial_dim);

Expand All @@ -49,31 +51,40 @@ class ResizeOpImplCvCuda : public ResizeBase<GPUBackend>::Impl {
// effective frames (from videos, channel planes, etc).
GetResizedShape(out_shape_, in_shape_, make_cspan(params_), 0);

// Create a map of non-empty samples
SetFrameIdxs();

// Now that we know how many logical frames there are, calculate batch subdivision.
CalculateMinibatchPartition(minibatch_size_);

CalculateSourceSamples(in_shape, first_spatial_dim);

SetupKernel();
}

// Set the frame_idx_ map with indices of samples that are not empty
void SetFrameIdxs() {
frame_idx_.clear();
frame_idx_.reserve(in_shape_.num_samples());
for (int i = 0; i < in_shape_.num_samples(); ++i) {
if (volume(out_shape_.tensor_shape_span(i)) != 0 &&
volume(in_shape_.tensor_shape_span(i)) != 0) {
frame_idx_.push_back(i);
// Assign each minibatch a range of frames in the original input/output TensorLists
void CalculateSourceSamples(const TensorListShape<> &original_shape, int first_spatial_dim) {
int64_t sample_id = 0;
int64_t frame_offset = 0;
for (auto &mb : minibatches_) {
auto v = original_shape[sample_id].num_elements();
while (v == 0) {
sample_id++;
v = original_shape[sample_id].num_elements();
}
mb.sample_offset = sample_id;
mb.frame_offset = frame_offset;
frame_offset = mb.frame_offset + mb.count;
int frames_n = num_frames(original_shape[sample_id], first_spatial_dim);
while (frame_offset >= frames_n) {
frame_offset -= frames_n;
if (++sample_id >= original_shape.num_samples()) {
break;
}
frames_n = num_frames(original_shape[sample_id], first_spatial_dim);
}
total_frames_ = frame_idx_.size();
}
}

// get the index of a frame in the DALI TensorList
int frame_idx(int f) {
return frame_idx_[f];
int64_t num_frames(const TensorShape<> &shape, int first_spatial_dim) {
return volume(&shape[0], &shape[first_spatial_dim]);
}

void SetupKernel() {
Expand All @@ -88,23 +99,22 @@ class ResizeOpImplCvCuda : public ResizeBase<GPUBackend>::Impl {

int end = mb.start + mb.count;
for (int i = mb.start, j = 0; i < end; i++, j++) {
auto f_id = frame_idx(i);
rois_ptr[j] = GetRoi(params_[f_id]);
rois_ptr[j] = GetRoi(params_[i]);
for (int d = 0; d < spatial_ndim; ++d) {
mb_input_shapes[j].extent[d] = static_cast<int32_t>(in_shape_.tensor_shape_span(f_id)[d]);
mb_input_shapes[j].extent[d] = static_cast<int32_t>(in_shape_.tensor_shape_span(i)[d]);
mb_output_shapes[j].extent[d] =
static_cast<int32_t>(out_shape_.tensor_shape_span(f_id)[d]);
static_cast<int32_t>(out_shape_.tensor_shape_span(i)[d]);
}
}
int num_channels = in_shape_[frame_idx(0)][frame_ndim - 1];
int num_channels = in_shape_[0][frame_ndim - 1];
HQResizeTensorShapesI mb_input_shape{mb_input_shapes.data(), mb.count, spatial_ndim,
num_channels};
HQResizeTensorShapesI mb_output_shape{mb_output_shapes.data(), mb.count, spatial_ndim,
num_channels};
mb.rois = HQResizeRoisF{mb.count, spatial_ndim, rois_ptr};
rois_ptr += mb.count;

auto param = params_[frame_idx(mb.start)][0];
auto param = params_[mb.start][0];
mb.min_interpolation = GetInterpolationType(param.min_filter);
mb.mag_interpolation = GetInterpolationType(param.mag_filter);
mb.antialias = param.min_filter.antialias || param.mag_filter.antialias;
Expand Down Expand Up @@ -149,21 +159,17 @@ class ResizeOpImplCvCuda : public ResizeBase<GPUBackend>::Impl {
kernels::DynamicScratchpad scratchpad({}, AccessOrder(ws.stream()));
auto allocator = nvcvop::GetScratchpadAllocator(scratchpad);

in_frames_.ShareData(input);
in_frames_.Resize(in_shape_);

out_frames_.ShareData(output);
out_frames_.Resize(out_shape_);

auto workspace_mem = AllocateWorkspaces(scratchpad);

for (size_t b = 0; b < minibatches_.size(); b++) {
MiniBatch &mb = minibatches_[b];
auto reqs = nvcv::TensorBatch::CalcRequirements(mb.count);
auto mb_output = nvcv::TensorBatch(reqs, allocator);
auto mb_input = nvcv::TensorBatch(reqs, allocator);
nvcvop::PushTensorsToBatch(mb_input, in_frames_, mb.start, mb.count, sample_layout_);
nvcvop::PushTensorsToBatch(mb_output, out_frames_, mb.start, mb.count, sample_layout_);
nvcvop::PushFramesToBatch(mb_input, input, first_spatial_dim_, mb.sample_offset,
mb.frame_offset, mb.count, sample_layout_);
nvcvop::PushFramesToBatch(mb_output, output, first_spatial_dim_, mb.sample_offset,
mb.frame_offset, mb.count, sample_layout_);
resize_op_(ws.stream(), workspace_mem[b % 2], mb_input, mb_output, mb.min_interpolation,
mb.mag_interpolation, mb.antialias, mb.rois);
}
Expand All @@ -179,13 +185,14 @@ class ResizeOpImplCvCuda : public ResizeBase<GPUBackend>::Impl {
}

void CalculateMinibatchPartition(int minibatch_size) {
total_frames_ = in_shape_.num_samples();
std::vector<std::pair<int, int>> continuous_ranges;
kernels::FilterDesc min_filter_desc = params_[frame_idx(0)][0].min_filter;
kernels::FilterDesc mag_filter_desc = params_[frame_idx(0)][0].mag_filter;
kernels::FilterDesc min_filter_desc = params_[0][0].min_filter;
kernels::FilterDesc mag_filter_desc = params_[0][0].mag_filter;
int start_id = 0;
for (int i = 0; i < total_frames_; i++) {
if (params_[frame_idx(i)][0].min_filter != min_filter_desc ||
params_[frame_idx(i)][0].mag_filter != mag_filter_desc) {
if (params_[i][0].min_filter != min_filter_desc ||
params_[i][0].mag_filter != mag_filter_desc) {
// we break the range if different filter types are used
continuous_ranges.emplace_back(start_id, i);
start_id = i;
Expand Down Expand Up @@ -214,25 +221,27 @@ class ResizeOpImplCvCuda : public ResizeBase<GPUBackend>::Impl {
}

TensorListShape<frame_ndim> in_shape_, out_shape_;
std::vector<int> frame_idx_; // map of absolute frame indices in the input TensorList
int total_frames_; // number of non-empty frames
std::vector<ResamplingParamsND<spatial_ndim>> params_;
int first_spatial_dim_;

cvcuda::HQResize resize_op_{};
nvcvop::NVCVOpWorkspace op_workspace_;
std::array<cvcuda::WorkspaceRequirements, 2> workspace_reqs_{};
std::vector<HQResizeRoiF> rois_;
const TensorLayout sample_layout_ = (spatial_ndim == 2) ? "HWC" : "DHWC";

TensorList<GPUBackend> in_frames_;
TensorList<GPUBackend> out_frames_;
std::vector<const void*> in_frames_;
std::vector<const void*> out_frames_;

struct MiniBatch {
int start, count;
NVCVInterpolationType min_interpolation;
NVCVInterpolationType mag_interpolation;
bool antialias;
HQResizeRoisF rois;
int64_t sample_offset; // id of a starting sample in the original IOs
int64_t frame_offset; // id of a starting frame in the starting sample
};

std::vector<MiniBatch> minibatches_;
Expand Down
8 changes: 5 additions & 3 deletions dali/operators/image/resize/resize_op_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ void GetFrameShapesAndParams(

for (int i = 0; i < N; i++) {
auto in_sample_shape = in_shape.tensor_shape_span(i);
total_frames += volume(&in_sample_shape[0], &in_sample_shape[first_spatial_dim]);
if (volume(in_sample_shape) > 0)
total_frames += volume(&in_sample_shape[0], &in_sample_shape[first_spatial_dim]);
}

frame_params.resize(total_frames);
Expand All @@ -72,10 +73,11 @@ void GetFrameShapesAndParams(
int ndim = in_shape.sample_dim();
for (int i = 0, flat_frame_idx = 0; i < N; i++) {
auto in_sample_shape = in_shape.tensor_shape_span(i);
if (volume(in_sample_shape) == 0) {
continue; // skip empty samples
}
// Collapse leading dimensions, if any, as frame dim. This handles channel-first.
int seq_len = volume(&in_sample_shape[0], &in_sample_shape[first_spatial_dim]);
if (seq_len == 0)
continue; // skip empty sequences
TensorShape<out_ndim> frame_shape;
frame_shape.resize(frame_ndim);

Expand Down
60 changes: 44 additions & 16 deletions dali/operators/nvcvop/nvcvop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ nvcv::Tensor AsTensor(void *data, const TensorShape<> &shape, DALIDataType daliD
TensorLayout layout) {
auto dtype = GetDataType(daliDType, 1);
nvcv::TensorDataStridedCuda::Buffer inBuf;
inBuf.basePtr = reinterpret_cast<NVCVByte *>(const_cast<void *>(data));
inBuf.basePtr = static_cast<NVCVByte *>(const_cast<void *>(data));
inBuf.strides[shape.size() - 1] = dtype.strideBytes();
for (int d = shape.size() - 2; d >= 0; --d) {
inBuf.strides[d] = shape[d + 1] * inBuf.strides[d + 1];
Expand All @@ -229,7 +229,7 @@ nvcv::Tensor AsTensor(const void *data, span<const int64_t> shape_data, const nv
const nvcv::TensorLayout &layout) {
int ndim = shape_data.size();
nvcv::TensorDataStridedCuda::Buffer inBuf;
inBuf.basePtr = reinterpret_cast<NVCVByte *>(const_cast<void *>(data));
inBuf.basePtr = static_cast<NVCVByte *>(const_cast<void *>(data));
inBuf.strides[ndim - 1] = dtype.strideBytes();
for (int d = ndim - 2; d >= 0; --d) {
inBuf.strides[d] = shape_data[d + 1] * inBuf.strides[d + 1];
Expand All @@ -239,26 +239,54 @@ nvcv::Tensor AsTensor(const void *data, span<const int64_t> shape_data, const nv
return nvcv::TensorWrapData(inData);
}

int64_t calc_num_frames(const TensorShape<> &shape, int first_spatial_dim) {
return (first_spatial_dim > 0) ?
volume(&shape[0], &shape[first_spatial_dim]) :
1;
}

void PushTensorsToBatch(nvcv::TensorBatch &batch, const TensorList<GPUBackend> &t_list,
int64_t start, int64_t count, const TensorLayout &layout) {
int ndim = t_list.sample_dim();
auto dtype = GetDataType(t_list.type(), 1);
TensorLayout out_layout = layout.empty() ? t_list.GetLayout() : layout;
DALI_ENFORCE(
out_layout.empty() || out_layout.size() == ndim,
make_string("Layout ", out_layout, " does not match the number of dimensions: ", ndim));
auto nvcv_layout = nvcv::TensorLayout(out_layout.c_str());
std::vector<nvcv::Tensor> tensors;
tensors.reserve(count);
void PushFramesToBatch(nvcv::TensorBatch &batch, const TensorList<GPUBackend> &t_list,
int first_spatial_dim, int64_t starting_sample, int64_t frame_offset,
int64_t num_frames, const TensorLayout &layout) {
int ndim = layout.ndim();
auto nvcv_layout = nvcv::TensorLayout(layout.c_str());
auto dtype = GetDataType(t_list.type());

for (int s = 0; s < count; ++s) {
tensors.push_back(AsTensor(t_list.raw_tensor(s + start), t_list.tensor_shape_span(s + start),
dtype, nvcv_layout));
std::vector<nvcv::Tensor> tensors;
tensors.reserve(num_frames);

const auto &input_shape = t_list.shape();
int64_t sample_id = starting_sample - 1;
auto type_size = dtype.strideBytes();
std::vector<int64_t> frame_shape(ndim, 1);

auto frame_stride = 0;
int sample_nframes = 0;
const uint8_t *data = nullptr;

for (int64_t i = 0; i < num_frames; ++i) {
if (frame_offset == sample_nframes) {
frame_offset = 0;
do {
++sample_id;
auto sample_shape = input_shape[sample_id];
DALI_ENFORCE(sample_id < t_list.num_samples());
std::copy(&sample_shape[first_spatial_dim], &sample_shape[input_shape.sample_dim()],
frame_shape.begin());
frame_stride = volume(frame_shape) * type_size;
sample_nframes = calc_num_frames(sample_shape, first_spatial_dim);
} while (sample_nframes * frame_stride == 0); // we skip empty samples
data =
static_cast<const uint8_t *>(t_list.raw_tensor(sample_id)) + frame_stride * frame_offset;
}
tensors.push_back(AsTensor(data, make_span(frame_shape), dtype, nvcv_layout));
data += frame_stride;
frame_offset++;
}
batch.pushBack(tensors.begin(), tensors.end());
}


cvcuda::Workspace NVCVOpWorkspace::Allocate(const cvcuda::WorkspaceRequirements &reqs,
kernels::Scratchpad &scratchpad) {
auto *hostBuffer = scratchpad.AllocateHost<uint8_t>(reqs.hostMem.size, reqs.hostMem.alignment);
Expand Down
23 changes: 19 additions & 4 deletions dali/operators/nvcvop/nvcvop.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,27 @@ void AllocateImagesLike(nvcv::ImageBatchVarShape &output, const TensorList<GPUBa
*/
void PushImagesToBatch(nvcv::ImageBatchVarShape &batch, const TensorList<GPUBackend> &t_list);


/**
* @brief Push samples from a given tensor list to a given TensorBatch.
* [start, start+count) determines the range of samples in the TensorList that will be used.
* @brief Push a range of frames from the input TensorList as samples in the output TensorBatch.
*
* The input TensorList is interpreted as sequence of frames where innermost dimensions
* starting from `first_spatial_dim` are the frames' dimensions.
*
* The range of frames is determined by the `starting_sample`, `frame_offset`
* and `num_frames` arguments.
* `starting_sample` is an index of the first source sample from the input TensorList. All the samples before that are skipped.
* `frame_offset` is an index of a first frame in the starting sample to be taken.
* `num_frames` is the total number of frames that will be pushed to the output TensorBatch.
*
* @param batch output TensorBatch
* @param t_list input TensorList
* @param layout layout of the output TensorBatch
*/
void PushTensorsToBatch(nvcv::TensorBatch &batch, const TensorList<GPUBackend> &t_list,
int64_t start, int64_t count, const TensorLayout &layout);
void PushFramesToBatch(nvcv::TensorBatch &batch, const TensorList<GPUBackend> &t_list,
int first_spatial_dim, int64_t starting_sample, int64_t frame_offset,
int64_t num_frames, const TensorLayout &layout);


class NVCVOpWorkspace {
public:
Expand Down

0 comments on commit 44b9ca2

Please sign in to comment.