Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dynamic executor support to TF plugin. #5686

Merged
merged 2 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion dali/python/nvidia/dali/plugin/tf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -214,6 +214,7 @@ def DALIIteratorWrapper(
dtypes=[],
batch_size=-1,
prefetch_queue_depth=2,
exec_dynamic=False,
**kwargs,
):
"""
Expand All @@ -232,6 +233,9 @@ def DALIIteratorWrapper(
cpu_prefetch_queue_depth = -1 # dummy: wont' be used
gpu_prefetch_queue_depth = prefetch_queue_depth

if pipeline is not None and pipeline._exec_dynamic:
exec_dynamic = True

if serialized_pipeline is None:
serialized_pipeline = serialize_pipeline(pipeline)

Expand Down Expand Up @@ -281,6 +285,7 @@ def DALIIteratorWrapper(
exec_separated=exec_separated,
gpu_prefetch_queue_depth=gpu_prefetch_queue_depth,
cpu_prefetch_queue_depth=cpu_prefetch_queue_depth,
exec_dynamic=exec_dynamic,
**kwargs,
)
new_out = []
Expand Down Expand Up @@ -436,6 +441,7 @@ def __init__(
num_threads=4,
device_id=0,
exec_separated=False,
exec_dynamic=False,
prefetch_queue_depth=2,
cpu_prefetch_queue_depth=2,
gpu_prefetch_queue_depth=2,
Expand All @@ -445,6 +451,9 @@ def __init__(
output_shapes = self._handle_deprecation(output_shapes, shapes, "shapes")
output_dtypes = self._handle_deprecation(output_dtypes, dtypes, "dtypes")

if pipeline._exec_dynamic:
exec_dynamic = True

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not move it to the line 487 and do the following there:
self.exec_dynamic = True if pipeline._exec_dynamic else exec_dynamic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it was more logical to set up arguments first and then do the "business as usual" of populating the fields...

if not self._check_dtypes(output_dtypes, tf.DType):
raise TypeError(
"`output_dtypes` should be provided as single tf.DType value "
Expand Down Expand Up @@ -475,6 +484,7 @@ def __init__(
device_id = types.CPU_ONLY_DEVICE_ID
self._device_id = device_id
self._exec_separated = exec_separated
self._exec_dynamic = exec_dynamic
self._prefetch_queue_depth = prefetch_queue_depth
self._cpu_prefetch_queue_depth = cpu_prefetch_queue_depth
self._gpu_prefetch_queue_depth = gpu_prefetch_queue_depth
Expand Down Expand Up @@ -805,6 +815,7 @@ def _as_variant_tensor(self):
num_threads=self._num_threads,
device_id=self._device_id,
exec_separated=self._exec_separated,
exec_dynamic=self._exec_dynamic,
prefetch_queue_depth=self._prefetch_queue_depth,
cpu_prefetch_queue_depth=self._cpu_prefetch_queue_depth,
gpu_prefetch_queue_depth=self._gpu_prefetch_queue_depth,
Expand Down Expand Up @@ -865,6 +876,7 @@ def __init__(
num_threads=4,
device_id=0,
exec_separated=False,
exec_dynamic=False,
prefetch_queue_depth=2,
cpu_prefetch_queue_depth=2,
gpu_prefetch_queue_depth=2,
Expand Down Expand Up @@ -984,6 +996,11 @@ def __init__(self, *args, **kwargs):
Whether to execute the pipeline in a way that enables
overlapping CPU and GPU computation, typically resulting
in faster execution speed, but larger memory consumption.
This flag is incompatible with ``exec_dymamic``.
exec_dynamic : bool, optional, default = False
Whether to execute the pipeline with the dynamic executor, which allows flexible mixing
of CPU and GPU operators and enables aggressive memory reuse.
This flag is incompatible with ``exec_separated``.
prefetch_queue_depth : int, optional, default = 2
depth of the executor queue. Deeper queue makes DALI more
resistant to uneven execution time of each batch, but it also
Expand Down
4 changes: 3 additions & 1 deletion dali_tf_plugin/dali_dataset.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -70,6 +70,7 @@ class DALIDatasetOp : public tensorflow::data::DatasetOpKernel {
int num_threads;
int device_id;
bool exec_separated;
bool exec_dynamic;
int prefetch_queue_depth;
int cpu_prefetch_queue_depth;
int gpu_prefetch_queue_depth;
Expand Down Expand Up @@ -99,6 +100,7 @@ class DALIDatasetOp : public tensorflow::data::DatasetOpKernel {
static constexpr const char* const kNumThreads = "num_threads";
static constexpr const char* const kDeviceId = "device_id";
static constexpr const char* const kExecSeparated = "exec_separated";
static constexpr const char* const kExecDynamic = "exec_dynamic";
static constexpr const char* const kPrefetchQueueDepth = "prefetch_queue_depth";
static constexpr const char* const kCpuPrefetchQueueDepth = "cpu_prefetch_queue_depth";
static constexpr const char* const kGpuPrefetchQueueDepth = "gpu_prefetch_queue_depth";
Expand Down
50 changes: 30 additions & 20 deletions dali_tf_plugin/dali_dataset_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ class DALIDatasetOp::Dataset : public DatasetBase {
SerializeField(attrs, b, kNumThreads, pipeline_def_.num_threads);
SerializeField(attrs, b, kDeviceId, pipeline_def_.device_id);
SerializeField(attrs, b, kExecSeparated, pipeline_def_.exec_separated);
SerializeField(attrs, b, kExecDynamic, pipeline_def_.exec_dynamic);
SerializeField(attrs, b, kPrefetchQueueDepth, pipeline_def_.prefetch_queue_depth);
SerializeField(attrs, b, kCpuPrefetchQueueDepth, pipeline_def_.cpu_prefetch_queue_depth);
SerializeField(attrs, b, kGpuPrefetchQueueDepth, pipeline_def_.gpu_prefetch_queue_depth);
Expand Down Expand Up @@ -248,10 +249,15 @@ class DALIDatasetOp::Dataset : public DatasetBase {
}

Status InitPipeline(daliPipelineHandle *pipeline_handle) const {
TF_DALI_CALL(daliCreatePipeline(
dali_exec_flags_t flags = DALI_EXEC_ASYNC_PIPELINED;
if (pipeline_def_.exec_dynamic)
flags = flags | DALI_EXEC_IS_DYNAMIC;
if (pipeline_def_.exec_separated)
flags = flags | DALI_EXEC_IS_SEPARATED;
TF_DALI_CALL(daliCreatePipeline3(
pipeline_handle, pipeline_def_.pipeline.c_str(), pipeline_def_.pipeline.length(),
pipeline_def_.batch_size, pipeline_def_.num_threads, pipeline_def_.device_id,
pipeline_def_.exec_separated, pipeline_def_.prefetch_queue_depth,
flags, pipeline_def_.prefetch_queue_depth,
pipeline_def_.cpu_prefetch_queue_depth, pipeline_def_.gpu_prefetch_queue_depth,
pipeline_def_.enable_memory_stats));
return Status();
Expand Down Expand Up @@ -380,26 +386,28 @@ class DALIDatasetOp::Dataset::Iterator : public DatasetIterator<Dataset> {
}

~Iterator() {
if (enable_memory_stats_) {
size_t N;
daliExecutorMetadata *meta;
daliGetExecutorMetadata(&pipeline_handle_, &meta, &N);
std::cout << "DALI operator memory statistics: " << std::endl;
for (size_t i = 0; i < N; ++i) {
std::cout << "Operator " << meta[i].operator_name;
for (size_t j = 0; j < meta[i].out_num; ++j) {
std::cout << " output [ " << j << " ] : " << meta[i].real_size[j] << "B allocated "
<< meta[i].max_real_size[j] << "B max allocated " << meta[i].reserved[j]
<< "B reserved" << meta[i].max_reserved[j] << "B max reserved";
if (j != meta[i].out_num - 1) {
std::cout << ",";
if (pipeline_handle_) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was actually a bug, that could cause a SIGSEGV. One possible trigger was to specify an incomplete output shape in Python.

if (enable_memory_stats_) {
size_t N;
daliExecutorMetadata *meta;
daliGetExecutorMetadata(&pipeline_handle_, &meta, &N);
std::cout << "DALI operator memory statistics: " << std::endl;
for (size_t i = 0; i < N; ++i) {
std::cout << "Operator " << meta[i].operator_name;
for (size_t j = 0; j < meta[i].out_num; ++j) {
std::cout << " output [ " << j << " ] : " << meta[i].real_size[j] << "B allocated "
<< meta[i].max_real_size[j] << "B max allocated " << meta[i].reserved[j]
<< "B reserved" << meta[i].max_reserved[j] << "B max reserved";
if (j != meta[i].out_num - 1) {
std::cout << ",";
}
}
std::cout << std::endl;
}
std::cout << std::endl;
daliFreeExecutorMetadata(meta, N);
}
daliFreeExecutorMetadata(meta, N);
daliDeletePipeline(&pipeline_handle_);
}
daliDeletePipeline(&pipeline_handle_);
}

#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION >= 3)
Expand Down Expand Up @@ -941,8 +949,8 @@ class DALIDatasetOp::Dataset::Iterator : public DatasetIterator<Dataset> {
std::vector<dali_backend_t> input_ext_src_devices_;
std::queue<ListOfBatches> alive_batches_;
InputState iterator_state_ = InputState::in_progress;
daliPipelineHandle pipeline_handle_;
bool enable_memory_stats_;
daliPipelineHandle pipeline_handle_ = nullptr;
bool enable_memory_stats_ = false;
};

void DALIDatasetOp::MakeDataset(OpKernelContext *context, DatasetBase **output) {
Expand All @@ -959,6 +967,7 @@ void DALIDatasetOp::FillPipelineDef(OpKernelConstruction *context, PipelineDef &
OP_REQUIRES_OK(context, context->GetAttr(kNumThreads, &def.num_threads));
OP_REQUIRES_OK(context, context->GetAttr(kDeviceId, &def.device_id));
OP_REQUIRES_OK(context, context->GetAttr(kExecSeparated, &def.exec_separated));
OP_REQUIRES_OK(context, context->GetAttr(kExecDynamic, &def.exec_dynamic));
OP_REQUIRES_OK(context, context->GetAttr(kPrefetchQueueDepth, &def.prefetch_queue_depth));
OP_REQUIRES_OK(context, context->GetAttr(kCpuPrefetchQueueDepth, &def.cpu_prefetch_queue_depth));
OP_REQUIRES_OK(context, context->GetAttr(kGpuPrefetchQueueDepth, &def.gpu_prefetch_queue_depth));
Expand Down Expand Up @@ -1079,6 +1088,7 @@ REGISTER_OP("DALIDataset")
.Attr("num_threads: int")
.Attr("device_id: int")
.Attr("exec_separated: bool")
.Attr("exec_dynamic: bool")
.Attr("prefetch_queue_depth: int")
.Attr("cpu_prefetch_queue_depth: int")
.Attr("gpu_prefetch_queue_depth: int")
Expand Down
63 changes: 37 additions & 26 deletions dali_tf_plugin/daliop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ REGISTER_OP("Dali")
.Attr("num_threads: int = -1")
.Attr("device_id: int = -1")
.Attr("exec_separated: bool = false")
.Attr("exec_dynamic: bool = false")
.Attr("gpu_prefetch_queue_depth: int = 2")
.Attr("cpu_prefetch_queue_depth: int = 2")
.Attr("sparse: list(bool) = []")
Expand Down Expand Up @@ -111,13 +112,15 @@ class DaliOp : public tf::OpKernel {
int device_id;
int max_batch_size;
bool exec_separated;
bool exec_dynamic;
int cpu_prefetch_queue_depth;

OP_REQUIRES_OK(context, context->GetAttr("shapes", &shapes_));
OP_REQUIRES_OK(context, context->GetAttr("dtypes", &types_));
OP_REQUIRES_OK(context, context->GetAttr("num_threads", &num_threads));
OP_REQUIRES_OK(context, context->GetAttr("device_id", &device_id));
OP_REQUIRES_OK(context, context->GetAttr("exec_separated", &exec_separated));
OP_REQUIRES_OK(context, context->GetAttr("exec_dynamic", &exec_dynamic));
// In exec_separated==false case, gpu_prefetch_queue_depth is the global prefetch_queue_depth_
OP_REQUIRES_OK(context, context->GetAttr("gpu_prefetch_queue_depth", &prefetch_queue_depth_));
OP_REQUIRES_OK(context, context->GetAttr("sparse", &sparse_));
Expand All @@ -142,13 +145,19 @@ class DaliOp : public tf::OpKernel {
max_batch_size = shapes_[0].dim_size(0);
}

TF_DALI_CALL(daliCreatePipeline(&pipe_handle_,
dali_exec_flags_t flags = DALI_EXEC_ASYNC_PIPELINED;
if (exec_dynamic)
flags = flags | DALI_EXEC_IS_DYNAMIC;
if (exec_separated)
flags = flags | DALI_EXEC_IS_SEPARATED;

TF_DALI_CALL(daliCreatePipeline3(&pipe_handle_,
serialized_pipeline.c_str(),
serialized_pipeline.length(),
max_batch_size,
num_threads,
device_id,
exec_separated,
flags,
prefetch_queue_depth_,
cpu_prefetch_queue_depth,
prefetch_queue_depth_,
Expand All @@ -165,28 +174,30 @@ class DaliOp : public tf::OpKernel {
}

~DaliOp() override {
if (enable_memory_stats_) {
size_t N;
daliExecutorMetadata *meta;
daliGetExecutorMetadata(&pipe_handle_, &meta, &N);
std::cout << "DALI operator memory statistics: " << std::endl;
for (size_t i = 0; i < N; ++i) {
std::cout << "Operator " << meta[i].operator_name;
for (size_t j = 0; j < meta[i].out_num; ++j) {
std::cout << " output [ " << j << " ] : "
<< meta[i].real_size[j] << "B allocated "
<< meta[i].max_real_size[j] << "B max allocated "
<< meta[i].reserved[j] << "B reserved"
<< meta[i].max_reserved[j] << "B max reserved";
if (j != meta[i].out_num - 1) {
std::cout << ",";
if (pipe_handle_) {
if (enable_memory_stats_) {
size_t N;
daliExecutorMetadata *meta;
daliGetExecutorMetadata(&pipe_handle_, &meta, &N);
std::cout << "DALI operator memory statistics: " << std::endl;
for (size_t i = 0; i < N; ++i) {
std::cout << "Operator " << meta[i].operator_name;
for (size_t j = 0; j < meta[i].out_num; ++j) {
std::cout << " output [ " << j << " ] : "
<< meta[i].real_size[j] << "B allocated "
<< meta[i].max_real_size[j] << "B max allocated "
<< meta[i].reserved[j] << "B reserved"
<< meta[i].max_reserved[j] << "B max reserved";
if (j != meta[i].out_num - 1) {
std::cout << ",";
}
}
std::cout << std::endl;
}
std::cout << std::endl;
daliFreeExecutorMetadata(meta, N);
}
daliFreeExecutorMetadata(meta, N);
daliDeletePipeline(&pipe_handle_);
}
daliDeletePipeline(&pipe_handle_);
}

void Compute(tf::OpKernelContext* context) override {
Expand Down Expand Up @@ -389,15 +400,15 @@ class DaliOp : public tf::OpKernel {
}

private:
daliPipelineHandle pipe_handle_;
daliPipelineHandle pipe_handle_ = nullptr;
std::vector<tf::TensorShape> shapes_;
tf::DataTypeVector types_;
int device_id_;
int batch_size_;
int prefetch_queue_depth_;
device_type_t device_type_;
int device_id_ = -1;
int batch_size_ = 0;
int prefetch_queue_depth_ = -1;
device_type_t device_type_ = CPU;
std::vector<bool> sparse_;
bool enable_memory_stats_;
bool enable_memory_stats_ = false;
};

using tf::int64;
Expand Down
3 changes: 3 additions & 0 deletions qa/TL0_tensorflow_plugin/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ test_body() {
${python_invoke_test} test_dali_tf_dataset_eager.py
${python_invoke_test} test_dali_tf_dataset_graph.py
fi

# DALI TF + dynamic executor
${python_invoke_test} test_dali_tf_exec2.py
}

pushd ../..
Expand Down
Loading