Skip to content

Commit

Permalink
Add a preprocessing pass to move entire function into a single dispat…
Browse files Browse the repository at this point in the history
…ch. (#14578)

For cases where the model is very small and does not have much concurrency, it is better to move the entire function body into a single dispatch. Eventually the default heuristics can probably figure out when a model is "too small", but for now this PR adds a pass to move the entire function body into a single dispatch to use as a way to find codegen issues such an approach throws up, and also to experiment with different heuristics needed to find such dispatches automatically.
  • Loading branch information
MaheshRavishankar authored Aug 9, 2023
1 parent dd89a32 commit eeb6e80
Show file tree
Hide file tree
Showing 18 changed files with 337 additions and 99 deletions.
114 changes: 64 additions & 50 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,31 +280,78 @@ static void printDispatchWorkgroupsCountRegion(OpAsmPrinter &p, Operation *op,
// flow.dispatch.region
//===----------------------------------------------------------------------===//

// Verifies the workgroup count

static LogicalResult
verifyWorkgroupCountRegion(Operation *op, ValueRange workload, Region &region) {
// Verify the workload operands match the expected capture args.
if (workload.size() != region.getNumArguments()) {
return op->emitOpError()
<< "workload operands and workgroup count args mismatch ("
<< workload.size() << " vs " << region.getNumArguments() << ")";
}
for (auto [index, values] :
llvm::enumerate(llvm::zip_equal(workload, region.getArguments()))) {
auto [workloadValue, capturedArg] = values;
if (workloadValue.getType() != capturedArg.getType()) {
return op->emitOpError()
<< "workload value " << index << " type mismatch; operand is "
<< workloadValue.getType() << " but region captures "
<< capturedArg.getType();
}
}

// Verify the return ops all provide XYZ values.
for (auto returnOp : region.getOps<IREE::Flow::ReturnOp>()) {
if (returnOp.getNumOperands() != 3 ||
!llvm::all_of(returnOp.getOperandTypes(),
[](Type type) { return type.isIndex(); })) {
return returnOp.emitOpError() << "workgroup count region must return "
"the XYZ dimension counts";
}
}

return success();
}

LogicalResult DispatchRegionOp::verify() {
// No block arguments.
if (!getBody().getArguments().empty())
if (!getBody().getArguments().empty()) {
return emitOpError() << "expected no block arguments";

// Only one block.
if (!getBody().hasOneBlock())
return emitOpError() << "expected exactly 1 block";
}

// Verify terminator.
auto returnOp = dyn_cast<Flow::ReturnOp>(getBody().front().getTerminator());
if (!returnOp)
return emitOpError() << "expected 'flow.return' terminator";
for (const auto [resultType, returnType] :
llvm::zip_equal(getResultTypes(), returnOp->getOperandTypes()))
if (resultType != returnType)
return returnOp->emitOpError()
<< "operand types do not match with parent results";
SmallVector<Flow::ReturnOp> returnOps;
for (Block &block : getBody()) {
if (auto returnOp =
dyn_cast_or_null<Flow::ReturnOp>(block.getTerminator())) {
returnOps.push_back(returnOp);
}
}
for (auto returnOp : returnOps) {
for (const auto [resultType, returnType] :
llvm::zip_equal(getResultTypes(), returnOp->getOperandTypes()))
if (resultType != returnType) {
return returnOp->emitOpError()
<< "operand types do not match with parent results";
}
}

// Make sure that all returned values are ranked tensors.
for (Type t : getResultTypes())
if (!llvm::isa<RankedTensorType>(t))
for (Type t : getResultTypes()) {
if (!llvm::isa<RankedTensorType>(t)) {
return emitOpError() << "only ranked tensor results are allowed";
}
}

return success();
Region &workgroupCount = getWorkgroupCount();
if (workgroupCount.empty()) {
return success();
}

// If workgroup count region exists, check it has a single block.
return verifyWorkgroupCountRegion(getOperation(), getWorkload(),
getWorkgroupCount());
}

ParseResult DispatchRegionOp::parse(OpAsmParser &parser,
Expand Down Expand Up @@ -348,7 +395,6 @@ ParseResult DispatchRegionOp::parse(OpAsmParser &parser,
return failure();
if (parser.parseRegion(*bodyRegion))
return failure();
ensureTerminator(*bodyRegion, parser.getBuilder(), result.location);

if (parseDispatchWorkgroupsCountRegion(parser, *workloadCountRegion)) {
return failure();
Expand Down Expand Up @@ -868,38 +914,6 @@ static void printDispatchWorkgroupBody(OpAsmPrinter &p, Operation *op,
/*printBlockTerminators=*/true);
}

LogicalResult verifyWorkgroupCountRegion(Operation *op, ValueRange workload,
Region &region) {
// Verify the workload operands match the expected capture args.
if (workload.size() != region.getNumArguments()) {
return op->emitOpError()
<< "workload operands and workgroup count args mismatch ("
<< workload.size() << " vs " << region.getNumArguments() << ")";
}
for (auto [index, values] :
llvm::enumerate(llvm::zip_equal(workload, region.getArguments()))) {
auto [workloadValue, capturedArg] = values;
if (workloadValue.getType() != capturedArg.getType()) {
return op->emitOpError()
<< "workload value " << index << " type mismatch; operand is "
<< workloadValue.getType() << " but region captures "
<< capturedArg.getType();
}
}

// Verify the return ops all provide XYZ values.
for (auto returnOp : region.getOps<IREE::Flow::ReturnOp>()) {
if (returnOp.getNumOperands() != 3 ||
!llvm::all_of(returnOp.getOperandTypes(),
[](Type type) { return type.isIndex(); })) {
return returnOp.emitOpError() << "workgroup count region must return "
"the XYZ dimension counts";
}
}

return success();
}

LogicalResult DispatchWorkgroupsOp::verify() {
Operation *op = getOperation();

Expand Down Expand Up @@ -1043,7 +1057,7 @@ DispatchWorkgroupsOp::getOperandAccess(unsigned operandIndex) {

IREE::Util::ValueAccess
DispatchWorkgroupsOp::getResultAccess(unsigned resultIndex) {
unsigned startIndex = getBody()->getNumArguments() - getNumResults();
unsigned startIndex = getWorkgroupBody().getNumArguments() - getNumResults();
BlockArgument arg =
getWorkgroupBody().front().getArgument(startIndex + resultIndex);
if (auto tensorType = llvm::dyn_cast<DispatchTensorType>(arg.getType())) {
Expand Down
4 changes: 1 addition & 3 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ let opDocGroup = OpGroupPartitionedRegionOps in {

def FLOW_DispatchRegionOp : FLOW_PureOp<"dispatch.region", [
Util_ShapeAwareOp,
AttrSizedOperandSegments,
SingleBlockImplicitTerminator<"IREE::Flow::ReturnOp">]> {
AttrSizedOperandSegments]> {
let summary = [{a group of ops}];
let description = [{
This op is a container/grouping of ops. It represents a fusion group before
Expand Down Expand Up @@ -76,7 +75,6 @@ def FLOW_DispatchRegionOp : FLOW_PureOp<"dispatch.region", [
def FLOW_DispatchWorkgroupsOp : FLOW_PureOp<"dispatch.workgroups", [
IsolatedFromAbove,
AttrSizedOperandSegments,
SingleBlockImplicitTerminator<"IREE::Flow::ReturnOp">,
DeclareOpInterfaceMethods<Util_ClosureOpInterface>,
DeclareOpInterfaceMethods<Util_TiedOpInterface, [
"getTiedOperandsIndexAndLength",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ namespace {
// leave the cleanup of redundant work to further optimization passes to keep
// this simple.
static void captureDims(IREE::Flow::DispatchWorkgroupsOp dispatchOp) {
auto *entryBlock = dispatchOp.getBody();
Region &body = dispatchOp.getWorkgroupBody();
if (body.empty()) {
return;
}
auto *entryBlock = &body.front();

// Map of SSA values on the outside of the op to arguments on the inside.
// This lets us avoid capturing duplicate values - they'd be cleaned up
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,15 @@ findFirstTiedValueOutsideOfRegionOp(Flow::DispatchRegionOp regionOp,

while (!isOutside(value)) {
auto tiedOpInterface = value.getDefiningOp<IREE::Util::TiedOpInterface>();
if (!tiedOpInterface)
if (!tiedOpInterface) {
// Reached an op that does not implement the interface.
return std::nullopt;
}
value = tiedOpInterface.getTiedResultOperand(value);
if (!value)
if (!value) {
// Nothing is tied here.
return std::nullopt;
}
}

return value;
Expand All @@ -84,13 +86,13 @@ findFirstTiedValueOutsideOfRegionOp(Flow::DispatchRegionOp regionOp,
FailureOr<Flow::DispatchWorkgroupsOp>
rewriteFlowDispatchRegionToFlowDispatchWorkgroups(
Flow::DispatchRegionOp regionOp, RewriterBase &rewriter) {
// Only ops with a single block are supported.
Region &region = regionOp.getBody();
if (!region.hasOneBlock())
return failure();
Block &body = region.front();
auto terminator = cast<Flow::ReturnOp>(body.getTerminator());
unsigned numResults = terminator->getNumOperands();
// Currently this does not handle empty `flow.dispatch.region` ops.
if (region.empty()) {
return rewriter.notifyMatchFailure(regionOp,
"unhandled op with empty region");
}
unsigned numResults = regionOp->getNumResults();

// Prepare rewriter.
OpBuilder::InsertionGuard guard(rewriter);
Expand Down Expand Up @@ -118,7 +120,14 @@ rewriteFlowDispatchRegionToFlowDispatchWorkgroups(
DenseSet<Value> tiedArgumentsSet;
SmallVector<int64_t> tiedArguments(numResults,
IREE::Util::TiedOpInterface::kUntiedIndex);
for (const auto &it : llvm::enumerate(terminator->getOperands())) {
SmallVector<Flow::ReturnOp> origTerminators;
region.walk(
[&](Flow::ReturnOp returnOp) { origTerminators.push_back(returnOp); });
assert(!origTerminators.empty() && "expected at least one terminator");
// Use one of the terminators to get the the `tiedArguments` set.
// TODO: Check that using all terminators gives you the same result.
for (const auto &it :
llvm::enumerate(origTerminators.front()->getOperands())) {
auto tiedArgument =
findFirstTiedValueOutsideOfRegionOp(regionOp, it.value());
if (!tiedArgument.has_value())
Expand Down Expand Up @@ -166,15 +175,20 @@ rewriteFlowDispatchRegionToFlowDispatchWorkgroups(
bvm.map(arguments, workgroupsOp.getInputBlockArguments());

// Create DispatchTensorLoadOp for all tensor arguments.
assert(workgroupsOp.getWorkgroupBody().hasOneBlock() &&
"expected one block after constructor");
Block &newBody = workgroupsOp.getWorkgroupBody().getBlocks().front();
assert(newBody.empty() && "expected empty block after constructor");
rewriter.setInsertionPointToStart(&newBody);
Region &newBody = workgroupsOp.getWorkgroupBody();
assert(llvm::hasSingleElement(newBody) &&
"expected `flow.dispatch.workgroup` op to be created with a single "
"block");

Block *newBodyEntry = &newBody.front();
rewriter.setInsertionPointToStart(newBodyEntry);
SmallVector<Value> argValues;
for (const auto &it : llvm::enumerate(arguments)) {
auto tensorType = llvm::dyn_cast<RankedTensorType>(it.value().getType());
if (!tensorType)
if (!tensorType) {
argValues.push_back(it.value());
continue;
}
auto inputBbArg = workgroupsOp.getInputBlockArgument(it.index());
auto dims =
Util::findVariadicDynamicDims(it.index(), arguments, argumentDims);
Expand All @@ -185,44 +199,55 @@ rewriteFlowDispatchRegionToFlowDispatchWorkgroups(
Value loadedTensor = rewriter.create<IREE::Flow::DispatchTensorLoadOp>(
loc, tensorType, inputBbArg, bbArgDims);
bvm.map(it.value(), loadedTensor);
argValues.push_back(loadedTensor);
}

// Move regionOp body into the workgroupsOp.
newBody.getOperations().splice(newBody.end(), body.getOperations());
rewriter.inlineRegionBefore(region, newBody, newBody.end());
// Merge the enrty block of `newBody` with the original entry block from the
// region.
Block *origEntry = &(*(std::next(newBody.begin())));
rewriter.mergeBlocks(origEntry, newBodyEntry);

for (Value argument : arguments) {
argument.replaceUsesWithIf(bvm.lookup(argument), [&](OpOperand &operand) {
return workgroupsOp->isProperAncestor(operand.getOwner());
});
}

// Update terminator.
rewriter.setInsertionPoint(terminator);
for (const auto &it : llvm::enumerate(terminator->getOperands())) {
auto outputBbArg = workgroupsOp.getOutputBlockArgument(it.index());
ValueRange dims;
if (tiedArguments[it.index()] ==
IREE::Util::TiedOpInterface::kUntiedIndex) {
dims = regionOp.getResultDynamicDims(it.index());
} else {
// This assumes that the number of dynamic dims does not change when
// following an SSA use-def chain of tied values.
dims = Util::findVariadicDynamicDims(tiedArguments[it.index()], arguments,
argumentDims);
}
SmallVector<Flow::ReturnOp> terminators;
newBody.walk(
[&](Flow::ReturnOp returnOp) { terminators.push_back(returnOp); });
for (auto terminator : terminators) {
rewriter.setInsertionPoint(terminator);
for (const auto &it : llvm::enumerate(terminator->getOperands())) {
auto outputBbArg = workgroupsOp.getOutputBlockArgument(it.index());
ValueRange dims;
if (tiedArguments[it.index()] ==
IREE::Util::TiedOpInterface::kUntiedIndex) {
dims = regionOp.getResultDynamicDims(it.index());
} else {
// This assumes that the number of dynamic dims does not change when
// following an SSA use-def chain of tied values.
dims = Util::findVariadicDynamicDims(tiedArguments[it.index()],
arguments, argumentDims);
}
#ifndef NDEBUG
auto tensorType = it.value().getType().cast<RankedTensorType>();
assert(dims.size() == tensorType.getNumDynamicDims() &&
"mismatching number of dynamic dims");
auto tensorType = it.value().getType().cast<RankedTensorType>();
assert(dims.size() == tensorType.getNumDynamicDims() &&
"mismatching number of dynamic dims");
#endif // NDEBUG
SmallVector<Value> bbArgDims =
llvm::map_to_vector(dims, [&](Value v) { return bvm.lookup(v); });
rewriter.create<IREE::Flow::DispatchTensorStoreOp>(loc, it.value(),
outputBbArg, bbArgDims);
}
SmallVector<Value> bbArgDims =
llvm::map_to_vector(dims, [&](Value v) { return bvm.lookup(v); });
rewriter.create<IREE::Flow::DispatchTensorStoreOp>(
loc, it.value(), outputBbArg, bbArgDims);
}

// Delete the old terminator and create a new one.
rewriter.create<IREE::Flow::ReturnOp>(loc);
rewriter.eraseOp(terminator);
// Delete the old terminator and create a new one.
rewriter.create<IREE::Flow::ReturnOp>(loc);
rewriter.eraseOp(terminator);
}

rewriter.replaceOp(regionOp, workgroupsOp.getResults());
return workgroupsOp;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,11 @@ createDefaultWorkgroupCountRegion(RewriterBase &rewriter,

// Annotate the values captures as workload with their position in the
// workload list.
rewriter.setInsertionPointToStart(workgroupsOp.getBody());
Region &body = workgroupsOp.getWorkgroupBody();
if (body.empty()) {
return;
}
rewriter.setInsertionPointToStart(&body.front());
int ordinalNumber = 0;
for (auto [index, operand] : llvm::enumerate(workgroupsOp.getArguments())) {
if (!llvm::isa<IndexType>(operand.getType()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ summarizeDispatchWorkgroupsOp(DispatchWorkgroupsOp regionOp) {
Operation *bestOp = NULL;
const int64_t kMinEstimatedCost = -1;
int64_t bestEstimatedCost = kMinEstimatedCost;
regionOp.getBodyRegion().walk([&](Operation *op) {
regionOp.getWorkgroupBody().walk([&](Operation *op) {
TypeSwitch<Operation *>(op)
.Case<linalg::LinalgOp>([&](auto op) {
int64_t estimatedCost = estimateLinalgOpCost(op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,32 @@ func.func @existing_count_region(%arg0 : index, %arg1 : index) -> tensor<?x?xf32
// CHECK: count(%[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index)
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: flow.return %[[ARG2]], %[[ARG3]], %[[C1]]

// -----

func.func @simple_test_with_cfg(%arg0: i1) -> (tensor<10x20xf32>) {
%cst = arith.constant dense<1.000000e+00> : tensor<10x20xf32>
%0 = flow.dispatch.region -> (tensor<10x20xf32>) {
%cst_0 = arith.constant dense<1.000000e+00> : tensor<10x20xf32>
cf.cond_br %arg0, ^bb1, ^bb2
^bb1: // pred: ^bb0
%2 = tensor.empty() : tensor<10x20xf32>
flow.return %2 : tensor<10x20xf32>
^bb2: // pred: ^bb0
flow.return %cst_0 : tensor<10x20xf32>
}
return %0 : tensor<10x20xf32>
}
// CHECK-LABEL: func @simple_test_with_cfg
// CHECK-SAME: %[[ARG0:.+]]: i1
// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups(%[[ARG0]])
// CHECK-NEXT: %[[ARG1:.+]]: i1, %[[ARG2:.+]]: !flow.dispatch.tensor
// CHECK: %[[CST:.+]] = arith.constant
// CHECK: ^[[BB1:.+]]:
// CHECK: %[[EMPTY:.+]] = tensor.empty()
// CHECK: flow.dispatch.tensor.store %[[EMPTY]], %[[ARG2]]
// CHECK: flow.return
// CHECK: ^[[BB2:.+]]:
// CHECK: flow.dispatch.tensor.store %[[CST]], %[[ARG2]]
// CHECK: flow.return
// CHECK: return %[[RESULT]]
Loading

0 comments on commit eeb6e80

Please sign in to comment.