Skip to content

Commit

Permalink
Moving OutlineConstantsPass to flow and adding parameter support. (ir…
Browse files Browse the repository at this point in the history
…ee-org#17303)

This allows us to hide the stream dialect attributes from frontends and
use inline flow.tensor.constant ops with parameter attrs. Outlining now
also properly preserves hoistable attrs such as stream affinity. By
running IPO at the head of the flow pipeline we gain fusion
opportunities for hoisted (by user or by global opt) constants and then
we clean up the inlined constants at the end of flow so that the stream
dialect can handle all values consistently.
  • Loading branch information
benvanik authored May 7, 2024
1 parent d8f49dc commit 3ca0a49
Show file tree
Hide file tree
Showing 52 changed files with 636 additions and 316 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def do_strip_data(args) -> int:
):
return 1
if not inv.execute_text_pass_pipeline(
"iree-util-outline-constants, iree-util-strip-and-splat-constants"
"iree-flow-outline-constants, iree-util-strip-and-splat-constants"
):
return 2
write_output(inv, output, args)
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def FLOW_CollectiveReductionOpAttr :
def FLOW_NamedParameterAttr :
AttrDef<Flow_Dialect, "NamedParameter", [
TypedAttrInterface,
DeclareAttrInterfaceMethods<Util_SizedStorageAttr, [
"getStorageSize",
]>,
]> {
let mnemonic = "parameter.named";
let summary = [{named parameter referenced an optional scope and key}];
Expand Down
17 changes: 17 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,4 +310,21 @@ void printParameterReference(AsmPrinter &p, StringAttr scopeAttr,
p << "\"" << keyAttr.getValue() << "\"";
}

//===----------------------------------------------------------------------===//
// #flow.parameter.named<...>
//===----------------------------------------------------------------------===//

int64_t NamedParameterAttr::getStorageSize() const {
if (auto configAttr = getConfig()) {
if (auto lengthAttr = configAttr.getAs<IntegerAttr>("length")) {
return lengthAttr.getInt();
}
}
if (auto shapedType = llvm::dyn_cast<ShapedType>(getType())) {
return IREE::Util::getRoundedPhysicalStorageSize(shapedType);
} else {
return IREE::Util::getTypePhysicalStorageBitWidth(getType());
}
}

} // namespace mlir::iree_compiler::IREE::Flow
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ iree_compiler_cc_library(
"InjectTensorTracing.cpp",
"InsertDispatchDebugTargets.cpp",
"InterchangeTransposeGenericOps.cpp",
"OutlineConstants.cpp",
"OutlineDispatchExterns.cpp",
"OutlineDispatchRegions.cpp",
"Passes.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ iree_cc_library(
"InjectTensorTracing.cpp"
"InsertDispatchDebugTargets.cpp"
"InterchangeTransposeGenericOps.cpp"
"OutlineConstants.cpp"
"OutlineDispatchExterns.cpp"
"OutlineDispatchRegions.cpp"
"Passes.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,25 @@ static IREE::Util::GlobalOp createDummyInput(const std::string &namePrefix,
OpBuilder &moduleBuilder,
Explorer &explorer) {
std::string name = namePrefix + "_arg" + std::to_string(arg.getArgNumber());
return TypeSwitch<Type, IREE::Util::GlobalOp>(arg.getType())
.Case([&](IREE::HAL::BufferViewType type) {
return createImportBufferViewGlobalOp(name, arg, symbolTable,
auto globalOp =
TypeSwitch<Type, IREE::Util::GlobalOp>(arg.getType())
.Case([&](IREE::HAL::BufferViewType type) {
return createImportBufferViewGlobalOp(name, arg, symbolTable,
moduleBuilder, explorer);
})
.Case([&](IREE::HAL::BufferType type) {
return createExportBufferGlobalOp(name, arg, symbolTable,
moduleBuilder, explorer);
})
.Case([&](IREE::HAL::BufferType type) {
return createExportBufferGlobalOp(name, arg, symbolTable, moduleBuilder,
explorer);
})
.Default([&](Type type) {
return createPrimitiveDefaultGlobalOp(name, arg.getLoc(), type,
symbolTable, moduleBuilder);
});
})
.Default([&](Type type) {
return createPrimitiveDefaultGlobalOp(name, arg.getLoc(), type,
symbolTable, moduleBuilder);
});
if (globalOp) {
// Prevent globals from folding so that we have unique buffers for each arg.
globalOp->setAttr("flow.unique_id", moduleBuilder.getStringAttr(name));
}
return globalOp;
}

static LogicalResult
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <utility>

#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Utils/StringUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"

namespace mlir::iree_compiler::IREE::Flow {

#define GEN_PASS_DEF_OUTLINECONSTANTSPASS
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"

namespace {

// Returns true if |value| is worth outlining (large, etc).
static bool isOutlinableValue(Attribute value) {
if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(value)) {
// Don't outline splats - we want those fused.
return !elementsAttr.isSplat();
} else if (isa<IREE::Flow::NamedParameterAttr>(value)) {
// Always outline parameter constants.
return true;
}
return false;
}

struct ConstantDef {
Operation *op;
Type type;
TypedAttr value;
};

// Returns a list of all constant-like shaped data ops in the module.
static SmallVector<ConstantDef> findConstantsInModule(mlir::ModuleOp moduleOp) {
SmallVector<ConstantDef> results;
for (auto callableOp : moduleOp.getOps<CallableOpInterface>()) {
auto *region = callableOp.getCallableRegion();
if (!region)
continue;
region->walk([&](Operation *op) {
if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
if (isOutlinableValue(constantOp.getValue())) {
results.push_back(ConstantDef{
constantOp,
constantOp.getType(),
constantOp.getValue(),
});
}
} else if (auto constantOp = dyn_cast<IREE::Flow::TensorConstantOp>(op)) {
if (isOutlinableValue(constantOp.getValue())) {
results.push_back(ConstantDef{
constantOp,
constantOp.getType(),
constantOp.getValue(),
});
}
}
});
}
return results;
}

// Returns the operation containing |childOp| that is a direct child of
// |ancestorOp|. May return |childOp|.
static Operation *getParentInOp(Operation *childOp, Operation *ancestorOp) {
assert(childOp != ancestorOp && "child can't be its own ancestor");
do {
auto *parentOp = childOp->getParentOp();
if (parentOp == ancestorOp)
return childOp;
childOp = parentOp;
} while (childOp);
assert(false && "child must be nested under ancestor");
return nullptr;
}

static std::string getConstantName(ConstantDef &def) {
std::string str;
llvm::raw_string_ostream os(str);
if (auto parameterAttr =
dyn_cast<IREE::Flow::NamedParameterAttr>(def.value)) {
os << "__parameter_";
if (parameterAttr.getScope() && !parameterAttr.getScope().empty())
os << parameterAttr.getScope().getValue() << "_";
os << parameterAttr.getKey().getValue() << "_";
} else {
os << "__constant_";
}
def.type.print(os);
str = sanitizeSymbolName(str);
if (str.substr(str.size() - 1) == "_")
str = str.substr(0, str.size() - 1); // strip trailing _
return str;
}

//===----------------------------------------------------------------------===//
// --iree-flow-outline-constants
//===----------------------------------------------------------------------===//

struct OutlineConstantsPass
: public IREE::Flow::impl::OutlineConstantsPassBase<OutlineConstantsPass> {
void runOnOperation() override {
auto moduleOp = getOperation();
if (moduleOp.getBody()->empty())
return;

SymbolTable moduleSymbols(moduleOp);

// Create all top-level util.globals from constants in the module.
std::vector<std::pair<Operation *, IREE::Util::GlobalOp>> replacements;
for (auto &def : findConstantsInModule(moduleOp)) {
// Position the global immediately preceding the top-level op that
// contains the constant.
OpBuilder moduleBuilder(&moduleOp.getBody()->front());
auto parentFuncOp = getParentInOp(def.op, moduleOp);
if (parentFuncOp)
moduleBuilder.setInsertionPoint(parentFuncOp);

// New immutable global takes the constant attribute in its specified
// encoding.
auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
def.op->getLoc(), getConstantName(def), /*isMutable=*/false, def.type,
def.value);
globalOp.setPrivate();
IREE::Util::HoistableAttrInterface::gatherHoistableAttrs(def.op,
globalOp);
moduleSymbols.insert(globalOp); // uniques name
replacements.emplace_back(def.op, globalOp);

// Prevent the variable from being re-inlined if the canonicalizer runs.
// By the time we've outlined things here we are sure we want them
// outlined even if the user runs an arbitrary number of passes between
// now and when we may use that information (HAL constant pooling, etc).
globalOp.setInliningPolicyAttr(
moduleBuilder.getAttr<IREE::Util::InlineNeverAttr>());
}

// Replace all of the constants with lookups for the new variables.
for (auto pair : replacements) {
auto *originalOp = pair.first;
auto globalOp = pair.second;
OpBuilder builder(moduleOp.getContext());
builder.setInsertionPoint(originalOp);
auto loadOp = globalOp.createLoadOp(originalOp->getLoc(), builder);
loadOp.setGlobalImmutable(true);
originalOp->getResult(0).replaceAllUsesWith(
loadOp.getLoadedGlobalValue());
originalOp->erase();
}
}
};

} // namespace

} // namespace mlir::iree_compiler::IREE::Flow
71 changes: 68 additions & 3 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,36 @@ namespace mlir::iree_compiler::IREE::Flow {
using FunctionLikeNest =
MultiOpNest<func::FuncOp, IREE::Util::InitializerOp, IREE::Util::FuncOp>;

//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//

static void addCleanupPatterns(OpPassManager &passManager) {
FunctionLikeNest(passManager)
// Standard MLIR cleanup.
.addPass(mlir::createCanonicalizerPass)
.addPass(mlir::createCSEPass)

// Simplify util.global accesses; this can help with data flow tracking as
// redundant store-loads are removed.
.addPass(IREE::Util::createSimplifyGlobalAccessesPass);

// Cleanup and canonicalization of util.global (and other util ops).
passManager.addPass(IREE::Util::createApplyPatternsPass());
passManager.addPass(IREE::Util::createFoldGlobalsPass());
passManager.addPass(IREE::Util::createFuseGlobalsPass());

// Large IPO pass. Note that this can introduce a significant amount of
// duplication/inlined constants and we'll want to ensure we're running
// cleanup again after (this entire set of patterns is run in a fixed-point
// iteration to do that).
passManager.addPass(IREE::Util::createIPOPass());
}

//===----------------------------------------------------------------------===//
// Pipelines
//===----------------------------------------------------------------------===//

void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) {
// 1. Do some simple elementwise op fusion. This could be skipped,
// but could reduce the surface area of ops to handle later.
Expand Down Expand Up @@ -240,6 +270,22 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager,
clEnableFusePaddingIntoLinalgConsumerOps}));
}

