From 2ec9017bb1de7d5ba552c1f016326088864b9cf9 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Wed, 30 Oct 2024 16:18:44 -0700 Subject: [PATCH] Improving VM conversion performance. (#18957) The major change here is using a precomputed import table in VM conversion patterns. This removes the symbol lookup that was happening on each call. In models with 100k calls to imports this speeds things up a lot. Also squashed a few more perf issues involving symbol lookups while profiling and made some passes that could nest on function-like ops do so. These changes drop VM translation of the 405b model from 3.5mins to ~1.5min. Disabling verification (`-verify-each=0` to iree-opt or `-verify=false` to iree-compile) takes it to 1min. Remaining work is mostly around parallelizing some passes that are not trivially parallelizable (FoldGlobals, DropUnusedCalls, etc) and parallelizing some analysis (Explorer global init, call graph walking) that tends to get real expensive when there are 250k calls and 500k ops. Any place that does a symbol use walk is going to suffer. Many of these fixes are in our code but there's several upstream components that fall over with this amount of IR (CallGraph, DataFlowSolver, the verifier, etc). --- .../Dialect/Flow/Transforms/Passes.cpp | 6 +- .../Dialect/HAL/Transforms/Passes.cpp | 18 +++-- .../Dialect/Stream/Transforms/Passes.cpp | 6 +- .../Dialect/Util/Analysis/Explorer.cpp | 47 +++++++---- .../compiler/Dialect/Util/Analysis/Explorer.h | 3 +- .../Dialect/Util/Analysis/GlobalTable.cpp | 2 + .../Dialect/Util/Transforms/Patterns.cpp | 3 +- .../Dialect/VM/Conversion/ImportUtils.cpp | 50 ++++++++++++ .../Dialect/VM/Conversion/ImportUtils.h | 27 +++++++ .../VM/Conversion/StandardToVM/Patterns.cpp | 80 +++++++------------ .../VM/Conversion/StandardToVM/Patterns.h | 2 + .../UtilToVM/ConvertStructuralOps.cpp | 75 +++++++---------- .../VM/Conversion/UtilToVM/Patterns.cpp | 4 +- .../Dialect/VM/Conversion/UtilToVM/Patterns.h | 2 + .../src/iree/compiler/Dialect/VM/IR/VMOps.cpp | 10 +-- .../src/iree/compiler/Dialect/VM/IR/VMOps.td | 3 +- .../Dialect/VM/Transforms/Conversion.cpp | 13 ++- .../VM/Transforms/GlobalInitialization.cpp | 33 +++----- .../compiler/Dialect/VM/Transforms/Passes.cpp | 15 ++-- .../iree/compiler/DispatchCreation/Passes.cpp | 6 +- .../compiler/GlobalOptimization/Passes.cpp | 13 +-- .../Modules/HAL/Inline/Transforms/Passes.cpp | 14 ++-- .../Modules/HAL/Loader/Transforms/Passes.cpp | 14 ++-- 23 files changed, 261 insertions(+), 185 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index 1357e077d86c..a84205a58dd7 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp @@ -88,10 +88,12 @@ static void addCleanupPatterns(OpPassManager &passManager) { // Simplify util.global accesses; this can help with data flow tracking as // redundant store-loads are removed. - .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + .addPass(IREE::Util::createSimplifyGlobalAccessesPass) + + // Aggressive cleanup. + .addPass(IREE::Util::createApplyPatternsPass); // Cleanup and canonicalization of util.global (and other util ops). - passManager.addPass(IREE::Util::createApplyPatternsPass()); passManager.addPass(IREE::Util::createFoldGlobalsPass()); passManager.addPass(IREE::Util::createFuseGlobalsPass()); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index 07df5d87665d..1de6f1cbeb84 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp @@ -149,18 +149,20 @@ using FunctionLikeNest = //===----------------------------------------------------------------------===// static void addCleanupPatterns(OpPassManager &passManager) { - // Standard MLIR cleanup. - passManager.addPass(mlir::createCSEPass()); - passManager.addPass(mlir::createCanonicalizerPass()); - passManager.addPass(mlir::createCSEPass()); - // Simplify util.global accesses; this can help with data flow tracking as - // redundant store-loads are removed. FunctionLikeNest(passManager) - .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + // Standard MLIR cleanup. + .addPass(mlir::createCanonicalizerPass) + .addPass(mlir::createCSEPass) + + // Simplify util.global accesses; this can help with data flow tracking as + // redundant store-loads are removed. + .addPass(IREE::Util::createSimplifyGlobalAccessesPass) + + // Aggressive cleanup. + .addPass(IREE::Util::createApplyPatternsPass); // Cleanup and canonicalization of util.global (and other util ops). - passManager.addPass(IREE::Util::createApplyPatternsPass()); passManager.addPass(IREE::Util::createFoldGlobalsPass()); passManager.addPass(IREE::Util::createFuseGlobalsPass()); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp index 408bb024f2f2..2234c62daa58 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp @@ -44,10 +44,12 @@ static void addCleanupPatterns(OpPassManager &passManager) { // Simplify util.global accesses; this can help with data flow tracking as // redundant store-loads are removed. - .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + .addPass(IREE::Util::createSimplifyGlobalAccessesPass) + + // Aggressive cleanup. + .addPass(IREE::Util::createApplyPatternsPass); // Cleanup and canonicalization of util.global (and other util ops). - passManager.addPass(IREE::Util::createApplyPatternsPass()); passManager.addPass(IREE::Util::createFoldGlobalsPass()); passManager.addPass(IREE::Util::createFuseGlobalsPass()); diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp index 745f0291f5d6..32c7819a66b3 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp @@ -126,20 +126,33 @@ void Explorer::initializeGlobalInfos() { // TODO(benvanik): filter the use list by traversal actions; where this runs // today we don't yet have the actions specified so we can't. + // Initialize the full list of globals. + for (auto globalOp : + symbolTableOp->getRegion(0).getOps()) { + auto globalInfo = std::make_unique(); + globalInfo->op = globalOp; + globalInfosByName[globalOp.getGlobalName().getValue()] = globalInfo.get(); + globalInfos[globalOp] = std::move(globalInfo); + } + + // Walk the module and gather uses. + // + // TODO: find a way to do this more efficiently when the module is large. + // We could parallelize on top-level functions and then merge at the end. auto allUses = symbolTable.getSymbolUses(&symbolTableOp->getRegion(0)); - if (!allUses.has_value()) - return; - for (auto use : allUses.value()) { - auto *symbolOp = - symbolTable.lookupNearestSymbolFrom(use.getUser(), use.getSymbolRef()); - if (!isa_and_nonnull(symbolOp)) - continue; - auto &globalInfo = globalInfos[symbolOp]; - globalInfo.op = cast(symbolOp); - if (isa(use.getUser())) { - globalInfo.isIndirect = true; - } else { - globalInfo.uses.push_back(use.getUser()); + if (allUses.has_value()) { + for (auto use : allUses.value()) { + auto globalInfoIt = globalInfosByName.find( + use.getSymbolRef().getLeafReference().getValue()); + if (globalInfoIt == globalInfosByName.end()) { + continue; // not a global + } + auto *globalInfo = globalInfoIt->second; + if (isa(use.getUser())) { + globalInfo->isIndirect = true; + } else { + globalInfo->uses.push_back(use.getUser()); + } } } } @@ -175,7 +188,7 @@ Explorer::getGlobalInfo(IREE::Util::GlobalOpInterface globalOp) { auto it = globalInfos.find(globalOp); if (it == globalInfos.end()) return nullptr; - return &it->second; + return it->second.get(); } const Explorer::GlobalInfo *Explorer::queryGlobalInfoFrom(StringRef globalName, @@ -189,12 +202,12 @@ const Explorer::GlobalInfo *Explorer::queryGlobalInfoFrom(StringRef globalName, auto it = globalInfos.find(op); if (it == globalInfos.end()) return nullptr; - return &it->second; + return it->second.get(); } void Explorer::forEachGlobal(std::function fn) { - for (auto it : globalInfos) { - fn(&it.second); + for (auto &it : globalInfos) { + fn(it.second.get()); } } diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h index 35ee12aa822b..4c9482b03ade 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h @@ -403,7 +403,8 @@ class Explorer { DenseMap dialectActions; DenseMap opActions; - DenseMap globalInfos; + DenseMap> globalInfos; + DenseMap globalInfosByName; ModuleAnalysisManager analysisManager; }; diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.cpp index 7fdc301fb365..70387f6de584 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.cpp @@ -69,6 +69,8 @@ void GlobalTable::rebuild() { globalOrder.push_back(globalName); } + // TODO: parallelize this by gathering on multiple threads per callable and + // then merging at the end. for (auto callableOp : moduleOp.getOps()) { if (auto uses = SymbolTable::getSymbolUses(callableOp)) { for (auto use : *uses) { diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp index ac527d792c0b..e7c722d46cff 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp @@ -46,7 +46,7 @@ struct FoldBlockArgumentsPattern if (!op.getCallableRegion()) return failure(); auto ®ion = *op.getCallableRegion(); - if (region.empty()) + if (region.empty() || region.hasOneBlock()) return failure(); // Analyze all branches in the op to compute the information we'll need to @@ -501,7 +501,6 @@ void populateCommonPatterns(MLIRContext *context, RewritePatternSet &patterns) { context->getOrLoadDialect() ->getCanonicalizationPatterns(patterns); - // TODO(benvanik): same as branch folding but for calls. patterns.insert( context); diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp index fa837d41c992..ea5e257a7728 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp @@ -18,6 +18,56 @@ namespace mlir::iree_compiler { +LogicalResult ImportTable::build(Operation *rootOp, + const TypeConverter &typeConverter) { + for (auto funcOp : rootOp->getRegion(0).getOps()) { + if (!funcOp.isExternal()) { + continue; // only external functions are imports + } + + ImportTable::Import import; + import.name = funcOp.getNameAttr(); + import.fallback = funcOp->getAttrOfType("vm.fallback"); + + // Try to use an assigned signature or fall back to converting the input. + if (auto importOp = dyn_cast(funcOp.getOperation())) { + // Import ops have their signature used directly. + import.signature = importOp.getFunctionType(); + } else if (auto signatureAttr = + funcOp->getAttrOfType("vm.signature")) { + // Directly use the specified signature. + import.signature = + dyn_cast_if_present(signatureAttr.getValue()); + } + if (!import.signature) { + // Convert the signature using the type converter. + SmallVector argumentTypes; + if (failed(typeConverter.convertTypes(funcOp.getArgumentTypes(), + argumentTypes))) { + return funcOp.emitError() << "unable to convert import argument types"; + } + SmallVector resultTypes; + if (failed(typeConverter.convertTypes(funcOp.getResultTypes(), + resultTypes))) { + return funcOp.emitError() << "unable to convert import result types"; + } + import.signature = + FunctionType::get(rootOp->getContext(), argumentTypes, resultTypes); + } + + symbols[import.name.getValue()] = std::move(import); + } + + return success(); +} + +std::optional ImportTable::find(StringRef symbolName) { + auto it = symbols.find(symbolName); + if (it == symbols.end()) + return std::nullopt; + return it->second; +} + // TODO(benvanik): replace with iree/compiler/Utils/ModuleUtils.h. // There may be some special insertion order arrangement required based on the // nested vm.module here. diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h index b2f0a8f74c60..c5d557f5b42d 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h @@ -20,6 +20,33 @@ namespace mlir::iree_compiler { // segment_sizes array. constexpr int kFixedSingleValue = -1; +// A table of import information. +class ImportTable { +public: + // Information about an import function. + struct Import { + // Used to ensure the StringRef in the map stays live. + StringAttr name; + // Function signature derived from the type or overridden by `vm.signature`. + FunctionType signature; + // Optional fallback function that should be used when the import is + // unavailable at runtime taken from `vm.fallback`. + SymbolRefAttr fallback; + }; + + // Builds a table of all import functions nested within the given |rootOp|. + // Clones any information such that the original ops can be mutated/erased. + // Must only be called once the type converter has been fully populated. + LogicalResult build(Operation *rootOp, const TypeConverter &typeConverter); + + // Finds an import with the given name if there exists one. + std::optional find(StringRef symbolName); + +private: + // Map of symbol names within the root op to import symbol info. + DenseMap symbols; +}; + // Appends a set of vm.import ops from a module to a target VM module. // Imports will only be added if they are not already present in the target // module. diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp index f9a5c264cba8..1cb4f413883b 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp @@ -170,7 +170,6 @@ static void copyImportAttrs(func::FuncOp srcOp, IREE::VM::ImportOp dstOp) { constexpr const char *kRetainedAttributes[] = { "nosideeffects", "vm.fallback", - "vm.signature", }; auto retainedAttributes = ArrayRef( kRetainedAttributes, @@ -241,7 +240,11 @@ struct ExternalFuncOpConversion : public OpConversionPattern { }; struct CallOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + ImportTable &importTable; + CallOpConversion(const TypeConverter &typeConverter, MLIRContext *context, + ImportTable &importTable, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + importTable(importTable) {} LogicalResult matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -262,7 +265,8 @@ struct CallOpConversion : public OpConversionPattern { // conversion if imports have fallbacks that are themselves imports. auto callResults = convertCallOp( callOp->getParentOfType(), callOp.getLoc(), - callOp.getCallee(), adaptor.getOperands(), resultTypes, rewriter); + callOp.getCallee(), adaptor.getOperands(), resultTypes, importTable, + rewriter); if (failed(callResults)) { return rewriter.notifyMatchFailure( callOp, "unable to convert call (results mismatch)"); @@ -277,36 +281,12 @@ struct CallOpConversion : public OpConversionPattern { FailureOr> convertCallOp(Operation *rootOp, Location loc, StringRef calleeName, ValueRange operands, TypeRange resultTypes, + ImportTable &importTable, ConversionPatternRewriter &rewriter) const { - // (Slow) lookup of the target function, which may be an import that we need - // to perform type conversion for. - auto calleeOp = SymbolTable::lookupSymbolIn(rootOp, calleeName); - if (auto funcOp = dyn_cast_or_null(calleeOp)) { - if (funcOp.isExternal()) { - // Import that may require conversion. - // This case handles when funcs are declared after the call. - FunctionType convertedSignature; - if (auto signatureAttr = - funcOp->getAttrOfType("vm.signature")) { - if (auto importSignature = - llvm::dyn_cast(signatureAttr.getValue())) { - convertedSignature = importSignature; - } - } - if (!convertedSignature) { - convertedSignature = - rewriter.getFunctionType(TypeRange(operands), resultTypes); - } - return convertImportCallOp(rootOp, loc, calleeName, operands, - resultTypes, convertedSignature, funcOp, - rewriter); - } - } else if (auto importOp = dyn_cast_or_null(calleeOp)) { - // Calling an import. - // This case handles when funcs are declared before the call and have - // already been converted. - return convertImportCallOp(rootOp, loc, calleeName, operands, resultTypes, - importOp.getFunctionType(), importOp, + // Lookup the target and detect if it is an import. + auto import = importTable.find(calleeName); + if (import.has_value()) { + return convertImportCallOp(rootOp, loc, *import, operands, resultTypes, rewriter); } @@ -319,19 +299,19 @@ struct CallOpConversion : public OpConversionPattern { // Converts a call to an import that may be optional. // Returns the new converted call results. FailureOr> - convertImportCallOp(Operation *rootOp, Location loc, StringRef calleeName, - ValueRange operands, TypeRange resultTypes, - FunctionType importSignature, Operation *calleeOp, + convertImportCallOp(Operation *rootOp, Location loc, + ImportTable::Import &import, ValueRange operands, + TypeRange resultTypes, ConversionPatternRewriter &rewriter) const { - auto fallbackAttr = calleeOp->getAttrOfType("vm.fallback"); - return fallbackAttr - ? convertOptionalImportCallOp( - rootOp, loc, calleeName, operands, resultTypes, - importSignature, - fallbackAttr.getLeafReference().getValue(), rewriter) - : convertMandatoryImportCallOp(rootOp, loc, calleeName, operands, - resultTypes, importSignature, - rewriter); + if (import.fallback) { + return convertOptionalImportCallOp( + rootOp, loc, import.name, operands, resultTypes, import.signature, + import.fallback.getLeafReference().getValue(), rewriter); + } else { + return convertMandatoryImportCallOp(rootOp, loc, import.name, operands, + resultTypes, import.signature, + rewriter); + } } // Converts a call to an optional import by adding logic to check whether it @@ -374,7 +354,7 @@ struct CallOpConversion : public OpConversionPattern { // Not resolved: call fallback as a normal function. rewriter.setInsertionPointToStart(fallbackBlock); auto fallbackResults = convertCallOp(rootOp, loc, fallbackName, operands, - resultTypes, rewriter); + resultTypes, importTable, rewriter); if (failed(fallbackResults)) return failure(); rewriter.create(loc, exitBlock, *fallbackResults); @@ -557,12 +537,14 @@ struct SwitchOpConversion : public OpConversionPattern { void populateStandardToVMPatterns(MLIRContext *context, TypeConverter &typeConverter, + ImportTable &importTable, RewritePatternSet &patterns) { patterns - .insert( - typeConverter, context); + .insert(typeConverter, + context); + patterns.insert(typeConverter, context, importTable); patterns.insert>( typeConverter, context); } diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.h index 37af2bfc03f0..b26e9ffae46a 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.h +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_DIALECT_VM_CONVERSION_STANDARDTOVM_CONVERTSTANDARDTOVM_H_ #define IREE_COMPILER_DIALECT_VM_CONVERSION_STANDARDTOVM_CONVERTSTANDARDTOVM_H_ +#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -15,6 +16,7 @@ namespace mlir::iree_compiler { // Appends standard dialect to vm dialect patterns to the given pattern list. void populateStandardToVMPatterns(MLIRContext *context, TypeConverter &typeConverter, + ImportTable &importTable, RewritePatternSet &patterns); } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp index 87f0b561e1bf..bc0ce196209f 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp @@ -146,7 +146,6 @@ static void copyImportAttrs(IREE::Util::FuncOp srcOp, constexpr const char *kRetainedAttributes[] = { "nosideeffects", "vm.fallback", - "vm.signature", }; auto retainedAttributes = ArrayRef( kRetainedAttributes, @@ -217,8 +216,12 @@ class ExternalFuncOpConversion } }; -class CallOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct CallOpConversion : public OpConversionPattern { + ImportTable &importTable; + CallOpConversion(const TypeConverter &typeConverter, MLIRContext *context, + ImportTable &importTable, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + importTable(importTable) {} LogicalResult matchAndRewrite(IREE::Util::CallOp callOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -255,35 +258,10 @@ class CallOpConversion : public OpConversionPattern { convertCallOp(Operation *rootOp, Location loc, StringRef calleeName, ValueRange operands, TypeRange resultTypes, ConversionPatternRewriter &rewriter) const { - // (Slow) lookup of the target function, which may be an import that we need - // to perform type conversion for. - auto calleeOp = SymbolTable::lookupSymbolIn(rootOp, calleeName); - if (auto funcOp = dyn_cast_or_null(calleeOp)) { - if (funcOp.isExternal()) { - // Import that may require conversion. - // This case handles when funcs are declared after the call. - FunctionType convertedSignature; - if (auto signatureAttr = - funcOp->getAttrOfType("vm.signature")) { - if (auto importSignature = - llvm::dyn_cast(signatureAttr.getValue())) { - convertedSignature = importSignature; - } - } - if (!convertedSignature) { - convertedSignature = - rewriter.getFunctionType(TypeRange(operands), resultTypes); - } - return convertImportCallOp(rootOp, loc, calleeName, operands, - resultTypes, convertedSignature, funcOp, - rewriter); - } - } else if (auto importOp = dyn_cast_or_null(calleeOp)) { - // Calling an import. - // This case handles when funcs are declared before the call and have - // already been converted. - return convertImportCallOp(rootOp, loc, calleeName, operands, resultTypes, - importOp.getFunctionType(), importOp, + // Lookup the target and detect if it is an import. + auto import = importTable.find(calleeName); + if (import.has_value()) { + return convertImportCallOp(rootOp, loc, *import, operands, resultTypes, rewriter); } @@ -296,19 +274,19 @@ class CallOpConversion : public OpConversionPattern { // Converts a call to an import that may be optional. // Returns the new converted call results. FailureOr> - convertImportCallOp(Operation *rootOp, Location loc, StringRef calleeName, - ValueRange operands, TypeRange resultTypes, - FunctionType importSignature, Operation *calleeOp, + convertImportCallOp(Operation *rootOp, Location loc, + ImportTable::Import &import, ValueRange operands, + TypeRange resultTypes, ConversionPatternRewriter &rewriter) const { - auto fallbackAttr = calleeOp->getAttrOfType("vm.fallback"); - return fallbackAttr - ? convertOptionalImportCallOp( - rootOp, loc, calleeName, operands, resultTypes, - importSignature, - fallbackAttr.getLeafReference().getValue(), rewriter) - : convertMandatoryImportCallOp(rootOp, loc, calleeName, operands, - resultTypes, importSignature, - rewriter); + if (import.fallback) { + return convertOptionalImportCallOp( + rootOp, loc, import.name, operands, resultTypes, import.signature, + import.fallback.getLeafReference().getValue(), rewriter); + } else { + return convertMandatoryImportCallOp(rootOp, loc, import.name, operands, + resultTypes, import.signature, + rewriter); + } } // Converts a call to an optional import by adding logic to check whether it @@ -405,13 +383,14 @@ struct ReturnOpConversion : public OpConversionPattern { void populateUtilStructuralToVMPatterns(MLIRContext *context, ConversionTarget &conversionTarget, TypeConverter &typeConverter, + ImportTable &importTable, RewritePatternSet &patterns) { conversionTarget.addIllegalOp(); - patterns - .insert( - typeConverter, context); + patterns.insert(typeConverter, + context); + patterns.insert(typeConverter, context, importTable); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp index 2d60a81915c9..f257369bffeb 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp @@ -45,6 +45,7 @@ void populateUtilStatusToVMPatterns(MLIRContext *context, void populateUtilStructuralToVMPatterns(MLIRContext *context, ConversionTarget &conversionTarget, TypeConverter &typeConverter, + ImportTable &importTable, RewritePatternSet &patterns); namespace { @@ -127,6 +128,7 @@ struct UnreachableOpConversion void populateUtilToVMPatterns(MLIRContext *context, ConversionTarget &conversionTarget, TypeConverter &typeConverter, + ImportTable &importTable, RewritePatternSet &patterns) { patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); @@ -146,7 +148,7 @@ void populateUtilToVMPatterns(MLIRContext *context, populateUtilStatusToVMPatterns(context, conversionTarget, typeConverter, patterns); populateUtilStructuralToVMPatterns(context, conversionTarget, typeConverter, - patterns); + importTable, patterns); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.h index 4faf2d82b156..baa13fac11ee 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.h +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_DIALECT_VM_CONVERSION_UTILTOVM_PATTERNS_H_ #define IREE_COMPILER_DIALECT_VM_CONVERSION_UTILTOVM_PATTERNS_H_ +#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -16,6 +17,7 @@ namespace mlir::iree_compiler { void populateUtilToVMPatterns(MLIRContext *context, ConversionTarget &conversionTarget, TypeConverter &typeConverter, + ImportTable &importTable, RewritePatternSet &patterns); } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp index 266b1e1835ac..203a15b8a927 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp @@ -791,12 +791,12 @@ void RodataOp::build(OpBuilder &builder, OperationState &result, StringRef name, result.addAttributes(attrs); } -LogicalResult ConstRefRodataOp::verify() { +LogicalResult +ConstRefRodataOp::verifySymbolUses(SymbolTableCollection &symbolTable) { Operation *op = getOperation(); - auto *rodataOp = - op->getParentOfType().lookupSymbol(getRodata()); - if (!rodataOp) { - return op->emitOpError() << "Undefined rodata section: " << getRodata(); + if (!symbolTable.lookupNearestSymbolFrom(op, getRodataAttr())) { + return op->emitError() << "undefined rodata section: '" << getRodata() + << "'"; } return success(); } diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td index 6e9899d53a85..c23e687a8c6d 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td @@ -1137,6 +1137,7 @@ def VM_RodataOp : VM_Op<"rodata", [ } def VM_ConstRefRodataOp : VM_PureOp<"const.ref.rodata", [ + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]> { @@ -1170,8 +1171,6 @@ def VM_ConstRefRodataOp : VM_PureOp<"const.ref.rodata", [ OpBuilder<(ins "RodataOp":$rodataOp, CArg<"ArrayRef", "{}">:$attrs)>, ]; - - let hasVerifier = 1; } def VM_RodataInlineOp : VM_PureOp<"rodata.inline", [ diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp index 38c659a160eb..3d59c2b153f2 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp @@ -129,11 +129,14 @@ class ConversionPass } } + // Populated below after all type converters are registered. + ImportTable importTable; + RewritePatternSet patterns(&getContext()); populateUtilConversionPatterns(context, conversionTarget, typeConverter, patterns); populateUtilToVMPatterns(context, conversionTarget, typeConverter, - patterns); + importTable, patterns); conversionTarget.addIllegalDialect(); populateAffineToStdConversionPatterns(patterns); @@ -146,7 +149,7 @@ class ConversionPass populateMathToVMPatterns(context, typeConverter, patterns); conversionTarget.addIllegalDialect(); - populateStandardToVMPatterns(context, typeConverter, patterns); + populateStandardToVMPatterns(context, typeConverter, importTable, patterns); // Populate patterns from all used dialects, providing the imports they // registered earlier. @@ -156,6 +159,12 @@ class ConversionPass importSymbols, patterns, conversionTarget, typeConverter); } + // Build an import table so that we can quickly look up import information + // during conversion. + if (failed(importTable.build(innerModuleOp, typeConverter))) { + return signalPassFailure(); // error emitted already + } + if (failed(applyPartialConversion(outerModuleOp, conversionTarget, std::move(patterns)))) { outerModuleOp.emitError() << "conversion to vm.module failed"; diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp index 871087cbe74f..ffce82e60e58 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp @@ -93,28 +93,22 @@ static void fixupGlobalMutability(Operation *moduleOp, explorer.initialize(); SmallVector deadOps; explorer.forEachGlobal([&](const Explorer::GlobalInfo *globalInfo) { - if (globalInfo->uses.empty()) { - // No uses - erase the global entirely. - deadOps.push_back(globalInfo->op); - } else { - // TODO(benvanik): verify we want this behavior - we likely want to change - // this to be mutable only if stores exist outside of initializers. - // - // If there are stores mark the global as mutable. We need to update all - // of the loads if this changes anything. - bool hasStores = !globalInfo->getStores().empty(); - bool didChange = globalInfo->op.isGlobalMutable() != hasStores; + if (globalInfo->uses.empty()) + return; + // TODO(benvanik): verify we want this behavior - we likely want to change + // this to be mutable only if stores exist outside of initializers. + // + // If there are stores mark the global as mutable. We need to update all + // of the loads if this changes anything. + bool hasStores = !globalInfo->getStores().empty(); + bool didChange = globalInfo->op.isGlobalMutable() != hasStores; + if (didChange) { globalInfo->op.setGlobalMutable(hasStores); - if (didChange) { - for (auto loadOp : globalInfo->getLoads()) - loadOp.setGlobalImmutable(!hasStores); + for (auto loadOp : globalInfo->getLoads()) { + loadOp.setGlobalImmutable(!hasStores); } } - for (auto loadOp : globalInfo->getLoads()) - loadOp.setGlobalImmutable(!globalInfo->op.isGlobalMutable()); }); - for (auto *deadOp : deadOps) - deadOp->erase(); } } // namespace @@ -171,8 +165,7 @@ class GlobalInitializationPass InlinerInterface inlinerInterface(&getContext()); SmallVector deadOps; for (auto &op : moduleOp.getBlock().getOperations()) { - if (auto globalOp = dyn_cast(op)) { - } else if (auto globalOp = dyn_cast(op)) { + if (auto globalOp = dyn_cast(op)) { if (llvm::isa(globalOp.getGlobalType())) { if (failed(appendRefInitialization(globalOp, initBuilder))) { globalOp.emitOpError() diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp index 665d4f09cf3b..890d9c5eb0ce 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp @@ -35,21 +35,24 @@ static void addCleanupPatterns(OpPassManager &passManager) { // TODO(benvanik): run in a fixed-point iteration pipeline. // Standard MLIR cleanup. - passManager.addPass(mlir::createCanonicalizerPass()); - passManager.addPass(mlir::createCSEPass()); + FunctionLikeNest(passManager) + .addPass(mlir::createCanonicalizerPass) + .addPass(mlir::createCSEPass); // Aggressive MLIR cleanup. passManager.addNestedPass( IREE::VM::createDropUnusedCallsPass()); passManager.addPass(mlir::createSymbolDCEPass()); - // Simplify util.global accesses; this can help with data flow tracking as - // redundant store-loads are removed. FunctionLikeNest(passManager) - .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + // Simplify util.global accesses; this can help with data flow tracking as + // redundant store-loads are removed. + .addPass(IREE::Util::createSimplifyGlobalAccessesPass) + + // Aggressive cleanup. + .addPass(IREE::Util::createApplyPatternsPass); // Cleanup and canonicalization of util.global (and other util ops). - passManager.addPass(IREE::Util::createApplyPatternsPass()); passManager.addPass(IREE::Util::createFoldGlobalsPass()); passManager.addPass(IREE::Util::createFuseGlobalsPass()); } diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp index afee21cbbcd8..3fc56829d86b 100644 --- a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp @@ -108,10 +108,12 @@ static void addCleanupPatterns(OpPassManager &passManager) { // Simplify util.global accesses; this can help with data flow tracking as // redundant store-loads are removed. - .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + .addPass(IREE::Util::createSimplifyGlobalAccessesPass) + + // Aggressive cleanup. + .addPass(IREE::Util::createApplyPatternsPass); // Cleanup and canonicalization of util.global (and other util ops). - passManager.addPass(IREE::Util::createApplyPatternsPass()); passManager.addPass(IREE::Util::createFoldGlobalsPass()); passManager.addPass(IREE::Util::createFuseGlobalsPass()); diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index bd61d4b6ce76..94c78b9f0464 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -182,16 +182,17 @@ void buildGlobalOptimizationPassPipeline( FunctionLikeNest(mainPassManager) .addPass(createGlobalLoopInvariantCodeMotionPass) .addPass(IREE::Flow::createCanonicalizerPass) - .addPass(mlir::createCSEPass); + .addPass(mlir::createCSEPass) - // Simplify util.global accesses early on; this can help with dispatch - // region formation as redundant store-loads are removed. - FunctionLikeNest(mainPassManager) - .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + // Simplify util.global accesses early on; this can help with dispatch + // region formation as redundant store-loads are removed. + .addPass(IREE::Util::createSimplifyGlobalAccessesPass) + + // Aggressive cleanup. + .addPass(IREE::Util::createApplyPatternsPass); // Module level cleanup and canonicalization of util.global (and other // util ops). - mainPassManager.addPass(IREE::Util::createApplyPatternsPass()); mainPassManager.addPass(IREE::Util::createFoldGlobalsPass()); mainPassManager.addPass(IREE::Util::createIPOPass()); diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp index 0d68029cb45f..68b479933178 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp @@ -27,17 +27,19 @@ using FunctionLikeNest = //===----------------------------------------------------------------------===// static void addCleanupPatterns(OpPassManager &passManager) { - // Standard MLIR cleanup. - passManager.addPass(mlir::createCanonicalizerPass()); - passManager.addPass(mlir::createCSEPass()); - FunctionLikeNest(passManager) + // Standard MLIR cleanup. + .addPass(mlir::createCanonicalizerPass) + .addPass(mlir::createCSEPass) + // Simplify util.global accesses; this can help with data flow tracking as // redundant store-loads are removed. - .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + .addPass(IREE::Util::createSimplifyGlobalAccessesPass) + + // Aggressive cleanup. + .addPass(IREE::Util::createApplyPatternsPass); // Cleanup and canonicalization of util.global (and other util ops). - passManager.addPass(IREE::Util::createApplyPatternsPass()); passManager.addPass(IREE::Util::createFoldGlobalsPass()); passManager.addPass(IREE::Util::createFuseGlobalsPass()); } diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp index 96c7eb8fbef7..7adccce1514e 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp @@ -27,17 +27,19 @@ using FunctionLikeNest = //===----------------------------------------------------------------------===// static void addCleanupPatterns(OpPassManager &passManager) { - // Standard MLIR cleanup. - passManager.addPass(mlir::createCanonicalizerPass()); - passManager.addPass(mlir::createCSEPass()); - FunctionLikeNest(passManager) + // Standard MLIR cleanup. + .addPass(mlir::createCanonicalizerPass) + .addPass(mlir::createCSEPass) + // Simplify util.global accesses; this can help with data flow tracking as // redundant store-loads are removed. - .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + .addPass(IREE::Util::createSimplifyGlobalAccessesPass) + + // Aggressive cleanup. + .addPass(IREE::Util::createApplyPatternsPass); // Cleanup and canonicalization of util.global (and other util ops). - passManager.addPass(IREE::Util::createApplyPatternsPass()); passManager.addPass(IREE::Util::createFoldGlobalsPass()); passManager.addPass(IREE::Util::createFuseGlobalsPass()); }