From ecfe1083bf200d03328c89a28fc5f9de3437e86f Mon Sep 17 00:00:00 2001 From: James Lin Date: Thu, 11 Apr 2024 15:12:30 -0500 Subject: [PATCH] [aievec] to-llvm flow for emulated i32xi32 elementwise multiplication (#1176) * Add some of the mul/srs/shuflle intrinsics to XLLVM. * Add aievec-to-llvm conversion pattern for the emulated i32xi32 elementwise multiplication. * Add aievec-to-llvm conversion tests for the newly added XLLVM ops. * Add target llvm translation tests. * Add i32xi32_mul_elem e2e tests for the to-llvm flow. This includes updating the testbench.cc and the test script. The e2e test, like other aievec to-cpp tests, goes through the simulator to verify the numeric correctness. --- .../aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td | 23 +++ lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp | 194 ++++++++++++++++-- test/Conversion/AIEVecToLLVM/mul_elem.mlir | 36 ++++ test/Conversion/AIEVecToLLVM/test-srs.mlir | 23 +++ test/Target/LLVMIR/aievec.mlir | 33 +++ .../i32xi32_mul_elem-peano.mlir | 28 +++ .../i32xi32_mul_elem/i32xi32_mul_elem.mlir | 6 +- .../i32xi32_mul_elem/testbench.cc | 16 ++ 8 files changed, 335 insertions(+), 24 deletions(-) create mode 100644 test/unit_tests/aievec_tests/i32xi32_mul_elem/i32xi32_mul_elem-peano.mlir diff --git a/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td b/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td index c00303c739..94e79ff5ef 100644 --- a/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td +++ b/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td @@ -64,6 +64,13 @@ def MulConfAcc32IntrOp : VectorOfLengthAndType<[16], [I32]>:$rhs, I32:$conf)>; +def MulConfAcc64IntrOp : + AIEVec2_IntrOp<"I512.I512.acc64.mul.conf", + [TypeIs<"res", VectorOfLengthAndType<[16], [I64]>>]>, + Arguments<(ins VectorOfLengthAndType<[64], [I8]>:$lhs, + VectorOfLengthAndType<[16], [I32]>:$rhs, + I32:$conf)>; + def MulConfBF16IntrOp : AIEVec2_IntrOp<"bf.mul16.conf", [TypeIs<"res", VectorOfLengthAndType<[8], [I64]>>]>, @@ -100,6 +107,13 @@ def I256V32Acc32SrsIntrOp : I32:$shft, I32:$sign)>; +def I512V16Acc64SrsIntrOp : + AIEVec2_IntrOp<"I512.v16.acc64.srs", + [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>, + Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$lhs, + I32:$shft, + I32:$sign)>; + def Vector16AccFloatToV16BF16IntrOp : AIEVec2_IntrOp<"v16accfloat.to.v16bf16", [TypeIs<"res", VectorOfLengthAndType<[16], [BF16]>>]>, @@ -138,6 +152,15 @@ def ConcatI512I256IntrOp : Arguments<(ins VectorOfLengthAndType<[8], [I32]>:$a0, VectorOfLengthAndType<[8], [I32]>:$a1)>; +// ----- SHUFFLE ----- + +def VectorShuffleIntrOp : + AIEVec2_IntrOp<"vshuffle", + [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>, + Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$a, + VectorOfLengthAndType<[16], [I32]>:$b, + I32:$mode)>; + // ----- UNDEF ----- def UndefV16I32IntrOp : diff --git a/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp b/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp index f4e50c03ac..f690b65296 100644 --- a/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp +++ b/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp @@ -364,23 +364,39 @@ class MulElemOpConversion // DtIn0_DtIn1_DtRes_CxMxKxN I8_I8_I32_32x1x2x1, I16_I16_I32_32x1x1x1, + I32_I32_I64_32x1x2x1, BF16_BF16_FP32_16x1x2x1, UNSUPPORTED // TODO: I16_I16_I64_16x1x2x1 - // TODO: I32 and FP32 mul_elem are emulated + // TODO: FP32 mul_elem is emulated }; Kind kind; int conf; }; + // sgn_x: Sign mask of matrix X. If it is one matrix X is interpreted as + // signed, else it treated as unsigned. + // sgn_y: Sign mask of matrix Y. If it is one matrix Y is interpreted as + // signed, else it treated as unsigned. + // amode/bmode/variant: config acc width, mul precision, and mul mode + // zero_acc: Zeroing of acc1. If it is one then acc1 is zeroed. + // shift16: Shift mask of acc1. If a bit is set the <<16 operation will be + // executed on acc1. + // sub_mul: Negation mask of the matrix multiplication result. If it is + // one the result of the operation will be negated. + // sub_acc1: Negation mask of acc1. If it is one acc1 will be negated. + // sub_acc2: Negation mask of acc2. If it is one acc2 will be negated. + // sub_mask: Negation mask of complex multiplications. Negates a term of a + // complex multiplication. static int aiev2_mul_mac_compute_control(int sgn_x, int sgn_y, int amode, int bmode, int variant, int zero_acc, - int shift16, int sub0, int sub1, - int sub2, int sub_mask) { + int shift16, int sub_mul, + int sub_acc1, int sub_acc2, + int sub_mask) { return ((unsigned)sub_mask << 16) | ((unsigned)shift16 << 10) | - ((unsigned)sub0 << 11) | ((unsigned)sub1 << 12) | - ((unsigned)sub2 << 13) | ((unsigned)amode << 1) | + ((unsigned)sub_mul << 11) | ((unsigned)sub_acc1 << 12) | + ((unsigned)sub_acc2 << 13) | ((unsigned)amode << 1) | ((unsigned)bmode << 3) | ((unsigned)variant << 5) | (((unsigned)sgn_x << 9) | ((unsigned)sgn_y << 8)) | ((unsigned)zero_acc << 0); @@ -396,22 +412,145 @@ class MulElemOpConversion if (lhsScaTy.isa()) { if (lhsBitWidth == 8) { return {DecodedMulElemOp::Kind::I8_I8_I32_32x1x2x1, - aiev2_mul_mac_compute_control(1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0)}; + aiev2_mul_mac_compute_control( + /*sgn_x=*/1, /*sgn_y=*/1, /*amode=*/0, /*bmode=*/1, + /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0, + /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, + /*sub_mask=*/0)}; } else if (lhsBitWidth == 16) { return {DecodedMulElemOp::Kind::I16_I16_I32_32x1x1x1, - aiev2_mul_mac_compute_control(1, 1, 0, 3, 1, 0, 0, 0, 0, 0, 0)}; + aiev2_mul_mac_compute_control( + /*sgn_x=*/1, /*sgn_y=*/1, /*amode=*/0, /*bmode=*/3, + /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0, + /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, + /*sub_mask=*/0)}; + } else if (lhsBitWidth == 32) { + return {DecodedMulElemOp::Kind::I32_I32_I64_32x1x2x1, -1}; } } else { // Float types if (lhsBitWidth == 16) { return {DecodedMulElemOp::Kind::BF16_BF16_FP32_16x1x2x1, - aiev2_mul_mac_compute_control(0, 0, 2, 3, 1, 0, 0, 0, 0, 0, 0)}; + aiev2_mul_mac_compute_control( + /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3, + /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0, + /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, + /*sub_mask=*/0)}; } } return {DecodedMulElemOp::Kind::UNSUPPORTED, -1}; } + // This conversion pattern implements the below CPP I32 mul_elem emulation. + // INTRINSIC(v16acc64) + // mul_elem_16_2(v16int32 a0, v16int32 a1, v16int32 b0, v16int32 b1) { + // v32uint16 a_lo = (v32uint16)shuffle(a0, a1, 2); + // v32int16 a_hi = (v32int16)shuffle(a0, a1, 3); + // v32uint16 b_lo = (v32uint16)shuffle(b0, b1, 2); + // v32int16 b_hi = (v32int16)shuffle(b0, b1, 3); + // v16acc64 acc = ::mul_elem_16_2(a_hi, b_hi); + // acc = mac_elem_16_2_conf(a_hi, 1, b_lo, false, acc, 0, 1, 0, 0); + // acc = mac_elem_16_2_conf(a_lo, false, b_hi, 1, acc, 0, 0, 0, 0); + // acc = mac_elem_16_2_conf(a_lo, false, b_lo, false, acc, 0, 1, 0, 0); + // return acc; + // } + // Caller to the above CPP intrinsic: + // v16int32 v1 = LHS(); + // v16int32 v2 = RHS(); + // v16acc64 v3 = mul_elem_16_2(v1, broadcast_zero_s32(), v2, + // undef_v16int32()); + LogicalResult + convertToI32MulElemEmulation(aievec::MulElemOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + Location loc = op.getLoc(); + auto zeroCst = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); + auto a0 = adaptor.getLhs(); + auto a1 = rewriter.create( + loc, VectorType::get({16}, rewriter.getI32Type()), zeroCst); + auto b0 = adaptor.getRhs(); + auto b1 = rewriter.create( + loc, VectorType::get({16}, rewriter.getI32Type())); + + // 4* Shuffle + auto a_lo = rewriter.create( + loc, VectorType::get({16}, rewriter.getI32Type()), a0, a1, + rewriter.create(loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(2))); + auto a_hi = rewriter.create( + loc, VectorType::get({16}, rewriter.getI32Type()), a0, a1, + rewriter.create(loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(3))); + auto b_lo = rewriter.create( + loc, VectorType::get({16}, rewriter.getI32Type()), b0, b1, + rewriter.create(loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(2))); + auto b_hi = rewriter.create( + loc, VectorType::get({16}, rewriter.getI32Type()), b0, b1, + rewriter.create(loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(3))); + // MUL + 3 * MAC + auto mulConfCst = rewriter.create( + loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(aiev2_mul_mac_compute_control( + /*sgn_x=*/1, /*sgn_y=*/1, /*amode=*/1, /*bmode=*/3, + /*variant=*/2, /*zero_acc=*/0, /*shift16=*/0, + /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0))); + auto mulConfOp = rewriter.create( + loc, VectorType::get({16}, rewriter.getI64Type()), + forceCastOperandsToSignature( + rewriter, loc, + /*operands=*/{a_hi, b_hi, mulConfCst}, + /*signature=*/ + {VectorType::get({64}, rewriter.getI8Type()), + VectorType::get({16}, rewriter.getI32Type()), + rewriter.getI32Type()})); + + auto createMacConfOp = [&](SmallVector operands, + int macConf) -> Value { + operands.push_back(rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(macConf))); + return rewriter + .create( + loc, VectorType::get({16}, rewriter.getI64Type()), + forceCastOperandsToSignature( + rewriter, loc, + /*operands=*/operands, + /*signature=*/ + {VectorType::get({64}, rewriter.getI8Type()), + VectorType::get({16}, rewriter.getI32Type()), + VectorType::get({16}, rewriter.getI64Type()), + rewriter.getI32Type()})) + .getResult(); + }; + auto acc64Val = mulConfOp.getResult(); + acc64Val = createMacConfOp( + SmallVector{a_hi, b_lo, acc64Val}, + aiev2_mul_mac_compute_control( + /*sgn_x=*/1, /*sgn_y=*/0, /*amode=*/1, /*bmode=*/3, + /*variant=*/2, /*zero_acc=*/0, /*shift16=*/1, + /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0)); + acc64Val = createMacConfOp( + SmallVector{a_lo, b_hi, acc64Val}, + aiev2_mul_mac_compute_control( + /*sgn_x=*/0, /*sgn_y=*/1, /*amode=*/1, /*bmode=*/3, + /*variant=*/2, /*zero_acc=*/0, /*shift16=*/0, + /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0)); + acc64Val = createMacConfOp( + SmallVector{a_lo, b_lo, acc64Val}, + aiev2_mul_mac_compute_control( + /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/1, /*bmode=*/3, + /*variant=*/2, /*zero_acc=*/0, /*shift16=*/1, + /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0)); + + // create bitcast for result + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + acc64Val); + return success(); + } + LogicalResult matchAndRewrite(aievec::MulElemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -423,6 +562,12 @@ class MulElemOpConversion return failure(); } + // Handle the I32 mul_elem emulation + // TODO: handle the FP32 mul_elem emulation + if (decodedMulElemOp.kind == DecodedMulElemOp::Kind::I32_I32_I64_32x1x2x1) { + return convertToI32MulElemEmulation(op, adaptor, rewriter); + } + // create constant for config auto confCst = rewriter.create( loc, rewriter.getI32Type(), @@ -496,12 +641,21 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern { SmallVector operands( {adaptor.getSource(), adaptor.getShift(), signCst}); if (resultVectorSize == 512) { - rewriter.replaceOpWithNewOp( - op, VectorType::get({32}, rewriter.getI16Type()), - forceCastOperandsToSignature( - rewriter, loc, operands, - {VectorType::get({16}, rewriter.getI64Type()), - rewriter.getI32Type(), rewriter.getI32Type()})); + if (resultBitWidth == 16) { + rewriter.replaceOpWithNewOp( + op, VectorType::get({32}, rewriter.getI16Type()), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({16}, rewriter.getI64Type()), + rewriter.getI32Type(), rewriter.getI32Type()})); + } else if (resultBitWidth == 32) { + rewriter.replaceOpWithNewOp( + op, VectorType::get({16}, rewriter.getI32Type()), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({16}, rewriter.getI64Type()), + rewriter.getI32Type(), rewriter.getI32Type()})); + } } else if (resultVectorSize == 256) { rewriter.replaceOpWithNewOp( op, VectorType::get({32}, rewriter.getI8Type()), @@ -1124,13 +1278,11 @@ class MatMulOpConversion } }; -/* - This pattern folds aievec.cast op. For AIE-ML, the accumulators are in 32/64 - bits, and the vectors are in 4/8/16/32 bits. Hence, we don't have to - explicitly express the casting between accumulators and vectors at the LLVM - dialect level. The backend LLVM compiler will decide the correct accumulator - or vector registers given the ops and intrinsics. -*/ +// This pattern folds aievec.cast op. For AIE-ML, the accumulators are in 32/64 +// bits, and the vectors are in 4/8/16/32 bits. Hence, we don't have to +// explicitly express the casting between accumulators and vectors at the LLVM +// dialect level. The backend LLVM compiler will decide the correct accumulator +// or vector registers given the ops and intrinsics. class FoldAIECastOps : public mlir::ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; diff --git a/test/Conversion/AIEVecToLLVM/mul_elem.mlir b/test/Conversion/AIEVecToLLVM/mul_elem.mlir index 885b9d2ac6..4bbaab93c4 100644 --- a/test/Conversion/AIEVecToLLVM/mul_elem.mlir +++ b/test/Conversion/AIEVecToLLVM/mul_elem.mlir @@ -37,6 +37,42 @@ func.func @i8_i8_i32_mul_elem(%arg0 : vector<64xi8>, %arg1 : vector<64xi8>) -> v // ----- +func.func @i32_i32_i32_mul_elem(%arg0 : vector<16xi32>, %arg1 : vector<16xi32>) -> vector<16xi64> { + %0 = aievec.mul_elem %arg0, %arg1 : vector<16xi32>, vector<16xi32>, vector<16xi64> + return %0 : vector<16xi64> +} + +// CHECK-LABEL: @i32_i32_i32_mul_elem +// CHECK-SAME: %[[ARG0:.*]]: vector<16xi32>, +// CHECK-SAME: %[[ARG1:.*]]: vector<16xi32> +// CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NEXT: %[[VBROADCAST:.*]] = "xllvm.intr.aie2.vbroadcast32.I512"(%[[CST0]]) : (i32) -> vector<16xi32> +// CHECK-NEXT: %[[UNDEF:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32> +// CHECK-NEXT: %[[CST1:.*]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK-NEXT: %[[SHUFF0:.*]] = "xllvm.intr.aie2.vshuffle"(%[[ARG0]], %[[VBROADCAST]], %[[CST1]]) : (vector<16xi32>, vector<16xi32>, i32) -> vector<16xi32> +// CHECK-NEXT: %[[CST2:.*]] = llvm.mlir.constant(3 : i32) : i32 +// CHECK-NEXT: %[[SHUFF1:.*]] = "xllvm.intr.aie2.vshuffle"(%[[ARG0]], %[[VBROADCAST]], %[[CST2]]) : (vector<16xi32>, vector<16xi32>, i32) -> vector<16xi32> +// CHECK-NEXT: %[[CST3:.*]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK-NEXT: %[[SHUFF2:.*]] = "xllvm.intr.aie2.vshuffle"(%[[ARG1]], %[[UNDEF]], %[[CST3]]) : (vector<16xi32>, vector<16xi32>, i32) -> vector<16xi32> +// CHECK-NEXT: %[[CST4:.*]] = llvm.mlir.constant(3 : i32) : i32 +// CHECK-NEXT: %[[SHUFF3:.*]] = "xllvm.intr.aie2.vshuffle"(%[[ARG1]], %[[UNDEF]], %[[CST4]]) : (vector<16xi32>, vector<16xi32>, i32) -> vector<16xi32> +// CHECK-NEXT: %[[CST5:.*]] = llvm.mlir.constant(858 : i32) : i32 +// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[SHUFF1]] : vector<16xi32> to vector<64xi8> +// CHECK-NEXT: %[[ACC0:.*]] = "xllvm.intr.aie2.I512.I512.acc64.mul.conf"(%[[BITCAST0]], %[[SHUFF3]], %[[CST5]]) : (vector<64xi8>, vector<16xi32>, i32) -> vector<16xi64> +// CHECK-NEXT: %[[CST6:.*]] = llvm.mlir.constant(1626 : i32) : i32 +// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[SHUFF1]] : vector<16xi32> to vector<64xi8> +// CHECK-NEXT: %[[ACC1:.*]] = "xllvm.intr.aie2.I512.I512.ACC1024.acc64.mac.conf"(%[[BITCAST1]], %[[SHUFF2]], %[[ACC0]], %[[CST6]]) : (vector<64xi8>, vector<16xi32>, vector<16xi64>, i32) -> vector<16xi64> +// CHECK-NEXT: %[[CST7:.*]] = llvm.mlir.constant(346 : i32) : i32 +// CHECK-NEXT: %[[BITCAST2:.*]] = llvm.bitcast %[[SHUFF0]] : vector<16xi32> to vector<64xi8> +// CHECK-NEXT: %[[ACC2:.*]] = "xllvm.intr.aie2.I512.I512.ACC1024.acc64.mac.conf"(%[[BITCAST2]], %[[SHUFF3]], %[[ACC1]], %[[CST7]]) : (vector<64xi8>, vector<16xi32>, vector<16xi64>, i32) -> vector<16xi64> +// CHECK-NEXT: %[[CST8:.*]] = llvm.mlir.constant(1114 : i32) : i32 +// CHECK-NEXT: %[[BITCAST3:.*]] = llvm.bitcast %[[SHUFF0]] : vector<16xi32> to vector<64xi8> +// CHECK-NEXT: %[[ACC3:.*]] = "xllvm.intr.aie2.I512.I512.ACC1024.acc64.mac.conf"(%[[BITCAST3]], %[[SHUFF2]], %[[ACC2]], %[[CST8]]) : (vector<64xi8>, vector<16xi32>, vector<16xi64>, i32) -> vector<16xi64> +// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[ACC3]] : vector<16xi64> to vector<16xi64> +// CHECK-NEXT: return %[[RES]] : vector<16xi64> + +// ----- + func.func @bf16_bf16_f32_mul_elem(%arg0 : vector<32xbf16>, %arg1 : vector<32xbf16>) -> vector<16xf32> { %0 = aievec.mul_elem %arg0, %arg1 : vector<32xbf16>, vector<32xbf16>, vector<16xf32> return %0 : vector<16xf32> diff --git a/test/Conversion/AIEVecToLLVM/test-srs.mlir b/test/Conversion/AIEVecToLLVM/test-srs.mlir index 80f6f64b33..4a1b557870 100644 --- a/test/Conversion/AIEVecToLLVM/test-srs.mlir +++ b/test/Conversion/AIEVecToLLVM/test-srs.mlir @@ -50,6 +50,29 @@ func.func @v32i32_srs_v32i8(%arg0 : vector<32xi32>) { // ----- +func.func @v16i64_srs_v16i32(%arg0 : vector<16xi64>) { + %c0 = arith.constant 0 : i32 + %c5 = arith.constant 5 : i32 + %0 = aievec.srs %arg0, %c0 : vector<16xi64>, i32, vector<16xi32> + %1 = aievec.srs %arg0, %c5 : vector<16xi64>, i32, vector<16xi32> + return +} + +// CHECK-LABEL: @v16i64_srs_v16i32 +// CHECK-SAME: %[[ARG0:.*]]: vector<16xi64> +// CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32 +// CHECK-NEXT: %[[SHIFT5:.*]] = arith.constant 5 : i32 +// CHECK-NEXT: %[[SIGN0:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.I512.v16.acc64.srs"( +// CHECK-SAME: [[ARG0]], %[[SHIFT0]], %[[SIGN0]]) : +// CHECK-SAME: (vector<16xi64>, i32, i32) -> vector<16xi32> +// CHECK-NEXT: %[[SIGN1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.I512.v16.acc64.srs"( +// CHECK-SAME: [[ARG0]], %[[SHIFT5]], %[[SIGN1]]) : +// CHECK-SAME: (vector<16xi64>, i32, i32) -> vector<16xi32> + +// ----- + func.func @v16f32_srs_v16bf16(%arg0 : vector<16xf32>) { %c0 = arith.constant 0 : i32 %c5 = arith.constant 5 : i32 diff --git a/test/Target/LLVMIR/aievec.mlir b/test/Target/LLVMIR/aievec.mlir index 537fb5203f..45816ec69f 100644 --- a/test/Target/LLVMIR/aievec.mlir +++ b/test/Target/LLVMIR/aievec.mlir @@ -45,6 +45,19 @@ llvm.func @mul_conf_acc32(%A : vector<64xi8>, llvm.return %0 : vector<16xi64> } +// CHECK-LABEL: define <16 x i64> @mul_conf_acc64 +llvm.func @mul_conf_acc64(%A : vector<64xi8>, + %B : vector<16xi32>, + %cfg : i32) + -> vector<16xi64> { + // CHECK: call <16 x i64> @llvm.aie2.I512.I512.acc64.mul.conf( + // CHECK-SAME: <64 x i8> %{{[0-9]+}}, <16 x i32> %{{[0-9]+}}, + // CHECK-SAME: i32 %{{[0-9]+}}) + %0 = "xllvm.intr.aie2.I512.I512.acc64.mul.conf"(%A, %B, %cfg) : + (vector<64xi8>, vector<16xi32>, i32) -> vector<16xi64> + llvm.return %0 : vector<16xi64> +} + // CHECK-LABEL: define <8 x i64> @mul_conf_bf16 llvm.func @mul_conf_bf16(%A : vector<32xbf16>, %B : vector<32xbf16>, @@ -96,6 +109,15 @@ llvm.func @srs_256b_v32_acc32(%v : vector<16xi64>, %shft : i32, %sign : i32) -> llvm.return %0 : vector<32xi8> } +// CHECK-LABEL: define <16 x i32> @srs_512b_v16_acc64 +llvm.func @srs_512b_v16_acc64(%v : vector<16xi64>, %shft : i32, %sign : i32) -> vector<16xi32> { + // CHECK: call <16 x i32> @llvm.aie2.I512.v16.acc64.srs( + // CHECK-SAME: <16 x i64> %{{[0-9]+}}, i32 %{{[0-9]+}}, i32 %{{[0-9]+}}) + %0 = "xllvm.intr.aie2.I512.v16.acc64.srs"(%v, %shft, %sign) : + (vector<16xi64>, i32, i32) -> vector<16xi32> + llvm.return %0 : vector<16xi32> +} + // CHECK-LABEL: define <16 x bfloat> @srs_256b_v16_accfloat llvm.func @srs_256b_v16_accfloat(%v : vector<8xi64>) -> vector<16xbf16> { // CHECK: call <16 x bfloat> @llvm.aie2.v16accfloat.to.v16bf16( @@ -152,6 +174,17 @@ llvm.func @concat_i512_i256(%a : vector<8xi32>, %b : vector<8xi32>) -> vector<16 llvm.return %0 : vector<16xi32> } +// ----- SHUFFLE ----- + +// CHECK-LABEL: define <16 x i32> @shuffle_i512 +llvm.func @shuffle_i512(%a : vector<16xi32>, %b : vector<16xi32>, %mode : i32) -> vector<16xi32> { + // CHECK: call <16 x i32> @llvm.aie2.vshuffle( + // CHECK-SAME: <16 x i32> %{{[0-9]+}}, <16 x i32> %{{[0-9]+}}, i32 %{{[0-9]+}}) + %0 = "xllvm.intr.aie2.vshuffle"(%a, %b, %mode) : + (vector<16xi32>, vector<16xi32>, i32) -> vector<16xi32> + llvm.return %0 : vector<16xi32> +} + // ----- UNDEF ----- // CHECK-LABEL: define <16 x i32> @undef_v16i32 diff --git a/test/unit_tests/aievec_tests/i32xi32_mul_elem/i32xi32_mul_elem-peano.mlir b/test/unit_tests/aievec_tests/i32xi32_mul_elem/i32xi32_mul_elem-peano.mlir new file mode 100644 index 0000000000..8a32609a10 --- /dev/null +++ b/test/unit_tests/aievec_tests/i32xi32_mul_elem/i32xi32_mul_elem-peano.mlir @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Copyright (C) 2023, Advanced Micro Devices, Inc. + +// REQUIRES: valid_xchess_license +// REQUIRES: peano +// RUN: mkdir -p %t/data; cd %t +// RUN: aie-opt %s -affine-super-vectorize="virtual-vector-size=16" %vector-to-llvmir% -o llvmir.mlir +// RUN: aie-translate llvmir.mlir %llvmir-to-ll% -o dut.ll +// RUN: %PEANO_INSTALL_DIR/bin/clang %clang_aie2_args -c dut.ll -o dut.o +// RUN: xchesscc_wrapper %xchesscc_aie2_args -DTO_LLVM +w work +o work -I%S -I. %S/testbench.cc dut.o +// RUN: xca_udm_dbg --aiearch aie-ml -qf -T -P %aietools/data/aie_ml/lib/ -t "%S/../profiling.tcl ./work/a.out" >& xca_udm_dbg.stdout +// RUN: FileCheck --input-file=./xca_udm_dbg.stdout %s +// CHECK: TEST PASSED + +module { + func.func @dut(%arg0: memref<1024xi32>, %arg1: memref<1024xi32>, %arg2: memref<1024xi32>) { + memref.assume_alignment %arg0, 32 : memref<1024xi32> + memref.assume_alignment %arg1, 32 : memref<1024xi32> + memref.assume_alignment %arg2, 32 : memref<1024xi32> + affine.for %arg3 = 0 to 1024 { + %0 = affine.load %arg0[%arg3] : memref<1024xi32> + %1 = affine.load %arg1[%arg3] : memref<1024xi32> + %2 = arith.muli %0, %1 : i32 + affine.store %2, %arg2[%arg3] : memref<1024xi32> + } + return + } +} diff --git a/test/unit_tests/aievec_tests/i32xi32_mul_elem/i32xi32_mul_elem.mlir b/test/unit_tests/aievec_tests/i32xi32_mul_elem/i32xi32_mul_elem.mlir index 8176fec826..b90ab5bbaa 100644 --- a/test/unit_tests/aievec_tests/i32xi32_mul_elem/i32xi32_mul_elem.mlir +++ b/test/unit_tests/aievec_tests/i32xi32_mul_elem/i32xi32_mul_elem.mlir @@ -2,10 +2,10 @@ // Copyright (C) 2023, Advanced Micro Devices, Inc. // REQUIRES: valid_xchess_license +// RUN: mkdir -p %t/data; cd %t // RUN: aie-opt %s -affine-super-vectorize="virtual-vector-size=16" --convert-vector-to-aievec="aie-target=aieml" -lower-affine | aie-translate -aieml=true --aievec-to-cpp -o dut.cc -// RUN: xchesscc_wrapper aie2 -f -g +s +w work +o work -I%S -I. -c dut.cc -o dut.o -// RUN: xchesscc_wrapper aie2 -f -g +s +w work +o work -I%S -I. %S/testbench.cc work/dut.o -// RUN: mkdir -p data +// RUN: xchesscc_wrapper %xchesscc_aie2_args +w work +o work -I%S -I. -c dut.cc -o dut.o +// RUN: xchesscc_wrapper %xchesscc_aie2_args -DTO_CPP +w work +o work -I%S -I. %S/testbench.cc work/dut.o // RUN: xca_udm_dbg --aiearch aie-ml -qf -T -P %aietools/data/aie_ml/lib/ -t "%S/../profiling.tcl ./work/a.out" >& xca_udm_dbg.stdout // RUN: FileCheck --input-file=./xca_udm_dbg.stdout %s // CHECK: TEST PASSED diff --git a/test/unit_tests/aievec_tests/i32xi32_mul_elem/testbench.cc b/test/unit_tests/aievec_tests/i32xi32_mul_elem/testbench.cc index bd04a945a3..5f437db045 100644 --- a/test/unit_tests/aievec_tests/i32xi32_mul_elem/testbench.cc +++ b/test/unit_tests/aievec_tests/i32xi32_mul_elem/testbench.cc @@ -4,7 +4,19 @@ #include #include #include + +#ifdef TO_CPP void dut(int32_t *restrict in0, int32_t *restrict in1, int32_t *restrict out0); +#elif TO_LLVM +extern "C" { +void dut(int32_t *in0_allocated, int32_t *in0_aligned, int64_t in0_offset, + int64_t in0_sizes_0, int64_t in0_strides_0, int32_t *in1_allocated, + int32_t *in1_aligned, int64_t in1_offset, int64_t in1_sizes_0, + int64_t in1_strides_0, int32_t *out0_allocated, int32_t *out0_aligned, + int64_t out0_offset, int64_t out0_sizes_0, int64_t out0_strides_0); +} +#endif + void dut_ref(int32_t *in0, int32_t *in1, int32_t *out0); alignas(32) int32_t g_in0[IN0_SIZE]; @@ -26,7 +38,11 @@ int main(int argc, char *argv[]) { chess_memory_fence(); auto cyclesBegin = chess_cycle_count(); +#ifdef TO_CPP dut(g_in0, g_in1, g_out0); +#elif TO_LLVM + dut(g_in0, g_in0, 0, 0, 0, g_in1, g_in1, 0, 0, 0, g_out0, g_out0, 0, 0, 0); +#endif auto cyclesEnd = chess_cycle_count(); chess_memory_fence();