Skip to content

Commit

Permalink
[aievec] Canonicalize transposed B contract op. (#1475)
Browse files Browse the repository at this point in the history
  • Loading branch information
jsetoain authored May 28, 2024
1 parent 555f014 commit 6338bc5
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 3 deletions.
161 changes: 158 additions & 3 deletions lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2023, Advanced Micro Devices, Inc.
// (c) Copyright 2023-2024 Advanced Micro Devices, Inc.
//
//===----------------------------------------------------------------------===//
// This file contains conversions and rewrites to the Vector dialect to make
Expand Down Expand Up @@ -39,6 +39,55 @@ using namespace xilinx::aievec;
//================== Common AIE canonicalization analysis ====================//
//============================================================================//

static bool isGemmBTransposedContractionOp(vector::ContractionOp op) {
if (op.getKind() != vector::CombiningKind::ADD)
return false;

// Get and check shape of operands
auto lhsShape = op.getLhsType().getShape();
auto rhsShape = op.getRhsType().getShape();
auto accShape = cast<ShapedType>(op.getAccType()).getShape();
if (lhsShape.size() < 2 || rhsShape.size() < 2 || accShape.size() < 2)
return false;

// Check that the innermost iterators match gemm-like iterators
SmallVector<vector::IteratorType> iterators = op.getIteratorTypesArray();
if (iterators.size() < 3)
return false;
auto innerMostIterators =
SmallVector<vector::IteratorType>(iterators.end() - 3, iterators.end());
if (vector::IteratorType::parallel != innerMostIterators[0] ||
vector::IteratorType::parallel != innerMostIterators[1] ||
vector::IteratorType::reduction != innerMostIterators[2])
return false;

// Get indexing maps of iterators for operands
SmallVector<AffineMap, 4> indexingMaps(op.getIndexingMapsArray());
SmallVector<int64_t> outerMostResults;
for (int64_t i = 0; i < indexingMaps[0].getNumResults() - 2; i++)
outerMostResults.push_back(i);

auto innerLhsMap = indexingMaps[0].dropResults(outerMostResults);
auto innerRhsMap = indexingMaps[1].dropResults(outerMostResults);
auto innerAccMap = indexingMaps[2].dropResults(outerMostResults);

// Check whether they conform to a "transposed B" gemm
auto ctx = op.getContext();
auto mmAidxMap =
AffineMap::getPermutationMap(ArrayRef<unsigned>{1, 0, 2}, ctx)
.dropResults(0);
auto mmBidxMap =
AffineMap::getPermutationMap(ArrayRef<unsigned>{0, 1, 2}, ctx)
.dropResults(0);
auto mmCidxMap =
AffineMap::getPermutationMap(ArrayRef<unsigned>{2, 0, 1}, ctx)
.dropResults(0);
int64_t numOuterMostDims = indexingMaps[0].getNumDims() - 3;
return innerLhsMap == mmAidxMap.shiftDims(numOuterMostDims) &&
innerRhsMap == mmBidxMap.shiftDims(numOuterMostDims) &&
innerAccMap == mmCidxMap.shiftDims(numOuterMostDims);
}

//============================================================================//
//============ Common AIE canonicalization conversion patterns ===============//
//============================================================================//
Expand Down Expand Up @@ -411,6 +460,107 @@ struct FlattenMultDimTransferWritePattern
}
};

// This pattern takes out an implicit transposition of the `rhs` operand in a
// gemm-like contraction op, making it an explicit `vector.transpose` op.
// If `rhs` is coming from a widening op (`extf`/`extsi`/`extui`), the
// transposition will be hoisted above the widening op.
struct ExtractTransposeFromContractionOp
: public OpConversionPattern<vector::ContractionOp> {
using OpConversionPattern<vector::ContractionOp>::OpConversionPattern;

static VectorType getTransposedVectorType(VectorType vecTy) {
SmallVector<int64_t> shape{vecTy.getShape()};
auto nDim = shape.size();
int64_t dimNm1 = shape[nDim - 1];
shape[nDim - 1] = shape[nDim - 2];
shape[nDim - 2] = dimNm1;
auto elemTy = vecTy.getElementType();
return VectorType::get(shape, elemTy);
}

LogicalResult
matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isGemmBTransposedContractionOp(contractOp))
return failure();

