From 268b6de58fecfb3a9b314ee3baab771b82ea0488 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 28 Jul 2023 11:57:59 -0700 Subject: [PATCH] Rework consteval to be more memory efficient. (#14504) * Also adds an option to use the llvm-cpu backend instead of VMVX. This isn't quite ready for primetime because I need to isolate some flags better, but it can be used in a pinch. * Changes the way that JIT programs are constructed so that stateless functions are compiled at once for all constants and the data is passed in/out vs being cloned from module globals. * Each global has a public function to compute it, allowing us to run them one by one or in batches to keep memory use under control. * In this patch, I am just running them one by one and also haven't optimized memory by use of resources yet. Still a big improvement. * Drops compilation time of h2ogpt from ~45m to 2m47s. More optimization is possible. * JIT'ing globals now takes 52s for this model (vs ~42m with the Linalg const evaler). * Memory pressure is kept under control and does not regress from the current state (more optimization is possible - just a starting point). I also changed the behavior of the constexpr hoisting and jit passes to produce/process initializers with the `iree.compiler.consteval` attribute. This was advisable because in the new way of doing evaluation, we are doing a non-general transformation on the initializers and I didn't want it stumbling over arbitrary initializers in the wild. Having the behavior be opt-in seemed prudent. Flag changes: * `iree-consteval-jit-use-vmvx` (default true): Uses the VMVX backend. When false, uses the LLVMCPU backend. I'll be doing work to change this to false by default when LLVMCPU is available. * `iree-consteval-jit-debug`: Prints debugging information about constant evaluation. * `iree-opt-const-eval`: Flipped to true since it now only processes initializers targeted at it and is therefore safe to always have enabled. --- .../src/iree/compiler/API/Internal/Embed.cpp | 8 +- .../src/iree/compiler/ConstEval/BUILD.bazel | 2 + .../iree/compiler/ConstEval/CMakeLists.txt | 2 + .../iree/compiler/ConstEval/JitGlobals.cpp | 586 +++++++++++++----- compiler/src/iree/compiler/ConstEval/Passes.h | 8 + .../src/iree/compiler/ConstEval/Runtime.cpp | 268 +++++--- .../src/iree/compiler/ConstEval/Runtime.h | 36 +- .../compiler/ConstEval/test/jit_globals.mlir | 36 +- .../Util/Transforms/HoistIntoGlobals.cpp | 3 + .../Transforms/test/hoist_into_globals.mlir | 10 +- .../src/iree/compiler/Pipelines/Options.cpp | 2 +- .../src/iree/compiler/Pipelines/Options.h | 2 +- .../docs/reference/optimization-options.md | 2 +- 13 files changed, 682 insertions(+), 283 deletions(-) diff --git a/compiler/src/iree/compiler/API/Internal/Embed.cpp b/compiler/src/iree/compiler/API/Internal/Embed.cpp index b96dbb07d731..20f29796e098 100644 --- a/compiler/src/iree/compiler/API/Internal/Embed.cpp +++ b/compiler/src/iree/compiler/API/Internal/Embed.cpp @@ -581,9 +581,11 @@ Invocation::Invocation(Session &session) // Since the jitter invokes much of the top-level compiler recursively, // it must be injected at the top-level here vs in the pass pipeline // (or else the circular dependency cannot be resolved). - pipelineHooks.buildConstEvalPassPipelineCallback = [](OpPassManager &pm) { - pm.addPass(ConstEval::createJitGlobalsPass()); - }; + auto &targetRegistry = session.targetRegistry; + pipelineHooks.buildConstEvalPassPipelineCallback = + [&targetRegistry](OpPassManager &pm) { + pm.addPass(ConstEval::createJitGlobalsPass(targetRegistry)); + }; // The PluginSession implements PipelineExtensions and delegates it to // activated plugins. pipelineHooks.pipelineExtensions = &session.pluginSession; diff --git a/compiler/src/iree/compiler/ConstEval/BUILD.bazel b/compiler/src/iree/compiler/ConstEval/BUILD.bazel index 2b5b9b00d662..1771f7c0e6ec 100644 --- a/compiler/src/iree/compiler/ConstEval/BUILD.bazel +++ b/compiler/src/iree/compiler/ConstEval/BUILD.bazel @@ -57,6 +57,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Pipelines", "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -79,6 +80,7 @@ iree_compiler_cc_library( "//runtime/src/iree/tooling:vm_util", "//runtime/src/iree/vm", "//runtime/src/iree/vm/bytecode:module", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", ], ) diff --git a/compiler/src/iree/compiler/ConstEval/CMakeLists.txt b/compiler/src/iree/compiler/ConstEval/CMakeLists.txt index 94757910488f..e13efe8af57a 100644 --- a/compiler/src/iree/compiler/ConstEval/CMakeLists.txt +++ b/compiler/src/iree/compiler/ConstEval/CMakeLists.txt @@ -46,6 +46,7 @@ iree_cc_library( ::PassesIncGen ::Runtime LLVMSupport + MLIRArithDialect MLIRFuncDialect MLIRIR MLIRPass @@ -62,6 +63,7 @@ iree_cc_library( SRCS "Runtime.cpp" DEPS + LLVMSupport MLIRIR iree::compiler::Dialect::VM::Target::Bytecode iree::hal diff --git a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp index 1d6d7a8820f4..1f2a56da622e 100644 --- a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp +++ b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp @@ -11,9 +11,12 @@ #include "iree/compiler/Utils/PassUtils.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/Timer.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/SymbolTable.h" @@ -25,119 +28,297 @@ namespace mlir { namespace iree_compiler { namespace ConstEval { +static llvm::cl::opt clUseVMVX( + "iree-consteval-jit-use-vmvx", + llvm::cl::desc( + "Uses VMVX (reference backend) instead of the full CPU compiler."), + llvm::cl::init(true)); + +static llvm::cl::opt clEnableDebug( + "iree-consteval-jit-debug", + llvm::cl::desc( + "Prints debugging information to stderr (useful since when consteval " + "has issues, it is often in production on the largest models where we " + "don't want to run a debug compiler)."), + llvm::cl::init(false)); + namespace { -struct ProgramExtractor { +// These options structs are not copy-constructable so we have to allocate them +// shared. +// TODO: See if we can make them copyable? +struct CompileOptions { + BindingOptions bindingOptions; + InputDialectOptions inputOptions; + PreprocessingOptions preprocessingOptions; + HighLevelOptimizationOptions highLevelOptimizationOptions; + SchedulingOptions schedulingOptions; + IREE::HAL::TargetOptions executableOptions; + IREE::VM::TargetOptions targetOptions; + IREEVMPipelineHooks hooks; +}; + +// Supported types vary by backend and other factors, so we track them here. +// Types that cross the ABI boundary are configured here. +class SupportedFeatures { public: - ProgramExtractor(Operation *sourceModuleOp, Operation *targetModuleOp) - : sourceSymbolTable(sourceModuleOp), targetSymbolTable(targetModuleOp), - builder(OpBuilder::atBlockEnd(&targetModuleOp->getRegion(0).front())) {} - - // Creates an accessor function to load the given global value. - // Returns the created symbol name. - StringAttr createAccessor(IREE::Util::GlobalOp globalOp) { - Location loc = globalOp.getLoc(); - std::string name = (llvm::Twine("get$") + globalOp.getSymName()).str(); - Type globalType = globalOp.getType(); - auto funcType = - builder.getType(TypeRange{}, TypeRange{globalType}); - auto funcOp = func::FuncOp::create(loc, name, funcType); - StringAttr funcSymbolName = targetSymbolTable.insert(funcOp); - Block *entryBlock = funcOp.addEntryBlock(); - - OpBuilder funcBuilder = OpBuilder::atBlockEnd(entryBlock); - Value loadValue = - funcBuilder.create(loc, globalOp); - funcBuilder.create(loc, ValueRange{loadValue}); - return funcSymbolName; + void addScalarType(Type t) { scalarTypes.insert(t); } + void addElementType(Type t) { elementTypes.insert(t); } + + bool supportsScalarType(Type t) const { return scalarTypes.contains(t); } + + bool supportsElementType(Type t) const { return elementTypes.contains(t); } + + bool isSupportedAbiType(Type t) const { + if (auto tensorType = llvm::dyn_cast(t)) { + return supportsElementType(tensorType.getElementType()); + } else { + return supportsScalarType(t); + } } - // Imports an op from the source module into the target. Cannot be used to - // import symbols. - void importOperation(Operation *sourceOp) { - Operation *targetOp = sourceOp->clone(); - builder.insert(targetOp); - scanDependentSymbols(targetOp); +private: + llvm::DenseSet scalarTypes; + llvm::DenseSet elementTypes; +}; + +// JIT functions take arguments, generally from the source program. We capture +// them here. +class ArgumentBinding { +public: + enum class Type { + // An ElementsAttr. + ElementsAttr, + + // The value of a GlobalOp. It may not be set at the start of the run + // if there is a dependency that evaluates first. + GlobalOp, + }; + + ArgumentBinding(ElementsAttr attr) + : type(Type::ElementsAttr), elementsAttr(attr) {} + ArgumentBinding(IREE::Util::GlobalOp globalOp) + : type(Type::GlobalOp), globalOp(globalOp) {} + + Type getType() { return type; } + + ElementsAttr getElementsAttr() { + assert(type == Type::ElementsAttr); + return elementsAttr; } - // Imports any dependencies. Should be called after all user-required imports - // are completed. - LogicalResult importDependencies() { - SmallVector iterWorklist; + IREE::Util::GlobalOp getGlobalOp() { + assert(type == Type::GlobalOp); + return globalOp; + } - while (!symbolImportWorklist.empty()) { - iterWorklist.clear(); - iterWorklist.swap(symbolImportWorklist); +private: + Type type; + ElementsAttr elementsAttr; + IREE::Util::GlobalOp globalOp; +}; - for (StringAttr symbolRef : iterWorklist) { - if (targetSymbolTable.lookup(symbolRef)) - continue; +// How to bind results to the original program. +class ResultBinding { +public: + enum class Type { + // Set the result on the global op. + GlobalOp, + }; - Operation *sourceOp = sourceSymbolTable.lookup(symbolRef); - if (!sourceOp) { - return mlir::emitError(targetSymbolTable.getOp()->getLoc()) - << "symbol not found while building jit-eval module: " - << symbolRef; - } + ResultBinding(IREE::Util::GlobalOp globalOp) + : type(Type::GlobalOp), globalOp(globalOp) {} - // Insert at top as ordering is respected. - auto ip = targetSymbolTable.getOp()->getRegion(0).front().begin(); - Operation *targetOp = sourceOp->clone(); - targetSymbolTable.insert(targetOp, ip); - scanDependentSymbols(targetOp); - } + Type getType() { return type; } + + IREE::Util::GlobalOp getGlobalOp() { + assert(type == Type::GlobalOp); + return globalOp; + } + +private: + Type type; + ElementsAttr elementsAttr; + IREE::Util::GlobalOp globalOp; +}; + +// Description of a JIT function that we have created for doing some +// initialization work. +struct JitFunctionDesc { + JitFunctionDesc(Location loc, std::string name) + : loc(loc), name(std::move(name)) {} + Location loc; + std::string name; + llvm::SmallVector argumentBindings; + llvm::SmallVector resultBindings; +}; + +class ProgramBuilder { +public: + ProgramBuilder(ModuleOp sourceModuleOp, + const SupportedFeatures &supportedFeatures) + : targetModuleOp(createInnerModule(sourceModuleOp)), + sourceSymbolTable(sourceModuleOp), targetSymbolTable(targetModuleOp), + supportedFeatures(supportedFeatures) {} + + llvm::SmallVector &getJitFunctions() { return jitFunctions; } + ModuleOp getTargetModule() { return targetModuleOp; } + + LogicalResult importInitializer(IREE::Util::InitializerOp initOp) { + // We convert each initializer into a public FuncOp by converting each: + // - Tensor constant into an argument + // - util.global_load into an argument + // - util.global_store into a result + // It is considered an eval'able initializer if it contains stores + // into immutable global(s). In the future, we will also want to + // condition this on an attribute so as to not try to statically + // compile dynamic initializers. + // Build it into a new function. + if (!initOp.getBody().hasOneBlock()) { + // It would be possible to support these in theory but unclear if + // worth it in practice. + emitWarning(initOp.getLoc()) + << "skipping consteval initializer: initializers with >1 block not " + "yet supported"; + return failure(); + } + + OpBuilder moduleBuilder = OpBuilder::atBlockEnd(targetModuleOp.getBody()); + auto funcOp = moduleBuilder.create( + initOp.getLoc(), "jit_eval", moduleBuilder.getFunctionType({}, {})); + targetSymbolTable.insert(funcOp); + IRMapping unusedMapping; + initOp.getBody().cloneInto(&funcOp.getBody(), unusedMapping); + if (failed(transformToJitFunction(funcOp))) { + funcOp.erase(); + return failure(); } return success(); } - void scanDependentSymbols(Operation *parentOp) { - // Find any global accessors and note their dependent symbols. - parentOp->walk([&](Operation *op) { - TypeSwitch(op) - .Case([&](IREE::Util::GlobalAddressOpInterface addressOp) { - symbolImportWorklist.push_back(addressOp.getGlobalAttr().getAttr()); - }) - .Case([&](IREE::Util::GlobalLoadOpInterface loadOp) { - symbolImportWorklist.push_back(loadOp.getGlobalAttr().getAttr()); - }) - .Case([&](IREE::Util::GlobalStoreOpInterface storeOp) { - symbolImportWorklist.push_back(storeOp.getGlobalAttr().getAttr()); - }); - }); - - // TODO: Scan for functions, etc. +private: + static ModuleOp createInnerModule(ModuleOp sourceModuleOp) { + OpBuilder builder = OpBuilder::atBlockEnd(sourceModuleOp.getBody()); + auto m = builder.create(sourceModuleOp.getLoc()); + m->setAttr("iree.consteval", builder.getUnitAttr()); + return m; } -private: + LogicalResult transformToJitFunction(func::FuncOp funcOp) { + JitFunctionDesc desc(funcOp.getLoc(), funcOp.getName().str()); + llvm::SmallVector argumentTypes; + llvm::SmallVector returnTypes; + llvm::SmallVector returns; + llvm::SmallVector eraseOps; + + Block *entryBlock = &funcOp.getBody().front(); + + // Find immutable loads. + for (auto loadOp : funcOp.getOps()) { + auto globalOp = llvm::dyn_cast_or_null( + sourceSymbolTable.lookup(loadOp.getGlobalAttr().getAttr())); + if (!globalOp || globalOp.getIsMutable()) { + emitWarning(loadOp.getLoc()) << "skipping consteval initializer: load " + "from mutable globals not supported"; + return failure(); + } + Type t = loadOp.getResult().getType(); + if (!supportedFeatures.isSupportedAbiType(t)) { + emitWarning(funcOp.getLoc()) + << "skipping consteval initializer: unsupported type for current " + "jit configuration: " + << t; + return failure(); + } + argumentTypes.push_back(t); + BlockArgument entryArg = entryBlock->addArgument(t, loadOp.getLoc()); + loadOp.getResult().replaceAllUsesWith(entryArg); + eraseOps.push_back(loadOp); + desc.argumentBindings.emplace_back(globalOp); + } + + // And loose tensor constants. + for (auto constantOp : funcOp.getOps()) { + auto tensorType = constantOp.getResult().getType().dyn_cast(); + auto elementsAttr = constantOp.getValue().dyn_cast(); + if (!tensorType || !elementsAttr) + continue; + if (!supportedFeatures.isSupportedAbiType(tensorType)) { + emitWarning(funcOp.getLoc()) + << "skipping consteval initializer: unsupported type for current " + "jit configuration: " + << tensorType; + return failure(); + } + argumentTypes.push_back(tensorType); + BlockArgument entryArg = + entryBlock->addArgument(tensorType, constantOp.getLoc()); + constantOp.getResult().replaceAllUsesWith(entryArg); + eraseOps.push_back(constantOp); + desc.argumentBindings.emplace_back(elementsAttr); + } + + // Find immutable stores, early exiting if not supported. + // The consumers must come after rewrites of the producers above. + for (auto storeOp : funcOp.getOps()) { + auto globalOp = llvm::dyn_cast_or_null( + sourceSymbolTable.lookup(storeOp.getGlobalAttr().getAttr())); + if (!globalOp || globalOp.getIsMutable()) { + emitWarning(storeOp.getLoc()) << "skipping consteval initializer: stor " + "to mutable globals not supported"; + return failure(); + } + Type t = storeOp.getValue().getType(); + if (!supportedFeatures.isSupportedAbiType(t)) { + emitWarning(funcOp.getLoc()) + << "skipping consteval initializer: unsupported type for current " + "jit configuration: " + << t; + return failure(); + } + + returns.push_back(storeOp.getValue()); + returnTypes.push_back(t); + eraseOps.push_back(storeOp); + desc.resultBindings.emplace_back(globalOp); + } + + // Cleanup. + for (auto *op : eraseOps) { + op->erase(); + } + + // Rewrite the terminator and the function type. + entryBlock->getTerminator()->erase(); + OpBuilder termBuilder = OpBuilder::atBlockEnd(entryBlock); + termBuilder.create(funcOp.getLoc(), returns); + funcOp.setType(termBuilder.getFunctionType(argumentTypes, returnTypes)); + + jitFunctions.push_back(std::move(desc)); + return success(); + } + + ModuleOp targetModuleOp; SymbolTable sourceSymbolTable; SymbolTable targetSymbolTable; - OpBuilder builder; - SmallVector symbolImportWorklist; -}; - -// These options structs are not copy-constructable so we have to allocate them -// shared. -// TODO: See if we can make them copyable? -struct CompileOptions { - BindingOptions bindingOptions; - InputDialectOptions inputOptions; - PreprocessingOptions preprocessingOptions; - HighLevelOptimizationOptions highLevelOptimizationOptions; - SchedulingOptions schedulingOptions; - IREE::HAL::TargetOptions executableOptions; - IREE::VM::TargetOptions targetOptions; - IREEVMPipelineHooks hooks; + llvm::SmallVector jitFunctions; + const SupportedFeatures &supportedFeatures; }; struct JitGlobalsPass : public JitGlobalsBase { - JitGlobalsPass() + JitGlobalsPass(const IREE::HAL::TargetBackendRegistry &targetRegistry) : options(std::make_shared()), compilePipeline("builtin.module") { - // Invoke IREE compilation flow. - options->executableOptions.targets.push_back("vmvx"); - options->targetOptions.f32Extension = true; - options->targetOptions.f64Extension = false; // not yet implemented + // Detect backend. + hasLLVMCPUBackend = targetRegistry.getTargetBackend("llvm-cpu") != nullptr; + if (clUseVMVX || !hasLLVMCPUBackend) { + options->executableOptions.targets.push_back("vmvx"); + options->targetOptions.f32Extension = true; + options->targetOptions.f64Extension = false; // not yet implemented + } else { + options->executableOptions.targets.push_back("llvm-cpu"); + } // Disable constant evaluation for our Jit compilation pipeline. // It would make no sense to recursively do constant evaluation, and since @@ -146,123 +327,196 @@ struct JitGlobalsPass : public JitGlobalsBase { options->highLevelOptimizationOptions.constEval = false; buildIREEVMTransformPassPipeline( - // TODO: If ever not using VMVX, plumb the real target registry - // through. - IREE::HAL::TargetBackendRegistry::getGlobal(), options->bindingOptions, - options->inputOptions, options->preprocessingOptions, - options->highLevelOptimizationOptions, options->schedulingOptions, - options->executableOptions, options->targetOptions, options->hooks, - compilePipeline); + targetRegistry, options->bindingOptions, options->inputOptions, + options->preprocessingOptions, options->highLevelOptimizationOptions, + options->schedulingOptions, options->executableOptions, + options->targetOptions, options->hooks, compilePipeline); } void getDependentDialects(DialectRegistry ®istry) const override { compilePipeline.getDependentDialects(registry); } + const SupportedFeatures getSupportedFeatures(MLIRContext *context) { + SupportedFeatures s; + Builder b(context); + s.addScalarType(b.getIntegerType(8)); + s.addScalarType(b.getIntegerType(16)); + s.addScalarType(b.getIntegerType(32)); + s.addScalarType(b.getIntegerType(64)); + s.addScalarType(b.getF32Type()); + + s.addElementType(b.getIntegerType(1)); + s.addElementType(b.getIntegerType(8)); + s.addElementType(b.getIntegerType(16)); + s.addElementType(b.getIntegerType(32)); + s.addElementType(b.getIntegerType(64)); + s.addElementType(b.getF32Type()); + if (!clUseVMVX && hasLLVMCPUBackend) { + // The full compilers support additional types. + // TODO: Enable support for i4 once it is worked out how to + // transfer to and from ElementsAttr. + s.addScalarType(b.getF64Type()); + s.addElementType(b.getF16Type()); + s.addElementType(b.getBF16Type()); + s.addElementType(b.getF64Type()); + } + return s; + } + + LogicalResult + processFunctions(CompiledBinary &binary, + llvm::SmallVector &jitFunctions, + ModuleOp module, llvm::TimerGroup &tg) { + // Process each function through the runtime. + for (JitFunctionDesc &jitFunction : jitFunctions) { + std::optional invokeTimer; + if (clEnableDebug) { + std::string timerName("Invoke "); + timerName.append(jitFunction.name); + invokeTimer.emplace(timerName, timerName, tg); + invokeTimer->startTimer(); + dbgs() << "::: Invoking " << jitFunction.name << "\n"; + } + + FunctionCall call(binary, jitFunction.argumentBindings.size(), + jitFunction.resultBindings.size()); + + // Convert arguments. + for (ArgumentBinding &arg : jitFunction.argumentBindings) { + switch (arg.getType()) { + case ArgumentBinding::Type::ElementsAttr: + if (failed(call.addArgument(jitFunction.loc, arg.getElementsAttr()))) + return failure(); + break; + + case ArgumentBinding::Type::GlobalOp: { + auto globalValue = arg.getGlobalOp().getInitialValue(); + if (!globalValue) { + return emitError(jitFunction.loc) + << "internal error: jit global source initialization order. " + "global " + << arg.getGlobalOp().getSymName() << " has no value"; + } + if (failed( + call.addArgument(arg.getGlobalOp().getLoc(), *globalValue))) + return failure(); + } break; + } + } + + if (failed(call.invoke(jitFunction.loc, jitFunction.name))) { + return failure(); + } + + // Process results. + for (auto it : llvm::enumerate(jitFunction.resultBindings)) { + ResultBinding &resultBinding = it.value(); + switch (resultBinding.getType()) { + case ResultBinding::Type::GlobalOp: { + TypedAttr attr; + if (failed(call.getResultAsAttr(resultBinding.getGlobalOp().getLoc(), + it.index(), attr))) + return failure(); + resultBinding.getGlobalOp().setInitialValueAttr(attr); + break; + } + } + } + + if (clEnableDebug) { + invokeTimer->stopTimer(); + } + } + + return success(); + } + void runOnOperation() override { + llvm::TimerGroup tg("iree-consteval-jit", "Consteval Jit"); auto outerModule = getOperation(); - SymbolTable outerSymbolTable(outerModule); - OpBuilder builder = OpBuilder::atBlockEnd(outerModule.getBody()); - auto innerModule = builder.create(outerModule.getLoc()); - ProgramExtractor extractor(outerModule, innerModule); - SmallVector pruneOps; - - // Import initializers. - for (auto childOp : outerModule.getOps()) { - extractor.importOperation(childOp); - pruneOps.push_back(childOp); + auto supportedFeatures = getSupportedFeatures(&getContext()); + if (!clUseVMVX && !hasLLVMCPUBackend) { + emitWarning(UnknownLoc::get(&getContext())) + << "consteval jit requested with llvm-cpu backend, but it is not " + "available. Falling back to vmvx"; } - // Transitively import any dependencies. - if (failed(extractor.importDependencies())) { - signalPassFailure(); + llvm::SmallVector initOps; + llvm::SmallVector deadInitOps; + for (auto childOp : outerModule.getOps()) { + initOps.push_back(childOp); } - // Find any globals that we pulled in which lack an initializer. These - // are the ones we will try to eval. Stash {func_symbol, global_symbol} - // pairs for later. - SmallVector> uninitializedGlobals; - for (Operation &childOp : *innerModule.getBody()) { - auto globalOp = llvm::dyn_cast(childOp); - if (!globalOp) - continue; - if (globalOp.getInitialValueAttr()) + // Build the program. + ProgramBuilder programBuilder(outerModule, supportedFeatures); + for (auto initOp : initOps) { + if (!initOp->hasAttr("iree.compiler.consteval")) continue; - // Only generate an accessor for types our runtime bridge knows how to - // handle. - Type type = globalOp.getType(); - if (!CompiledBinary::isSupportedResultType(type)) { - LLVM_DEBUG(dbgs() << "JitGlobals: unsupported global type " << type); - continue; + if (succeeded(programBuilder.importInitializer(initOp))) { + deadInitOps.push_back(initOp); } - - StringAttr funcSymbol = extractor.createAccessor(globalOp); - uninitializedGlobals.emplace_back(funcSymbol, globalOp.getSymNameAttr()); } - - // Early exit without compiling if no entry-points (this is not just an - // optimization: the low level compiler will fail on an empty module). - if (uninitializedGlobals.empty()) { - LLVM_DEBUG(dbgs() << "Not JIT'ing globals: no undefined globals found\n"); - innerModule.erase(); + if (programBuilder.getJitFunctions().empty()) { + programBuilder.getTargetModule()->erase(); return; } - // Run the IREE compiler, transforming the inner module into a vm.module. - LLVM_DEBUG(dbgs() << "JIT'ing " << uninitializedGlobals.size() - << " uninitialized globals\n"); - if (failed(runPipeline(compilePipeline, innerModule))) { + std::optional compileTimer; + if (clEnableDebug) { + dbgs() << "::: COMPILING JIT: " << programBuilder.getTargetModule() + << "\n"; + compileTimer.emplace("iree-consteval-jit-compile", "Compiling", tg); + compileTimer->startTimer(); + } + if (failed( + runPipeline(compilePipeline, programBuilder.getTargetModule()))) { return signalPassFailure(); } - // Generate a binary. InMemoryCompiledBinary binary; - if (failed(binary.translateFromModule(innerModule))) { + if (failed(binary.translateFromModule(programBuilder.getTargetModule()))) { return signalPassFailure(); } - - // Kill the temporary program we constructed. - innerModule.erase(); - - bool modified = false; - for (auto &it : uninitializedGlobals) { - StringAttr funcSymbol = it.first; - StringAttr globalSymbol = it.second; - auto targetGlobal = llvm::cast( - outerSymbolTable.lookup(globalSymbol)); - Location loc = targetGlobal->getLoc(); - - Attribute value = - binary.invokeNullaryAsAttribute(loc, funcSymbol.strref()); - if (!value) { - return signalPassFailure(); - } - - modified = true; - targetGlobal.setInitialValueAttr(cast(value)); + if (clEnableDebug) { + compileTimer->stopTimer(); } - // Delete any ops noted for pruning. - for (Operation *op : pruneOps) { - op->erase(); + // Kill the temporary program. + programBuilder.getTargetModule()->erase(); + + // Process the functions. + if (failed(processFunctions(binary, programBuilder.getJitFunctions(), + outerModule, tg))) { + signalPassFailure(); + return; } - // Signal any outer fixed point iterator that we have modified - // globals and need another pass. - if (modified) { - signalFixedPointModified(outerModule); + // Cleanup any initializers we replaced. + // We do this after running the JIT-ed functions because we have deep + // references into ops and attributes that need to be converted to + // arguments. + for (auto deadOp : deadInitOps) { + deadOp.erase(); } } std::shared_ptr options; OpPassManager compilePipeline; + bool hasLLVMCPUBackend; }; } // namespace +std::unique_ptr> +createJitGlobalsPass(const IREE::HAL::TargetBackendRegistry &targetRegistry) { + return std::make_unique(targetRegistry); +} + std::unique_ptr> createJitGlobalsPass() { - return std::make_unique(); + return std::make_unique( + IREE::HAL::TargetBackendRegistry::getGlobal()); } } // namespace ConstEval diff --git a/compiler/src/iree/compiler/ConstEval/Passes.h b/compiler/src/iree/compiler/ConstEval/Passes.h index e867da272711..4822b0175b59 100644 --- a/compiler/src/iree/compiler/ConstEval/Passes.h +++ b/compiler/src/iree/compiler/ConstEval/Passes.h @@ -12,11 +12,19 @@ namespace mlir { namespace iree_compiler { +namespace IREE::HAL { +class TargetBackendRegistry; +} // namespace IREE::HAL namespace ConstEval { /// Creates a pass which uses the compiler and runtime to Jit global /// initializers eligible for optimization and uses the actual results to /// simplify the globals in the module. +std::unique_ptr> +createJitGlobalsPass(const IREE::HAL::TargetBackendRegistry &targetRegistry); + +// Creates with the global target registry (for opt and such). This +// may only have access to the VMVX backend. std::unique_ptr> createJitGlobalsPass(); void registerConstEvalPasses(); diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.cpp b/compiler/src/iree/compiler/ConstEval/Runtime.cpp index 8260cdf6e731..e49f9a9141ca 100644 --- a/compiler/src/iree/compiler/ConstEval/Runtime.cpp +++ b/compiler/src/iree/compiler/ConstEval/Runtime.cpp @@ -11,12 +11,70 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#define DEBUG_TYPE "iree-const-eval" +using llvm::dbgs; + namespace mlir { namespace iree_compiler { namespace ConstEval { namespace { +LogicalResult handleRuntimeError(Location loc, iree_status_t status) { + if (iree_status_is_ok(status)) + return success(); + std::string message; + message.resize(512); + iree_host_size_t buffer_length; + if (!iree_status_format(status, message.size(), &message[0], + &buffer_length)) { + message.resize(buffer_length + 1); + iree_status_format(status, message.size(), &message[0], &buffer_length); + } + message.resize(buffer_length); + iree_status_ignore(status); + return emitError(loc) << "runtime error in consteval: " << message; +} + +LogicalResult convertToElementType(Location loc, Type baseType, + iree_hal_element_type_t *outElementType) { + Builder builder(loc.getContext()); + if (auto t = llvm::dyn_cast(baseType)) { + switch (t.getWidth()) { + case 32: + *outElementType = IREE_HAL_ELEMENT_TYPE_INT_32; + return success(); + case 64: + *outElementType = IREE_HAL_ELEMENT_TYPE_INT_64; + return success(); + case 8: + *outElementType = IREE_HAL_ELEMENT_TYPE_INT_8; + return success(); + case 16: + *outElementType = IREE_HAL_ELEMENT_TYPE_INT_16; + return success(); + case 4: + *outElementType = IREE_HAL_ELEMENT_TYPE_INT_4; + return success(); + } + } else if (baseType == builder.getF32Type()) { + *outElementType = IREE_HAL_ELEMENT_TYPE_FLOAT_32; + return success(); + } else if (baseType == builder.getF64Type()) { + *outElementType = IREE_HAL_ELEMENT_TYPE_FLOAT_64; + return success(); + } else if (baseType == builder.getF16Type()) { + *outElementType = IREE_HAL_ELEMENT_TYPE_FLOAT_16; + return success(); + } else if (baseType == builder.getBF16Type()) { + *outElementType = IREE_HAL_ELEMENT_TYPE_BFLOAT_16; + return success(); + } + + return emitError(loc) + << "internal error: unhandled element type in consteval: " << baseType; +} + Type mapElementType(Location loc, iree_hal_element_type_t halElementType) { Builder builder(loc.getContext()); if (iree_hal_element_numerical_type_is_boolean(halElementType)) { @@ -38,7 +96,7 @@ Type mapElementType(Location loc, iree_hal_element_type_t halElementType) { return {}; } -static Attribute createAttributeFromRawData(Location loc, +static TypedAttr createAttributeFromRawData(Location loc, RankedTensorType tensorType, MutableArrayRef rawBuffer) { Type elementType = tensorType.getElementType(); @@ -94,98 +152,156 @@ void CompiledBinary::deinitialize() { device.reset(); } -LogicalResult CompiledBinary::invokeNullary(Location loc, StringRef name, - ResultsCallback callback) { - iree_vm_function_t function; - if (auto status = iree_vm_module_lookup_function_by_name( - main_module.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT, - iree_string_view_t{name.data(), - static_cast(name.size())}, - &function)) { - iree_status_ignore(status); - return emitError(loc) << "internal error evaling constant: func '" << name - << "' not found"; - } +FunctionCall::FunctionCall(CompiledBinary &binary, iree_host_size_t argCapacity, + iree_host_size_t resultCapacity) + : binary(binary) { + IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), + argCapacity, iree_allocator_system(), + &inputs)); + IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), + resultCapacity, iree_allocator_system(), + &outputs)); +} - iree::vm::ref inputs; - IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 0, - iree_allocator_system(), &inputs)); - iree::vm::ref outputs; - IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), 1, - iree_allocator_system(), &outputs)); - - if (auto status = - iree_vm_invoke(context.get(), function, IREE_VM_INVOCATION_FLAG_NONE, - /*policy=*/nullptr, inputs.get(), outputs.get(), - iree_allocator_system())) { - std::string message; - message.resize(512); - iree_host_size_t buffer_length; - if (!iree_status_format(status, message.size(), &message[0], - &buffer_length)) { - message.resize(buffer_length + 1); - iree_status_format(status, message.size(), &message[0], &buffer_length); - } - message.resize(buffer_length); - iree_status_ignore(status); - return emitError(loc) << "internal error evaling constant: " << message; +// Imports or snapshots a raw host buffer, depending on whether import is +// possible. +LogicalResult FunctionCall::importBufferForRead(Location loc, + const uint8_t *rawData, + iree_host_size_t length, + iree_hal_buffer_t **buffer) { + // TODO: Allow import when we have resources in the input where alignment + // can be guaranteed. + bool tryImport = false; + if (tryImport) { + iree_hal_buffer_params_t params; + std::memset(¶ms, 0, sizeof(params)); + params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE; + iree_hal_external_buffer_t external_buffer; + std::memset(&external_buffer, 0, sizeof(external_buffer)); + external_buffer.type = IREE_HAL_EXTERNAL_BUFFER_TYPE_HOST_ALLOCATION; + external_buffer.size = length; + external_buffer.handle.host_allocation.ptr = + const_cast(static_cast(rawData)); + auto status = iree_hal_allocator_import_buffer( + binary.getAllocator(), params, &external_buffer, + /*release_callback=*/{nullptr, nullptr}, buffer); + if (iree_status_is_ok(status)) + return success(); + else if (!(iree_status_is_out_of_range(status) || + iree_status_is_unavailable(status))) + return handleRuntimeError(loc, status); } - if (failed(callback(outputs.get()))) { - return failure(); + // Buffer is not compatible with import. Snapshot. + { + iree_hal_buffer_params_t params; + std::memset(¶ms, 0, sizeof(params)); + params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE; + LLVM_DEBUG( + dbgs() + << "Cannot import consteval buffer. Falling back to snapshot.\n"); + return handleRuntimeError(loc, iree_hal_allocator_allocate_buffer( + binary.getAllocator(), params, length, + iree_const_byte_span_t{rawData, length}, + buffer)); } - return success(); } -Attribute CompiledBinary::invokeNullaryAsAttribute(Location loc, - StringRef name) { - Attribute result; - if (failed(invokeNullary( - loc, name, [&](iree_vm_list_t *outputs) -> LogicalResult { - if (iree_vm_list_size(outputs) != 1) { - return emitError(loc) << "expected 1 result for func " << name - << " got " << iree_vm_list_size(outputs); - } - iree_vm_variant_t variant = iree_vm_variant_empty(); - IREE_CHECK_OK( - iree_vm_list_get_variant_assign(outputs, 0, &variant)); - result = convertVariantToAttribute(loc, variant); - return success(result != nullptr); - }))) { - return nullptr; +LogicalResult FunctionCall::addArgument(Location loc, Attribute attr) { + if (auto elementsAttr = llvm::dyn_cast(attr)) { + // Meta-data. + ArrayRef data = elementsAttr.getRawData(); + ShapedType st = elementsAttr.getType(); + auto stShape = st.getShape(); + auto rank = static_cast(st.getRank()); + iree_hal_dim_t *shape = + static_cast(alloca(rank * sizeof(iree_hal_dim_t))); + for (size_t i = 0; i < rank; ++i) { + shape[i] = stShape[i]; + } + iree_hal_element_type_t elementType; + if (failed(convertToElementType(loc, st.getElementType(), &elementType))) + return failure(); + + iree::vm::ref buffer; + if (elementsAttr.isSplat()) { + // Handle splat. In this case, the data size is one element. + iree_device_size_t bufferSize = data.size() * st.getNumElements(); + iree_hal_buffer_params_t params; + std::memset(¶ms, 0, sizeof(params)); + params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE; + if (failed(handleRuntimeError( + loc, iree_hal_allocator_allocate_buffer( + binary.getAllocator(), params, bufferSize, + iree_const_byte_span_t{nullptr, 0}, &buffer)))) + return failure(); + + if (failed(handleRuntimeError( + loc, iree_hal_buffer_map_fill( + buffer.get(), 0, bufferSize, + static_cast(data.data()), data.size())))) + return failure(); + } else { + // Dense, non-splat. + if (failed(importBufferForRead( + loc, reinterpret_cast(data.data()), data.size(), + &buffer))) + return failure(); + } + + // Construct buffer view. + iree::vm::ref bv; + if (failed(handleRuntimeError( + loc, + iree_hal_buffer_view_create(buffer.get(), rank, shape, elementType, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + iree_allocator_system(), &bv)))) + return failure(); + + return handleRuntimeError( + loc, iree_vm_list_push_ref_move(inputs.get(), std::move(bv))); } - return result; + return emitError(loc) + << "internal error: unsupported consteval jit function input (" << attr + << ")"; } -bool CompiledBinary::isSupportedResultType(Type type) { - // TODO(laurenzo): Not currently supported. VMVX would need to support these - // and today it doesn't. We could use alternative backends (LLVM CPU/etc) if - // we wanted to handle f64, but f16 and bf16 often need special hardware. - if (llvm::isa(type) || llvm::isa(type) || - llvm::isa(type)) { - return false; +LogicalResult FunctionCall::invoke(Location loc, StringRef name) { + // Lookup function. + iree_vm_function_t function; + if (auto status = iree_vm_module_lookup_function_by_name( + binary.main_module.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT, + iree_string_view_t{name.data(), + static_cast(name.size())}, + &function)) { + iree_status_ignore(status); + return emitError(loc) << "internal error evaling constant: func '" << name + << "' not found"; } - // Support scalar int and float type of byte aligned widths. - if (type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0) { - return true; - } + return handleRuntimeError(loc, iree_vm_invoke(binary.context.get(), function, + IREE_VM_INVOCATION_FLAG_NONE, + /*policy=*/nullptr, + inputs.get(), outputs.get(), + iree_allocator_system())); +} - // Special support for i1. - if (llvm::isa(type) && type.getIntOrFloatBitWidth() == 1) { - return true; - } +LogicalResult FunctionCall::getResultAsAttr(Location loc, size_t index, + TypedAttr &outAttr) { + iree_vm_variant_t variant = iree_vm_variant_empty(); + if (failed(handleRuntimeError(loc, iree_vm_list_get_variant_assign( + outputs.get(), index, &variant)))) + return failure(); - // Support tensors. - if (auto tt = llvm::dyn_cast(type)) { - return isSupportedResultType(tt.getElementType()); - } + outAttr = binary.convertVariantToAttribute(loc, variant); + if (!outAttr) + return failure(); - return false; + return success(); } -Attribute +TypedAttr CompiledBinary::convertVariantToAttribute(Location loc, iree_vm_variant_t &variant) { auto context = loc.getContext(); diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.h b/compiler/src/iree/compiler/ConstEval/Runtime.h index 6078da069b38..be7259ac55c8 100644 --- a/compiler/src/iree/compiler/ConstEval/Runtime.h +++ b/compiler/src/iree/compiler/ConstEval/Runtime.h @@ -12,6 +12,7 @@ #include "iree/modules/hal/module.h" #include "iree/vm/api.h" #include "iree/vm/bytecode/module.h" +#include "llvm/Support/Debug.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" @@ -28,16 +29,9 @@ class CompiledBinary { using ResultsCallback = std::function; virtual ~CompiledBinary(); - // Invokes a nullary function. - LogicalResult invokeNullary(Location loc, StringRef name, - ResultsCallback callback); - - // Invokes a nullary function and returns its (presumed single) single result - // as an Attribute. - Attribute invokeNullaryAsAttribute(Location loc, StringRef name); - - // Whether the given type is supported in *AsAttribute methods. - static bool isSupportedResultType(Type type); + iree_hal_allocator_t *getAllocator() { + return iree_hal_device_allocator(device.get()); + } protected: CompiledBinary(); @@ -46,12 +40,32 @@ class CompiledBinary { // explicitly by subclasses, ensuring that any backing images remain valid // through the call to deinitialize(). void deinitialize(); - Attribute convertVariantToAttribute(Location loc, iree_vm_variant_t &variant); + TypedAttr convertVariantToAttribute(Location loc, iree_vm_variant_t &variant); iree::vm::ref device; iree::vm::ref hal_module; iree::vm::ref main_module; iree::vm::ref context; + + friend class FunctionCall; +}; + +class FunctionCall { +public: + FunctionCall(CompiledBinary &binary, iree_host_size_t argCapacity, + iree_host_size_t resultCapacity); + + LogicalResult addArgument(Location loc, Attribute attr); + LogicalResult invoke(Location loc, StringRef name); + LogicalResult getResultAsAttr(Location loc, size_t index, TypedAttr &outAttr); + +private: + LogicalResult importBufferForRead(Location loc, const uint8_t *rawData, + iree_host_size_t length, + iree_hal_buffer_t **buffer); + CompiledBinary binary; + iree::vm::ref inputs; + iree::vm::ref outputs; }; // An in-memory compiled binary and accessors for working with it. diff --git a/compiler/src/iree/compiler/ConstEval/test/jit_globals.mlir b/compiler/src/iree/compiler/ConstEval/test/jit_globals.mlir index f8a47cc968ee..bffb87ecc076 100644 --- a/compiler/src/iree/compiler/ConstEval/test/jit_globals.mlir +++ b/compiler/src/iree/compiler/ConstEval/test/jit_globals.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --iree-consteval-jit-globals %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-consteval-jit-use-vmvx=true --verify-diagnostics --iree-consteval-jit-debug --iree-consteval-jit-globals %s | FileCheck %s // TODO(laurenzo): Full type matrix for tests. @@ -22,7 +22,7 @@ module @linalg_tensor_jit { return %hoisted : tensor<5x6xf32> } // CHECK-NOT: util.initializer - util.initializer { + util.initializer attributes {iree.compiler.consteval} { %cst = arith.constant dense<2.0e+02> : tensor %0 = tensor.empty() : tensor<5x6xf32> %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst : tensor) outs(%0 : tensor<5x6xf32>) { @@ -49,7 +49,7 @@ module @eval_splat_detection { %hoisted = util.global.load @hoisted : tensor<2xi32> return %hoisted : tensor<2xi32> } - util.initializer { + util.initializer attributes {iree.compiler.consteval} { %cst = arith.constant dense<[2, 2]> : tensor<2xi32> util.global.store %cst, @hoisted : tensor<2xi32> util.initializer.return @@ -59,15 +59,14 @@ module @eval_splat_detection { // ----- // CHECK-LABEL: @eval_f16_tensor -// Not currently supported (initializer should remain) -// CHECK: util.initializer module @eval_f16_tensor { util.global private @hoisted : tensor<5x6xf16> func.func @main() -> tensor<5x6xf16> { %hoisted = util.global.load @hoisted : tensor<5x6xf16> return %hoisted : tensor<5x6xf16> } - util.initializer { + // expected-warning @+1 {{unsupported type for current jit configuration}} + util.initializer attributes {iree.compiler.consteval} { %cst = arith.constant dense<2.0e+2> : tensor<5x6xf16> util.global.store %cst, @hoisted : tensor<5x6xf16> util.initializer.return @@ -77,14 +76,14 @@ module @eval_f16_tensor { // ----- // CHECK-LABEL: @eval_bf16_tensor // Not currently supported (initializer should remain) -// CHECK: util.initializer module @eval_bf16_tensor { util.global private @hoisted : tensor<5x6xbf16> func.func @main() -> tensor<5x6xbf16> { %hoisted = util.global.load @hoisted : tensor<5x6xbf16> return %hoisted : tensor<5x6xbf16> } - util.initializer { + // expected-warning @+1 {{unsupported type for current jit configuration}} + util.initializer attributes {iree.compiler.consteval} { %cst = arith.constant dense<2.0e+2> : tensor<5x6xbf16> util.global.store %cst, @hoisted : tensor<5x6xbf16> util.initializer.return @@ -100,7 +99,7 @@ module @eval_f32_tensor { %hoisted = util.global.load @hoisted : tensor<2xf32> return %hoisted : tensor<2xf32> } - util.initializer { + util.initializer attributes {iree.compiler.consteval} { %cst = arith.constant dense<[2.0e+2, 3.2e+3]> : tensor<2xf32> util.global.store %cst, @hoisted : tensor<2xf32> util.initializer.return @@ -109,15 +108,14 @@ module @eval_f32_tensor { // ----- // CHECK-LABEL: @eval_f64_tensor -// Not currently supported (initializer should remain) -// CHECK: util.initializer module @eval_f64_tensor { util.global private @hoisted : tensor<2xf64> func.func @main() -> tensor<2xf64> { %hoisted = util.global.load @hoisted : tensor<2xf64> return %hoisted : tensor<2xf64> } - util.initializer { + // expected-warning @+1 {{unsupported type for current jit configuration}} + util.initializer attributes {iree.compiler.consteval} { %cst = arith.constant dense<[2.0e+2, 3.2e+3]> : tensor<2xf64> util.global.store %cst, @hoisted : tensor<2xf64> util.initializer.return @@ -133,7 +131,7 @@ module @eval_i1_tensor { %hoisted = util.global.load @hoisted : tensor<6xi1> return %hoisted : tensor<6xi1> } - util.initializer { + util.initializer attributes {iree.compiler.consteval} { // Note that the level we are testing at is a bit odd in the way i1 vs // i8 are handled. %cst = arith.constant dense<[0, 1, 0, 1, 1, 0]> : tensor<6xi8> @@ -145,14 +143,14 @@ module @eval_i1_tensor { // ----- // CHECK-LABEL: @eval_i4_tensor -// CHECK: util.initializer module @eval_i4_tensor { util.global private @hoisted : tensor<5x6xi4> func.func @main() -> tensor<5x6xi4> { %hoisted = util.global.load @hoisted : tensor<5x6xi4> return %hoisted : tensor<5x6xi4> } - util.initializer { + // expected-warning @+1 {{unsupported type for current jit configuration}} + util.initializer attributes {iree.compiler.consteval} { %cst = arith.constant dense<3> : tensor<5x6xi4> util.global.store %cst, @hoisted : tensor<5x6xi4> util.initializer.return @@ -168,7 +166,7 @@ module @eval_i8_tensor { %hoisted = util.global.load @hoisted : tensor<2xi8> return %hoisted : tensor<2xi8> } - util.initializer { + util.initializer attributes {iree.compiler.consteval} { %cst = arith.constant dense<[2, 3]> : tensor<2xi8> util.global.store %cst, @hoisted : tensor<2xi8> util.initializer.return @@ -184,7 +182,7 @@ module @eval_i16_tensor { %hoisted = util.global.load @hoisted : tensor<2xi16> return %hoisted : tensor<2xi16> } - util.initializer { + util.initializer attributes {iree.compiler.consteval} { %cst = arith.constant dense<[2, 3]> : tensor<2xi16> util.global.store %cst, @hoisted : tensor<2xi16> util.initializer.return @@ -200,7 +198,7 @@ module @eval_i32_tensor { %hoisted = util.global.load @hoisted : tensor<2xi32> return %hoisted : tensor<2xi32> } - util.initializer { + util.initializer attributes {iree.compiler.consteval} { %cst = arith.constant dense<[2, 3]> : tensor<2xi32> util.global.store %cst, @hoisted : tensor<2xi32> util.initializer.return @@ -216,7 +214,7 @@ module @eval_i64_tensor { %hoisted = util.global.load @hoisted : tensor<2xi64> return %hoisted : tensor<2xi64> } - util.initializer { + util.initializer attributes {iree.compiler.consteval} { %cst = arith.constant dense<[2, 3]> : tensor<2xi64> util.global.store %cst, @hoisted : tensor<2xi64> util.initializer.return diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp index 1c41f322b492..c685e567971c 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp @@ -102,6 +102,9 @@ class HoistIntoGlobalsPass : public HoistIntoGlobalsBase { Location loc = originalValue.getLoc(); OpBuilder builder = getModuleEndBuilder(); auto initializerOp = builder.create(loc); + // Signals that this initializer is eligible for constant evaluation + // at compile time. + initializerOp->setAttr("iree.compiler.consteval", builder.getUnitAttr()); Block *entryBlock = initializerOp.addEntryBlock(); OpBuilder initBuilder = OpBuilder::atBlockEnd(entryBlock); IRMapping valueMapping; diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir index fbe4dd6c871c..a285289aee35 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir @@ -14,7 +14,7 @@ module @hoist_simple_const_expr { %2 = "iree_unregistered.const_expr"(%0, %1) : (i32, i32) -> i32 return %2 : i32 } - // CHECK: util.initializer { + // CHECK: util.initializer attributes {iree.compiler.consteval} { // CHECK: %[[C0:.*]] = arith.constant 0 : i32 // CHECK: %[[C1:.*]] = arith.constant 1 : i32 // CHECK: %[[CE0:.*]] = "iree_unregistered.const_expr"(%[[C0]], %[[C1]]) @@ -126,14 +126,14 @@ module @hoist_tree_const_expr { %5 = "iree_unregistered.var_expr"(%4) : (i32) -> i32 return %2, %4, %5 : i32, i32, i32 } - // CHECK: util.initializer { + // CHECK: util.initializer attributes {iree.compiler.consteval} { // CHECK: %[[C0:.*]] = arith.constant 0 : i32 // CHECK: %[[C1:.*]] = arith.constant 1 : i32 // CHECK: %[[CE0:.*]] = "iree_unregistered.const_expr"(%[[C0]], %[[C1]]) // CHECK: util.global.store %[[CE0]], @[[HOISTED_0]] : i32 // CHECK: util.initializer.return // CHECK: } - // CHECK: util.initializer { + // CHECK: util.initializer attributes {iree.compiler.consteval} { // CHECK: %[[LOAD_HOISTED_0:.*]] = util.global.load @[[HOISTED_0]] : i32 // CHECK: %[[LOAD_LATENT_GLOBAL:.*]] = util.global.load @latent_global : i32 // CHECK: %[[CE1:.*]] = "iree_unregistered.const_expr"(%[[LOAD_HOISTED_0]], %[[LOAD_LATENT_GLOBAL]]) @@ -161,7 +161,7 @@ module @hoist_non_leaf_const_expr { %4 = "iree_unregistered.non_leaf_const_expr"(%3) : (i32) -> i32 return %4 : i32 } - // CHECK: util.initializer { + // CHECK: util.initializer attributes {iree.compiler.consteval} { // CHECK: %[[C0:.*]] = arith.constant 0 : i32 // CHECK: %[[C1:.*]] = arith.constant 1 : i32 // CHECK: %[[CE0:.*]] = "iree_unregistered.non_leaf_const_expr"(%[[C0]], %[[C1]]) @@ -192,7 +192,7 @@ module @hoist_implicit_capture { } // Key checks: arith.constant 1 gets pulled in to the initializer // and the reference is updated correctly in the custom op region. - // CHECK: util.initializer { + // CHECK: util.initializer attributes {iree.compiler.consteval} { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 // CHECK: %[[CE0:.*]] = "iree_unregistered.const_expr"(%[[C0]]) diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp index b27a560447d4..eaa1cea11d4b 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.cpp +++ b/compiler/src/iree/compiler/Pipelines/Options.cpp @@ -105,7 +105,7 @@ void HighLevelOptimizationOptions::bindOptions(OptionsBinder &binder) { binder.opt( "iree-opt-const-eval", constEval, llvm::cl::desc("Enables eager evaluation of constants using the full " - "compiler and runtime."), + "compiler and runtime (on by default)."), llvm::cl::cat(category)); binder.opt( "iree-opt-const-expr-hoisting", constExprHoisting, diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h index 86b5bf29507a..0b61177492ec 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.h +++ b/compiler/src/iree/compiler/Pipelines/Options.h @@ -81,7 +81,7 @@ struct HighLevelOptimizationOptions { // Enables recursive evaluation of immutable globals using the compiler // and runtime. - bool constEval = false; + bool constEval = true; // Optimizations to reduce numeric precision where it is safe to do so. bool numericPrecisionReduction = false; diff --git a/docs/website/docs/reference/optimization-options.md b/docs/website/docs/reference/optimization-options.md index de8f668c3458..9f61a6f227f2 100644 --- a/docs/website/docs/reference/optimization-options.md +++ b/docs/website/docs/reference/optimization-options.md @@ -18,7 +18,7 @@ These flags can be passed to the: ## High level program optimizations -### Constant evaluation (`--iree-opt-const-eval` (off)) +### Constant evaluation (`--iree-opt-const-eval` (on)) Performs compile-time evaluation of any global initializers which produce the initial values for global constants, storing the global directly in the