Skip to content

Commit

Permalink
Improving the *.tensor.trace op to carry shapes/encodings. (#15228)
Browse files Browse the repository at this point in the history
This avoids a bunch of IR noise when lowering into the stream dialect
and forces all traced tensors to be transfered into staging resources.
This means we no longer require mapping (or worse, emulated mapping) on
any target in order to trace as we'll have the compiler implement the
readback asynchronously.

A change I held off on here but should be done soon is making the trace
ops return their tensors so that we can asynchronously schedule them.
Then instead of emitting runtime calls on staging resources we could
batch up the transfers into a ringbuffer and trace out without
introducing dispatch-to-dispatch host synchronization.
  • Loading branch information
benvanik authored Oct 19, 2023
1 parent c232aeb commit 3a70dda
Show file tree
Hide file tree
Showing 32 changed files with 559 additions and 140 deletions.
74 changes: 74 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,51 @@ static bool doesSliceSpanWholeTarget(
return true;
}

//===----------------------------------------------------------------------===//
// custom<ShapedOperandList>($values, type($values), $value_dims)
//===----------------------------------------------------------------------===//
// %value : type{%dynamic_dims}, ...

static ParseResult parseShapedOperandList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
SmallVectorImpl<Type> &valueTypes,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &valueDims) {
do {
values.emplace_back();
valueTypes.emplace_back();
if (failed(parser.parseOperand(values.back())) ||
failed(parser.parseColon()) ||
failed(parser.parseType(valueTypes.back())))
return failure();
if (int64_t dynamicDimCount =
cast<ShapedType>(valueTypes.back()).getNumDynamicDims()) {
if (failed(parser.parseOperandList(valueDims, dynamicDimCount,
AsmParser::Delimiter::Braces)))
return failure();
}
} while (succeeded(parser.parseOptionalComma()));
return success();
}

static void printShapedOperandList(OpAsmPrinter &p, Operation *op,
ValueRange values, TypeRange valueTypes,
ValueRange valueDims) {
llvm::interleaveComma(llvm::zip_equal(values, valueTypes), p, [&](auto it) {
auto [value, valueType] = it;
p << value;
p << " : ";
p << valueType;
if (int64_t dynamicDimCount =
cast<ShapedType>(valueType).getNumDynamicDims()) {
p << "{";
llvm::interleaveComma(valueDims.take_front(dynamicDimCount), p);
valueDims = valueDims.drop_front(dynamicDimCount);
p << "}";
}
});
}

//===----------------------------------------------------------------------===//
// custom<WorkgroupCountRegion>($body)
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1648,6 +1693,35 @@ SmallVector<int64_t> TensorUpdateOp::getTiedResultOperandIndices() {
return {0}; // target
}

//===----------------------------------------------------------------------===//
// flow.tensor.trace
//===----------------------------------------------------------------------===//

LogicalResult TensorTraceOp::verify() {
TensorTraceOp op = *this;
if (failed(verifyOpDynamicDims(op, op.getValues(), op.getValueDims()))) {
return failure();
}
return success();
}

ValueRange TensorTraceOp::getOperandDynamicDims(unsigned idx) {
auto valueDims = getValueDims();
for (unsigned i = 0; i <= idx; ++i) {
auto valueType = cast<ShapedType>(getValues()[i].getType());
int64_t dynamicDimCount = valueType.getNumDynamicDims();
if (i == idx) {
return valueDims.take_front(dynamicDimCount);
}
valueDims = valueDims.drop_front(dynamicDimCount);
}
return ValueRange{};
}

ValueRange TensorTraceOp::getResultDynamicDims(unsigned idx) {
return ValueRange{};
}

//===----------------------------------------------------------------------===//
// Public methods
//===----------------------------------------------------------------------===//
Expand Down
30 changes: 23 additions & 7 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1509,21 +1509,37 @@ def FLOW_TensorUpdateOp : FLOW_PureOp<"tensor.update", [
let hasFolder = 1;
}

