From 9b42d288b4a4a8fb26c86f37862aaadefb6f75f9 Mon Sep 17 00:00:00 2001 From: Michal Zientkiewicz Date: Thu, 24 Oct 2024 09:01:35 +0200 Subject: [PATCH] Expose streams to Python in DALI tensors. Signed-off-by: Michal Zientkiewicz --- dali/python/backend_impl.cc | 12 ++++++++++++ dali/python/nvidia/dali/types.py | 3 +++ 2 files changed, 15 insertions(+) diff --git a/dali/python/backend_impl.cc b/dali/python/backend_impl.cc index f49a416509..75a3373fb7 100644 --- a/dali/python/backend_impl.cc +++ b/dali/python/backend_impl.cc @@ -890,6 +890,12 @@ void ExposeTensor(py::module &m) { non_blocking : bool Asynchronous copy. )code") + .def_property_readonly("stream", [](const Tensor &t)->py::object { + if (t.order().is_device()) + return py::reinterpret_borrow(PyLong_FromVoidPtr(t.order().stream())); + else + return py::none(); + }) .def("data_ptr", [](Tensor &t) { return py::reinterpret_borrow(PyLong_FromVoidPtr(t.raw_mutable_data())); @@ -1532,6 +1538,12 @@ void ExposeTensorList(py::module &m) { .def("__repr__", [](TensorList &t) { return FromPythonTrampoline("nvidia.dali.tensors", "_tensorlist_to_string")(t, false); }) + .def_property_readonly("stream", [](const Tensor &t)->py::object { + if (t.order().is_device()) + return py::reinterpret_borrow(PyLong_FromVoidPtr(t.order().stream())); + else + return py::none(); + }) .def_property_readonly("dtype", [](TensorList &tl) { return tl.type(); }, diff --git a/dali/python/nvidia/dali/types.py b/dali/python/nvidia/dali/types.py index 6d33ba7853..9fabc23fc0 100644 --- a/dali/python/nvidia/dali/types.py +++ b/dali/python/nvidia/dali/types.py @@ -16,6 +16,7 @@ from enum import Enum, unique import ctypes import re +from nvidia.dali import backend_impl from nvidia.dali._backend_enums import ( DALIDataType as DALIDataType, @@ -397,6 +398,8 @@ def _raw_cuda_stream(stream_obj): def _get_default_stream_for_array(array): if isinstance(array, list) and len(array): array = array[0] + if isinstance(array, (backend_impl.TensorListGPU, backend_impl.TensorGPU)): + return array.stream if _is_torch_tensor(array): import torch