{
// We run these under a fixed-point iteration such that we can perform
// inter-procedural, intra-procedural, and canonicalization as separably
// verifiable/reusable passes. IPO will fold duplicate arguments/results
// and inline constants to allow the local optimizations to work more
// effectively.
OpPassManager ipoPipeline(mlir::ModuleOp::getOperationName());

// IPO and other cleanups.
addCleanupPatterns(ipoPipeline);

// Run fixed-point iteration on the IPO pipeline.
passManager.addPass(
IREE::Util::createFixedPointIteratorPass(std::move(ipoPipeline)));
}

addDispatchRegionCreationPasses(passManager, transformOptions);

FunctionLikeNest(passManager)
Expand Down Expand Up @@ -325,9 +371,28 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager,
// passes above after we've formed dispatch regions.
.addPass(IREE::Flow::createInjectTensorTracingPass)
// Cleanup the IR after we are done.
.addPass(IREE::Flow::createCleanupTensorShapesPass)
.addPass(mlir::createCanonicalizerPass)
.addPass(mlir::createCSEPass);
.addPass(IREE::Flow::createCleanupTensorShapesPass);

{
// We run these under a fixed-point iteration such that we can perform
// inter-procedural, intra-procedural, and canonicalization as separably
// verifiable/reusable passes. IPO will fold duplicate arguments/results
// and inline constants to allow the local optimizations to work more
// effectively.
OpPassManager ipoPipeline(mlir::ModuleOp::getOperationName());

// Turn all constant ops into global variables and fix up the IR.
// As many locations change and constants are deduplicated we'll end up with
// a lot of extraneous IR (mostly global loads) and clean those up here.
ipoPipeline.addPass(IREE::Flow::createOutlineConstantsPass());

// IPO and other cleanups.
addCleanupPatterns(ipoPipeline);

// Run fixed-point iteration on the IPO pipeline.
passManager.addPass(
IREE::Util::createFixedPointIteratorPass(std::move(ipoPipeline)));
}

