Skip to content

Commit

Permalink
[aievec] to-llvm flow for aievec.ext op with I128.I512 extract intrin…
Browse files Browse the repository at this point in the history
…sic (#1490)
  • Loading branch information
jamestcl-amd authored May 16, 2024
1 parent e60b4f8 commit 3dc49bf
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 12 deletions.
11 changes: 8 additions & 3 deletions include/aie/Dialect/AIEVec/IR/AIEVecOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,11 @@ def AIEVec_ExtOp:
let description = [{
AMD-specific vector extract intrinsic. Selects contiguous lanes from
the source vector, and transfers the data from those lanes to the
result. The lane selection is controlled by index.
result. The lane selection is controlled by index. There are two cases:
1. Extracted vector fills half of the original vector lanes (e.g. extract v64int8 from v128int8)
2. Extracted vector fills a fourth of the original vector lanes (e.g. extract v32int8 from v128int8)
In the first case, index can be 0 or 1. Index 0 extracts the lower half, and index 1 extracts the upper half.
In the second case, index can be 0 to 3. Index 0 extracts the lowest quarter, index 1 the next quarter, and so on.
`$result = ext($source, $index)`
}];
}
Expand Down Expand Up @@ -576,8 +580,9 @@ def AIEVec_ShiftOp:
let description = [{
AMD-specific shift intrinsic. Concatenates two
vectors into a bigger vector, interprets them as a vector of 128 bytes
and returns v1::v2[shift: shift+64]. The verifier confirms that all the
input vectors have the same number of lanes.
and returns v1::v2[shift: shift+64]. `shift` is the number of bytes to
be shifted. The verifier confirms that all the input and result vectors
have the same number of lanes and element types.
`$result = shift($lhs, $rhs, $shift)`
}];
}
Expand Down
3 changes: 1 addition & 2 deletions include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ def ExtI256I1024IntrOp :
def ExtI128I512IntrOp :
AIEVec2_IntrOp<"extract.I128.I512",
[TypeIs<"res", VectorOfLengthAndType<[4], [I32]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src,
I32:$idx)>;
Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src)>;

// ----- CONCAT -----

