Skip to content

Commit

Permalink
[aievec] to-llvm flow for emulated i32xi32 elementwise multiplication (
Browse files Browse the repository at this point in the history
…#1176)

* Add some of the mul/srs/shuflle intrinsics to XLLVM.
* Add aievec-to-llvm conversion pattern for the emulated i32xi32 elementwise multiplication.
* Add aievec-to-llvm conversion tests for the newly added XLLVM ops.
* Add target llvm translation tests.
* Add i32xi32_mul_elem e2e tests for the to-llvm flow. This includes updating the testbench.cc and the test script. The e2e test, like other aievec to-cpp tests, goes through the simulator to verify the numeric correctness.
  • Loading branch information
jamestcl-amd authored Apr 11, 2024
1 parent 370781c commit ecfe108
Show file tree
Hide file tree
Showing 8 changed files with 335 additions and 24 deletions.
23 changes: 23 additions & 0 deletions include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ def MulConfAcc32IntrOp :
VectorOfLengthAndType<[16], [I32]>:$rhs,
I32:$conf)>;

def MulConfAcc64IntrOp :
AIEVec2_IntrOp<"I512.I512.acc64.mul.conf",
[TypeIs<"res", VectorOfLengthAndType<[16], [I64]>>]>,
Arguments<(ins VectorOfLengthAndType<[64], [I8]>:$lhs,
VectorOfLengthAndType<[16], [I32]>:$rhs,
I32:$conf)>;

def MulConfBF16IntrOp :
AIEVec2_IntrOp<"bf.mul16.conf",
[TypeIs<"res", VectorOfLengthAndType<[8], [I64]>>]>,
Expand Down Expand Up @@ -100,6 +107,13 @@ def I256V32Acc32SrsIntrOp :
I32:$shft,
I32:$sign)>;

def I512V16Acc64SrsIntrOp :
AIEVec2_IntrOp<"I512.v16.acc64.srs",
[TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$lhs,
I32:$shft,
I32:$sign)>;

def Vector16AccFloatToV16BF16IntrOp :
AIEVec2_IntrOp<"v16accfloat.to.v16bf16",
[TypeIs<"res", VectorOfLengthAndType<[16], [BF16]>>]>,
Expand Down Expand Up @@ -138,6 +152,15 @@ def ConcatI512I256IntrOp :
Arguments<(ins VectorOfLengthAndType<[8], [I32]>:$a0,
VectorOfLengthAndType<[8], [I32]>:$a1)>;

// ----- SHUFFLE -----

def VectorShuffleIntrOp :
AIEVec2_IntrOp<"vshuffle",
[TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>,
Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$a,
VectorOfLengthAndType<[16], [I32]>:$b,
I32:$mode)>;

// ----- UNDEF -----

def UndefV16I32IntrOp :
Expand Down
194 changes: 173 additions & 21 deletions lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,23 +364,39 @@ class MulElemOpConversion
// DtIn0_DtIn1_DtRes_CxMxKxN
I8_I8_I32_32x1x2x1,
I16_I16_I32_32x1x1x1,
I32_I32_I64_32x1x2x1,
BF16_BF16_FP32_16x1x2x1,
UNSUPPORTED
// TODO: I16_I16_I64_16x1x2x1
// TODO: I32 and FP32 mul_elem are emulated
// TODO: FP32 mul_elem is emulated
};

Kind kind;
int conf;
};

