Skip to content

Commit

Permalink
[aievec] to-llvm flow for aievec.concat op (#1454)
Browse files Browse the repository at this point in the history
* This PR add the support for aievec.concat op going through the to-llvm flow.
* Add aievec-to-llvm conversion pattern/tests for the aievec.concat op.
* Add target external llvm translation tests.
  • Loading branch information
jamestcl-amd authored May 7, 2024
1 parent 7cb4396 commit 4868500
Show file tree
Hide file tree
Showing 4 changed files with 324 additions and 10 deletions.
14 changes: 14 additions & 0 deletions include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,20 @@ def ConcatI512I256IntrOp :
Arguments<(ins VectorOfLengthAndType<[8], [I32]>:$lhs,
VectorOfLengthAndType<[8], [I32]>:$rhs)>;

def ConcatI1024I256IntrOp :
AIEVec2_IntrOp<"concat.I1024.I256",
[TypeIs<"res", VectorOfLengthAndType<[32], [I32]>>]>,
Arguments<(ins VectorOfLengthAndType<[8], [I32]>:$src0,
VectorOfLengthAndType<[8], [I32]>:$src1,
VectorOfLengthAndType<[8], [I32]>:$src2,
VectorOfLengthAndType<[8], [I32]>:$src3)>;

def ConcatI1024I512IntrOp :
AIEVec2_IntrOp<"concat.I1024.I512",
[TypeIs<"res", VectorOfLengthAndType<[32], [I32]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$lhs,
VectorOfLengthAndType<[16], [I32]>:$rhs)>;

// ----- SHUFFLE -----

def VectorShuffleIntrOp :
Expand Down
57 changes: 51 additions & 6 deletions lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1034,13 +1034,58 @@ class ConcatOpConversion
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();

SmallVector<Value> sources = adaptor.getSources();
Value src = sources.front();
VectorType srcType = cast<VectorType>(src.getType());
Type srcScalarType = srcType.getElementType();
unsigned srcBitWidth = srcScalarType.getIntOrFloatBitWidth();
int srcLanes = getVectorLaneSize(srcType);
int srcVectorSize = srcBitWidth * srcLanes;

Value result = op.getResult();
VectorType resultType = cast<VectorType>(result.getType());
Type resultScaTy = resultType.getElementType();
unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
int resultLanes = getVectorLaneSize(resultType);
int resultVectorSize = resultBitWidth * resultLanes;

if (sources.size() != 2 && sources.size() != 4) {
op.emitWarning() << "aievec.concat with " << sources.size()
<< " operands is not supported.\n";
return failure();
}

// create xllvm intrinsic
auto concatOp = rewriter.create<xllvm::ConcatI512I256IntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, adaptor.getSources(),
{VectorType::get({8}, rewriter.getI32Type()),
VectorType::get({8}, rewriter.getI32Type())}));
Value concatOp = nullptr;
if (srcVectorSize == 256 && resultVectorSize == 512) {
concatOp = rewriter.create<xllvm::ConcatI512I256IntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, adaptor.getSources(),
{VectorType::get({8}, rewriter.getI32Type()),
VectorType::get({8}, rewriter.getI32Type())}));
} else if (srcVectorSize == 256 && resultVectorSize == 1024) {
concatOp = rewriter.create<xllvm::ConcatI1024I256IntrOp>(
loc, VectorType::get({32}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, adaptor.getSources(),
{VectorType::get({8}, rewriter.getI32Type()),
VectorType::get({8}, rewriter.getI32Type()),
VectorType::get({8}, rewriter.getI32Type()),
VectorType::get({8}, rewriter.getI32Type())}));
} else if (srcVectorSize == 512 && resultVectorSize == 1024) {
concatOp = rewriter.create<xllvm::ConcatI1024I512IntrOp>(
loc, VectorType::get({32}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, adaptor.getSources(),
{VectorType::get({16}, rewriter.getI32Type()),
VectorType::get({16}, rewriter.getI32Type())}));
} else {
op.emitWarning() << "aievec.concat with " << srcVectorSize
<< "-bit operands, and " << resultVectorSize
<< "-bit result is not supported.\n";
return failure();
}

// create bitcast for result
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
Expand Down
243 changes: 239 additions & 4 deletions test/Conversion/AIEVecToLLVM/test-concat.mlir
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// RUN: aie-opt %s -split-input-file --convert-aievec-to-llvm | FileCheck %s