Expand Down
32 changes: 29 additions & 3 deletions lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1300,9 +1300,35 @@ class ExtOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::ExtOp> {
rewriter, loc, operands,
{VectorType::get({32}, rewriter.getI32Type()),
rewriter.getI32Type()}));
// TODO: handle below case
// } else if (resultVectorSize == 128 && srcVectorSize == 512) {
// Special case
} else if (resultVectorSize == 128 && srcVectorSize == 512) {
auto shiftOp = adaptor.getSource();
if (op.getIndex() > 0) {
auto undefOp = rewriter.create<xllvm::UndefV16I32IntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()));
auto stepCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
auto shiftCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(op.getIndex() * 16));
SmallVector<Value> shiftOperands{adaptor.getSource(), undefOp, stepCst,
shiftCst};
// Right shift the source vector in index * 16 bytes (i.e. in index *
// 128 bits). The integer index is expected to be 0 to 3.
shiftOp = rewriter.create<xllvm::VectorShiftI512I512IntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, shiftOperands,
{VectorType::get({16}, rewriter.getI32Type()),
VectorType::get({16}, rewriter.getI32Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
}
// The underlying intrinsic takes a source vector and extract the lowest
// 128-bit. i.e. it always extracts the input vector with index = 0.
extOp = rewriter.create<xllvm::ExtI128I512IntrOp>(
loc, VectorType::get({4}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, /*operands=*/{shiftOp},
{VectorType::get({16}, rewriter.getI32Type())}));
} else {
op.emitWarning() << "aievec.ext with " << srcVectorSize
<< "-bit source, and " << resultVectorSize
Expand Down
132 changes: 132 additions & 0 deletions test/Conversion/AIEVecToLLVM/test-ext.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,33 @@ func.func @v32i8_ext_v128i8(%arg0 : vector<128xi8>) -> (vector<32xi8>, vector<32

// -----

func.func @v16i8_ext_v64i8(%arg0 : vector<64xi8>) -> (vector<16xi8>, vector<16xi8>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<64xi8>, vector<16xi8>
%1 = aievec.ext %arg0 {index = 3 : i8} : vector<64xi8>, vector<16xi8>
return %0, %1 : vector<16xi8>, vector<16xi8>
}

// CHECK-LABEL: @v16i8_ext_v64i8
// CHECK-SAME: %[[ARG0:.*]]: vector<64xi8>
// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<64xi8> to vector<16xi32>
// CHECK-NEXT: %[[EXT0:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[BITCAST0]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[RES0:.*]] = llvm.bitcast %[[EXT0]] : vector<4xi32> to vector<16xi8>
// CHECK-NEXT: %[[UNDEF:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32>
// CHECK-NEXT: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[CST48:.*]] = llvm.mlir.constant(48 : i32) : i32
// CHECK-NEXT: %[[VSHIFT1:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[UNDEF]], %[[CST0]], %[[CST48]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[EXT1:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[VSHIFT1]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[RES1:.*]] = llvm.bitcast %[[EXT1]] : vector<4xi32> to vector<16xi8>
// CHECK-NEXT: return %[[RES0]], %[[RES1]] : vector<16xi8>, vector<16xi8>

// -----

func.func @v16i16_ext_v32i16(%arg0 : vector<32xi16>) -> (vector<16xi16>, vector<16xi16>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<32xi16>, vector<16xi16>
%1 = aievec.ext %arg0 {index = 1 : i8} : vector<32xi16>, vector<16xi16>
Expand Down Expand Up @@ -138,6 +165,33 @@ func.func @v16i16_ext_v64i16(%arg0 : vector<64xi16>) -> (vector<16xi16>, vector<

// -----

func.func @v8i16_ext_v32i16(%arg0 : vector<32xi16>) -> (vector<8xi16>, vector<8xi16>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<32xi16>, vector<8xi16>
%1 = aievec.ext %arg0 {index = 3 : i8} : vector<32xi16>, vector<8xi16>
return %0, %1 : vector<8xi16>, vector<8xi16>
}

// CHECK-LABEL: @v8i16_ext_v32i16
// CHECK-SAME: %[[ARG0:.*]]: vector<32xi16>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<32xi16> to vector<16xi32>
// CHECK-NEXT: %[[EXT0:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[BITCAST0]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[RES0:.*]] = llvm.bitcast %[[EXT0]] : vector<4xi32> to vector<8xi16>
// CHECK-NEXT: %[[UNDEF:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32>
// CHECK-NEXT: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[CST48:.*]] = llvm.mlir.constant(48 : i32) : i32
// CHECK-NEXT: %[[VSHIFT1:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[UNDEF]], %[[CST0]], %[[CST48]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[EXT1:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[VSHIFT1]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[RES1:.*]] = llvm.bitcast %[[EXT1]] : vector<4xi32> to vector<8xi16>
// CHECK-NEXT: return %[[RES0]], %[[RES1]] : vector<8xi16>, vector<8xi16>

// -----

func.func @v8i32_ext_v16i32(%arg0 : vector<16xi32>) -> (vector<8xi32>, vector<8xi32>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<16xi32>, vector<8xi32>
%1 = aievec.ext %arg0 {index = 1 : i8} : vector<16xi32>, vector<8xi32>
Expand Down Expand Up @@ -198,6 +252,30 @@ func.func @v8i32_ext_v32i32(%arg0 : vector<32xi32>) -> (vector<8xi32>, vector<8x

// -----

func.func @v4i32_ext_v16i32(%arg0 : vector<16xi32>) -> (vector<4xi32>, vector<4xi32>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<16xi32>, vector<4xi32>
%1 = aievec.ext %arg0 {index = 3 : i8} : vector<16xi32>, vector<4xi32>
return %0, %1 : vector<4xi32>, vector<4xi32>
}

// CHECK-LABEL: @v4i32_ext_v16i32
// CHECK-SAME: %[[ARG0:.*]]: vector<16xi32>
// CHECK: %[[EXT0:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[ARG0]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[UNDEF:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32>
// CHECK-NEXT: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[CST48:.*]] = llvm.mlir.constant(48 : i32) : i32
// CHECK-NEXT: %[[VSHIFT1:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[ARG0]], %[[UNDEF]], %[[CST0]], %[[CST48]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[EXT1:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[VSHIFT1]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: return %[[EXT0]], %[[EXT1]] : vector<4xi32>, vector<4xi32>

// -----

func.func @v16bf16_ext_v32bf16(%arg0 : vector<32xbf16>) -> (vector<16xbf16>, vector<16xbf16>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<32xbf16>, vector<16xbf16>
%1 = aievec.ext %arg0 {index = 1 : i8} : vector<32xbf16>, vector<16xbf16>
Expand Down Expand Up @@ -267,6 +345,33 @@ func.func @v16bf16_ext_v64bf16(%arg0 : vector<64xbf16>) -> (vector<16xbf16>, vec

// -----

func.func @v8bf16_ext_v32bf16(%arg0 : vector<32xbf16>) -> (vector<8xbf16>, vector<8xbf16>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<32xbf16>, vector<8xbf16>
%1 = aievec.ext %arg0 {index = 3 : i8} : vector<32xbf16>, vector<8xbf16>
return %0, %1 : vector<8xbf16>, vector<8xbf16>
}

// CHECK-LABEL: @v8bf16_ext_v32bf16
// CHECK-SAME: %[[ARG0:.*]]: vector<32xbf16>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<32xbf16> to vector<16xi32>
// CHECK-NEXT: %[[EXT0:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[BITCAST0]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[RES0:.*]] = llvm.bitcast %[[EXT0]] : vector<4xi32> to vector<8xbf16>
// CHECK-NEXT: %[[UNDEF:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32>
// CHECK-NEXT: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[CST48:.*]] = llvm.mlir.constant(48 : i32) : i32
// CHECK-NEXT: %[[VSHIFT1:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[UNDEF]], %[[CST0]], %[[CST48]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[EXT1:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[VSHIFT1]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[RES1:.*]] = llvm.bitcast %[[EXT1]] : vector<4xi32> to vector<8xbf16>
// CHECK-NEXT: return %[[RES0]], %[[RES1]] : vector<8xbf16>, vector<8xbf16>

// -----

func.func @v8f32_ext_v16f32(%arg0 : vector<16xf32>) -> (vector<8xf32>, vector<8xf32>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<16xf32>, vector<8xf32>
%1 = aievec.ext %arg0 {index = 1 : i8} : vector<16xf32>, vector<8xf32>
Expand Down Expand Up @@ -333,3 +438,30 @@ func.func @v8f32_ext_v32f32(%arg0 : vector<32xf32>) -> (vector<8xf32>, vector<8x
// CHECK-SAME: (vector<32xi32>, i32) -> vector<8xi32>
// CHECK-NEXT: %[[RES1:.*]] = llvm.bitcast %[[EXT1]] : vector<8xi32> to vector<8xf32>
// CHECK-NEXT: return %[[RES0]], %[[RES1]] : vector<8xf32>, vector<8xf32>

// -----

func.func @v4f32_ext_v16f32(%arg0 : vector<16xf32>) -> (vector<4xf32>, vector<4xf32>) {
%0 = aievec.ext %arg0 {index = 0 : i8} : vector<16xf32>, vector<4xf32>
%1 = aievec.ext %arg0 {index = 3 : i8} : vector<16xf32>, vector<4xf32>
return %0, %1 : vector<4xf32>, vector<4xf32>
}

// CHECK-LABEL: @v4f32_ext_v16f32
// CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<16xf32> to vector<16xi32>
// CHECK-NEXT: %[[EXT0:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[BITCAST0]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[RES0:.*]] = llvm.bitcast %[[EXT0]] : vector<4xi32> to vector<4xf32>
// CHECK-NEXT: %[[UNDEF:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32>
// CHECK-NEXT: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[CST48:.*]] = llvm.mlir.constant(48 : i32) : i32
// CHECK-NEXT: %[[VSHIFT1:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[UNDEF]], %[[CST0]], %[[CST48]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[EXT1:.*]] = "xllvm.intr.aie2.extract.I128.I512"(
// CHECK-SAME: %[[VSHIFT1]]) :
// CHECK-SAME: (vector<16xi32>) -> vector<4xi32>
// CHECK-NEXT: %[[RES1:.*]] = llvm.bitcast %[[EXT1]] : vector<4xi32> to vector<4xf32>
// CHECK-NEXT: return %[[RES0]], %[[RES1]] : vector<4xf32>, vector<4xf32>
8 changes: 4 additions & 4 deletions test/Target/LLVMIR/aievec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,11 @@ llvm.func @ext_i256_i1024(%v : vector<32xi32>, %idx : i32) -> vector<8xi32> {
}

// CHECK-LABEL: define <4 x i32> @ext_i128_i512
llvm.func @ext_i128_i512(%v : vector<16xi32>, %idx : i32) -> vector<4xi32> {
llvm.func @ext_i128_i512(%v : vector<16xi32>) -> vector<4xi32> {
// CHECK: call <4 x i32> @llvm.aie2.extract.I128.I512(
// CHECK-SAME: <16 x i32> %{{[0-9]+}}, i32 %{{[0-9]+}})
%1 = "xllvm.intr.aie2.extract.I128.I512"(%v, %idx) :
(vector<16xi32>, i32) -> vector<4xi32>
// CHECK-SAME: <16 x i32> %{{[0-9]+}})
%1 = "xllvm.intr.aie2.extract.I128.I512"(%v) :
(vector<16xi32>) -> vector<4xi32>
llvm.return %1 : vector<4xi32>
}

Expand Down

0 comments on commit 3dc49bf

Please sign in to comment.