// sgn_x: Sign mask of matrix X. If it is one matrix X is interpreted as
// signed, else it treated as unsigned.
// sgn_y: Sign mask of matrix Y. If it is one matrix Y is interpreted as
// signed, else it treated as unsigned.
// amode/bmode/variant: config acc width, mul precision, and mul mode
// zero_acc: Zeroing of acc1. If it is one then acc1 is zeroed.
// shift16: Shift mask of acc1. If a bit is set the <<16 operation will be
// executed on acc1.
// sub_mul: Negation mask of the matrix multiplication result. If it is
// one the result of the operation will be negated.
// sub_acc1: Negation mask of acc1. If it is one acc1 will be negated.
// sub_acc2: Negation mask of acc2. If it is one acc2 will be negated.
// sub_mask: Negation mask of complex multiplications. Negates a term of a
// complex multiplication.
static int aiev2_mul_mac_compute_control(int sgn_x, int sgn_y, int amode,
int bmode, int variant, int zero_acc,
int shift16, int sub0, int sub1,
int sub2, int sub_mask) {
int shift16, int sub_mul,
int sub_acc1, int sub_acc2,
int sub_mask) {
return ((unsigned)sub_mask << 16) | ((unsigned)shift16 << 10) |
((unsigned)sub0 << 11) | ((unsigned)sub1 << 12) |
((unsigned)sub2 << 13) | ((unsigned)amode << 1) |
((unsigned)sub_mul << 11) | ((unsigned)sub_acc1 << 12) |
((unsigned)sub_acc2 << 13) | ((unsigned)amode << 1) |
((unsigned)bmode << 3) | ((unsigned)variant << 5) |
(((unsigned)sgn_x << 9) | ((unsigned)sgn_y << 8)) |
((unsigned)zero_acc << 0);
Expand All @@ -396,22 +412,145 @@ class MulElemOpConversion
if (lhsScaTy.isa<IntegerType>()) {
if (lhsBitWidth == 8) {
return {DecodedMulElemOp::Kind::I8_I8_I32_32x1x2x1,
aiev2_mul_mac_compute_control(1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0)};
aiev2_mul_mac_compute_control(
/*sgn_x=*/1, /*sgn_y=*/1, /*amode=*/0, /*bmode=*/1,
/*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
/*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
/*sub_mask=*/0)};
} else if (lhsBitWidth == 16) {
return {DecodedMulElemOp::Kind::I16_I16_I32_32x1x1x1,
aiev2_mul_mac_compute_control(1, 1, 0, 3, 1, 0, 0, 0, 0, 0, 0)};
aiev2_mul_mac_compute_control(
/*sgn_x=*/1, /*sgn_y=*/1, /*amode=*/0, /*bmode=*/3,
/*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
/*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
/*sub_mask=*/0)};
} else if (lhsBitWidth == 32) {
return {DecodedMulElemOp::Kind::I32_I32_I64_32x1x2x1, -1};
}
} else {
// Float types
if (lhsBitWidth == 16) {
return {DecodedMulElemOp::Kind::BF16_BF16_FP32_16x1x2x1,
aiev2_mul_mac_compute_control(0, 0, 2, 3, 1, 0, 0, 0, 0, 0, 0)};
aiev2_mul_mac_compute_control(
/*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3,
/*variant=*/1, /*zero_acc=*/0, /*shift16=*/0,
/*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0,
/*sub_mask=*/0)};
}
}

return {DecodedMulElemOp::Kind::UNSUPPORTED, -1};
}

// This conversion pattern implements the below CPP I32 mul_elem emulation.
// INTRINSIC(v16acc64)
// mul_elem_16_2(v16int32 a0, v16int32 a1, v16int32 b0, v16int32 b1) {
// v32uint16 a_lo = (v32uint16)shuffle(a0, a1, 2);
// v32int16 a_hi = (v32int16)shuffle(a0, a1, 3);
// v32uint16 b_lo = (v32uint16)shuffle(b0, b1, 2);
// v32int16 b_hi = (v32int16)shuffle(b0, b1, 3);
// v16acc64 acc = ::mul_elem_16_2(a_hi, b_hi);
// acc = mac_elem_16_2_conf(a_hi, 1, b_lo, false, acc, 0, 1, 0, 0);
// acc = mac_elem_16_2_conf(a_lo, false, b_hi, 1, acc, 0, 0, 0, 0);
// acc = mac_elem_16_2_conf(a_lo, false, b_lo, false, acc, 0, 1, 0, 0);
// return acc;
// }
// Caller to the above CPP intrinsic:
// v16int32 v1 = LHS();
// v16int32 v2 = RHS();
// v16acc64 v3 = mul_elem_16_2(v1, broadcast_zero_s32(), v2,
// undef_v16int32());
LogicalResult
convertToI32MulElemEmulation(aievec::MulElemOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

Location loc = op.getLoc();
auto zeroCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
auto a0 = adaptor.getLhs();
auto a1 = rewriter.create<xllvm::VectorBroadcast32I512IntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()), zeroCst);
auto b0 = adaptor.getRhs();
auto b1 = rewriter.create<xllvm::UndefV16I32IntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()));

