Skip to content

Commit

Permalink
[aievec] to-llvm/to-cpp flow for emulated fp32xfp32 mul elem (#1239)
Browse files Browse the repository at this point in the history
* Add some of the upd/msc intrinsics to XLLVM.
* Enable vector->aievec support for the f32xf32 mul_elem and add the corresponding conversion test.
* Add aievec-to-llvm conversion pattern for the emulated f32xf32 elementwise multiplication.
* Add aievec-to-llvm conversion tests for the newly added XLLVM ops.
* Add target llvm translation tests.
* Add f32xf32_mul_elem e2e tests for the to-llvm and to-cpp flow. This includes the testbench.cc and the test script.
  • Loading branch information
jamestcl-amd authored Apr 16, 2024
1 parent d088353 commit de9e2dc
Show file tree
Hide file tree
Showing 18 changed files with 441 additions and 29 deletions.
35 changes: 27 additions & 8 deletions include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class AIEVec2_IntrOp<string mnemonic,

// TODO: Find better names for these

class AIE2bf16MACConf :
Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$lhs,
VectorOfLengthAndType<[32], [BF16]>:$rhs,
VectorOfLengthAndType<[8], [I64]>:$acc,
I32:$conf)>;

// ----- MAC -----

def MacConfAcc32IntrOp :
Expand All @@ -50,10 +56,14 @@ def MacConfAcc64IntrOp :
def MacConfBF16IntrOp :
AIEVec2_IntrOp<"bf.mac16.conf",
[TypeIs<"res", VectorOfLengthAndType<[8], [I64]>>]>,
Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$lhs,
VectorOfLengthAndType<[32], [BF16]>:$rhs,
VectorOfLengthAndType<[8], [I64]>:$acc,
I32:$conf)>;
AIE2bf16MACConf;

// ----- MSC -----

def MscConfBF16IntrOp :
AIEVec2_IntrOp<"bf.msc16.conf",
[TypeIs<"res", VectorOfLengthAndType<[8], [I64]>>]>,
AIE2bf16MACConf;

// ----- MUL -----

Expand Down Expand Up @@ -96,28 +106,28 @@ def VectorSetI512I256IntrOp :
def I512V32Acc32SrsIntrOp :
AIEVec2_IntrOp<"I512.v32.acc32.srs",
[TypeIs<"res", VectorOfLengthAndType<[32], [I16]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$lhs,
Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$a,
I32:$shft,
I32:$sign)>;

def I256V32Acc32SrsIntrOp :
AIEVec2_IntrOp<"I256.v32.acc32.srs",
[TypeIs<"res", VectorOfLengthAndType<[32], [I8]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$lhs,
Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$a,
I32:$shft,
I32:$sign)>;

def I512V16Acc64SrsIntrOp :
AIEVec2_IntrOp<"I512.v16.acc64.srs",
[TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$lhs,
Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$a,
I32:$shft,
I32:$sign)>;

def Vector16AccFloatToV16BF16IntrOp :
AIEVec2_IntrOp<"v16accfloat.to.v16bf16",
[TypeIs<"res", VectorOfLengthAndType<[16], [BF16]>>]>,
Arguments<(ins VectorOfLengthAndType<[8], [I64]>:$lhs)>;
Arguments<(ins VectorOfLengthAndType<[8], [I64]>:$a)>;

// ----- BROADCAST -----

Expand Down Expand Up @@ -167,4 +177,13 @@ def UndefV16I32IntrOp :
AIEVec2_IntrOp<"v16int32",
[TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>;

// ----- UPD -----

def UpdBF512BF256IntrOp :
AIEVec2_IntrOp<"upd.bf512.bf256",
[TypeIs<"res", VectorOfLengthAndType<[32], [BF16]>>]>,
Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$a,
VectorOfLengthAndType<[16], [BF16]>:$b,
I32:$idx)>;

#endif // AIE_DIALECT_XLLVM_IR_XLLVMAIE2INTROPS_TD
187 changes: 175 additions & 12 deletions lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,9 @@ class MulElemOpConversion
I16_I16_I32_32x1x1x1,
I32_I32_I64_32x1x2x1,
BF16_BF16_FP32_16x1x2x1,
FP32_FP32_FP32_16x1x1x1,
UNSUPPORTED
// TODO: I16_I16_I64_16x1x2x1
// TODO: FP32 mul_elem is emulated
};

