Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aievec] Canonicalize transposed B contract op. #1475

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 158 additions & 3 deletions lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2023, Advanced Micro Devices, Inc.
// (c) Copyright 2023-2024 Advanced Micro Devices, Inc.
//
//===----------------------------------------------------------------------===//
// This file contains conversions and rewrites to the Vector dialect to make
Expand Down Expand Up @@ -39,6 +39,55 @@ using namespace xilinx::aievec;
//================== Common AIE canonicalization analysis ====================//
//============================================================================//

static bool isGemmBTransposedContractionOp(vector::ContractionOp op) {
if (op.getKind() != vector::CombiningKind::ADD)
return false;

// Get and check shape of operands
auto lhsShape = op.getLhsType().getShape();
auto rhsShape = op.getRhsType().getShape();
auto accShape = cast<ShapedType>(op.getAccType()).getShape();
if (lhsShape.size() < 2 || rhsShape.size() < 2 || accShape.size() < 2)
return false;

// Check that the innermost iterators match gemm-like iterators
SmallVector<vector::IteratorType> iterators = op.getIteratorTypesArray();
if (iterators.size() < 3)
return false;
auto innerMostIterators =
SmallVector<vector::IteratorType>(iterators.end() - 3, iterators.end());
if (vector::IteratorType::parallel != innerMostIterators[0] ||
vector::IteratorType::parallel != innerMostIterators[1] ||
vector::IteratorType::reduction != innerMostIterators[2])
return false;

// Get indexing maps of iterators for operands
SmallVector<AffineMap, 4> indexingMaps(op.getIndexingMapsArray());
SmallVector<int64_t> outerMostResults;
for (int64_t i = 0; i < indexingMaps[0].getNumResults() - 2; i++)
outerMostResults.push_back(i);

auto innerLhsMap = indexingMaps[0].dropResults(outerMostResults);
auto innerRhsMap = indexingMaps[1].dropResults(outerMostResults);
auto innerAccMap = indexingMaps[2].dropResults(outerMostResults);

// Check whether they conform to a "transposed B" gemm
auto ctx = op.getContext();
auto mmAidxMap =
AffineMap::getPermutationMap(ArrayRef<unsigned>{1, 0, 2}, ctx)
.dropResults(0);
auto mmBidxMap =
AffineMap::getPermutationMap(ArrayRef<unsigned>{0, 1, 2}, ctx)
.dropResults(0);
auto mmCidxMap =
AffineMap::getPermutationMap(ArrayRef<unsigned>{2, 0, 1}, ctx)
.dropResults(0);
int64_t numOuterMostDims = indexingMaps[0].getNumDims() - 3;
return innerLhsMap == mmAidxMap.shiftDims(numOuterMostDims) &&
innerRhsMap == mmBidxMap.shiftDims(numOuterMostDims) &&
innerAccMap == mmCidxMap.shiftDims(numOuterMostDims);
}

//============================================================================//
//============ Common AIE canonicalization conversion patterns ===============//
//============================================================================//
Expand Down Expand Up @@ -411,6 +460,107 @@ struct FlattenMultDimTransferWritePattern
}
};

// This pattern takes out an implicit transposition of the `rhs` operand in a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to clarify the comment as follows:
an implicit transposition of the rhs operand
=>
an implicit transposition of the 2 innermost dimensions of the rhs operand

// gemm-like contraction op, making it an explicit `vector.transpose` op.
// If `rhs` is coming from a widening op (`extf`/`extsi`/`extui`), the
// transposition will be hoisted above the widening op.
struct ExtractTransposeFromContractionOp
: public OpConversionPattern<vector::ContractionOp> {
using OpConversionPattern<vector::ContractionOp>::OpConversionPattern;

static VectorType getTransposedVectorType(VectorType vecTy) {
SmallVector<int64_t> shape{vecTy.getShape()};
auto nDim = shape.size();
int64_t dimNm1 = shape[nDim - 1];
shape[nDim - 1] = shape[nDim - 2];
shape[nDim - 2] = dimNm1;
auto elemTy = vecTy.getElementType();
return VectorType::get(shape, elemTy);
}

LogicalResult
matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isGemmBTransposedContractionOp(contractOp))
return failure();

