Skip to content

Commit

Permalink
[aievec] to-llvm flow for aievec.shift op (#1432)
Browse files Browse the repository at this point in the history
* Add vshift and vextract.elem intrinsic ops to XLLVM dialect.
* Add aievec-to-llvm conversion pattern/tests for the aievec.shift op.
* Add target external llvm translation tests.
* Update test/unit_tests/aievec_tests/bf16_max_reduce e2e test and mark to-llvm test as XFAIL for now.
* Other minor changes: add a f32 mul_elem scalar test, and a missing aievec.broadcast_scalar conversion test
  • Loading branch information
jamestcl-amd authored Apr 29, 2024
1 parent 774f25b commit 2a55238
Show file tree
Hide file tree
Showing 11 changed files with 329 additions and 22 deletions.
49 changes: 45 additions & 4 deletions include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,16 @@ def ExtI256I512IntrOp :
def ConcatI512I256IntrOp :
AIEVec2_IntrOp<"concat.I512.I256",
[TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>,
Arguments<(ins VectorOfLengthAndType<[8], [I32]>:$src0,
VectorOfLengthAndType<[8], [I32]>:$src1)>;
Arguments<(ins VectorOfLengthAndType<[8], [I32]>:$lhs,
VectorOfLengthAndType<[8], [I32]>:$rhs)>;

// ----- SHUFFLE -----

def VectorShuffleIntrOp :
AIEVec2_IntrOp<"vshuffle",
[TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src0,
VectorOfLengthAndType<[16], [I32]>:$src1,
Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$lhs,
VectorOfLengthAndType<[16], [I32]>:$rhs,
I32:$mode)>;

// ----- UNDEF -----
Expand All @@ -186,4 +186,45 @@ def UpdBF512BF256IntrOp :
VectorOfLengthAndType<[16], [BF16]>:$src,
I32:$idx)>;

// ----- SHIFT -----

def VectorShiftI512I512IntrOp :
AIEVec2_IntrOp<"vshift.I512.I512",
[TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$lhs,
VectorOfLengthAndType<[16], [I32]>:$rhs,
I32:$step,
I32:$shift)>;

def VectorShiftBF512BF512IntrOp :
AIEVec2_IntrOp<"vshift.bf512.bf512",
[TypeIs<"res", VectorOfLengthAndType<[32], [BF16]>>]>,
Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$lhs,
VectorOfLengthAndType<[32], [BF16]>:$rhs,
I32:$step,
I32:$shift)>;

// ----- EXTRACT ELEMENT -----

def VectorExtractElem8I512IntrOp :
AIEVec2_IntrOp<"vextract.elem8.I512",
[TypeIs<"res", I32>]>,
Arguments<(ins VectorOfLengthAndType<[64], [I8]>:$src,
I32:$idx,
I32:$sign)>;

def VectorExtractElem16I512IntrOp :
AIEVec2_IntrOp<"vextract.elem16.I512",
[TypeIs<"res", I32>]>,
Arguments<(ins VectorOfLengthAndType<[32], [I16]>:$src,
I32:$idx,
I32:$sign)>;

def VectorExtractElem32I512IntrOp :
AIEVec2_IntrOp<"vextract.elem32.I512",
[TypeIs<"res", I32>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src,
I32:$idx,
I32:$sign)>;

#endif // AIE_DIALECT_XLLVM_IR_XLLVMAIE2INTROPS_TD
59 changes: 59 additions & 0 deletions lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,64 @@ class BroadcastScalarOpConversion
}
};

class ShiftOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::ShiftOp> {
public:
using ConvertOpToLLVMPattern<aievec::ShiftOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(aievec::ShiftOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();

Value result = op.getResult();
VectorType resultType = cast<VectorType>(result.getType());
Type resultScaTy = resultType.getElementType();
unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth();
int resultLanes = getVectorLaneSize(resultType);
int resultVectorSize = resultBitWidth * resultLanes;

if (resultVectorSize != 512) {
op.emitWarning() << "aievec.shift conversion with result vector size "
<< resultVectorSize << " is not implemented.\n";
return failure();
}

// assume step is always zero
auto stepCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));

