From b0c77fa8ce1c4c860813c8064a500ecc02af2832 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Thu, 24 Aug 2023 13:44:36 -0700 Subject: [PATCH] Improving util.align folding. (#14805) --- .../iree/compiler/Dialect/Util/IR/BUILD.bazel | 1 + .../compiler/Dialect/Util/IR/CMakeLists.txt | 1 + .../compiler/Dialect/Util/IR/UtilAttrs.cpp | 2 +- .../Dialect/Util/IR/UtilOpFolders.cpp | 31 ++++++++---- .../Util/IR/test/alignment_folding.mlir | 48 ++++++++++++++++++- 5 files changed, 73 insertions(+), 10 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel index 213481e16998..9a8c58204970 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel @@ -95,6 +95,7 @@ iree_compiler_cc_library( ":UtilOpsGen", ":UtilTypesGen", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:CastInterfaces", "@llvm-project//mlir:ControlFlowDialect", diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/IR/CMakeLists.txt index c4ca0ce28d93..5abbea5f901f 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/IR/CMakeLists.txt @@ -63,6 +63,7 @@ iree_cc_library( ::UtilOpsGen ::UtilTypesGen LLVMSupport + MLIRAffineDialect MLIRArithDialect MLIRCastInterfaces MLIRControlFlowDialect diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp index 2b664cd72501..0e26efc0fc07 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp @@ -395,7 +395,7 @@ serializeGenericElementData(Location loc, DenseElementsAttr elementsAttr, case 64: return serializeGenericIntegerElements(attr, endian, os); default: - if (bitWidth != 1 && bitWidth < 64) { + if (bitWidth < 64) { // Special case for bit-packing of sub-byte aligned types. // This could be extended to handle larger widths (i33, etc) but they // are rare today. diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp index fe3954ac54d4..ea55cd4ab526 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp @@ -4,10 +4,13 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include + #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/SmallPtrSet.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" @@ -364,9 +367,11 @@ void RangeExtentsOp::getCanonicalizationPatterns(RewritePatternSet &results, // a large majority of the cases we generate ourselves from packing/allocation. static bool isAlignedTo(Value value, Value alignment) { APInt staticValue; + bool hasStaticValue = matchPattern(value, m_ConstantInt(&staticValue)); APInt staticAlignment; - if (matchPattern(value, m_ConstantInt(&staticValue)) && - matchPattern(alignment, m_ConstantInt(&staticAlignment))) { + bool hasStaticAlignment = + matchPattern(alignment, m_ConstantInt(&staticAlignment)); + if (hasStaticValue && hasStaticAlignment) { // If this value is itself a multiple of the alignment then we can fold. if (staticValue.urem(staticAlignment).isZero()) { return true; // value % alignment == 0 @@ -381,11 +386,12 @@ static bool isAlignedTo(Value value, Value alignment) { // If the alignments are constant we can compare them inline. APInt sourceAlignment; - APInt selfAlignment; - if (matchPattern(sourceAlignOp.getAlignment(), - m_ConstantInt(&sourceAlignment)) && - matchPattern(alignment, m_ConstantInt(&selfAlignment))) { - if (sourceAlignment.uge(selfAlignment)) { + if (hasStaticAlignment && matchPattern(sourceAlignOp.getAlignment(), + m_ConstantInt(&sourceAlignment))) { + if (sourceAlignment.uge(staticAlignment) && + std::gcd(sourceAlignment.getZExtValue(), + staticAlignment.getZExtValue()) == + staticAlignment.getZExtValue()) { return true; // source alignment is >= our alignment } } @@ -395,6 +401,15 @@ static bool isAlignedTo(Value value, Value alignment) { return isAlignedTo(sourceAlignOp.getValue(), alignment); } + // Affine apply ops producing the value to be aligned usually include + // alignment already. + if (auto affineOp = value.getDefiningOp()) { + if (hasStaticAlignment) { + return (affineOp.getAffineMap().getLargestKnownDivisorOfMapExprs() % + staticAlignment.getZExtValue()) == 0; + } + } + // If we are sourced from add/mul we peephole check to see if what is being // added is also aligned. This should be part of a larger pass doing IPO but // as the common case is that we align+add+align this is worth having in a @@ -414,7 +429,7 @@ static bool isAlignedTo(Value value, Value alignment) { } } else if (auto sourceMulOp = value.getDefiningOp()) { // Two aligned values multiplied together are still aligned. - if (isAlignedTo(sourceMulOp.getLhs(), alignment) && + if (isAlignedTo(sourceMulOp.getLhs(), alignment) || isAlignedTo(sourceMulOp.getRhs(), alignment)) { return true; } diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_folding.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_folding.mlir index 750746e13b17..3477dbe6e0bf 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_folding.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/alignment_folding.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --canonicalize %s | iree-opt --split-input-file | FileCheck %s +// RUN: iree-opt --split-input-file --canonicalize --mlir-print-local-scope %s | iree-opt --split-input-file --mlir-print-local-scope | FileCheck %s // CHECK-LABEL: @foldSameAlignment // CHECK-SAME: (%[[VALUE:.+]]: index, %[[ALIGNMENT:.+]]: index) @@ -43,6 +43,21 @@ func.func @dontFoldLesserAlignment(%value: index) -> index { // ----- +// CHECK-LABEL: @dontFoldMixedAlignment +// CHECK-SAME: (%[[VALUE:.+]]: index) +func.func @dontFoldMixedAlignment(%value: index) -> index { + %c9 = arith.constant 9 : index + %c16 = arith.constant 16 : index + // CHECK: %[[ALIGN16:.+]] = util.align %[[VALUE]], %c16 + %0 = util.align %value, %c16 : index + // CHECK: %[[ALIGN9:.+]] = util.align %[[ALIGN16]], %c9 + %1 = util.align %0, %c9 : index + // CHECK: return %[[ALIGN9]] + return %1 : index +} + +// ----- + // CHECK-LABEL: @foldAlignmentRecursively // CHECK-SAME: (%[[VALUE:.+]]: index, %[[ALIGNMENT:.+]]: index) func.func @foldAlignmentRecursively(%value: index, %alignment: index) -> index { @@ -94,6 +109,21 @@ func.func @foldAddAlignmentConstant(%lhs: index) -> index { // ----- +// CHECK-LABEL: @foldMulAlignmentConstant +// CHECK-SAME: (%[[LHS:.+]]: index) +func.func @foldMulAlignmentConstant(%lhs: index) -> index { + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + // CHECK: %[[RESULT:.+]] = arith.muli %[[LHS]], %c2048 + %lhs_mul = arith.muli %lhs, %c2048 : index + // CHECK-NOT: util.align + %result = util.align %lhs_mul, %c64 : index + // CHECK: return %[[RESULT]] + return %result : index +} + +// ----- + // CHECK-LABEL: @foldConstantAlign func.func @foldConstantAlign() -> (index, index, index) { %c0 = arith.constant 0 : index @@ -110,6 +140,22 @@ func.func @foldConstantAlign() -> (index, index, index) { // ----- +// CHECK-LABEL: @foldAffineAlign +func.func @foldAffineAlign(%arg0: index) -> (index, index) { + // CHECK: %[[A0:.+]] = affine.apply affine_map<()[s0] -> (s0 * 16384)>()[%arg0] + %a0 = affine.apply affine_map<()[s0] -> (s0 * 16384)>()[%arg0] + %c64 = arith.constant 64 : index + %a1 = util.align %a0, %c64 : index + // CHECK: %[[B0:.+]] = affine.apply affine_map<()[s0] -> ((s0 * s0) * 4)>()[%arg0] + %b0 = affine.apply affine_map<()[s0] -> ((s0 * s0) * 4)>()[%arg0] + %c4 = arith.constant 4 : index + %b1 = util.align %b0, %c4 : index + // CHECK: return %[[A0]], %[[B0]] + return %a1, %b1 : index, index +} + +// ----- + // CHECK-LABEL: @sizeofWholeInt func.func @sizeofWholeInt() -> index { // CHECK: = arith.constant 4 : index