Skip to content

Commit

Permalink
Add hoist support for index type (#18303)
Browse files Browse the repository at this point in the history
#18232

Signed-off-by: jinchen62 <jinchenye62@gmail.com>
  • Loading branch information
jinchen62 authored Aug 24, 2024
1 parent cc44a85 commit d8f0fc3
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 1 deletion.
4 changes: 3 additions & 1 deletion compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ class JitGlobalsPass final : public impl::JitGlobalsPassBase<JitGlobalsPass> {
requestedTargetDevice);
compileOptions->targetOptions.f32Extension = true;
compileOptions->targetOptions.f64Extension = true;
compileOptions->targetOptions.indexBits = 64;
compileOptions->targetOptions.truncateUnsupportedFloats = false;
compileOptions->inputOptions.demoteF64ToF32 = false;
if (requestedTargetDevice == "vmvx" || !hasRequestedTargetDevice) {
Expand Down Expand Up @@ -677,14 +678,15 @@ class JitGlobalsPass final : public impl::JitGlobalsPassBase<JitGlobalsPass> {
s.addScalarType(b.getIntegerType(16));
s.addScalarType(b.getIntegerType(32));
s.addScalarType(b.getIntegerType(64));
s.addScalarType(b.getIndexType());
s.addScalarType(b.getF32Type());

s.addElementType(b.getIntegerType(1));

s.addElementType(b.getIntegerType(8));
s.addElementType(b.getIntegerType(16));
s.addElementType(b.getIntegerType(32));
s.addElementType(b.getIntegerType(64));
s.addElementType(b.getIndexType());
s.addElementType(b.getF32Type());
if (requestedTargetDevice != "vmvx" && hasRequestedTargetDevice) {
// The full compilers support additional types.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ iree_cc_library(
"HoistableTypeInterface.cpp"
DEPS
LLVMSupport
MLIRArithDialect
MLIRIR
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Util::IR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinTypes.h"

namespace mlir::iree_compiler {
Expand Down Expand Up @@ -85,6 +86,36 @@ struct HoistableTensorTypeInterface
}
};

struct HoistableIndexTypeInterface
: public IREE::Util::HoistableTypeInterface::ExternalModel<
HoistableIndexTypeInterface, IndexType> {
bool isHoistableType(Type type) const { return true; }
bool isHoistableLeafType(Type type) const { return true; }
Type getPreferredStorageType(Type type) const {
// Conservatively enforce 64 bit indices for
// (potentially constant evaluated) hoisted globals.
return IntegerType::get(type.getContext(), 64);
}
static Value encodeStorageType(OpBuilder &builder, Location loc,
Type storageType, Value init) {
auto storageIndexType = dyn_cast<IntegerType>(storageType);
if (!storageIndexType || init.getType() == storageIndexType ||
!isa<IndexType>(init.getType())) {
return init;
}
return builder.create<arith::IndexCastOp>(loc, storageType, init);
}
static Value decodeStorageType(OpBuilder &builder, Location loc,
Type originalType, Value loadedGlobal) {
auto originalIndexType = dyn_cast<IndexType>(originalType);
if (!originalIndexType || loadedGlobal.getType() == originalIndexType ||
!isa<IntegerType>(loadedGlobal.getType())) {
return loadedGlobal;
}
return builder.create<arith::IndexCastOp>(loc, originalType, loadedGlobal);
}
};

//===----------------------------------------------------------------------===//
// IREE specific post analysis transformations.
//===----------------------------------------------------------------------===//
Expand All @@ -93,6 +124,7 @@ void registerHoistableTypeInterfaces(DialectRegistry &registry) {
// Register hoistable type interfaces for builtin types.
registry.addExtension(+[](MLIRContext *ctx) {
RankedTensorType::attachInterface<HoistableTensorTypeInterface>(*ctx);
IndexType::attachInterface<HoistableIndexTypeInterface>(*ctx);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,26 @@ module @hoist_dialect_attrs {
util.return %1 : tensor<i32>
}
}

// -----

// CHECK-LABEL: @hoist_index
module @hoist_index {
// CHECK: util.global private @[[HOISTED:.*]] : i64
// CHECK: util.initializer
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[CEXPR:.*]] = "iree_unregistered.const_expr"(%[[C0]])
// CHECK: %[[CAST:.*]] = arith.index_cast %[[CEXPR]] : index to i64
// CHECK: util.global.store %[[CAST]], @[[HOISTED]] : i64
// CHECK: util.return

// CHECK: util.func public @main() -> index
// CHECK: %[[GLOBAL_LD:.*]] = util.global.load immutable @[[HOISTED]] : i64
// CHECK: %[[ORIG_VAL:.*]] = arith.index_cast %[[GLOBAL_LD]] : i64 to index
// CHECK: util.return %[[ORIG_VAL]]
util.func public @main() -> (index) {
%0 = arith.constant 0 : index
%1 = "iree_unregistered.const_expr"(%0) : (index) -> index
util.return %1 : index
}
}

0 comments on commit d8f0fc3

Please sign in to comment.