Skip to content

Commit

Permalink
Executor 2.0: ExecGraph (#5587)
Browse files Browse the repository at this point in the history
## Overall design
The executor uses `dali::tasking` as the main run-time library. `dali::tasking` ensures that the tasks are executed in correct order.
It's the job of the executor to define the dependencies between the tasks and apply other constraints.

The heart of the system is `ExecGraph` - a graph structure that stores the detailed information about the execution of the pipeline's graph nodes.
`ExecGraph` consists of `ExecNodes` connected with `ExecEdges`.
There are two kinds of graph nodes:
Operator node - an `ExecNode` which stores an instance of a DALI `OperatorBase`. Its inputs and outputs correspond precisely to the ones in operator's `OpSpec`.
Output node - a node which gathers outputs of the operator nodes that comprise pipeline's output.. The inputs of this node are pipeline's outputs. It returns a `dali::Workspace` by value.

The `ExecGraph` is normally created by lowering the `graph::OpGraph`.

The life cycle of the graph is:
* Construction (e.g. by lowering)
* Topological sorting
* Validation
* Analysis & optimization
* Usage (in a loop)
  * PrepareIteration
  * Launch

## Graph structure
ExecGraph is a directed acyclic graph which stores the nodes and edges in linked lists. Each node has an array of input edges and an array of output descriptors.
An output descriptor aggregates a list of consumers and some properties of the output (device, pinnedness and similar).
ExecEdge is a structure used for navigating the graph. It links an output of a producer node to an input of a consumer node.
The graph has no dangling edges. All graph inputs start with an (inputless) ExternalSource node and all pipeline outputs contribute to the output node. Unused outputs have an output descriptor, but no outgoing edges.

After construction the graph is sorted and analyzed. The sorting is a topological sort with an additional partitioning that guarantees that BatchSizeProviders appear first.

## Implementing order of execution
The order of execution is implemented with dali::tasking dependency mechanisms.
See: http://dali-ci-01:7070/docs/17070353/doxygen/html/namespacedali_1_1tasking.html

### Task state and integrity
Each task has the current main task and the previous main task. Since operators are non-reentrant and potentially stateful, the current task succeeds the previous task. Simply adding main_task_->Succeed(prev_task_) ensures that the tasking::Scheduler will not begin the task until the previous iteration is complete.
### Data dependencies
Each task's main task subscribes to the outputs of the producers. This not only guarantees the order, but provides a mechanism by which the data is passed between operators. Each operator node returns one task output per one operator output.
### Concurrency limit
For various reasons we may want to limit the concurrency of operators. Obviously an operator cannot run in parallel with itself due to reasons outlined above - but we may also want to limit the number of concurrently running different operators from various groups. For example, due to technical limitations of `dali::ThreadPool`, it's impossible to run multiple CPU operators simultaneously because concurrent submission of work to the threadpool results in a hang.
Concurrency is limited with a tasking::Semaphore shared pointer stored in a node.
### Output buffer limit
When scheduling multiple iterations ahead, it's possible for "bubbles" to form - if an operator produces its data quickly but its consumers are slow, the operator node would create multiple output buffers which would live inside tasking framework as data being passed between tasks. In order to limit the number of active output buffers we need another semaphore - but this semaphore needs to be lowered until all consumers are done with the data. To achieve this an auxiliary (empty) task is scheduled to succeed all of the consumers and, upon completion, it raises the semaphore.

#### Example
In this example the operator Op1 has a maximum of 2 output buffers. The iteration 1 proceeds without delay (only waiting for the previous iteration of each operator). In iteration 2, however, the operator Op1 has to wait before it allocates an output buffer. The blue boxes represent the operators' "main" tasks, the red boxes - the "release_outputs" task and the green boxes - semaphore operations.
![image](https://github.com/user-attachments/assets/a3025551-96af-42e3-a238-544df2d66121)

_NOTE: This diagram represents task life cycle, not thread activity - with tasking::Scheduler the worker threads never actually wait as long as there are tasks to execute. Resource (e.g. semaphore) acquisition follows "'wait all" semantics, so the extended "acquire" boxes are not an accurate representation of "waiting on a semaphore"._

## Workspace lifecycle
The new executor follows a "linear" memory usage model - buffers are created as needed and thrown away as soon as they're no longer used. The memory pool is solely responsible for efficient memory recycling.
Despite the buffers' being disposable, the workspace object contains some additional structure (e.g. mapping of argument input names to indices) which we don't need to recreate each time the operator runs. Because of that, each ExecNode has a Workspace object which stores the workspace. The workspace is removed from ExecNode at the beginning of the task body and returned to it when the task completes.
Life cycle:
- Get workspace from ExecNode
- (run the task)
- Clear workspace
- Put workspace back in ExecNode

_(*) Clearing the workspace means removing all TensorLists from it_

## ExecNode Tasks
### Operator task
The operator task performs the following operations:
- get the inputs from parent tasks
- wait for inputs in the operator's stream/order
- put the inputs into the workspace
- apply default input layouts, if necessary
- compute the batch size
- create the outputs
- run operator's Setup
- Resize the outputs, if necessary
- run operator's Run
- propagate metadata
- record CUDA events
- restore empty input layouts, if necessary

### Output task
- wait for the inputs
- construct the "output workspace" where the outputs are workspace outputs (and the tasks's inputs)
- move the output workspace to the task's return value

---------

Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
  • Loading branch information
mzient authored Aug 9, 2024
1 parent 7cdfa0e commit bff5aef
Show file tree
Hide file tree
Showing 13 changed files with 2,369 additions and 1 deletion.
5 changes: 4 additions & 1 deletion dali/pipeline/executor/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,9 @@
# limitations under the License.

# Get all the source files and dump test files

add_subdirectory(executor2)

collect_headers(DALI_INST_HDRS PARENT_SCOPE)
collect_sources(DALI_SRCS PARENT_SCOPE)
collect_test_sources(DALI_TEST_SRCS PARENT_SCOPE)
18 changes: 18 additions & 0 deletions dali/pipeline/executor/executor2/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Get all the source files and dump test files
collect_headers(DALI_INST_HDRS PARENT_SCOPE)
collect_sources(DALI_SRCS PARENT_SCOPE)
collect_test_sources(DALI_TEST_SRCS PARENT_SCOPE)
107 changes: 107 additions & 0 deletions dali/pipeline/executor/executor2/exec2_ops_for_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "dali/pipeline/executor/executor2/exec2_ops_for_test.h"
#include <vector>
#include "dali/pipeline/operator/op_schema.h"
#include "dali/kernels/dynamic_scratchpad.h"
#include "dali/kernels/common/scatter_gather.h"
#include "dali/core/span.h"

namespace dali {

DALI_SCHEMA(Exec2TestOp) // DALI_SCHEMA can't take a macro :(
.NumInput(0, 99)
.NumOutput(1)
.AddOptionalArg("delay", "[CPU-only] in milliseconds, to wait inside the operator's Run", 1.0f)
.AddArg("addend", "a value added to the sum of inputs", DALI_INT32, true);

// DALI_REGISTER_OPERATOR can't take a macro for the name
DALI_REGISTER_OPERATOR(Exec2TestOp, exec2::test::DummyOpCPU, CPU);
DALI_REGISTER_OPERATOR(Exec2TestOp, exec2::test::DummyOpGPU, GPU);

DALI_SCHEMA(Exec2Counter)
.NumInput(0)
.NumOutput(1);

DALI_REGISTER_OPERATOR(Exec2Counter, exec2::test::CounterOp, CPU);

namespace exec2 {
namespace test {

__global__ void Sum(int *out, const int **ins, int nbuf, int buf_size) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
if (x >= buf_size)
return;
out[x] = 0;
for (int i = 0; i < nbuf; i++)
out[x] += ins[i][x];
}

void DummyOpGPU::RunImpl(Workspace &ws) {
kernels::DynamicScratchpad scratch({}, ws.stream());
int N = ws.GetRequestedBatchSize(0);
addend_.Acquire(spec_, ws, N);
scratch.Allocate<mm::memory_kind::device, int>(N);
auto addend_cpu = addend_.get();
std::vector<const int *> pointers;

std::vector<int> addend_cpu_cont(N);
for (int i = 0; i < N; i++)
addend_cpu_cont[i] = *addend_cpu[i].data;

// This should go early - if we clobber some memory with that, we'll see it
pointers.push_back(
scratch.ToGPU(ws.stream(), make_cspan(addend_cpu_cont)));

kernels::ScatterGatherGPU sg;
for (int i = 0; i < ws.NumInput(); i++) {
auto &inp = ws.Input<GPUBackend>(i);
if (!inp.IsContiguousInMemory()) {
int *cont = scratch.AllocateGPU<int>(N);
for (int s = 0; s < N; s++)
sg.AddCopy(cont + s, inp[s].data<int>(), sizeof(int));
pointers.push_back(cont);
} else {
pointers.push_back(inp[0].data<int>());
}
}

sg.Run(ws.stream());

// The goal of this part is to introduce a delay in the GPU - and make the results come late
size_t junk_size = 4 << 20;
char *junk = scratch.Allocate<mm::memory_kind::device, char>(junk_size);
for (int i = 0; i < 256; i++)
CUDA_CALL(cudaMemsetAsync(junk, i, junk_size, ws.stream()));

// After the delay, we can finally run the test body
auto &out = ws.Output<GPUBackend>(0);
if (!out.IsContiguousInMemory()) {
out.Reset();
out.SetContiguity(dali::BatchContiguity::Contiguous);
out.Resize(uniform_list_shape(N, TensorShape<0>()), DALI_INT32);
assert(out.IsContiguousInMemory());
}
Sum<<<div_ceil(N, 256), 256, 0, ws.stream()>>>(
out[0].mutable_data<int>(),
scratch.ToGPU(ws.stream(), pointers),
ws.NumInput() + 1,
N);
}


} // namespace test
} // namespace exec2
} // namespace dali
148 changes: 148 additions & 0 deletions dali/pipeline/executor/executor2/exec2_ops_for_test.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_PIPELINE_EXECUTOR_EXECUTOR2_EXEC2_OPS_FOR_TEST_H_
#define DALI_PIPELINE_EXECUTOR_EXECUTOR2_EXEC2_OPS_FOR_TEST_H_

#include <gtest/gtest.h>
#include <string>
#include <thread>
#include <chrono>
#include <utility>
#include <vector>
#include "dali/pipeline/operator/operator.h"
#include "dali/pipeline/operator/arg_helper.h"
#include "dali/pipeline/graph/op_graph2.h"

namespace dali {
namespace exec2 {
namespace test {

constexpr char kTestOpName[] = "Exec2TestOp";

/** A dummy operator that takes a bunch of scalar inputs and returns their sum.
*
* This operator contains a sleep to increase latency and expose bugs.
*/
class DummyOpCPU : public Operator<CPUBackend> {
public:
explicit DummyOpCPU(const OpSpec &spec) : Operator<CPUBackend>(spec) {
instance_name_ = spec_.GetArgument<string>("name");
delay_ms_ = spec_.GetArgument<float>("delay");
}

bool SetupImpl(std::vector<OutputDesc> &outs, const Workspace &ws) override {
int N = ws.GetRequestedBatchSize(0);
outs.resize(ws.NumOutput());
outs[0].shape = uniform_list_shape(N, TensorShape<>{});
outs[0].type = DALI_INT32;
return true;
}

void RunImpl(Workspace &ws) override {
int N = ws.GetRequestedBatchSize(0);
if (delay_ms_)
std::this_thread::sleep_for(std::chrono::duration<double, std::milli>(delay_ms_));
addend_.Acquire(spec_, ws, N);
sample_sums_.resize(N);
auto &tp = ws.GetThreadPool();
for (int s = 0; s < N; s++) {
auto sample_sum = [&, s](int) {
int sum = *addend_[s].data + s;
for (int i = 0; i < ws.NumInput(); i++) {
sum += *ws.Input<CPUBackend>(i)[s].data<int>();
}
sample_sums_[s] = sum;
};
tp.AddWork(sample_sum);
}
tp.RunAll(true);
for (int s = 0; s < N; s++)
*ws.Output<CPUBackend>(0)[s].mutable_data<int>() = sample_sums_[s];
}

bool CanInferOutputs() const override { return true; }
ArgValue<int> addend_{"addend", spec_};
double delay_ms_ = 0;

std::vector<int> sample_sums_;
std::string instance_name_;
};

/** A dummy operator that takes a bunch of scalar inputs and returns their sum.
*
* This operator introduces some pointless GPU work to increase latency and expose bugs.
*/
class DummyOpGPU : public Operator<GPUBackend> {
public:
explicit DummyOpGPU(const OpSpec &spec) : Operator<GPUBackend>(spec) {
instance_name_ = spec_.GetArgument<string>("name");
}

bool SetupImpl(std::vector<OutputDesc> &outs, const Workspace &ws) override {
int N = ws.GetRequestedBatchSize(0);
outs.resize(ws.NumOutput());
outs[0].shape = uniform_list_shape(N, TensorShape<>{});
outs[0].type = DALI_INT32;
return true;
}

void RunImpl(Workspace &ws) override;

bool CanInferOutputs() const override { return true; }

private:
ArgValue<int> addend_{"addend", spec_};

std::string instance_name_;
};


constexpr char kCounterOpName[] = "Exec2Counter";

/** An operator with state.
*
* This operator counts iterations. Its purpose is to check that iterations are executed
* in correct order.
*/
class CounterOp : public Operator<CPUBackend> {
public:
explicit CounterOp(const OpSpec &spec) : Operator<CPUBackend>(spec) {
}

bool SetupImpl(std::vector<OutputDesc> &outs, const Workspace &ws) override {
int N = ws.GetRequestedBatchSize(0);
outs.resize(ws.NumOutput());
outs[0].shape = uniform_list_shape(N, TensorShape<>{});
outs[0].type = DALI_INT32;
return true;
}

void RunImpl(Workspace &ws) override {
int N = ws.GetRequestedBatchSize(0);
for (int s = 0; s < N; s++) {
*ws.Output<CPUBackend>(0)[s].mutable_data<int>() = counter++;
}
}

bool CanInferOutputs() const override { return true; }

int counter = 0;
};

} // namespace test
} // namespace exec2
} // namespace dali

#endif // DALI_PIPELINE_EXECUTOR_EXECUTOR2_EXEC2_OPS_FOR_TEST_H_
Loading

0 comments on commit bff5aef

Please sign in to comment.