Skip to content

Commit

Permalink
[aievec] to-llvm flow for aievec.srs op (#1463)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamestcl-amd authored May 8, 2024
1 parent a9037bc commit 28ee711
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 45 deletions.
27 changes: 24 additions & 3 deletions include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,16 @@ def VectorSetI512I256IntrOp :

// ----- SRS -----

def I512V32Acc32SrsIntrOp :
AIEVec2_IntrOp<"I512.v32.acc32.srs",
[TypeIs<"res", VectorOfLengthAndType<[32], [I16]>>]>,
def I256V16Acc32SrsIntrOp :
AIEVec2_IntrOp<"I256.v16.acc32.srs",
[TypeIs<"res", VectorOfLengthAndType<[16], [I16]>>]>,
Arguments<(ins VectorOfLengthAndType<[8], [I64]>:$src,
I32:$shift,
I32:$sign)>;

def I256V16Acc64SrsIntrOp :
AIEVec2_IntrOp<"I256.v16.acc64.srs",
[TypeIs<"res", VectorOfLengthAndType<[16], [I16]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$src,
I32:$shift,
I32:$sign)>;
Expand All @@ -117,13 +124,27 @@ def I256V32Acc32SrsIntrOp :
I32:$shift,
I32:$sign)>;

def I256V8Acc64SrsIntrOp :
AIEVec2_IntrOp<"I256.v8.acc64.srs",
[TypeIs<"res", VectorOfLengthAndType<[8], [I32]>>]>,
Arguments<(ins VectorOfLengthAndType<[8], [I64]>:$src,
I32:$shift,
I32:$sign)>;

def I512V16Acc64SrsIntrOp :
AIEVec2_IntrOp<"I512.v16.acc64.srs",
[TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$src,
I32:$shift,
I32:$sign)>;

def I512V32Acc32SrsIntrOp :
AIEVec2_IntrOp<"I512.v32.acc32.srs",
[TypeIs<"res", VectorOfLengthAndType<[32], [I16]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$src,
I32:$shift,
I32:$sign)>;

