From a093d74146c1f1bacef47e58501348acede2ff46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Zientkiewicz?= Date: Mon, 28 Oct 2024 11:04:05 +0100 Subject: [PATCH] Add dynamic executor support to TF plugin. (#5686) * Add dynamic executor support to TF plugin. * Add tests that wouldn't work with legacy executor --------- Signed-off-by: Michal Zientkiewicz --- dali/python/nvidia/dali/plugin/tf.py | 19 +++- dali/test/python/test_dali_tf_exec2.py | 133 +++++++++++++++++++++++++ dali_tf_plugin/dali_dataset.h | 4 +- dali_tf_plugin/dali_dataset_op.cc | 50 ++++++---- dali_tf_plugin/daliop.cc | 63 +++++++----- qa/TL0_tensorflow_plugin/test.sh | 3 + 6 files changed, 224 insertions(+), 48 deletions(-) create mode 100644 dali/test/python/test_dali_tf_exec2.py diff --git a/dali/python/nvidia/dali/plugin/tf.py b/dali/python/nvidia/dali/plugin/tf.py index 7ae887e79d3..28260bf3299 100644 --- a/dali/python/nvidia/dali/plugin/tf.py +++ b/dali/python/nvidia/dali/plugin/tf.py @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -214,6 +214,7 @@ def DALIIteratorWrapper( dtypes=[], batch_size=-1, prefetch_queue_depth=2, + exec_dynamic=False, **kwargs, ): """ @@ -232,6 +233,9 @@ def DALIIteratorWrapper( cpu_prefetch_queue_depth = -1 # dummy: wont' be used gpu_prefetch_queue_depth = prefetch_queue_depth + if pipeline is not None and pipeline._exec_dynamic: + exec_dynamic = True + if serialized_pipeline is None: serialized_pipeline = serialize_pipeline(pipeline) @@ -281,6 +285,7 @@ def DALIIteratorWrapper( exec_separated=exec_separated, gpu_prefetch_queue_depth=gpu_prefetch_queue_depth, cpu_prefetch_queue_depth=cpu_prefetch_queue_depth, + exec_dynamic=exec_dynamic, **kwargs, ) new_out = [] @@ -436,6 +441,7 @@ def __init__( num_threads=4, device_id=0, exec_separated=False, + exec_dynamic=False, prefetch_queue_depth=2, cpu_prefetch_queue_depth=2, gpu_prefetch_queue_depth=2, @@ -445,6 +451,9 @@ def __init__( output_shapes = self._handle_deprecation(output_shapes, shapes, "shapes") output_dtypes = self._handle_deprecation(output_dtypes, dtypes, "dtypes") + if pipeline._exec_dynamic: + exec_dynamic = True + if not self._check_dtypes(output_dtypes, tf.DType): raise TypeError( "`output_dtypes` should be provided as single tf.DType value " @@ -475,6 +484,7 @@ def __init__( device_id = types.CPU_ONLY_DEVICE_ID self._device_id = device_id self._exec_separated = exec_separated + self._exec_dynamic = exec_dynamic self._prefetch_queue_depth = prefetch_queue_depth self._cpu_prefetch_queue_depth = cpu_prefetch_queue_depth self._gpu_prefetch_queue_depth = gpu_prefetch_queue_depth @@ -805,6 +815,7 @@ def _as_variant_tensor(self): num_threads=self._num_threads, device_id=self._device_id, exec_separated=self._exec_separated, + exec_dynamic=self._exec_dynamic, prefetch_queue_depth=self._prefetch_queue_depth, cpu_prefetch_queue_depth=self._cpu_prefetch_queue_depth, gpu_prefetch_queue_depth=self._gpu_prefetch_queue_depth, @@ -865,6 +876,7 @@ def __init__( num_threads=4, device_id=0, exec_separated=False, + exec_dynamic=False, prefetch_queue_depth=2, cpu_prefetch_queue_depth=2, gpu_prefetch_queue_depth=2, @@ -984,6 +996,11 @@ def __init__(self, *args, **kwargs): Whether to execute the pipeline in a way that enables overlapping CPU and GPU computation, typically resulting in faster execution speed, but larger memory consumption. + This flag is incompatible with ``exec_dymamic``. + exec_dynamic : bool, optional, default = False + Whether to execute the pipeline with the dynamic executor, which allows flexible mixing + of CPU and GPU operators and enables aggressive memory reuse. + This flag is incompatible with ``exec_separated``. prefetch_queue_depth : int, optional, default = 2 depth of the executor queue. Deeper queue makes DALI more resistant to uneven execution time of each batch, but it also diff --git a/dali/test/python/test_dali_tf_exec2.py b/dali/test/python/test_dali_tf_exec2.py new file mode 100644 index 00000000000..ec83babdc32 --- /dev/null +++ b/dali/test/python/test_dali_tf_exec2.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import numpy as np +import os.path +from nvidia.dali import pipeline_def +import nvidia.dali.fn as fn +import nvidia.dali.types as types +import nvidia.dali.plugin.tf as dali_tf +from nose_utils import with_setup +from test_utils_tensorflow import skip_inputs_for_incompatible_tf +from test_utils import get_dali_extra_path + + +test_data_root = get_dali_extra_path() +lmdb_folder = os.path.join(test_data_root, "db", "lmdb") + + +@pipeline_def( + enable_conditionals=True, + batch_size=5, + num_threads=4, + device_id=0, + experimental_exec_dynamic=True, +) +def dali_exec2_pipeline(): + iter_id = fn.external_source(source=lambda x: np.array(x.iteration), batch=False) + if iter_id & 1 == 0: + output = types.Constant(np.array(-1), device="gpu") + else: + output = types.Constant(np.array(1), device="gpu") + return output.cpu() + + +@with_setup(skip_inputs_for_incompatible_tf) +def test_tf_dataset_exec2(): + """Test that exec_dynamic is propagated to DALI pipeline from dali_tf.DALIDatasetWithInputs""" + # From Tensorflow's perspective, this is a CPU pipeline + with tf.device("/cpu:0"): + dali_dataset = dali_tf.experimental.DALIDatasetWithInputs( + pipeline=dali_exec2_pipeline(), + batch_size=5, + output_shapes=(5,), + output_dtypes=(tf.int32), + num_threads=4, + device_id=0, + ) + + @tf.function + def tf_function_with_conditionals(dali_dataset): + negative = tf.constant(0) + positive = tf.constant(0) + for input in dali_dataset: + if tf.reduce_sum(input) < 0: + negative = negative + 1 + else: + positive = positive + 1 + return negative, positive + + pos, neg = tf_function_with_conditionals(dali_dataset.take(5)) + assert pos == 3 + assert neg == 2 + + +@pipeline_def(num_threads=4, experimental_exec_dynamic=True) +def daliop_pipe(): + jpegs, labels = fn.readers.caffe(path=lmdb_folder, random_shuffle=False) + imgs = fn.decoders.image(jpegs, device="mixed") + imgs = fn.resize(imgs, size=(100, 100)) + shape = imgs.shape(dtype=types.UINT32) + return imgs.cpu(), shape + + +def get_batch_dali(batch_size): + pipe = daliop_pipe(batch_size=batch_size, num_threads=4, device_id=0) + pipe.build() + + daliop = dali_tf.DALIIterator() + images = [] + labels = [] + with tf.device("/cpu:0"): + image, label = daliop( + pipeline=pipe, + shapes=[ + (batch_size, 100, 100, 3), + ( + batch_size, + 3, + ), + ], + dtypes=[tf.uint8, tf.int32], + device_id=0, + ) + images.append(image) + labels.append(label) + + return [images, labels] + + +def test_tf_op(): + """Test that exec_dynamic is propagated to DALI pipeline from dali_tf.DALIIterator""" + try: + tf.compat.v1.disable_eager_execution() + except ModuleNotFoundError: + pass + + batch_size = 8 + iterations = 2 + test_batch = get_batch_dali(batch_size) + try: + from tensorflow.compat.v1 import Session + except ImportError: + # Older TF versions don't have compat.v1 layer + from tensorflow import Session + + with Session() as sess: + for i in range(iterations): + imgs, shapes = sess.run(test_batch) + for img, shape in zip(imgs, shapes): + for i in range(batch_size): + assert tuple(img[i].shape) == tuple(shape[i]) diff --git a/dali_tf_plugin/dali_dataset.h b/dali_tf_plugin/dali_dataset.h index f4996f175a0..ef0b4d7b1bd 100644 --- a/dali_tf_plugin/dali_dataset.h +++ b/dali_tf_plugin/dali_dataset.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -70,6 +70,7 @@ class DALIDatasetOp : public tensorflow::data::DatasetOpKernel { int num_threads; int device_id; bool exec_separated; + bool exec_dynamic; int prefetch_queue_depth; int cpu_prefetch_queue_depth; int gpu_prefetch_queue_depth; @@ -99,6 +100,7 @@ class DALIDatasetOp : public tensorflow::data::DatasetOpKernel { static constexpr const char* const kNumThreads = "num_threads"; static constexpr const char* const kDeviceId = "device_id"; static constexpr const char* const kExecSeparated = "exec_separated"; + static constexpr const char* const kExecDynamic = "exec_dynamic"; static constexpr const char* const kPrefetchQueueDepth = "prefetch_queue_depth"; static constexpr const char* const kCpuPrefetchQueueDepth = "cpu_prefetch_queue_depth"; static constexpr const char* const kGpuPrefetchQueueDepth = "gpu_prefetch_queue_depth"; diff --git a/dali_tf_plugin/dali_dataset_op.cc b/dali_tf_plugin/dali_dataset_op.cc index 0c95a0dd2f0..dbefe62ae6c 100644 --- a/dali_tf_plugin/dali_dataset_op.cc +++ b/dali_tf_plugin/dali_dataset_op.cc @@ -220,6 +220,7 @@ class DALIDatasetOp::Dataset : public DatasetBase { SerializeField(attrs, b, kNumThreads, pipeline_def_.num_threads); SerializeField(attrs, b, kDeviceId, pipeline_def_.device_id); SerializeField(attrs, b, kExecSeparated, pipeline_def_.exec_separated); + SerializeField(attrs, b, kExecDynamic, pipeline_def_.exec_dynamic); SerializeField(attrs, b, kPrefetchQueueDepth, pipeline_def_.prefetch_queue_depth); SerializeField(attrs, b, kCpuPrefetchQueueDepth, pipeline_def_.cpu_prefetch_queue_depth); SerializeField(attrs, b, kGpuPrefetchQueueDepth, pipeline_def_.gpu_prefetch_queue_depth); @@ -248,10 +249,15 @@ class DALIDatasetOp::Dataset : public DatasetBase { } Status InitPipeline(daliPipelineHandle *pipeline_handle) const { - TF_DALI_CALL(daliCreatePipeline( + dali_exec_flags_t flags = DALI_EXEC_ASYNC_PIPELINED; + if (pipeline_def_.exec_dynamic) + flags = flags | DALI_EXEC_IS_DYNAMIC; + if (pipeline_def_.exec_separated) + flags = flags | DALI_EXEC_IS_SEPARATED; + TF_DALI_CALL(daliCreatePipeline3( pipeline_handle, pipeline_def_.pipeline.c_str(), pipeline_def_.pipeline.length(), pipeline_def_.batch_size, pipeline_def_.num_threads, pipeline_def_.device_id, - pipeline_def_.exec_separated, pipeline_def_.prefetch_queue_depth, + flags, pipeline_def_.prefetch_queue_depth, pipeline_def_.cpu_prefetch_queue_depth, pipeline_def_.gpu_prefetch_queue_depth, pipeline_def_.enable_memory_stats)); return Status(); @@ -380,26 +386,28 @@ class DALIDatasetOp::Dataset::Iterator : public DatasetIterator { } ~Iterator() { - if (enable_memory_stats_) { - size_t N; - daliExecutorMetadata *meta; - daliGetExecutorMetadata(&pipeline_handle_, &meta, &N); - std::cout << "DALI operator memory statistics: " << std::endl; - for (size_t i = 0; i < N; ++i) { - std::cout << "Operator " << meta[i].operator_name; - for (size_t j = 0; j < meta[i].out_num; ++j) { - std::cout << " output [ " << j << " ] : " << meta[i].real_size[j] << "B allocated " - << meta[i].max_real_size[j] << "B max allocated " << meta[i].reserved[j] - << "B reserved" << meta[i].max_reserved[j] << "B max reserved"; - if (j != meta[i].out_num - 1) { - std::cout << ","; + if (pipeline_handle_) { + if (enable_memory_stats_) { + size_t N; + daliExecutorMetadata *meta; + daliGetExecutorMetadata(&pipeline_handle_, &meta, &N); + std::cout << "DALI operator memory statistics: " << std::endl; + for (size_t i = 0; i < N; ++i) { + std::cout << "Operator " << meta[i].operator_name; + for (size_t j = 0; j < meta[i].out_num; ++j) { + std::cout << " output [ " << j << " ] : " << meta[i].real_size[j] << "B allocated " + << meta[i].max_real_size[j] << "B max allocated " << meta[i].reserved[j] + << "B reserved" << meta[i].max_reserved[j] << "B max reserved"; + if (j != meta[i].out_num - 1) { + std::cout << ","; + } } + std::cout << std::endl; } - std::cout << std::endl; + daliFreeExecutorMetadata(meta, N); } - daliFreeExecutorMetadata(meta, N); + daliDeletePipeline(&pipeline_handle_); } - daliDeletePipeline(&pipeline_handle_); } #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION >= 3) @@ -941,8 +949,8 @@ class DALIDatasetOp::Dataset::Iterator : public DatasetIterator { std::vector input_ext_src_devices_; std::queue alive_batches_; InputState iterator_state_ = InputState::in_progress; - daliPipelineHandle pipeline_handle_; - bool enable_memory_stats_; + daliPipelineHandle pipeline_handle_ = nullptr; + bool enable_memory_stats_ = false; }; void DALIDatasetOp::MakeDataset(OpKernelContext *context, DatasetBase **output) { @@ -959,6 +967,7 @@ void DALIDatasetOp::FillPipelineDef(OpKernelConstruction *context, PipelineDef & OP_REQUIRES_OK(context, context->GetAttr(kNumThreads, &def.num_threads)); OP_REQUIRES_OK(context, context->GetAttr(kDeviceId, &def.device_id)); OP_REQUIRES_OK(context, context->GetAttr(kExecSeparated, &def.exec_separated)); + OP_REQUIRES_OK(context, context->GetAttr(kExecDynamic, &def.exec_dynamic)); OP_REQUIRES_OK(context, context->GetAttr(kPrefetchQueueDepth, &def.prefetch_queue_depth)); OP_REQUIRES_OK(context, context->GetAttr(kCpuPrefetchQueueDepth, &def.cpu_prefetch_queue_depth)); OP_REQUIRES_OK(context, context->GetAttr(kGpuPrefetchQueueDepth, &def.gpu_prefetch_queue_depth)); @@ -1079,6 +1088,7 @@ REGISTER_OP("DALIDataset") .Attr("num_threads: int") .Attr("device_id: int") .Attr("exec_separated: bool") + .Attr("exec_dynamic: bool") .Attr("prefetch_queue_depth: int") .Attr("cpu_prefetch_queue_depth: int") .Attr("gpu_prefetch_queue_depth: int") diff --git a/dali_tf_plugin/daliop.cc b/dali_tf_plugin/daliop.cc index 7d9e1a9ed55..5123ab1355b 100644 --- a/dali_tf_plugin/daliop.cc +++ b/dali_tf_plugin/daliop.cc @@ -67,6 +67,7 @@ REGISTER_OP("Dali") .Attr("num_threads: int = -1") .Attr("device_id: int = -1") .Attr("exec_separated: bool = false") + .Attr("exec_dynamic: bool = false") .Attr("gpu_prefetch_queue_depth: int = 2") .Attr("cpu_prefetch_queue_depth: int = 2") .Attr("sparse: list(bool) = []") @@ -111,6 +112,7 @@ class DaliOp : public tf::OpKernel { int device_id; int max_batch_size; bool exec_separated; + bool exec_dynamic; int cpu_prefetch_queue_depth; OP_REQUIRES_OK(context, context->GetAttr("shapes", &shapes_)); @@ -118,6 +120,7 @@ class DaliOp : public tf::OpKernel { OP_REQUIRES_OK(context, context->GetAttr("num_threads", &num_threads)); OP_REQUIRES_OK(context, context->GetAttr("device_id", &device_id)); OP_REQUIRES_OK(context, context->GetAttr("exec_separated", &exec_separated)); + OP_REQUIRES_OK(context, context->GetAttr("exec_dynamic", &exec_dynamic)); // In exec_separated==false case, gpu_prefetch_queue_depth is the global prefetch_queue_depth_ OP_REQUIRES_OK(context, context->GetAttr("gpu_prefetch_queue_depth", &prefetch_queue_depth_)); OP_REQUIRES_OK(context, context->GetAttr("sparse", &sparse_)); @@ -142,13 +145,19 @@ class DaliOp : public tf::OpKernel { max_batch_size = shapes_[0].dim_size(0); } - TF_DALI_CALL(daliCreatePipeline(&pipe_handle_, + dali_exec_flags_t flags = DALI_EXEC_ASYNC_PIPELINED; + if (exec_dynamic) + flags = flags | DALI_EXEC_IS_DYNAMIC; + if (exec_separated) + flags = flags | DALI_EXEC_IS_SEPARATED; + + TF_DALI_CALL(daliCreatePipeline3(&pipe_handle_, serialized_pipeline.c_str(), serialized_pipeline.length(), max_batch_size, num_threads, device_id, - exec_separated, + flags, prefetch_queue_depth_, cpu_prefetch_queue_depth, prefetch_queue_depth_, @@ -165,28 +174,30 @@ class DaliOp : public tf::OpKernel { } ~DaliOp() override { - if (enable_memory_stats_) { - size_t N; - daliExecutorMetadata *meta; - daliGetExecutorMetadata(&pipe_handle_, &meta, &N); - std::cout << "DALI operator memory statistics: " << std::endl; - for (size_t i = 0; i < N; ++i) { - std::cout << "Operator " << meta[i].operator_name; - for (size_t j = 0; j < meta[i].out_num; ++j) { - std::cout << " output [ " << j << " ] : " - << meta[i].real_size[j] << "B allocated " - << meta[i].max_real_size[j] << "B max allocated " - << meta[i].reserved[j] << "B reserved" - << meta[i].max_reserved[j] << "B max reserved"; - if (j != meta[i].out_num - 1) { - std::cout << ","; + if (pipe_handle_) { + if (enable_memory_stats_) { + size_t N; + daliExecutorMetadata *meta; + daliGetExecutorMetadata(&pipe_handle_, &meta, &N); + std::cout << "DALI operator memory statistics: " << std::endl; + for (size_t i = 0; i < N; ++i) { + std::cout << "Operator " << meta[i].operator_name; + for (size_t j = 0; j < meta[i].out_num; ++j) { + std::cout << " output [ " << j << " ] : " + << meta[i].real_size[j] << "B allocated " + << meta[i].max_real_size[j] << "B max allocated " + << meta[i].reserved[j] << "B reserved" + << meta[i].max_reserved[j] << "B max reserved"; + if (j != meta[i].out_num - 1) { + std::cout << ","; + } } + std::cout << std::endl; } - std::cout << std::endl; + daliFreeExecutorMetadata(meta, N); } - daliFreeExecutorMetadata(meta, N); + daliDeletePipeline(&pipe_handle_); } - daliDeletePipeline(&pipe_handle_); } void Compute(tf::OpKernelContext* context) override { @@ -389,15 +400,15 @@ class DaliOp : public tf::OpKernel { } private: - daliPipelineHandle pipe_handle_; + daliPipelineHandle pipe_handle_ = nullptr; std::vector shapes_; tf::DataTypeVector types_; - int device_id_; - int batch_size_; - int prefetch_queue_depth_; - device_type_t device_type_; + int device_id_ = -1; + int batch_size_ = 0; + int prefetch_queue_depth_ = -1; + device_type_t device_type_ = CPU; std::vector sparse_; - bool enable_memory_stats_; + bool enable_memory_stats_ = false; }; using tf::int64; diff --git a/qa/TL0_tensorflow_plugin/test.sh b/qa/TL0_tensorflow_plugin/test.sh index 60092468d72..0f4ef097a46 100755 --- a/qa/TL0_tensorflow_plugin/test.sh +++ b/qa/TL0_tensorflow_plugin/test.sh @@ -45,6 +45,9 @@ test_body() { ${python_invoke_test} test_dali_tf_dataset_eager.py ${python_invoke_test} test_dali_tf_dataset_graph.py fi + + # DALI TF + dynamic executor + ${python_invoke_test} test_dali_tf_exec2.py } pushd ../..