Skip to content

Commit

Permalink
Expose streams to Python in DALI tensors.
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
mzient committed Oct 29, 2024
1 parent 2838655 commit 71854e7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
12 changes: 12 additions & 0 deletions dali/python/backend_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,12 @@ void ExposeTensor(py::module &m) {
non_blocking : bool
Asynchronous copy.
)code")
.def_property_readonly("stream", [](const Tensor<GPUBackend> &t)->py::object {
if (t.order().is_device())
return py::reinterpret_borrow<py::object>(PyLong_FromVoidPtr(t.order().stream()));
else
return py::none();
})
.def("data_ptr",
[](Tensor<GPUBackend> &t) {
return py::reinterpret_borrow<py::object>(PyLong_FromVoidPtr(t.raw_mutable_data()));
Expand Down Expand Up @@ -1532,6 +1538,12 @@ void ExposeTensorList(py::module &m) {
.def("__repr__", [](TensorList<GPUBackend> &t) {
return FromPythonTrampoline("nvidia.dali.tensors", "_tensorlist_to_string")(t, false);
})
.def_property_readonly("stream", [](const Tensor<GPUBackend> &t)->py::object {
if (t.order().is_device())
return py::reinterpret_borrow<py::object>(PyLong_FromVoidPtr(t.order().stream()));
else
return py::none();
})
.def_property_readonly("dtype", [](TensorList<GPUBackend> &tl) {
return tl.type();
},
Expand Down
3 changes: 3 additions & 0 deletions dali/python/nvidia/dali/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from enum import Enum, unique
import ctypes
import re
from nvidia.dali import backend_impl

from nvidia.dali._backend_enums import (
DALIDataType as DALIDataType,
Expand Down Expand Up @@ -397,6 +398,8 @@ def _raw_cuda_stream(stream_obj):
def _get_default_stream_for_array(array):
if isinstance(array, list) and len(array):
array = array[0]
if isinstance(array, (backend_impl.TensorListGPU, backend_impl.TensorGPU)):
return array.stream
if _is_torch_tensor(array):
import torch

Expand Down

0 comments on commit 71854e7

Please sign in to comment.