Skip to content

Commit

Permalink
Fix stream semantics with default streams.
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
mzient committed Oct 25, 2024
1 parent 2d912c3 commit 6571032
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 21 deletions.
53 changes: 38 additions & 15 deletions dali/core/access_order.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -25,6 +25,13 @@ AccessOrder::AccessOrder(cudaStream_t stream) : stream_(stream) {
device_id_ = DeviceFromStream(stream);
}

constexpr bool is_ambiguous_handle(cudaStream_t stream) {
return
stream == 0 ||
stream == cudaStreamPerThread ||
stream == cudaStreamLegacy;
}

void AccessOrder::wait(const AccessOrder &other) const {
if (*this == other)
return;
Expand All @@ -33,44 +40,60 @@ void AccessOrder::wait(const AccessOrder &other) const {
// always considered up-to-date.
if (!has_value() || !other.is_device())
return;

auto current_dev = []() {
int dev;
CUDA_CALL(cudaGetDevice(&dev));
return dev;
};

auto need_device_switch = [&]() {
return is_ambiguous_handle(other.stream_) && other.device_id() != current_dev();
};

if (is_device()) {
auto &pool = CUDAEventPool::instance();
int other_dev = other.device_id();
auto event = pool.Get(other_dev);
// Record an event in the preceding stream

auto current_dev = []() {
int dev;
CUDA_CALL(cudaGetDevice(&dev));
return dev;
};

// If the stream handle has a special value, we can't refer to it directly - it is
// inherently associated with the concept of "current device" and it must be switched
if (other_dev != device_id_ ||
((other.stream_ == 0 ||
other.stream_ == cudaStreamPerThread ||
other.stream_ == cudaStreamLegacy) &&
other_dev != current_dev())) {
if (need_device_switch()) {
DeviceGuard dg(other.device_id_);
CUDA_CALL(cudaEventRecord(event, other.stream()));
} else {
CUDA_CALL(cudaEventRecord(event, other.stream()));
}
// and wait for it in this stream
CUDA_CALL(cudaStreamWaitEvent(stream(), event, 0));
if (is_ambiguous_handle(stream())) {
DeviceGuard dg(device_id_);
CUDA_CALL(cudaStreamWaitEvent(stream(), event, 0));
} else {
CUDA_CALL(cudaStreamWaitEvent(stream(), event, 0));
}
pool.Put(std::move(event), other_dev);
} else {
// host order - wait for the preceding stream on host
CUDA_CALL(cudaStreamSynchronize(other.stream()));
if (need_device_switch()) {
DeviceGuard dg(device_id_);
CUDA_CALL(cudaStreamSynchronize(other.stream()));
} else {
CUDA_CALL(cudaStreamSynchronize(other.stream()));
}
}
}

void AccessOrder::wait(cudaEvent_t event) const {
if (!has_value())
throw std::logic_error("A null AccessOrder cannot wait for an event.");
if (is_device()) {
CUDA_DTOR_CALL(cudaStreamWaitEvent(stream(), event, 0));
if (is_ambiguous_handle(stream())) {
DeviceGuard dg(device_id_);
CUDA_DTOR_CALL(cudaStreamWaitEvent(stream(), event, 0));
} else {
CUDA_DTOR_CALL(cudaStreamWaitEvent(stream(), event, 0));
}
} else {
CUDA_DTOR_CALL(cudaEventSynchronize(event));
}
Expand Down
26 changes: 20 additions & 6 deletions dali/pipeline/operator/builtin/input_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct InputQueueItem {
if (device_id != device_id_)
Put();
if (!event_) {
event_ = CUDAEventPool::instance().Get(device_id_);
event_ = CUDAEventPool::instance().Get(device_id);
device_id_ = device_id;
}
}
Expand Down Expand Up @@ -413,7 +413,17 @@ class InputOperator : public Operator<Backend>, virtual public BatchSizeProvider
tl_elm->data.Copy(batch, order, use_copy_kernel);
int device_id = order.is_device() ? order.device_id() : tl_elm->data.device_id();
cudaEvent_t event = tl_elm->GetCompletionEvent(order.device_id());
CUDA_CALL(cudaEventRecord(event, order.stream()));

if (order.device_id() != device_id_ &&
(order.stream() == 0 ||
order.stream() == cudaStreamPerThread ||
order.stream() == cudaStreamLegacy)) {
// In case of ambiguous stream handles, we need to swithch to the proper device
DeviceGuard dg;
CUDA_CALL(cudaEventRecord(event, order.stream()));
} else {
CUDA_CALL(cudaEventRecord(event, order.stream()));
}

if (zero_copy_noncontiguous_gpu_input_) {
DALI_WARN("ExternalSource operator should not mix contiguous and noncontiguous inputs. "
Expand Down Expand Up @@ -476,10 +486,14 @@ class InputOperator : public Operator<Backend>, virtual public BatchSizeProvider
}
tl_elm->data.Copy(batch, order, use_copy_kernel);
int copy_device = order.is_device() ? order.device_id() : tl_elm->data.device_id();
auto event = tl_elm->GetCompletionEvent(copy_device);
CUDA_CALL(cudaEventRecord(event, order.stream()));
if (sync) {
CUDA_CALL(cudaEventSynchronize(event));

{
DeviceGuard dg(copy_device);
auto event = tl_elm->GetCompletionEvent(copy_device);
CUDA_CALL(cudaEventRecord(event, order.stream()));
if (sync) {
CUDA_CALL(cudaEventSynchronize(event));
}
}

{
Expand Down

0 comments on commit 6571032

Please sign in to comment.