def FLOW_TensorTraceOp : FLOW_Op<"tensor.trace", []> {
let summary = [{trace value(s) operation}];
def FLOW_TensorTraceOp : FLOW_Op<"tensor.trace", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<Util_ShapeAwareOp>,
]> {
let summary = [{traces one or more tensor values at runtime}];
let description = [{
Traces out to a runtime trace sink (console, log file, etc) the given
tensors and titles them with the given key. The key is informational only
and useful for titling/marking specific sets of tensors for easier
searching.
tensors. The key is arbitrary and can be used for identifying the set of
values being traced.
}];

let arguments = (ins
StrAttr:$key,
Variadic<FLOW_Tensor>:$operands
Variadic<FLOW_Tensor>:$values,
FLOW_ShapeDynamicDims:$value_dims
);

let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
let assemblyFormat = [{
$key `=` `[`
custom<ShapedOperandList>($values, type($values), $value_dims)
`]` attr-dict-with-keyword
}];

let builders = [
OpBuilder<(ins "StringRef":$key, "ValueRange":$values), [{
build($_builder, $_state, key, values,
IREE::Util::buildDynamicDimsForValues($_state.location, values, $_builder));
}]>,
];

let hasVerifier = 1;
}

//===---------------------------------------------------------------------===//
Expand Down
16 changes: 16 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,19 @@ func.func @tensorUpdateDynamic(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x4xf32>,
%0 = flow.tensor.update %arg0, %arg1[%arg2, %arg3] : tensor<?x?xf32>{%c1, %c2} -> %arg1 as tensor<?x4xf32>{%c3}
return %0 : tensor<?x4xf32>
}

// -----