def Vector16AccFloatToV16BF16IntrOp :
AIEVec2_IntrOp<"v16accfloat.to.v16bf16",
[TypeIs<"res", VectorOfLengthAndType<[16], [BF16]>>]>,
Expand Down
87 changes: 66 additions & 21 deletions lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,7 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::SRSOp> {
int resultVectorSize = resultBitWidth * resultLanes;

// Integer types
Operation *srsIntrOp = nullptr;
if (llvm::isa<IntegerType>(resultScaTy)) {
// create constant for sign
auto signCst = rewriter.create<LLVM::ConstantOp>(
Expand All @@ -852,47 +853,91 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::SRSOp> {
{adaptor.getSource(), adaptor.getShift(), signCst});
if (resultVectorSize == 512) {
if (resultBitWidth == 16) {
rewriter.replaceOpWithNewOp<xllvm::I512V32Acc32SrsIntrOp>(
op, VectorType::get({32}, rewriter.getI16Type()),
// v32acc32 -> v32int16
srsIntrOp = rewriter.create<xllvm::I512V32Acc32SrsIntrOp>(
loc, VectorType::get({32}, rewriter.getI16Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({16}, rewriter.getI64Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
} else if (resultBitWidth == 32) {
rewriter.replaceOpWithNewOp<xllvm::I512V16Acc64SrsIntrOp>(
op, VectorType::get({16}, rewriter.getI32Type()),
// v16acc64 -> v16int32
srsIntrOp = rewriter.create<xllvm::I512V16Acc64SrsIntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({16}, rewriter.getI64Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
}
} else if (resultVectorSize == 256) {
rewriter.replaceOpWithNewOp<xllvm::I256V32Acc32SrsIntrOp>(
op, VectorType::get({32}, rewriter.getI8Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({16}, rewriter.getI64Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
} else {
op.emitWarning() << "aievec.srs with result vector size = "
<< resultVectorSize << " is not supported.\n";
return failure();
Value src = adaptor.getSource();
VectorType srcType = cast<VectorType>(src.getType());
Type srcScaType = srcType.getElementType();
unsigned srcBitWidth = srcScaType.getIntOrFloatBitWidth();

if (resultBitWidth == 16 && srcBitWidth == 32) {
// v16acc32 -> v16int16
srsIntrOp = rewriter.create<xllvm::I256V16Acc32SrsIntrOp>(
loc, VectorType::get({16}, rewriter.getI16Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({8}, rewriter.getI64Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
} else if (resultBitWidth == 8 && srcBitWidth == 32) {
// v32acc32 -> v32int8
srsIntrOp = rewriter.create<xllvm::I256V32Acc32SrsIntrOp>(
loc, VectorType::get({32}, rewriter.getI8Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({16}, rewriter.getI64Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
} else if (resultBitWidth == 16 && srcBitWidth == 64) {
// v16acc64 -> v16int16
srsIntrOp = rewriter.create<xllvm::I256V16Acc64SrsIntrOp>(
loc, VectorType::get({16}, rewriter.getI16Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({16}, rewriter.getI64Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
} else if (resultBitWidth == 32 && srcBitWidth == 64) {
// v8acc64 -> v8int32
srsIntrOp = rewriter.create<xllvm::I256V8Acc64SrsIntrOp>(
loc, VectorType::get({8}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({8}, rewriter.getI64Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
}
}
} else {
// Float types
if (resultVectorSize == 256) {
rewriter.replaceOpWithNewOp<xllvm::Vector16AccFloatToV16BF16IntrOp>(
op, VectorType::get({16}, rewriter.getBF16Type()),
// v16accfloat -> v16bfloat16
srsIntrOp = rewriter.create<xllvm::Vector16AccFloatToV16BF16IntrOp>(
loc, VectorType::get({16}, rewriter.getBF16Type()),
forceCastOperandsToSignature(
rewriter, loc, {adaptor.getSource()},
{VectorType::get({8}, rewriter.getI64Type())}));
} else {
op.emitWarning() << "aievec.srs with result vector size = "
<< resultVectorSize << " is not supported.\n";
return failure();
} else if (resultVectorSize == 512) {
// v32accfloat -> v32bfloat16
// Implement this scenario in emulation. The CPP example is below:
// v32bfloat16 to_v32bfloat16(v32accfloat acc) {
// v16bfloat16 x0 = to_v16bfloat16(extract_v16accfloat(acc, 0));
// v16bfloat16 x1 = to_v16bfloat16(extract_v16accfloat(acc, 1));
// return concat(x0, x1);
// }
// TODO: implement this after adding 1024->512 vector extraction
// intrinsic
}
}

if (!srsIntrOp) {
op.emitWarning() << "aievec.srs is not supported.\n";
return failure();
}

rewriter.replaceOp(op, srsIntrOp);

return success();
}
};
Expand Down Expand Up @@ -1471,7 +1516,7 @@ class ExtractElemOpConversion
}

// create truncation op (and bitcast op)
if (resultType.isa<IntegerType>()) {
if (llvm::isa<IntegerType>(resultType)) {
if (resultBitWidth < 32) {
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, resultType, extElemOp);
} else {
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/AIEVec/IR/AIEVecOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ LogicalResult BroadcastScalarOp::verify() {
if (!resultType)
return emitError("requires vector type");

if (!sourceType.isa<IntegerType, FloatType>())
if (!isa<IntegerType, FloatType>(sourceType))
return emitError("requires source type to be integer or float");

Type resultElemType = resultType.getElementType();
Expand Down
99 changes: 85 additions & 14 deletions test/Conversion/AIEVecToLLVM/test-srs.mlir
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
// RUN: aie-opt %s -split-input-file --convert-aievec-to-llvm | FileCheck %s

func.func @v32i32_srs_v32i16(%arg0 : vector<32xi32>) {
func.func @v32i16_srs_v32i32(%arg0 : vector<32xi32>) {
%c0 = arith.constant 0 : i32
%c5 = arith.constant 5 : i32
%0 = aievec.srs %arg0, %c0 : vector<32xi32>, i32, vector<32xi16>
%1 = aievec.srs %arg0, %c5 : vector<32xi32>, i32, vector<32xi16>
return
}

// CHECK-LABEL: @v32i32_srs_v32i16
// CHECK-LABEL: @v32i16_srs_v32i32
// CHECK-SAME: %[[ARG0:.*]]: vector<32xi32>
// CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32
// CHECK-NEXT: %[[SHIFT5:.*]] = arith.constant 5 : i32
Expand All @@ -25,15 +25,63 @@ func.func @v32i32_srs_v32i16(%arg0 : vector<32xi32>) {

// -----

func.func @v32i32_srs_v32i8(%arg0 : vector<32xi32>) {
func.func @v16i32_srs_v16i64(%arg0 : vector<16xi64>) {
%c0 = arith.constant 0 : i32
%c5 = arith.constant 5 : i32
%0 = aievec.srs %arg0, %c0 : vector<16xi64>, i32, vector<16xi32>
%1 = aievec.srs %arg0, %c5 : vector<16xi64>, i32, vector<16xi32>
return
}

// CHECK-LABEL: @v16i32_srs_v16i64
// CHECK-SAME: %[[ARG0:.*]]: vector<16xi64>
// CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32
// CHECK-NEXT: %[[SHIFT5:.*]] = arith.constant 5 : i32
// CHECK-NEXT: %[[SIGN0:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.I512.v16.acc64.srs"(
// CHECK-SAME: [[ARG0]], %[[SHIFT0]], %[[SIGN0]]) :
// CHECK-SAME: (vector<16xi64>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[SIGN1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.I512.v16.acc64.srs"(
// CHECK-SAME: [[ARG0]], %[[SHIFT5]], %[[SIGN1]]) :
// CHECK-SAME: (vector<16xi64>, i32, i32) -> vector<16xi32>

// -----

func.func @v16i16_srs_v16i32(%arg0 : vector<16xi32>) {
%c0 = arith.constant 0 : i32
%c5 = arith.constant 5 : i32
%0 = aievec.srs %arg0, %c0 : vector<16xi32>, i32, vector<16xi16>
%1 = aievec.srs %arg0, %c5 : vector<16xi32>, i32, vector<16xi16>
return
}

// CHECK-LABEL: @v16i16_srs_v16i32
// CHECK-SAME: %[[ARG0:.*]]: vector<16xi32>
// CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32
// CHECK-NEXT: %[[SHIFT5:.*]] = arith.constant 5 : i32
// CHECK-NEXT: %[[SIGN0:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<16xi32> to vector<8xi64>
// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.I256.v16.acc32.srs"(
// CHECK-SAME: [[BITCAST0]], %[[SHIFT0]], %[[SIGN0]]) :
// CHECK-SAME: (vector<8xi64>, i32, i32) -> vector<16xi16>
// CHECK-NEXT: %[[SIGN1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<16xi32> to vector<8xi64>
// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.I256.v16.acc32.srs"(
// CHECK-SAME: [[BITCAST1]], %[[SHIFT5]], %[[SIGN1]]) :
// CHECK-SAME: (vector<8xi64>, i32, i32) -> vector<16xi16>

// -----

func.func @v32i8_srs_v32i32(%arg0 : vector<32xi32>) {
%c0 = arith.constant 0 : i32
%c5 = arith.constant 5 : i32
%0 = aievec.srs %arg0, %c0 : vector<32xi32>, i32, vector<32xi8>
%1 = aievec.srs %arg0, %c5 : vector<32xi32>, i32, vector<32xi8>
return
}

// CHECK-LABEL: @v32i32_srs_v32i8
// CHECK-LABEL: @v32i8_srs_v32i32
// CHECK-SAME: %[[ARG0:.*]]: vector<32xi32>
// CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32
// CHECK-NEXT: %[[SHIFT5:.*]] = arith.constant 5 : i32
Expand All @@ -50,38 +98,61 @@ func.func @v32i32_srs_v32i8(%arg0 : vector<32xi32>) {

// -----

func.func @v16i64_srs_v16i32(%arg0 : vector<16xi64>) {
func.func @v16i16_srs_v16i64(%arg0 : vector<16xi64>) {
%c0 = arith.constant 0 : i32
%c5 = arith.constant 5 : i32
%0 = aievec.srs %arg0, %c0 : vector<16xi64>, i32, vector<16xi32>
%1 = aievec.srs %arg0, %c5 : vector<16xi64>, i32, vector<16xi32>
%0 = aievec.srs %arg0, %c0 : vector<16xi64>, i32, vector<16xi16>
%1 = aievec.srs %arg0, %c5 : vector<16xi64>, i32, vector<16xi16>
return
}

// CHECK-LABEL: @v16i64_srs_v16i32
// CHECK-LABEL: @v16i16_srs_v16i64
// CHECK-SAME: %[[ARG0:.*]]: vector<16xi64>
// CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32
// CHECK-NEXT: %[[SHIFT5:.*]] = arith.constant 5 : i32
// CHECK-NEXT: %[[SIGN0:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.I512.v16.acc64.srs"(
// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.I256.v16.acc64.srs"(
// CHECK-SAME: [[ARG0]], %[[SHIFT0]], %[[SIGN0]]) :
// CHECK-SAME: (vector<16xi64>, i32, i32) -> vector<16xi32>
// CHECK-SAME: (vector<16xi64>, i32, i32) -> vector<16xi16>
// CHECK-NEXT: %[[SIGN1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.I512.v16.acc64.srs"(
// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.I256.v16.acc64.srs"(
// CHECK-SAME: [[ARG0]], %[[SHIFT5]], %[[SIGN1]]) :
// CHECK-SAME: (vector<16xi64>, i32, i32) -> vector<16xi32>
// CHECK-SAME: (vector<16xi64>, i32, i32) -> vector<16xi16>

// -----

func.func @v8i32_srs_v8i64(%arg0 : vector<8xi64>) {
%c0 = arith.constant 0 : i32
%c5 = arith.constant 5 : i32
%0 = aievec.srs %arg0, %c0 : vector<8xi64>, i32, vector<8xi32>
%1 = aievec.srs %arg0, %c5 : vector<8xi64>, i32, vector<8xi32>
return
}

// CHECK-LABEL: @v8i32_srs_v8i64
// CHECK-SAME: %[[ARG0:.*]]: vector<8xi64>
// CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32
// CHECK-NEXT: %[[SHIFT5:.*]] = arith.constant 5 : i32
// CHECK-NEXT: %[[SIGN0:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.I256.v8.acc64.srs"(
// CHECK-SAME: [[ARG0]], %[[SHIFT0]], %[[SIGN0]]) :
// CHECK-SAME: (vector<8xi64>, i32, i32) -> vector<8xi32>
// CHECK-NEXT: %[[SIGN1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.I256.v8.acc64.srs"(
// CHECK-SAME: [[ARG0]], %[[SHIFT5]], %[[SIGN1]]) :
// CHECK-SAME: (vector<8xi64>, i32, i32) -> vector<8xi32>

// -----

func.func @v16f32_srs_v16bf16(%arg0 : vector<16xf32>) {
func.func @v16bf16_srs_v16f32(%arg0 : vector<16xf32>) {
%c0 = arith.constant 0 : i32
%c5 = arith.constant 5 : i32
%0 = aievec.srs %arg0, %c0 : vector<16xf32>, i32, vector<16xbf16>
%1 = aievec.srs %arg0, %c5 : vector<16xf32>, i32, vector<16xbf16>
return
}

// CHECK-LABEL: @v16f32_srs_v16bf16
// CHECK-LABEL: @v16bf16_srs_v16f32
// CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>
// CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32
// CHECK-NEXT: %[[SHIFT5:.*]] = arith.constant 5 : i32
Expand Down
Loading

0 comments on commit 28ee711

Please sign in to comment.