func.func @i8_concat(%arg0 : vector<32xi8>, %arg1 : vector<32xi8>) -> vector<64xi8> {
func.func @v64i8_concat_v32i8(%arg0 : vector<32xi8>, %arg1 : vector<32xi8>) -> vector<64xi8> {
%0 = aievec.concat %arg0, %arg1 : vector<32xi8>, vector<64xi8>
return %0 : vector<64xi8>
}

// CHECK-LABEL: @i8_concat
// CHECK-LABEL: @v64i8_concat_v32i8
// CHECK-SAME: %[[ARG0:.*]]: vector<32xi8>,
// CHECK-SAME: %[[ARG1:.*]]: vector<32xi8>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<32xi8> to vector<8xi32>
Expand All @@ -18,12 +18,154 @@ func.func @i8_concat(%arg0 : vector<32xi8>, %arg1 : vector<32xi8>) -> vector<64x

// -----

func.func @bf16_concat(%arg0 : vector<16xbf16>, %arg1 : vector<16xbf16>) -> vector<32xbf16> {
func.func @v128i8_concat_v32i8(%arg0 : vector<32xi8>, %arg1 : vector<32xi8>,
%arg2 : vector<32xi8>, %arg3 : vector<32xi8>) -> vector<128xi8> {
%0 = aievec.concat %arg0, %arg1, %arg2, %arg3 : vector<32xi8>, vector<128xi8>
return %0 : vector<128xi8>
}

// CHECK-LABEL: @v128i8_concat_v32i8
// CHECK-SAME: %[[ARG0:.*]]: vector<32xi8>, %[[ARG1:.*]]: vector<32xi8>, %[[ARG2:.*]]: vector<32xi8>, %[[ARG3:.*]]: vector<32xi8>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<32xi8> to vector<8xi32>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[ARG1]] : vector<32xi8> to vector<8xi32>
// CHECK-NEXT: %[[BITCAST2:.*]] = llvm.bitcast %[[ARG2]] : vector<32xi8> to vector<8xi32>
// CHECK-NEXT: %[[BITCAST3:.*]] = llvm.bitcast %[[ARG3]] : vector<32xi8> to vector<8xi32>
// CHECK-NEXT: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I1024.I256"(
// CHECK-SAME: %[[BITCAST0]], %[[BITCAST1]], %[[BITCAST2]], %[[BITCAST3]]) :
// CHECK-SAME: (vector<8xi32>, vector<8xi32>, vector<8xi32>, vector<8xi32>) -> vector<32xi32>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<32xi32> to vector<128xi8>
// CHECK-NEXT: return %[[RES]] : vector<128xi8>

// -----

func.func @v128i8_concat_v64i8(%arg0 : vector<64xi8>, %arg1 : vector<64xi8>) -> vector<128xi8> {
%0 = aievec.concat %arg0, %arg1 : vector<64xi8>, vector<128xi8>
return %0 : vector<128xi8>
}

// CHECK-LABEL: @v128i8_concat_v64i8
// CHECK-SAME: %[[ARG0:.*]]: vector<64xi8>,
// CHECK-SAME: %[[ARG1:.*]]: vector<64xi8>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<64xi8> to vector<16xi32>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[ARG1]] : vector<64xi8> to vector<16xi32>
// CHECK-NEXT: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I1024.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[BITCAST1]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>) -> vector<32xi32>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<32xi32> to vector<128xi8>
// CHECK-NEXT: return %[[RES]] : vector<128xi8>

// -----

func.func @v32i16_concat_v16i16(%arg0 : vector<16xi16>, %arg1 : vector<16xi16>) -> vector<32xi16> {
%0 = aievec.concat %arg0, %arg1 : vector<16xi16>, vector<32xi16>
return %0 : vector<32xi16>
}

// CHECK-LABEL: @v32i16_concat_v16i16
// CHECK-SAME: %[[ARG0:.*]]: vector<16xi16>,
// CHECK-SAME: %[[ARG1:.*]]: vector<16xi16>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<16xi16> to vector<8xi32>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[ARG1]] : vector<16xi16> to vector<8xi32>
// CHECK-NEXT: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I512.I256"(
// CHECK-SAME: %[[BITCAST0]], %[[BITCAST1]]) :
// CHECK-SAME: (vector<8xi32>, vector<8xi32>) -> vector<16xi32>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<16xi32> to vector<32xi16>
// CHECK-NEXT: return %[[RES]] : vector<32xi16>

// -----

func.func @v64i16_concat_v16i16(%arg0 : vector<16xi16>, %arg1 : vector<16xi16>,
%arg2 : vector<16xi16>, %arg3 : vector<16xi16>) -> vector<64xi16> {
%0 = aievec.concat %arg0, %arg1, %arg2, %arg3 : vector<16xi16>, vector<64xi16>
return %0 : vector<64xi16>
}

// CHECK-LABEL: @v64i16_concat_v16i16
// CHECK-SAME: %[[ARG0:.*]]: vector<16xi16>, %[[ARG1:.*]]: vector<16xi16>, %[[ARG2:.*]]: vector<16xi16>, %[[ARG3:.*]]: vector<16xi16>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<16xi16> to vector<8xi32>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[ARG1]] : vector<16xi16> to vector<8xi32>
// CHECK-NEXT: %[[BITCAST2:.*]] = llvm.bitcast %[[ARG2]] : vector<16xi16> to vector<8xi32>
// CHECK-NEXT: %[[BITCAST3:.*]] = llvm.bitcast %[[ARG3]] : vector<16xi16> to vector<8xi32>
// CHECK-NEXT: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I1024.I256"(
// CHECK-SAME: %[[BITCAST0]], %[[BITCAST1]], %[[BITCAST2]], %[[BITCAST3]]) :
// CHECK-SAME: (vector<8xi32>, vector<8xi32>, vector<8xi32>, vector<8xi32>) -> vector<32xi32>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<32xi32> to vector<64xi16>
// CHECK-NEXT: return %[[RES]] : vector<64xi16>

// -----

func.func @v64i16_concat_v32i16(%arg0 : vector<32xi16>, %arg1 : vector<32xi16>) -> vector<64xi16> {
%0 = aievec.concat %arg0, %arg1 : vector<32xi16>, vector<64xi16>
return %0 : vector<64xi16>
}

// CHECK-LABEL: @v64i16_concat_v32i16
// CHECK-SAME: %[[ARG0:.*]]: vector<32xi16>,
// CHECK-SAME: %[[ARG1:.*]]: vector<32xi16>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<32xi16> to vector<16xi32>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[ARG1]] : vector<32xi16> to vector<16xi32>
// CHECK-NEXT: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I1024.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[BITCAST1]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>) -> vector<32xi32>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<32xi32> to vector<64xi16>
// CHECK-NEXT: return %[[RES]] : vector<64xi16>

// -----

func.func @v16i32_concat_v8i32(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>) -> vector<16xi32> {
%0 = aievec.concat %arg0, %arg1 : vector<8xi32>, vector<16xi32>
return %0 : vector<16xi32>
}

// CHECK-LABEL: @v16i32_concat_v8i32
// CHECK-SAME: %[[ARG0:.*]]: vector<8xi32>,
// CHECK-SAME: %[[ARG1:.*]]: vector<8xi32>
// CHECK: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I512.I256"(
// CHECK-SAME: %[[ARG0]], %[[ARG1]]) :
// CHECK-SAME: (vector<8xi32>, vector<8xi32>) -> vector<16xi32>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<16xi32> to vector<16xi32>
// CHECK-NEXT: return %[[RES]] : vector<16xi32>

// -----

func.func @v32i32_concat_v8i32(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>,
%arg2 : vector<8xi32>, %arg3 : vector<8xi32>) -> vector<32xi32> {
%0 = aievec.concat %arg0, %arg1, %arg2, %arg3 : vector<8xi32>, vector<32xi32>
return %0 : vector<32xi32>
}

// CHECK-LABEL: @v32i32_concat_v8i32
// CHECK-SAME: %[[ARG0:.*]]: vector<8xi32>, %[[ARG1:.*]]: vector<8xi32>, %[[ARG2:.*]]: vector<8xi32>, %[[ARG3:.*]]: vector<8xi32>
// CHECK: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I1024.I256"(
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]) :
// CHECK-SAME: (vector<8xi32>, vector<8xi32>, vector<8xi32>, vector<8xi32>) -> vector<32xi32>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<32xi32> to vector<32xi32>
// CHECK-NEXT: return %[[RES]] : vector<32xi32>

