Skip to content

Commit

Permalink
Further simplification.
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
mzient committed Oct 29, 2024
1 parent acbb9d5 commit afc953e
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions dali/pipeline/operator/builtin/input_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class InputOperator : public Operator<Backend>, virtual public BatchSizeProvider
}


int device_id_;
int device_id_ = -1;
bool blocking_ = true;
bool no_copy_ = false;
bool running_ = true;
Expand Down Expand Up @@ -400,28 +400,23 @@ class InputOperator : public Operator<Backend>, virtual public BatchSizeProvider
auto tl_elm = GetEmptyOutputBatch(std::move(data_id));
bool copied_shared_data = false;

if (!order.has_value())
order = batch.order().is_device() ? batch.order() : tl_elm->data.order();

// We can share only contiguous tensor lists that are stored on the same device.
if (batch.IsContiguousInMemory() && batch.device_id() == device_id_) {
auto &batch_owner = unsafe_sample_owner(const_cast<TensorList<SrcBackend> &>(batch), 0);
tl_elm->data.ShareData(batch_owner, batch.nbytes(), batch.is_pinned(), batch.shape(),
batch.type(), batch.device_id(), batch.order());
tl_elm->data.ShareData(batch);
zero_copy_noncontiguous_gpu_input_ = true;
} else {
// Do not overwrite the buffer it if shares data.
if (tl_elm->data.shares_data())
tl_elm->data.Reset();
tl_elm->data.Copy(batch, order, use_copy_kernel);
int device_id = order.is_device() ? order.device_id() : tl_elm->data.device_id();
cudaEvent_t event = tl_elm->GetCompletionEvent(order.device_id());

if (order.device_id() != device_id_ &&
(order.stream() == 0 ||
order.stream() == cudaStreamPerThread ||
order.stream() == cudaStreamLegacy)) {
// In case of ambiguous stream handles, we need to swithch to the proper device
DeviceGuard dg;
CUDA_CALL(cudaEventRecord(event, order.stream()));
} else {
cudaEvent_t event = tl_elm->GetCompletionEvent(device_id);

{
DeviceGuard dg(order.device_id());
CUDA_CALL(cudaEventRecord(event, order.stream()));
}

Expand Down Expand Up @@ -543,6 +538,7 @@ class InputOperator : public Operator<Backend>, virtual public BatchSizeProvider
*/
queue_item_t GetEmptyOutputBatch(std::optional<std::string> data_id) {
auto result = tl_data_.GetEmpty();
result->data.set_device_id(device_id_);
result->data.set_order(internal_copy_order_);
result->data_id = (std::move(data_id));
return result;
Expand Down

0 comments on commit afc953e

Please sign in to comment.