diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp index b999a8390685..b29252bcb897 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h" #include "iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" #include "iree/compiler/Dialect/Util/Transforms/PassDetail.h" #include "iree/compiler/Dialect/Util/Transforms/Passes.h" #include "llvm/Support/Debug.h" @@ -29,6 +30,50 @@ static llvm::cl::opt clPrintDotGraphToFile( "nodes represent roots and the green nodes represent hoisted values."), llvm::cl::value_desc("filename")); +template +static inline bool isAccessorParameterized(SymbolTable moduleSymbols, + AccessorTy op) { + auto global = + moduleSymbols.lookup(op.getGlobalName()); + if (!global) + return true; + if (!global.getGlobalInitialValue()) + return false; + return !isa(global.getGlobalInitialValue()); +} + +// Today the only way to interact with a global is with loads, stores, and +// addresses, and globals are the only way to reference parameters given where +// const-eval is run today. This is a workaround until we have proper dialect +// interfaces for detecting whether something is evaluatable at compile time. +static bool isParameterized(SymbolTable moduleSymbols, Operation *initializer) { + WalkResult res = initializer->walk([&](Operation *op) { + bool parameterized = + llvm::TypeSwitch(op) + .Case([=](GlobalLoadOpInterface accessor) { + return isAccessorParameterized(moduleSymbols, accessor); + }) + .Case([=](GlobalLoadOpInterface accessor) { + return isAccessorParameterized(moduleSymbols, accessor); + }) + .Case([=](GlobalLoadOpInterface accessor) { + return isAccessorParameterized(moduleSymbols, accessor); + }) + .Case([=](CallOpInterface accessor) { + // Pessimistic case for calls that could transitively load a + // parameter. Today const-expr hoisting does not model calls + // properly so we don't hit this path. + return true; + }) + .Default([=](auto) { return false; }); + if (parameterized) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return res.wasInterrupted(); +} + // Maps an original value in the program to the symbol name of a global. using HoistedValueMap = llvm::DenseMap; @@ -149,9 +194,6 @@ class HoistIntoGlobalsPass : public HoistIntoGlobalsBase { Location loc = originalValue.getLoc(); OpBuilder builder = getModuleEndBuilder(); auto initializerOp = builder.create(loc); - // Signals that this initializer is eligible for constant evaluation - // at compile time. - initializerOp->setAttr("iree.compiler.consteval", builder.getUnitAttr()); Block *entryBlock = initializerOp.addEntryBlock(); OpBuilder initBuilder = OpBuilder::atBlockEnd(entryBlock); IRMapping valueMapping; @@ -162,6 +204,13 @@ class HoistIntoGlobalsPass : public HoistIntoGlobalsBase { } existingGlobal = hoistedMap.lookup(originalValue); + + // Signals that this initializer is eligible for constant evaluation + // at compile time. + if (!isParameterized(moduleSymbols, initializerOp)) { + initializerOp->setAttr("iree.compiler.consteval", + builder.getUnitAttr()); + } } assert(existingGlobal && "hoisting const-expr should have mapped a global for the requested " diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir index 4d8a2aa9b667..2eacd77567e4 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir @@ -343,3 +343,19 @@ module @nested_program_const_expr { } } } + +// ----- + +// CHECK-LABEL: @parameterized_const_expr +module @parameterized_const_expr { +// Verify that the initializer does not get labelled as evaluatable by +// const-eval. +// CHECK: util.initializer { +// CHECK-NEXT: util.global.load @parameter_constant + util.global private @parameter_constant = #stream.parameter.named<"compile"::"constant_hoisted_0"> : i32 + func.func @main() -> (i32) { + %load = util.global.load @parameter_constant : i32 + %1 = "iree_unregistered.const_expr"(%load) : (i32) -> i32 + return %1 : i32 + } +}