// -----

func.func @v32i32_concat_v16i32(%arg0 : vector<16xi32>, %arg1 : vector<16xi32>) -> vector<32xi32> {
%0 = aievec.concat %arg0, %arg1 : vector<16xi32>, vector<32xi32>
return %0 : vector<32xi32>
}

// CHECK-LABEL: @v32i32_concat_v16i32
// CHECK-SAME: %[[ARG0:.*]]: vector<16xi32>,
// CHECK-SAME: %[[ARG1:.*]]: vector<16xi32>
// CHECK: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I1024.I512"(
// CHECK-SAME: %[[ARG0]], %[[ARG1]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>) -> vector<32xi32>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<32xi32> to vector<32xi32>
// CHECK-NEXT: return %[[RES]] : vector<32xi32>

// -----

func.func @v32bf16_concat_v16bf16(%arg0 : vector<16xbf16>, %arg1 : vector<16xbf16>) -> vector<32xbf16> {
%0 = aievec.concat %arg0, %arg1 : vector<16xbf16>, vector<32xbf16>
return %0 : vector<32xbf16>
}

// CHECK-LABEL: @bf16_concat
// CHECK-LABEL: @v32bf16_concat_v16bf16
// CHECK-SAME: %[[ARG0:.*]]: vector<16xbf16>,
// CHECK-SAME: %[[ARG1:.*]]: vector<16xbf16>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<16xbf16> to vector<8xi32>
Expand All @@ -34,3 +176,96 @@ func.func @bf16_concat(%arg0 : vector<16xbf16>, %arg1 : vector<16xbf16>) -> vect
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<16xi32> to vector<32xbf16>
// CHECK-NEXT: return %[[RES]] : vector<32xbf16>

