diff --git a/dali/pipeline/operator/builtin/input_operator.h b/dali/pipeline/operator/builtin/input_operator.h index b5b01e5042..27f21fc729 100644 --- a/dali/pipeline/operator/builtin/input_operator.h +++ b/dali/pipeline/operator/builtin/input_operator.h @@ -512,7 +512,10 @@ class InputOperator : public Operator, virtual public BatchSizeProvider order = tl_elm.front()->order(); } tl_elm.front()->Copy(batch, order, use_copy_kernel); - CUDA_CALL(cudaEventRecord(*copy_to_storage_event.front(), order.stream())); + { + DeviceGuard dg(order.device_id()); + CUDA_CALL(cudaEventRecord(*copy_to_storage_event.front(), order.stream())); + } if (sync) { CUDA_CALL(cudaEventSynchronize(*copy_to_storage_event.front())); } diff --git a/dali/python/backend_impl.cc b/dali/python/backend_impl.cc index cac6bb8c19..f49a416509 100644 --- a/dali/python/backend_impl.cc +++ b/dali/python/backend_impl.cc @@ -2245,6 +2245,14 @@ PYBIND11_MODULE(backend_impl, m) { device_id = sample0.device_id(); } AccessOrder order(stream, device_id); + if (order.is_device()) { + CUcontext ctx = nullptr; + CUDA_CALL(cuStreamGetCtx(order.stream(), &ctx)); + CUDA_CALL(cuCtxPushCurrent(ctx)); + CUdevice device; + CUDA_CALL(cuCtxGetDevice(&device)); + CUDA_CALL(cuCtxPopCurrent(&ctx)); + } FeedPipeline(p, name, list, order, cuda_stream.is_none(), use_copy_kernel); } },