diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index 1357e077d86c..a84205a58dd7 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp @@ -88,10 +88,12 @@ static void addCleanupPatterns(OpPassManager &passManager) { // Simplify util.global accesses; this can help with data flow tracking as // redundant store-loads are removed. - .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + .addPass(IREE::Util::createSimplifyGlobalAccessesPass) + + // Aggressive cleanup. + .addPass(IREE::Util::createApplyPatternsPass); // Cleanup and canonicalization of util.global (and other util ops). - passManager.addPass(IREE::Util::createApplyPatternsPass()); passManager.addPass(IREE::Util::createFoldGlobalsPass()); passManager.addPass(IREE::Util::createFuseGlobalsPass()); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index 07df5d87665d..1de6f1cbeb84 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp @@ -149,18 +149,20 @@ using FunctionLikeNest = //===----------------------------------------------------------------------===// static void addCleanupPatterns(OpPassManager &passManager) { - // Standard MLIR cleanup. - passManager.addPass(mlir::createCSEPass()); - passManager.addPass(mlir::createCanonicalizerPass()); - passManager.addPass(mlir::createCSEPass()); - // Simplify util.global accesses; this can help with data flow tracking as - // redundant store-loads are removed. FunctionLikeNest(passManager) - .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + // Standard MLIR cleanup. + .addPass(mlir::createCanonicalizerPass) + .addPass(mlir::createCSEPass) + + // Simplify util.global accesses; this can help with data flow tracking as + // redundant store-loads are removed. + .addPass(IREE::Util::createSimplifyGlobalAccessesPass) + + // Aggressive cleanup. + .addPass(IREE::Util::createApplyPatternsPass); // Cleanup and canonicalization of util.global (and other util ops). - passManager.addPass(IREE::Util::createApplyPatternsPass()); passManager.addPass(IREE::Util::createFoldGlobalsPass()); passManager.addPass(IREE::Util::createFuseGlobalsPass()); } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp index 408bb024f2f2..2234c62daa58 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp @@ -44,10 +44,12 @@ static void addCleanupPatterns(OpPassManager &passManager) { // Simplify util.global accesses; this can help with data flow tracking as // redundant store-loads are removed. - .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + .addPass(IREE::Util::createSimplifyGlobalAccessesPass) + + // Aggressive cleanup. + .addPass(IREE::Util::createApplyPatternsPass); // Cleanup and canonicalization of util.global (and other util ops). - passManager.addPass(IREE::Util::createApplyPatternsPass()); passManager.addPass(IREE::Util::createFoldGlobalsPass()); passManager.addPass(IREE::Util::createFuseGlobalsPass()); diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp index 745f0291f5d6..32c7819a66b3 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp @@ -126,20 +126,33 @@ void Explorer::initializeGlobalInfos() { // TODO(benvanik): filter the use list by traversal actions; where this runs // today we don't yet have the actions specified so we can't. + // Initialize the full list of globals. + for (auto globalOp : + symbolTableOp->getRegion(0).getOps()) { + auto globalInfo = std::make_unique(); + globalInfo->op = globalOp; + globalInfosByName[globalOp.getGlobalName().getValue()] = globalInfo.get(); + globalInfos[globalOp] = std::move(globalInfo); + } + + // Walk the module and gather uses. + // + // TODO: find a way to do this more efficiently when the module is large. + // We could parallelize on top-level functions and then merge at the end. auto allUses = symbolTable.getSymbolUses(&symbolTableOp->getRegion(0)); - if (!allUses.has_value()) - return; - for (auto use : allUses.value()) { - auto *symbolOp = - symbolTable.lookupNearestSymbolFrom(use.getUser(), use.getSymbolRef()); - if (!isa_and_nonnull(symbolOp)) - continue; - auto &globalInfo = globalInfos[symbolOp]; - globalInfo.op = cast(symbolOp); - if (isa(use.getUser())) { - globalInfo.isIndirect = true; - } else { - globalInfo.uses.push_back(use.getUser()); + if (allUses.has_value()) { + for (auto use : allUses.value()) { + auto globalInfoIt = globalInfosByName.find( + use.getSymbolRef().getLeafReference().getValue()); + if (globalInfoIt == globalInfosByName.end()) { + continue; // not a global + } + auto *globalInfo = globalInfoIt->second; + if (isa(use.getUser())) { + globalInfo->isIndirect = true; + } else { + globalInfo->uses.push_back(use.getUser()); + } } } } @@ -175,7 +188,7 @@ Explorer::getGlobalInfo(IREE::Util::GlobalOpInterface globalOp) { auto it = globalInfos.find(globalOp); if (it == globalInfos.end()) return nullptr; - return &it->second; + return it->second.get(); } const Explorer::GlobalInfo *Explorer::queryGlobalInfoFrom(StringRef globalName, @@ -189,12 +202,12 @@ const Explorer::GlobalInfo *Explorer::queryGlobalInfoFrom(StringRef globalName, auto it = globalInfos.find(op); if (it == globalInfos.end()) return nullptr; - return &it->second; + return it->second.get(); } void Explorer::forEachGlobal(std::function fn) { - for (auto it : globalInfos) { - fn(&it.second); + for (auto &it : globalInfos) { + fn(it.second.get()); } } diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h index 35ee12aa822b..4c9482b03ade 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h @@ -403,7 +403,8 @@ class Explorer { DenseMap dialectActions; DenseMap opActions; - DenseMap globalInfos; + DenseMap> globalInfos; + DenseMap globalInfosByName; ModuleAnalysisManager analysisManager; }; diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.cpp index 7fdc301fb365..70387f6de584 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.cpp @@ -69,6 +69,8 @@ void GlobalTable::rebuild() { globalOrder.push_back(globalName); } + // TODO: parallelize this by gathering on multiple threads per callable and + // then merging at the end. for (auto callableOp : moduleOp.getOps()) { if (auto uses = SymbolTable::getSymbolUses(callableOp)) { for (auto use : *uses) { diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp index ac527d792c0b..e7c722d46cff 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp @@ -46,7 +46,7 @@ struct FoldBlockArgumentsPattern if (!op.getCallableRegion()) return failure(); auto ®ion = *op.getCallableRegion(); - if (region.empty()) + if (region.empty() || region.hasOneBlock()) return failure(); // Analyze all branches in the op to compute the information we'll need to @@ -501,7 +501,6 @@ void populateCommonPatterns(MLIRContext *context, RewritePatternSet &patterns) { context->getOrLoadDialect() ->getCanonicalizationPatterns(patterns); - // TODO(benvanik): same as branch folding but for calls. patterns.insert( context); diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp index fa837d41c992..ea5e257a7728 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp @@ -18,6 +18,56 @@ namespace mlir::iree_compiler { +LogicalResult ImportTable::build(Operation *rootOp, + const TypeConverter &typeConverter) { + for (auto funcOp : rootOp->getRegion(0).getOps()) { + if (!funcOp.isExternal()) { + continue; // only external functions are imports + } + + ImportTable::Import import; + import.name = funcOp.getNameAttr(); + import.fallback = funcOp->getAttrOfType("vm.fallback"); + + // Try to use an assigned signature or fall back to converting the input. + if (auto importOp = dyn_cast(funcOp.getOperation())) { + // Import ops have their signature used directly. + import.signature = importOp.getFunctionType(); + } else if (auto signatureAttr = + funcOp->getAttrOfType("vm.signature")) { + // Directly use the specified signature. + import.signature = + dyn_cast_if_present(signatureAttr.getValue()); + } + if (!import.signature) { + // Convert the signature using the type converter. + SmallVector argumentTypes; + if (failed(typeConverter.convertTypes(funcOp.getArgumentTypes(), + argumentTypes))) { + return funcOp.emitError() << "unable to convert import argument types"; + } + SmallVector resultTypes; + if (failed(typeConverter.convertTypes(funcOp.getResultTypes(), + resultTypes))) { + return funcOp.emitError() << "unable to convert import result types"; + } + import.signature = + FunctionType::get(rootOp->getContext(), argumentTypes, resultTypes); + } + + symbols[import.name.getValue()] = std::move(import); + } + + return success(); +} + +std::optional ImportTable::find(StringRef symbolName) { + auto it = symbols.find(symbolName); + if (it == symbols.end()) + return std::nullopt; + return it->second; +} + // TODO(benvanik): replace with iree/compiler/Utils/ModuleUtils.h. // There may be some special insertion order arrangement required based on the // nested vm.module here. diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h index b2f0a8f74c60..c5d557f5b42d 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h @@ -20,6 +20,33 @@ namespace mlir::iree_compiler { // segment_sizes array. constexpr int kFixedSingleValue = -1; +// A table of import information. +class ImportTable { +public: + // Information about an import function. + struct Import { + // Used to ensure the StringRef in the map stays live. + StringAttr name; + // Function signature derived from the type or overridden by `vm.signature`. + FunctionType signature; + // Optional fallback function that should be used when the import is + // unavailable at runtime taken from `vm.fallback`. + SymbolRefAttr fallback; + }; + + // Builds a table of all import functions nested within the given |rootOp|. + // Clones any information such that the original ops can be mutated/erased. + // Must only be called once the type converter has been fully populated. + LogicalResult build(Operation *rootOp, const TypeConverter &typeConverter); + + // Finds an import with the given name if there exists one. + std::optional find(StringRef symbolName); + +private: + // Map of symbol names within the root op to import symbol info. + DenseMap symbols; +}; + // Appends a set of vm.import ops from a module to a target VM module. // Imports will only be added if they are not already present in the target // module. diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp index f9a5c264cba8..1cb4f413883b 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp @@ -170,7 +170,6 @@ static void copyImportAttrs(func::FuncOp srcOp, IREE::VM::ImportOp dstOp) { constexpr const char *kRetainedAttributes[] = { "nosideeffects", "vm.fallback", - "vm.signature", }; auto retainedAttributes = ArrayRef( kRetainedAttributes, @@ -241,7 +240,11 @@ struct ExternalFuncOpConversion : public OpConversionPattern { }; struct CallOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + ImportTable &importTable; + CallOpConversion(const TypeConverter &typeConverter, MLIRContext *context, + ImportTable &importTable, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + importTable(importTable) {} LogicalResult matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -262,7 +265,8 @@ struct CallOpConversion : public OpConversionPattern { // conversion if imports have fallbacks that are themselves imports. auto callResults = convertCallOp( callOp->getParentOfType(), callOp.getLoc(), - callOp.getCallee(), adaptor.getOperands(), resultTypes, rewriter); + callOp.getCallee(), adaptor.getOperands(), resultTypes, importTable, + rewriter); if (failed(callResults)) { return rewriter.notifyMatchFailure( callOp, "unable to convert call (results mismatch)"); @@ -277,36 +281,12 @@ struct CallOpConversion : public OpConversionPattern { FailureOr> convertCallOp(Operation *rootOp, Location loc, StringRef calleeName, ValueRange operands, TypeRange resultTypes, + ImportTable &importTable, ConversionPatternRewriter &rewriter) const { - // (Slow) lookup of the target function, which may be an import that we need - // to perform type conversion for. - auto calleeOp = SymbolTable::lookupSymbolIn(rootOp, calleeName); - if (auto funcOp = dyn_cast_or_null(calleeOp)) { - if (funcOp.isExternal()) { - // Import that may require conversion. - // This case handles when funcs are declared after the call. - FunctionType convertedSignature; - if (auto signatureAttr = - funcOp->getAttrOfType("vm.signature")) { - if (auto importSignature = - llvm::dyn_cast(signatureAttr.getValue())) { - convertedSignature = importSignature; - } - } - if (!convertedSignature) { - convertedSignature = - rewriter.getFunctionType(TypeRange(operands), resultTypes); - } - return convertImportCallOp(rootOp, loc, calleeName, operands, - resultTypes, convertedSignature, funcOp, - rewriter); - } - } else if (auto importOp = dyn_cast_or_null(calleeOp)) { - // Calling an import. - // This case handles when funcs are declared before the call and have - // already been converted. - return convertImportCallOp(rootOp, loc, calleeName, operands, resultTypes, - importOp.getFunctionType(), importOp, + // Lookup the target and detect if it is an import. + auto import = importTable.find(calleeName); + if (import.has_value()) { + return convertImportCallOp(rootOp, loc, *import, operands, resultTypes, rewriter); } @@ -319,19 +299,19 @@ struct CallOpConversion : public OpConversionPattern { // Converts a call to an import that may be optional. // Returns the new converted call results. FailureOr> - convertImportCallOp(Operation *rootOp, Location loc, StringRef calleeName, - ValueRange operands, TypeRange resultTypes, - FunctionType importSignature, Operation *calleeOp, + convertImportCallOp(Operation *rootOp, Location loc, + ImportTable::Import &import, ValueRange operands, + TypeRange resultTypes, ConversionPatternRewriter &rewriter) const { - auto fallbackAttr = calleeOp->getAttrOfType("vm.fallback"); - return fallbackAttr - ? convertOptionalImportCallOp( - rootOp, loc, calleeName, operands, resultTypes, - importSignature, - fallbackAttr.getLeafReference().getValue(), rewriter) - : convertMandatoryImportCallOp(rootOp, loc, calleeName, operands, - resultTypes, importSignature, - rewriter); + if (import.fallback) { + return convertOptionalImportCallOp( + rootOp, loc, import.name, operands, resultTypes, import.signature, + import.fallback.getLeafReference().getValue(), rewriter); + } else { + return convertMandatoryImportCallOp(rootOp, loc, import.name, operands, + resultTypes, import.signature, + rewriter); + } } // Converts a call to an optional import by adding logic to check whether it @@ -374,7 +354,7 @@ struct CallOpConversion : public OpConversionPattern { // Not resolved: call fallback as a normal function. rewriter.setInsertionPointToStart(fallbackBlock); auto fallbackResults = convertCallOp(rootOp, loc, fallbackName, operands, - resultTypes, rewriter); + resultTypes, importTable, rewriter); if (failed(fallbackResults)) return failure(); rewriter.create(loc, exitBlock, *fallbackResults); @@ -557,12 +537,14 @@ struct SwitchOpConversion : public OpConversionPattern { void populateStandardToVMPatterns(MLIRContext *context, TypeConverter &typeConverter, + ImportTable &importTable, RewritePatternSet &patterns) { patterns - .insert( - typeConverter, context); + .insert(typeConverter, + context); + patterns.insert(typeConverter, context, importTable); patterns.insert>( typeConverter, context); } diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.h index 37af2bfc03f0..b26e9ffae46a 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.h +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_DIALECT_VM_CONVERSION_STANDARDTOVM_CONVERTSTANDARDTOVM_H_ #define IREE_COMPILER_DIALECT_VM_CONVERSION_STANDARDTOVM_CONVERTSTANDARDTOVM_H_ +#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -15,6 +16,7 @@ namespace mlir::iree_compiler { // Appends standard dialect to vm dialect patterns to the given pattern list. void populateStandardToVMPatterns(MLIRContext *context, TypeConverter &typeConverter, + ImportTable &importTable, RewritePatternSet &patterns); } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp index 87f0b561e1bf..bc0ce196209f 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp @@ -146,7 +146,6 @@ static void copyImportAttrs(IREE::Util::FuncOp srcOp, constexpr const char *kRetainedAttributes[] = { "nosideeffects", "vm.fallback", - "vm.signature", }; auto retainedAttributes = ArrayRef( kRetainedAttributes, @@ -217,8 +216,12 @@ class ExternalFuncOpConversion } }; -class CallOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct CallOpConversion : public OpConversionPattern { + ImportTable &importTable; + CallOpConversion(const TypeConverter &typeConverter, MLIRContext *context, + ImportTable &importTable, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + importTable(importTable) {} LogicalResult matchAndRewrite(IREE::Util::CallOp callOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -255,35 +258,10 @@ class CallOpConversion : public OpConversionPattern { convertCallOp(Operation *rootOp, Location loc, StringRef calleeName, ValueRange operands, TypeRange resultTypes, ConversionPatternRewriter &rewriter) const { - // (Slow) lookup of the target function, which may be an import that we need - // to perform type conversion for. - auto calleeOp = SymbolTable::lookupSymbolIn(rootOp, calleeName); - if (auto funcOp = dyn_cast_or_null(calleeOp)) { - if (funcOp.isExternal()) { - // Import that may require conversion. - // This case handles when funcs are declared after the call. - FunctionType convertedSignature; - if (auto signatureAttr = - funcOp->getAttrOfType("vm.signature")) { - if (auto importSignature = - llvm::dyn_cast(signatureAttr.getValue())) { - convertedSignature = importSignature; - } - } - if (!convertedSignature) { - convertedSignature = - rewriter.getFunctionType(TypeRange(operands), resultTypes); - } - return convertImportCallOp(rootOp, loc, calleeName, operands, - resultTypes, convertedSignature, funcOp, - rewriter); - } - } else if (auto importOp = dyn_cast_or_null(calleeOp)) { - // Calling an import. - // This case handles when funcs are declared before the call and have - // already been converted. - return convertImportCallOp(rootOp, loc, calleeName, operands, resultTypes, - importOp.getFunctionType(), importOp, + // Lookup the target and detect if it is an import. + auto import = importTable.find(calleeName); + if (import.has_value()) { + return convertImportCallOp(rootOp, loc, *import, operands, resultTypes, rewriter); } @@ -296,19 +274,19 @@ class CallOpConversion : public OpConversionPattern { // Converts a call to an import that may be optional. // Returns the new converted call results. FailureOr> - convertImportCallOp(Operation *rootOp, Location loc, StringRef calleeName, - ValueRange operands, TypeRange resultTypes, - FunctionType importSignature, Operation *calleeOp, + convertImportCallOp(Operation *rootOp, Location loc, + ImportTable::Import &import, ValueRange operands, + TypeRange resultTypes, ConversionPatternRewriter &rewriter) const { - auto fallbackAttr = calleeOp->getAttrOfType("vm.fallback"); - return fallbackAttr - ? convertOptionalImportCallOp( - rootOp, loc, calleeName, operands, resultTypes, - importSignature, - fallbackAttr.getLeafReference().getValue(), rewriter) - : convertMandatoryImportCallOp(rootOp, loc, calleeName, operands, - resultTypes, importSignature, - rewriter); + if (import.fallback) { + return convertOptionalImportCallOp( + rootOp, loc, import.name, operands, resultTypes, import.signature, + import.fallback.getLeafReference().getValue(), rewriter); + } else { + return convertMandatoryImportCallOp(rootOp, loc, import.name, operands, + resultTypes, import.signature, + rewriter); + } } // Converts a call to an optional import by adding logic to check whether it @@ -405,13 +383,14 @@ struct ReturnOpConversion : public OpConversionPattern { void populateUtilStructuralToVMPatterns(MLIRContext *context, ConversionTarget &conversionTarget, TypeConverter &typeConverter, + ImportTable &importTable, RewritePatternSet &patterns) { conversionTarget.addIllegalOp(); - patterns - .insert( - typeConverter, context); + patterns.insert(typeConverter, + context); + patterns.insert(typeConverter, context, importTable); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp index 2d60a81915c9..f257369bffeb 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp @@ -45,6 +45,7 @@ void populateUtilStatusToVMPatterns(MLIRContext *context, void populateUtilStructuralToVMPatterns(MLIRContext *context, ConversionTarget &conversionTarget, TypeConverter &typeConverter, + ImportTable &importTable, RewritePatternSet &patterns); namespace { @@ -127,6 +128,7 @@ struct UnreachableOpConversion void populateUtilToVMPatterns(MLIRContext *context, ConversionTarget &conversionTarget, TypeConverter &typeConverter, + ImportTable &importTable, RewritePatternSet &patterns) { patterns.insert(typeConverter, context); patterns.insert(typeConverter, context); @@ -146,7 +148,7 @@ void populateUtilToVMPatterns(MLIRContext *context, populateUtilStatusToVMPatterns(context, conversionTarget, typeConverter, patterns); populateUtilStructuralToVMPatterns(context, conversionTarget, typeConverter, - patterns); + importTable, patterns); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.h index 4faf2d82b156..baa13fac11ee 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.h +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_DIALECT_VM_CONVERSION_UTILTOVM_PATTERNS_H_ #define IREE_COMPILER_DIALECT_VM_CONVERSION_UTILTOVM_PATTERNS_H_ +#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -16,6 +17,7 @@ namespace mlir::iree_compiler { void populateUtilToVMPatterns(MLIRContext *context, ConversionTarget &conversionTarget, TypeConverter &typeConverter, + ImportTable &importTable, RewritePatternSet &patterns); } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp index 266b1e1835ac..203a15b8a927 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp @@ -791,12 +791,12 @@ void RodataOp::build(OpBuilder &builder, OperationState &result, StringRef name, result.addAttributes(attrs); } -LogicalResult ConstRefRodataOp::verify() { +LogicalResult +ConstRefRodataOp::verifySymbolUses(SymbolTableCollection &symbolTable) { Operation *op = getOperation(); - auto *rodataOp = - op->getParentOfType().lookupSymbol(getRodata()); - if (!rodataOp) { - return op->emitOpError() << "Undefined rodata section: " << getRodata(); + if (!symbolTable.lookupNearestSymbolFrom(op, getRodataAttr())) { + return op->emitError() << "undefined rodata section: '" << getRodata() + << "'"; } return success(); } diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td index 6e9899d53a85..c23e687a8c6d 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td @@ -1137,6 +1137,7 @@ def VM_RodataOp : VM_Op<"rodata", [ } def VM_ConstRefRodataOp : VM_PureOp<"const.ref.rodata", [ + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]> { @@ -1170,8 +1171,6 @@ def VM_ConstRefRodataOp : VM_PureOp<"const.ref.rodata", [ OpBuilder<(ins "RodataOp":$rodataOp, CArg<"ArrayRef", "{}">:$attrs)>, ]; - - let hasVerifier = 1; } def VM_RodataInlineOp : VM_PureOp<"rodata.inline", [ diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp index 38c659a160eb..3d59c2b153f2 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp @@ -129,11 +129,14 @@ class ConversionPass } } + // Populated below after all type converters are registered. + ImportTable importTable; + RewritePatternSet patterns(&getContext()); populateUtilConversionPatterns(context, conversionTarget, typeConverter, patterns); populateUtilToVMPatterns(context, conversionTarget, typeConverter, - patterns); + importTable, patterns); conversionTarget.addIllegalDialect(); populateAffineToStdConversionPatterns(patterns); @@ -146,7 +149,7 @@ class ConversionPass populateMathToVMPatterns(context, typeConverter, patterns); conversionTarget.addIllegalDialect(); - populateStandardToVMPatterns(context, typeConverter, patterns); + populateStandardToVMPatterns(context, typeConverter, importTable, patterns); // Populate patterns from all used dialects, providing the imports they // registered earlier. @@ -156,6 +159,12 @@ class ConversionPass importSymbols, patterns, conversionTarget, typeConverter); } + // Build an import table so that we can quickly look up import information + // during conversion. + if (failed(importTable.build(innerModuleOp, typeConverter))) { + return signalPassFailure(); // error emitted already + } + if (failed(applyPartialConversion(outerModuleOp, conversionTarget, std::move(patterns)))) { outerModuleOp.emitError() << "conversion to vm.module failed"; diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp index 871087cbe74f..ffce82e60e58 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp @@ -93,28 +93,22 @@ static void fixupGlobalMutability(Operation *moduleOp, explorer.initialize(); SmallVector deadOps; explorer.forEachGlobal([&](const Explorer::GlobalInfo *globalInfo) { - if (globalInfo->uses.empty()) { - // No uses - erase the global entirely. - deadOps.push_back(globalInfo->op); - } else { - // TODO(benvanik): verify we want this behavior - we likely want to change - // this to be mutable only if stores exist outside of initializers. - // - // If there are stores mark the global as mutable. We need to update all - // of the loads if this changes anything. - bool hasStores = !globalInfo->getStores().empty(); - bool didChange = globalInfo->op.isGlobalMutable() != hasStores; + if (globalInfo->uses.empty()) + return; + // TODO(benvanik): verify we want this behavior - we likely want to change + // this to be mutable only if stores exist outside of initializers. + // + // If there are stores mark the global as mutable. We need to update all + // of the loads if this changes anything. + bool hasStores = !globalInfo->getStores().empty(); + bool didChange = globalInfo->op.isGlobalMutable() != hasStores; + if (didChange) { globalInfo->op.setGlobalMutable(hasStores); - if (didChange) { - for (auto loadOp : globalInfo->getLoads()) - loadOp.setGlobalImmutable(!hasStores); + for (auto loadOp : globalInfo->getLoads()) { + loadOp.setGlobalImmutable(!hasStores); } } - for (auto loadOp : globalInfo->getLoads()) - loadOp.setGlobalImmutable(!globalInfo->op.isGlobalMutable()); }); - for (auto *deadOp : deadOps) - deadOp->erase(); } } // namespace @@ -171,8 +165,7 @@ class GlobalInitializationPass InlinerInterface inlinerInterface(&getContext()); SmallVector deadOps; for (auto &op : moduleOp.getBlock().getOperations()) { - if (auto globalOp = dyn_cast(op)) { - } else if (auto globalOp = dyn_cast(op)) { + if (auto globalOp = dyn_cast(op)) { if (llvm::isa(globalOp.getGlobalType())) { if (failed(appendRefInitialization(globalOp, initBuilder))) { globalOp.emitOpError() diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp index 665d4f09cf3b..890d9c5eb0ce 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp @@ -35,21 +35,24 @@ static void addCleanupPatterns(OpPassManager &passManager) { // TODO(benvanik): run in a fixed-point iteration pipeline. // Standard MLIR cleanup. - passManager.addPass(mlir::createCanonicalizerPass()); - passManager.addPass(mlir::createCSEPass()); + FunctionLikeNest(passManager) + .addPass(mlir::createCanonicalizerPass) + .addPass(mlir::createCSEPass); // Aggressive MLIR cleanup. passManager.addNestedPass( IREE::VM::createDropUnusedCallsPass()); passManager.addPass(mlir::createSymbolDCEPass()); - // Simplify util.global accesses; this can help with data flow tracking as - // redundant store-loads are removed. FunctionLikeNest(passManager) - .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + // Simplify util.global accesses; this can help with data flow tracking as + // redundant store-loads are removed. + .addPass(IREE::Util::createSimplifyGlobalAccessesPass) + + // Aggressive cleanup. + .addPass(IREE::Util::createApplyPatternsPass); // Cleanup and canonicalization of util.global (and other util ops). - passManager.addPass(IREE::Util::createApplyPatternsPass()); passManager.addPass(IREE::Util::createFoldGlobalsPass()); passManager.addPass(IREE::Util::createFuseGlobalsPass()); } diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp index afee21cbbcd8..3fc56829d86b 100644 --- a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp @@ -108,10 +108,12 @@ static void addCleanupPatterns(OpPassManager &passManager) { // Simplify util.global accesses; this can help with data flow tracking as // redundant store-loads are removed. - .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + .addPass(IREE::Util::createSimplifyGlobalAccessesPass) + + // Aggressive cleanup. + .addPass(IREE::Util::createApplyPatternsPass); // Cleanup and canonicalization of util.global (and other util ops). - passManager.addPass(IREE::Util::createApplyPatternsPass()); passManager.addPass(IREE::Util::createFoldGlobalsPass()); passManager.addPass(IREE::Util::createFuseGlobalsPass()); diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index bd61d4b6ce76..94c78b9f0464 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -182,16 +182,17 @@ void buildGlobalOptimizationPassPipeline( FunctionLikeNest(mainPassManager) .addPass(createGlobalLoopInvariantCodeMotionPass) .addPass(IREE::Flow::createCanonicalizerPass) - .addPass(mlir::createCSEPass); + .addPass(mlir::createCSEPass) - // Simplify util.global accesses early on; this can help with dispatch - // region formation as redundant store-loads are removed. - FunctionLikeNest(mainPassManager) - .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + // Simplify util.global accesses early on; this can help with dispatch + // region formation as redundant store-loads are removed. + .addPass(IREE::Util::createSimplifyGlobalAccessesPass) + + // Aggressive cleanup. + .addPass(IREE::Util::createApplyPatternsPass); // Module level cleanup and canonicalization of util.global (and other // util ops). - mainPassManager.addPass(IREE::Util::createApplyPatternsPass()); mainPassManager.addPass(IREE::Util::createFoldGlobalsPass()); mainPassManager.addPass(IREE::Util::createIPOPass()); diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp index 0d68029cb45f..68b479933178 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp @@ -27,17 +27,19 @@ using FunctionLikeNest = //===----------------------------------------------------------------------===// static void addCleanupPatterns(OpPassManager &passManager) { - // Standard MLIR cleanup. - passManager.addPass(mlir::createCanonicalizerPass()); - passManager.addPass(mlir::createCSEPass()); - FunctionLikeNest(passManager) + // Standard MLIR cleanup. + .addPass(mlir::createCanonicalizerPass) + .addPass(mlir::createCSEPass) + // Simplify util.global accesses; this can help with data flow tracking as // redundant store-loads are removed. - .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + .addPass(IREE::Util::createSimplifyGlobalAccessesPass) + + // Aggressive cleanup. + .addPass(IREE::Util::createApplyPatternsPass); // Cleanup and canonicalization of util.global (and other util ops). - passManager.addPass(IREE::Util::createApplyPatternsPass()); passManager.addPass(IREE::Util::createFoldGlobalsPass()); passManager.addPass(IREE::Util::createFuseGlobalsPass()); } diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp index 96c7eb8fbef7..7adccce1514e 100644 --- a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp @@ -27,17 +27,19 @@ using FunctionLikeNest = //===----------------------------------------------------------------------===// static void addCleanupPatterns(OpPassManager &passManager) { - // Standard MLIR cleanup. - passManager.addPass(mlir::createCanonicalizerPass()); - passManager.addPass(mlir::createCSEPass()); - FunctionLikeNest(passManager) + // Standard MLIR cleanup. + .addPass(mlir::createCanonicalizerPass) + .addPass(mlir::createCSEPass) + // Simplify util.global accesses; this can help with data flow tracking as // redundant store-loads are removed. - .addPass(IREE::Util::createSimplifyGlobalAccessesPass); + .addPass(IREE::Util::createSimplifyGlobalAccessesPass) + + // Aggressive cleanup. + .addPass(IREE::Util::createApplyPatternsPass); // Cleanup and canonicalization of util.global (and other util ops). - passManager.addPass(IREE::Util::createApplyPatternsPass()); passManager.addPass(IREE::Util::createFoldGlobalsPass()); passManager.addPass(IREE::Util::createFuseGlobalsPass()); }