Skip to content

Commit

Permalink
Stream aware outputs (#5684)
Browse files Browse the repository at this point in the history
* Add output order handling to exec2
* Add CUDA stream to Outputs and ShareOutputs in Python bindings for
  Pipeline.
* Refactor stream pointer handling in Python
* Return outputs in stream order

TODO (follow-up): Expose Tensor(List) `set_order` in Python.

---------

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
mzient authored Oct 30, 2024
1 parent 268704e commit 000ba4d
Show file tree
Hide file tree
Showing 20 changed files with 652 additions and 101 deletions.
3 changes: 3 additions & 0 deletions dali/operators/python_function/dltensor_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ PYBIND11_MODULE(python_function_plugin, m) {
std::optional<cudaStream_t> cuda_stream{};
if (stream.has_value()) {
cuda_stream = reinterpret_cast<cudaStream_t>(*stream);
} else {
if (std::get<0>(self.dlpack_device()) == kDLCUDA)
cuda_stream = cudaStream_t(cudaStreamDefault);
}
DLManagedTensor *data_ptr = self.dlpack(cuda_stream);
return py::capsule(data_ptr, DLTENSOR_NAME, &DLTensorCapsuleDestructor);
Expand Down
106 changes: 106 additions & 0 deletions dali/pipeline/data/dltensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,119 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cuda_runtime_api.h>
#include <condition_variable>
#include <list>
#include <mutex>
#include <string>
#include <thread>
#include "dali/pipeline/data/dltensor.h"
#include "dali/core/error_handling.h"
#include "dali/core/mm/detail/aux_alloc.h"
#include "dali/core/static_switch.h"

namespace dali {

/** A temporary storage for destroyed DLPack buffers
*
* DLPack doesn't define any stream semantics at the end of exchange. Some libraries
* (most, actually) call the deleter while the tensor is still in use on device.
*
* Here we store the shared pointers that were managed by DLTensors issued by DALI.
* A thread collects those shared pointers, calls `cudaDeviceSynchronize` and only then
* decrements the reference counter, effectively guaranteeing that whatever work was scheduled
* before the deleter was called, is complete. While the thread waits for the GPU, new
* buffers may be accumulated. Under high load, there can be much fewer calls to
* `cudaDeviceSynchronize` than DLPack deleters.
*/
class DLTensorGraveyard {
public:
~DLTensorGraveyard() {
shutdown();
}

/** Places a device memory pointer in the deletion queue.
*
* The pointer `mem` is kept referenced at least until all GPU work scheduled prior to
* the call to `enqueue` is complete.
*/
void enqueue(std::shared_ptr<void> mem) {
{
std::lock_guard g(mtx_);
if (exit_requested_)
return; // we don't prolong the life of memory on shutdown
if (!started_)
start();
pending_.push_back(std::move(mem));
}
cv_.notify_one();
}

static DLTensorGraveyard &instance(int dev) {
static std::vector<DLTensorGraveyard> inst = []() {
int ndev = 0;
CUDA_CALL(cudaGetDeviceCount(&ndev));
std::vector<DLTensorGraveyard> ret(ndev);
for (int i = 0; i < ndev; i++)
ret[i].device_id_ = i;
return ret;
}();
return inst[dev];
}

private:
void start() {
assert(!exit_requested_);
assert(!started_);
worker_ = std::thread([this]() { run(); });
started_ = true;
}

void shutdown() {
{
std::lock_guard g(mtx_);
exit_requested_ = true;
}
cv_.notify_one();
if (worker_.joinable())
worker_.join();
}

void run() {
CUDA_CALL(cudaSetDevice(device_id_));
std::unique_lock lock(mtx_);
for (;;) {
cv_.wait(lock, [&]() {
return !pending_.empty() || exit_requested_;
});
if (exit_requested_)
break;
list_t tmp = std::move(pending_); // get some pointers
lock.unlock(); // and let new ones accumulate while we wait for the GPU
auto ret = cudaDeviceSynchronize();
if (ret == cudaErrorCudartUnloading) // the process is shutting down - exit
break;
CUDA_CALL(ret);
tmp.clear(); // this actually clears the references, still outside the lock
lock.lock(); // OK, regain the lock and start over
}
}

std::mutex mtx_;
std::condition_variable cv_;
std::thread worker_;
using list_alloc_t = mm::detail::object_pool_allocator<std::shared_ptr<void>>;
using list_t = std::list<std::shared_ptr<void>, list_alloc_t>;
list_t pending_;
int device_id_ = -1;
bool started_ = false;
bool exit_requested_ = false;
};

void EnqueueForDeletion(std::shared_ptr<void> data, int device_id) {
DLTensorGraveyard::instance(device_id).enqueue(std::move(data));
}

DLDataType ToDLType(DALIDataType type) {
DLDataType dl_type{};
TYPE_SWITCH(type, type2id, T, (DALI_NUMERIC_TYPES_FP16, bool), (
Expand Down
17 changes: 17 additions & 0 deletions dali/pipeline/data/dltensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef DALI_PIPELINE_DATA_DLTENSOR_H_
#define DALI_PIPELINE_DATA_DLTENSOR_H_

#include <cuda_runtime_api.h>
#include <cassert>
#include <memory>
#include <optional>
Expand Down Expand Up @@ -95,6 +96,8 @@ struct TensorViewPayload {
TensorShape<> shape, strides;
};

DLL_PUBLIC void EnqueueForDeletion(std::shared_ptr<void> data, int device_id);

/** Default ownership-sharing payload for DLPack tensors. */
struct SharedTensorPayload : TensorViewPayload {
std::shared_ptr<void> data;
Expand Down Expand Up @@ -144,6 +147,8 @@ struct DLTensorResource {
: dlm_tensor{{}, this, dlm_deleter}
, payload{std::forward<PayloadArgs>(args)...} {}

~DLTensorResource() {}


DLManagedTensor dlm_tensor{};
Payload payload;
Expand All @@ -162,6 +167,18 @@ struct DLTensorResource {
}
};

template <>
inline DLTensorResource<SharedTensorPayload>::~DLTensorResource() {
if (dlm_tensor.dl_tensor.device.device_type == kDLCUDAHost ||
dlm_tensor.dl_tensor.device.device_type == kDLCUDAManaged) {
int current_dev = 0;
CUDA_DTOR_CALL(cudaGetDevice(&current_dev));
EnqueueForDeletion(std::move(payload.data), current_dev);
} else if (dlm_tensor.dl_tensor.device.device_type == kDLCUDA) {
EnqueueForDeletion(std::move(payload.data), dlm_tensor.dl_tensor.device.device_id);
}
}

/** Type-erases the DLTensorResource and returns a smart pointer to the contained DLManagedTensor.
*/
template <typename Payload>
Expand Down
26 changes: 20 additions & 6 deletions dali/pipeline/data/dltensor_obj.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include "third_party/dlpack/include/dlpack/dlpack.h"

#include "dali/core/common.h"
#include "dali/core/cuda_event_pool.h"
#include "dali/core/cuda_shared_event.h"
#include "dali/core/error_handling.h"
#include "dali/pipeline/data/dltensor.h"

Expand All @@ -44,17 +44,31 @@ class DLL_PUBLIC DLTensorObj {
DALI_ENFORCE(dlm_ptr_, "Expected non-null pointer for managed DLTensor");
device_id_ = dlm_ptr_->dl_tensor.device.device_id;
device_type_ = dlm_ptr_->dl_tensor.device.device_type;
DALI_ENFORCE(device_type_ == kDLCPU || device_type_ == kDLCUDA,
DALI_ENFORCE(device_type_ == kDLCPU || device_type_ == kDLCUDAHost || device_type_ == kDLCUDA,
"Currently only DLCPU and DLGPU device types are supported");
if (producer_stream) {
DALI_ENFORCE(device_type_ == kDLCUDA,
"Stream-aware DLTensorObj supports only DLGPU device type.");
DALI_ENFORCE(device_type_ == kDLCUDA || device_type_ == kDLCUDAHost,
"Stream-aware DLTensorObj supports only CUDA and CUDA host device type.");
auto &pool = CUDAEventPool::instance();
data_ready_ = pool.Get(device_id_);
data_ready_ = CUDASharedEvent::GetFromPool(device_id_);
CUDA_CALL(cudaEventRecord(data_ready_, *producer_stream));
}
}

DLL_PUBLIC DLTensorObj(DLMTensorPtr ptr, CUDASharedEvent event)
: dlm_ptr_{std::move(ptr)} {
DALI_ENFORCE(dlm_ptr_, "Expected non-null pointer for managed DLTensor");
device_id_ = dlm_ptr_->dl_tensor.device.device_id;
device_type_ = dlm_ptr_->dl_tensor.device.device_type;
DALI_ENFORCE(device_type_ == kDLCPU || device_type_ == kDLCUDAHost || device_type_ == kDLCUDA,
"Currently only DLCPU and DLGPU device types are supported");
if (event) {
DALI_ENFORCE(device_type_ == kDLCUDA || device_type_ == kDLCUDAHost,
"Stream-aware DLTensorObj supports only CUDA and CUDA host device type.");
}
data_ready_ = std::move(event);
}

DLTensorObj(DLTensorObj &) = delete;
DLL_PUBLIC DLTensorObj(DLTensorObj &&) = default;

Expand All @@ -74,7 +88,7 @@ class DLL_PUBLIC DLTensorObj {
DLMTensorPtr dlm_ptr_;
int device_id_;
DLDeviceType device_type_;
CUDAEvent data_ready_;
CUDASharedEvent data_ready_;
};


Expand Down
40 changes: 35 additions & 5 deletions dali/pipeline/executor/executor2/exec2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class Executor2::Impl {
}
}

Workspace PopOutputs() {
Workspace PopOutputs(AccessOrder output_order, bool set_output_order) {
if (pending_outputs_.empty())
throw std::out_of_range("All pending outputs were already popped.");
DeviceGuard dg(config_.device.value_or(CPU_ONLY_DEVICE_ID));
Expand All @@ -145,9 +145,39 @@ class Executor2::Impl {
auto &pipe_out = fut.Value<const PipelineOutput &>();
auto ws = pipe_out.workspace;
last_iter_data_ = ws.GetIterationData();
if (ws.has_event())
CUDA_CALL(cudaEventSynchronize(ws.event()));
ws.set_event(nullptr);
if (ws.has_event()) {
if (output_order.has_value() && output_order != ws.output_order())
output_order.wait(ws.event());
ws.set_event(nullptr);
}

if (output_order.has_value() && output_order != ws.output_order()) {
// Set the order of the outputs to the requested output_order - no synchronization
// is necessary, the stream has been properly synchronized a few lines above.
for (int i = 0; i < ws.NumOutput(); i++) {
if (ws.OutputIsType<GPUBackend>(i)) {
auto &output = ws.Output<GPUBackend>(i);
assert(output.ready_event() == pipe_out.event.get());
if (set_output_order)
output.set_order(output_order, false);
if (output_order.is_host())
output.set_ready_event({});
} else {
assert(ws.OutputIsType<CPUBackend>(i));
auto &out = ws.Output<CPUBackend>(i);
if (out.is_pinned() && out.order().is_device())
assert(out.ready_event() == pipe_out.event.get());

if (set_output_order && output_order.has_value() &&
out.order().is_device() && out.is_pinned())
out.set_order(output_order, false);
if (output_order.is_host())
out.set_ready_event({});
}
}

ws.set_output_order(output_order);
}
return ws;
}

Expand Down Expand Up @@ -414,7 +444,7 @@ void Executor2::Prefetch() {


void Executor2::Outputs(Workspace *ws) {
*ws = impl_->PopOutputs();
*ws = impl_->PopOutputs(ws->output_order(), false);
}

void Executor2::ShareOutputs(Workspace *ws) {
Expand Down
12 changes: 12 additions & 0 deletions dali/pipeline/executor/executor2/exec_node_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,18 @@ void OpTask::RunOp() {
}
if (ws_->has_stream()) {
assert(ws_->has_event());
assert(event_ == ws_->event());
for (int o = 0; o < ws_->NumOutput(); o++) {
if (ws_->OutputIsType<GPUBackend>(o)) {
auto &out = ws_->Output<GPUBackend>(o);
out.set_ready_event(event_);
} else if (ws_->OutputIsType<CPUBackend>(o)) {
auto &out = ws_->Output<CPUBackend>(o);
if (out.is_pinned() && out.order().is_device()) {
out.set_ready_event(event_);
}
}
}
CUDA_CALL(cudaEventRecord(ws_->event(), ws_->stream()));
}
}
Expand Down
3 changes: 2 additions & 1 deletion dali/python/__init__.py.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2017-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@ __version__ = '@DALI_VERSION@'
__cuda_version__ = int('@CUDA_VERSION@'.replace('.', ''))
__git_sha__ = '@GIT_SHA@'

from . import backend
from . import ops
from . import pipeline
from . import tensors
Expand Down
Loading

0 comments on commit 000ba4d

Please sign in to comment.