diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index f1a820fb245b..81f8da846eb1 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -22,6 +22,27 @@ namespace mlir::iree_compiler::IREE::HAL { +namespace { + +// We aribtrarily say that unbounded dimensions in a torch program cannot +// exceed 53bits, making the maximum safe dimension 9007199254740991. The +// astute reader will note that this is also the maximum safe value in +// JavaScript, which also "happens" to be the largest mantissa value in a +// 64bit double. We need a maximum and in the absence of a better choice, +// with this one we are at least in good company. This limit is also used +// in the frontends. +static constexpr uint64_t MAX_DIM_VALUE = (static_cast(1) << 53) - 1; + +// Similarly we use a very conservative maximum rank value for specifying +// ranges of runtime rank resolution functions. Various frameworks have hard +// and practical limits ranging from 32 (numpy) to hundreds. At the time of +// writing, PyTorch throws weird errors if trying to print a tensor with a rank +// greater than 992. We really just want a smallish integer value to bound +// arithmetic, so we use an arbitrary maximum. +static constexpr uint64_t MAX_RANK_VALUE = 4096; + +} // namespace + //===----------------------------------------------------------------------===// // custom($descriptor_type) //===----------------------------------------------------------------------===// @@ -1024,6 +1045,30 @@ void BufferViewBufferOp::getAsmResultNames( setNameFn(getResult(), "buffer"); } +//===----------------------------------------------------------------------===// +// hal.buffer_view.dim +//===----------------------------------------------------------------------===// + +void BufferViewDimOp::inferResultRangesFromOptional( + ArrayRef argRanges, SetIntLatticeFn setResultRange) { + const unsigned indexTypeNumBits = 64; + setResultRange(getResult(), IntegerValueRange(ConstantIntRanges::fromUnsigned( + APInt::getZero(indexTypeNumBits), + APInt(indexTypeNumBits, MAX_DIM_VALUE)))); +} + +//===----------------------------------------------------------------------===// +// hal.buffer_view.dim +//===----------------------------------------------------------------------===// + +void BufferViewRankOp::inferResultRangesFromOptional( + ArrayRef argRanges, SetIntLatticeFn setResultRange) { + const unsigned indexTypeNumBits = 64; + setResultRange(getResult(), IntegerValueRange(ConstantIntRanges::fromUnsigned( + APInt::getZero(indexTypeNumBits), + APInt(indexTypeNumBits, MAX_RANK_VALUE)))); +} + //===----------------------------------------------------------------------===// // hal.channel.create //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h index ae58127959bb..16dd46bc5e17 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h @@ -20,6 +20,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index fdd43b7a5e72..9e370a10c22b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -18,6 +18,7 @@ include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -1010,7 +1011,10 @@ def HAL_BufferViewEncodingTypeOp : HAL_PureOp<"buffer_view.encoding_type"> { }]; } -def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank"> { +def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank", [ + DeclareOpInterfaceMethods, +]> { let summary = [{buffer view rank query}]; let description = [{ Returns the rank of the buffer view. @@ -1030,7 +1034,10 @@ def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank"> { }]; } -def HAL_BufferViewDimOp : HAL_PureOp<"buffer_view.dim"> { +def HAL_BufferViewDimOp : HAL_PureOp<"buffer_view.dim", [ + DeclareOpInterfaceMethods, +]> { let summary = [{buffer view dimension value query}]; let description = [{ Returns the value of the given dimension. diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir index 1924f423ef66..f78817cb03af 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir @@ -493,3 +493,33 @@ util.func @util_align_zero(%arg0 : i64) -> i64 { %rem16 = arith.remui %0, %c16 : i64 util.return %rem16 : i64 } + +// ----- + +util.func @hal_buffer_view_dim_min_max(%bv : !hal.buffer_view) -> (i1, i1, i1) { + %zero = arith.constant 0 : index + %max = arith.constant 9007199254740991 : index + %0 = hal.buffer_view.dim<%bv : !hal.buffer_view>[0] : index + %1 = arith.cmpi slt, %0, %zero : index + %2 = arith.cmpi uge, %0, %zero : index + %3 = arith.cmpi ugt, %0, %max : index + // CHECK-DAG: %[[FALSE:.*]] = arith.constant false + // CHECK-DAG: %[[TRUE:.*]] = arith.constant true + // CHECK: util.return %[[FALSE]], %[[TRUE]], %[[FALSE]] + util.return %1, %2, %3 : i1, i1, i1 +} + +// ----- + +util.func @hal_buffer_view_rank_min_max(%bv : !hal.buffer_view) -> (i1, i1, i1) { + %zero = arith.constant 0 : index + %max = arith.constant 4096 : index + %0 = hal.buffer_view.rank<%bv : !hal.buffer_view> : index + %1 = arith.cmpi slt, %0, %zero : index + %2 = arith.cmpi uge, %0, %zero : index + %3 = arith.cmpi ugt, %0, %max : index + // CHECK-DAG: %[[FALSE:.*]] = arith.constant false + // CHECK-DAG: %[[TRUE:.*]] = arith.constant true + // CHECK: util.return %[[FALSE]], %[[TRUE]], %[[FALSE]] + util.return %1, %2, %3 : i1, i1, i1 +}