Location loc = contractOp.getLoc();
auto ctx = rewriter.getContext();

Value rhsVal = adaptor.getRhs();
VectorType rhsVecTy = contractOp.getRhsType();
Type rhsElemTy = rhsVecTy.getElementType();

bool doExtF = false, doExtSI = false, doExtUI = false;
if (auto extfRhsOp = rhsVal.getDefiningOp<arith::ExtFOp>()) {
rhsVal = extfRhsOp.getIn();
rhsVecTy = cast<VectorType>(rhsVal.getType());
doExtF = true;
} else if (auto extsiRhsOp = rhsVal.getDefiningOp<arith::ExtSIOp>()) {
rhsVal = extsiRhsOp.getIn();
rhsVecTy = cast<VectorType>(rhsVal.getType());
doExtSI = true;
} else if (auto extuiRhsOp = rhsVal.getDefiningOp<arith::ExtUIOp>()) {
rhsVal = extuiRhsOp.getIn();
rhsVecTy = cast<VectorType>(rhsVal.getType());
doExtUI = true;
}

int64_t nDim = rhsVecTy.getShape().size();
SmallVector<int64_t> rhsPermutation;
for (int64_t i = 0; i < nDim - 2; i++)
rhsPermutation.push_back(i);
rhsPermutation.push_back(nDim - 1);
rhsPermutation.push_back(nDim - 2);
auto transpRhsVecTy = getTransposedVectorType(rhsVecTy);
rhsVal = rewriter
.create<vector::TransposeOp>(loc, transpRhsVecTy, rhsVal,
rhsPermutation)
.getResult();

if (doExtF)
rhsVal =
rewriter
.create<arith::ExtFOp>(
loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
rhsVal)
.getOut();
if (doExtSI)
rhsVal =
rewriter
.create<arith::ExtSIOp>(
loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
rhsVal)
.getOut();
if (doExtUI)
rhsVal =
rewriter
.create<arith::ExtUIOp>(
loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy),
rhsVal)
.getOut();

SmallVector<AffineMap, 4> oldIdxMaps(contractOp.getIndexingMapsArray());

nDim = oldIdxMaps[1].getNumDims();
SmallVector<int64_t> innerDimPerm;
for (int64_t i = 0; i < nDim - 2; i++)
innerDimPerm.push_back(i);
innerDimPerm.push_back(nDim - 1);
innerDimPerm.push_back(nDim - 2);
auto transpPermMap = AffineMap::getPermutationMap(innerDimPerm, ctx);

auto newIdxMaps = rewriter.getAffineMapArrayAttr(
{oldIdxMaps[0], oldIdxMaps[1].compose(transpPermMap), oldIdxMaps[2]});

rewriter.replaceOpWithNewOp<vector::ContractionOp>(
contractOp, contractOp.getResult().getType(), adaptor.getLhs(), rhsVal,
adaptor.getAcc(), newIdxMaps, contractOp.getIteratorTypes());

return success();
}
};

//============================================================================//
//============ AIEML canonicalization conversion patterns ===============//
//============================================================================//
Expand Down Expand Up @@ -470,15 +620,20 @@ static void configureAIEMLCanonicalizeLegalizations(ConversionTarget &target,
[](vector::TransferWriteOp op) {
return cast<VectorType>(op.getVector().getType()).getRank() < 2;
});
target.addDynamicallyLegalOp<vector::ContractionOp>(
[](vector::ContractionOp op) {
return !isGemmBTransposedContractionOp(op);
});
}

static void
populateAIEMLCanonicalizeConversionPatterns(RewritePatternSet &patterns,
TargetBackend backend) {
patterns.add<SplitUnalignedTransferReadPattern>(patterns.getContext(), 1024,
256);
patterns.add<FlattenMultDimTransferReadPattern,
FlattenMultDimTransferWritePattern>(patterns.getContext());
patterns
.add<ExtractTransposeFromContractionOp, FlattenMultDimTransferReadPattern,
FlattenMultDimTransferWritePattern>(patterns.getContext());
}