// create xllvm intrinsic
Value shiftOp = nullptr;
SmallVector<Value> operands(
{adaptor.getLhs(), adaptor.getRhs(), stepCst, adaptor.getShift()});
if (resultScaTy.isa<IntegerType>()) {
// Integer types
shiftOp = rewriter.create<xllvm::VectorShiftI512I512IntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({16}, rewriter.getI32Type()),
VectorType::get({16}, rewriter.getI32Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
} else {
// Float types
shiftOp = rewriter.create<xllvm::VectorShiftBF512BF512IntrOp>(
loc, VectorType::get({32}, rewriter.getBF16Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({32}, rewriter.getBF16Type()),
VectorType::get({32}, rewriter.getBF16Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
}

// create bitcast for result
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
shiftOp);

return success();
}
};

class FMAElemOpConversion
: public mlir::ConvertOpToLLVMPattern<aievec::FMAElemOp> {
public:
Expand Down Expand Up @@ -1524,6 +1582,7 @@ void populateAIEVecToLLVMConversionPatterns(
BroadcastScalarOpConversion,
FMAElemOpConversion,
MatMulOpConversion,
ShiftOpConversion,
FoldAIECastOps>(converter);
patterns.add<MulElemOpConversion>(converter, aie2Fp32EmulationOption);
// clang-format on
Expand Down
14 changes: 14 additions & 0 deletions test/Conversion/AIEVecToLLVM/broadcast_scalar.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@ func.func @i8_broadcast_scalar(%arg0 : i8) -> vector<64xi8> {

// -----

func.func @i32_broadcast_scalar(%arg0 : i32) -> vector<16xi32> {
%0 = aievec.broadcast_scalar %arg0 : i32, vector<16xi32>
return %0 : vector<16xi32>
}

// CHECK-LABEL: @i32_broadcast_scalar
// CHECK-SAME: %[[ARG0:.*]]: i32
// CHECK-NEXT: %[[VBROADCAST:.*]] = "xllvm.intr.aie2.vbroadcast32.I512"(
// CHECK-SAME: %[[ARG0]]) :
// CHECK-SAME: (i32) -> vector<16xi32>
// CHECK-NEXT: return %[[VBROADCAST]] : vector<16xi32>

// -----

func.func @bf16_broadcast_scalar(%arg0 : bf16) -> vector<32xbf16> {
%0 = aievec.broadcast_scalar %arg0 : bf16, vector<32xbf16>
return %0 : vector<32xbf16>
Expand Down
73 changes: 73 additions & 0 deletions test/Conversion/AIEVecToLLVM/shift.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// RUN: aie-opt %s -split-input-file -convert-aievec-to-llvm | FileCheck %s

// -----

func.func @i8_shift(%arg0 : vector<64xi8>, %shift : i32) -> vector<64xi8> {
%0 = aievec.shift %arg0, %arg0, %shift {isAcc = false} : vector<64xi8>, vector<64xi8>, i32, vector<64xi8>
return %0 : vector<64xi8>
}

// CHECK-LABEL: @i8_shift
// CHECK-SAME: %[[ARG0:.*]]: vector<64xi8>,
// CHECK-SAME: %[[SHIFT:.*]]: i32
// CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<64xi8> to vector<16xi32>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<64xi8> to vector<16xi32>
// CHECK-NEXT: %[[VSHIFT:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[BITCAST1]], %[[CST]], %[[SHIFT]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[VSHIFT]] : vector<16xi32> to vector<64xi8>
// CHECK-NEXT: return %[[RES]] : vector<64xi8>

// -----

func.func @i16_shift(%arg0 : vector<32xi16>, %shift : i32) -> vector<32xi16> {
%0 = aievec.shift %arg0, %arg0, %shift {isAcc = false} : vector<32xi16>, vector<32xi16>, i32, vector<32xi16>
return %0 : vector<32xi16>
}

// CHECK-LABEL: @i16_shift
// CHECK-SAME: %[[ARG0:.*]]: vector<32xi16>,
// CHECK-SAME: %[[SHIFT:.*]]: i32
// CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG0]] : vector<32xi16> to vector<16xi32>
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<32xi16> to vector<16xi32>
// CHECK-NEXT: %[[VSHIFT:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[BITCAST0]], %[[BITCAST1]], %[[CST]], %[[SHIFT]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[VSHIFT]] : vector<16xi32> to vector<32xi16>
// CHECK-NEXT: return %[[RES]] : vector<32xi16>

// -----

func.func @i32_shift(%arg0 : vector<16xi32>, %shift : i32) -> vector<16xi32> {
%0 = aievec.shift %arg0, %arg0, %shift {isAcc = false} : vector<16xi32>, vector<16xi32>, i32, vector<16xi32>
return %0 : vector<16xi32>
}

// CHECK-LABEL: @i32_shift
// CHECK-SAME: %[[ARG0:.*]]: vector<16xi32>,
// CHECK-SAME: %[[SHIFT:.*]]: i32
// CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[VSHIFT:.*]] = "xllvm.intr.aie2.vshift.I512.I512"(
// CHECK-SAME: %[[ARG0]], %[[ARG0]], %[[CST]], %[[SHIFT]]) :
// CHECK-SAME: (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[VSHIFT]] : vector<16xi32> to vector<16xi32>
// CHECK-NEXT: return %[[RES]] : vector<16xi32>

// -----

func.func @bf16_shift(%arg0 : vector<32xbf16>, %shift : i32) -> vector<32xbf16> {
%0 = aievec.shift %arg0, %arg0, %shift {isAcc = false} : vector<32xbf16>, vector<32xbf16>, i32, vector<32xbf16>
return %0 : vector<32xbf16>
}

// CHECK-LABEL: @bf16_shift
// CHECK-SAME: %[[ARG0:.*]]: vector<32xbf16>,
// CHECK-SAME: %[[SHIFT:.*]]: i32
// CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[VSHIFT:.*]] = "xllvm.intr.aie2.vshift.bf512.bf512"(
// CHECK-SAME: %[[ARG0]], %[[ARG0]], %[[CST]], %[[SHIFT]]) :
// CHECK-SAME: (vector<32xbf16>, vector<32xbf16>, i32, i32) -> vector<32xbf16>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[VSHIFT]] : vector<32xbf16> to vector<32xbf16>
// CHECK-NEXT: return %[[RES]] : vector<32xbf16>
46 changes: 46 additions & 0 deletions test/Target/LLVMIR/aievec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,49 @@ llvm.func @upd_bf512_bf256(%a : vector<32xbf16>, %b : vector<16xbf16>, %idx : i3
%0 = "xllvm.intr.aie2.upd.bf512.bf256"(%a, %b, %idx) : (vector<32xbf16>, vector<16xbf16>, i32) -> vector<32xbf16>
llvm.return %0 : vector<32xbf16>
}

// ----- SHIFT -----

// CHECK-LABEL: define <16 x i32> @vshift_i512_i512
llvm.func @vshift_i512_i512(%a : vector<16xi32>, %b : vector<16xi32>, %step : i32, %shift : i32) -> vector<16xi32> {
// CHECK: call <16 x i32> @llvm.aie2.vshift.I512.I512(
// CHECK-SAME: <16 x i32> %{{[0-9]+}}, <16 x i32> %{{[0-9]+}},
// CHECK-SAME: i32 %{{[0-9]+}}, i32 %{{[0-9]+}})
%0 = "xllvm.intr.aie2.vshift.I512.I512"(%a, %b, %step, %shift) : (vector<16xi32>, vector<16xi32>, i32, i32) -> vector<16xi32>
llvm.return %0 : vector<16xi32>
}

// CHECK-LABEL: define <32 x bfloat> @vshift_bf512_bf512
llvm.func @vshift_bf512_bf512(%a : vector<32xbf16>, %b : vector<32xbf16>, %step : i32, %shift : i32) -> vector<32xbf16> {
// CHECK: call <32 x bfloat> @llvm.aie2.vshift.bf512.bf512(
// CHECK-SAME: <32 x bfloat> %{{[0-9]+}}, <32 x bfloat> %{{[0-9]+}},
// CHECK-SAME: i32 %{{[0-9]+}}, i32 %{{[0-9]+}})
%0 = "xllvm.intr.aie2.vshift.bf512.bf512"(%a, %b, %step, %shift) : (vector<32xbf16>, vector<32xbf16>, i32, i32) -> vector<32xbf16>
llvm.return %0 : vector<32xbf16>
}

// ----- EXTRACT ELEMENT -----

// CHECK-LABEL: define i32 @vextract_elem8_i512
llvm.func @vextract_elem8_i512(%a : vector<64xi8>, %idx : i32, %sign : i32) -> i32 {
// CHECK: call i32 @llvm.aie2.vextract.elem8.I512(
// CHECK-SAME: <64 x i8> %{{[0-9]+}}, i32 %{{[0-9]+}}, i32 %{{[0-9]+}})
%0 = "xllvm.intr.aie2.vextract.elem8.I512"(%a, %idx, %sign) : (vector<64xi8>, i32, i32) -> i32
llvm.return %0 : i32
}

// CHECK-LABEL: define i32 @vextract_elem16_i512
llvm.func @vextract_elem16_i512(%a : vector<32xi16>, %idx : i32, %sign : i32) -> i32 {
// CHECK: call i32 @llvm.aie2.vextract.elem16.I512(
// CHECK-SAME: <32 x i16> %{{[0-9]+}}, i32 %{{[0-9]+}}, i32 %{{[0-9]+}})
%0 = "xllvm.intr.aie2.vextract.elem16.I512"(%a, %idx, %sign) : (vector<32xi16>, i32, i32) -> i32
llvm.return %0 : i32
}

// CHECK-LABEL: define i32 @vextract_elem32_i512
llvm.func @vextract_elem32_i512(%a : vector<16xi32>, %idx : i32, %sign : i32) -> i32 {
// CHECK: call i32 @llvm.aie2.vextract.elem32.I512(
// CHECK-SAME: <16 x i32> %{{[0-9]+}}, i32 %{{[0-9]+}}, i32 %{{[0-9]+}})
%0 = "xllvm.intr.aie2.vextract.elem32.I512"(%a, %idx, %sign) : (vector<16xi32>, i32, i32) -> i32
llvm.return %0 : i32
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Copyright (C) 2024, Advanced Micro Devices, Inc.

// REQUIRES: valid_xchess_license
// REQUIRES: peano
// RUN: mkdir -p %t/data; cd %t
// RUN: aie-opt %s %vector-to-llvmir% -o llvmir.mlir
// RUN: aie-translate llvmir.mlir %llvmir-to-ll% -o dut.ll
// RUN: %PEANO_INSTALL_DIR/bin/clang %clang_aie2_args -c dut.ll -o dut.o
// RUN: xchesscc_wrapper %xchesscc_aie2_args -DTO_LLVM +w work +o work -I%S -I. %S/testbench.cc dut.o
// RUN: xca_udm_dbg --aiearch aie-ml -qf -T -P %aietools/data/aie_ml/lib/ -t "%S/../profiling.tcl ./work/a.out" >& xca_udm_dbg.stdout
// RUN: FileCheck --input-file=./xca_udm_dbg.stdout %s
// CHECK: TEST PASSED
// XFAIL: *

module {
func.func @dut(%arg0: memref<1024xbf16>, %arg1: memref<bf16>) {
%cst_0 = arith.constant dense<0xFF80> : vector<32xbf16>
%0 = affine.for %arg2 = 0 to 1024 step 32 iter_args(%arg3 = %cst_0) -> (vector<32xbf16>) {
%cst_1 = arith.constant 0.000000e+00 : bf16
%3 = vector.transfer_read %arg0[%arg2], %cst_1 : memref<1024xbf16>, vector<32xbf16>
%4 = arith.maximumf %arg3, %3 : vector<32xbf16>
affine.yield %4 : vector<32xbf16>
}
%1 = vector.reduction <maximumf>, %0 : vector<32xbf16> into bf16
affine.store %1, %arg1[] : memref<bf16>
return
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Copyright (C) 2024, Advanced Micro Devices, Inc.

// REQUIRES: valid_xchess_license
// RUN: mkdir -p %t/data; cd %t
// RUN: aie-opt %s --convert-vector-to-aievec="aie-target=aieml" -lower-affine | aie-translate -aieml=true --aievec-to-cpp -o dut.cc
// RUN: xchesscc_wrapper aie2 -f -g +s +w work +o work -I%S -I. -c dut.cc -o dut.o
// RUN: xchesscc_wrapper aie2 -f -g +s +w work +o work -I%S -I. %S/testbench.cc work/dut.o
// RUN: mkdir -p data
// RUN: xchesscc_wrapper %xchesscc_aie2_args +w work +o work -I%S -I. -c dut.cc -o dut.o
// RUN: xchesscc_wrapper %xchesscc_aie2_args -DTO_CPP +w work +o work -I%S -I. %S/testbench.cc work/dut.o
// RUN: xca_udm_dbg --aiearch aie-ml -qf -T -P %aietools/data/aie_ml/lib/ -t "%S/../profiling.tcl ./work/a.out" >& xca_udm_dbg.stdout
// RUN: FileCheck --input-file=./xca_udm_dbg.stdout %s
// CHECK: TEST PASSED
Expand Down
16 changes: 15 additions & 1 deletion test/unit_tests/aievec_tests/bf16_max_reduce/testbench.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,22 @@
void dut(bfloat16 *restrict in0, bfloat16 *restrict out0);
void dut_ref(bfloat16 *in0, bfloat16 *out0);

#ifdef TO_CPP
void dut(bfloat16 *restrict in0, bfloat16 *restrict out0);
#elif TO_LLVM
extern "C" {
void dut(bfloat16 *in0_allocated, bfloat16 *in0_aligned, int64_t in0_offset,
int64_t in0_sizes_0, int64_t in0_strides_0, bfloat16 *out0_allocated,
bfloat16 *out0_aligned, int64_t out0_offset, int64_t out0_sizes_0,
int64_t out0_strides_0);
}
#endif

alignas(32) bfloat16 g_in0[IN0_SIZE];
alignas(32) bfloat16 g_out0[OUT0_SIZE];
alignas(32) bfloat16 g_out0Ref[OUT0_SIZE];

int main(int argc, char *argv[]) {
// XXX Figure out how to use argv with xca_udm_dbg --aiearch aie-ml -A
std::string dataDir(TO_STR(DATA_DIR));
srand(10);
std::generate(g_in0, g_in0 + IN0_SIZE,
Expand All @@ -22,7 +32,11 @@ int main(int argc, char *argv[]) {

chess_memory_fence();
auto cyclesBegin = chess_cycle_count();
#ifdef TO_CPP
dut(g_in0, g_out0);
#elif TO_LLVM
dut(g_in0, g_in0, 0, 0, 0, g_out0, g_out0, 0, 0, 0);
#endif
auto cyclesEnd = chess_cycle_count();
chess_memory_fence();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Copyright (C) 2024, Advanced Micro Devices, Inc.

// REQUIRES: valid_xchess_license
// REQUIRES: peano
// RUN: mkdir -p %t/data; cd %t
// RUN: aie-opt %s %vector-to-generic-llvmir% -o llvmir.mlir
// RUN: aie-translate llvmir.mlir %llvmir-to-ll% -o dut.ll
// RUN: %PEANO_INSTALL_DIR/bin/clang %clang_aie2_args -c dut.ll -o dut.o
// RUN: xchesscc_wrapper %xchesscc_aie2_args -DTO_LLVM +w work +o work -I%S -I. %S/testbench.cc dut.o
// RUN: xca_udm_dbg --aiearch aie-ml -qf -T -P %aietools/data/aie_ml/lib/ -t "%S/../profiling.tcl ./work/a.out" >& xca_udm_dbg.stdout
// RUN: FileCheck --input-file=./xca_udm_dbg.stdout %s
// CHECK: TEST PASSED
// XFAIL: *

module {
func.func @dut(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>, %arg2: memref<1024xf32>) {
memref.assume_alignment %arg0, 32 : memref<1024xf32>
memref.assume_alignment %arg1, 32 : memref<1024xf32>
memref.assume_alignment %arg2, 32 : memref<1024xf32>
affine.for %arg3 = 0 to 1024 {
%0 = affine.load %arg0[%arg3] : memref<1024xf32>
%1 = affine.load %arg1[%arg3] : memref<1024xf32>
%2 = arith.mulf %0, %1 : f32
affine.store %2, %arg2[%arg3] : memref<1024xf32>
}
return
}
}
Loading

0 comments on commit 2a55238

Please sign in to comment.