Location loc = contractOp.getLoc();
auto ctx = rewriter.getContext();

Value rhsVal = adaptor.getRhs();
VectorType rhsVecTy = contractOp.getRhsType();
Type rhsElemTy = rhsVecTy.getElementType();

bool doExtF = false, doExtSI = false, doExtUI = false;
if (auto extfRhsOp = rhsVal.getDefiningOp<arith::ExtFOp>()) {
rhsVal = extfRhsOp.getIn();
rhsVecTy = cast<VectorType>(rhsVal.getType());
doExtF = true;
} else if (auto extsiRhsOp = rhsVal.getDefiningOp<arith::ExtSIOp>()) {
rhsVal = extsiRhsOp.getIn();
rhsVecTy = cast<VectorType>(rhsVal.getType());
doExtSI = true;
} else if (auto extuiRhsOp = rhsVal.getDefiningOp<arith::ExtUIOp>()) {
rhsVal = extuiRhsOp.getIn();
rhsVecTy = cast<VectorType>(rhsVal.getType());
doExtUI = true;
}

int64_t nDim = rhsVecTy.getShape().size();
SmallVector<int64_t> rhsPermutation;
for (int64_t i = 0; i < nDim - 2; i++)
rhsPermutation.push_back(i);
rhsPermutation.push_back(nDim - 1);
rhsPermutation.push_back(nDim - 2);
auto transpRhsVecTy = getTransposedVectorType(rhsVecTy);
rhsVal = rewriter
.create<vector::TransposeOp>(loc, transpRhsVecTy, rhsVal,
rhsPermutation)
.getResult();

if (doExtF)
rhsVal =
rewriter
.create<arith::ExtFOp>(
loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
rhsVal)
.getOut();
if (doExtSI)
rhsVal =
rewriter
.create<arith::ExtSIOp>(
loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
rhsVal)
.getOut();
if (doExtUI)
rhsVal =
rewriter
.create<arith::ExtUIOp>(
loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
rhsVal)
.getOut();

SmallVector<AffineMap, 4> oldIdxMaps(contractOp.getIndexingMapsArray());

nDim = oldIdxMaps[1].getNumDims();
SmallVector<int64_t> innerDimPerm;
for (int64_t i = 0; i < nDim - 2; i++)
innerDimPerm.push_back(i);
innerDimPerm.push_back(nDim - 1);
innerDimPerm.push_back(nDim - 2);
auto transpPermMap = AffineMap::getPermutationMap(innerDimPerm, ctx);

auto newIdxMaps = rewriter.getAffineMapArrayAttr(
{oldIdxMaps[0], oldIdxMaps[1].compose(transpPermMap), oldIdxMaps[2]});

rewriter.replaceOpWithNewOp<vector::ContractionOp>(
contractOp, contractOp.getResult().getType(), adaptor.getLhs(), rhsVal,
adaptor.getAcc(), newIdxMaps, contractOp.getIteratorTypes());

return success();
}
};

//============================================================================//
//============ AIEML canonicalization conversion patterns ===============//
//============================================================================//
Expand Down Expand Up @@ -470,15 +620,20 @@ static void configureAIEMLCanonicalizeLegalizations(ConversionTarget &target,
[](vector::TransferWriteOp op) {
return cast<VectorType>(op.getVector().getType()).getRank() < 2;
});
target.addDynamicallyLegalOp<vector::ContractionOp>(
[](vector::ContractionOp op) {
return !isGemmBTransposedContractionOp(op);
});
}

static void
populateAIEMLCanonicalizeConversionPatterns(RewritePatternSet &patterns,
TargetBackend backend) {
patterns.add<SplitUnalignedTransferReadPattern>(patterns.getContext(), 1024,
256);
patterns.add<FlattenMultDimTransferReadPattern,
FlattenMultDimTransferWritePattern>(patterns.getContext());
patterns
.add<ExtractTransposeFromContractionOp, FlattenMultDimTransferReadPattern,
FlattenMultDimTransferWritePattern>(patterns.getContext());
}

