From afc953e40f38382bdb266a59b3ff7c5728014f40 Mon Sep 17 00:00:00 2001 From: Michal Zientkiewicz Date: Mon, 28 Oct 2024 17:06:19 +0100 Subject: [PATCH] Further simplification. Signed-off-by: Michal Zientkiewicz --- .../operator/builtin/input_operator.h | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/dali/pipeline/operator/builtin/input_operator.h b/dali/pipeline/operator/builtin/input_operator.h index 2940bcdacd..c5754460d8 100644 --- a/dali/pipeline/operator/builtin/input_operator.h +++ b/dali/pipeline/operator/builtin/input_operator.h @@ -331,7 +331,7 @@ class InputOperator : public Operator, virtual public BatchSizeProvider } - int device_id_; + int device_id_ = -1; bool blocking_ = true; bool no_copy_ = false; bool running_ = true; @@ -400,11 +400,12 @@ class InputOperator : public Operator, virtual public BatchSizeProvider auto tl_elm = GetEmptyOutputBatch(std::move(data_id)); bool copied_shared_data = false; + if (!order.has_value()) + order = batch.order().is_device() ? batch.order() : tl_elm->data.order(); + // We can share only contiguous tensor lists that are stored on the same device. if (batch.IsContiguousInMemory() && batch.device_id() == device_id_) { - auto &batch_owner = unsafe_sample_owner(const_cast &>(batch), 0); - tl_elm->data.ShareData(batch_owner, batch.nbytes(), batch.is_pinned(), batch.shape(), - batch.type(), batch.device_id(), batch.order()); + tl_elm->data.ShareData(batch); zero_copy_noncontiguous_gpu_input_ = true; } else { // Do not overwrite the buffer it if shares data. @@ -412,16 +413,10 @@ class InputOperator : public Operator, virtual public BatchSizeProvider tl_elm->data.Reset(); tl_elm->data.Copy(batch, order, use_copy_kernel); int device_id = order.is_device() ? order.device_id() : tl_elm->data.device_id(); - cudaEvent_t event = tl_elm->GetCompletionEvent(order.device_id()); - - if (order.device_id() != device_id_ && - (order.stream() == 0 || - order.stream() == cudaStreamPerThread || - order.stream() == cudaStreamLegacy)) { - // In case of ambiguous stream handles, we need to swithch to the proper device - DeviceGuard dg; - CUDA_CALL(cudaEventRecord(event, order.stream())); - } else { + cudaEvent_t event = tl_elm->GetCompletionEvent(device_id); + + { + DeviceGuard dg(order.device_id()); CUDA_CALL(cudaEventRecord(event, order.stream())); } @@ -543,6 +538,7 @@ class InputOperator : public Operator, virtual public BatchSizeProvider */ queue_item_t GetEmptyOutputBatch(std::optional data_id) { auto result = tl_data_.GetEmpty(); + result->data.set_device_id(device_id_); result->data.set_order(internal_copy_order_); result->data_id = (std::move(data_id)); return result;