Skip to content

Commit

Permalink
[aievec] emulated fp32xfp32 mul elem in different accuracy modes (#1291)
Browse files Browse the repository at this point in the history
Add AIEVecToLLVM pass option for --aie2-fp32-emulation-strategy= accuracy-safe, accuracy-fast, and accuracy-low
Add aievec-to-llvm conversion pattern and tests for the emulated f32xf32 elementwise multiplication in different options.
  • Loading branch information
jamestcl-amd authored Apr 19, 2024
1 parent d065e66 commit aec7ee5
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 39 deletions.
2 changes: 2 additions & 0 deletions include/aie/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Conversion)
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header -prefix Conversion)
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl -prefix Conversion)
mlir_tablegen(PassesEnums.h.inc -gen-enum-decls)
mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRAIEConversionPassIncGen)

add_mlir_doc(Passes MLIRAIEConversionPasses ./ -gen-pass-doc)
4 changes: 4 additions & 0 deletions include/aie/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
#define AIE_CONVERSION_PASSES_H

#include "aie/Conversion/AIEVecToLLVM/AIEVecToLLVM.h"
#include "aie/Conversion/PassesEnums.h.inc"

namespace xilinx {

#define GEN_PASS_DECL
#include "aie/Conversion/Passes.h.inc"

#define GEN_PASS_REGISTRATION
#include "aie/Conversion/Passes.h.inc"

Expand Down
23 changes: 23 additions & 0 deletions include/aie/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,21 @@
#ifndef AIE_CONVERSION_PASSES
#define AIE_CONVERSION_PASSES

include "mlir/IR/EnumAttr.td"
include "mlir/Pass/PassBase.td"

//===----------------------------------------------------------------------===//
// AIEVecToLLVM
//===----------------------------------------------------------------------===//
def Aie2Fp32EmulationType : I32EnumAttr<"Aie2Fp32Emulation", "AIE2 FP32 Emulation",
[
I32EnumAttrCase<"AccuracySafe", 0, "accuracy-safe">,
I32EnumAttrCase<"AccuracyFast", 1, "accuracy-fast">,
I32EnumAttrCase<"AccuracyLow", 2, "accuracy-low">
]>{
let cppNamespace = "xilinx::aievec";
}

def ConvertAIEVecToLLVM : Pass<"convert-aievec-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert AIEVec dialect to LLVM dialect";
let description = [{
Expand All @@ -26,6 +36,19 @@ def ConvertAIEVecToLLVM : Pass<"convert-aievec-to-llvm", "mlir::ModuleOp"> {
"mlir::arith::ArithDialect",
"mlir::vector::VectorDialect",
"xilinx::xllvm::XLLVMDialect"];
let options = [
Option<"aie2Fp32Emulation", "aie2-fp32-emulation-strategy", "xilinx::aievec::Aie2Fp32Emulation",
/*default=*/"xilinx::aievec::Aie2Fp32Emulation::AccuracySafe",
"Set the AIE2 FP32 emulation strategy. Elementwise multiplication and matrix multiplication intrinsics for FP32 input type are emulated using bfloat16 data-path.",
[{::llvm::cl::values(
clEnumValN(xilinx::aievec::Aie2Fp32Emulation::AccuracySafe, "accuracy-safe",
"Most accurate option since input fp32 number is split into 3 bfloat16 numbers. float a*b would require 9 mac operations due to 3 bfloat16 splits each."),
clEnumValN(xilinx::aievec::Aie2Fp32Emulation::AccuracyFast, "accuracy-fast",
"Fast and Accurate option. Input fp32 number is split in to 3 bfloat16 numbers. In the 9 mac operations to emulate fp32 mul, mac operations with LSBs are ignored. (3 last terms)."),
clEnumValN(xilinx::aievec::Aie2Fp32Emulation::AccuracyLow, "accuracy-low",
"Fast and least accurate option. Input fp32 number is split in to 2 bfloat16 numbers. In the 4 mac operations to emulate fp32 mul, mac operations with LSBs are ignored. (1 last term).")
)}]>
];
}

#endif // AIE_CONVERSION_PASSES
28 changes: 14 additions & 14 deletions include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -106,28 +106,28 @@ def VectorSetI512I256IntrOp :
def I512V32Acc32SrsIntrOp :
AIEVec2_IntrOp<"I512.v32.acc32.srs",
[TypeIs<"res", VectorOfLengthAndType<[32], [I16]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$a,
I32:$shft,
Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$src,
I32:$shift,
I32:$sign)>;

def I256V32Acc32SrsIntrOp :
AIEVec2_IntrOp<"I256.v32.acc32.srs",
[TypeIs<"res", VectorOfLengthAndType<[32], [I8]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$a,
I32:$shft,
Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$src,
I32:$shift,
I32:$sign)>;

def I512V16Acc64SrsIntrOp :
AIEVec2_IntrOp<"I512.v16.acc64.srs",
[TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$a,
I32:$shft,
Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$src,
I32:$shift,
I32:$sign)>;

def Vector16AccFloatToV16BF16IntrOp :
AIEVec2_IntrOp<"v16accfloat.to.v16bf16",
[TypeIs<"res", VectorOfLengthAndType<[16], [BF16]>>]>,
Arguments<(ins VectorOfLengthAndType<[8], [I64]>:$a)>;
Arguments<(ins VectorOfLengthAndType<[8], [I64]>:$src)>;

// ----- BROADCAST -----

Expand All @@ -151,24 +151,24 @@ def VectorBroadcast16BF512IntrOp :
def ExtI256I512IntrOp :
AIEVec2_IntrOp<"ext.I256.I512",
[TypeIs<"res", VectorOfLengthAndType<[8], [I32]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$a,
Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src,
I32:$idx)>;

// ----- CONCAT -----

def ConcatI512I256IntrOp :
AIEVec2_IntrOp<"concat.I512.I256",
[TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>,
Arguments<(ins VectorOfLengthAndType<[8], [I32]>:$a0,
VectorOfLengthAndType<[8], [I32]>:$a1)>;
Arguments<(ins VectorOfLengthAndType<[8], [I32]>:$src0,
VectorOfLengthAndType<[8], [I32]>:$src1)>;

// ----- SHUFFLE -----

def VectorShuffleIntrOp :
AIEVec2_IntrOp<"vshuffle",
[TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$a,
VectorOfLengthAndType<[16], [I32]>:$b,
Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src0,
VectorOfLengthAndType<[16], [I32]>:$src1,
I32:$mode)>;

// ----- UNDEF -----
Expand All @@ -182,8 +182,8 @@ def UndefV16I32IntrOp :
def UpdBF512BF256IntrOp :
AIEVec2_IntrOp<"upd.bf512.bf256",
[TypeIs<"res", VectorOfLengthAndType<[32], [BF16]>>]>,
Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$a,
VectorOfLengthAndType<[16], [BF16]>:$b,
Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$dst,
VectorOfLengthAndType<[16], [BF16]>:$src,
I32:$idx)>;

#endif // AIE_DIALECT_XLLVM_IR_XLLVMAIE2INTROPS_TD
93 changes: 71 additions & 22 deletions lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,13 @@ class MulElemOpConversion
public:
using ConvertOpToLLVMPattern<aievec::MulElemOp>::ConvertOpToLLVMPattern;

MulElemOpConversion(const LLVMTypeConverter &typeConverter,
Aie2Fp32Emulation aie2Fp32EmulationOption)
: ConvertOpToLLVMPattern(typeConverter),
aie2Fp32EmulationOption(aie2Fp32EmulationOption) {}

Aie2Fp32Emulation aie2Fp32EmulationOption;

struct DecodedMulElemOp {
enum class Kind {
// DtIn0_DtIn1_DtRes_CxMxKxN
Expand Down Expand Up @@ -681,34 +688,74 @@ class MulElemOpConversion
auto [d, e, f] =
extractV16FP32ToThreeV16BF16(adaptor.getRhs(), dZeros, eZeros, fZeros);

// 1 MUL + 8 * MACs
auto cfMul = rewriter.create<xllvm::MulConfBF16IntrOp>(
loc, VectorType::get({8}, rewriter.getI64Type()), c, f,
mscMacMulConfCst);
// Create 1 MUL and 2/5/8 MACs depending on the Aie2Fp32EmulationOption
auto createMacOps = [&](Value lhs, Value rhs, Value acc) -> Value {
return rewriter
.create<xllvm::MacConfBF16IntrOp>(
loc, VectorType::get({8}, rewriter.getI64Type()), lhs, rhs, acc,
mscMacMulConfCst)
.getResult();
};
auto adMac = createMacOps(
a, d,
createMacOps(
a, e,
createMacOps(
b, d,
createMacOps(
d, c,
createMacOps(
b, e,
createMacOps(
a, f,
createMacOps(b, f, createMacOps(c, e, cfMul))))))));

Value finalMacVal;
if (aie2Fp32EmulationOption == Aie2Fp32Emulation::AccuracyFast) {
// Fast and Accurate option. float a*b would require 6 mac operations.
// Input fp32 number is split in to 3 bfloat16 numbers to extract all the
// bits of the mantissa. float a,b; both a and b are split in to 3
// bfloat16 numbers each. Hence there would be 9 mac operations in
// multiplication of a and b. In the 9 mac operations to emulate fp32 mul,
// mac operations with LSBs are ignored. (3 last terms). This helps
// improve cycle count of mul and has least impact on accuracy of result.
// This is the default option to the aiecompiler
auto afMul = rewriter.create<xllvm::MulConfBF16IntrOp>(
loc, VectorType::get({8}, rewriter.getI64Type()), a, f,
mscMacMulConfCst);
finalMacVal = createMacOps(
a, d,
createMacOps(
a, e,
createMacOps(b, d,
createMacOps(d, c, createMacOps(b, e, afMul)))));
} else if (aie2Fp32EmulationOption == Aie2Fp32Emulation::AccuracyLow) {
// Fast and least accurate option. float a*b would require 3 mac
// operations.
// Input fp32 number is split in to 2 bfloat16 numbers. Hence not all the
// bits from mantissa can be used. float a,b; Both a and b are split in to
// 2 bfloat16 numbers each. Hence there would be 4 mac operations in
// multiplication of a and b. In the 4 mac operations to emulate fp32 mul,
// mac operations with LSBs are ignored. (1 last term). This helps improve
// cycle count of mul float a, b;
auto bdMul = rewriter.create<xllvm::MulConfBF16IntrOp>(
loc, VectorType::get({8}, rewriter.getI64Type()), b, d,
mscMacMulConfCst);
finalMacVal = createMacOps(a, d, createMacOps(a, e, bdMul));
} else {
// aie2Fp32EmulationOption == Aie2Fp32Emulation::AccuracySafe
// Most accurate option since input fp32 number is split in to 3 bfloat16
// numbers to extract all the bits of the mantissa. float a*b would
// require 9 mac operations due to 3 bfloat16 splits each.
auto cfMul = rewriter.create<xllvm::MulConfBF16IntrOp>(
loc, VectorType::get({8}, rewriter.getI64Type()), c, f,
mscMacMulConfCst);
finalMacVal = createMacOps(
a, d,
createMacOps(
a, e,
createMacOps(
b, d,
createMacOps(
d, c,
createMacOps(
b, e,
createMacOps(
a, f,
createMacOps(b, f,
createMacOps(c, e, cfMul))))))));
}

// create bitcast for result
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
adMac);
finalMacVal);
return success();
}

Expand Down Expand Up @@ -1457,8 +1504,9 @@ class FoldAIECastOps : public mlir::ConvertOpToLLVMPattern<aievec::CastOp> {
}
};

void populateAIEVecToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter,
mlir::RewritePatternSet &patterns) {
void populateAIEVecToLLVMConversionPatterns(
mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns,
Aie2Fp32Emulation aie2Fp32EmulationOption) {
// clang-format off
patterns.add<AddOpConversion,
SubOpConversion,
Expand All @@ -1475,9 +1523,9 @@ void populateAIEVecToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter,
BroadcastOpConversion,
BroadcastScalarOpConversion,
FMAElemOpConversion,
MulElemOpConversion,
MatMulOpConversion,
FoldAIECastOps>(converter);
patterns.add<MulElemOpConversion>(converter, aie2Fp32EmulationOption);
// clang-format on
}

Expand All @@ -1492,7 +1540,8 @@ struct ConvertAIEVecToLLVMPass
converter.addConversion(
[&](VectorType type) -> std::optional<Type> { return type; });

populateAIEVecToLLVMConversionPatterns(converter, patterns);
populateAIEVecToLLVMConversionPatterns(converter, patterns,
aie2Fp32Emulation);

LLVMConversionTarget target(getContext());
target.addIllegalDialect<AIEVecDialect>();
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#ifndef AIE_CONVERSION_PASSDETAIL_H_
#define AIE_CONVERSION_PASSDETAIL_H_

#include "aie/Conversion/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinOps.h"
Expand Down
Loading

0 comments on commit aec7ee5

Please sign in to comment.