Skip to content

Commit

Permalink
Improving VM conversion performance. (iree-org#18957)
Browse files Browse the repository at this point in the history
The major change here is using a precomputed import table in VM
conversion patterns. This removes the symbol lookup that was happening
on each call. In models with 100k calls to imports this speeds things up
a lot.

Also squashed a few more perf issues involving symbol lookups while
profiling and made some passes that could nest on function-like ops do
so.

These changes drop VM translation of the 405b model from 3.5mins to
~1.5min. Disabling verification (`-verify-each=0` to iree-opt or
`-verify=false` to iree-compile) takes it to 1min.

Remaining work is mostly around parallelizing some passes that are not
trivially parallelizable (FoldGlobals, DropUnusedCalls, etc) and
parallelizing some analysis (Explorer global init, call graph walking)
that tends to get real expensive when there are 250k calls and 500k ops.
Any place that does a symbol use walk is going to suffer. Many of these
fixes are in our code but there's several upstream components that fall
over with this amount of IR (CallGraph, DataFlowSolver, the verifier,
etc).
  • Loading branch information
benvanik authored Oct 30, 2024
1 parent a744285 commit 2ec9017
Show file tree
Hide file tree
Showing 23 changed files with 261 additions and 185 deletions.
6 changes: 4 additions & 2 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,12 @@ static void addCleanupPatterns(OpPassManager &passManager) {

// Simplify util.global accesses; this can help with data flow tracking as
// redundant store-loads are removed.
.addPass(IREE::Util::createSimplifyGlobalAccessesPass);
.addPass(IREE::Util::createSimplifyGlobalAccessesPass)

// Aggressive cleanup.
.addPass(IREE::Util::createApplyPatternsPass);

// Cleanup and canonicalization of util.global (and other util ops).
passManager.addPass(IREE::Util::createApplyPatternsPass());
passManager.addPass(IREE::Util::createFoldGlobalsPass());
passManager.addPass(IREE::Util::createFuseGlobalsPass());

Expand Down
18 changes: 10 additions & 8 deletions compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,20 @@ using FunctionLikeNest =
//===----------------------------------------------------------------------===//

static void addCleanupPatterns(OpPassManager &passManager) {
// Standard MLIR cleanup.
passManager.addPass(mlir::createCSEPass());
passManager.addPass(mlir::createCanonicalizerPass());
passManager.addPass(mlir::createCSEPass());

// Simplify util.global accesses; this can help with data flow tracking as
// redundant store-loads are removed.
FunctionLikeNest(passManager)
.addPass(IREE::Util::createSimplifyGlobalAccessesPass);
// Standard MLIR cleanup.
.addPass(mlir::createCanonicalizerPass)
.addPass(mlir::createCSEPass)

// Simplify util.global accesses; this can help with data flow tracking as
// redundant store-loads are removed.
.addPass(IREE::Util::createSimplifyGlobalAccessesPass)

// Aggressive cleanup.
.addPass(IREE::Util::createApplyPatternsPass);

// Cleanup and canonicalization of util.global (and other util ops).
passManager.addPass(IREE::Util::createApplyPatternsPass());
passManager.addPass(IREE::Util::createFoldGlobalsPass());
passManager.addPass(IREE::Util::createFuseGlobalsPass());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ static void addCleanupPatterns(OpPassManager &passManager) {

// Simplify util.global accesses; this can help with data flow tracking as
// redundant store-loads are removed.
.addPass(IREE::Util::createSimplifyGlobalAccessesPass);
.addPass(IREE::Util::createSimplifyGlobalAccessesPass)

// Aggressive cleanup.
.addPass(IREE::Util::createApplyPatternsPass);

// Cleanup and canonicalization of util.global (and other util ops).
passManager.addPass(IREE::Util::createApplyPatternsPass());
passManager.addPass(IREE::Util::createFoldGlobalsPass());
passManager.addPass(IREE::Util::createFuseGlobalsPass());

Expand Down
47 changes: 30 additions & 17 deletions compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,33 @@ void Explorer::initializeGlobalInfos() {
// TODO(benvanik): filter the use list by traversal actions; where this runs
// today we don't yet have the actions specified so we can't.

// Initialize the full list of globals.
for (auto globalOp :
symbolTableOp->getRegion(0).getOps<IREE::Util::GlobalOpInterface>()) {
auto globalInfo = std::make_unique<GlobalInfo>();
globalInfo->op = globalOp;
globalInfosByName[globalOp.getGlobalName().getValue()] = globalInfo.get();
globalInfos[globalOp] = std::move(globalInfo);
}

// Walk the module and gather uses.
//
// TODO: find a way to do this more efficiently when the module is large.
// We could parallelize on top-level functions and then merge at the end.
auto allUses = symbolTable.getSymbolUses(&symbolTableOp->getRegion(0));
if (!allUses.has_value())
return;
for (auto use : allUses.value()) {
auto *symbolOp =
symbolTable.lookupNearestSymbolFrom(use.getUser(), use.getSymbolRef());
if (!isa_and_nonnull<IREE::Util::GlobalOpInterface>(symbolOp))
continue;
auto &globalInfo = globalInfos[symbolOp];
globalInfo.op = cast<IREE::Util::GlobalOpInterface>(symbolOp);
if (isa<IREE::Util::GlobalAddressOpInterface>(use.getUser())) {
globalInfo.isIndirect = true;
} else {
globalInfo.uses.push_back(use.getUser());
if (allUses.has_value()) {
for (auto use : allUses.value()) {
auto globalInfoIt = globalInfosByName.find(
use.getSymbolRef().getLeafReference().getValue());
if (globalInfoIt == globalInfosByName.end()) {
continue; // not a global
}
auto *globalInfo = globalInfoIt->second;
if (isa<IREE::Util::GlobalAddressOpInterface>(use.getUser())) {
globalInfo->isIndirect = true;
} else {
globalInfo->uses.push_back(use.getUser());
}
}
}
}
Expand Down Expand Up @@ -175,7 +188,7 @@ Explorer::getGlobalInfo(IREE::Util::GlobalOpInterface globalOp) {
auto it = globalInfos.find(globalOp);
if (it == globalInfos.end())
return nullptr;
return &it->second;
return it->second.get();
}

const Explorer::GlobalInfo *Explorer::queryGlobalInfoFrom(StringRef globalName,
Expand All @@ -189,12 +202,12 @@ const Explorer::GlobalInfo *Explorer::queryGlobalInfoFrom(StringRef globalName,
auto it = globalInfos.find(op);
if (it == globalInfos.end())
return nullptr;
return &it->second;
return it->second.get();
}

void Explorer::forEachGlobal(std::function<void(const GlobalInfo *)> fn) {
for (auto it : globalInfos) {
fn(&it.second);
for (auto &it : globalInfos) {
fn(it.second.get());
}
}

Expand Down
3 changes: 2 additions & 1 deletion compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,8 @@ class Explorer {
DenseMap<StringRef, TraversalAction> dialectActions;
DenseMap<OperationName, TraversalAction> opActions;

DenseMap<Operation *, GlobalInfo> globalInfos;
DenseMap<Operation *, std::unique_ptr<GlobalInfo>> globalInfos;
DenseMap<StringRef, GlobalInfo *> globalInfosByName;
ModuleAnalysisManager analysisManager;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ void GlobalTable::rebuild() {
globalOrder.push_back(globalName);
}

// TODO: parallelize this by gathering on multiple threads per callable and
// then merging at the end.
for (auto callableOp : moduleOp.getOps<CallableOpInterface>()) {
if (auto uses = SymbolTable::getSymbolUses(callableOp)) {
for (auto use : *uses) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct FoldBlockArgumentsPattern
if (!op.getCallableRegion())
return failure();
auto &region = *op.getCallableRegion();
if (region.empty())
if (region.empty() || region.hasOneBlock())
return failure();

// Analyze all branches in the op to compute the information we'll need to
Expand Down Expand Up @@ -501,7 +501,6 @@ void populateCommonPatterns(MLIRContext *context, RewritePatternSet &patterns) {
context->getOrLoadDialect<IREE::Util::UtilDialect>()
->getCanonicalizationPatterns(patterns);

// TODO(benvanik): same as branch folding but for calls.
patterns.insert<FoldBlockArgumentsPattern, ElideBranchOperandsPattern>(
context);

Expand Down
50 changes: 50 additions & 0 deletions compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,56 @@

namespace mlir::iree_compiler {

LogicalResult ImportTable::build(Operation *rootOp,
const TypeConverter &typeConverter) {
for (auto funcOp : rootOp->getRegion(0).getOps<FunctionOpInterface>()) {
if (!funcOp.isExternal()) {
continue; // only external functions are imports
}

ImportTable::Import import;
import.name = funcOp.getNameAttr();
import.fallback = funcOp->getAttrOfType<SymbolRefAttr>("vm.fallback");

// Try to use an assigned signature or fall back to converting the input.
if (auto importOp = dyn_cast<IREE::VM::ImportOp>(funcOp.getOperation())) {
// Import ops have their signature used directly.
import.signature = importOp.getFunctionType();
} else if (auto signatureAttr =
funcOp->getAttrOfType<TypeAttr>("vm.signature")) {
// Directly use the specified signature.
import.signature =
dyn_cast_if_present<FunctionType>(signatureAttr.getValue());
}
if (!import.signature) {
// Convert the signature using the type converter.
SmallVector<Type> argumentTypes;
if (failed(typeConverter.convertTypes(funcOp.getArgumentTypes(),
argumentTypes))) {
return funcOp.emitError() << "unable to convert import argument types";
}
SmallVector<Type> resultTypes;
if (failed(typeConverter.convertTypes(funcOp.getResultTypes(),
resultTypes))) {
return funcOp.emitError() << "unable to convert import result types";
}
import.signature =
FunctionType::get(rootOp->getContext(), argumentTypes, resultTypes);
}

symbols[import.name.getValue()] = std::move(import);
}

return success();
}

std::optional<ImportTable::Import> ImportTable::find(StringRef symbolName) {
auto it = symbols.find(symbolName);
if (it == symbols.end())
return std::nullopt;
return it->second;
}

// TODO(benvanik): replace with iree/compiler/Utils/ModuleUtils.h.
// There may be some special insertion order arrangement required based on the
// nested vm.module here.
Expand Down
27 changes: 27 additions & 0 deletions compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,33 @@ namespace mlir::iree_compiler {
// segment_sizes array.
constexpr int kFixedSingleValue = -1;

// A table of import information.
class ImportTable {
public:
// Information about an import function.
struct Import {
// Used to ensure the StringRef in the map stays live.
StringAttr name;
// Function signature derived from the type or overridden by `vm.signature`.
FunctionType signature;
// Optional fallback function that should be used when the import is
// unavailable at runtime taken from `vm.fallback`.
SymbolRefAttr fallback;
};

// Builds a table of all import functions nested within the given |rootOp|.
// Clones any information such that the original ops can be mutated/erased.
// Must only be called once the type converter has been fully populated.
LogicalResult build(Operation *rootOp, const TypeConverter &typeConverter);

// Finds an import with the given name if there exists one.
std::optional<Import> find(StringRef symbolName);

private:
// Map of symbol names within the root op to import symbol info.
DenseMap<StringRef, Import> symbols;
};

// Appends a set of vm.import ops from a module to a target VM module.
// Imports will only be added if they are not already present in the target
// module.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ static void copyImportAttrs(func::FuncOp srcOp, IREE::VM::ImportOp dstOp) {
constexpr const char *kRetainedAttributes[] = {
"nosideeffects",
"vm.fallback",
"vm.signature",
};
auto retainedAttributes = ArrayRef<const char *>(
kRetainedAttributes,
Expand Down Expand Up @@ -241,7 +240,11 @@ struct ExternalFuncOpConversion : public OpConversionPattern<func::FuncOp> {
};

struct CallOpConversion : public OpConversionPattern<func::CallOp> {
using OpConversionPattern::OpConversionPattern;
ImportTable &importTable;
CallOpConversion(const TypeConverter &typeConverter, MLIRContext *context,
ImportTable &importTable, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
importTable(importTable) {}
LogicalResult
matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand All @@ -262,7 +265,8 @@ struct CallOpConversion : public OpConversionPattern<func::CallOp> {
// conversion if imports have fallbacks that are themselves imports.
auto callResults = convertCallOp(
callOp->getParentOfType<IREE::VM::ModuleOp>(), callOp.getLoc(),
callOp.getCallee(), adaptor.getOperands(), resultTypes, rewriter);
callOp.getCallee(), adaptor.getOperands(), resultTypes, importTable,
rewriter);
if (failed(callResults)) {
return rewriter.notifyMatchFailure(
callOp, "unable to convert call (results mismatch)");
Expand All @@ -277,36 +281,12 @@ struct CallOpConversion : public OpConversionPattern<func::CallOp> {
FailureOr<SmallVector<Value>>
convertCallOp(Operation *rootOp, Location loc, StringRef calleeName,
ValueRange operands, TypeRange resultTypes,
ImportTable &importTable,
ConversionPatternRewriter &rewriter) const {
// (Slow) lookup of the target function, which may be an import that we need
// to perform type conversion for.
auto calleeOp = SymbolTable::lookupSymbolIn(rootOp, calleeName);
if (auto funcOp = dyn_cast_or_null<FunctionOpInterface>(calleeOp)) {
if (funcOp.isExternal()) {
// Import that may require conversion.
// This case handles when funcs are declared after the call.
FunctionType convertedSignature;
if (auto signatureAttr =
funcOp->getAttrOfType<TypeAttr>("vm.signature")) {
if (auto importSignature =
llvm::dyn_cast<FunctionType>(signatureAttr.getValue())) {
convertedSignature = importSignature;
}
}
if (!convertedSignature) {
convertedSignature =
rewriter.getFunctionType(TypeRange(operands), resultTypes);
}
return convertImportCallOp(rootOp, loc, calleeName, operands,
resultTypes, convertedSignature, funcOp,
rewriter);
}
} else if (auto importOp = dyn_cast_or_null<IREE::VM::ImportOp>(calleeOp)) {
// Calling an import.
// This case handles when funcs are declared before the call and have
// already been converted.
return convertImportCallOp(rootOp, loc, calleeName, operands, resultTypes,
importOp.getFunctionType(), importOp,
// Lookup the target and detect if it is an import.
auto import = importTable.find(calleeName);
if (import.has_value()) {
return convertImportCallOp(rootOp, loc, *import, operands, resultTypes,
rewriter);
}

Expand All @@ -319,19 +299,19 @@ struct CallOpConversion : public OpConversionPattern<func::CallOp> {
// Converts a call to an import that may be optional.
// Returns the new converted call results.
FailureOr<SmallVector<Value>>
convertImportCallOp(Operation *rootOp, Location loc, StringRef calleeName,
ValueRange operands, TypeRange resultTypes,
FunctionType importSignature, Operation *calleeOp,
convertImportCallOp(Operation *rootOp, Location loc,
ImportTable::Import &import, ValueRange operands,
TypeRange resultTypes,
ConversionPatternRewriter &rewriter) const {
auto fallbackAttr = calleeOp->getAttrOfType<SymbolRefAttr>("vm.fallback");
return fallbackAttr
? convertOptionalImportCallOp(
rootOp, loc, calleeName, operands, resultTypes,
importSignature,
fallbackAttr.getLeafReference().getValue(), rewriter)
: convertMandatoryImportCallOp(rootOp, loc, calleeName, operands,
resultTypes, importSignature,
rewriter);
if (import.fallback) {
return convertOptionalImportCallOp(
rootOp, loc, import.name, operands, resultTypes, import.signature,
import.fallback.getLeafReference().getValue(), rewriter);
} else {
return convertMandatoryImportCallOp(rootOp, loc, import.name, operands,
resultTypes, import.signature,
rewriter);
}
}

// Converts a call to an optional import by adding logic to check whether it
Expand Down Expand Up @@ -374,7 +354,7 @@ struct CallOpConversion : public OpConversionPattern<func::CallOp> {
// Not resolved: call fallback as a normal function.
rewriter.setInsertionPointToStart(fallbackBlock);
auto fallbackResults = convertCallOp(rootOp, loc, fallbackName, operands,
resultTypes, rewriter);
resultTypes, importTable, rewriter);
if (failed(fallbackResults))
return failure();
rewriter.create<IREE::VM::BranchOp>(loc, exitBlock, *fallbackResults);
Expand Down Expand Up @@ -557,12 +537,14 @@ struct SwitchOpConversion : public OpConversionPattern<cf::SwitchOp> {

void populateStandardToVMPatterns(MLIRContext *context,
TypeConverter &typeConverter,
ImportTable &importTable,
RewritePatternSet &patterns) {
patterns
.insert<AssertOpConversion, BranchOpConversion, CallOpConversion,
CondBranchOpConversion, SwitchOpConversion, ModuleOpConversion,
FuncOpConversion, ExternalFuncOpConversion, ReturnOpConversion>(
typeConverter, context);
.insert<AssertOpConversion, BranchOpConversion, CondBranchOpConversion,
SwitchOpConversion, ModuleOpConversion, FuncOpConversion,
ExternalFuncOpConversion, ReturnOpConversion>(typeConverter,
context);
patterns.insert<CallOpConversion>(typeConverter, context, importTable);
patterns.insert<CastingOpConversion<mlir::UnrealizedConversionCastOp>>(
typeConverter, context);
}
Expand Down
Loading

0 comments on commit 2ec9017

Please sign in to comment.