From 9f0aa05bfb40c077a5b1c2ea8cac88fdd51f0c5c Mon Sep 17 00:00:00 2001 From: Hugo Trachino Date: Fri, 21 Jun 2024 13:34:37 +0100 Subject: [PATCH] [mlir][vector] Add ElementwiseToOuterproduct (#93664) 1D multi-reduction are lowered to arith which can prevent some optimisations. I propose `ElementwiseToOuterproduct` matching a series of ops to generate `vector.outerproduct`. As part of some `ElementwiseToVectorOpsPatterns`, it could allow to fuse other elementwiseOps to vector dialect. Originally discussed https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/24. quote @MacDue ``` %lhsBcast = vector.broadcast %lhsCast : vector<[4]xf32> to vector<[4]x[4]xf32> %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32> %rhsBcast = vector.broadcast %rhs : vector<[4]xf32> to vector<[4]x[4]xf32> %mul = arith.mulf %lhsT, %rhsBcast : vector<[4]x[4]xf32> ``` Can be rewritten as: ``` %mul = vector.outerproduct $lhs, $rhs : vector<[4]xf32>, vector<[4]xf32> ``` --------- Co-authored-by: Han-Chung Wang --- .../mlir/Dialect/Vector/IR/VectorOps.h | 4 + .../Vector/TransformOps/VectorTransformOps.td | 11 +++ .../TransformOps/VectorTransformOps.cpp | 5 ++ .../Vector/Transforms/VectorTransforms.cpp | 85 +++++++++++++++++++ .../test/Dialect/Vector/transform-vector.mlir | 38 +++++++++ 5 files changed, 143 insertions(+) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index 4603953cb40fa5..ac55433fadb2f4 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -80,6 +80,10 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, /// into vector contract for the backends with native support. void populateFoldArithExtensionPatterns(RewritePatternSet &patterns); +/// Collect a set of patterns that fold elementwise op on vectors to the vector +/// dialect. +void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns); + /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index c91e8fbbae90f2..820a18731ffdb0 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -406,6 +406,17 @@ def ApplyFoldArithExtensionPatternsOp : Op]> { + let description = [{ + Collect a set of patterns that fold elementwise op on vectors to the vector + dialect. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyVectorReductionToContractPatternsOp : Op]> { diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 23960269095e50..2e9aa88011825b 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -59,6 +59,11 @@ void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns( vector::populateFoldArithExtensionPatterns(patterns); } +void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateElementwiseToVectorOpsPatterns(patterns); +} + void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorReductionToContractPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index b824508728ac8a..eac6db585aad78 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1813,6 +1813,84 @@ struct BreakDownVectorReduction final : OpRewritePattern { unsigned maxNumElementsToExtract = 0; }; +/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A, +/// B)`. +/// Example: +/// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32> +/// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to +/// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to +/// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32> +/// +/// Becomes : +/// +/// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32> +/// +/// Supports only 1D-to-2D broadcasts. The following cases are not supported. +/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32> +/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32> +/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32> +template +struct FoldArithToVectorOuterProduct : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + // Returns whether a vector.broadcast matches requirements for an outerproduct + // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension. + bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const { + // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating + // shape_casts/broadcasts which does not belong in this pattern. + if (!broadcastOp.computeBroadcastedUnitDims().empty()) + return false; + // Avoid broadcast like f32 or vector -> ResType + auto srcType = dyn_cast(broadcastOp.getSourceType()); + return srcType && srcType.getRank() != 2; + } + + LogicalResult matchAndRewrite(MulOpType mulOp, + PatternRewriter &rewriter) const override { + auto resType = llvm::cast(mulOp.getResult().getType()); + if (!resType) + return failure(); + if (resType.getRank() != 2) + return failure(); + /// If operandA can be written as tr(broadcast(A)) and operandB as + /// broadcast(B) where broadcasts are 1D-to-2D, create and return + /// vector.outerproduct(A, B). Returns failure() otherwise. + auto matchOuterProduct = + [&](Value operandA, + Value operandB) -> FailureOr { + auto transposedLhs = operandA.getDefiningOp(); + if (!transposedLhs) + return failure(); + // Fail unless this is a true 2-D matrix transpose. + ArrayRef permutation = transposedLhs.getPermutation(); + if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0) + return failure(); + + auto broadcastedLhs = + transposedLhs.getVector().getDefiningOp(); + if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs)) + return failure(); + + auto broadcastedRhs = operandB.getDefiningOp(); + if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs)) + return failure(); + + return rewriter.create( + mulOp->getLoc(), resType, broadcastedLhs.getSource(), + broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD); + }; + + Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1); + auto maybeOuterP = matchOuterProduct(lhs, rhs); + // Handle commutativity, the transposed op is the outerproduct LHS. + if (failed(maybeOuterP)) + maybeOuterP = matchOuterProduct(rhs, lhs); + if (failed(maybeOuterP)) + return failure(); + rewriter.replaceOp(mulOp, maybeOuterP->getResult()); + return success(); + } +}; + } // namespace void mlir::vector::populateFoldArithExtensionPatterns( @@ -1900,6 +1978,13 @@ void mlir::vector::populateBreakDownVectorReductionPatterns( maxNumElementsToExtract, benefit); } +void mlir::vector::populateElementwiseToVectorOpsPatterns( + RewritePatternSet &patterns) { + patterns.add, + FoldArithToVectorOuterProduct>( + patterns.getContext()); +} + //===----------------------------------------------------------------------===// // TableGen'd enum attribute definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir index 75b29e22b4d2ce..4b38db79bff3e1 100644 --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -92,3 +92,41 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: func.func @arith_to_outerproduct_scalable_i32 +// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>, +// CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> { +// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32> +// CHECK: return %[[RES]] : vector<[4]x[4]xi32> +func.func @arith_to_outerproduct_scalable_i32(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> { + %lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32> + %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32> + %rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32> + %mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32> + return %mul: vector<[4]x[4]xi32> +} + +// CHECK-LABEL: func.func @arith_to_outerproduct_trans_rhs_f32 +// CHECK-SAME: %[[LHS:.*]]: vector<16xf32>, +// CHECK-SAME: %[[RHS:.*]]: vector<8xf32>) -> vector<8x16xf32> { +// CHECK: %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<8xf32>, vector<16xf32> +// CHECK: return %[[RES]] : vector<8x16xf32> +func.func @arith_to_outerproduct_trans_rhs_f32(%lhs: vector<16xf32>, %rhs: vector<8xf32>) -> vector<8x16xf32> { + %rhsBcast = vector.broadcast %rhs : vector<8xf32> to vector<16x8xf32> + %rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x8xf32> to vector<8x16xf32> + %lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<8x16xf32> + %mul = arith.mulf %lhsBcast, %rhsT : vector<8x16xf32> + return %mul: vector<8x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.vector.elementwise_to_vector + } : !transform.any_op + transform.yield + } +}