From 5270093401cea05a548352fce312e39d0291024c Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 8 Oct 2024 19:53:04 -0700 Subject: [PATCH] Add an integer divisibility analysis. (#18727) * Also extends the numeric optimization test to elide arith.remui that exactly divides (a reasonable optimization but primarily used for testing in this patch). * Only implements the analysis for `arith.constant` and `util.int.assume` in this patch. * Renames assumption `divisor` to `udiv` to match terminology elsewhere wrt signed/unsigned analysis. * The lattice tracks unsigned and signed interpretations separately as this is needed for propagation through signed ops (but this is not implemented here). --------- Signed-off-by: Stella Laurenzo --- .../test/bind_symbolic_shapes.mlir | 12 +-- .../Dialect/Util/Analysis/BUILD.bazel | 2 + .../Dialect/Util/Analysis/CMakeLists.txt | 2 + .../Analysis/IntegerDivisibilityAnalysis.cpp | 69 ++++++++++++++ .../Analysis/IntegerDivisibilityAnalysis.h | 42 +++++++++ .../compiler/Dialect/Util/IR/UtilAttrs.td | 4 +- .../Dialect/Util/IR/UtilInterfaces.td | 26 ++++++ .../iree/compiler/Dialect/Util/IR/UtilOps.cpp | 19 +++- .../iree/compiler/Dialect/Util/IR/UtilOps.td | 5 +- .../iree/compiler/Dialect/Util/IR/UtilTypes.h | 90 +++++++++++++++++++ .../Dialect/Util/IR/test/assume_ops.mlir | 8 +- .../Dialect/Util/IR/test/attributes.mlir | 8 +- .../Dialect/Util/IR/test/op_verification.mlir | 2 +- .../Util/Transforms/OptimizeIntArithmetic.cpp | 48 ++++++++++ .../Dialect/Util/Transforms/test/BUILD.bazel | 1 + .../Util/Transforms/test/CMakeLists.txt | 1 + .../Transforms/test/integer_divisibility.mlir | 36 ++++++++ .../test/optimize_int_arithmetic.mlir | 12 +++ .../ExternalInterfaces/UtilExternalModels.cpp | 25 ++++++ 19 files changed, 391 insertions(+), 21 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp create mode 100644 compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h create mode 100644 compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir diff --git a/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir b/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir index 641013f57fb1..8f78bf7dcb0c 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir @@ -19,7 +19,7 @@ module @basic_example { // CHECK: %[[ARG1_DIM0_RANGE:.*]] = util.assume.int %[[DIM0]] // CHECK: %[[MULTIPLIER0:.*]] = arith.constant 2 : index // CHECK: %[[ARG1_DIM1:.*]] = arith.muli %[[DIM1]], %[[MULTIPLIER0]] - // CHECK: %[[ARG1_DIM1_RANGE:.*]] = util.assume.int %[[ARG1_DIM1]] : index + // CHECK: %[[ARG1_DIM1_RANGE:.*]] = util.assume.int %[[ARG1_DIM1]] : index // CHECK: %[[ARG1_TIE:.*]] = flow.tensor.tie_shape %[[ARG1_ANCHOR]] : tensor{%[[ARG1_DIM0_RANGE]], %[[ARG1_DIM1_RANGE]]} // CHECK: %[[ARG1_EXPORT:.*]] = torch_c.from_builtin_tensor %[[ARG1_TIE]] %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int @@ -49,10 +49,10 @@ module @basic_example { module @unbacked_symbol { func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: util.assume.int{{.*}} - // CHECK: util.assume.int{{.*}} + // CHECK: util.assume.int{{.*}} // CHECK: tie_shape // CHECK: util.assume.int{{.*}} - // CHECK: util.assume.int{{.*}} + // CHECK: util.assume.int{{.*}} // CHECK: tie_shape %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int %1 = torch.symbolic_int "2*s4" {min_val = 0, max_val = 2048} : !torch.int @@ -100,7 +100,7 @@ module @all_bindings_dropped { module @add_expr { func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) { // CHECK: addi - // CHECK-NOT: divisor + // CHECK-NOT: udiv %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32> @@ -114,7 +114,7 @@ module @add_expr { module @mod_expr { func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) { // CHECK: remui - // CHECK-NOT: divisor + // CHECK-NOT: udiv %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32> @@ -128,7 +128,7 @@ module @mod_expr { module @floordiv_expr { func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) { // CHECK: divui - // CHECK-NOT: divisor + // CHECK-NOT: udiv %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32> diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Analysis/BUILD.bazel index ebc51fd3e6ce..502818b852e3 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/BUILD.bazel @@ -17,11 +17,13 @@ iree_compiler_cc_library( srcs = [ "Explorer.cpp", "GlobalTable.cpp", + "IntegerDivisibilityAnalysis.cpp", "Position.cpp", ], hdrs = [ "Explorer.h", "GlobalTable.h", + "IntegerDivisibilityAnalysis.h", "Position.h", ], deps = [ diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Analysis/CMakeLists.txt index e253a2ed4fab..04c4c80b4516 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/CMakeLists.txt @@ -16,10 +16,12 @@ iree_cc_library( HDRS "Explorer.h" "GlobalTable.h" + "IntegerDivisibilityAnalysis.h" "Position.h" SRCS "Explorer.cpp" "GlobalTable.cpp" + "IntegerDivisibilityAnalysis.cpp" "Position.cpp" DEPS LLVMSupport diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp new file mode 100644 index 000000000000..f87864ebb0ad --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp @@ -0,0 +1,69 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h" + +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "iree-util-int-divisibility-analysis" + +using llvm::dbgs; + +namespace mlir::iree_compiler::IREE::Util { + +void IntegerDivisibilityAnalysis::setToEntryState( + IntegerDivisibilityLattice *lattice) { + propagateIfChanged(lattice, + lattice->join(IntegerDivisibility::getMinDivisibility())); +} + +LogicalResult IntegerDivisibilityAnalysis::visitOperation( + Operation *op, ArrayRef operands, + ArrayRef results) { + auto inferrable = dyn_cast(op); + if (!inferrable) { + setAllToEntryStates(results); + return success(); + } + + LLVM_DEBUG(dbgs() << "Inferring divisibility for " << *op << "\n"); + auto argDivs = llvm::map_to_vector( + operands, [](const IntegerDivisibilityLattice *lattice) { + return lattice->getValue(); + }); + auto joinCallback = [&](Value v, const IntegerDivisibility &newDiv) { + auto result = dyn_cast(v); + if (!result) + return; + assert(llvm::is_contained(op->getResults(), result)); + + LLVM_DEBUG(dbgs() << "Inferred divisibility " << newDiv << "\n"); + IntegerDivisibilityLattice *lattice = results[result.getResultNumber()]; + IntegerDivisibility oldDiv = lattice->getValue(); + + ChangeResult changed = lattice->join(newDiv); + + // Catch loop results with loop variant bounds and conservatively make + // them [-inf, inf] so we don't circle around infinitely often (because + // the dataflow analysis in MLIR doesn't attempt to work out trip counts + // and often can't). + bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { + return op->hasTrait(); + }); + if (isYieldedResult && !oldDiv.isUninitialized() && + !(lattice->getValue() == oldDiv)) { + LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + changed |= lattice->join(IntegerDivisibility::getMinDivisibility()); + } + propagateIfChanged(lattice, changed); + }; + + inferrable.inferResultDivisibility(argDivs, joinCallback); + return success(); +} + +} // namespace mlir::iree_compiler::IREE::Util diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h new file mode 100644 index 000000000000..2a550f69ec88 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h @@ -0,0 +1,42 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_UTIL_INTEGER_DIVISIBILITY_ANALYSIS_H_ +#define IREE_COMPILER_DIALECT_UTIL_INTEGER_DIVISIBILITY_ANALYSIS_H_ + +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" + +#include + +namespace mlir::iree_compiler::IREE::Util { + +class IntegerDivisibilityLattice + : public dataflow::Lattice { +public: + using Lattice::Lattice; +}; + +class IntegerDivisibilityAnalysis + : public dataflow::SparseForwardDataFlowAnalysis< + IntegerDivisibilityLattice> { +public: + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; + + // At an entry point, set the lattice to the most pessimistic state, + // indicating that no further reasoning can be done. + void setToEntryState(IntegerDivisibilityLattice *lattice) override; + + // Visit an operation, invoking the transfer function. + LogicalResult + visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override; +}; + +} // namespace mlir::iree_compiler::IREE::Util + +#endif // IREE_COMPILER_DIALECT_UTIL_INTEGER_DIVISIBILITY_ANALYSIS_H_ diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.td index e43a2b6a1997..be4b11531aa0 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.td @@ -30,10 +30,10 @@ def Util_IntAssumptionAttr : AttrDef", "std::nullopt">:$umin, DefaultValuedParameter<"std::optional", "std::nullopt">:$umax, - DefaultValuedParameter<"std::optional", "std::nullopt">:$divisor + DefaultValuedParameter<"std::optional", "std::nullopt">:$udiv ); let assemblyFormat = [{ - `<` struct($umin, $umax, $divisor) `>` + `<` struct($umin, $umax, $udiv) `>` }]; } diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td index 85acab88a8c6..f0b541096b21 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td @@ -94,6 +94,32 @@ def Util_ClosureOpInterface : OpInterface<"ClosureOpInterface"> { ]; } +//===----------------------------------------------------------------------===// +// IREE::Util::InferIntDivisibilityOpInterface +//===----------------------------------------------------------------------===// + +def InferIntDivisibilityOpInterface : + OpInterface<"InferIntDivisibilityOpInterface"> { + let cppNamespace = "::mlir::iree_compiler::IREE::Util"; + + let description = [{ + Allows operations to participate in integer divisibility analysis. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + + }], + /*retTy=*/"void", + /*methodName=*/"inferResultDivisibility", + /*args=*/(ins + "::llvm::ArrayRef<::mlir::iree_compiler::IREE::Util::IntegerDivisibility>":$argDivs, + "::mlir::iree_compiler::IREE::Util::SetIntDivisibilityFn":$setResultDivs) + > + ]; +} + //===----------------------------------------------------------------------===// // IREE::Util::InitializerOpInterface //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp index 201ce8ac5a90..971413c77740 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp @@ -1141,11 +1141,12 @@ AssumeIntOp::getUnionedUnsignedRange(unsigned operandIndex) { // Gets the unioned divisor for an operand. If there are multiple divisor // assumptions, the gcd of all of them is returned. If there are no // divisor assumptions, std::nullopt is returned. -std::optional AssumeIntOp::getUnionedDivisor(unsigned operandIndex) { +std::optional +AssumeIntOp::getUnionedUnsignedDivisor(unsigned operandIndex) { auto assumptions = getOperandAssumptions(operandIndex); std::optional divisorUnion; for (auto assumption : assumptions) { - auto divisor = assumption.getDivisor(); + auto divisor = assumption.getUdiv(); if (divisor) { if (divisorUnion) divisorUnion = std::gcd(*divisor, *divisorUnion); @@ -1176,6 +1177,20 @@ void AssumeIntOp::inferResultRanges(ArrayRef argRanges, } } +void AssumeIntOp::inferResultDivisibility(ArrayRef argDivs, + SetIntDivisibilityFn setResultDivs) { + for (auto [index, result] : llvm::enumerate(getResults())) { + Type type = result.getType(); + if (!isa(type) && !isa(type)) + continue; + auto udiv = getUnionedUnsignedDivisor(index); + if (udiv) { + setResultDivs(result, + ConstantIntDivisibility(/*udiv=*/*udiv, /*sdiv=*/*udiv)); + } + } +} + void AssumeIntOp::build(OpBuilder &builder, OperationState &state, Value singleOperand, IntAssumptionAttr singleAssumption) { diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td index 3a7179b16412..84e0d79b7d78 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td @@ -460,7 +460,8 @@ def OpGroupCompilerHintOps : OpDocGroup { let opDocGroup = OpGroupCompilerHintOps in { def Util_AssumeIntOp : Util_PureOp<"assume.int", [ - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods ]> { let summary = "memorializes assumptions about index/integer values."; let description = [{ @@ -507,7 +508,7 @@ def Util_AssumeIntOp : Util_PureOp<"assume.int", [ // Gets the unioned divisor for an operand. If there are multiple divisor // assumptions, the gcd of all of them is returned. If there are no // divisor assumptions, std::nullopt is returned. - std::optional getUnionedDivisor(unsigned operandIndex); + std::optional getUnionedUnsignedDivisor(unsigned operandIndex); }]; let hasCustomAssemblyFormat = 1; diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h index a80281c43999..ec1c2d984a1a 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h @@ -22,6 +22,8 @@ #include "mlir/IR/Types.h" #include "mlir/Interfaces/CallInterfaces.h" +#include + // clang-format off: must be included after all LLVM/MLIR headers. #include "iree/compiler/Dialect/Util/IR/UtilEnums.h.inc" // IWYU pragma: keep // clang-format on @@ -155,6 +157,94 @@ void excludeTiedOperandAndResultIndices( ArrayRef excludedResultIndices, SmallVector &tiedOperandIndices); +//===----------------------------------------------------------------------===// +// Forward defines for InferIntDivisibilityOpInterface +// See implementations in IntegerDivisibility.h. +//===----------------------------------------------------------------------===// + +class ConstantIntDivisibility { +public: + ConstantIntDivisibility() = default; + ConstantIntDivisibility(uint64_t udiv, uint64_t sdiv) + : udivVal(udiv), sdivVal(sdiv) {} + + bool operator==(const ConstantIntDivisibility &other) const { + return udivVal == other.udivVal && sdivVal == other.sdivVal; + } + + uint64_t udiv() const { return this->udivVal; } + uint64_t sdiv() const { return this->sdivVal; } + + // Returns the union (computed separately for signed and unsigned bounds) + // for this range and `other`. + ConstantIntDivisibility getUnion(const ConstantIntDivisibility &other) const { + return ConstantIntDivisibility( + /*udiv=*/std::gcd(udiv(), other.udiv()), + /*sdiv=*/std::gcd(sdiv(), other.sdiv())); + } + +private: + uint64_t udivVal; + uint64_t sdivVal; + + friend raw_ostream &operator<<(raw_ostream &os, + const ConstantIntDivisibility &div); +}; + +inline raw_ostream &operator<<(raw_ostream &os, + const ConstantIntDivisibility &div) { + os << "ConstantIntDivisibility(udiv = " << div.udivVal + << ", sdiv = " << div.sdivVal << ")"; + return os; +} + +class IntegerDivisibility { +public: + IntegerDivisibility(ConstantIntDivisibility value) + : value(std::move(value)) {} + IntegerDivisibility( + std::optional value = std::nullopt) + : value(std::move(value)) {} + // Gets the minimum divisibility of 1 that is used to indicate that the value + // cannot be analyzed further. + static IntegerDivisibility getMinDivisibility() { + return IntegerDivisibility(ConstantIntDivisibility(1, 1)); + } + + bool isUninitialized() const { return !value.has_value(); } + const ConstantIntDivisibility &getValue() const { + assert(!isUninitialized()); + return *value; + } + + bool operator==(const IntegerDivisibility &rhs) const { + return value == rhs.value; + } + + static IntegerDivisibility join(const IntegerDivisibility &lhs, + const IntegerDivisibility &rhs) { + if (lhs.isUninitialized()) + return rhs; + if (rhs.isUninitialized()) + return lhs; + return IntegerDivisibility(lhs.getValue().getUnion(rhs.getValue())); + } + + void print(raw_ostream &os) const { os << value; } + +private: + std::optional value; +}; + +inline raw_ostream &operator<<(raw_ostream &os, + const IntegerDivisibility &div) { + div.print(os); + return os; +} + +using SetIntDivisibilityFn = + llvm::function_ref; + //===----------------------------------------------------------------------===// // Shape-aware interface utilities //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/assume_ops.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/assume_ops.mlir index 0886d38531dd..ecf6a8c4d741 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/test/assume_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/assume_ops.mlir @@ -10,15 +10,15 @@ util.func public @assume.int.single_assumption(%arg0 : index) -> index { // ----- // CHECK-LABEL: @assume.int.multi_assumption util.func public @assume.int.multi_assumption(%arg0 : index) -> index { - // CHECK: util.assume.int %arg0[, ] : index - %0 = util.assume.int %arg0[, ] : index + // CHECK: util.assume.int %arg0[, ] : index + %0 = util.assume.int %arg0[, ] : index util.return %0 : index } // ----- // CHECK-LABEL: @assume.int.multi_operand util.func public @assume.int.multi_operand(%arg0 : index, %arg1 : i64) -> index, i64 { - // CHECK: util.assume.int %arg0[, ], %arg1[, ] : index, i64 - %0:2 = util.assume.int %arg0[, ], %arg1[, ] : index, i64 + // CHECK: util.assume.int %arg0[, ], %arg1[, ] : index, i64 + %0:2 = util.assume.int %arg0[, ], %arg1[, ] : index, i64 util.return %0#0, %0#1 : index, i64 } diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/attributes.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/attributes.mlir index 5220794b561f..54c954d03fa2 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/test/attributes.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/attributes.mlir @@ -2,12 +2,12 @@ // CHECK-LABEL: @assume_int builtin.module @assume_int attributes { - // CHECK: util.all = #util.int.assumption - // CHECK-SAME: util.divisor = #util.int.assumption + // CHECK: util.all = #util.int.assumption + // CHECK-SAME: util.udiv = #util.int.assumption // CHECK-SAME: util.umax = #util.int.assumption // CHECK-SAME: util.umin = #util.int.assumption - util.all = #util.int.assumption, - util.divisor = #util.int.assumption, + util.all = #util.int.assumption, + util.udiv = #util.int.assumption, util.umax = #util.int.assumption, util.umin = #util.int.assumption } {} diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir index e4a6f6e9bc92..2be8dc549226 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir @@ -2,7 +2,7 @@ util.func public @assume.int.multi_operand(%arg0 : index, %arg1 : i64) -> index, i64 { // expected-error @+1 {{expected operand #1 to have 1 assumptions but it has 2}} - %0:2 = util.assume.int %arg0[], %arg1[, ] : index, i64 + %0:2 = util.assume.int %arg0[], %arg1[, ] : index, i64 util.return %0#0, %0#1 : index, i64 } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp index 04ae8b707baf..7499bf0ce3f2 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp @@ -4,6 +4,7 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include "iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/Util/Transforms/PassDetail.h" #include "iree/compiler/Dialect/Util/Transforms/Passes.h" @@ -162,6 +163,49 @@ struct ConvertUnsignedI64IndexCastProducerToIndex DataFlowSolver &solver; }; +//===----------------------------------------------------------------------===// +// Divisibility +//===----------------------------------------------------------------------===// + +static LogicalResult getDivisibility(DataFlowSolver &solver, Operation *op, + Value value, PatternRewriter &rewriter, + ConstantIntDivisibility &out) { + auto *div = solver.lookupState(value); + if (!div || div->getValue().isUninitialized()) + return rewriter.notifyMatchFailure(op, + "divisibility could not be determined"); + + out = div->getValue().getValue(); + LLVM_DEBUG(dbgs() << " * Resolved divisibility: " << out << "\n"); + return success(); +} + +struct RemUIDivisibilityByConstant : public OpRewritePattern { + RemUIDivisibilityByConstant(MLIRContext *context, DataFlowSolver &solver) + : OpRewritePattern(context), solver(solver) {} + + LogicalResult matchAndRewrite(arith::RemUIOp op, + PatternRewriter &rewriter) const override { + APInt rhsConstant; + if (!matchPattern(op.getRhs(), m_ConstantInt(&rhsConstant))) + return rewriter.notifyMatchFailure(op, "rhs is not constant"); + + ConstantIntDivisibility lhsDiv; + if (failed(getDivisibility(solver, op, op.getLhs(), rewriter, lhsDiv))) + return failure(); + + uint64_t rhsValue = rhsConstant.getZExtValue(); + if (rhsValue > 0 && lhsDiv.udiv() > 0 && lhsDiv.udiv() % rhsValue != 0) + return rewriter.notifyMatchFailure(op, "rhs does not divide lhs"); + + rewriter.replaceOpWithNewOp( + op, rewriter.getZeroAttr(op.getResult().getType())); + return success(); + } + + DataFlowSolver &solver; +}; + //===----------------------------------------------------------------------===// // Pass setup //===----------------------------------------------------------------------===// @@ -229,6 +273,7 @@ class OptimizeIntArithmeticPass DataFlowSolver solver; solver.load(); solver.load(); + solver.load(); DataFlowListener listener(solver); RewritePatternSet patterns(ctx); @@ -255,6 +300,9 @@ class OptimizeIntArithmeticPass ConvertOpToUnsigned>(ctx, solver); + // Populate divisibility patterns. + patterns.add(ctx, solver); + GreedyRewriteConfig config; // Results in fewer recursive data flow flushes/cycles on modification. config.useTopDownTraversal = false; diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel index 4a1135868c72..c4a4724e7d11 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel @@ -24,6 +24,7 @@ iree_lit_test_suite( "hoist_into_globals.mlir", "hoist_into_globals_linalg.mlir", "import_resources.mlir", + "integer_divisibility.mlir", "ipo.mlir", "optimize_int_arithmetic.mlir", "patterns.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt index 2109f2abdb53..02ba4d93fa78 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt @@ -22,6 +22,7 @@ iree_lit_test_suite( "hoist_into_globals.mlir" "hoist_into_globals_linalg.mlir" "import_resources.mlir" + "integer_divisibility.mlir" "ipo.mlir" "optimize_int_arithmetic.mlir" "patterns.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir new file mode 100644 index 000000000000..e8c2740bb31f --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir @@ -0,0 +1,36 @@ +// RUN: iree-opt --split-input-file --iree-util-optimize-int-arithmetic %s | FileCheck %s + +// Use the int arithmetic optimization pipeline to test the integer divisibility +// analysis. This largely relies on the arith.remui operation resolving to a +// constant 0 on even division. + +// CHECK-LABEL: @remui_div_by_exact_factor +util.func @remui_div_by_exact_factor(%arg0 : index) -> index { + %cst = arith.constant 16 : index + %0 = util.assume.int %arg0 : index + %1 = arith.remui %0, %cst : index + // CHECK: %[[CST:.*]] = arith.constant 0 + // CHECK: return %[[CST]] + util.return %1 : index +} + +// ----- +// CHECK-LABEL: @remui_div_by_common_factor +util.func @remui_div_by_common_factor(%arg0 : index) -> index { + %cst = arith.constant 8 : index + %0 = util.assume.int %arg0 : index + %1 = arith.remui %0, %cst : index + // CHECK: %[[CST:.*]] = arith.constant 0 + // CHECK: return %[[CST]] + util.return %1 : index +} + +// ----- +// CHECK-LABEL: @remui_div_by_unrelated +util.func @remui_div_by_unrelated(%arg0 : index) -> index { + %cst = arith.constant 23 : index + %0 = util.assume.int %arg0 : index + // CHECK: arith.remui + %1 = arith.remui %0, %cst : index + util.return %1 : index +} diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir index 39ff74c2cc96..4726941c42db 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir @@ -359,3 +359,15 @@ util.func @index_unsigned_overflow_signed(%arg0 : index) -> index { %1 = arith.divsi %0, %cst : index util.return %1 : index } + +// ----- +// CHECK-LABEL: @index_cast_i64_to_index_remsi +util.func @index_cast_i64_to_index_remsi(%arg0 : index, %arg1 : index) -> index { + // CHECK: %[[ASSUME:.*]] = util.assume.int + %0 = util.assume.int %arg0 : index + %1 = arith.index_cast %0 : index to i64 + // CHECK: arith.remui %[[ASSUME]], %[[ASSUME]] : index + %2 = arith.remsi %1, %1 : i64 + %3 = arith.index_cast %2 : i64 to index + util.return %3 : index +} diff --git a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp index c5e3e1679ec0..2097cbf1e6de 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp +++ b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp @@ -22,6 +22,29 @@ namespace mlir::iree_compiler { namespace { +//===----------------------------------------------------------------------===// +// InferIntDivisibilityOpInterface +//===----------------------------------------------------------------------===// + +struct ArithConstantInferIntDivisibilityOpInterface + : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< + ArithConstantInferIntDivisibilityOpInterface, arith::ConstantOp> { + + void inferResultDivisibility( + Operation *op, ArrayRef argDivs, + IREE::Util::SetIntDivisibilityFn setResultDivs) const { + auto constOp = cast(op); + auto constAttr = llvm::dyn_cast_or_null(constOp.getValue()); + if (constAttr) { + const APInt &value = constAttr.getValue(); + uint64_t udiv = value.getZExtValue(); + uint64_t sdiv = std::abs(value.getSExtValue()); + setResultDivs(constOp.getResult(), + IREE::Util::ConstantIntDivisibility(udiv, sdiv)); + } + } +}; + //===----------------------------------------------------------------------===// // GlobalOpInterface //===----------------------------------------------------------------------===// @@ -303,6 +326,8 @@ void registerUtilExternalModels(DialectRegistry ®istry) { arith::BitcastOp, arith::ExtFOp, arith::ExtUIOp, arith::ExtSIOp, arith::FPToSIOp, arith::FPToUIOp, arith::IndexCastOp, arith::TruncFOp, arith::TruncIOp, arith::SIToFPOp, arith::UIToFPOp>(context); + arith::ConstantOp::attachInterface< + ArithConstantInferIntDivisibilityOpInterface>(*context); }); registry.addExtension(