Skip to content

Commit

Permalink
[aievec] to-llvm flow for aievec.ups/srs op (#1497)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamestcl-amd authored May 21, 2024
1 parent 0c35b86 commit 37ef519
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 18 deletions.
83 changes: 72 additions & 11 deletions lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -910,14 +910,39 @@ class UPSOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::UPSOp> {
{VectorType::get({16}, rewriter.getBF16Type())}));
} else if (resultVectorSize == 1024) {
// v32bfloat16 -> v32accfloat
// Implement this scenario in emulation. The CPP example is below:
// The CPP example of the implementation is below:
// INTRINSIC(v32accfloat) ups_to_v32accfloat(v32bfloat16 a) {
// v16accfloat x0 = ups_to_v16accfloat(extract_v16bfloat16(a, 0));
// v16accfloat x1 = ups_to_v16accfloat(extract_v16bfloat16(a, 1));
// return concat(x0, x1);
// }
// TODO: implement this after adding 512->256 vector extraction
// intrinsic
auto indexZeroCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
auto indexOneCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
auto extractUps = [&](Value source, Value index) -> Value {
auto extOp = rewriter.create<xllvm::ExtI256I512IntrOp>(
loc, VectorType::get({8}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, {source, index},
{VectorType::get({16}, rewriter.getI32Type()),
rewriter.getI32Type()}));
return rewriter.create<xllvm::Vector16BF16ToV16AccFloatIntrOp>(
loc, VectorType::get({8}, rewriter.getI64Type()),
forceCastOperandsToSignature(
rewriter, loc, {extOp},
{VectorType::get({16}, rewriter.getBF16Type())}));
};
auto resLo = extractUps(adaptor.getSource(), indexZeroCst);
auto resHi = extractUps(adaptor.getSource(), indexOneCst);
// Concat the two 512-bit vector to a 1024-bit vector.
// Note that given sources a0 and a1, the result is [a1; a0].
upsIntrOp = rewriter.create<xllvm::ConcatI1024I512IntrOp>(
loc, VectorType::get({32}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, {resLo, resHi},
{VectorType::get({16}, rewriter.getI32Type()),
VectorType::get({16}, rewriter.getI32Type())}));
}
}

Expand All @@ -926,9 +951,14 @@ class UPSOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::UPSOp> {
return failure();
}

// create bitcast for result
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
upsIntrOp);
// create bitcast for result if needed
if (op.getResult().getType() != upsIntrOp.getType()) {
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
upsIntrOp);
} else {
rewriter.replaceOp(op, upsIntrOp);
}

return success();
}
};
Expand All @@ -950,7 +980,7 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::SRSOp> {
int resultVectorSize = resultBitWidth * resultLanes;

// Integer types
Operation *srsIntrOp = nullptr;
Value srsIntrOp = nullptr;
if (llvm::isa<IntegerType>(resultScaTy)) {
// create constant for sign
auto signCst = rewriter.create<LLVM::ConstantOp>(
Expand Down Expand Up @@ -1028,14 +1058,39 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::SRSOp> {
{VectorType::get({8}, rewriter.getI64Type())}));
} else if (resultVectorSize == 512) {
// v32accfloat -> v32bfloat16
// Implement this scenario in emulation. The CPP example is below:
// The CPP example of the implementation is below:
// v32bfloat16 to_v32bfloat16(v32accfloat acc) {
// v16bfloat16 x0 = to_v16bfloat16(extract_v16accfloat(acc, 0));
// v16bfloat16 x1 = to_v16bfloat16(extract_v16accfloat(acc, 1));
// return concat(x0, x1);
// }
// TODO: implement this after adding 1024->512 vector extraction
// intrinsic
auto indexZeroCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
auto indexOneCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
auto extractSrs = [&](Value source, Value index) -> Value {
auto extOp = rewriter.create<xllvm::ExtI512I1024IntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, {source, index},
{VectorType::get({32}, rewriter.getI32Type()),
rewriter.getI32Type()}));
return rewriter.create<xllvm::Vector16AccFloatToV16BF16IntrOp>(
loc, VectorType::get({16}, rewriter.getBF16Type()),
forceCastOperandsToSignature(
rewriter, loc, {extOp},
{VectorType::get({8}, rewriter.getI64Type())}));
};
auto resLo = extractSrs(adaptor.getSource(), indexZeroCst);
auto resHi = extractSrs(adaptor.getSource(), indexOneCst);
// Concat the two 256-bit vector to a 512-bit vector.
// Note that given sources a0 and a1, the result is [a1; a0].
srsIntrOp = rewriter.create<xllvm::ConcatI512I256IntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, {resLo, resHi},
{VectorType::get({8}, rewriter.getI32Type()),
VectorType::get({8}, rewriter.getI32Type())}));
}
}