// 4* Shuffle
auto a_lo = rewriter.create<xllvm::VectorShuffleIntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()), a0, a1,
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(2)));
auto a_hi = rewriter.create<xllvm::VectorShuffleIntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()), a0, a1,
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(3)));
auto b_lo = rewriter.create<xllvm::VectorShuffleIntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()), b0, b1,
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(2)));
auto b_hi = rewriter.create<xllvm::VectorShuffleIntrOp>(
loc, VectorType::get({16}, rewriter.getI32Type()), b0, b1,
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(3)));
// MUL + 3 * MAC
auto mulConfCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(aiev2_mul_mac_compute_control(
/*sgn_x=*/1, /*sgn_y=*/1, /*amode=*/1, /*bmode=*/3,
/*variant=*/2, /*zero_acc=*/0, /*shift16=*/0,
/*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0)));
auto mulConfOp = rewriter.create<xllvm::MulConfAcc64IntrOp>(
loc, VectorType::get({16}, rewriter.getI64Type()),
forceCastOperandsToSignature(
rewriter, loc,
/*operands=*/{a_hi, b_hi, mulConfCst},
/*signature=*/
{VectorType::get({64}, rewriter.getI8Type()),
VectorType::get({16}, rewriter.getI32Type()),
rewriter.getI32Type()}));

auto createMacConfOp = [&](SmallVector<Value> operands,
int macConf) -> Value {
operands.push_back(rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(macConf)));
return rewriter
.create<xllvm::MacConfAcc64IntrOp>(
loc, VectorType::get({16}, rewriter.getI64Type()),
forceCastOperandsToSignature(
rewriter, loc,
/*operands=*/operands,
/*signature=*/
{VectorType::get({64}, rewriter.getI8Type()),
VectorType::get({16}, rewriter.getI32Type()),
VectorType::get({16}, rewriter.getI64Type()),
rewriter.getI32Type()}))
.getResult();
};
auto acc64Val = mulConfOp.getResult();
acc64Val = createMacConfOp(
SmallVector<Value>{a_hi, b_lo, acc64Val},
aiev2_mul_mac_compute_control(
/*sgn_x=*/1, /*sgn_y=*/0, /*amode=*/1, /*bmode=*/3,
/*variant=*/2, /*zero_acc=*/0, /*shift16=*/1,
/*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0));
acc64Val = createMacConfOp(
SmallVector<Value>{a_lo, b_hi, acc64Val},
aiev2_mul_mac_compute_control(
/*sgn_x=*/0, /*sgn_y=*/1, /*amode=*/1, /*bmode=*/3,
/*variant=*/2, /*zero_acc=*/0, /*shift16=*/0,
/*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0));
acc64Val = createMacConfOp(
SmallVector<Value>{a_lo, b_lo, acc64Val},
aiev2_mul_mac_compute_control(
/*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/1, /*bmode=*/3,
/*variant=*/2, /*zero_acc=*/0, /*shift16=*/1,
/*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, /*sub_mask=*/0));

// create bitcast for result
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getResult().getType(),
acc64Val);
return success();
}

LogicalResult
matchAndRewrite(aievec::MulElemOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand All @@ -423,6 +562,12 @@ class MulElemOpConversion
return failure();
}

// Handle the I32 mul_elem emulation
// TODO: handle the FP32 mul_elem emulation
if (decodedMulElemOp.kind == DecodedMulElemOp::Kind::I32_I32_I64_32x1x2x1) {
return convertToI32MulElemEmulation(op, adaptor, rewriter);
}

