Skip to content

Commit

Permalink
Add dynamic executor support to TF plugin. (#5686)
Browse files Browse the repository at this point in the history
* Add dynamic executor support to TF plugin.
* Add tests that wouldn't work with legacy executor

---------

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
mzient authored Oct 28, 2024
1 parent 52d314c commit a093d74
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 48 deletions.
19 changes: 18 additions & 1 deletion dali/python/nvidia/dali/plugin/tf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -214,6 +214,7 @@ def DALIIteratorWrapper(
dtypes=[],
batch_size=-1,
prefetch_queue_depth=2,
exec_dynamic=False,
**kwargs,
):
"""
Expand All @@ -232,6 +233,9 @@ def DALIIteratorWrapper(
cpu_prefetch_queue_depth = -1 # dummy: wont' be used
gpu_prefetch_queue_depth = prefetch_queue_depth

if pipeline is not None and pipeline._exec_dynamic:
exec_dynamic = True

if serialized_pipeline is None:
serialized_pipeline = serialize_pipeline(pipeline)

Expand Down Expand Up @@ -281,6 +285,7 @@ def DALIIteratorWrapper(
exec_separated=exec_separated,
gpu_prefetch_queue_depth=gpu_prefetch_queue_depth,
cpu_prefetch_queue_depth=cpu_prefetch_queue_depth,
exec_dynamic=exec_dynamic,
**kwargs,
)
new_out = []
Expand Down Expand Up @@ -436,6 +441,7 @@ def __init__(
num_threads=4,
device_id=0,
exec_separated=False,
exec_dynamic=False,
prefetch_queue_depth=2,
cpu_prefetch_queue_depth=2,
gpu_prefetch_queue_depth=2,
Expand All @@ -445,6 +451,9 @@ def __init__(
output_shapes = self._handle_deprecation(output_shapes, shapes, "shapes")
output_dtypes = self._handle_deprecation(output_dtypes, dtypes, "dtypes")

if pipeline._exec_dynamic:
exec_dynamic = True

if not self._check_dtypes(output_dtypes, tf.DType):
raise TypeError(
"`output_dtypes` should be provided as single tf.DType value "
Expand Down Expand Up @@ -475,6 +484,7 @@ def __init__(
device_id = types.CPU_ONLY_DEVICE_ID
self._device_id = device_id
self._exec_separated = exec_separated
self._exec_dynamic = exec_dynamic
self._prefetch_queue_depth = prefetch_queue_depth
self._cpu_prefetch_queue_depth = cpu_prefetch_queue_depth
self._gpu_prefetch_queue_depth = gpu_prefetch_queue_depth
Expand Down Expand Up @@ -805,6 +815,7 @@ def _as_variant_tensor(self):
num_threads=self._num_threads,
device_id=self._device_id,
exec_separated=self._exec_separated,
exec_dynamic=self._exec_dynamic,
prefetch_queue_depth=self._prefetch_queue_depth,
cpu_prefetch_queue_depth=self._cpu_prefetch_queue_depth,
gpu_prefetch_queue_depth=self._gpu_prefetch_queue_depth,
Expand Down Expand Up @@ -865,6 +876,7 @@ def __init__(
num_threads=4,
device_id=0,
exec_separated=False,
exec_dynamic=False,
prefetch_queue_depth=2,
cpu_prefetch_queue_depth=2,
gpu_prefetch_queue_depth=2,
Expand Down Expand Up @@ -984,6 +996,11 @@ def __init__(self, *args, **kwargs):
Whether to execute the pipeline in a way that enables
overlapping CPU and GPU computation, typically resulting
in faster execution speed, but larger memory consumption.
This flag is incompatible with ``exec_dymamic``.
exec_dynamic : bool, optional, default = False
Whether to execute the pipeline with the dynamic executor, which allows flexible mixing
of CPU and GPU operators and enables aggressive memory reuse.
This flag is incompatible with ``exec_separated``.
prefetch_queue_depth : int, optional, default = 2
depth of the executor queue. Deeper queue makes DALI more
resistant to uneven execution time of each batch, but it also
Expand Down
133 changes: 133 additions & 0 deletions dali/test/python/test_dali_tf_exec2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
import numpy as np
import os.path
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import nvidia.dali.plugin.tf as dali_tf
from nose_utils import with_setup
from test_utils_tensorflow import skip_inputs_for_incompatible_tf
from test_utils import get_dali_extra_path


test_data_root = get_dali_extra_path()
lmdb_folder = os.path.join(test_data_root, "db", "lmdb")


@pipeline_def(
enable_conditionals=True,
batch_size=5,
num_threads=4,
device_id=0,
experimental_exec_dynamic=True,
)
def dali_exec2_pipeline():
iter_id = fn.external_source(source=lambda x: np.array(x.iteration), batch=False)
if iter_id & 1 == 0:
output = types.Constant(np.array(-1), device="gpu")
else:
output = types.Constant(np.array(1), device="gpu")
return output.cpu()


@with_setup(skip_inputs_for_incompatible_tf)
def test_tf_dataset_exec2():
"""Test that exec_dynamic is propagated to DALI pipeline from dali_tf.DALIDatasetWithInputs"""
# From Tensorflow's perspective, this is a CPU pipeline
with tf.device("/cpu:0"):
dali_dataset = dali_tf.experimental.DALIDatasetWithInputs(
pipeline=dali_exec2_pipeline(),
batch_size=5,
output_shapes=(5,),
output_dtypes=(tf.int32),
num_threads=4,
device_id=0,
)

@tf.function
def tf_function_with_conditionals(dali_dataset):
negative = tf.constant(0)
positive = tf.constant(0)
for input in dali_dataset:
if tf.reduce_sum(input) < 0:
negative = negative + 1
else:
positive = positive + 1
return negative, positive

pos, neg = tf_function_with_conditionals(dali_dataset.take(5))
assert pos == 3
assert neg == 2


@pipeline_def(num_threads=4, experimental_exec_dynamic=True)
def daliop_pipe():
jpegs, labels = fn.readers.caffe(path=lmdb_folder, random_shuffle=False)
imgs = fn.decoders.image(jpegs, device="mixed")
imgs = fn.resize(imgs, size=(100, 100))
shape = imgs.shape(dtype=types.UINT32)
return imgs.cpu(), shape


def get_batch_dali(batch_size):
pipe = daliop_pipe(batch_size=batch_size, num_threads=4, device_id=0)
pipe.build()

daliop = dali_tf.DALIIterator()
images = []
labels = []
with tf.device("/cpu:0"):
image, label = daliop(
pipeline=pipe,
shapes=[
(batch_size, 100, 100, 3),
(
batch_size,
3,
),
],
dtypes=[tf.uint8, tf.int32],
device_id=0,
)
images.append(image)
labels.append(label)

return [images, labels]


def test_tf_op():
"""Test that exec_dynamic is propagated to DALI pipeline from dali_tf.DALIIterator"""
try:
tf.compat.v1.disable_eager_execution()
except ModuleNotFoundError:
pass

batch_size = 8
iterations = 2
test_batch = get_batch_dali(batch_size)
try:
from tensorflow.compat.v1 import Session
except ImportError:
# Older TF versions don't have compat.v1 layer
from tensorflow import Session

with Session() as sess:
for i in range(iterations):
imgs, shapes = sess.run(test_batch)
for img, shape in zip(imgs, shapes):
for i in range(batch_size):
assert tuple(img[i].shape) == tuple(shape[i])
4 changes: 3 additions & 1 deletion dali_tf_plugin/dali_dataset.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -70,6 +70,7 @@ class DALIDatasetOp : public tensorflow::data::DatasetOpKernel {
int num_threads;
int device_id;
bool exec_separated;
bool exec_dynamic;
int prefetch_queue_depth;
int cpu_prefetch_queue_depth;
int gpu_prefetch_queue_depth;
Expand Down Expand Up @@ -99,6 +100,7 @@ class DALIDatasetOp : public tensorflow::data::DatasetOpKernel {
static constexpr const char* const kNumThreads = "num_threads";
static constexpr const char* const kDeviceId = "device_id";
static constexpr const char* const kExecSeparated = "exec_separated";
static constexpr const char* const kExecDynamic = "exec_dynamic";
static constexpr const char* const kPrefetchQueueDepth = "prefetch_queue_depth";
static constexpr const char* const kCpuPrefetchQueueDepth = "cpu_prefetch_queue_depth";
static constexpr const char* const kGpuPrefetchQueueDepth = "gpu_prefetch_queue_depth";
Expand Down
50 changes: 30 additions & 20 deletions dali_tf_plugin/dali_dataset_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ class DALIDatasetOp::Dataset : public DatasetBase {
SerializeField(attrs, b, kNumThreads, pipeline_def_.num_threads);
SerializeField(attrs, b, kDeviceId, pipeline_def_.device_id);
SerializeField(attrs, b, kExecSeparated, pipeline_def_.exec_separated);
SerializeField(attrs, b, kExecDynamic, pipeline_def_.exec_dynamic);
SerializeField(attrs, b, kPrefetchQueueDepth, pipeline_def_.prefetch_queue_depth);
SerializeField(attrs, b, kCpuPrefetchQueueDepth, pipeline_def_.cpu_prefetch_queue_depth);
SerializeField(attrs, b, kGpuPrefetchQueueDepth, pipeline_def_.gpu_prefetch_queue_depth);
Expand Down Expand Up @@ -248,10 +249,15 @@ class DALIDatasetOp::Dataset : public DatasetBase {
}

Status InitPipeline(daliPipelineHandle *pipeline_handle) const {
TF_DALI_CALL(daliCreatePipeline(
dali_exec_flags_t flags = DALI_EXEC_ASYNC_PIPELINED;
if (pipeline_def_.exec_dynamic)
flags = flags | DALI_EXEC_IS_DYNAMIC;
if (pipeline_def_.exec_separated)
flags = flags | DALI_EXEC_IS_SEPARATED;
TF_DALI_CALL(daliCreatePipeline3(
pipeline_handle, pipeline_def_.pipeline.c_str(), pipeline_def_.pipeline.length(),
pipeline_def_.batch_size, pipeline_def_.num_threads, pipeline_def_.device_id,
pipeline_def_.exec_separated, pipeline_def_.prefetch_queue_depth,
flags, pipeline_def_.prefetch_queue_depth,
pipeline_def_.cpu_prefetch_queue_depth, pipeline_def_.gpu_prefetch_queue_depth,
pipeline_def_.enable_memory_stats));
return Status();
Expand Down Expand Up @@ -380,26 +386,28 @@ class DALIDatasetOp::Dataset::Iterator : public DatasetIterator<Dataset> {
}

~Iterator() {
if (enable_memory_stats_) {
size_t N;
daliExecutorMetadata *meta;
daliGetExecutorMetadata(&pipeline_handle_, &meta, &N);
std::cout << "DALI operator memory statistics: " << std::endl;
for (size_t i = 0; i < N; ++i) {
std::cout << "Operator " << meta[i].operator_name;
for (size_t j = 0; j < meta[i].out_num; ++j) {
std::cout << " output [ " << j << " ] : " << meta[i].real_size[j] << "B allocated "
<< meta[i].max_real_size[j] << "B max allocated " << meta[i].reserved[j]
<< "B reserved" << meta[i].max_reserved[j] << "B max reserved";
if (j != meta[i].out_num - 1) {
std::cout << ",";
if (pipeline_handle_) {
if (enable_memory_stats_) {
size_t N;
daliExecutorMetadata *meta;
daliGetExecutorMetadata(&pipeline_handle_, &meta, &N);
std::cout << "DALI operator memory statistics: " << std::endl;
for (size_t i = 0; i < N; ++i) {
std::cout << "Operator " << meta[i].operator_name;
for (size_t j = 0; j < meta[i].out_num; ++j) {
std::cout << " output [ " << j << " ] : " << meta[i].real_size[j] << "B allocated "
<< meta[i].max_real_size[j] << "B max allocated " << meta[i].reserved[j]
<< "B reserved" << meta[i].max_reserved[j] << "B max reserved";
if (j != meta[i].out_num - 1) {
std::cout << ",";
}
}
std::cout << std::endl;
}
std::cout << std::endl;
daliFreeExecutorMetadata(meta, N);
}
daliFreeExecutorMetadata(meta, N);
daliDeletePipeline(&pipeline_handle_);
}
daliDeletePipeline(&pipeline_handle_);
}

#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION >= 3)
Expand Down Expand Up @@ -941,8 +949,8 @@ class DALIDatasetOp::Dataset::Iterator : public DatasetIterator<Dataset> {
std::vector<dali_backend_t> input_ext_src_devices_;
std::queue<ListOfBatches> alive_batches_;
InputState iterator_state_ = InputState::in_progress;
daliPipelineHandle pipeline_handle_;
bool enable_memory_stats_;
daliPipelineHandle pipeline_handle_ = nullptr;
bool enable_memory_stats_ = false;
};

void DALIDatasetOp::MakeDataset(OpKernelContext *context, DatasetBase **output) {
Expand All @@ -959,6 +967,7 @@ void DALIDatasetOp::FillPipelineDef(OpKernelConstruction *context, PipelineDef &
OP_REQUIRES_OK(context, context->GetAttr(kNumThreads, &def.num_threads));
OP_REQUIRES_OK(context, context->GetAttr(kDeviceId, &def.device_id));
OP_REQUIRES_OK(context, context->GetAttr(kExecSeparated, &def.exec_separated));
OP_REQUIRES_OK(context, context->GetAttr(kExecDynamic, &def.exec_dynamic));
OP_REQUIRES_OK(context, context->GetAttr(kPrefetchQueueDepth, &def.prefetch_queue_depth));
OP_REQUIRES_OK(context, context->GetAttr(kCpuPrefetchQueueDepth, &def.cpu_prefetch_queue_depth));
OP_REQUIRES_OK(context, context->GetAttr(kGpuPrefetchQueueDepth, &def.gpu_prefetch_queue_depth));
Expand Down Expand Up @@ -1079,6 +1088,7 @@ REGISTER_OP("DALIDataset")
.Attr("num_threads: int")
.Attr("device_id: int")
.Attr("exec_separated: bool")
.Attr("exec_dynamic: bool")
.Attr("prefetch_queue_depth: int")
.Attr("cpu_prefetch_queue_depth: int")
.Attr("gpu_prefetch_queue_depth: int")
Expand Down
Loading

0 comments on commit a093d74

Please sign in to comment.