Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aievec] to-llvm flow for aievec.ext op with I128.I512 extract intrinsic #1490

Merged
merged 7 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ def ExtI256I1024IntrOp :
def ExtI128I512IntrOp :
AIEVec2_IntrOp<"extract.I128.I512",
[TypeIs<"res", VectorOfLengthAndType<[4], [I32]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src,
I32:$idx)>;
Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src)>;

// ----- CONCAT -----

Expand Down
27 changes: 24 additions & 3 deletions lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1300,9 +1300,30 @@ class ExtOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::ExtOp> {
rewriter, loc, operands,
{VectorType::get({32}, rewriter.getI32Type()),
rewriter.getI32Type()}));
// TODO: handle below case
// } else if (resultVectorSize == 128 && srcVectorSize == 512) {
// Special case
} else if (resultVectorSize == 128 && srcVectorSize == 512) {
auto undefOp = rewriter.create<xllvm::UndefV16I32IntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()));
auto stepCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
auto shiftCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(op.getIndex() * 16));
SmallVector<Value> shiftOperands{adaptor.getSource(), undefOp, stepCst,
shiftCst};
// right shift the source in index * 16 bytes (i.e. index * 128 bits)
auto shiftOp = rewriter.create<xllvm::VectorShiftI512I512IntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, shiftOperands,
{VectorType::get({16}, rewriter.getI32Type()),
VectorType::get({16}, rewriter.getI32Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
// the intrinsic takes 1 source vector and extract the first 128-bit
extOp = rewriter.create<xllvm::ExtI128I512IntrOp>(
loc, VectorType::get({4}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, /*operands=*/{shiftOp},
{VectorType::get({16}, rewriter.getI32Type())}));
} else {
op.emitWarning() << "aievec.ext with " << srcVectorSize
<< "-bit source, and " << resultVectorSize
Expand Down
152 changes: 152 additions & 0 deletions test/Conversion/AIEVecToLLVM/test-ext.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,37 @@ func.func @v32i8_ext_v128i8(%arg0 : vector<128xi8>) -> (vector<32xi8>, vector<32

// -----

func.func @v16i8_ext_v64i8(%arg0 : vector<64xi8>) -> (vector<16xi8>, vector<16xi8>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<64xi8>, vector<16xi8>
%1 = aievec.ext %arg0 {index = 3 : i8} : vector<64xi8>, vector<16xi8>
return %0, %1 : vector<16xi8>, vector<16xi8>
}

// CHECK-LABEL: @v16i8_ext_v64i8
// CHECK-SAME: %[[ARG0:.*]]: vector<64xi8>
// CHECK: %[[UNDEF0:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32>
// CHECK-NEXT: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<64xi8> to vector<16xi32>
// CHECK-NEXT: %[[VSHIFT0:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[UNDEF0]], %[[CST0]], %[[CST0]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[EXT0:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[VSHIFT0]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[RES0:.*]] = llvm.bitcast %[[EXT0]] : vector<4xi32> to vector<16xi8>
// CHECK-NEXT: %[[UNDEF1:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32>
// CHECK-NEXT: %[[CST48:.*]] = llvm.mlir.constant(48 : i32) : i32
// CHECK-NEXT: %[[VSHIFT1:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[UNDEF1]], %[[CST0]], %[[CST48]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[EXT1:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[VSHIFT1]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[RES1:.*]] = llvm.bitcast %[[EXT1]] : vector<4xi32> to vector<16xi8>
// CHECK-NEXT: return %[[RES0]], %[[RES1]] : vector<16xi8>, vector<16xi8>

// -----

func.func @v16i16_ext_v32i16(%arg0 : vector<32xi16>) -> (vector<16xi16>, vector<16xi16>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<32xi16>, vector<16xi16>
%1 = aievec.ext %arg0 {index = 1 : i8} : vector<32xi16>, vector<16xi16>
Expand Down Expand Up @@ -138,6 +169,37 @@ func.func @v16i16_ext_v64i16(%arg0 : vector<64xi16>) -> (vector<16xi16>, vector<

// -----

func.func @v8i16_ext_v32i16(%arg0 : vector<32xi16>) -> (vector<8xi16>, vector<8xi16>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<32xi16>, vector<8xi16>
%1 = aievec.ext %arg0 {index = 3 : i8} : vector<32xi16>, vector<8xi16>
return %0, %1 : vector<8xi16>, vector<8xi16>
}

// CHECK-LABEL: @v8i16_ext_v32i16
// CHECK-SAME: %[[ARG0:.*]]: vector<32xi16>
// CHECK: %[[UNDEF0:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32>
// CHECK-NEXT: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<32xi16> to vector<16xi32>
// CHECK-NEXT: %[[VSHIFT0:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[UNDEF0]], %[[CST0]], %[[CST0]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[EXT0:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[VSHIFT0]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[RES0:.*]] = llvm.bitcast %[[EXT0]] : vector<4xi32> to vector<8xi16>
// CHECK-NEXT: %[[UNDEF1:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32>
// CHECK-NEXT: %[[CST48:.*]] = llvm.mlir.constant(48 : i32) : i32
// CHECK-NEXT: %[[VSHIFT1:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[UNDEF1]], %[[CST0]], %[[CST48]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[EXT1:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[VSHIFT1]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[RES1:.*]] = llvm.bitcast %[[EXT1]] : vector<4xi32> to vector<8xi16>
// CHECK-NEXT: return %[[RES0]], %[[RES1]] : vector<8xi16>, vector<8xi16>

// -----

func.func @v8i32_ext_v16i32(%arg0 : vector<16xi32>) -> (vector<8xi32>, vector<8xi32>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<16xi32>, vector<8xi32>
%1 = aievec.ext %arg0 {index = 1 : i8} : vector<16xi32>, vector<8xi32>
Expand Down Expand Up @@ -198,6 +260,34 @@ func.func @v8i32_ext_v32i32(%arg0 : vector<32xi32>) -> (vector<8xi32>, vector<8x

// -----

func.func @v4i32_ext_v16i32(%arg0 : vector<16xi32>) -> (vector<4xi32>, vector<4xi32>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<16xi32>, vector<4xi32>
%1 = aievec.ext %arg0 {index = 3 : i8} : vector<16xi32>, vector<4xi32>
return %0, %1 : vector<4xi32>, vector<4xi32>
}

// CHECK-LABEL: @v4i32_ext_v16i32
// CHECK-SAME: %[[ARG0:.*]]: vector<16xi32>
// CHECK: %[[UNDEF0:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32>
// CHECK-NEXT: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[VSHIFT0:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[ARG0]], %[[UNDEF0]], %[[CST0]], %[[CST0]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[EXT0:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[VSHIFT0]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[UNDEF1:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32>
// CHECK-NEXT: %[[CST48:.*]] = llvm.mlir.constant(48 : i32) : i32
// CHECK-NEXT: %[[VSHIFT1:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[ARG0]], %[[UNDEF1]], %[[CST0]], %[[CST48]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[EXT1:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[VSHIFT1]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: return %[[EXT0]], %[[EXT1]] : vector<4xi32>, vector<4xi32>

// -----

func.func @v16bf16_ext_v32bf16(%arg0 : vector<32xbf16>) -> (vector<16xbf16>, vector<16xbf16>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<32xbf16>, vector<16xbf16>
%1 = aievec.ext %arg0 {index = 1 : i8} : vector<32xbf16>, vector<16xbf16>
Expand Down Expand Up @@ -267,6 +357,37 @@ func.func @v16bf16_ext_v64bf16(%arg0 : vector<64xbf16>) -> (vector<16xbf16>, vec

// -----

func.func @v8bf16_ext_v32bf16(%arg0 : vector<32xbf16>) -> (vector<8xbf16>, vector<8xbf16>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<32xbf16>, vector<8xbf16>
%1 = aievec.ext %arg0 {index = 3 : i8} : vector<32xbf16>, vector<8xbf16>
return %0, %1 : vector<8xbf16>, vector<8xbf16>
}

// CHECK-LABEL: @v8bf16_ext_v32bf16
// CHECK-SAME: %[[ARG0:.*]]: vector<32xbf16>
// CHECK: %[[UNDEF0:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32>
// CHECK-NEXT: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<32xbf16> to vector<16xi32>
// CHECK-NEXT: %[[VSHIFT0:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[UNDEF0]], %[[CST0]], %[[CST0]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[EXT0:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[VSHIFT0]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[RES0:.*]] = llvm.bitcast %[[EXT0]] : vector<4xi32> to vector<8xbf16>
// CHECK-NEXT: %[[UNDEF1:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32>
// CHECK-NEXT: %[[CST48:.*]] = llvm.mlir.constant(48 : i32) : i32
// CHECK-NEXT: %[[VSHIFT1:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[UNDEF1]], %[[CST0]], %[[CST48]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[EXT1:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[VSHIFT1]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[RES1:.*]] = llvm.bitcast %[[EXT1]] : vector<4xi32> to vector<8xbf16>
// CHECK-NEXT: return %[[RES0]], %[[RES1]] : vector<8xbf16>, vector<8xbf16>

// -----

func.func @v8f32_ext_v16f32(%arg0 : vector<16xf32>) -> (vector<8xf32>, vector<8xf32>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<16xf32>, vector<8xf32>
%1 = aievec.ext %arg0 {index = 1 : i8} : vector<16xf32>, vector<8xf32>
Expand Down Expand Up @@ -333,3 +454,34 @@ func.func @v8f32_ext_v32f32(%arg0 : vector<32xf32>) -> (vector<8xf32>, vector<8x
// CHECK-SAME: (vector<32xi32>, i32) -> vector<8xi32>
// CHECK-NEXT: %[[RES1:.*]] = llvm.bitcast %[[EXT1]] : vector<8xi32> to vector<8xf32>
// CHECK-NEXT: return %[[RES0]], %[[RES1]] : vector<8xf32>, vector<8xf32>

// -----

func.func @v4f32_ext_v16f32(%arg0 : vector<16xf32>) -> (vector<4xf32>, vector<4xf32>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<16xf32>, vector<4xf32>
%1 = aievec.ext %arg0 {index = 3 : i8} : vector<16xf32>, vector<4xf32>
return %0, %1 : vector<4xf32>, vector<4xf32>
}

// CHECK-LABEL: @v4f32_ext_v16f32
// CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>
// CHECK: %[[UNDEF0:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32>
// CHECK-NEXT: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<16xf32> to vector<16xi32>
// CHECK-NEXT: %[[VSHIFT0:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[UNDEF0]], %[[CST0]], %[[CST0]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[EXT0:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[VSHIFT0]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[RES0:.*]] = llvm.bitcast %[[EXT0]] : vector<4xi32> to vector<4xf32>
// CHECK-NEXT: %[[UNDEF1:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32>
// CHECK-NEXT: %[[CST48:.*]] = llvm.mlir.constant(48 : i32) : i32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? If I understand correctly, the process to extract a subvector from a given position is to shift the whole vector by the number of bits up to the position, and then extracting the lowest part. For 16-bit element vectors and an extraction from position 3, that's 3 x 16 = 48 bit, but for 32-bit element vectors, shouldn't it be 3 x 32 = 96 bit? Am I misunderstanding something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shift amount for the shift op/intrinsic is in the number of bytes. The I128.I512 extract intrinsic is special. It always extracts the lowest 128-bit vector from the source 512-bit vector. For index=0 vector extraction, we actually don't need any vector shift beforehand, so I have updated the conversion pattern for this scenario. As for index=3 vector extraction, the shift amount is 48 bytes = 384 bits = 3*128 bits. After the right shift of 48 bytes, the I128.I512 intrinsic extracts the lowest 128-bit vector, which is correct. I have updated the aievec.ext/shift op descriptions, to explicitly clarify the shift amount in bytes and the extraction index to be either 0--1 or 0--3. Let me know if there is anything else I can do :).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aaaah! Got it, I was trying to make the maths work in my head and failing, thanks for the clarification 🙂

// CHECK-NEXT: %[[VSHIFT1:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[UNDEF1]], %[[CST0]], %[[CST48]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[EXT1:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[VSHIFT1]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[RES1:.*]] = llvm.bitcast %[[EXT1]] : vector<4xi32> to vector<4xf32>
// CHECK-NEXT: return %[[RES0]], %[[RES1]] : vector<4xf32>, vector<4xf32>
8 changes: 4 additions & 4 deletions test/Target/LLVMIR/aievec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,11 @@ llvm.func @ext_i256_i1024(%v : vector<32xi32>, %idx : i32) -> vector<8xi32> {
}

// CHECK-LABEL: define <4 x i32> @ext_i128_i512
llvm.func @ext_i128_i512(%v : vector<16xi32>, %idx : i32) -> vector<4xi32> {
llvm.func @ext_i128_i512(%v : vector<16xi32>) -> vector<4xi32> {
// CHECK: call <4 x i32> @llvm.aie2.extract.I128.I512(
// CHECK-SAME: <16 x i32> %{{[0-9]+}}, i32 %{{[0-9]+}})
%1 = "xllvm.intr.aie2.extract.I128.I512"(%v, %idx) :
(vector<16xi32>, i32) -> vector<4xi32>
// CHECK-SAME: <16 x i32> %{{[0-9]+}})
%1 = "xllvm.intr.aie2.extract.I128.I512"(%v) :
(vector<16xi32>) -> vector<4xi32>
llvm.return %1 : vector<4xi32>
}

Expand Down
Loading