diff --git a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp index 4fd578af6bb7..9a4d89a65818 100644 --- a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp +++ b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp @@ -628,6 +628,7 @@ class JitGlobalsPass final : public impl::JitGlobalsPassBase { requestedTargetDevice); compileOptions->targetOptions.f32Extension = true; compileOptions->targetOptions.f64Extension = true; + compileOptions->targetOptions.indexBits = 64; compileOptions->targetOptions.truncateUnsupportedFloats = false; compileOptions->inputOptions.demoteF64ToF32 = false; if (requestedTargetDevice == "vmvx" || !hasRequestedTargetDevice) { @@ -677,14 +678,15 @@ class JitGlobalsPass final : public impl::JitGlobalsPassBase { s.addScalarType(b.getIntegerType(16)); s.addScalarType(b.getIntegerType(32)); s.addScalarType(b.getIntegerType(64)); + s.addScalarType(b.getIndexType()); s.addScalarType(b.getF32Type()); s.addElementType(b.getIntegerType(1)); - s.addElementType(b.getIntegerType(8)); s.addElementType(b.getIntegerType(16)); s.addElementType(b.getIntegerType(32)); s.addElementType(b.getIntegerType(64)); + s.addElementType(b.getIndexType()); s.addElementType(b.getF32Type()); if (requestedTargetDevice != "vmvx" && hasRequestedTargetDevice) { // The full compilers support additional types. diff --git a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/BUILD.bazel index 0b2f50fcd450..27f9a8e21ef8 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/BUILD.bazel +++ b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/BUILD.bazel @@ -38,6 +38,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/Util/IR", "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", ], ) diff --git a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/CMakeLists.txt index 0b5911bf2cfe..657d9aac5f99 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/CMakeLists.txt +++ b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/CMakeLists.txt @@ -32,6 +32,7 @@ iree_cc_library( "HoistableTypeInterface.cpp" DEPS LLVMSupport + MLIRArithDialect MLIRIR iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::Util::IR diff --git a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/HoistableTypeInterface.cpp b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/HoistableTypeInterface.cpp index 76c90cf9f09d..f5a3135b6296 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/HoistableTypeInterface.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/HoistableTypeInterface.cpp @@ -9,6 +9,7 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "llvm/Support/MathExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinTypes.h" namespace mlir::iree_compiler { @@ -85,6 +86,36 @@ struct HoistableTensorTypeInterface } }; +struct HoistableIndexTypeInterface + : public IREE::Util::HoistableTypeInterface::ExternalModel< + HoistableIndexTypeInterface, IndexType> { + bool isHoistableType(Type type) const { return true; } + bool isHoistableLeafType(Type type) const { return true; } + Type getPreferredStorageType(Type type) const { + // Conservatively enforce 64 bit indices for + // (potentially constant evaluated) hoisted globals. + return IntegerType::get(type.getContext(), 64); + } + static Value encodeStorageType(OpBuilder &builder, Location loc, + Type storageType, Value init) { + auto storageIndexType = dyn_cast(storageType); + if (!storageIndexType || init.getType() == storageIndexType || + !isa(init.getType())) { + return init; + } + return builder.create(loc, storageType, init); + } + static Value decodeStorageType(OpBuilder &builder, Location loc, + Type originalType, Value loadedGlobal) { + auto originalIndexType = dyn_cast(originalType); + if (!originalIndexType || loadedGlobal.getType() == originalIndexType || + !isa(loadedGlobal.getType())) { + return loadedGlobal; + } + return builder.create(loc, originalType, loadedGlobal); + } +}; + //===----------------------------------------------------------------------===// // IREE specific post analysis transformations. //===----------------------------------------------------------------------===// @@ -93,6 +124,7 @@ void registerHoistableTypeInterfaces(DialectRegistry ®istry) { // Register hoistable type interfaces for builtin types. registry.addExtension(+[](MLIRContext *ctx) { RankedTensorType::attachInterface(*ctx); + IndexType::attachInterface(*ctx); }); } diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir index e289f07575a9..67e631518488 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir @@ -156,3 +156,26 @@ module @hoist_dialect_attrs { util.return %1 : tensor } } + +// ----- + +// CHECK-LABEL: @hoist_index +module @hoist_index { + // CHECK: util.global private @[[HOISTED:.*]] : i64 + // CHECK: util.initializer + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[CEXPR:.*]] = "iree_unregistered.const_expr"(%[[C0]]) + // CHECK: %[[CAST:.*]] = arith.index_cast %[[CEXPR]] : index to i64 + // CHECK: util.global.store %[[CAST]], @[[HOISTED]] : i64 + // CHECK: util.return + + // CHECK: util.func public @main() -> index + // CHECK: %[[GLOBAL_LD:.*]] = util.global.load immutable @[[HOISTED]] : i64 + // CHECK: %[[ORIG_VAL:.*]] = arith.index_cast %[[GLOBAL_LD]] : i64 to index + // CHECK: util.return %[[ORIG_VAL]] + util.func public @main() -> (index) { + %0 = arith.constant 0 : index + %1 = "iree_unregistered.const_expr"(%0) : (index) -> index + util.return %1 : index + } +}