// -----

func.func @v64bf16_concat_v16bf16(%arg0 : vector<16xbf16>, %arg1 : vector<16xbf16>,
%arg2 : vector<16xbf16>, %arg3 : vector<16xbf16>) -> vector<64xbf16> {
%0 = aievec.concat %arg0, %arg1, %arg2, %arg3 : vector<16xbf16>, vector<64xbf16>
return %0 : vector<64xbf16>
}

// CHECK-LABEL: @v64bf16_concat_v16bf16
// CHECK-SAME: %[[ARG0:.*]]: vector<16xbf16>, %[[ARG1:.*]]: vector<16xbf16>, %[[ARG2:.*]]: vector<16xbf16>, %[[ARG3:.*]]: vector<16xbf16>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<16xbf16> to vector<8xi32>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[ARG1]] : vector<16xbf16> to vector<8xi32>
// CHECK-NEXT: %[[BITCAST2:.*]] = llvm.bitcast %[[ARG2]] : vector<16xbf16> to vector<8xi32>
// CHECK-NEXT: %[[BITCAST3:.*]] = llvm.bitcast %[[ARG3]] : vector<16xbf16> to vector<8xi32>
// CHECK-NEXT: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I1024.I256"(
// CHECK-SAME: %[[BITCAST0]], %[[BITCAST1]], %[[BITCAST2]], %[[BITCAST3]]) :
// CHECK-SAME: (vector<8xi32>, vector<8xi32>, vector<8xi32>, vector<8xi32>) -> vector<32xi32>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<32xi32> to vector<64xbf16>
// CHECK-NEXT: return %[[RES]] : vector<64xbf16>

// -----

func.func @v64bf16_concat_v32bf16(%arg0 : vector<32xbf16>, %arg1 : vector<32xbf16>) -> vector<64xbf16> {
%0 = aievec.concat %arg0, %arg1 : vector<32xbf16>, vector<64xbf16>
return %0 : vector<64xbf16>
}

// CHECK-LABEL: @v64bf16_concat_v32bf16
// CHECK-SAME: %[[ARG0:.*]]: vector<32xbf16>,
// CHECK-SAME: %[[ARG1:.*]]: vector<32xbf16>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<32xbf16> to vector<16xi32>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[ARG1]] : vector<32xbf16> to vector<16xi32>
// CHECK-NEXT: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I1024.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[BITCAST1]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>) -> vector<32xi32>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<32xi32> to vector<64xbf16>
// CHECK-NEXT: return %[[RES]] : vector<64xbf16>

// -----

func.func @v16f32_concat_v8f32(%arg0 : vector<8xf32>, %arg1 : vector<8xf32>) -> vector<16xf32> {
%0 = aievec.concat %arg0, %arg1 : vector<8xf32>, vector<16xf32>
return %0 : vector<16xf32>
}

// CHECK-LABEL: @v16f32_concat_v8f32
// CHECK-SAME: %[[ARG0:.*]]: vector<8xf32>,
// CHECK-SAME: %[[ARG1:.*]]: vector<8xf32>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<8xf32> to vector<8xi32>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[ARG1]] : vector<8xf32> to vector<8xi32>
// CHECK-NEXT: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I512.I256"(
// CHECK-SAME: %[[BITCAST0]], %[[BITCAST1]]) :
// CHECK-SAME: (vector<8xi32>, vector<8xi32>) -> vector<16xi32>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<16xi32> to vector<16xf32>
// CHECK-NEXT: return %[[RES]] : vector<16xf32>

