diff --git a/lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp b/lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp index 42cf201824..0ab77bcde7 100644 --- a/lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp +++ b/lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp @@ -4,7 +4,7 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// (c) Copyright 2023, Advanced Micro Devices, Inc. +// (c) Copyright 2023-2024 Advanced Micro Devices, Inc. // //===----------------------------------------------------------------------===// // This file contains conversions and rewrites to the Vector dialect to make @@ -39,6 +39,55 @@ using namespace xilinx::aievec; //================== Common AIE canonicalization analysis ====================// //============================================================================// +static bool isGemmBTransposedContractionOp(vector::ContractionOp op) { + if (op.getKind() != vector::CombiningKind::ADD) + return false; + + // Get and check shape of operands + auto lhsShape = op.getLhsType().getShape(); + auto rhsShape = op.getRhsType().getShape(); + auto accShape = cast(op.getAccType()).getShape(); + if (lhsShape.size() < 2 || rhsShape.size() < 2 || accShape.size() < 2) + return false; + + // Check that the innermost iterators match gemm-like iterators + SmallVector iterators = op.getIteratorTypesArray(); + if (iterators.size() < 3) + return false; + auto innerMostIterators = + SmallVector(iterators.end() - 3, iterators.end()); + if (vector::IteratorType::parallel != innerMostIterators[0] || + vector::IteratorType::parallel != innerMostIterators[1] || + vector::IteratorType::reduction != innerMostIterators[2]) + return false; + + // Get indexing maps of iterators for operands + SmallVector indexingMaps(op.getIndexingMapsArray()); + SmallVector outerMostResults; + for (int64_t i = 0; i < indexingMaps[0].getNumResults() - 2; i++) + outerMostResults.push_back(i); + + auto innerLhsMap = indexingMaps[0].dropResults(outerMostResults); + auto innerRhsMap = indexingMaps[1].dropResults(outerMostResults); + auto innerAccMap = indexingMaps[2].dropResults(outerMostResults); + + // Check whether they conform to a "transposed B" gemm + auto ctx = op.getContext(); + auto mmAidxMap = + AffineMap::getPermutationMap(ArrayRef{1, 0, 2}, ctx) + .dropResults(0); + auto mmBidxMap = + AffineMap::getPermutationMap(ArrayRef{0, 1, 2}, ctx) + .dropResults(0); + auto mmCidxMap = + AffineMap::getPermutationMap(ArrayRef{2, 0, 1}, ctx) + .dropResults(0); + int64_t numOuterMostDims = indexingMaps[0].getNumDims() - 3; + return innerLhsMap == mmAidxMap.shiftDims(numOuterMostDims) && + innerRhsMap == mmBidxMap.shiftDims(numOuterMostDims) && + innerAccMap == mmCidxMap.shiftDims(numOuterMostDims); +} + //============================================================================// //============ Common AIE canonicalization conversion patterns ===============// //============================================================================// @@ -411,6 +460,107 @@ struct FlattenMultDimTransferWritePattern } }; +// This pattern takes out an implicit transposition of the `rhs` operand in a +// gemm-like contraction op, making it an explicit `vector.transpose` op. +// If `rhs` is coming from a widening op (`extf`/`extsi`/`extui`), the +// transposition will be hoisted above the widening op. +struct ExtractTransposeFromContractionOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static VectorType getTransposedVectorType(VectorType vecTy) { + SmallVector shape{vecTy.getShape()}; + auto nDim = shape.size(); + int64_t dimNm1 = shape[nDim - 1]; + shape[nDim - 1] = shape[nDim - 2]; + shape[nDim - 2] = dimNm1; + auto elemTy = vecTy.getElementType(); + return VectorType::get(shape, elemTy); + } + + LogicalResult + matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isGemmBTransposedContractionOp(contractOp)) + return failure(); + + Location loc = contractOp.getLoc(); + auto ctx = rewriter.getContext(); + + Value rhsVal = adaptor.getRhs(); + VectorType rhsVecTy = contractOp.getRhsType(); + Type rhsElemTy = rhsVecTy.getElementType(); + + bool doExtF = false, doExtSI = false, doExtUI = false; + if (auto extfRhsOp = rhsVal.getDefiningOp()) { + rhsVal = extfRhsOp.getIn(); + rhsVecTy = cast(rhsVal.getType()); + doExtF = true; + } else if (auto extsiRhsOp = rhsVal.getDefiningOp()) { + rhsVal = extsiRhsOp.getIn(); + rhsVecTy = cast(rhsVal.getType()); + doExtSI = true; + } else if (auto extuiRhsOp = rhsVal.getDefiningOp()) { + rhsVal = extuiRhsOp.getIn(); + rhsVecTy = cast(rhsVal.getType()); + doExtUI = true; + } + + int64_t nDim = rhsVecTy.getShape().size(); + SmallVector rhsPermutation; + for (int64_t i = 0; i < nDim - 2; i++) + rhsPermutation.push_back(i); + rhsPermutation.push_back(nDim - 1); + rhsPermutation.push_back(nDim - 2); + auto transpRhsVecTy = getTransposedVectorType(rhsVecTy); + rhsVal = rewriter + .create(loc, transpRhsVecTy, rhsVal, + rhsPermutation) + .getResult(); + + if (doExtF) + rhsVal = + rewriter + .create( + loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy), + rhsVal) + .getOut(); + if (doExtSI) + rhsVal = + rewriter + .create( + loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy), + rhsVal) + .getOut(); + if (doExtUI) + rhsVal = + rewriter + .create( + loc, VectorType::get(transpRhsVecTy.getShape(), rhsElemTy), + rhsVal) + .getOut(); + + SmallVector oldIdxMaps(contractOp.getIndexingMapsArray()); + + nDim = oldIdxMaps[1].getNumDims(); + SmallVector innerDimPerm; + for (int64_t i = 0; i < nDim - 2; i++) + innerDimPerm.push_back(i); + innerDimPerm.push_back(nDim - 1); + innerDimPerm.push_back(nDim - 2); + auto transpPermMap = AffineMap::getPermutationMap(innerDimPerm, ctx); + + auto newIdxMaps = rewriter.getAffineMapArrayAttr( + {oldIdxMaps[0], oldIdxMaps[1].compose(transpPermMap), oldIdxMaps[2]}); + + rewriter.replaceOpWithNewOp( + contractOp, contractOp.getResult().getType(), adaptor.getLhs(), rhsVal, + adaptor.getAcc(), newIdxMaps, contractOp.getIteratorTypes()); + + return success(); + } +}; + //============================================================================// //============ AIEML canonicalization conversion patterns ===============// //============================================================================// @@ -470,6 +620,10 @@ static void configureAIEMLCanonicalizeLegalizations(ConversionTarget &target, [](vector::TransferWriteOp op) { return cast(op.getVector().getType()).getRank() < 2; }); + target.addDynamicallyLegalOp( + [](vector::ContractionOp op) { + return !isGemmBTransposedContractionOp(op); + }); } static void @@ -477,8 +631,9 @@ populateAIEMLCanonicalizeConversionPatterns(RewritePatternSet &patterns, TargetBackend backend) { patterns.add(patterns.getContext(), 1024, 256); - patterns.add(patterns.getContext()); + patterns + .add(patterns.getContext()); } //============================================================================// diff --git a/test/dialect/AIEVec/precanonicalization-aieml.mlir b/test/dialect/AIEVec/precanonicalization-aieml.mlir index e4e1004a0c..14b557b49b 100644 --- a/test/dialect/AIEVec/precanonicalization-aieml.mlir +++ b/test/dialect/AIEVec/precanonicalization-aieml.mlir @@ -85,3 +85,84 @@ func.func @multidim_vector_transfer(%in : memref<64x64x32x8xbf16>, return } +// +// ----- +// + +// CHECK: #[[IDXMAPA:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)> +// CHECK: #[[IDXMAPB:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)> +// CHECK: #[[IDXMAPC:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> + +// CHECK-LABEL: func.func @vector_contract_permuted_b( +// CHECK-SAME: %[[VA:[a-zA-Z0-9]+]]: vector<1x1x4x8xbf16>, +// CHECK-SAME: %[[VB:[a-zA-Z0-9]+]]: vector<1x1x4x8xbf16>, +// CHECK-SAME: %[[VC:[a-zA-Z0-9]+]]: vector<1x1x4x4xf32> +func.func @vector_contract_permuted_b(%A : vector<1x1x4x8xbf16>, + %B : vector<1x1x4x8xbf16>, + %C : vector<1x1x4x4xf32>) + -> vector<1x1x4x4xf32> { + // CHECK: %[[TRB:.*]] = vector.transpose %[[VB]], [0, 1, 3, 2] : + // CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x8x4xbf16> + // CHECK: %[[RES:.*]] = vector.contract { + // CHECK-SAME: indexing_maps = [#[[IDXMAPA]], #[[IDXMAPB]], #[[IDXMAPC]]], + // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", + // CHECK-SAME: "parallel", "parallel", "reduction"], + // CHECK-SAME: kind = #vector.kind} + // CHECK-SAME: %[[VA]], %[[TRB]], %[[VC]] : + // CHECK-SAME: vector<1x1x4x8xbf16>, vector<1x1x8x4xbf16> + // CHECK-SAME: into vector<1x1x4x4xf32> + %res = vector.contract { + indexing_maps = [#map1, #map2, #map3], + iterator_types = ["parallel", "parallel", "reduction", + "parallel", "parallel", "reduction"], + kind = #vector.kind} %A, %B, %C : + vector<1x1x4x8xbf16>, vector<1x1x4x8xbf16> into vector<1x1x4x4xf32> + return %res : vector<1x1x4x4xf32> +} + +// +// ----- +// + +// CHECK: #[[IDXMAPA:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)> +// CHECK: #[[IDXMAPB:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)> +// CHECK: #[[IDXMAPC:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> + +// CHECK-LABEL: func.func @vector_contract_permuted_b( +// CHECK-SAME: %[[VA:[a-zA-Z0-9]+]]: vector<1x1x4x8xbf16>, +// CHECK-SAME: %[[VB:[a-zA-Z0-9]+]]: vector<1x1x4x8xbf16>, +// CHECK-SAME: %[[VC:[a-zA-Z0-9]+]]: vector<1x1x4x4xf32> +func.func @vector_contract_permuted_b(%A : vector<1x1x4x8xbf16>, + %B : vector<1x1x4x8xbf16>, + %C : vector<1x1x4x4xf32>) + -> vector<1x1x4x4xf32> { + // CHECK: %[[LHS:.*]] = arith.extf %[[VA]] : + // CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x4x8xf32> + // CHECK: %[[TRB:.*]] = vector.transpose %[[VB]], [0, 1, 3, 2] : + // CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x8x4xbf16> + // CHECK: %[[RHS:.*]] = arith.extf %[[TRB]] : + // CHECK-SAME: vector<1x1x8x4xbf16> to vector<1x1x8x4xf32> + // CHECK: %[[RES:.*]] = vector.contract { + // CHECK-SAME: indexing_maps = [#[[IDXMAPA]], #[[IDXMAPB]], #[[IDXMAPC]]], + // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", + // CHECK-SAME: "parallel", "parallel", "reduction"], + // CHECK-SAME: kind = #vector.kind} + // CHECK-SAME: %[[LHS]], %[[RHS]], %[[VC]] : + // CHECK-SAME: vector<1x1x4x8xf32>, vector<1x1x8x4xf32> + // CHECK-SAME: into vector<1x1x4x4xf32> + %lhs = arith.extf %A : vector<1x1x4x8xbf16> to vector<1x1x4x8xf32> + %rhs = arith.extf %B : vector<1x1x4x8xbf16> to vector<1x1x4x8xf32> + %res = vector.contract { + indexing_maps = [#map1, #map2, #map3], + iterator_types = ["parallel", "parallel", "reduction", + "parallel", "parallel", "reduction"], + kind = #vector.kind} %lhs, %rhs, %C : + vector<1x1x4x8xf32>, vector<1x1x4x8xf32> into vector<1x1x4x4xf32> + return %res : vector<1x1x4x4xf32> +} \ No newline at end of file