Kind kind;
Expand Down Expand Up @@ -425,6 +425,7 @@ class MulElemOpConversion
/*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
/*sub_mask=*/0)};
} else if (lhsBitWidth == 32) {
// emulated I32 mul_elem
return {DecodedMulElemOp::Kind::I32_I32_I64_32x1x2x1, -1};
}
} else {
Expand All @@ -436,13 +437,16 @@ class MulElemOpConversion
/*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
/*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
/*sub_mask=*/0)};
} else if (lhsBitWidth == 32) {
// emulated FP32 mul_elem
return {DecodedMulElemOp::Kind::FP32_FP32_FP32_16x1x1x1, -1};
}
}

return {DecodedMulElemOp::Kind::UNSUPPORTED, -1};
}

// This conversion pattern implements the below CPP I32 mul_elem emulation.
// This conversion pattern implements the below CPP emulated I32 mul_elem.
// INTRINSIC(v16acc64)
// mul_elem_16_2(v16int32 a0, v16int32 a1, v16int32 b0, v16int32 b1) {
// v32uint16 a_lo = (v32uint16)shuffle(a0, a1, 2);
Expand All @@ -455,14 +459,26 @@ class MulElemOpConversion
// acc = mac_elem_16_2_conf(a_lo, false, b_lo, false, acc, 0, 1, 0, 0);
// return acc;
// }
// Caller to the above CPP intrinsic:
// v16int32 v1 = LHS();
// v16int32 v2 = RHS();
// v16acc64 v3 = mul_elem_16_2(v1, broadcast_zero_s32(), v2,
// undef_v16int32());
// Caller example when handling the elementwise mul of two v16int32 vectors.
// v16int32 v1 = LHS();
// v16int32 v2 = RHS();
// v16acc64 v3 = mul_elem_16_2(v1, broadcast_zero_s32(), v2,
// undef_v16int32());
// Explantion:
// a_lo = low_part(a0[0]--a0[15], a1[0]--a1[15])
// a_hi = high_part(a0[0]--a0[15], a1[0]--a1[15])
// b_lo = low_part(b0[0]--b0[15], b1[0]--b1[15])
// b_hi = high_part(b0[0]--b0[15], b1[0]--b1[15])
// The firt `acc` is from mul_elem_16_2(a_hi, b_hi), which performs 16 channel
// of 1x2x1 matmul, acc[0] = a_hi[0]*b_hi[0]+a_hi[16]*b_hi[16], ... , acc[15]
// = a_hi[15]*b_hi[15]+a_hi[31]*b_hi[31]. Then, the first MAC performs `acc`
// left shift 16bit, and then 16 channel of 1x2x1 matmul (a_hi, b_lo)
// accumulating to `acc`. The second MAC performs 16 channel of 1x2x1 matmul
// (a_lo, b_hi) accumulating to `acc`. Finally, the third MAC performs 16
// channel of 1x2x1 matmul (a_lo, b_hi) accumulating to `acc`.
LogicalResult
convertToI32MulElemEmulation(aievec::MulElemOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
convertToEmulatedI32MulElem(aievec::MulElemOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

Location loc = op.getLoc();
auto zeroCst = rewriter.create<LLVM::ConstantOp>(
Expand Down Expand Up @@ -551,6 +567,151 @@ class MulElemOpConversion
return success();
}

// This conversion pattern implements the below CPP emulated FP32 mul_elem.
// inline v16accfloat mul_elem_16_accuracy_safe(v16float v1, v16float v2) {
// v32bfloat16 a = broadcast_zero_to_v32bfloat16();
// v32bfloat16 b = broadcast_zero_to_v32bfloat16();
// v32bfloat16 c = broadcast_zero_to_v32bfloat16();
// v32bfloat16 d = broadcast_zero_to_v32bfloat16();
// v32bfloat16 e = broadcast_zero_to_v32bfloat16();
// v32bfloat16 f = broadcast_zero_to_v32bfloat16();
// v32bfloat16 dummy0 = broadcast_one_to_v32bfloat16();
// a = insert(a,0,to_v16bfloat16((v16accfloat)v1));
// v16accfloat acc0 = msc_elem_16_2(a, dummy0, (v16accfloat)v1);
// b = insert(b,0,to_v16bfloat16(acc0));
// c = insert(c,0,to_v16bfloat16(msc_elem_16_2(b, dummy0, acc0)));
// d = insert(d,0,to_v16bfloat16((v16accfloat)v2));
// v16accfloat acc1 = msc_elem_16_2(d, dummy0, (v16accfloat)v2);
// e = insert(e,0,to_v16bfloat16(acc1));
// f = insert(f,0,to_v16bfloat16(msc_elem_16_2(e, dummy0, acc1)));
// return
// mac_elem_16_2(a,d,mac_elem_16_2(a,e,mac_elem_16_2(b,d,mac_elem_16_2(
// d,c,mac_elem_16_2(b,e,mac_elem_16_2(a,f,mac_elem_16_2(
// b,f,mac_elem_16_2(c,e,mul_elem_16_2(c,f)))))))));
// }
// Caller example when handling the elementwise mul of two v16float vectors.
// v16float v1 = LHS(); v16float v2 = RHS();
// v16accfloat v3 = mul_elem_16(v1, v2);
// Explantion: For v32bfloat16 `a`, the first half v16bf16 contains `most
// significant 7 bits of mantissa` from v1, and the second half v16bf16 are
// zeros. For v16accfloat `acc0`, the MSC equals to "(original `v1` with 23
// bits of mantissa) - (`a` with MSB 7 bits of mantissa from v1)". For
// v32bfloat16 `b`, the first half v16bf16 contains `[7:13] bits of mantissa
// from v1` from v1, and the second half v16bf16 are zeros. For v32bfloat16
// `c`, the first half v16bf16 contains `[14:20] bits of mantissa from v1`
// from v1, and the second half v16bf16 are zeros. Hence, we can represent
// v16float in three v32bfloat16 and then perform 9 MUL/MAC in v32bfloat16 to
// get the final elementwise multiplication result.

LogicalResult
convertToEmulatedFP32MulElem(aievec::MulElemOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
auto zeroCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getBF16Type(),
rewriter.getZeroAttr(rewriter.getBF16Type()));
auto aZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
auto bZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
auto cZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
auto dZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
auto eZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
auto fZeros = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
loc, VectorType::get({32}, rewriter.getBF16Type()), zeroCst);
auto oneCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getBF16Type(),
rewriter.getOneAttr(rewriter.getBF16Type()));
auto dummy0 = rewriter.create<xllvm::VectorBroadcast16BF512IntrOp>(
loc, VectorType::get({32}, rewriter.getBF16Type()), oneCst);
auto zeroCstI32 = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
auto mscMacMulConfCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(aiev2_mul_mac_compute_control(
/*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3,
/*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
/*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0)));

auto extractV16FP32ToThreeV16BF16 =
[&](Value inputV16FP32, Value aZeros, Value bZeros,
Value cZeros) -> std::tuple<Value, Value, Value> {
// a = insert(a,0,to_v16bfloat16((v16accfloat)v1));
auto inputBitCasted =
forceCastValueToType(rewriter, loc, inputV16FP32,
VectorType::get({8}, rewriter.getI64Type()));
auto v1ToBF16 = rewriter.create<xllvm::Vector16AccFloatToV16BF16IntrOp>(
loc, VectorType::get({16}, rewriter.getBF16Type()), inputBitCasted);
auto a = rewriter.create<xllvm::UpdBF512BF256IntrOp>(
loc, VectorType::get({32}, rewriter.getBF16Type()), aZeros, v1ToBF16,
zeroCstI32);

// v16accfloat acc0 = msc_elem_16_2(a, dummy0, (v16accfloat)v1);
auto acc0 = rewriter.create<xllvm::MscConfBF16IntrOp>(
loc, VectorType::get({8}, rewriter.getI64Type()), a, dummy0,
inputBitCasted, mscMacMulConfCst);

// b = insert(b,0,to_v16bfloat16(acc0));
auto acc0ToBF16 = rewriter.create<xllvm::Vector16AccFloatToV16BF16IntrOp>(
loc, VectorType::get({16}, rewriter.getBF16Type()), acc0);
auto b = rewriter.create<xllvm::UpdBF512BF256IntrOp>(
loc, VectorType::get({32}, rewriter.getBF16Type()), bZeros,
acc0ToBF16, zeroCstI32);

// c = insert(c,0,to_v16bfloat16(msc_elem_16_2(b, dummy0, acc0)));
auto acc0Mscb = rewriter.create<xllvm::MscConfBF16IntrOp>(
loc, VectorType::get({8}, rewriter.getI64Type()), b, dummy0, acc0,
mscMacMulConfCst);
auto acc0MscbToBF16 =
rewriter.create<xllvm::Vector16AccFloatToV16BF16IntrOp>(
loc, VectorType::get({16}, rewriter.getBF16Type()), acc0Mscb);
auto c = rewriter.create<xllvm::UpdBF512BF256IntrOp>(
loc, VectorType::get({32}, rewriter.getBF16Type()), cZeros,
acc0MscbToBF16, zeroCstI32);
return std::make_tuple(a.getResult(), b.getResult(), c.getResult());
};

// Get v16vfloat16 a, b, c for representing v16float v1
auto [a, b, c] =
extractV16FP32ToThreeV16BF16(adaptor.getLhs(), aZeros, bZeros, cZeros);
// Get v16vfloat16 d, e, f for representing v16float v2
auto [d, e, f] =
extractV16FP32ToThreeV16BF16(adaptor.getRhs(), dZeros, eZeros, fZeros);

// 1 MUL + 8 * MACs
auto cfMul = rewriter.create<xllvm::MulConfBF16IntrOp>(
loc, VectorType::get({8}, rewriter.getI64Type()), c, f,
mscMacMulConfCst);
auto createMacOps = [&](Value lhs, Value rhs, Value acc) -> Value {
return rewriter
.create<xllvm::MacConfBF16IntrOp>(
loc, VectorType::get({8}, rewriter.getI64Type()), lhs, rhs, acc,
mscMacMulConfCst)
.getResult();
};
auto adMac = createMacOps(
a, d,
createMacOps(
a, e,
createMacOps(
b, d,
createMacOps(
d, c,
createMacOps(
b, e,
createMacOps(
a, f,
createMacOps(b, f, createMacOps(c, e, cfMul))))))));

// create bitcast for result
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
adMac);
return success();
}

LogicalResult
matchAndRewrite(aievec::MulElemOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand All @@ -562,10 +723,12 @@ class MulElemOpConversion
return failure();
}

// Handle the I32 mul_elem emulation
// TODO: handle the FP32 mul_elem emulation
// Handle the emulated I32/FP32 mul_elem
if (decodedMulElemOp.kind == DecodedMulElemOp::Kind::I32_I32_I64_32x1x2x1) {
return convertToI32MulElemEmulation(op, adaptor, rewriter);
return convertToEmulatedI32MulElem(op, adaptor, rewriter);
} else if (decodedMulElemOp.kind ==
DecodedMulElemOp::Kind::FP32_FP32_FP32_16x1x1x1) {
return convertToEmulatedFP32MulElem(op, adaptor, rewriter);
}

// create constant for config
Expand Down
6 changes: 1 addition & 5 deletions lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ struct ConvertMulFToAIEVecMulElemOpPattern

unsigned laneSize = getVectorLaneSize(resultType);

// bfloat16 type
// bfloat16 and float type
if (laneSize != 16 || (resultElWidth != 16 && resultElWidth != 32))
return failure();

Expand All @@ -636,10 +636,6 @@ struct ConvertMulFToAIEVecMulElemOpPattern
if (lSrcType != rSrcType) {
return failure();
}
// Only support two bfloat16 inputs at the moment
if (lBitWidth != 16 || rBitWidth != 16) {
return failure();
}

// Prepare lhr/rhs for the aievec.mul_elem op
VectorType targetInputType =
Expand Down
Loading

0 comments on commit de9e2dc

Please sign in to comment.