// create constant for config
auto confCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
Expand Down Expand Up @@ -496,12 +641,21 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::SRSOp> {
SmallVector<Value> operands(
{adaptor.getSource(), adaptor.getShift(), signCst});
if (resultVectorSize == 512) {
rewriter.replaceOpWithNewOp<xllvm::I512V32Acc32SrsIntrOp>(
op, VectorType::get({32}, rewriter.getI16Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({16}, rewriter.getI64Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
if (resultBitWidth == 16) {
rewriter.replaceOpWithNewOp<xllvm::I512V32Acc32SrsIntrOp>(
op, VectorType::get({32}, rewriter.getI16Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({16}, rewriter.getI64Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
} else if (resultBitWidth == 32) {
rewriter.replaceOpWithNewOp<xllvm::I512V16Acc64SrsIntrOp>(
op, VectorType::get({16}, rewriter.getI32Type()),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({16}, rewriter.getI64Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
}
} else if (resultVectorSize == 256) {
rewriter.replaceOpWithNewOp<xllvm::I256V32Acc32SrsIntrOp>(
op, VectorType::get({32}, rewriter.getI8Type()),
Expand Down Expand Up @@ -1124,13 +1278,11 @@ class MatMulOpConversion
}
};

/*
This pattern folds aievec.cast op. For AIE-ML, the accumulators are in 32/64
bits, and the vectors are in 4/8/16/32 bits. Hence, we don't have to
explicitly express the casting between accumulators and vectors at the LLVM
dialect level. The backend LLVM compiler will decide the correct accumulator
or vector registers given the ops and intrinsics.
*/
// This pattern folds aievec.cast op. For AIE-ML, the accumulators are in 32/64
// bits, and the vectors are in 4/8/16/32 bits. Hence, we don't have to
// explicitly express the casting between accumulators and vectors at the LLVM
// dialect level. The backend LLVM compiler will decide the correct accumulator
// or vector registers given the ops and intrinsics.
class FoldAIECastOps : public mlir::ConvertOpToLLVMPattern<aievec::CastOp> {
using ConvertOpToLLVMPattern<aievec::CastOp>::ConvertOpToLLVMPattern;

Expand Down
36 changes: 36 additions & 0 deletions test/Conversion/AIEVecToLLVM/mul_elem.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,42 @@ func.func @i8_i8_i32_mul_elem(%arg0 : vector<64xi8>, %arg1 : vector<64xi8>) -> v

// -----

func.func @i32_i32_i32_mul_elem(%arg0 : vector<16xi32>, %arg1 : vector<16xi32>) -> vector<16xi64> {
%0 = aievec.mul_elem %arg0, %arg1 : vector<16xi32>, vector<16xi32>, vector<16xi64>
return %0 : vector<16xi64>
}

// CHECK-LABEL: @i32_i32_i32_mul_elem
// CHECK-SAME: %[[ARG0:.*]]: vector<16xi32>,
// CHECK-SAME: %[[ARG1:.*]]: vector<16xi32>
// CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[VBROADCAST:.*]] = "xllvm.intr.aie2.vbroadcast32.I512"(%[[CST0]]) : (i32) -> vector<16xi32>
// CHECK-NEXT: %[[UNDEF:.*]] = "xllvm.intr.aie2.v16int32"() : () -> vector<16xi32>
// CHECK-NEXT: %[[CST1:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK-NEXT: %[[SHUFF0:.*]] = "xllvm.intr.aie2.vshuffle"(%[[ARG0]], %[[VBROADCAST]], %[[CST1]]) : (vector<16xi32>, vector<16xi32>, i32) -> vector<16xi32>
// CHECK-NEXT: %[[CST2:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK-NEXT: %[[SHUFF1:.*]] = "xllvm.intr.aie2.vshuffle"(%[[ARG0]], %[[VBROADCAST]], %[[CST2]]) : (vector<16xi32>, vector<16xi32>, i32) -> vector<16xi32>
// CHECK-NEXT: %[[CST3:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK-NEXT: %[[SHUFF2:.*]] = "xllvm.intr.aie2.vshuffle"(%[[ARG1]], %[[UNDEF]], %[[CST3]]) : (vector<16xi32>, vector<16xi32>, i32) -> vector<16xi32>
// CHECK-NEXT: %[[CST4:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK-NEXT: %[[SHUFF3:.*]] = "xllvm.intr.aie2.vshuffle"(%[[ARG1]], %[[UNDEF]], %[[CST4]]) : (vector<16xi32>, vector<16xi32>, i32) -> vector<16xi32>
// CHECK-NEXT: %[[CST5:.*]] = llvm.mlir.constant(858 : i32) : i32
// CHECK-NEXT: %[[BITCAST0:.*]] = llvm.bitcast %[[SHUFF1]] : vector<16xi32> to vector<64xi8>
// CHECK-NEXT: %[[ACC0:.*]] = "xllvm.intr.aie2.I512.I512.acc64.mul.conf"(%[[BITCAST0]], %[[SHUFF3]], %[[CST5]]) : (vector<64xi8>, vector<16xi32>, i32) -> vector<16xi64>
// CHECK-NEXT: %[[CST6:.*]] = llvm.mlir.constant(1626 : i32) : i32
// CHECK-NEXT: %[[BITCAST1:.*]] = llvm.bitcast %[[SHUFF1]] : vector<16xi32> to vector<64xi8>
// CHECK-NEXT: %[[ACC1:.*]] = "xllvm.intr.aie2.I512.I512.ACC1024.acc64.mac.conf"(%[[BITCAST1]], %[[SHUFF2]], %[[ACC0]], %[[CST6]]) : (vector<64xi8>, vector<16xi32>, vector<16xi64>, i32) -> vector<16xi64>
// CHECK-NEXT: %[[CST7:.*]] = llvm.mlir.constant(346 : i32) : i32
// CHECK-NEXT: %[[BITCAST2:.*]] = llvm.bitcast %[[SHUFF0]] : vector<16xi32> to vector<64xi8>
// CHECK-NEXT: %[[ACC2:.*]] = "xllvm.intr.aie2.I512.I512.ACC1024.acc64.mac.conf"(%[[BITCAST2]], %[[SHUFF3]], %[[ACC1]], %[[CST7]]) : (vector<64xi8>, vector<16xi32>, vector<16xi64>, i32) -> vector<16xi64>
// CHECK-NEXT: %[[CST8:.*]] = llvm.mlir.constant(1114 : i32) : i32
// CHECK-NEXT: %[[BITCAST3:.*]] = llvm.bitcast %[[SHUFF0]] : vector<16xi32> to vector<64xi8>
// CHECK-NEXT: %[[ACC3:.*]] = "xllvm.intr.aie2.I512.I512.ACC1024.acc64.mac.conf"(%[[BITCAST3]], %[[SHUFF2]], %[[ACC2]], %[[CST8]]) : (vector<64xi8>, vector<16xi32>, vector<16xi64>, i32) -> vector<16xi64>
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[ACC3]] : vector<16xi64> to vector<16xi64>
// CHECK-NEXT: return %[[RES]] : vector<16xi64>

// -----

func.func @bf16_bf16_f32_mul_elem(%arg0 : vector<32xbf16>, %arg1 : vector<32xbf16>) -> vector<16xf32> {
%0 = aievec.mul_elem %arg0, %arg1 : vector<32xbf16>, vector<32xbf16>, vector<16xf32>
return %0 : vector<16xf32>
Expand Down
23 changes: 23 additions & 0 deletions test/Conversion/AIEVecToLLVM/test-srs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,29 @@ func.func @v32i32_srs_v32i8(%arg0 : vector<32xi32>) {

// -----

func.func @v16i64_srs_v16i32(%arg0 : vector<16xi64>) {
%c0 = arith.constant 0 : i32
%c5 = arith.constant 5 : i32
%0 = aievec.srs %arg0, %c0 : vector<16xi64>, i32, vector<16xi32>
%1 = aievec.srs %arg0, %c5 : vector<16xi64>, i32, vector<16xi32>
return
}

// CHECK-LABEL: @v16i64_srs_v16i32
// CHECK-SAME: %[[ARG0:.*]]: vector<16xi64>
// CHECK-NEXT: %[[SHIFT0:.*]] = arith.constant 0 : i32
// CHECK-NEXT: %[[SHIFT5:.*]] = arith.constant 5 : i32
// CHECK-NEXT: %[[SIGN0:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[SRS0:.*]] = "xllvm.intr.aie2.I512.v16.acc64.srs"(
// CHECK-SAME: [[ARG0]], %[[SHIFT0]], %[[SIGN0]]) :
// CHECK-SAME: (vector<16xi64>, i32, i32) -> vector<16xi32>
// CHECK-NEXT: %[[SIGN1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[SRS1:.*]] = "xllvm.intr.aie2.I512.v16.acc64.srs"(
// CHECK-SAME: [[ARG0]], %[[SHIFT5]], %[[SIGN1]]) :
// CHECK-SAME: (vector<16xi64>, i32, i32) -> vector<16xi32>

// -----

func.func @v16f32_srs_v16bf16(%arg0 : vector<16xf32>) {
%c0 = arith.constant 0 : i32
%c5 = arith.constant 5 : i32
Expand Down
Loading

0 comments on commit ecfe108

Please sign in to comment.