// CHECK-LABEL: @tensorTrace
// CHECK-SAME: (%[[TENSOR0:.+]]: tensor<5xf32>, %[[TENSOR1:.+]]: tensor<?x3x?xi32>, %[[TENSOR1_DIM0:.+]]: index, %[[TENSOR1_DIM2:.+]]: index)
func.func @tensorTrace(%tensor0: tensor<5xf32>, %tensor1: tensor<?x3x?xi32>, %tensor1_dim0: index, %tensor1_dim2: index) {
// CHECK: flow.tensor.trace "FOOBAR" = [
// CHECK-SAME: %[[TENSOR0]] : tensor<5xf32>,
// CHECK-SAME: %[[TENSOR1]] : tensor<?x3x?xi32>{%[[TENSOR1_DIM0]], %[[TENSOR1_DIM2]]}
// CHECK-SAME: ]
flow.tensor.trace "FOOBAR" = [
%tensor0 : tensor<5xf32>,
%tensor1 : tensor<?x3x?xi32>{%tensor1_dim0, %tensor1_dim2}
]
return
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ class InjectDispatchTracingPass

// Input tensors:
OpBuilder builder(dispatchOp);
builder.create<TensorTraceOp>(
builder.create<IREE::Flow::TensorTraceOp>(
dispatchOp.getLoc(),
builder.getStringAttr(entryPointName + " inputs"),
filterTensorValues(dispatchOp.getArguments()));

// Output tensors:
builder.setInsertionPointAfter(dispatchOp);
builder.create<TensorTraceOp>(
builder.create<IREE::Flow::TensorTraceOp>(
dispatchOp.getLoc(),
builder.getStringAttr(entryPointName + " outputs"),
filterTensorValues(dispatchOp.getResults()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,20 @@ getOrdinalFromDebugTarget(std::string marker) {
}

// Inserts flow.tensor.trace ops around the specified dispatch op.
static void traceOpWithName(DispatchOp dispatchOp, std::string name) {
static void traceOpWithName(IREE::Flow::DispatchOp dispatchOp,
std::string name) {
OpBuilder builder(dispatchOp);
// Input tensors:
builder.create<TensorTraceOp>(
builder.create<IREE::Flow::TensorTraceOp>(
dispatchOp.getLoc(), builder.getStringAttr(name + " inputs"),
filterNonTensorValues(dispatchOp.getArguments()));

// Output tensors:
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointAfter(dispatchOp);
builder.create<TensorTraceOp>(dispatchOp.getLoc(),
builder.getStringAttr(name + " outputs"),
filterNonTensorValues(dispatchOp.getResults()));
builder.create<IREE::Flow::TensorTraceOp>(
dispatchOp.getLoc(), builder.getStringAttr(name + " outputs"),
filterNonTensorValues(dispatchOp.getResults()));
}

// Breaks the given function on the specified op by simply returning immediately
Expand Down Expand Up @@ -162,7 +163,8 @@ struct InsertDebugTargetAtOrdinalPass
localTraceOrdinal = traceOrdinal;

auto &bodyRegion = op.getFunctionBody();
auto dispatchOps = llvm::to_vector<8>(bodyRegion.getOps<DispatchOp>());
auto dispatchOps =
llvm::to_vector<8>(bodyRegion.getOps<IREE::Flow::DispatchOp>());

// Trace on a valid ordinal.
if (localTraceOrdinal >= 0 && localTraceOrdinal < dispatchOps.size()) {
Expand Down Expand Up @@ -222,8 +224,8 @@ struct InsertDebugTargetAtSymbolPass

// Find the target dispatch to break on and trace on all matching
// dispatches.
DispatchOp breakTarget = DispatchOp();
funcOp.walk([&](DispatchOp dispatchOp) {
IREE::Flow::DispatchOp breakTarget;
funcOp.walk([&](IREE::Flow::DispatchOp dispatchOp) {
std::string entryPointName =
dispatchOp.getEntryPoint().getRootReference().getValue().str();
for (FlatSymbolRefAttr nestedRef :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
// CHECK-SAME: (%[[ARG0:.+]]: tensor<4xf32>)
func.func @singleDispatch(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%c4 = arith.constant 4 : index
// CHECK: flow.tensor.trace {key = "ex::entry0 inputs"} %[[ARG0]] : tensor<4xf32>
// CHECK: flow.tensor.trace "ex::entry0 inputs" = [%[[ARG0]] : tensor<4xf32>]
// CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @ex::@entry0[%c4](%[[ARG0]]) : (tensor<4xf32>) -> tensor<4xf32>
%0 = flow.dispatch @ex::@entry0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: flow.tensor.trace {key = "ex::entry0 outputs"} %[[RET0]] : tensor<4xf32>
// CHECK-NEXT: flow.tensor.trace "ex::entry0 outputs" = [%[[RET0]] : tensor<4xf32>]
// CHECK-NEXT: return %[[RET0]]
return %0 : tensor<4xf32>
}
Expand All @@ -19,15 +19,15 @@ func.func @singleDispatch(%arg0: tensor<4xf32>) -> tensor<4xf32> {
func.func @multiDispatch(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%c4 = arith.constant 4 : index

// CHECK: flow.tensor.trace {key = "ex::entry0 inputs"} %[[ARG0]] : tensor<4xf32>
// CHECK: flow.tensor.trace "ex::entry0 inputs" = [%[[ARG0]] : tensor<4xf32>]
// CHECK-NEXT: %[[RET0:.+]] = flow.dispatch @ex::@entry0[%c4](%[[ARG0]]) : (tensor<4xf32>) -> tensor<4xf32>
%0 = flow.dispatch @ex::@entry0[%c4](%arg0) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: flow.tensor.trace {key = "ex::entry0 outputs"} %[[RET0]] : tensor<4xf32>
// CHECK-NEXT: flow.tensor.trace "ex::entry0 outputs" = [%[[RET0]] : tensor<4xf32>]

// CHECK: flow.tensor.trace {key = "ex::entry1 inputs"} %[[RET0]] : tensor<4xf32>
// CHECK: flow.tensor.trace "ex::entry1 inputs" = [%[[RET0]] : tensor<4xf32>]
// CHECK-NEXT: %[[RET1:.+]] = flow.dispatch @ex::@entry1[%c4](%[[RET0]]) : (tensor<4xf32>) -> tensor<4xf32>
%1 = flow.dispatch @ex::@entry1[%c4](%0) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: flow.tensor.trace {key = "ex::entry1 outputs"} %[[RET1]] : tensor<4xf32>
// CHECK-NEXT: flow.tensor.trace "ex::entry1 outputs" = [%[[RET1]] : tensor<4xf32>]

// CHECK: return %[[RET1]]
return %1 : tensor<4xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-flow-insert-debug-target-at-ordinal{break-debug-target=@target_func:0 trace-debug-target=@target_func:0})" %s | FileCheck %s --check-prefixes=ORDINAL_0
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-flow-insert-debug-target-at-symbol{break-debug-target=dispatch_1 trace-debug-target=dispatch_1[^0-9]})" %s | FileCheck %s --check-prefixes=CHECK,SYMBOL

/// Multiple functions
// Multiple functions.

// CHECK-LABEL: func.func @target_func
// ORDINAL_0-LABEL: func.func @target_func
func.func @target_func(%arg0: tensor<4xf32>) -> !hal.buffer_view {
%c4 = arith.constant 4 : index
// CHECK: %[[D0:.+]] = flow.dispatch @dispatch_0::@dispatch_0_entry
// ORDINAL_0: flow.tensor.trace {key = "dispatch_0::dispatch_0_entry::0 inputs"}
// ORDINAL_0: flow.tensor.trace "dispatch_0::dispatch_0_entry::0 inputs"
// ORDINAL_0-NEXT: %[[D0:.+]] = flow.dispatch @dispatch_0::@dispatch_0_entry
%0 = flow.dispatch @dispatch_0::@dispatch_0_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
// ORDINAL_0-NEXT: flow.tensor.trace {key = "dispatch_0::dispatch_0_entry::0 outputs"}
// ORDINAL_0-NEXT: flow.tensor.trace "dispatch_0::dispatch_0_entry::0 outputs"
// CHECK: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1_entry
%1 = flow.dispatch @dispatch_1::@dispatch_1_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
%2 = flow.dispatch @dispatch_2::@dispatch_2_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
Expand All @@ -22,31 +22,31 @@ func.func @target_func(%arg0: tensor<4xf32>) -> !hal.buffer_view {
return %3 : !hal.buffer_view
}


// CHECK-LABEL: func.func @other_func
func.func @other_func(%arg0: tensor<4xf32>) -> !hal.buffer_view {
%c4 = arith.constant 4 : index
// CHECK: %[[D0:.+]] = flow.dispatch @dispatch_1::@dispatch_1_entry
%0 = flow.dispatch @dispatch_1::@dispatch_1_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: %[[D3:.+]] = flow.dispatch @dispatch_3::@dispatch_3_entry
%0 = flow.dispatch @dispatch_3::@dispatch_3_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>

// CHECK: %[[D1:.+]] = flow.dispatch @dispatch_2::@dispatch_2_entry
%1 = flow.dispatch @dispatch_2::@dispatch_2_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: %[[D2:.+]] = flow.dispatch @dispatch_3::@dispatch_3_entry
%2 = flow.dispatch @dispatch_3::@dispatch_3_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: %[[D4:.+]] = flow.dispatch @dispatch_4::@dispatch_4_entry
%1 = flow.dispatch @dispatch_4::@dispatch_4_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: %[[D5:.+]] = flow.dispatch @dispatch_5::@dispatch_5_entry
%2 = flow.dispatch @dispatch_5::@dispatch_5_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>

// ORDINAL: %[[ORIGINAL_EXPORT:.+]] = hal.tensor.export %[[D2]] : tensor<4xf32> -> !hal.buffer_view
// SYMBOL: %[[BREAK_EXPORT:.+]] = hal.tensor.export %[[D0]] : tensor<4xf32> -> !hal.buffer_view
// ORDINAL: %[[ORIGINAL_EXPORT:.+]] = hal.tensor.export %[[D5]] : tensor<4xf32> -> !hal.buffer_view
// SYMBOL: %[[BREAK_EXPORT:.+]] = hal.tensor.export %[[D5]] : tensor<4xf32> -> !hal.buffer_view
%3 = hal.tensor.export %2 : tensor<4xf32> -> !hal.buffer_view

/// Only break on the symbol as the ordinal specifies a different function
// Only break on the symbol as the ordinal specifies a different function.
// SYMBOL: return %[[BREAK_EXPORT]] : !hal.buffer_view
// ORDINAL: return %[[ORIGINAL_EXPORT]] : !hal.buffer_view
return %3 : !hal.buffer_view
}

// -----

// Break on a dispatch with a different number of results
// Break on a dispatch with a different number of results.

// CHECK-LABEL: func.func @target_func
func.func @target_func(%arg0: tensor<4xf32>) -> !hal.buffer_view {
%c4 = arith.constant 4 : index
Expand All @@ -64,7 +64,8 @@ func.func @target_func(%arg0: tensor<4xf32>) -> !hal.buffer_view {

// -----

// Break/trace on a dispatch not found in the target function should do nothing
// Break/trace on a dispatch not found in the target function should do nothing.

// CHECK-LABEL: func.func @target_func
func.func @target_func(%arg0: tensor<4xf32>) -> !hal.buffer_view {
%c4 = arith.constant 4 : index
Expand All @@ -78,20 +79,21 @@ func.func @target_func(%arg0: tensor<4xf32>) -> !hal.buffer_view {

// -----

/// Combine tracing and breaking on the same dispatch
// Combines tracing and breaking on the same dispatch.

// CHECK-LABEL: func.func @target_func
// CHECK-SAME: %[[ARG0:.+]]: tensor<4xf32>
func.func @target_func(%arg0: tensor<4xf32>) -> !hal.buffer_view {
%c4 = arith.constant 4 : index
// CHECK: %[[D0:.+]] = flow.dispatch @dispatch_0::@dispatch_0_entry
%0 = flow.dispatch @dispatch_0::@dispatch_0_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>

// ORDINAL: flow.tensor.trace {key = "dispatch_1::dispatch_1_entry::1 inputs"} %[[ARG0]] : tensor<4xf32>
// SYMBOL: flow.tensor.trace {key = "dispatch_1::dispatch_1_entry inputs"} %[[ARG0]] : tensor<4xf32>
// ORDINAL: flow.tensor.trace "dispatch_1::dispatch_1_entry::1 inputs" = [%[[ARG0]] : tensor<4xf32>]
// SYMBOL: flow.tensor.trace "dispatch_1::dispatch_1_entry inputs" = [%[[ARG0]] : tensor<4xf32>]
// CHECK: %[[D1:.+]] = flow.dispatch @dispatch_1::@dispatch_1_entry
%1 = flow.dispatch @dispatch_1::@dispatch_1_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
// ORDINAL: flow.tensor.trace {key = "dispatch_1::dispatch_1_entry::1 outputs"} %[[D1]] : tensor<4xf32>
// SYMBOL: flow.tensor.trace {key = "dispatch_1::dispatch_1_entry outputs"} %[[D1]] : tensor<4xf32>
// ORDINAL: flow.tensor.trace "dispatch_1::dispatch_1_entry::1 outputs" = [%[[D1]] : tensor<4xf32>]
// SYMBOL: flow.tensor.trace "dispatch_1::dispatch_1_entry outputs" = [%[[D1]] : tensor<4xf32>]

%2 = flow.dispatch @dispatch_2::@dispatch_2_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
%3 = hal.tensor.export %2 : tensor<4xf32> -> !hal.buffer_view
Expand All @@ -103,19 +105,21 @@ func.func @target_func(%arg0: tensor<4xf32>) -> !hal.buffer_view {

// -----

/// Check regex matching on symbol
// Checks regex matching on a dispatch symbol.

// CHECK-LABEL: func.func @target_func
func.func @target_func(%arg0: tensor<4xf32>) -> !hal.buffer_view {
%c4 = arith.constant 4 : index
// SYMBOL: flow.tensor.trace {key = "dispatch_1::dispatch_1_entry inputs"}

// SYMBOL: flow.tensor.trace "dispatch_1::dispatch_1_entry inputs"
// CHECK: flow.dispatch @dispatch_1::@dispatch_1_entry
%0 = flow.dispatch @dispatch_1::@dispatch_1_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
// SYMBOL: flow.tensor.trace {key = "dispatch_1::dispatch_1_entry outputs"}
// SYMBOL: flow.tensor.trace "dispatch_1::dispatch_1_entry outputs"

// SYMBOL-NOT: flow.tensor.trace {key = "dispatch_11::dispatch_11_entry inputs"}
// SYMBOL-NOT: flow.tensor.trace "dispatch_11::dispatch_11_entry inputs"
// CHECK: flow.dispatch @dispatch_11::@dispatch_11_entry
%1 = flow.dispatch @dispatch_11::@dispatch_11_entry[%c4] (%arg0) : (tensor<4xf32>) -> tensor<4xf32>
// SYMBOL-NOT: flow.tensor.trace {key = "dispatch_11::dispatch_11_entry outputs"}
// SYMBOL-NOT: flow.tensor.trace "dispatch_11::dispatch_11_entry outputs"

%2 = hal.tensor.export %1 : tensor<4xf32> -> !hal.buffer_view
return %2 : !hal.buffer_view
Expand Down
Loading

0 comments on commit 3a70dda

Please sign in to comment.