Skip to content

Commit

Permalink
[mlir][vector] Add ElementwiseToOuterproduct (llvm#93664)
Browse files Browse the repository at this point in the history
1D multi-reduction are lowered to arith which can prevent some
optimisations. I propose `ElementwiseToOuterproduct` matching a series of
ops to generate `vector.outerproduct`.
As part of some `ElementwiseToVectorOpsPatterns`, it could allow to fuse
other elementwiseOps to vector dialect.
Originally discussed
https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/24.

quote @MacDue
```
%lhsBcast = vector.broadcast %lhsCast : vector<[4]xf32> to vector<[4]x[4]xf32>
%lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
%rhsBcast = vector.broadcast %rhs : vector<[4]xf32> to vector<[4]x[4]xf32>
%mul = arith.mulf %lhsT, %rhsBcast : vector<[4]x[4]xf32>
```

Can be rewritten as:

```
%mul = vector.outerproduct $lhs, $rhs : vector<[4]xf32>, vector<[4]xf32>
```

---------

Co-authored-by: Han-Chung Wang <hanhan0912@gmail.com>
  • Loading branch information
nujaa and hanhanW authored Jun 21, 2024
1 parent 138ea7d commit 9f0aa05
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 0 deletions.
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
/// into vector contract for the backends with native support.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns);

/// Collect a set of patterns that fold elementwise op on vectors to the vector
/// dialect.
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns);

/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,17 @@ def ApplyFoldArithExtensionPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplyFoldElementwiseToVectorPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.elementwise_to_vector",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collect a set of patterns that fold elementwise op on vectors to the vector
dialect.
}];

let assemblyFormat = "attr-dict";
}

def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.reduction_to_contract",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
vector::populateFoldArithExtensionPatterns(patterns);
}

void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateElementwiseToVectorOpsPatterns(patterns);
}

void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorReductionToContractPatterns(patterns);
Expand Down
85 changes: 85 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1813,6 +1813,84 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
unsigned maxNumElementsToExtract = 0;
};

/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
/// B)`.
/// Example:
/// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
/// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
/// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
/// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
///
/// Becomes :
///
/// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
///
/// Supports only 1D-to-2D broadcasts. The following cases are not supported.
/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
template <typename MulOpType>
struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
using OpRewritePattern<MulOpType>::OpRewritePattern;
// Returns whether a vector.broadcast matches requirements for an outerproduct
// pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension.
bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
// Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
// shape_casts/broadcasts which does not belong in this pattern.
if (!broadcastOp.computeBroadcastedUnitDims().empty())
return false;
// Avoid broadcast like f32 or vector<f32> -> ResType
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
return srcType && srcType.getRank() != 2;
}

LogicalResult matchAndRewrite(MulOpType mulOp,
PatternRewriter &rewriter) const override {
auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
if (!resType)
return failure();
if (resType.getRank() != 2)
return failure();
/// If operandA can be written as tr(broadcast(A)) and operandB as
/// broadcast(B) where broadcasts are 1D-to-2D, create and return
/// vector.outerproduct(A, B). Returns failure() otherwise.
auto matchOuterProduct =
[&](Value operandA,
Value operandB) -> FailureOr<vector::OuterProductOp> {
auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>();
if (!transposedLhs)
return failure();
// Fail unless this is a true 2-D matrix transpose.
ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
return failure();

auto broadcastedLhs =
transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
return failure();

auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
return failure();

return rewriter.create<vector::OuterProductOp>(
mulOp->getLoc(), resType, broadcastedLhs.getSource(),
broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD);
};

Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
auto maybeOuterP = matchOuterProduct(lhs, rhs);
// Handle commutativity, the transposed op is the outerproduct LHS.
if (failed(maybeOuterP))
maybeOuterP = matchOuterProduct(rhs, lhs);
if (failed(maybeOuterP))
return failure();
rewriter.replaceOp(mulOp, maybeOuterP->getResult());
return success();
}
};

} // namespace

void mlir::vector::populateFoldArithExtensionPatterns(
Expand Down Expand Up @@ -1900,6 +1978,13 @@ void mlir::vector::populateBreakDownVectorReductionPatterns(
maxNumElementsToExtract, benefit);
}

void mlir::vector::populateElementwiseToVectorOpsPatterns(
RewritePatternSet &patterns) {
patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
FoldArithToVectorOuterProduct<arith::MulIOp>>(
patterns.getContext());
}

//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//
Expand Down
38 changes: 38 additions & 0 deletions mlir/test/Dialect/Vector/transform-vector.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,41 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

// CHECK-LABEL: func.func @arith_to_outerproduct_scalable_i32
// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>,
// CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32>
// CHECK: return %[[RES]] : vector<[4]x[4]xi32>
func.func @arith_to_outerproduct_scalable_i32(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
%lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32>
%lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
%rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32>
%mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32>
return %mul: vector<[4]x[4]xi32>
}

// CHECK-LABEL: func.func @arith_to_outerproduct_trans_rhs_f32
// CHECK-SAME: %[[LHS:.*]]: vector<16xf32>,
// CHECK-SAME: %[[RHS:.*]]: vector<8xf32>) -> vector<8x16xf32> {
// CHECK: %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<8xf32>, vector<16xf32>
// CHECK: return %[[RES]] : vector<8x16xf32>
func.func @arith_to_outerproduct_trans_rhs_f32(%lhs: vector<16xf32>, %rhs: vector<8xf32>) -> vector<8x16xf32> {
%rhsBcast = vector.broadcast %rhs : vector<8xf32> to vector<16x8xf32>
%rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x8xf32> to vector<8x16xf32>
%lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<8x16xf32>
%mul = arith.mulf %lhsBcast, %rhsT : vector<8x16xf32>
return %mul: vector<8x16xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.vector.elementwise_to_vector
} : !transform.any_op
transform.yield
}
}

0 comments on commit 9f0aa05

Please sign in to comment.