diff --git a/include/aie/Conversion/CMakeLists.txt b/include/aie/Conversion/CMakeLists.txt index af41b487bf..40f9c3ec7e 100644 --- a/include/aie/Conversion/CMakeLists.txt +++ b/include/aie/Conversion/CMakeLists.txt @@ -2,6 +2,8 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name Conversion) mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header -prefix Conversion) mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl -prefix Conversion) +mlir_tablegen(PassesEnums.h.inc -gen-enum-decls) +mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRAIEConversionPassIncGen) add_mlir_doc(Passes MLIRAIEConversionPasses ./ -gen-pass-doc) diff --git a/include/aie/Conversion/Passes.h b/include/aie/Conversion/Passes.h index efc21e74fd..53797352cf 100644 --- a/include/aie/Conversion/Passes.h +++ b/include/aie/Conversion/Passes.h @@ -12,9 +12,13 @@ #define AIE_CONVERSION_PASSES_H #include "aie/Conversion/AIEVecToLLVM/AIEVecToLLVM.h" +#include "aie/Conversion/PassesEnums.h.inc" namespace xilinx { +#define GEN_PASS_DECL +#include "aie/Conversion/Passes.h.inc" + #define GEN_PASS_REGISTRATION #include "aie/Conversion/Passes.h.inc" diff --git a/include/aie/Conversion/Passes.td b/include/aie/Conversion/Passes.td index 23051cdd32..4dc6247b30 100644 --- a/include/aie/Conversion/Passes.td +++ b/include/aie/Conversion/Passes.td @@ -11,11 +11,21 @@ #ifndef AIE_CONVERSION_PASSES #define AIE_CONVERSION_PASSES +include "mlir/IR/EnumAttr.td" include "mlir/Pass/PassBase.td" //===----------------------------------------------------------------------===// // AIEVecToLLVM //===----------------------------------------------------------------------===// +def Aie2Fp32EmulationType : I32EnumAttr<"Aie2Fp32Emulation", "AIE2 FP32 Emulation", + [ + I32EnumAttrCase<"AccuracySafe", 0, "accuracy-safe">, + I32EnumAttrCase<"AccuracyFast", 1, "accuracy-fast">, + I32EnumAttrCase<"AccuracyLow", 2, "accuracy-low"> + ]>{ + let cppNamespace = "xilinx::aievec"; +} + def ConvertAIEVecToLLVM : Pass<"convert-aievec-to-llvm", "mlir::ModuleOp"> { let summary = "Convert AIEVec dialect to LLVM dialect"; let description = [{ @@ -26,6 +36,19 @@ def ConvertAIEVecToLLVM : Pass<"convert-aievec-to-llvm", "mlir::ModuleOp"> { "mlir::arith::ArithDialect", "mlir::vector::VectorDialect", "xilinx::xllvm::XLLVMDialect"]; + let options = [ + Option<"aie2Fp32Emulation", "aie2-fp32-emulation-strategy", "xilinx::aievec::Aie2Fp32Emulation", + /*default=*/"xilinx::aievec::Aie2Fp32Emulation::AccuracySafe", + "Set the AIE2 FP32 emulation strategy. Elementwise multiplication and matrix multiplication intrinsics for FP32 input type are emulated using bfloat16 data-path.", + [{::llvm::cl::values( + clEnumValN(xilinx::aievec::Aie2Fp32Emulation::AccuracySafe, "accuracy-safe", + "Most accurate option since input fp32 number is split into 3 bfloat16 numbers. float a*b would require 9 mac operations due to 3 bfloat16 splits each."), + clEnumValN(xilinx::aievec::Aie2Fp32Emulation::AccuracyFast, "accuracy-fast", + "Fast and Accurate option. Input fp32 number is split in to 3 bfloat16 numbers. In the 9 mac operations to emulate fp32 mul, mac operations with LSBs are ignored. (3 last terms)."), + clEnumValN(xilinx::aievec::Aie2Fp32Emulation::AccuracyLow, "accuracy-low", + "Fast and least accurate option. Input fp32 number is split in to 2 bfloat16 numbers. In the 4 mac operations to emulate fp32 mul, mac operations with LSBs are ignored. (1 last term).") + )}]> + ]; } #endif // AIE_CONVERSION_PASSES diff --git a/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td b/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td index 812f104a3d..236e9779cc 100644 --- a/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td +++ b/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td @@ -106,28 +106,28 @@ def VectorSetI512I256IntrOp : def I512V32Acc32SrsIntrOp : AIEVec2_IntrOp<"I512.v32.acc32.srs", [TypeIs<"res", VectorOfLengthAndType<[32], [I16]>>]>, - Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$a, - I32:$shft, + Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$src, + I32:$shift, I32:$sign)>; def I256V32Acc32SrsIntrOp : AIEVec2_IntrOp<"I256.v32.acc32.srs", [TypeIs<"res", VectorOfLengthAndType<[32], [I8]>>]>, - Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$a, - I32:$shft, + Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$src, + I32:$shift, I32:$sign)>; def I512V16Acc64SrsIntrOp : AIEVec2_IntrOp<"I512.v16.acc64.srs", [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>, - Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$a, - I32:$shft, + Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$src, + I32:$shift, I32:$sign)>; def Vector16AccFloatToV16BF16IntrOp : AIEVec2_IntrOp<"v16accfloat.to.v16bf16", [TypeIs<"res", VectorOfLengthAndType<[16], [BF16]>>]>, - Arguments<(ins VectorOfLengthAndType<[8], [I64]>:$a)>; + Arguments<(ins VectorOfLengthAndType<[8], [I64]>:$src)>; // ----- BROADCAST ----- @@ -151,7 +151,7 @@ def VectorBroadcast16BF512IntrOp : def ExtI256I512IntrOp : AIEVec2_IntrOp<"ext.I256.I512", [TypeIs<"res", VectorOfLengthAndType<[8], [I32]>>]>, - Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$a, + Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src, I32:$idx)>; // ----- CONCAT ----- @@ -159,16 +159,16 @@ def ExtI256I512IntrOp : def ConcatI512I256IntrOp : AIEVec2_IntrOp<"concat.I512.I256", [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>, - Arguments<(ins VectorOfLengthAndType<[8], [I32]>:$a0, - VectorOfLengthAndType<[8], [I32]>:$a1)>; + Arguments<(ins VectorOfLengthAndType<[8], [I32]>:$src0, + VectorOfLengthAndType<[8], [I32]>:$src1)>; // ----- SHUFFLE ----- def VectorShuffleIntrOp : AIEVec2_IntrOp<"vshuffle", [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>, - Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$a, - VectorOfLengthAndType<[16], [I32]>:$b, + Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src0, + VectorOfLengthAndType<[16], [I32]>:$src1, I32:$mode)>; // ----- UNDEF ----- @@ -182,8 +182,8 @@ def UndefV16I32IntrOp : def UpdBF512BF256IntrOp : AIEVec2_IntrOp<"upd.bf512.bf256", [TypeIs<"res", VectorOfLengthAndType<[32], [BF16]>>]>, - Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$a, - VectorOfLengthAndType<[16], [BF16]>:$b, + Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$dst, + VectorOfLengthAndType<[16], [BF16]>:$src, I32:$idx)>; #endif // AIE_DIALECT_XLLVM_IR_XLLVMAIE2INTROPS_TD diff --git a/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp b/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp index ee11baf989..bb35d631f7 100644 --- a/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp +++ b/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp @@ -359,6 +359,13 @@ class MulElemOpConversion public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + MulElemOpConversion(const LLVMTypeConverter &typeConverter, + Aie2Fp32Emulation aie2Fp32EmulationOption) + : ConvertOpToLLVMPattern(typeConverter), + aie2Fp32EmulationOption(aie2Fp32EmulationOption) {} + + Aie2Fp32Emulation aie2Fp32EmulationOption; + struct DecodedMulElemOp { enum class Kind { // DtIn0_DtIn1_DtRes_CxMxKxN @@ -681,10 +688,7 @@ class MulElemOpConversion auto [d, e, f] = extractV16FP32ToThreeV16BF16(adaptor.getRhs(), dZeros, eZeros, fZeros); - // 1 MUL + 8 * MACs - auto cfMul = rewriter.create( - loc, VectorType::get({8}, rewriter.getI64Type()), c, f, - mscMacMulConfCst); + // Create 1 MUL and 2/5/8 MACs depending on the Aie2Fp32EmulationOption auto createMacOps = [&](Value lhs, Value rhs, Value acc) -> Value { return rewriter .create( @@ -692,23 +696,66 @@ class MulElemOpConversion mscMacMulConfCst) .getResult(); }; - auto adMac = createMacOps( - a, d, - createMacOps( - a, e, - createMacOps( - b, d, - createMacOps( - d, c, - createMacOps( - b, e, - createMacOps( - a, f, - createMacOps(b, f, createMacOps(c, e, cfMul)))))))); + + Value finalMacVal; + if (aie2Fp32EmulationOption == Aie2Fp32Emulation::AccuracyFast) { + // Fast and Accurate option. float a*b would require 6 mac operations. + // Input fp32 number is split in to 3 bfloat16 numbers to extract all the + // bits of the mantissa. float a,b; both a and b are split in to 3 + // bfloat16 numbers each. Hence there would be 9 mac operations in + // multiplication of a and b. In the 9 mac operations to emulate fp32 mul, + // mac operations with LSBs are ignored. (3 last terms). This helps + // improve cycle count of mul and has least impact on accuracy of result. + // This is the default option to the aiecompiler + auto afMul = rewriter.create( + loc, VectorType::get({8}, rewriter.getI64Type()), a, f, + mscMacMulConfCst); + finalMacVal = createMacOps( + a, d, + createMacOps( + a, e, + createMacOps(b, d, + createMacOps(d, c, createMacOps(b, e, afMul))))); + } else if (aie2Fp32EmulationOption == Aie2Fp32Emulation::AccuracyLow) { + // Fast and least accurate option. float a*b would require 3 mac + // operations. + // Input fp32 number is split in to 2 bfloat16 numbers. Hence not all the + // bits from mantissa can be used. float a,b; Both a and b are split in to + // 2 bfloat16 numbers each. Hence there would be 4 mac operations in + // multiplication of a and b. In the 4 mac operations to emulate fp32 mul, + // mac operations with LSBs are ignored. (1 last term). This helps improve + // cycle count of mul float a, b; + auto bdMul = rewriter.create( + loc, VectorType::get({8}, rewriter.getI64Type()), b, d, + mscMacMulConfCst); + finalMacVal = createMacOps(a, d, createMacOps(a, e, bdMul)); + } else { + // aie2Fp32EmulationOption == Aie2Fp32Emulation::AccuracySafe + // Most accurate option since input fp32 number is split in to 3 bfloat16 + // numbers to extract all the bits of the mantissa. float a*b would + // require 9 mac operations due to 3 bfloat16 splits each. + auto cfMul = rewriter.create( + loc, VectorType::get({8}, rewriter.getI64Type()), c, f, + mscMacMulConfCst); + finalMacVal = createMacOps( + a, d, + createMacOps( + a, e, + createMacOps( + b, d, + createMacOps( + d, c, + createMacOps( + b, e, + createMacOps( + a, f, + createMacOps(b, f, + createMacOps(c, e, cfMul)))))))); + } // create bitcast for result rewriter.replaceOpWithNewOp(op, op.getResult().getType(), - adMac); + finalMacVal); return success(); } @@ -1457,8 +1504,9 @@ class FoldAIECastOps : public mlir::ConvertOpToLLVMPattern { } }; -void populateAIEVecToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, - mlir::RewritePatternSet &patterns) { +void populateAIEVecToLLVMConversionPatterns( + mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns, + Aie2Fp32Emulation aie2Fp32EmulationOption) { // clang-format off patterns.add(converter); + patterns.add(converter, aie2Fp32EmulationOption); // clang-format on } @@ -1492,7 +1540,8 @@ struct ConvertAIEVecToLLVMPass converter.addConversion( [&](VectorType type) -> std::optional { return type; }); - populateAIEVecToLLVMConversionPatterns(converter, patterns); + populateAIEVecToLLVMConversionPatterns(converter, patterns, + aie2Fp32Emulation); LLVMConversionTarget target(getContext()); target.addIllegalDialect(); diff --git a/lib/Conversion/PassDetail.h b/lib/Conversion/PassDetail.h index 51b78bb5f8..537986193e 100644 --- a/lib/Conversion/PassDetail.h +++ b/lib/Conversion/PassDetail.h @@ -11,6 +11,7 @@ #ifndef AIE_CONVERSION_PASSDETAIL_H_ #define AIE_CONVERSION_PASSDETAIL_H_ +#include "aie/Conversion/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinOps.h" diff --git a/test/Conversion/AIEVecToLLVM/mul_elem.mlir b/test/Conversion/AIEVecToLLVM/mul_elem.mlir index 1c7ce8c3fd..38193386e7 100644 --- a/test/Conversion/AIEVecToLLVM/mul_elem.mlir +++ b/test/Conversion/AIEVecToLLVM/mul_elem.mlir @@ -1,4 +1,7 @@ // RUN: aie-opt %s -split-input-file -convert-aievec-to-llvm | FileCheck %s +// RUN: aie-opt %s -split-input-file -convert-aievec-to-llvm="aie2-fp32-emulation-strategy=accuracy-fast" | FileCheck --check-prefix=FP32FAST %s +// RUN: aie-opt %s -split-input-file -convert-aievec-to-llvm="aie2-fp32-emulation-strategy=accuracy-low" | FileCheck --check-prefix=FP32LOW %s + func.func @i16_i16_i32_mul_elem(%arg0 : vector<32xi16>, %arg1 : vector<32xi16>) -> vector<32xi32> { %0 = aievec.mul_elem %arg0, %arg1 : vector<32xi16>, vector<32xi16>, vector<32xi32> @@ -109,7 +112,7 @@ func.func @f32_f32_f32_mul_elem(%arg0 : vector<16xf32>, %arg1 : vector<16xf32>) // CHECK-NEXT: %[[ONES0:.*]] = "xllvm.intr.aie2.vbroadcast16.bf512"(%[[CST1]]) : (bf16) -> vector<32xbf16> // CHECK-NEXT: %[[CST2:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK-NEXT: %[[CONF:.*]] = llvm.mlir.constant(60 : i32) : i32 -// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %arg0 : vector<16xf32> to vector<8xi64> +// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<16xf32> to vector<8xi64> // CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(%[[BITCAST0]]) : (vector<8xi64>) -> vector<16xbf16> // CHECK-NEXT: %[[UPD0:.*]] = "xllvm.intr.aie2.upd.bf512.bf256"(%[[ZEROS0]], %[[SRS0]], %[[CST2]]) : (vector<32xbf16>, vector<16xbf16>, i32) -> vector<32xbf16> // CHECK-NEXT: %[[MSC0:.*]] = "xllvm.intr.aie2.bf.msc16.conf"(%[[UPD0]], %[[ONES0]], %[[BITCAST0]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> @@ -118,7 +121,7 @@ func.func @f32_f32_f32_mul_elem(%arg0 : vector<16xf32>, %arg1 : vector<16xf32>) // CHECK-NEXT: %[[MSC1:.*]] = "xllvm.intr.aie2.bf.msc16.conf"(%[[UPD1]], %[[ONES0]], %[[MSC0]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> // CHECK-NEXT: %[[SRS2:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(%[[MSC1]]) : (vector<8xi64>) -> vector<16xbf16> // CHECK-NEXT: %[[UPD2:.*]] = "xllvm.intr.aie2.upd.bf512.bf256"(%[[ZEROS2]], %[[SRS2]], %[[CST2]]) : (vector<32xbf16>, vector<16xbf16>, i32) -> vector<32xbf16> -// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %arg1 : vector<16xf32> to vector<8xi64> +// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[ARG1]] : vector<16xf32> to vector<8xi64> // CHECK-NEXT: %[[SRS3:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(%[[BITCAST1]]) : (vector<8xi64>) -> vector<16xbf16> // CHECK-NEXT: %[[UPD3:.*]] = "xllvm.intr.aie2.upd.bf512.bf256"(%[[ZEROS3]], %[[SRS3]], %[[CST2]]) : (vector<32xbf16>, vector<16xbf16>, i32) -> vector<32xbf16> // CHECK-NEXT: %[[MSC2:.*]] = "xllvm.intr.aie2.bf.msc16.conf"(%[[UPD3]], %[[ONES0]], %[[BITCAST1]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> @@ -137,4 +140,83 @@ func.func @f32_f32_f32_mul_elem(%arg0 : vector<16xf32>, %arg1 : vector<16xf32>) // CHECK-NEXT: %[[ACC7:.*]] = "xllvm.intr.aie2.bf.mac16.conf"(%[[UPD0]], %[[UPD4]], %[[ACC6]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> // CHECK-NEXT: %[[ACC8:.*]] = "xllvm.intr.aie2.bf.mac16.conf"(%[[UPD0]], %[[UPD3]], %[[ACC7]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> // CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[ACC8]] : vector<8xi64> to vector<16xf32> -// CHECK-NEXT: return %[[RES]] : vector<16xf32> \ No newline at end of file +// CHECK-NEXT: return %[[RES]] : vector<16xf32> + +// FP32FAST-LABEL: @f32_f32_f32_mul_elem +// FP32FAST-SAME: %[[ARG0:.*]]: vector<16xf32>, +// FP32FAST-SAME: %[[ARG1:.*]]: vector<16xf32> +// FP32FAST: %[[CST0:.*]] = llvm.mlir.constant(0.000000e+00 : bf16) : bf16 +// FP32FAST-NEXT: %[[ZEROS0:.*]] = "xllvm.intr.aie2.vbroadcast16.bf512"(%[[CST0]]) : (bf16) -> vector<32xbf16> +// FP32FAST-NEXT: %[[ZEROS1:.*]] = "xllvm.intr.aie2.vbroadcast16.bf512"(%[[CST0]]) : (bf16) -> vector<32xbf16> +// FP32FAST-NEXT: %[[ZEROS2:.*]] = "xllvm.intr.aie2.vbroadcast16.bf512"(%[[CST0]]) : (bf16) -> vector<32xbf16> +// FP32FAST-NEXT: %[[ZEROS3:.*]] = "xllvm.intr.aie2.vbroadcast16.bf512"(%[[CST0]]) : (bf16) -> vector<32xbf16> +// FP32FAST-NEXT: %[[ZEROS4:.*]] = "xllvm.intr.aie2.vbroadcast16.bf512"(%[[CST0]]) : (bf16) -> vector<32xbf16> +// FP32FAST-NEXT: %[[ZEROS5:.*]] = "xllvm.intr.aie2.vbroadcast16.bf512"(%[[CST0]]) : (bf16) -> vector<32xbf16> +// FP32FAST-NEXT: %[[CST1:.*]] = llvm.mlir.constant(1.000000e+00 : bf16) : bf16 +// FP32FAST-NEXT: %[[ONES0:.*]] = "xllvm.intr.aie2.vbroadcast16.bf512"(%[[CST1]]) : (bf16) -> vector<32xbf16> +// FP32FAST-NEXT: %[[CST2:.*]] = llvm.mlir.constant(0 : i32) : i32 +// FP32FAST-NEXT: %[[CONF:.*]] = llvm.mlir.constant(60 : i32) : i32 +// FP32FAST-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<16xf32> to vector<8xi64> +// FP32FAST-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(%[[BITCAST0]]) : (vector<8xi64>) -> vector<16xbf16> +// FP32FAST-NEXT: %[[UPD0:.*]] = "xllvm.intr.aie2.upd.bf512.bf256"(%[[ZEROS0]], %[[SRS0]], %[[CST2]]) : (vector<32xbf16>, vector<16xbf16>, i32) -> vector<32xbf16> +// FP32FAST-NEXT: %[[MSC0:.*]] = "xllvm.intr.aie2.bf.msc16.conf"(%[[UPD0]], %[[ONES0]], %[[BITCAST0]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> +// FP32FAST-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(%[[MSC0]]) : (vector<8xi64>) -> vector<16xbf16> +// FP32FAST-NEXT: %[[UPD1:.*]] = "xllvm.intr.aie2.upd.bf512.bf256"(%[[ZEROS1]], %[[SRS1]], %[[CST2]]) : (vector<32xbf16>, vector<16xbf16>, i32) -> vector<32xbf16> +// FP32FAST-NEXT: %[[MSC1:.*]] = "xllvm.intr.aie2.bf.msc16.conf"(%[[UPD1]], %[[ONES0]], %[[MSC0]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> +// FP32FAST-NEXT: %[[SRS2:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(%[[MSC1]]) : (vector<8xi64>) -> vector<16xbf16> +// FP32FAST-NEXT: %[[UPD2:.*]] = "xllvm.intr.aie2.upd.bf512.bf256"(%[[ZEROS2]], %[[SRS2]], %[[CST2]]) : (vector<32xbf16>, vector<16xbf16>, i32) -> vector<32xbf16> +// FP32FAST-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[ARG1]] : vector<16xf32> to vector<8xi64> +// FP32FAST-NEXT: %[[SRS3:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(%[[BITCAST1]]) : (vector<8xi64>) -> vector<16xbf16> +// FP32FAST-NEXT: %[[UPD3:.*]] = "xllvm.intr.aie2.upd.bf512.bf256"(%[[ZEROS3]], %[[SRS3]], %[[CST2]]) : (vector<32xbf16>, vector<16xbf16>, i32) -> vector<32xbf16> +// FP32FAST-NEXT: %[[MSC2:.*]] = "xllvm.intr.aie2.bf.msc16.conf"(%[[UPD3]], %[[ONES0]], %[[BITCAST1]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> +// FP32FAST-NEXT: %[[SRS4:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(%[[MSC2]]) : (vector<8xi64>) -> vector<16xbf16> +// FP32FAST-NEXT: %[[UPD4:.*]] = "xllvm.intr.aie2.upd.bf512.bf256"(%[[ZEROS4]], %[[SRS4]], %[[CST2]]) : (vector<32xbf16>, vector<16xbf16>, i32) -> vector<32xbf16> +// FP32FAST-NEXT: %[[MSC3:.*]] = "xllvm.intr.aie2.bf.msc16.conf"(%[[UPD4]], %[[ONES0]], %[[MSC2]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> +// FP32FAST-NEXT: %[[SRS5:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(%[[MSC3]]) : (vector<8xi64>) -> vector<16xbf16> +// FP32FAST-NEXT: %[[UPD5:.*]] = "xllvm.intr.aie2.upd.bf512.bf256"(%[[ZEROS5]], %[[SRS5]], %[[CST2]]) : (vector<32xbf16>, vector<16xbf16>, i32) -> vector<32xbf16> +// FP32FAST-NEXT: %[[ACC0:.*]] = "xllvm.intr.aie2.bf.mul16.conf"(%[[UPD0]], %[[UPD5]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, i32) -> vector<8xi64> +// FP32FAST-NEXT: %[[ACC1:.*]] = "xllvm.intr.aie2.bf.mac16.conf"(%[[UPD1]], %[[UPD4]], %[[ACC0]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> +// FP32FAST-NEXT: %[[ACC2:.*]] = "xllvm.intr.aie2.bf.mac16.conf"(%[[UPD3]], %[[UPD2]], %[[ACC1]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> +// FP32FAST-NEXT: %[[ACC3:.*]] = "xllvm.intr.aie2.bf.mac16.conf"(%[[UPD1]], %[[UPD3]], %[[ACC2]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> +// FP32FAST-NEXT: %[[ACC4:.*]] = "xllvm.intr.aie2.bf.mac16.conf"(%[[UPD0]], %[[UPD4]], %[[ACC3]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> +// FP32FAST-NEXT: %[[ACC5:.*]] = "xllvm.intr.aie2.bf.mac16.conf"(%[[UPD0]], %[[UPD3]], %[[ACC4]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> +// FP32FAST-NEXT: %[[RES:.*]] = llvm.bitcast %[[ACC5]] : vector<8xi64> to vector<16xf32> +// FP32FAST-NEXT: return %[[RES]] : vector<16xf32> + +// FP32LOW-LABEL: @f32_f32_f32_mul_elem +// FP32LOW-SAME: %[[ARG0:.*]]: vector<16xf32>, +// FP32LOW-SAME: %[[ARG1:.*]]: vector<16xf32> +// FP32LOW: %[[CST0:.*]] = llvm.mlir.constant(0.000000e+00 : bf16) : bf16 +// FP32LOW-NEXT: %[[ZEROS0:.*]] = "xllvm.intr.aie2.vbroadcast16.bf512"(%[[CST0]]) : (bf16) -> vector<32xbf16> +// FP32LOW-NEXT: %[[ZEROS1:.*]] = "xllvm.intr.aie2.vbroadcast16.bf512"(%[[CST0]]) : (bf16) -> vector<32xbf16> +// FP32LOW-NEXT: %[[ZEROS2:.*]] = "xllvm.intr.aie2.vbroadcast16.bf512"(%[[CST0]]) : (bf16) -> vector<32xbf16> +// FP32LOW-NEXT: %[[ZEROS3:.*]] = "xllvm.intr.aie2.vbroadcast16.bf512"(%[[CST0]]) : (bf16) -> vector<32xbf16> +// FP32LOW-NEXT: %[[ZEROS4:.*]] = "xllvm.intr.aie2.vbroadcast16.bf512"(%[[CST0]]) : (bf16) -> vector<32xbf16> +// FP32LOW-NEXT: %[[ZEROS5:.*]] = "xllvm.intr.aie2.vbroadcast16.bf512"(%[[CST0]]) : (bf16) -> vector<32xbf16> +// FP32LOW-NEXT: %[[CST1:.*]] = llvm.mlir.constant(1.000000e+00 : bf16) : bf16 +// FP32LOW-NEXT: %[[ONES0:.*]] = "xllvm.intr.aie2.vbroadcast16.bf512"(%[[CST1]]) : (bf16) -> vector<32xbf16> +// FP32LOW-NEXT: %[[CST2:.*]] = llvm.mlir.constant(0 : i32) : i32 +// FP32LOW-NEXT: %[[CONF:.*]] = llvm.mlir.constant(60 : i32) : i32 +// FP32LOW-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<16xf32> to vector<8xi64> +// FP32LOW-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(%[[BITCAST0]]) : (vector<8xi64>) -> vector<16xbf16> +// FP32LOW-NEXT: %[[UPD0:.*]] = "xllvm.intr.aie2.upd.bf512.bf256"(%[[ZEROS0]], %[[SRS0]], %[[CST2]]) : (vector<32xbf16>, vector<16xbf16>, i32) -> vector<32xbf16> +// FP32LOW-NEXT: %[[MSC0:.*]] = "xllvm.intr.aie2.bf.msc16.conf"(%[[UPD0]], %[[ONES0]], %[[BITCAST0]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> +// FP32LOW-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(%[[MSC0]]) : (vector<8xi64>) -> vector<16xbf16> +// FP32LOW-NEXT: %[[UPD1:.*]] = "xllvm.intr.aie2.upd.bf512.bf256"(%[[ZEROS1]], %[[SRS1]], %[[CST2]]) : (vector<32xbf16>, vector<16xbf16>, i32) -> vector<32xbf16> +// FP32LOW-NEXT: %[[MSC1:.*]] = "xllvm.intr.aie2.bf.msc16.conf"(%[[UPD1]], %[[ONES0]], %[[MSC0]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> +// FP32LOW-NEXT: %[[SRS2:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(%[[MSC1]]) : (vector<8xi64>) -> vector<16xbf16> +// FP32LOW-NEXT: %[[UPD2:.*]] = "xllvm.intr.aie2.upd.bf512.bf256"(%[[ZEROS2]], %[[SRS2]], %[[CST2]]) : (vector<32xbf16>, vector<16xbf16>, i32) -> vector<32xbf16> +// FP32LOW-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[ARG1]] : vector<16xf32> to vector<8xi64> +// FP32LOW-NEXT: %[[SRS3:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(%[[BITCAST1]]) : (vector<8xi64>) -> vector<16xbf16> +// FP32LOW-NEXT: %[[UPD3:.*]] = "xllvm.intr.aie2.upd.bf512.bf256"(%[[ZEROS3]], %[[SRS3]], %[[CST2]]) : (vector<32xbf16>, vector<16xbf16>, i32) -> vector<32xbf16> +// FP32LOW-NEXT: %[[MSC2:.*]] = "xllvm.intr.aie2.bf.msc16.conf"(%[[UPD3]], %[[ONES0]], %[[BITCAST1]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> +// FP32LOW-NEXT: %[[SRS4:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(%[[MSC2]]) : (vector<8xi64>) -> vector<16xbf16> +// FP32LOW-NEXT: %[[UPD4:.*]] = "xllvm.intr.aie2.upd.bf512.bf256"(%[[ZEROS4]], %[[SRS4]], %[[CST2]]) : (vector<32xbf16>, vector<16xbf16>, i32) -> vector<32xbf16> +// FP32LOW-NEXT: %[[MSC3:.*]] = "xllvm.intr.aie2.bf.msc16.conf"(%[[UPD4]], %[[ONES0]], %[[MSC2]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> +// FP32LOW-NEXT: %[[SRS5:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(%[[MSC3]]) : (vector<8xi64>) -> vector<16xbf16> +// FP32LOW-NEXT: %[[UPD5:.*]] = "xllvm.intr.aie2.upd.bf512.bf256"(%[[ZEROS5]], %[[SRS5]], %[[CST2]]) : (vector<32xbf16>, vector<16xbf16>, i32) -> vector<32xbf16> +// FP32LOW-NEXT: %[[ACC0:.*]] = "xllvm.intr.aie2.bf.mul16.conf"(%[[UPD1]], %[[UPD3]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, i32) -> vector<8xi64> +// FP32LOW-NEXT: %[[ACC1:.*]] = "xllvm.intr.aie2.bf.mac16.conf"(%[[UPD0]], %[[UPD4]], %[[ACC0]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> +// FP32LOW-NEXT: %[[ACC2:.*]] = "xllvm.intr.aie2.bf.mac16.conf"(%[[UPD0]], %[[UPD3]], %[[ACC1]], %[[CONF]]) : (vector<32xbf16>, vector<32xbf16>, vector<8xi64>, i32) -> vector<8xi64> +// FP32LOW-NEXT: %[[RES:.*]] = llvm.bitcast %[[ACC2]] : vector<8xi64> to vector<16xf32> +// FP32LOW-NEXT: return %[[RES]] : vector<16xf32>