diff --git a/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td b/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td index 500c2b9f52..b5a24a6bc4 100644 --- a/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td +++ b/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td @@ -103,9 +103,16 @@ def VectorSetI512I256IntrOp : // ----- SRS ----- -def I512V32Acc32SrsIntrOp : - AIEVec2_IntrOp<"I512.v32.acc32.srs", - [TypeIs<"res", VectorOfLengthAndType<[32], [I16]>>]>, +def I256V16Acc32SrsIntrOp : + AIEVec2_IntrOp<"I256.v16.acc32.srs", + [TypeIs<"res", VectorOfLengthAndType<[16], [I16]>>]>, + Arguments<(ins VectorOfLengthAndType<[8], [I64]>:$src, + I32:$shift, + I32:$sign)>; + +def I256V16Acc64SrsIntrOp : + AIEVec2_IntrOp<"I256.v16.acc64.srs", + [TypeIs<"res", VectorOfLengthAndType<[16], [I16]>>]>, Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$src, I32:$shift, I32:$sign)>; @@ -117,6 +124,13 @@ def I256V32Acc32SrsIntrOp : I32:$shift, I32:$sign)>; +def I256V8Acc64SrsIntrOp : + AIEVec2_IntrOp<"I256.v8.acc64.srs", + [TypeIs<"res", VectorOfLengthAndType<[8], [I32]>>]>, + Arguments<(ins VectorOfLengthAndType<[8], [I64]>:$src, + I32:$shift, + I32:$sign)>; + def I512V16Acc64SrsIntrOp : AIEVec2_IntrOp<"I512.v16.acc64.srs", [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>, @@ -124,6 +138,13 @@ def I512V16Acc64SrsIntrOp : I32:$shift, I32:$sign)>; +def I512V32Acc32SrsIntrOp : + AIEVec2_IntrOp<"I512.v32.acc32.srs", + [TypeIs<"res", VectorOfLengthAndType<[32], [I16]>>]>, + Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$src, + I32:$shift, + I32:$sign)>; + def Vector16AccFloatToV16BF16IntrOp : AIEVec2_IntrOp<"v16accfloat.to.v16bf16", [TypeIs<"res", VectorOfLengthAndType<[16], [BF16]>>]>, diff --git a/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp b/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp index 84d1ac28f5..7c82148b85 100644 --- a/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp +++ b/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp @@ -842,6 +842,7 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern { int resultVectorSize = resultBitWidth * resultLanes; // Integer types + Operation *srsIntrOp = nullptr; if (llvm::isa(resultScaTy)) { // create constant for sign auto signCst = rewriter.create( @@ -852,47 +853,91 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern { {adaptor.getSource(), adaptor.getShift(), signCst}); if (resultVectorSize == 512) { if (resultBitWidth == 16) { - rewriter.replaceOpWithNewOp( - op, VectorType::get({32}, rewriter.getI16Type()), + // v32acc32 -> v32int16 + srsIntrOp = rewriter.create( + loc, VectorType::get({32}, rewriter.getI16Type()), forceCastOperandsToSignature( rewriter, loc, operands, {VectorType::get({16}, rewriter.getI64Type()), rewriter.getI32Type(), rewriter.getI32Type()})); } else if (resultBitWidth == 32) { - rewriter.replaceOpWithNewOp( - op, VectorType::get({16}, rewriter.getI32Type()), + // v16acc64 -> v16int32 + srsIntrOp = rewriter.create( + loc, VectorType::get({16}, rewriter.getI32Type()), forceCastOperandsToSignature( rewriter, loc, operands, {VectorType::get({16}, rewriter.getI64Type()), rewriter.getI32Type(), rewriter.getI32Type()})); } } else if (resultVectorSize == 256) { - rewriter.replaceOpWithNewOp( - op, VectorType::get({32}, rewriter.getI8Type()), - forceCastOperandsToSignature( - rewriter, loc, operands, - {VectorType::get({16}, rewriter.getI64Type()), - rewriter.getI32Type(), rewriter.getI32Type()})); - } else { - op.emitWarning() << "aievec.srs with result vector size = " - << resultVectorSize << " is not supported.\n"; - return failure(); + Value src = adaptor.getSource(); + VectorType srcType = cast(src.getType()); + Type srcScaType = srcType.getElementType(); + unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth(); + + if (resultBitWidth == 16 && srcBitWidth == 32) { + // v16acc32 -> v16int16 + srsIntrOp = rewriter.create( + loc, VectorType::get({16}, rewriter.getI16Type()), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({8}, rewriter.getI64Type()), + rewriter.getI32Type(), rewriter.getI32Type()})); + } else if (resultBitWidth == 8 && srcBitWidth == 32) { + // v32acc32 -> v32int8 + srsIntrOp = rewriter.create( + loc, VectorType::get({32}, rewriter.getI8Type()), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({16}, rewriter.getI64Type()), + rewriter.getI32Type(), rewriter.getI32Type()})); + } else if (resultBitWidth == 16 && srcBitWidth == 64) { + // v16acc64 -> v16int16 + srsIntrOp = rewriter.create( + loc, VectorType::get({16}, rewriter.getI16Type()), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({16}, rewriter.getI64Type()), + rewriter.getI32Type(), rewriter.getI32Type()})); + } else if (resultBitWidth == 32 && srcBitWidth == 64) { + // v8acc64 -> v8int32 + srsIntrOp = rewriter.create( + loc, VectorType::get({8}, rewriter.getI32Type()), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({8}, rewriter.getI64Type()), + rewriter.getI32Type(), rewriter.getI32Type()})); + } } } else { // Float types if (resultVectorSize == 256) { - rewriter.replaceOpWithNewOp( - op, VectorType::get({16}, rewriter.getBF16Type()), + // v16accfloat -> v16bfloat16 + srsIntrOp = rewriter.create( + loc, VectorType::get({16}, rewriter.getBF16Type()), forceCastOperandsToSignature( rewriter, loc, {adaptor.getSource()}, {VectorType::get({8}, rewriter.getI64Type())})); - } else { - op.emitWarning() << "aievec.srs with result vector size = " - << resultVectorSize << " is not supported.\n"; - return failure(); + } else if (resultVectorSize == 512) { + // v32accfloat -> v32bfloat16 + // Implement this scenario in emulation. The CPP example is below: + // v32bfloat16 to_v32bfloat16(v32accfloat acc) { + // v16bfloat16 x0 = to_v16bfloat16(extract_v16accfloat(acc, 0)); + // v16bfloat16 x1 = to_v16bfloat16(extract_v16accfloat(acc, 1)); + // return concat(x0, x1); + // } + // TODO: implement this after adding 1024->512 vector extraction + // intrinsic } } + if (!srsIntrOp) { + op.emitWarning() << "aievec.srs is not supported.\n"; + return failure(); + } + + rewriter.replaceOp(op, srsIntrOp); + return success(); } }; @@ -1471,7 +1516,7 @@ class ExtractElemOpConversion } // create truncation op (and bitcast op) - if (resultType.isa()) { + if (llvm::isa(resultType)) { if (resultBitWidth < 32) { rewriter.replaceOpWithNewOp(op, resultType, extElemOp); } else { diff --git a/lib/Dialect/AIEVec/IR/AIEVecOps.cpp b/lib/Dialect/AIEVec/IR/AIEVecOps.cpp index b8db4437d9..b0055be060 100644 --- a/lib/Dialect/AIEVec/IR/AIEVecOps.cpp +++ b/lib/Dialect/AIEVec/IR/AIEVecOps.cpp @@ -541,7 +541,7 @@ LogicalResult BroadcastScalarOp::verify() { if (!resultType) return emitError("requires vector type"); - if (!sourceType.isa()) + if (!isa(sourceType)) return emitError("requires source type to be integer or float"); Type resultElemType = resultType.getElementType(); diff --git a/test/Conversion/AIEVecToLLVM/test-srs.mlir b/test/Conversion/AIEVecToLLVM/test-srs.mlir index 4a1b557870..b6eda22d57 100644 --- a/test/Conversion/AIEVecToLLVM/test-srs.mlir +++ b/test/Conversion/AIEVecToLLVM/test-srs.mlir @@ -1,6 +1,6 @@ // RUN: aie-opt %s -split-input-file --convert-aievec-to-llvm | FileCheck %s -func.func @v32i32_srs_v32i16(%arg0 : vector<32xi32>) { +func.func @v32i16_srs_v32i32(%arg0 : vector<32xi32>) { %c0 = arith.constant 0 : i32 %c5 = arith.constant 5 : i32 %0 = aievec.srs %arg0, %c0 : vector<32xi32>, i32, vector<32xi16> @@ -8,7 +8,7 @@ func.func @v32i32_srs_v32i16(%arg0 : vector<32xi32>) { return } -// CHECK-LABEL: @v32i32_srs_v32i16 +// CHECK-LABEL: @v32i16_srs_v32i32 // CHECK-SAME: %[[ARG0:.*]]: vector<32xi32> // CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32 // CHECK-NEXT: %[[SHIFT5:.*]] = arith.constant 5 : i32 @@ -25,7 +25,55 @@ func.func @v32i32_srs_v32i16(%arg0 : vector<32xi32>) { // ----- -func.func @v32i32_srs_v32i8(%arg0 : vector<32xi32>) { +func.func @v16i32_srs_v16i64(%arg0 : vector<16xi64>) { + %c0 = arith.constant 0 : i32 + %c5 = arith.constant 5 : i32 + %0 = aievec.srs %arg0, %c0 : vector<16xi64>, i32, vector<16xi32> + %1 = aievec.srs %arg0, %c5 : vector<16xi64>, i32, vector<16xi32> + return +} + +// CHECK-LABEL: @v16i32_srs_v16i64 +// CHECK-SAME: %[[ARG0:.*]]: vector<16xi64> +// CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32 +// CHECK-NEXT: %[[SHIFT5:.*]] = arith.constant 5 : i32 +// CHECK-NEXT: %[[SIGN0:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.I512.v16.acc64.srs"( +// CHECK-SAME: [[ARG0]], %[[SHIFT0]], %[[SIGN0]]) : +// CHECK-SAME: (vector<16xi64>, i32, i32) -> vector<16xi32> +// CHECK-NEXT: %[[SIGN1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.I512.v16.acc64.srs"( +// CHECK-SAME: [[ARG0]], %[[SHIFT5]], %[[SIGN1]]) : +// CHECK-SAME: (vector<16xi64>, i32, i32) -> vector<16xi32> + +// ----- + +func.func @v16i16_srs_v16i32(%arg0 : vector<16xi32>) { + %c0 = arith.constant 0 : i32 + %c5 = arith.constant 5 : i32 + %0 = aievec.srs %arg0, %c0 : vector<16xi32>, i32, vector<16xi16> + %1 = aievec.srs %arg0, %c5 : vector<16xi32>, i32, vector<16xi16> + return +} + +// CHECK-LABEL: @v16i16_srs_v16i32 +// CHECK-SAME: %[[ARG0:.*]]: vector<16xi32> +// CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32 +// CHECK-NEXT: %[[SHIFT5:.*]] = arith.constant 5 : i32 +// CHECK-NEXT: %[[SIGN0:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<16xi32> to vector<8xi64> +// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.I256.v16.acc32.srs"( +// CHECK-SAME: [[BITCAST0]], %[[SHIFT0]], %[[SIGN0]]) : +// CHECK-SAME: (vector<8xi64>, i32, i32) -> vector<16xi16> +// CHECK-NEXT: %[[SIGN1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<16xi32> to vector<8xi64> +// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.I256.v16.acc32.srs"( +// CHECK-SAME: [[BITCAST1]], %[[SHIFT5]], %[[SIGN1]]) : +// CHECK-SAME: (vector<8xi64>, i32, i32) -> vector<16xi16> + +// ----- + +func.func @v32i8_srs_v32i32(%arg0 : vector<32xi32>) { %c0 = arith.constant 0 : i32 %c5 = arith.constant 5 : i32 %0 = aievec.srs %arg0, %c0 : vector<32xi32>, i32, vector<32xi8> @@ -33,7 +81,7 @@ func.func @v32i32_srs_v32i8(%arg0 : vector<32xi32>) { return } -// CHECK-LABEL: @v32i32_srs_v32i8 +// CHECK-LABEL: @v32i8_srs_v32i32 // CHECK-SAME: %[[ARG0:.*]]: vector<32xi32> // CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32 // CHECK-NEXT: %[[SHIFT5:.*]] = arith.constant 5 : i32 @@ -50,30 +98,53 @@ func.func @v32i32_srs_v32i8(%arg0 : vector<32xi32>) { // ----- -func.func @v16i64_srs_v16i32(%arg0 : vector<16xi64>) { +func.func @v16i16_srs_v16i64(%arg0 : vector<16xi64>) { %c0 = arith.constant 0 : i32 %c5 = arith.constant 5 : i32 - %0 = aievec.srs %arg0, %c0 : vector<16xi64>, i32, vector<16xi32> - %1 = aievec.srs %arg0, %c5 : vector<16xi64>, i32, vector<16xi32> + %0 = aievec.srs %arg0, %c0 : vector<16xi64>, i32, vector<16xi16> + %1 = aievec.srs %arg0, %c5 : vector<16xi64>, i32, vector<16xi16> return } -// CHECK-LABEL: @v16i64_srs_v16i32 +// CHECK-LABEL: @v16i16_srs_v16i64 // CHECK-SAME: %[[ARG0:.*]]: vector<16xi64> // CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32 // CHECK-NEXT: %[[SHIFT5:.*]] = arith.constant 5 : i32 // CHECK-NEXT: %[[SIGN0:.*]] = llvm.mlir.constant(1 : i32) : i32 -// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.I512.v16.acc64.srs"( +// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.I256.v16.acc64.srs"( // CHECK-SAME: [[ARG0]], %[[SHIFT0]], %[[SIGN0]]) : -// CHECK-SAME: (vector<16xi64>, i32, i32) -> vector<16xi32> +// CHECK-SAME: (vector<16xi64>, i32, i32) -> vector<16xi16> // CHECK-NEXT: %[[SIGN1:.*]] = llvm.mlir.constant(1 : i32) : i32 -// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.I512.v16.acc64.srs"( +// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.I256.v16.acc64.srs"( // CHECK-SAME: [[ARG0]], %[[SHIFT5]], %[[SIGN1]]) : -// CHECK-SAME: (vector<16xi64>, i32, i32) -> vector<16xi32> +// CHECK-SAME: (vector<16xi64>, i32, i32) -> vector<16xi16> + +// ----- + +func.func @v8i32_srs_v8i64(%arg0 : vector<8xi64>) { + %c0 = arith.constant 0 : i32 + %c5 = arith.constant 5 : i32 + %0 = aievec.srs %arg0, %c0 : vector<8xi64>, i32, vector<8xi32> + %1 = aievec.srs %arg0, %c5 : vector<8xi64>, i32, vector<8xi32> + return +} + +// CHECK-LABEL: @v8i32_srs_v8i64 +// CHECK-SAME: %[[ARG0:.*]]: vector<8xi64> +// CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32 +// CHECK-NEXT: %[[SHIFT5:.*]] = arith.constant 5 : i32 +// CHECK-NEXT: %[[SIGN0:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.I256.v8.acc64.srs"( +// CHECK-SAME: [[ARG0]], %[[SHIFT0]], %[[SIGN0]]) : +// CHECK-SAME: (vector<8xi64>, i32, i32) -> vector<8xi32> +// CHECK-NEXT: %[[SIGN1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.I256.v8.acc64.srs"( +// CHECK-SAME: [[ARG0]], %[[SHIFT5]], %[[SIGN1]]) : +// CHECK-SAME: (vector<8xi64>, i32, i32) -> vector<8xi32> // ----- -func.func @v16f32_srs_v16bf16(%arg0 : vector<16xf32>) { +func.func @v16bf16_srs_v16f32(%arg0 : vector<16xf32>) { %c0 = arith.constant 0 : i32 %c5 = arith.constant 5 : i32 %0 = aievec.srs %arg0, %c0 : vector<16xf32>, i32, vector<16xbf16> @@ -81,7 +152,7 @@ func.func @v16f32_srs_v16bf16(%arg0 : vector<16xf32>) { return } -// CHECK-LABEL: @v16f32_srs_v16bf16 +// CHECK-LABEL: @v16bf16_srs_v16f32 // CHECK-SAME: %[[ARG0:.*]]: vector<16xf32> // CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32 // CHECK-NEXT: %[[SHIFT5:.*]] = arith.constant 5 : i32 diff --git a/test/Target/LLVMIR/aievec.mlir b/test/Target/LLVMIR/aievec.mlir index 875c07d0ac..acf634f2ca 100644 --- a/test/Target/LLVMIR/aievec.mlir +++ b/test/Target/LLVMIR/aievec.mlir @@ -107,13 +107,22 @@ llvm.func @vector_set_256b_into_512b(%v : vector<8xi32>) -> vector<16xi32> { // ----- SRS ----- -// CHECK-LABEL: define <32 x i16> @srs_512b_v32_acc32 -llvm.func @srs_512b_v32_acc32(%v : vector<16xi64>, %shft : i32, %sign : i32) -> vector<32xi16> { - // CHECK: call <32 x i16> @llvm.aie2.I512.v32.acc32.srs( +// CHECK-LABEL: define <16 x i16> @srs_256b_v16_acc32 +llvm.func @srs_256b_v16_acc32(%v : vector<8xi64>, %shft : i32, %sign : i32) -> vector<16xi16> { + // CHECK: call <16 x i16> @llvm.aie2.I256.v16.acc32.srs( + // CHECK-SAME: <8 x i64> %{{[0-9]+}}, i32 %{{[0-9]+}}, i32 %{{[0-9]+}}) + %0 = "xllvm.intr.aie2.I256.v16.acc32.srs"(%v, %shft, %sign) : + (vector<8xi64>, i32, i32) -> vector<16xi16> + llvm.return %0 : vector<16xi16> +} + +// CHECK-LABEL: define <16 x i16> @srs_256b_v16_acc64 +llvm.func @srs_256b_v16_acc64(%v : vector<16xi64>, %shft : i32, %sign : i32) -> vector<16xi16> { + // CHECK: call <16 x i16> @llvm.aie2.I256.v16.acc64.srs( // CHECK-SAME: <16 x i64> %{{[0-9]+}}, i32 %{{[0-9]+}}, i32 %{{[0-9]+}}) - %0 = "xllvm.intr.aie2.I512.v32.acc32.srs"(%v, %shft, %sign) : - (vector<16xi64>, i32, i32) -> vector<32xi16> - llvm.return %0 : vector<32xi16> + %0 = "xllvm.intr.aie2.I256.v16.acc64.srs"(%v, %shft, %sign) : + (vector<16xi64>, i32, i32) -> vector<16xi16> + llvm.return %0 : vector<16xi16> } // CHECK-LABEL: define <32 x i8> @srs_256b_v32_acc32 @@ -125,6 +134,15 @@ llvm.func @srs_256b_v32_acc32(%v : vector<16xi64>, %shft : i32, %sign : i32) -> llvm.return %0 : vector<32xi8> } +// CHECK-LABEL: define <8 x i32> @srs_256b_v8_acc64 +llvm.func @srs_256b_v8_acc64(%v : vector<8xi64>, %shft : i32, %sign : i32) -> vector<8xi32> { + // CHECK: call <8 x i32> @llvm.aie2.I256.v8.acc64.srs( + // CHECK-SAME: <8 x i64> %{{[0-9]+}}, i32 %{{[0-9]+}}, i32 %{{[0-9]+}}) + %0 = "xllvm.intr.aie2.I256.v8.acc64.srs"(%v, %shft, %sign) : + (vector<8xi64>, i32, i32) -> vector<8xi32> + llvm.return %0 : vector<8xi32> +} + // CHECK-LABEL: define <16 x i32> @srs_512b_v16_acc64 llvm.func @srs_512b_v16_acc64(%v : vector<16xi64>, %shft : i32, %sign : i32) -> vector<16xi32> { // CHECK: call <16 x i32> @llvm.aie2.I512.v16.acc64.srs( @@ -134,6 +152,15 @@ llvm.func @srs_512b_v16_acc64(%v : vector<16xi64>, %shft : i32, %sign : i32) -> llvm.return %0 : vector<16xi32> } +// CHECK-LABEL: define <32 x i16> @srs_512b_v32_acc32 +llvm.func @srs_512b_v32_acc32(%v : vector<16xi64>, %shft : i32, %sign : i32) -> vector<32xi16> { + // CHECK: call <32 x i16> @llvm.aie2.I512.v32.acc32.srs( + // CHECK-SAME: <16 x i64> %{{[0-9]+}}, i32 %{{[0-9]+}}, i32 %{{[0-9]+}}) + %0 = "xllvm.intr.aie2.I512.v32.acc32.srs"(%v, %shft, %sign) : + (vector<16xi64>, i32, i32) -> vector<32xi16> + llvm.return %0 : vector<32xi16> +} + // CHECK-LABEL: define <16 x bfloat> @srs_256b_v16_accfloat llvm.func @srs_256b_v16_accfloat(%v : vector<8xi64>) -> vector<16xbf16> { // CHECK: call <16 x bfloat> @llvm.aie2.v16accfloat.to.v16bf16(