// Cleanup executable contents.
{
Expand Down
14 changes: 14 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,20 @@ def InterchangeTransposeGenericOpsPass :
];
}

def OutlineConstantsPass :
Pass<"iree-flow-outline-constants", "mlir::ModuleOp"> {
let summary = "Outlines tensor constants into util.globals at the module level.";
let description = [{
Outlines tensor constants throughout the program into globals initialized
with stream ops.
}];
let dependentDialects = [
"mlir::arith::ArithDialect",
"IREE::Flow::FlowDialect",
"IREE::Util::UtilDialect",
];
}

def OutlineDispatchExternsPass :
Pass<"iree-flow-outline-dispatch-externs", "mlir::ModuleOp"> {
let summary = "Outlines external dispatches into executables.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ iree_lit_test_suite(
"inject_tensor_tracing.mlir",
"insert_dispatch_debug_targets.mlir",
"interchange_transpose_generic_ops.mlir",
"outline_constants.mlir",
"outline_dispatch_externs.mlir",
"outline_dispatch_regions.mlir",
"pad_fusion_with_consumer.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ iree_lit_test_suite(
"inject_tensor_tracing.mlir"
"insert_dispatch_debug_targets.mlir"
"interchange_transpose_generic_ops.mlir"
"outline_constants.mlir"
"outline_dispatch_externs.mlir"
"outline_dispatch_regions.mlir"
"pad_fusion_with_consumer.mlir"
Expand Down
Loading

0 comments on commit 3ca0a49

Please sign in to comment.