// -----

func.func @v32f32_concat_v8f32(%arg0 : vector<8xf32>, %arg1 : vector<8xf32>,
%arg2 : vector<8xf32>, %arg3 : vector<8xf32>) -> vector<32xf32> {
%0 = aievec.concat %arg0, %arg1, %arg2, %arg3 : vector<8xf32>, vector<32xf32>
return %0 : vector<32xf32>
}

// CHECK-LABEL: @v32f32_concat_v8f32
// CHECK-SAME: %[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xf32>, %[[ARG2:.*]]: vector<8xf32>, %[[ARG3:.*]]: vector<8xf32>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<8xf32> to vector<8xi32>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[ARG1]] : vector<8xf32> to vector<8xi32>
// CHECK-NEXT: %[[BITCAST2:.*]] = llvm.bitcast %[[ARG2]] : vector<8xf32> to vector<8xi32>
// CHECK-NEXT: %[[BITCAST3:.*]] = llvm.bitcast %[[ARG3]] : vector<8xf32> to vector<8xi32>
// CHECK-NEXT: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I1024.I256"(
// CHECK-SAME: %[[BITCAST0]], %[[BITCAST1]], %[[BITCAST2]], %[[BITCAST3]]) :
// CHECK-SAME: (vector<8xi32>, vector<8xi32>, vector<8xi32>, vector<8xi32>) -> vector<32xi32>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<32xi32> to vector<32xf32>
// CHECK-NEXT: return %[[RES]] : vector<32xf32>

// -----

func.func @v32f32_concat_v16f32(%arg0 : vector<16xf32>, %arg1 : vector<16xf32>) -> vector<32xf32> {
%0 = aievec.concat %arg0, %arg1 : vector<16xf32>, vector<32xf32>
return %0 : vector<32xf32>
}

// CHECK-LABEL: @v32f32_concat_v16f32
// CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>,
// CHECK-SAME: %[[ARG1:.*]]: vector<16xf32>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<16xf32> to vector<16xi32>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[ARG1]] : vector<16xf32> to vector<16xi32>
// CHECK-NEXT: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I1024.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[BITCAST1]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>) -> vector<32xi32>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<32xi32> to vector<32xf32>
// CHECK-NEXT: return %[[RES]] : vector<32xf32>
20 changes: 20 additions & 0 deletions test/Target/LLVMIR/aievec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,26 @@ llvm.func @concat_i512_i256(%a : vector<8xi32>, %b : vector<8xi32>) -> vector<16
llvm.return %0 : vector<16xi32>
}

// CHECK-LABEL: define <32 x i32> @concat_i1024_i256
llvm.func @concat_i1024_i256(%a : vector<8xi32>, %b : vector<8xi32>,
%c : vector<8xi32>, %d : vector<8xi32>) -> vector<32xi32> {
// CHECK: call <32 x i32> @llvm.aie2.concat.I1024.I256(
// CHECK-SAME: <8 x i32> %{{[0-9]+}}, <8 x i32> %{{[0-9]+}},
// CHECK-SAME: <8 x i32> %{{[0-9]+}}, <8 x i32> %{{[0-9]+}})
%0 = "xllvm.intr.aie2.concat.I1024.I256"(%a, %b, %c, %d) :
(vector<8xi32>, vector<8xi32>, vector<8xi32>, vector<8xi32>) -> vector<32xi32>
llvm.return %0 : vector<32xi32>
}

// CHECK-LABEL: define <32 x i32> @concat_i1024_i512
llvm.func @concat_i1024_i512(%a : vector<16xi32>, %b : vector<16xi32>) -> vector<32xi32> {
// CHECK: call <32 x i32> @llvm.aie2.concat.I1024.I512(
// CHECK-SAME: <16 x i32> %{{[0-9]+}}, <16 x i32> %{{[0-9]+}})
%0 = "xllvm.intr.aie2.concat.I1024.I512"(%a, %b) :
(vector<16xi32>, vector<16xi32>) -> vector<32xi32>
llvm.return %0 : vector<32xi32>
}

// ----- SHUFFLE -----

// CHECK-LABEL: define <16 x i32> @shuffle_i512
Expand Down

0 comments on commit 4868500

Please sign in to comment.