diff --git a/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp b/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp index ed633af851..fe9ba98f31 100644 --- a/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp +++ b/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp @@ -47,6 +47,49 @@ using namespace xilinx::aievec; // Utility functions //===----------------------------------------------------------------------===// +static bool isNarrowingOp(Operation *op) { + if (isa(op) || isa(op)) + return true; + + if (auto srsOp = dyn_cast(op)) { + auto srsOpSrcOp = srsOp.getSource().getDefiningOp(); + if (isa(srsOpSrcOp) || isa(srsOpSrcOp)) + return true; + } + return false; +} + +// Given a Value, if it is defined by a widening op (arith:ExtSIOp, +// arith::ExtUIOp, arith::ExtFOp, aievec::UPSOp + aievec::SRSOp, +// aievec::UPSOp + aievec::CastOp), return the source of the widening op. +static std::optional getSourceOfWideningOp(Value src) { + if (auto extSIOp = src.getDefiningOp()) + return extSIOp.getIn(); + if (auto extUIOp = src.getDefiningOp()) + return extUIOp.getIn(); + if (auto extFOp = src.getDefiningOp()) + return extFOp.getIn(); + if (auto srsOp = src.getDefiningOp()) { + // Conversion through AIE intrinsics takes two steps: + // 1) Load to accumulator: aievec.ups + // 2) Move from accumulator: aievec.srs + auto srsSource = srsOp.getSource(); + if (srsSource) + if (auto upsOp = srsSource.getDefiningOp()) + return upsOp.getSource(); + } + if (auto castOp = src.getDefiningOp()) { + // Conversion through AIE intrinsics can also take the following two steps: + // 1) Load to accumulator: aievec.ups + // 2) Move from accumulator: aievec.cast + auto castSource = castOp.getSource(); + if (castSource) + if (auto upsOp = castSource.getDefiningOp()) + return upsOp.getSource(); + } + return std::optional(); +} + // Given the LHS and RHS of an `arith::AddIOp`, if one of them is defined by an // `arith::MulIOp`, return a tuple with the `lhs`, `rhs`, and `acc` of the MAC // operation that can replace them. @@ -619,12 +662,8 @@ struct ConvertMulFToAIEVecMulElemOpPattern // Decide the accType for aievec.mul_elem based on mulOp's lhs & rhs auto lval = adaptor.getLhs(); auto rval = adaptor.getRhs(); - if (auto lvalExtOp = lval.getDefiningOp()) { - lval = lvalExtOp->getOperand(0); - } - if (auto rvalExtOp = rval.getDefiningOp()) { - rval = rvalExtOp->getOperand(0); - } + lval = getSourceOfWideningOp(lval).value_or(lval); + rval = getSourceOfWideningOp(rval).value_or(rval); auto lSrcType = cast(lval.getType()); auto rSrcType = cast(rval.getType()); unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth(); @@ -713,12 +752,8 @@ struct ConvertMulIToAIEVecMulElemOpPattern // Decide the accType for aievec.mul_elem based on mulOp's lhs & rhs auto lval = adaptor.getLhs(); auto rval = adaptor.getRhs(); - if (auto lvalExtOp = lval.getDefiningOp()) { - lval = lvalExtOp->getOperand(0); - } - if (auto rvalExtOp = rval.getDefiningOp()) { - rval = rvalExtOp->getOperand(0); - } + lval = getSourceOfWideningOp(lval).value_or(lval); + rval = getSourceOfWideningOp(rval).value_or(rval); auto lSrcType = cast(lval.getType()); auto rSrcType = cast(rval.getType()); unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth(); @@ -1054,8 +1089,8 @@ struct LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp // If element width is 32, we need to consider sign extension cases if (resultElWidth == 32) { - auto lhsExt = lhsDefOp ? dyn_cast(lhsDefOp) : nullptr; - auto rhsExt = rhsDefOp ? dyn_cast(rhsDefOp) : nullptr; + auto lhsExt = getSourceOfWideningOp(lhs).value_or(nullptr); + auto rhsExt = getSourceOfWideningOp(rhs).value_or(nullptr); if (!lhsExt && !rhsExt) { if (laneSize * resultElWidth == 512) { @@ -1068,8 +1103,8 @@ struct LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp } if (lhsExt && rhsExt) { - auto lval = lhsExt->getOperand(0); - auto rval = rhsExt->getOperand(0); + auto lval = lhsExt; + auto rval = rhsExt; VectorType lSrcType = cast(lval.getType()); Type accType = getVectorOpDestType(lSrcType, /*AIEML =*/true); @@ -1086,8 +1121,8 @@ struct LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp } if (!lhsExt || !rhsExt) { - auto lval = lhsExt ? lhsExt->getOperand(0) : lhs; - auto rval = rhsExt ? rhsExt->getOperand(0) : rhs; + auto lval = lhsExt ? lhsExt : lhs; + auto rval = rhsExt ? rhsExt : rhs; auto extVal = lhsExt ? lval : rval; VectorType vType = cast(extVal.getType()); unsigned bitWidth = vType.getElementType().getIntOrFloatBitWidth(); @@ -1160,8 +1195,8 @@ struct LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp resultType, srcOp); } - auto lhsExt = lhsDefOp ? dyn_cast(lhsDefOp) : nullptr; - auto rhsExt = rhsDefOp ? dyn_cast(rhsDefOp) : nullptr; + auto lhsExt = getSourceOfWideningOp(lhs).value_or(nullptr); + auto rhsExt = getSourceOfWideningOp(rhs).value_or(nullptr); // v16float if (!lhsExt && !rhsExt) { return genAddElemAieML(rewriter, lhs, rhs, @@ -1170,8 +1205,8 @@ struct LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp // v16bf16 with two extension ops if (lhsExt && rhsExt) { - auto lval = lhsExt->getOperand(0); - auto rval = rhsExt->getOperand(0); + auto lval = lhsExt; + auto rval = rhsExt; VectorType vType = cast(lval.getType()); Type accType = getVectorOpDestType(vType, /*AIEML =*/true); @@ -1189,8 +1224,8 @@ struct LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp // v16bf16 with one extension op if (!lhsExt || !rhsExt) { - auto lval = lhsExt ? lhsExt->getOperand(0) : lhs; - auto rval = rhsExt ? rhsExt->getOperand(0) : rhs; + auto lval = lhsExt ? lhsExt : lhs; + auto rval = rhsExt ? rhsExt : rhs; auto extVal = lhsExt ? lval : rval; VectorType vType = cast(extVal.getType()); Type accType = getVectorOpDestType(vType, /*AIEML =*/true); @@ -1819,7 +1854,7 @@ struct ComputeInvOpByLUTPattern : OpConversionPattern { !isa(srcType)) return failure(); - if (!isa(*divOp->getUsers().begin())) + if (!isNarrowingOp(*divOp->getUsers().begin())) return failure(); auto fType = cast(srcType); @@ -2735,26 +2770,20 @@ struct LowerVectorContractionOpToAIEVecMatMulPattern // contractions, lower precisions operands are cast to the target // precission outside the contraction. For those cases, we check. lhs = adaptor.getLhs(); + auto wideLhsValue = getSourceOfWideningOp(lhs).value_or(nullptr); + if (wideLhsValue) + lhs = reshapeLeadingUnitDims(rewriter, wideLhsValue); + rhs = adaptor.getRhs(); - if (auto lhsExtSIOp = lhs.getDefiningOp()) - lhs = reshapeLeadingUnitDims(rewriter, lhsExtSIOp.getIn()); - else if (auto lhsExtUIOp = lhs.getDefiningOp()) - lhs = reshapeLeadingUnitDims(rewriter, lhsExtUIOp.getIn()); - else if (auto lhsExtFOp = lhs.getDefiningOp()) - lhs = reshapeLeadingUnitDims(rewriter, lhsExtFOp.getIn()); - - if (auto rhsExtSIOp = rhs.getDefiningOp()) - rhs = reshapeLeadingUnitDims(rewriter, rhsExtSIOp.getIn()); - else if (auto rhsExtUIOp = rhs.getDefiningOp()) - rhs = reshapeLeadingUnitDims(rewriter, rhsExtUIOp.getIn()); - else if (auto rhsExtFOp = rhs.getDefiningOp()) - rhs = reshapeLeadingUnitDims(rewriter, rhsExtFOp.getIn()); + auto wideRhsValue = getSourceOfWideningOp(rhs).value_or(nullptr); + if (wideRhsValue) + rhs = reshapeLeadingUnitDims(rewriter, wideRhsValue); matmulOp = rewriter.create( contractOp.getLoc(), acc.getType(), lhs, rhs, acc); + if (failed(matmulOp.verifyInvariants())) + return failure(); } - if (failed(matmulOp.verifyInvariants())) - return failure(); } Value result = matmulOp.getResult(); @@ -2776,6 +2805,16 @@ struct LowerVectorContractionOpToAIEVecMatMulPattern // Pattern collection //===----------------------------------------------------------------------===// +static void populateAIEVecCommonConversionPatterns(RewritePatternSet &patterns, + TargetBackend backend) { + // clang-format off + patterns.add(patterns.getContext()); + // clang-format on +} + static void populateAIEVecV1ConversionPatterns(RewritePatternSet &patterns, TargetBackend backend) { patterns.add(patterns.getContext(), 128, 512, @@ -2795,15 +2834,20 @@ static void populateAIEVecV1ConversionPatterns(RewritePatternSet &patterns, static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns, TargetBackend backend) { - if (backend == TargetBackend::CPP) { - patterns.add(patterns.getContext(), 128, - 1024, 256, 1024); - } // clang-format off // TODO: Reorder these alphabetically + if (backend == TargetBackend::CPP) { + patterns.add< + LowerVectorTransferReadToAIEUPD + >(patterns.getContext(), 128, 1024, 256, 1024); + patterns.add< + LowerVectorAddFOpToAIEVecAddElemOp, + LowerVectorSubFOpToAIEVecSubElemOp, + LowerVectorAddIOpToAIEVecAddElemOp, + LowerVectorSubIOpToAIEVecSubElemOp + >(patterns.getContext()); + } patterns.add< - LowerVectorAddIOpToAIEVecAddElemOp, - LowerVectorSubIOpToAIEVecSubElemOp, ComputeExpOpByLUTPattern, ComputeInvOpByLUTPattern, ComputeTanhOpByLUTPattern, @@ -2821,8 +2865,6 @@ static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns, ComputeBandOpPattern, ComputeSignedIntRightShiftOpPattern, ConvertMulIToAIEVecMulElemOpPattern, - LowerVectorAddFOpToAIEVecAddElemOp, - LowerVectorSubFOpToAIEVecSubElemOp, ConvertMulFToAIEVecMulElemOpPattern, LowerVectorMinSIOpToAIEVecMinOp, LowerVectorMinimumFOpToAIEVecMinOp, @@ -2915,6 +2957,98 @@ static void configureAIEVecCommonLegalizations(ConversionTarget &target, target.addIllegalOp(); } target.addIllegalOp(); + + // ****************************NEW**************************** + target.addDynamicallyLegalOp([](arith::ExtFOp extfOp) { + auto srcType = dyn_cast(extfOp.getIn().getType()); + auto dstType = dyn_cast(extfOp.getOut().getType()); + if (!srcType || !dstType) + return true; + + Type srcScalarType = srcType.getElementType(); + Type dstScalarType = dstType.getElementType(); + if (!isa(srcScalarType) || !isa(dstScalarType)) + return true; + + unsigned srcLaneSize = getVectorLaneSize(srcType); + unsigned dstLaneSize = getVectorLaneSize(dstType); + unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth(); + unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth(); + if (srcElWidth != 16 || srcLaneSize != 16 || dstElWidth != 32 || + dstLaneSize != 16) + return true; + + return false; + }); + + target.addDynamicallyLegalOp([](arith::ExtSIOp extsiOp) { + auto srcType = dyn_cast(extsiOp.getIn().getType()); + auto dstType = dyn_cast(extsiOp.getOut().getType()); + if (!srcType || !dstType) + return true; + + Type srcScalarType = srcType.getElementType(); + Type dstScalarType = dstType.getElementType(); + if (!isa(srcScalarType) || !isa(dstScalarType)) + return true; + + unsigned srcLaneSize = getVectorLaneSize(srcType); + unsigned dstLaneSize = getVectorLaneSize(dstType); + unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth(); + unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth(); + if (!(srcLaneSize == 32 && (dstElWidth > srcElWidth) && + (dstLaneSize == srcLaneSize))) + return true; + + return false; + }); + + target.addDynamicallyLegalOp([](arith::TruncFOp truncfOp) { + auto srcType = dyn_cast(truncfOp.getIn().getType()); + auto dstType = dyn_cast(truncfOp.getOut().getType()); + if (!srcType || !dstType) + return true; + + Type srcScalarType = srcType.getElementType(); + Type dstScalarType = dstType.getElementType(); + if (!isa(srcScalarType) || !isa(dstScalarType)) + return true; + + unsigned srcLaneSize = getVectorLaneSize(srcType); + unsigned dstLaneSize = getVectorLaneSize(dstType); + unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth(); + unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth(); + if (srcElWidth != 32 || srcLaneSize != 16 || dstElWidth != 16 || + dstLaneSize != 16) + return true; + + return false; + }); + + target.addDynamicallyLegalOp([](arith::TruncIOp trunciOp) { + auto srcType = dyn_cast(trunciOp.getIn().getType()); + auto dstType = dyn_cast(trunciOp.getOut().getType()); + if (!srcType || !dstType) + return true; + + Type srcScalarType = srcType.getElementType(); + Type dstScalarType = dstType.getElementType(); + if (!isa(srcScalarType) || !isa(dstScalarType)) + return true; + + unsigned srcLaneSize = getVectorLaneSize(srcType); + unsigned dstLaneSize = getVectorLaneSize(dstType); + unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth(); + unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth(); + + if (!(srcLaneSize == 32 && (dstElWidth < srcElWidth) && + (dstLaneSize == srcLaneSize))) + return true; + + return false; + }); + // *********************************************************** + target.addDynamicallyLegalOp([](math::ExpOp expOp) { auto srcType = dyn_cast(expOp.getOperand().getType()); if (!srcType) @@ -3030,7 +3164,7 @@ static void configureAIEVecCommonLegalizations(ConversionTarget &target, Type scalarType = divfOp.getLhs().getType(); if (!divfOp->hasOneUse() || !isa(scalarType)) return true; - if (!isa(*divfOp->getUsers().begin())) + if (!isNarrowingOp(*divfOp->getUsers().begin())) return true; auto fType = cast(scalarType); @@ -3159,8 +3293,10 @@ static void configureAIEVecCommonLegalizations(ConversionTarget &target, return laneSize * elWidth != 512; }); - target.addDynamicallyLegalOp( - [](arith::AddIOp op) { return !isa(op.getType()); }); + if (backend == TargetBackend::CPP) { + target.addDynamicallyLegalOp( + [](arith::AddIOp op) { return !isa(op.getType()); }); + } target.addDynamicallyLegalOp( [](arith::AddFOp op) { return !isa(op.getType()); }); target.addDynamicallyLegalOp( @@ -3227,17 +3363,19 @@ static void configureAIEVecV2Legalizations(ConversionTarget &target, elWidthSet.insert(16); elWidthSet.insert(32); - target.addDynamicallyLegalOp([=](arith::AddIOp op) { - auto resultType = dyn_cast(op.getType()); - if (!resultType) - return true; + if (backend == TargetBackend::CPP) { + target.addDynamicallyLegalOp([=](arith::AddIOp op) { + auto resultType = dyn_cast(op.getType()); + if (!resultType) + return true; - auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth(); - unsigned laneSize = getVectorLaneSize(resultType); + auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth(); + unsigned laneSize = getVectorLaneSize(resultType); - return !laneSizeElWidthPairSet.count( - std::make_pair(laneSize, resultElWidth)); - }); + return !laneSizeElWidthPairSet.count( + std::make_pair(laneSize, resultElWidth)); + }); + } target.addDynamicallyLegalOp([=](arith::SubIOp op) { auto resultType = dyn_cast(op.getType()); @@ -3504,6 +3642,7 @@ struct LowerVectorToAIEVec : PassWrapper> { } } + populateAIEVecCommonConversionPatterns(patterns, backend); configureAIEVecCommonLegalizations(target, backend); if (aieVersion == AIEArch::AIE) { populateAIEVecV1ConversionPatterns(patterns, backend); @@ -3527,111 +3666,6 @@ createLowerVectorToAIEVec(const LowerVectorToAIEVecOptions &options) { // Custom canonicalization passes //===--------------------------------------------------------------------------- -struct ProcessExtOpsPass : PassWrapper> { - - void runOnOperation() override { - MLIRContext *context = &getContext(); - RewritePatternSet patterns(context); - ConversionTarget target(*context); - patterns.add(patterns.getContext()); - target.addLegalDialect(); - target.addDynamicallyLegalOp([](arith::ExtFOp extfOp) { - auto srcType = dyn_cast(extfOp.getIn().getType()); - auto dstType = dyn_cast(extfOp.getOut().getType()); - if (!srcType || !dstType) - return true; - - Type srcScalarType = srcType.getElementType(); - Type dstScalarType = dstType.getElementType(); - if (!isa(srcScalarType) || !isa(dstScalarType)) - return true; - - unsigned srcLaneSize = getVectorLaneSize(srcType); - unsigned dstLaneSize = getVectorLaneSize(dstType); - unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth(); - unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth(); - if (srcElWidth != 16 || srcLaneSize != 16 || dstElWidth != 32 || - dstLaneSize != 16) - return true; - - return false; - }); - - target.addDynamicallyLegalOp([](arith::ExtSIOp extsiOp) { - auto srcType = dyn_cast(extsiOp.getIn().getType()); - auto dstType = dyn_cast(extsiOp.getOut().getType()); - if (!srcType || !dstType) - return true; - - Type srcScalarType = srcType.getElementType(); - Type dstScalarType = dstType.getElementType(); - if (!isa(srcScalarType) || !isa(dstScalarType)) - return true; - - unsigned srcLaneSize = getVectorLaneSize(srcType); - unsigned dstLaneSize = getVectorLaneSize(dstType); - unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth(); - unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth(); - if (!(srcLaneSize == 32 && (dstElWidth > srcElWidth) && - (dstLaneSize == srcLaneSize))) - return true; - - return false; - }); - - target.addDynamicallyLegalOp([](arith::TruncFOp truncfOp) { - auto srcType = dyn_cast(truncfOp.getIn().getType()); - auto dstType = dyn_cast(truncfOp.getOut().getType()); - if (!srcType || !dstType) - return true; - - Type srcScalarType = srcType.getElementType(); - Type dstScalarType = dstType.getElementType(); - if (!isa(srcScalarType) || !isa(dstScalarType)) - return true; - - unsigned srcLaneSize = getVectorLaneSize(srcType); - unsigned dstLaneSize = getVectorLaneSize(dstType); - unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth(); - unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth(); - if (srcElWidth != 32 || srcLaneSize != 16 || dstElWidth != 16 || - dstLaneSize != 16) - return true; - - return false; - }); - - target.addDynamicallyLegalOp([](arith::TruncIOp trunciOp) { - auto srcType = dyn_cast(trunciOp.getIn().getType()); - auto dstType = dyn_cast(trunciOp.getOut().getType()); - if (!srcType || !dstType) - return true; - - Type srcScalarType = srcType.getElementType(); - Type dstScalarType = dstType.getElementType(); - if (!isa(srcScalarType) || !isa(dstScalarType)) - return true; - - unsigned srcLaneSize = getVectorLaneSize(srcType); - unsigned dstLaneSize = getVectorLaneSize(dstType); - unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth(); - unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth(); - - if (!(srcLaneSize == 32 && (dstElWidth < srcElWidth) && - (dstLaneSize == srcLaneSize))) - return true; - - return false; - }); - - if (auto op = getOperation(); - failed(applyPartialConversion(op, target, std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - // This pass widens UPD ops to twice the width followed by an ext op of the // bottom half. This can be used together with SimplifyUPDOpsPass to find // additional common subexpressions with UPDs generated from unaligned @@ -3692,7 +3726,6 @@ void xilinx::aievec::buildLowerVectorToAIEVec( OpPassManager &pm, const LowerVectorToAIEVecOptions &options) { // Add lowering from `Vector` to `AIEVec` pm.addPass(createLowerVectorToAIEVec(options)); - pm.addPass(std::make_unique()); pm.addPass(createCanonicalizerPass()); // Simplify UPD ops diff --git a/test/Conversion/VectorToAIEVec/test-trunc-ext.mlir b/test/Conversion/VectorToAIEVec/test-trunc-ext.mlir new file mode 100644 index 0000000000..365e257b61 --- /dev/null +++ b/test/Conversion/VectorToAIEVec/test-trunc-ext.mlir @@ -0,0 +1,79 @@ +// RUN: aie-opt %s --convert-vector-to-aievec | FileCheck %s + +// CHECK-LABEL: func.func @test_exti( +// CHECK-SAME: %[[V1:[a-zA-Z0-9]+]]: vector<32xi8> +// CHECK-SAME: %[[V2:[a-zA-Z0-9]+]]: vector<32xi8> +// CHECK-SAME: %[[V3:[a-zA-Z0-9]+]]: vector<32xi16> +func.func @test_exti(%vi8_4_i16 : vector<32xi8>, %vi8_4_i32 : vector<32xi8>, + %vi16_4_i32 : vector<32xi16>) -> + (vector<32xi16>, vector<32xi32>, vector<32xi32>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : i32 + %0 = arith.extsi %vi8_4_i16 : vector<32xi8> to vector<32xi16> + // CHECK: %[[V1A:.*]] = aievec.ups %[[V1]] {shift = 0 : i8} : + // CHECK-SAME: vector<32xi8>, vector<32xi32> + // CHECK: %[[V1E:.*]] = aievec.srs %[[V1A]], %[[C0]] : + // CHECK-SAME: vector<32xi32>, i32, vector<32xi16> + %1 = arith.extsi %vi8_4_i32 : vector<32xi8> to vector<32xi32> + // CHECK: %[[V2A:.*]] = aievec.ups %[[V2]] {shift = 0 : i8} : + // CHECK-SAME: vector<32xi8>, vector<32xi32> + // CHECK: %[[V2E:.*]] = aievec.cast %[[V2A]] {isResAcc = false} : + // CHECK-SAME: vector<32xi32>, vector<32xi32> + %2 = arith.extsi %vi16_4_i32 : vector<32xi16> to vector<32xi32> + // CHECK: %[[V3A:.*]] = aievec.ups %[[V3]] {shift = 0 : i8} : + // CHECK-SAME: vector<32xi16>, vector<32xi32> + // CHECK: %[[V3E:.*]] = aievec.cast %[[V3A]] {isResAcc = false} : + // CHECK-SAME: vector<32xi32>, vector<32xi32> + return %0, %1, %2 : vector<32xi16>, vector<32xi32>, vector<32xi32> + // CHECK: return %[[V1E]], %[[V2E]], %[[V3E]] +} + +// CHECK-LABEL: func.func @test_extf( +// CHECK-SAME: %[[V:[a-zA-Z0-9]+]]: vector<16xbf16> +func.func @test_extf(%vbf16_4_f32 : vector<16xbf16>) -> vector<16xf32> { + %0 = arith.extf %vbf16_4_f32 : vector<16xbf16> to vector<16xf32> + // CHECK: %[[VA:.*]] = aievec.ups %[[V]] {shift = 0 : i8} : + // CHECK-SAME: vector<16xbf16>, vector<16xf32> + // CHECK: %[[VE:.*]] = aievec.cast %[[VA]] {isResAcc = false} : + // CHECK-SAME: vector<16xf32>, vector<16xf32> + return %0 : vector<16xf32> + // CHECK: return %[[VE]] +} + +// CHECK-LABEL: func.func @test_trunci( +// CHECK-SAME: %[[V1:[a-zA-Z0-9]+]]: vector<32xi32> +// CHECK-SAME: %[[V2:[a-zA-Z0-9]+]]: vector<32xi32> +// CHECK-SAME: %[[V3:[a-zA-Z0-9]+]]: vector<32xi16> +func.func @test_trunci(%vi32_4_i8 : vector<32xi32>, %vi32_4_i16 : vector<32xi32>, + %vi16_4_i8 : vector<32xi16>) -> + (vector<32xi8>, vector<32xi16>, vector<32xi8>) { + // CHECK: %[[C0:.*]] = arith.constant 0 : i32 + %0 = arith.trunci %vi32_4_i8 : vector<32xi32> to vector<32xi8> + // CHECK: %[[V1A:.*]] = aievec.cast %[[V1]] {isResAcc = true} : + // CHECK-SAME: vector<32xi32>, vector<32xi32> + // CHECK: %[[V1T:.*]] = aievec.srs %[[V1A]], %[[C0]] : + // CHECK-SAME: vector<32xi32>, i32, vector<32xi8> + %1 = arith.trunci %vi32_4_i16 : vector<32xi32> to vector<32xi16> + // CHECK: %[[V2A:.*]] = aievec.cast %[[V2]] {isResAcc = true} : + // CHECK-SAME: vector<32xi32>, vector<32xi32> + // CHECK: %[[V2T:.*]] = aievec.srs %[[V2A]], %[[C0]] : + // CHECK-SAME: vector<32xi32>, i32, vector<32xi16> + %2 = arith.trunci %vi16_4_i8 : vector<32xi16> to vector<32xi8> + // CHECK: %[[V3A:.*]] = aievec.ups %[[V3]] {shift = 0 : i8} : + // CHECK-SAME: vector<32xi16>, vector<32xi32> + // CHECK: %[[V3T:.*]] = aievec.srs %[[V3A]], %[[C0]] : + // CHECK-SAME: vector<32xi32>, i32, vector<32xi8> + return %0, %1, %2 : vector<32xi8>, vector<32xi16>, vector<32xi8> + // CHECK: return %[[V1T]], %[[V2T]], %[[V3T]] +} + +// CHECK-LABEL: func.func @test_truncf( +// CHECK-SAME: %[[V:[a-zA-Z0-9]+]]: vector<16xf32> +func.func @test_truncf(%vf32_4_bf16 : vector<16xf32>) -> vector<16xbf16> { + %0 = arith.truncf %vf32_4_bf16 : vector<16xf32> to vector<16xbf16> + // CHECK: %[[VA:.*]] = aievec.cast %[[V]] {isResAcc = true} : + // CHECK-SAME: vector<16xf32>, vector<16xf32> + // CHECK: %[[VT:.*]] = aievec.srs %[[VA]], %[[C0]] : + // CHECK-SAME: vector<16xf32>, i32, vector<16xbf16> + return %0 : vector<16xbf16> + // CHECK: return %[[VE]] +} \ No newline at end of file