//============================================================================//
Expand Down
81 changes: 81 additions & 0 deletions test/dialect/AIEVec/precanonicalization-aieml.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,84 @@ func.func @multidim_vector_transfer(%in : memref<64x64x32x8xbf16>,
return
}

//
// -----
//

// CHECK: #[[IDXMAPA:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
// CHECK: #[[IDXMAPB:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
// CHECK: #[[IDXMAPC:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>

// CHECK-LABEL: func.func @vector_contract_permuted_b(
// CHECK-SAME: %[[VA:[a-zA-Z0-9]+]]: vector<1x1x4x8xbf16>,
// CHECK-SAME: %[[VB:[a-zA-Z0-9]+]]: vector<1x1x4x8xbf16>,
// CHECK-SAME: %[[VC:[a-zA-Z0-9]+]]: vector<1x1x4x4xf32>
func.func @vector_contract_permuted_b(%A : vector<1x1x4x8xbf16>,
%B : vector<1x1x4x8xbf16>,
%C : vector<1x1x4x4xf32>)
-> vector<1x1x4x4xf32> {
// CHECK: %[[TRB:.*]] = vector.transpose %[[VB]], [0, 1, 3, 2] :
// CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x8x4xbf16>
// CHECK: %[[RES:.*]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[IDXMAPA]], #[[IDXMAPB]], #[[IDXMAPC]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction",
// CHECK-SAME: "parallel", "parallel", "reduction"],
// CHECK-SAME: kind = #vector.kind<add>}
// CHECK-SAME: %[[VA]], %[[TRB]], %[[VC]] :
// CHECK-SAME: vector<1x1x4x8xbf16>, vector<1x1x8x4xbf16>
// CHECK-SAME: into vector<1x1x4x4xf32>
%res = vector.contract {
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction",
"parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %A, %B, %C :
vector<1x1x4x8xbf16>, vector<1x1x4x8xbf16> into vector<1x1x4x4xf32>
return %res : vector<1x1x4x4xf32>
}

//
// -----
//

// CHECK: #[[IDXMAPA:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
// CHECK: #[[IDXMAPB:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
// CHECK: #[[IDXMAPC:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>

// CHECK-LABEL: func.func @vector_contract_permuted_b(
// CHECK-SAME: %[[VA:[a-zA-Z0-9]+]]: vector<1x1x4x8xbf16>,
// CHECK-SAME: %[[VB:[a-zA-Z0-9]+]]: vector<1x1x4x8xbf16>,
// CHECK-SAME: %[[VC:[a-zA-Z0-9]+]]: vector<1x1x4x4xf32>
func.func @vector_contract_permuted_b(%A : vector<1x1x4x8xbf16>,
%B : vector<1x1x4x8xbf16>,
%C : vector<1x1x4x4xf32>)
-> vector<1x1x4x4xf32> {
// CHECK: %[[LHS:.*]] = arith.extf %[[VA]] :
// CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x4x8xf32>
// CHECK: %[[TRB:.*]] = vector.transpose %[[VB]], [0, 1, 3, 2] :
// CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x8x4xbf16>
// CHECK: %[[RHS:.*]] = arith.extf %[[TRB]] :
// CHECK-SAME: vector<1x1x8x4xbf16> to vector<1x1x8x4xf32>
// CHECK: %[[RES:.*]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[IDXMAPA]], #[[IDXMAPB]], #[[IDXMAPC]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction",
// CHECK-SAME: "parallel", "parallel", "reduction"],
// CHECK-SAME: kind = #vector.kind<add>}
// CHECK-SAME: %[[LHS]], %[[RHS]], %[[VC]] :
// CHECK-SAME: vector<1x1x4x8xf32>, vector<1x1x8x4xf32>
// CHECK-SAME: into vector<1x1x4x4xf32>
%lhs = arith.extf %A : vector<1x1x4x8xbf16> to vector<1x1x4x8xf32>
%rhs = arith.extf %B : vector<1x1x4x8xbf16> to vector<1x1x4x8xf32>
%res = vector.contract {
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction",
"parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %lhs, %rhs, %C :
vector<1x1x4x8xf32>, vector<1x1x4x8xf32> into vector<1x1x4x4xf32>
return %res : vector<1x1x4x4xf32>
}

0 comments on commit 6338bc5

Please sign in to comment.