Expand All @@ -1044,7 +1099,13 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::SRSOp> {
return failure();
}

rewriter.replaceOp(op, srsIntrOp);
// create bitcast for result if needed
if (op.getResult().getType() != srsIntrOp.getType()) {
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
srsIntrOp);
} else {
rewriter.replaceOp(op, srsIntrOp);
}

return success();
}
Expand Down
36 changes: 36 additions & 0 deletions test/Conversion/AIEVecToLLVM/test-srs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,39 @@ func.func @v16bf16_srs_v16f32(%arg0 : vector<16xf32>) {
// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(
// CHECK-SAME: [[BITCAST1]]) :
// CHECK-SAME: (vector<8xi64>) -> vector<16xbf16>

// -----

func.func @v32bf16_srs_v32f32(%arg0 : vector<32xf32>) {
%c0 = arith.constant 0 : i32
%0 = aievec.srs %arg0, %c0 : vector<32xf32>, i32, vector<32xbf16>
return
}

// CHECK-LABEL: @v32bf16_srs_v32f32
// CHECK-SAME: %[[ARG0:.*]]: vector<32xf32>
// CHECK: %[[SHIFT0:.*]] = arith.constant 0 : i32
// CHECK-NEXT: %[[INDEX0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[INDEX1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<32xf32> to vector<32xi32>
// CHECK-NEXT: %[[EXT0:.*]] = "xllvm.intr.aie2.ext.I512.I1024"(
// CHECK-SAME: %[[BITCAST0]], %[[INDEX0]]) :
// CHECK-SAME: (vector<32xi32>, i32) -> vector<16xi32>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[EXT0]] : vector<16xi32> to vector<8xi64>
// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(
// CHECK-SAME: %[[BITCAST1]]) :
// CHECK-SAME: (vector<8xi64>) -> vector<16xbf16>
// CHECK-NEXT: %[[BITCAST2:.*]] = llvm.bitcast %[[ARG0]] : vector<32xf32> to vector<32xi32>
// CHECK-NEXT: %[[EXT1:.*]] = "xllvm.intr.aie2.ext.I512.I1024"(
// CHECK-SAME: %[[BITCAST2]], %[[INDEX1]]) :
// CHECK-SAME: (vector<32xi32>, i32) -> vector<16xi32>
// CHECK-NEXT: %[[BITCAST3:.*]] = llvm.bitcast %[[EXT1]] : vector<16xi32> to vector<8xi64>
// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.v16accfloat.to.v16bf16"(
// CHECK-SAME: %[[BITCAST3]]) :
// CHECK-SAME: (vector<8xi64>) -> vector<16xbf16>
// CHECK-NEXT: %[[BITCAST4:.*]] = llvm.bitcast %[[SRS0]] : vector<16xbf16> to vector<8xi32>
// CHECK-NEXT: %[[BITCAST5:.*]] = llvm.bitcast %[[SRS1]] : vector<16xbf16> to vector<8xi32>
// CHECK-NEXT: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I512.I256"(
// CHECK-SAME: %[[BITCAST4]], %[[BITCAST5]]) :
// CHECK-SAME: (vector<8xi32>, vector<8xi32>) -> vector<16xi32>
// CHECK-NEXT: %[[BITCAST6:.*]] = llvm.bitcast %[[CONCAT]] : vector<16xi32> to vector<32xbf16>
42 changes: 35 additions & 7 deletions test/Conversion/AIEVecToLLVM/test-ups.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,11 @@ func.func @v8acc64_ups_v8i32(%arg0 : vector<8xi32>) {
// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.acc64.v8.I256.ups"(
// CHECK-SAME: [[ARG0]], %[[SHIFT0]], %[[SIGN0]]) :
// CHECK-SAME: (vector<8xi32>, i32, i32) -> vector<8xi64>
// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[SRS0]] : vector<8xi64> to vector<8xi64>
// CHECK-NEXT: %[[SIGN1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[SHIFT5:.*]] = llvm.mlir.constant(5 : i32) : i32
// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.acc64.v8.I256.ups"(
// CHECK-SAME: [[ARG0]], %[[SHIFT5]], %[[SIGN1]]) :
// CHECK-SAME: (vector<8xi32>, i32, i32) -> vector<8xi64>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[SRS1]] : vector<8xi64> to vector<8xi64>

// -----

Expand Down Expand Up @@ -82,13 +80,11 @@ func.func @v16acc64_ups_v16i32(%arg0 : vector<16xi32>) {
// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.acc64.v16.I512.ups"(
// CHECK-SAME: [[ARG0]], %[[SHIFT0]], %[[SIGN0]]) :
// CHECK-SAME: (vector<16xi32>, i32, i32) -> vector<16xi64>
// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[SRS0]] : vector<16xi64> to vector<16xi64>
// CHECK-NEXT: %[[SIGN1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[SHIFT5:.*]] = llvm.mlir.constant(5 : i32) : i32
// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.acc64.v16.I512.ups"(
// CHECK-SAME: [[ARG0]], %[[SHIFT5]], %[[SIGN1]]) :
// CHECK-SAME: (vector<16xi32>, i32, i32) -> vector<16xi64>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[SRS1]] : vector<16xi64> to vector<16xi64>

// -----

Expand All @@ -105,13 +101,11 @@ func.func @v16acc64_ups_v16i16(%arg0 : vector<16xi16>) {
// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.acc64.v16.I256.ups"(
// CHECK-SAME: [[ARG0]], %[[SHIFT0]], %[[SIGN0]]) :
// CHECK-SAME: (vector<16xi16>, i32, i32) -> vector<16xi64>
// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[SRS0]] : vector<16xi64> to vector<16xi64>
// CHECK-NEXT: %[[SIGN1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[SHIFT5:.*]] = llvm.mlir.constant(5 : i32) : i32
// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.acc64.v16.I256.ups"(
// CHECK-SAME: [[ARG0]], %[[SHIFT5]], %[[SIGN1]]) :
// CHECK-SAME: (vector<16xi16>, i32, i32) -> vector<16xi64>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[SRS1]] : vector<16xi64> to vector<16xi64>

// -----

Expand Down Expand Up @@ -153,4 +147,38 @@ func.func @v16f32_ups_v16bf16(%arg0 : vector<16xbf16>) {
// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.v16bf16.to.v16accfloat"(
// CHECK-SAME: [[ARG0]]) :
// CHECK-SAME: (vector<16xbf16>) -> vector<8xi64>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[SRS1]] : vector<8xi64> to vector<16xf32>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[SRS1]] : vector<8xi64> to vector<16xf32>

// -----

func.func @v32f32_ups_v32bf16(%arg0 : vector<32xbf16>) {
%0 = aievec.ups %arg0 {shift = 0 : i8} : vector<32xbf16>, vector<32xf32>
return
}

// CHECK-LABEL: @v32f32_ups_v32bf16
// CHECK-SAME: %[[ARG0:.*]]: vector<32xbf16>
// CHECK: %[[INDEX0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[INDEX1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<32xbf16> to vector<16xi32>
// CHECK-NEXT: %[[EXT0:.*]] = "xllvm.intr.aie2.ext.I256.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[INDEX0]]) :
// CHECK-SAME: (vector<16xi32>, i32) -> vector<8xi32>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[EXT0]] : vector<8xi32> to vector<16xbf16>
// CHECK-NEXT: %[[UPS0:.*]] = "xllvm.intr.aie2.v16bf16.to.v16accfloat"(
// CHECK-SAME: %[[BITCAST1]]) :
// CHECK-SAME: (vector<16xbf16>) -> vector<8xi64>
// CHECK-NEXT: %[[BITCAST2:.*]] = llvm.bitcast %[[ARG0]] : vector<32xbf16> to vector<16xi32>
// CHECK-NEXT: %[[EXT1:.*]] = "xllvm.intr.aie2.ext.I256.I512"(
// CHECK-SAME: %[[BITCAST2]], %[[INDEX1]]) :
// CHECK-SAME: (vector<16xi32>, i32) -> vector<8xi32>
// CHECK-NEXT: %[[BITCAST3:.*]] = llvm.bitcast %[[EXT1]] : vector<8xi32> to vector<16xbf16>
// CHECK-NEXT: %[[UPS1:.*]] = "xllvm.intr.aie2.v16bf16.to.v16accfloat"(
// CHECK-SAME: %[[BITCAST3]]) :
// CHECK-SAME: (vector<16xbf16>) -> vector<8xi64>
// CHECK-NEXT: %[[BITCAST4:.*]] = llvm.bitcast %[[UPS0]] : vector<8xi64> to vector<16xi32>
// CHECK-NEXT: %[[BITCAST5:.*]] = llvm.bitcast %[[UPS1]] : vector<8xi64> to vector<16xi32>
// CHECK-NEXT: %[[CONCAT:.*]] = "xllvm.intr.aie2.concat.I1024.I512"(
// CHECK-SAME: %[[BITCAST4]], %[[BITCAST5]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>) -> vector<32xi32>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[CONCAT]] : vector<32xi32> to vector<32xf32>

0 comments on commit 37ef519

Please sign in to comment.