Skip to content

Commit

Permalink
Improving util.align folding. (#14805)
Browse files Browse the repository at this point in the history
  • Loading branch information
benvanik authored Aug 24, 2023
1 parent 8f62f9b commit b0c77fa
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 10 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ iree_compiler_cc_library(
":UtilOpsGen",
":UtilTypesGen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:CastInterfaces",
"@llvm-project//mlir:ControlFlowDialect",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/Util/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ iree_cc_library(
::UtilOpsGen
::UtilTypesGen
LLVMSupport
MLIRAffineDialect
MLIRArithDialect
MLIRCastInterfaces
MLIRControlFlowDialect
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ serializeGenericElementData(Location loc, DenseElementsAttr elementsAttr,
case 64:
return serializeGenericIntegerElements<uint64_t>(attr, endian, os);
default:
if (bitWidth != 1 && bitWidth < 64) {
if (bitWidth < 64) {
// Special case for bit-packing of sub-byte aligned types.
// This could be extended to handle larger widths (i33, etc) but they
// are rare today.
Expand Down
31 changes: 23 additions & 8 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <numeric>

#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
Expand Down Expand Up @@ -364,9 +367,11 @@ void RangeExtentsOp::getCanonicalizationPatterns(RewritePatternSet &results,
// a large majority of the cases we generate ourselves from packing/allocation.
static bool isAlignedTo(Value value, Value alignment) {
APInt staticValue;
bool hasStaticValue = matchPattern(value, m_ConstantInt(&staticValue));
APInt staticAlignment;
if (matchPattern(value, m_ConstantInt(&staticValue)) &&
matchPattern(alignment, m_ConstantInt(&staticAlignment))) {
bool hasStaticAlignment =
matchPattern(alignment, m_ConstantInt(&staticAlignment));
if (hasStaticValue && hasStaticAlignment) {
// If this value is itself a multiple of the alignment then we can fold.
if (staticValue.urem(staticAlignment).isZero()) {
return true; // value % alignment == 0
Expand All @@ -381,11 +386,12 @@ static bool isAlignedTo(Value value, Value alignment) {

// If the alignments are constant we can compare them inline.
APInt sourceAlignment;
APInt selfAlignment;
if (matchPattern(sourceAlignOp.getAlignment(),
m_ConstantInt(&sourceAlignment)) &&
matchPattern(alignment, m_ConstantInt(&selfAlignment))) {
if (sourceAlignment.uge(selfAlignment)) {
if (hasStaticAlignment && matchPattern(sourceAlignOp.getAlignment(),
m_ConstantInt(&sourceAlignment))) {
if (sourceAlignment.uge(staticAlignment) &&
std::gcd(sourceAlignment.getZExtValue(),
staticAlignment.getZExtValue()) ==
staticAlignment.getZExtValue()) {
return true; // source alignment is >= our alignment
}
}
Expand All @@ -395,6 +401,15 @@ static bool isAlignedTo(Value value, Value alignment) {
return isAlignedTo(sourceAlignOp.getValue(), alignment);
}

// Affine apply ops producing the value to be aligned usually include
// alignment already.
if (auto affineOp = value.getDefiningOp<affine::AffineApplyOp>()) {
if (hasStaticAlignment) {
return (affineOp.getAffineMap().getLargestKnownDivisorOfMapExprs() %
staticAlignment.getZExtValue()) == 0;
}
}

// If we are sourced from add/mul we peephole check to see if what is being
// added is also aligned. This should be part of a larger pass doing IPO but
// as the common case is that we align+add+align this is worth having in a
Expand All @@ -414,7 +429,7 @@ static bool isAlignedTo(Value value, Value alignment) {
}
} else if (auto sourceMulOp = value.getDefiningOp<arith::MulIOp>()) {
// Two aligned values multiplied together are still aligned.
if (isAlignedTo(sourceMulOp.getLhs(), alignment) &&
if (isAlignedTo(sourceMulOp.getLhs(), alignment) ||
isAlignedTo(sourceMulOp.getRhs(), alignment)) {
return true;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --split-input-file --canonicalize %s | iree-opt --split-input-file | FileCheck %s
// RUN: iree-opt --split-input-file --canonicalize --mlir-print-local-scope %s | iree-opt --split-input-file --mlir-print-local-scope | FileCheck %s

// CHECK-LABEL: @foldSameAlignment
// CHECK-SAME: (%[[VALUE:.+]]: index, %[[ALIGNMENT:.+]]: index)
Expand Down Expand Up @@ -43,6 +43,21 @@ func.func @dontFoldLesserAlignment(%value: index) -> index {

// -----

// CHECK-LABEL: @dontFoldMixedAlignment
// CHECK-SAME: (%[[VALUE:.+]]: index)
func.func @dontFoldMixedAlignment(%value: index) -> index {
%c9 = arith.constant 9 : index
%c16 = arith.constant 16 : index
// CHECK: %[[ALIGN16:.+]] = util.align %[[VALUE]], %c16
%0 = util.align %value, %c16 : index
// CHECK: %[[ALIGN9:.+]] = util.align %[[ALIGN16]], %c9
%1 = util.align %0, %c9 : index
// CHECK: return %[[ALIGN9]]
return %1 : index
}

// -----

// CHECK-LABEL: @foldAlignmentRecursively
// CHECK-SAME: (%[[VALUE:.+]]: index, %[[ALIGNMENT:.+]]: index)
func.func @foldAlignmentRecursively(%value: index, %alignment: index) -> index {
Expand Down Expand Up @@ -94,6 +109,21 @@ func.func @foldAddAlignmentConstant(%lhs: index) -> index {

// -----

// CHECK-LABEL: @foldMulAlignmentConstant
// CHECK-SAME: (%[[LHS:.+]]: index)
func.func @foldMulAlignmentConstant(%lhs: index) -> index {
%c64 = arith.constant 64 : index
%c2048 = arith.constant 2048 : index
// CHECK: %[[RESULT:.+]] = arith.muli %[[LHS]], %c2048
%lhs_mul = arith.muli %lhs, %c2048 : index
// CHECK-NOT: util.align
%result = util.align %lhs_mul, %c64 : index
// CHECK: return %[[RESULT]]
return %result : index
}

// -----

// CHECK-LABEL: @foldConstantAlign
func.func @foldConstantAlign() -> (index, index, index) {
%c0 = arith.constant 0 : index
Expand All @@ -110,6 +140,22 @@ func.func @foldConstantAlign() -> (index, index, index) {

// -----

// CHECK-LABEL: @foldAffineAlign
func.func @foldAffineAlign(%arg0: index) -> (index, index) {
// CHECK: %[[A0:.+]] = affine.apply affine_map<()[s0] -> (s0 * 16384)>()[%arg0]
%a0 = affine.apply affine_map<()[s0] -> (s0 * 16384)>()[%arg0]
%c64 = arith.constant 64 : index
%a1 = util.align %a0, %c64 : index
// CHECK: %[[B0:.+]] = affine.apply affine_map<()[s0] -> ((s0 * s0) * 4)>()[%arg0]
%b0 = affine.apply affine_map<()[s0] -> ((s0 * s0) * 4)>()[%arg0]
%c4 = arith.constant 4 : index
%b1 = util.align %b0, %c4 : index
// CHECK: return %[[A0]], %[[B0]]
return %a1, %b1 : index, index
}

// -----

// CHECK-LABEL: @sizeofWholeInt
func.func @sizeofWholeInt() -> index {
// CHECK: = arith.constant 4 : index
Expand Down

0 comments on commit b0c77fa

Please sign in to comment.