//============================================================================//
Expand Down
81 changes: 81 additions & 0 deletions test/dialect/AIEVec/precanonicalization-aieml.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,84 @@ func.func @multidim_vector_transfer(%in : memref<64x64x32x8xbf16>,
return
}

//
// -----
//

// CHECK: #[[IDXMAPA:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
// CHECK: #[[IDXMAPB:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
// CHECK: #[[IDXMAPC:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>

// CHECK-LABEL: func.func @vector_contract_permuted_b(
// CHECK-SAME: %[[VA:[a-zA-Z0-9]+]]: vector<1x1x4x8xbf16>,
// CHECK-SAME: %[[VB:[a-zA-Z0-9]+]]: vector<1x1x4x8xbf16>,
// CHECK-SAME: %[[VC:[a-zA-Z0-9]+]]: vector<1x1x4x4xf32>
func.func @vector_contract_permuted_b(%A : vector<1x1x4x8xbf16>,
%B : vector<1x1x4x8xbf16>,
%C : vector<1x1x4x4xf32>)
-> vector<1x1x4x4xf32> {
// CHECK: %[[TRB:.*]] = vector.transpose %[[VB]], [0, 1, 3, 2] :
// CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x8x4xbf16>
// CHECK: %[[RES:.*]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[IDXMAPA]], #[[IDXMAPB]], #[[IDXMAPC]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction",
// CHECK-SAME: "parallel", "parallel", "reduction"],
// CHECK-SAME: kind = #vector.kind<add>}
// CHECK-SAME: %[[VA]], %[[TRB]], %[[VC]] :
// CHECK-SAME: vector<1x1x4x8xbf16>, vector<1x1x8x4xbf16>
// CHECK-SAME: into vector<1x1x4x4xf32>
%res = vector.contract {
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction",
"parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %A, %B, %C :
vector<1x1x4x8xbf16>, vector<1x1x4x8xbf16> into vector<1x1x4x4xf32>
return %res : vector<1x1x4x4xf32>
}

//
// -----
//

// CHECK: #[[IDXMAPA:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
// CHECK: #[[IDXMAPB:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
// CHECK: #[[IDXMAPC:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>

// CHECK-LABEL: func.func @vector_contract_permuted_b(
// CHECK-SAME: %[[VA:[a-zA-Z0-9]+]]: vector<1x1x4x8xbf16>,
// CHECK-SAME: %[[VB:[a-zA-Z0-9]+]]: vector<1x1x4x8xbf16>,
// CHECK-SAME: %[[VC:[a-zA-Z0-9]+]]: vector<1x1x4x4xf32>
func.func @vector_contract_permuted_b(%A : vector<1x1x4x8xbf16>,
%B : vector<1x1x4x8xbf16>,
%C : vector<1x1x4x4xf32>)
-> vector<1x1x4x4xf32> {
// CHECK: %[[LHS:.*]] = arith.extf %[[VA]] :
// CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x4x8xf32>
// CHECK: %[[TRB:.*]] = vector.transpose %[[VB]], [0, 1, 3, 2] :
// CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x8x4xbf16>
// CHECK: %[[RHS:.*]] = arith.extf %[[TRB]] :
// CHECK-SAME: vector<1x1x8x4xbf16> to vector<1x1x8x4xf32>
// CHECK: %[[RES:.*]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[IDXMAPA]], #[[IDXMAPB]], #[[IDXMAPC]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction",
// CHECK-SAME: "parallel", "parallel", "reduction"],
// CHECK-SAME: kind = #vector.kind<add>}
// CHECK-SAME: %[[LHS]], %[[RHS]], %[[VC]] :
// CHECK-SAME: vector<1x1x4x8xf32>, vector<1x1x8x4xf32>
// CHECK-SAME: into vector<1x1x4x4xf32>
%lhs = arith.extf %A : vector<1x1x4x8xbf16> to vector<1x1x4x8xf32>
%rhs = arith.extf %B : vector<1x1x4x8xbf16> to vector<1x1x4x8xf32>
%res = vector.contract {
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction",
"parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %lhs, %rhs, %C :
vector<1x1x4x8xf32>, vector<1x1x4x8xf32> into vector<1x1x4x4xf32>
return %res : vector<1x1x4x4xf32>
}
Loading