From 15006418ceb03023e8887cba87e93b499f669ad7 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 16 Oct 2024 18:22:56 -0700 Subject: [PATCH] Various tweaks to numeric optimizations found while looking at programs. (#18765) * Expands affine.apply ops at the program level. These get introduced from various places that should possibly be eliminated at this level. For now expanding them gets the job done, making the program transformation friendly again. * Fixes a multi-use issue in i64->index promotion. * Adds a pattern to fold trunc of an index cast. * Uses a conservative limit to bound all dynamic dims at the torch level, even when coming to us as unbounded. * Implements analysis interfaces on util.align. --------- Signed-off-by: Stella Laurenzo --- .../InputConversion/BindSymbolicShapes.cpp | 24 ++--- .../test/bind_symbolic_shapes.mlir | 6 +- .../iree/compiler/Dialect/Util/IR/UtilOps.cpp | 68 +++++++++++--- .../iree/compiler/Dialect/Util/IR/UtilOps.td | 11 ++- .../Dialect/Util/Transforms/BUILD.bazel | 2 + .../Dialect/Util/Transforms/CMakeLists.txt | 2 + .../Util/Transforms/OptimizeIntArithmetic.cpp | 86 ++++++++++++++++-- .../Transforms/test/integer_divisibility.mlir | 11 +++ .../test/optimize_int_arithmetic.mlir | 91 +++++++++++++++++++ 9 files changed, 263 insertions(+), 38 deletions(-) diff --git a/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp b/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp index b1c90d2f6cef..f8e13f2fa0e8 100644 --- a/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp +++ b/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp @@ -30,6 +30,14 @@ namespace mlir::iree_compiler::TorchInput { namespace { +// We aribtrarily say that unbounded dimensions in a torch program cannot +// exceed 53bits, making the maximum safe dimension 9007199254740991. The +// astute reader will note that this is also the maximum safe value in +// JavaScript, which also "happens" to be the largest mantissa value in a +// 64bit double. We need a maximum and in the absence of a better choice, +// with this one we are at least in good company. +static constexpr uint64_t MAX_DIM_VALUE = (static_cast(1) << 53) - 1; + // Torch "binds" symbolic shape information to all tensors in the program // which are not static. It does this by emitting side-effecting // torch.bind_symbolic_shape ops which are backed by torch.symbolic_int ops @@ -95,15 +103,9 @@ class BindSymbolicShapesPass final auto maxVal = symbolDefOp.getMaxValAttr(); if (minVal && maxVal) { uint64_t minValInt = minVal.getValue().getZExtValue(); - uint64_t maxValInt = maxVal.getValue().getZExtValue(); - // Note that torch represents open ranges in strange ways with various - // magic numbers in the high range of the uint64_t type. We somewhat - // arbitrarily say that anything over a fourth of the uint64_t - // range (which is half of the positive int64_t range, should these have - // originated as signed quantities), is a ridiculously large number not - // suitable as a shape dimension, and we drop the hint. - if (maxValInt >= minValInt && - maxValInt < std::numeric_limits::max() / 4) { + uint64_t maxValInt = + std::min(maxVal.getValue().getZExtValue(), MAX_DIM_VALUE); + if (maxValInt >= minValInt) { // Note that in Torch, min values are "weird" because they encode // some special cases about broadcast behavior. Here we just discard // them, but in the future, there may be more to derive here. @@ -220,8 +222,8 @@ class BindSymbolicShapesPass final for (auto [pos, symbolValue] : llvm::enumerate(symbols)) { const SymbolInfo &symbolInfo = symbolInfos.at(symbolValue); if (!symbolInfo.minMaxBounds) { - lowerBounds.push_back({}); - upperBounds.push_back({}); + lowerBounds.push_back(1); + upperBounds.push_back(MAX_DIM_VALUE); } else { lowerBounds.push_back(symbolInfo.minMaxBounds->first); upperBounds.push_back(symbolInfo.minMaxBounds->second); diff --git a/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir b/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir index 8f78bf7dcb0c..699b6dbf6d60 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir @@ -153,14 +153,16 @@ module @unsupported_non_symbolic { // ----- // Torch uses high values to signal unbounded ranges. Ensure they are -// suppressed. +// clamped. // CHECK-LABEL: @torch_unbounded_max_range module @torch_unbounded_max_range { func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) { - // CHECK-NOT: util.assume.int + // CHECK: util.assume.int {{.*}} torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32> + // CHECK: util.assume.int {{.*}} torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 10)> : !torch.vtensor<[?,?],f32> return } diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp index fd1df8e0f2ac..3de051be39dd 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp @@ -1101,6 +1101,40 @@ void printShapedFunctionSignature(OpAsmPrinter &p, Operation *op, namespace mlir::iree_compiler::IREE::Util { +//===----------------------------------------------------------------------===// +// util.align +//===----------------------------------------------------------------------===// + +void AlignOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + auto constantAlignment = argRanges[1].getConstantValue(); + // Note that for non constant alignment, there may still be something we + // want to infer, but this is left for the future. + if (constantAlignment) { + // We can align the range directly. + // (value + (alignment - 1)) & ~(alignment - 1) + // https://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding + APInt umin = argRanges[0].umin(); + APInt umax = argRanges[0].umax(); + APInt one(constantAlignment->getBitWidth(), 1); + APInt alignmentM1 = *constantAlignment - one; + APInt alignmentM1Inv = ~alignmentM1; + auto align = [&](APInt value) -> APInt { + return (value + alignmentM1) & alignmentM1Inv; + }; + setResultRange(getResult(), + ConstantIntRanges::fromUnsigned(align(umin), align(umax))); + } +} + +void AlignOp::inferResultDivisibility(ArrayRef argDivs, + SetIntDivisibilityFn setResultDivs) { + auto alignmentDiv = argDivs[1]; + if (alignmentDiv.isUninitialized()) + return; + setResultDivs(getResult(), alignmentDiv.getValue()); +} + //===----------------------------------------------------------------------===// // util.assume.int //===----------------------------------------------------------------------===// @@ -1120,39 +1154,45 @@ AssumeIntOp::getOperandAssumptions(unsigned operandIndex) { std::pair, std::optional> AssumeIntOp::getUnionedUnsignedRange(unsigned operandIndex) { auto assumptions = getOperandAssumptions(operandIndex); - std::optional uminUnion; - std::optional umaxUnion; + uint64_t uminUnion = std::numeric_limits::max(); + int uminCount = 0; + uint64_t umaxUnion = std::numeric_limits::min(); + int umaxCount = 0; for (auto assumption : assumptions) { auto umin = assumption.getUmin(); auto umax = assumption.getUmax(); if (umin) { uminUnion = std::min( - *umin, uminUnion ? *uminUnion : std::numeric_limits::max()); + *umin, uminUnion ? uminUnion : std::numeric_limits::max()); + uminCount += 1; } if (umax) { umaxUnion = std::max( - *umax, umaxUnion ? *umaxUnion : std::numeric_limits::min()); + *umax, umaxUnion ? umaxUnion : std::numeric_limits::min()); + umaxCount += 1; } } - return std::make_pair(uminUnion, umaxUnion); + return std::make_pair(uminCount > 0 && uminCount == assumptions.size() + ? std::optional(uminUnion) + : std::nullopt, + umaxCount > 0 && umaxCount == assumptions.size() + ? std::optional(umaxUnion) + : std::nullopt); } -// Gets the unioned divisor for an operand. If there are multiple divisor -// assumptions, the gcd of all of them is returned. If there are no -// divisor assumptions, std::nullopt is returned. std::optional AssumeIntOp::getUnionedUnsignedDivisor(unsigned operandIndex) { auto assumptions = getOperandAssumptions(operandIndex); std::optional divisorUnion; for (auto assumption : assumptions) { auto divisor = assumption.getUdiv(); - if (divisor) { - if (divisorUnion) - divisorUnion = std::gcd(*divisor, *divisorUnion); - else - divisorUnion = *divisor; - } + if (!divisor) + return std::nullopt; + if (divisorUnion) + divisorUnion = std::gcd(*divisor, *divisorUnion); + else + divisorUnion = *divisor; } return divisorUnion; } diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td index 648e4f4dad58..aaa10da27005 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td @@ -336,7 +336,9 @@ def OpGroupAddressOffsetArithmeticOps : OpDocGroup { let opDocGroup = OpGroupAddressOffsetArithmeticOps in { def Util_AlignOp : Util_PureOp<"align", [ - SameOperandsAndResultType + SameOperandsAndResultType, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods ]> { let summary = "Aligns up to a power-of-two alignment if required"; let description = [{ @@ -504,14 +506,15 @@ def Util_AssumeIntOp : Util_PureOp<"assume.int", [ // Gets the unioned unsigned range for an operand. If there are multiple // assumptions for the operand, this will return the bounding range for - // them all. If there is no umin/umax, then std::nullopt will be returned - // for that position. + // them all. If there is no umin/umax for any row in the set, then + // std::nullopt will be returned for that position. std::pair, std::optional> getUnionedUnsignedRange(unsigned operandIndex); // Gets the unioned divisor for an operand. If there are multiple divisor // assumptions, the gcd of all of them is returned. If there are no - // divisor assumptions, std::nullopt is returned. + // divisor assumptions or if there is not a udiv for any row, std::nullopt + // is returned. std::optional getUnionedUnsignedDivisor(unsigned operandIndex); }]; diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel index a98135d57bdb..da3c44f08cb1 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel @@ -54,6 +54,8 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineTransforms", + "@llvm-project//mlir:AffineUtils", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithTransforms", diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt index d233f11e0278..a8542c40ae64 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt @@ -43,6 +43,8 @@ iree_cc_library( ::PassesIncGen LLVMSupport MLIRAffineDialect + MLIRAffineTransforms + MLIRAffineUtils MLIRAnalysis MLIRArithDialect MLIRArithTransforms diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp index 79e61a174f4c..022beaac439b 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp @@ -12,6 +12,9 @@ #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/IR/Matchers.h" @@ -108,19 +111,35 @@ struct ConvertOpToUnsigned : public OpRewritePattern { // optimizations, it can be useful to eliminate them when possible. //===----------------------------------------------------------------------===// +// Matches IR like: +// %5 = arith.addi %0, %1 : int64 +// %6 = arith.index_castui %5 : int64 to index +// +// And moves the index_castui to the producer's operands: +// %3 = arith.index_castui %0 : int64 to index +// %4 = arith.index_castui %1 : int64 to index +// %5 = arith.addi %3, %4 : index +// struct ConvertUnsignedI64IndexCastProducerToIndex : public OpRewritePattern { ConvertUnsignedI64IndexCastProducerToIndex(MLIRContext *context, DataFlowSolver &solver) : OpRewritePattern(context), solver(solver) {} - LogicalResult matchAndRewrite(arith::IndexCastUIOp op, + LogicalResult matchAndRewrite(arith::IndexCastUIOp origIndexOp, PatternRewriter &rewriter) const override { - Type inType = op.getIn().getType(); - Type outType = op.getOut().getType(); + Type inType = origIndexOp.getIn().getType(); + Type outType = origIndexOp.getOut().getType(); if (!inType.isSignlessInteger(64) && isa(outType)) return failure(); + Operation *producer = origIndexOp.getIn().getDefiningOp(); + if (!producer) + return failure(); + auto producerResult = producer->getResult(0); + if (!producerResult.hasOneUse()) + return failure(); + auto pred = [&](Value v) -> bool { auto *result = solver.lookupState(v); if (!result || result->getValue().isUninitialized()) { @@ -137,7 +156,6 @@ struct ConvertUnsignedI64IndexCastProducerToIndex llvm::all_of(op->getResults(), pred); }; - Operation *producer = op.getIn().getDefiningOp(); if (!isa_and_present(producer)) @@ -145,6 +163,7 @@ struct ConvertUnsignedI64IndexCastProducerToIndex if (!isOpStaticallyLegal(producer)) return failure(); + // Make modifications. rewriter.modifyOpInPlace(producer, [&]() { rewriter.setInsertionPoint(producer); for (auto &operand : producer->getOpOperands()) { @@ -156,6 +175,8 @@ struct ConvertUnsignedI64IndexCastProducerToIndex } producer->getResult(0).setType(outType); }); + origIndexOp.getOut().replaceAllUsesWith(producer->getResult(0)); + rewriter.eraseOp(origIndexOp); return success(); } @@ -206,6 +227,52 @@ struct RemUIDivisibilityByConstant : public OpRewritePattern { DataFlowSolver &solver; }; +//===----------------------------------------------------------------------===// +// Affine expansion +// affine.apply expansion can fail after producing a lot of IR. Since this is +// a bad thing to be doing as part of our overall iteration, we do it as a +// preprocessing walk. This also lets it be well behaved with respect to +// error messaging, etc. We will likely replace this with a more integrated +// version at some point which can use the bounds analysis to avoid corners +// of the original. +//===----------------------------------------------------------------------===// + +void expandAffineOps(Operation *rootOp) { + IRRewriter rewriter(rootOp->getContext()); + rootOp->walk([&](affine::AffineApplyOp op) { + LLVM_DEBUG(dbgs() << "** Expand affine.apply: " << op << "\n"); + rewriter.setInsertionPoint(op); + auto maybeExpanded = + mlir::affine::expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), + llvm::to_vector<4>(op.getOperands())); + if (!maybeExpanded) { + LLVM_DEBUG(dbgs() << "** ERROR: Failed to expand affine.apply\n"); + return; + } + rewriter.replaceOp(op, *maybeExpanded); + }); +} + +//===----------------------------------------------------------------------===// +// General optimization patterns +//===----------------------------------------------------------------------===// + +struct ElideTruncOfIndexCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::TruncIOp truncOp, + PatternRewriter &rewriter) const override { + Operation *producer = truncOp.getOperand().getDefiningOp(); + if (!producer) + return failure(); + if (!isa(producer)) + return failure(); + rewriter.replaceOpWithNewOp( + truncOp, truncOp.getResult().getType(), producer->getOperand(0)); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Pass setup //===----------------------------------------------------------------------===// @@ -270,6 +337,9 @@ class OptimizeIntArithmeticPass void runOnOperation() override { Operation *op = getOperation(); MLIRContext *ctx = op->getContext(); + + expandAffineOps(op); + DataFlowSolver solver; solver.load(); solver.load(); @@ -281,13 +351,15 @@ class OptimizeIntArithmeticPass arith::populateIntRangeOptimizationsPatterns(patterns, solver); // Populate canonicalization patterns. - auto arithDialectTypeID = - ctx->getOrLoadDialect()->getTypeID(); + auto arithDialect = ctx->getOrLoadDialect(); for (const RegisteredOperationName &name : ctx->getRegisteredOperations()) { - if (name.getDialect().getTypeID() == arithDialectTypeID) + if (&name.getDialect() == arithDialect) name.getCanonicalizationPatterns(patterns, ctx); } + // General optimization patterns. + patterns.add(ctx); + // Populate unsigned conversion patterns. patterns.add, diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir index e8c2740bb31f..3e2235b7f8f4 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir @@ -34,3 +34,14 @@ util.func @remui_div_by_unrelated(%arg0 : index) -> index { %1 = arith.remui %0, %cst : index util.return %1 : index } + +// ----- +// A missing udiv in a multi-row assumption is treated as an unknown. +// CHECK-LABEL: @missing_udiv_skipped +util.func @missing_udiv_skipped(%arg0 : index) -> index { + // CHECK: arith.remui + %cst = arith.constant 16 : index + %0 = util.assume.int %arg0[, <>] : index + %1 = arith.remui %0, %cst : index + util.return %1 : index +} diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir index 4726941c42db..41b304a89c1f 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir @@ -27,6 +27,30 @@ util.func @index_lower_bound(%arg0 : index) -> i1 { util.return %1 : i1 } +// ----- +// If there is a missing umax in a multi-row assumption, then it must +// be treated as having no known upper bound. +// CHECK-LABEL: @missing_umax_skipped +util.func @missing_umax_skipped(%arg0 : index) -> i1 { + // CHECK: arith.cmpi + %cst = arith.constant 101 : index + %0 = util.assume.int %arg0[, ] : index + %1 = arith.cmpi ult, %0, %cst : index + util.return %1 : i1 +} + +// ----- +// If there is a missing umin in a multi-row assumption, then it must +// be treated as having no known lower bound. +// CHECK-LABEL: @missing_umin_skipped +util.func @missing_umin_skipped(%arg0 : index) -> i1 { + // CHECK: arith.cmpi + %cst = arith.constant 5 : index + %0 = util.assume.int %arg0[, ] : index + %1 = arith.cmpi ugt, %0, %cst : index + util.return %1 : i1 +} + // ----- // CHECK-LABEL: @index_indeterminate util.func @index_indeterminate(%arg0 : index) -> i1 { @@ -246,6 +270,20 @@ util.func @index_cast_i64_to_index_addi(%arg0 : index, %arg1 : index) -> index { util.return %3 : index } +// ----- +// Multi-use should not convert +// CHECK-LABEL: @index_cast_i64_to_index_addi_multiuse +util.func @index_cast_i64_to_index_addi_multiuse(%arg0 : index, %arg1 : index) -> index, i64 { + // CHECK: %[[ASSUME:.*]] = util.assume.int + %0 = util.assume.int %arg0 : index + // CHECK: arith.index_cast + // CHECK: arith.index_cast + %1 = arith.index_cast %0 : index to i64 + %2 = arith.addi %1, %1 : i64 + %3 = arith.index_cast %2 : i64 to index + util.return %3, %2 : index, i64 +} + // ----- // CHECK-LABEL: @index_cast_i64_to_index_ceildivsi util.func @index_cast_i64_to_index_ceildivsi(%arg0 : index, %arg1 : index) -> index { @@ -371,3 +409,56 @@ util.func @index_cast_i64_to_index_remsi(%arg0 : index, %arg1 : index) -> index %3 = arith.index_cast %2 : i64 to index util.return %3 : index } + +// ----- +// Truncate of an index cast can be folded into the index cast. +// CHECK-LABEL: @elide_trunc_of_index_castui +util.func @elide_trunc_of_index_castui(%arg0 : index) -> i32 { + %1 = arith.index_castui %arg0 : index to i64 + %2 = arith.trunci %1 : i64 to i32 + // CHECK: %[[RESULT:.*]] = arith.index_castui %arg0 : index to i32 + // CHECH: util.return %[[RESULT]] + util.return %2 : i32 +} + +// ----- +// CHECK-LABEL: @elide_trunc_of_index_cast +util.func @elide_trunc_of_index_cast(%arg0 : index) -> i32 { + %1 = arith.index_cast %arg0 : index to i64 + %2 = arith.trunci %1 : i64 to i32 + // CHECK: %[[RESULT:.*]] = arith.index_castui %arg0 : index to i32 + // CHECH: util.return %[[RESULT]] + util.return %2 : i32 +} + +// ----- +// CHECK-LABEL: @util_align_bounds_div +util.func @util_align_bounds_div(%arg0 : index, %arg1 : index) -> index, index, index, i1, i1 { + %0 = util.assume.int %arg0 : index + %1 = util.assume.int %arg1 : index + // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 + // CHECK-DAG: %[[FALSE:.*]] = arith.constant false + // CHECK-DAG: %[[TRUE:.*]] = arith.constant true + // CHECK-DAG: %[[C64:.*]] = arith.constant 64 + // CHECK-DAG: %[[ASSUME:.*]] = util.assume.int %arg0 + // CHECK: %[[ALIGN:.*]] = util.align %[[ASSUME]], %[[C64]] + %2 = util.align %0, %1 : index + + // The result should be >= 64 and <= 128. + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %lower = arith.cmpi uge, %2, %c64 : index // True + %upper = arith.cmpi ule, %2, %c128 : index // True + %under = arith.cmpi ult, %2, %c64 : index // False + %over = arith.cmpi ugt, %2, %c128 : index // False + %in_bounds = arith.andi %lower, %upper : i1 // True + %out_bounds = arith.andi %under, %over : i1 // False + + // And 64 should evenly divide it. + %rem64 = arith.remui %2, %c64 : index + // But 128 should not. + // CHECK: %[[REM128:.*]] = arith.remui + %rem128 = arith.remui %2, %c128 : index + // CHECK: util.return %[[ALIGN]], %[[ZERO]], %[[REM128]], %[[TRUE]], %[[FALSE]] + util.return %2, %rem64, %rem128, %in_bounds, %out_bounds : index, index, index, i1, i1 +}