From 3dc49bf13729764781ae9aeaddcdd3451e3a822d Mon Sep 17 00:00:00 2001 From: James Lin Date: Thu, 16 May 2024 18:52:48 -0500 Subject: [PATCH] [aievec] to-llvm flow for aievec.ext op with I128.I512 extract intrinsic (#1490) --- include/aie/Dialect/AIEVec/IR/AIEVecOps.td | 11 +- .../aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td | 3 +- lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp | 32 ++++- test/Conversion/AIEVecToLLVM/test-ext.mlir | 132 ++++++++++++++++++ test/Target/LLVMIR/aievec.mlir | 8 +- 5 files changed, 174 insertions(+), 12 deletions(-) diff --git a/include/aie/Dialect/AIEVec/IR/AIEVecOps.td b/include/aie/Dialect/AIEVec/IR/AIEVecOps.td index 79cb1c5e00..3861dc32d6 100644 --- a/include/aie/Dialect/AIEVec/IR/AIEVecOps.td +++ b/include/aie/Dialect/AIEVec/IR/AIEVecOps.td @@ -475,7 +475,11 @@ def AIEVec_ExtOp: let description = [{ AMD-specific vector extract intrinsic. Selects contiguous lanes from the source vector, and transfers the data from those lanes to the - result. The lane selection is controlled by index. + result. The lane selection is controlled by index. There are two cases: + 1. Extracted vector fills half of the original vector lanes (e.g. extract v64int8 from v128int8) + 2. Extracted vector fills a fourth of the original vector lanes (e.g. extract v32int8 from v128int8) + In the first case, index can be 0 or 1. Index 0 extracts the lower half, and index 1 extracts the upper half. + In the second case, index can be 0 to 3. Index 0 extracts the lowest quarter, index 1 the next quarter, and so on. `$result = ext($source, $index)` }]; } @@ -576,8 +580,9 @@ def AIEVec_ShiftOp: let description = [{ AMD-specific shift intrinsic. Concatenates two vectors into a bigger vector, interprets them as a vector of 128 bytes - and returns v1::v2[shift: shift+64]. The verifier confirms that all the - input vectors have the same number of lanes. + and returns v1::v2[shift: shift+64]. `shift` is the number of bytes to + be shifted. The verifier confirms that all the input and result vectors + have the same number of lanes and element types. `$result = shift($lhs, $rhs, $shift)` }]; } diff --git a/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td b/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td index 1d1a70bafd..37e8a12ae9 100644 --- a/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td +++ b/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td @@ -200,8 +200,7 @@ def ExtI256I1024IntrOp : def ExtI128I512IntrOp : AIEVec2_IntrOp<"extract.I128.I512", [TypeIs<"res", VectorOfLengthAndType<[4], [I32]>>]>, - Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src, - I32:$idx)>; + Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src)>; // ----- CONCAT ----- diff --git a/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp b/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp index e3899d8ad0..9c1c3d5bd5 100644 --- a/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp +++ b/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp @@ -1300,9 +1300,35 @@ class ExtOpConversion : public mlir::ConvertOpToLLVMPattern { rewriter, loc, operands, {VectorType::get({32}, rewriter.getI32Type()), rewriter.getI32Type()})); - // TODO: handle below case - // } else if (resultVectorSize == 128 && srcVectorSize == 512) { - // Special case + } else if (resultVectorSize == 128 && srcVectorSize == 512) { + auto shiftOp = adaptor.getSource(); + if (op.getIndex() > 0) { + auto undefOp = rewriter.create( + loc, VectorType::get({16}, rewriter.getI32Type())); + auto stepCst = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); + auto shiftCst = rewriter.create( + loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(op.getIndex() * 16)); + SmallVector shiftOperands{adaptor.getSource(), undefOp, stepCst, + shiftCst}; + // Right shift the source vector in index * 16 bytes (i.e. in index * + // 128 bits). The integer index is expected to be 0 to 3. + shiftOp = rewriter.create( + loc, VectorType::get({16}, rewriter.getI32Type()), + forceCastOperandsToSignature( + rewriter, loc, shiftOperands, + {VectorType::get({16}, rewriter.getI32Type()), + VectorType::get({16}, rewriter.getI32Type()), + rewriter.getI32Type(), rewriter.getI32Type()})); + } + // The underlying intrinsic takes a source vector and extract the lowest + // 128-bit. i.e. it always extracts the input vector with index = 0. + extOp = rewriter.create( + loc, VectorType::get({4}, rewriter.getI32Type()), + forceCastOperandsToSignature( + rewriter, loc, /*operands=*/{shiftOp}, + {VectorType::get({16}, rewriter.getI32Type())})); } else { op.emitWarning() << "aievec.ext with " << srcVectorSize << "-bit source, and " << resultVectorSize diff --git a/test/Conversion/AIEVecToLLVM/test-ext.mlir b/test/Conversion/AIEVecToLLVM/test-ext.mlir index a4f3b93c73..87d33af53c 100644 --- a/test/Conversion/AIEVecToLLVM/test-ext.mlir +++ b/test/Conversion/AIEVecToLLVM/test-ext.mlir @@ -69,6 +69,33 @@ func.func @v32i8_ext_v128i8(%arg0 : vector<128xi8>) -> (vector<32xi8>, vector<32 // ----- +func.func @v16i8_ext_v64i8(%arg0 : vector<64xi8>) -> (vector<16xi8>, vector<16xi8>) { + %0 = aievec.ext %arg0 {index = 0 : i8} : vector<64xi8>, vector<16xi8> + %1 = aievec.ext %arg0 {index = 3 : i8} : vector<64xi8>, vector<16xi8> + return %0, %1 : vector<16xi8>, vector<16xi8> +} + +// CHECK-LABEL: @v16i8_ext_v64i8 +// CHECK-SAME: %[[ARG0:.*]]: vector<64xi8> +// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<64xi8> to vector<16xi32> +// CHECK-NEXT: %[[EXT0:.*]] = "xllvm.intr.aie2.extract.I128.I512"( +// CHECK-SAME: %[[BITCAST0]]) : +// CHECK-SAME: (vector<16xi32>) -> vector<4xi32> +// CHECK-NEXT: %[[RES0:.*]] = llvm.bitcast %[[EXT0]] : vector<4xi32> to vector<16xi8> +// CHECK-NEXT: %[[UNDEF:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32> +// CHECK-NEXT: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NEXT: %[[CST48:.*]] = llvm.mlir.constant(48 : i32) : i32 +// CHECK-NEXT: %[[VSHIFT1:.*]] = "xllvm.intr.aie2.vshift.I512.I512"( +// CHECK-SAME: %[[BITCAST0]], %[[UNDEF]], %[[CST0]], %[[CST48]]) : +// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32> +// CHECK-NEXT: %[[EXT1:.*]] = "xllvm.intr.aie2.extract.I128.I512"( +// CHECK-SAME: %[[VSHIFT1]]) : +// CHECK-SAME: (vector<16xi32>) -> vector<4xi32> +// CHECK-NEXT: %[[RES1:.*]] = llvm.bitcast %[[EXT1]] : vector<4xi32> to vector<16xi8> +// CHECK-NEXT: return %[[RES0]], %[[RES1]] : vector<16xi8>, vector<16xi8> + +// ----- + func.func @v16i16_ext_v32i16(%arg0 : vector<32xi16>) -> (vector<16xi16>, vector<16xi16>) { %0 = aievec.ext %arg0 {index = 0 : i8} : vector<32xi16>, vector<16xi16> %1 = aievec.ext %arg0 {index = 1 : i8} : vector<32xi16>, vector<16xi16> @@ -138,6 +165,33 @@ func.func @v16i16_ext_v64i16(%arg0 : vector<64xi16>) -> (vector<16xi16>, vector< // ----- +func.func @v8i16_ext_v32i16(%arg0 : vector<32xi16>) -> (vector<8xi16>, vector<8xi16>) { + %0 = aievec.ext %arg0 {index = 0 : i8} : vector<32xi16>, vector<8xi16> + %1 = aievec.ext %arg0 {index = 3 : i8} : vector<32xi16>, vector<8xi16> + return %0, %1 : vector<8xi16>, vector<8xi16> +} + +// CHECK-LABEL: @v8i16_ext_v32i16 +// CHECK-SAME: %[[ARG0:.*]]: vector<32xi16> +// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<32xi16> to vector<16xi32> +// CHECK-NEXT: %[[EXT0:.*]] = "xllvm.intr.aie2.extract.I128.I512"( +// CHECK-SAME: %[[BITCAST0]]) : +// CHECK-SAME: (vector<16xi32>) -> vector<4xi32> +// CHECK-NEXT: %[[RES0:.*]] = llvm.bitcast %[[EXT0]] : vector<4xi32> to vector<8xi16> +// CHECK-NEXT: %[[UNDEF:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32> +// CHECK-NEXT: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NEXT: %[[CST48:.*]] = llvm.mlir.constant(48 : i32) : i32 +// CHECK-NEXT: %[[VSHIFT1:.*]] = "xllvm.intr.aie2.vshift.I512.I512"( +// CHECK-SAME: %[[BITCAST0]], %[[UNDEF]], %[[CST0]], %[[CST48]]) : +// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32> +// CHECK-NEXT: %[[EXT1:.*]] = "xllvm.intr.aie2.extract.I128.I512"( +// CHECK-SAME: %[[VSHIFT1]]) : +// CHECK-SAME: (vector<16xi32>) -> vector<4xi32> +// CHECK-NEXT: %[[RES1:.*]] = llvm.bitcast %[[EXT1]] : vector<4xi32> to vector<8xi16> +// CHECK-NEXT: return %[[RES0]], %[[RES1]] : vector<8xi16>, vector<8xi16> + +// ----- + func.func @v8i32_ext_v16i32(%arg0 : vector<16xi32>) -> (vector<8xi32>, vector<8xi32>) { %0 = aievec.ext %arg0 {index = 0 : i8} : vector<16xi32>, vector<8xi32> %1 = aievec.ext %arg0 {index = 1 : i8} : vector<16xi32>, vector<8xi32> @@ -198,6 +252,30 @@ func.func @v8i32_ext_v32i32(%arg0 : vector<32xi32>) -> (vector<8xi32>, vector<8x // ----- +func.func @v4i32_ext_v16i32(%arg0 : vector<16xi32>) -> (vector<4xi32>, vector<4xi32>) { + %0 = aievec.ext %arg0 {index = 0 : i8} : vector<16xi32>, vector<4xi32> + %1 = aievec.ext %arg0 {index = 3 : i8} : vector<16xi32>, vector<4xi32> + return %0, %1 : vector<4xi32>, vector<4xi32> +} + +// CHECK-LABEL: @v4i32_ext_v16i32 +// CHECK-SAME: %[[ARG0:.*]]: vector<16xi32> +// CHECK: %[[EXT0:.*]] = "xllvm.intr.aie2.extract.I128.I512"( +// CHECK-SAME: %[[ARG0]]) : +// CHECK-SAME: (vector<16xi32>) -> vector<4xi32> +// CHECK-NEXT: %[[UNDEF:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32> +// CHECK-NEXT: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NEXT: %[[CST48:.*]] = llvm.mlir.constant(48 : i32) : i32 +// CHECK-NEXT: %[[VSHIFT1:.*]] = "xllvm.intr.aie2.vshift.I512.I512"( +// CHECK-SAME: %[[ARG0]], %[[UNDEF]], %[[CST0]], %[[CST48]]) : +// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32> +// CHECK-NEXT: %[[EXT1:.*]] = "xllvm.intr.aie2.extract.I128.I512"( +// CHECK-SAME: %[[VSHIFT1]]) : +// CHECK-SAME: (vector<16xi32>) -> vector<4xi32> +// CHECK-NEXT: return %[[EXT0]], %[[EXT1]] : vector<4xi32>, vector<4xi32> + +// ----- + func.func @v16bf16_ext_v32bf16(%arg0 : vector<32xbf16>) -> (vector<16xbf16>, vector<16xbf16>) { %0 = aievec.ext %arg0 {index = 0 : i8} : vector<32xbf16>, vector<16xbf16> %1 = aievec.ext %arg0 {index = 1 : i8} : vector<32xbf16>, vector<16xbf16> @@ -267,6 +345,33 @@ func.func @v16bf16_ext_v64bf16(%arg0 : vector<64xbf16>) -> (vector<16xbf16>, vec // ----- +func.func @v8bf16_ext_v32bf16(%arg0 : vector<32xbf16>) -> (vector<8xbf16>, vector<8xbf16>) { + %0 = aievec.ext %arg0 {index = 0 : i8} : vector<32xbf16>, vector<8xbf16> + %1 = aievec.ext %arg0 {index = 3 : i8} : vector<32xbf16>, vector<8xbf16> + return %0, %1 : vector<8xbf16>, vector<8xbf16> +} + +// CHECK-LABEL: @v8bf16_ext_v32bf16 +// CHECK-SAME: %[[ARG0:.*]]: vector<32xbf16> +// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<32xbf16> to vector<16xi32> +// CHECK-NEXT: %[[EXT0:.*]] = "xllvm.intr.aie2.extract.I128.I512"( +// CHECK-SAME: %[[BITCAST0]]) : +// CHECK-SAME: (vector<16xi32>) -> vector<4xi32> +// CHECK-NEXT: %[[RES0:.*]] = llvm.bitcast %[[EXT0]] : vector<4xi32> to vector<8xbf16> +// CHECK-NEXT: %[[UNDEF:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32> +// CHECK-NEXT: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NEXT: %[[CST48:.*]] = llvm.mlir.constant(48 : i32) : i32 +// CHECK-NEXT: %[[VSHIFT1:.*]] = "xllvm.intr.aie2.vshift.I512.I512"( +// CHECK-SAME: %[[BITCAST0]], %[[UNDEF]], %[[CST0]], %[[CST48]]) : +// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32> +// CHECK-NEXT: %[[EXT1:.*]] = "xllvm.intr.aie2.extract.I128.I512"( +// CHECK-SAME: %[[VSHIFT1]]) : +// CHECK-SAME: (vector<16xi32>) -> vector<4xi32> +// CHECK-NEXT: %[[RES1:.*]] = llvm.bitcast %[[EXT1]] : vector<4xi32> to vector<8xbf16> +// CHECK-NEXT: return %[[RES0]], %[[RES1]] : vector<8xbf16>, vector<8xbf16> + +// ----- + func.func @v8f32_ext_v16f32(%arg0 : vector<16xf32>) -> (vector<8xf32>, vector<8xf32>) { %0 = aievec.ext %arg0 {index = 0 : i8} : vector<16xf32>, vector<8xf32> %1 = aievec.ext %arg0 {index = 1 : i8} : vector<16xf32>, vector<8xf32> @@ -333,3 +438,30 @@ func.func @v8f32_ext_v32f32(%arg0 : vector<32xf32>) -> (vector<8xf32>, vector<8x // CHECK-SAME: (vector<32xi32>, i32) -> vector<8xi32> // CHECK-NEXT: %[[RES1:.*]] = llvm.bitcast %[[EXT1]] : vector<8xi32> to vector<8xf32> // CHECK-NEXT: return %[[RES0]], %[[RES1]] : vector<8xf32>, vector<8xf32> + +// ----- + +func.func @v4f32_ext_v16f32(%arg0 : vector<16xf32>) -> (vector<4xf32>, vector<4xf32>) { + %0 = aievec.ext %arg0 {index = 0 : i8} : vector<16xf32>, vector<4xf32> + %1 = aievec.ext %arg0 {index = 3 : i8} : vector<16xf32>, vector<4xf32> + return %0, %1 : vector<4xf32>, vector<4xf32> +} + +// CHECK-LABEL: @v4f32_ext_v16f32 +// CHECK-SAME: %[[ARG0:.*]]: vector<16xf32> +// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<16xf32> to vector<16xi32> +// CHECK-NEXT: %[[EXT0:.*]] = "xllvm.intr.aie2.extract.I128.I512"( +// CHECK-SAME: %[[BITCAST0]]) : +// CHECK-SAME: (vector<16xi32>) -> vector<4xi32> +// CHECK-NEXT: %[[RES0:.*]] = llvm.bitcast %[[EXT0]] : vector<4xi32> to vector<4xf32> +// CHECK-NEXT: %[[UNDEF:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32> +// CHECK-NEXT: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NEXT: %[[CST48:.*]] = llvm.mlir.constant(48 : i32) : i32 +// CHECK-NEXT: %[[VSHIFT1:.*]] = "xllvm.intr.aie2.vshift.I512.I512"( +// CHECK-SAME: %[[BITCAST0]], %[[UNDEF]], %[[CST0]], %[[CST48]]) : +// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32> +// CHECK-NEXT: %[[EXT1:.*]] = "xllvm.intr.aie2.extract.I128.I512"( +// CHECK-SAME: %[[VSHIFT1]]) : +// CHECK-SAME: (vector<16xi32>) -> vector<4xi32> +// CHECK-NEXT: %[[RES1:.*]] = llvm.bitcast %[[EXT1]] : vector<4xi32> to vector<4xf32> +// CHECK-NEXT: return %[[RES0]], %[[RES1]] : vector<4xf32>, vector<4xf32> diff --git a/test/Target/LLVMIR/aievec.mlir b/test/Target/LLVMIR/aievec.mlir index c0d386a715..fbea27805b 100644 --- a/test/Target/LLVMIR/aievec.mlir +++ b/test/Target/LLVMIR/aievec.mlir @@ -241,11 +241,11 @@ llvm.func @ext_i256_i1024(%v : vector<32xi32>, %idx : i32) -> vector<8xi32> { } // CHECK-LABEL: define <4 x i32> @ext_i128_i512 -llvm.func @ext_i128_i512(%v : vector<16xi32>, %idx : i32) -> vector<4xi32> { +llvm.func @ext_i128_i512(%v : vector<16xi32>) -> vector<4xi32> { // CHECK: call <4 x i32> @llvm.aie2.extract.I128.I512( - // CHECK-SAME: <16 x i32> %{{[0-9]+}}, i32 %{{[0-9]+}}) - %1 = "xllvm.intr.aie2.extract.I128.I512"(%v, %idx) : - (vector<16xi32>, i32) -> vector<4xi32> + // CHECK-SAME: <16 x i32> %{{[0-9]+}}) + %1 = "xllvm.intr.aie2.extract.I128.I512"(%v) : + (vector<16xi32>) -> vector<4xi32> llvm.return %1 : vector<4xi32> }