Skip to content

Commit

Permalink
Un-unroll ukernel C+intrinsics code. (#14908)
Browse files Browse the repository at this point in the history
Fully unrolling C+intrinsics code is a common defensive practice vs. the
tendency of compilers to miss good codegen of SIMD code. It's not just
about the loops, it's about using arrays of vector-variables, which is
necessary to be able to write loop. A sufficiently naive compiler will
literally take that to mean that the vectors are memory objects. Several
years ago I had filed https://bugs.llvm.org/show_bug.cgi?id=34945 and
never heard back about it. To this day, XNNPACK sticks to this practice,
e.g.
https://github.com/google/XNNPACK/blob/master/src/f32-gemm/gen/f32-gemm-8x8s4-minmax-neon.c#L238-L269
.

This prompts the question of how to manage the resulting verbose code,
how to scale to supporting many variants of ukernels. The immediate
motivation for us here is as we are about to introduce narrow variants
of matmul kernels. XNNPACK deals with that with a Python-based generator
of unrolled C code. In our case, as our primary deployment path for
ukernels is to compile them to LLVM bitcode that IREE can then inline at
each call site and "LTO", where it should be able to perform loop
unrolling and dead code optimization, it would be neat to simply take
advantage of that, instead of inventing a new way to unroll loops and
skip over dead code, or carry verbose source code.

The danger is regressing performance in the native-toolchain,
non-bitcode builds of ukernels. That's only used in VMVX, and in
ukernel's own micro benchmarks (and unit tests). Performance of that
isn't really critical. We want to make sure that we build correctly
there, but it's OK to have suboptimal performance. To be clear, to
preserve performance in the native build, we will still instantiate
functions with the loop size known at compile time (calling into the
shared loop impl, inlined into each case). The only question is whether
the native toolchain will handle that inlining as well as Clang does.
Concretely, I tried one case, and found that GCC generates ~ 2x slower
code, while Clang and MSVC did fine. https://godbolt.org/z/WsbW487ze

Just look at the code shrink here. And this is only a first step. As a
next PR will introduce variants for narrow M0 dimensions, they will be
able to all share the same loop implementation, both in source code and
in embedded bitcode in the bitcode build.
  • Loading branch information
bjacob authored Sep 11, 2023
1 parent 0eb7b1a commit 801356f
Show file tree
Hide file tree
Showing 9 changed files with 490 additions and 1,348 deletions.
337 changes: 99 additions & 238 deletions runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64.c

Large diffs are not rendered by default.

146 changes: 38 additions & 108 deletions runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_bf16.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,122 +38,52 @@ void iree_uk_mmt4d_tile_bf16bf16f32_8x8x4_arm_64_bf16(
const bfloat16_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
const bfloat16_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
float* IREE_UK_RESTRICT out_ptr = out_tile;
float32x4_t acc_01_01, acc_01_23, acc_01_45, acc_01_67;
float32x4_t acc_23_01, acc_23_23, acc_23_45, acc_23_67;
float32x4_t acc_45_01, acc_45_23, acc_45_45, acc_45_67;
float32x4_t acc_67_01, acc_67_23, acc_67_45, acc_67_67;
// Accumulator 2x2 register tiles.
float32x4_t acc[4][4];
if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
float32x4_t acc_0_0123 = vld1q_f32(out_ptr + 8 * 0 + 0);
float32x4_t acc_0_4567 = vld1q_f32(out_ptr + 8 * 0 + 4);
float32x4_t acc_1_0123 = vld1q_f32(out_ptr + 8 * 1 + 0);
float32x4_t acc_1_4567 = vld1q_f32(out_ptr + 8 * 1 + 4);
float32x4_t acc_2_0123 = vld1q_f32(out_ptr + 8 * 2 + 0);
float32x4_t acc_2_4567 = vld1q_f32(out_ptr + 8 * 2 + 4);
float32x4_t acc_3_0123 = vld1q_f32(out_ptr + 8 * 3 + 0);
float32x4_t acc_3_4567 = vld1q_f32(out_ptr + 8 * 3 + 4);
float32x4_t acc_4_0123 = vld1q_f32(out_ptr + 8 * 4 + 0);
float32x4_t acc_4_4567 = vld1q_f32(out_ptr + 8 * 4 + 4);
float32x4_t acc_5_0123 = vld1q_f32(out_ptr + 8 * 5 + 0);
float32x4_t acc_5_4567 = vld1q_f32(out_ptr + 8 * 5 + 4);
float32x4_t acc_6_0123 = vld1q_f32(out_ptr + 8 * 6 + 0);
float32x4_t acc_6_4567 = vld1q_f32(out_ptr + 8 * 6 + 4);
float32x4_t acc_7_0123 = vld1q_f32(out_ptr + 8 * 7 + 0);
float32x4_t acc_7_4567 = vld1q_f32(out_ptr + 8 * 7 + 4);
acc_01_01 = iree_uk_neon_zip1_f32_as_s64(acc_0_0123, acc_1_0123);
acc_01_23 = iree_uk_neon_zip2_f32_as_s64(acc_0_0123, acc_1_0123);
acc_01_45 = iree_uk_neon_zip1_f32_as_s64(acc_0_4567, acc_1_4567);
acc_01_67 = iree_uk_neon_zip2_f32_as_s64(acc_0_4567, acc_1_4567);
acc_23_01 = iree_uk_neon_zip1_f32_as_s64(acc_2_0123, acc_3_0123);
acc_23_23 = iree_uk_neon_zip2_f32_as_s64(acc_2_0123, acc_3_0123);
acc_23_45 = iree_uk_neon_zip1_f32_as_s64(acc_2_4567, acc_3_4567);
acc_23_67 = iree_uk_neon_zip2_f32_as_s64(acc_2_4567, acc_3_4567);
acc_45_01 = iree_uk_neon_zip1_f32_as_s64(acc_4_0123, acc_5_0123);
acc_45_23 = iree_uk_neon_zip2_f32_as_s64(acc_4_0123, acc_5_0123);
acc_45_45 = iree_uk_neon_zip1_f32_as_s64(acc_4_4567, acc_5_4567);
acc_45_67 = iree_uk_neon_zip2_f32_as_s64(acc_4_4567, acc_5_4567);
acc_67_01 = iree_uk_neon_zip1_f32_as_s64(acc_6_0123, acc_7_0123);
acc_67_23 = iree_uk_neon_zip2_f32_as_s64(acc_6_0123, acc_7_0123);
acc_67_45 = iree_uk_neon_zip1_f32_as_s64(acc_6_4567, acc_7_4567);
acc_67_67 = iree_uk_neon_zip2_f32_as_s64(acc_6_4567, acc_7_4567);
// Load row-major accumulator and swizzle into 2x2 register tiles.
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 2; ++j) {
float32x4_t acc_1x4_0 = vld1q_f32(out_ptr + 8 * (2 * i + 0) + 4 * j);
float32x4_t acc_1x4_1 = vld1q_f32(out_ptr + 8 * (2 * i + 1) + 4 * j);
acc[i][2 * j + 0] = iree_uk_neon_zip1_f32_as_s64(acc_1x4_0, acc_1x4_1);
acc[i][2 * j + 1] = iree_uk_neon_zip2_f32_as_s64(acc_1x4_0, acc_1x4_1);
}
}
} else {
acc_01_01 = vdupq_n_f32(0);
acc_01_23 = vdupq_n_f32(0);
acc_01_45 = vdupq_n_f32(0);
acc_01_67 = vdupq_n_f32(0);
acc_23_01 = vdupq_n_f32(0);
acc_23_23 = vdupq_n_f32(0);
acc_23_45 = vdupq_n_f32(0);
acc_23_67 = vdupq_n_f32(0);
acc_45_01 = vdupq_n_f32(0);
acc_45_23 = vdupq_n_f32(0);
acc_45_45 = vdupq_n_f32(0);
acc_45_67 = vdupq_n_f32(0);
acc_67_01 = vdupq_n_f32(0);
acc_67_23 = vdupq_n_f32(0);
acc_67_45 = vdupq_n_f32(0);
acc_67_67 = vdupq_n_f32(0);
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
acc[i][j] = vdupq_n_f32(0);
}
}
}

IREE_UK_ASSUME(params->K >= 1);
for (int k = 0; k < params->K; ++k) {
bfloat16x8_t lhs01 = vld1q_bf16(lhs_ptr + 0);
bfloat16x8_t lhs23 = vld1q_bf16(lhs_ptr + 8);
bfloat16x8_t lhs45 = vld1q_bf16(lhs_ptr + 16);
bfloat16x8_t lhs67 = vld1q_bf16(lhs_ptr + 24);
bfloat16x8_t lhs[4];
bfloat16x8_t rhs[4];
for (int i = 0; i < 4; ++i) {
lhs[i] = vld1q_bf16(lhs_ptr + 8 * i);
rhs[i] = vld1q_bf16(rhs_ptr + 8 * i);
}
lhs_ptr += 32;
bfloat16x8_t rhs01 = vld1q_bf16(rhs_ptr + 0);
bfloat16x8_t rhs23 = vld1q_bf16(rhs_ptr + 8);
bfloat16x8_t rhs45 = vld1q_bf16(rhs_ptr + 16);
bfloat16x8_t rhs67 = vld1q_bf16(rhs_ptr + 24);
rhs_ptr += 32;
acc_01_01 = vbfmmlaq_f32(acc_01_01, lhs01, rhs01);
acc_01_23 = vbfmmlaq_f32(acc_01_23, lhs01, rhs23);
acc_01_45 = vbfmmlaq_f32(acc_01_45, lhs01, rhs45);
acc_01_67 = vbfmmlaq_f32(acc_01_67, lhs01, rhs67);
acc_23_01 = vbfmmlaq_f32(acc_23_01, lhs23, rhs01);
acc_23_23 = vbfmmlaq_f32(acc_23_23, lhs23, rhs23);
acc_23_45 = vbfmmlaq_f32(acc_23_45, lhs23, rhs45);
acc_23_67 = vbfmmlaq_f32(acc_23_67, lhs23, rhs67);
acc_45_01 = vbfmmlaq_f32(acc_45_01, lhs45, rhs01);
acc_45_23 = vbfmmlaq_f32(acc_45_23, lhs45, rhs23);
acc_45_45 = vbfmmlaq_f32(acc_45_45, lhs45, rhs45);
acc_45_67 = vbfmmlaq_f32(acc_45_67, lhs45, rhs67);
acc_67_01 = vbfmmlaq_f32(acc_67_01, lhs67, rhs01);
acc_67_23 = vbfmmlaq_f32(acc_67_23, lhs67, rhs23);
acc_67_45 = vbfmmlaq_f32(acc_67_45, lhs67, rhs45);
acc_67_67 = vbfmmlaq_f32(acc_67_67, lhs67, rhs67);
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
acc[i][j] = vbfmmlaq_f32(acc[i][j], lhs[i], rhs[j]);
}
}
}

float32x4_t acc_0_0123 = iree_uk_neon_uzp1_f32_as_s64(acc_01_01, acc_01_23);
float32x4_t acc_0_4567 = iree_uk_neon_uzp1_f32_as_s64(acc_01_45, acc_01_67);
float32x4_t acc_1_0123 = iree_uk_neon_uzp2_f32_as_s64(acc_01_01, acc_01_23);
float32x4_t acc_1_4567 = iree_uk_neon_uzp2_f32_as_s64(acc_01_45, acc_01_67);
float32x4_t acc_2_0123 = iree_uk_neon_uzp1_f32_as_s64(acc_23_01, acc_23_23);
float32x4_t acc_2_4567 = iree_uk_neon_uzp1_f32_as_s64(acc_23_45, acc_23_67);
float32x4_t acc_3_0123 = iree_uk_neon_uzp2_f32_as_s64(acc_23_01, acc_23_23);
float32x4_t acc_3_4567 = iree_uk_neon_uzp2_f32_as_s64(acc_23_45, acc_23_67);
float32x4_t acc_4_0123 = iree_uk_neon_uzp1_f32_as_s64(acc_45_01, acc_45_23);
float32x4_t acc_4_4567 = iree_uk_neon_uzp1_f32_as_s64(acc_45_45, acc_45_67);
float32x4_t acc_5_0123 = iree_uk_neon_uzp2_f32_as_s64(acc_45_01, acc_45_23);
float32x4_t acc_5_4567 = iree_uk_neon_uzp2_f32_as_s64(acc_45_45, acc_45_67);
float32x4_t acc_6_0123 = iree_uk_neon_uzp1_f32_as_s64(acc_67_01, acc_67_23);
float32x4_t acc_6_4567 = iree_uk_neon_uzp1_f32_as_s64(acc_67_45, acc_67_67);
float32x4_t acc_7_0123 = iree_uk_neon_uzp2_f32_as_s64(acc_67_01, acc_67_23);
float32x4_t acc_7_4567 = iree_uk_neon_uzp2_f32_as_s64(acc_67_45, acc_67_67);
vst1q_f32(out_ptr + 8 * 0 + 0, acc_0_0123);
vst1q_f32(out_ptr + 8 * 0 + 4, acc_0_4567);
vst1q_f32(out_ptr + 8 * 1 + 0, acc_1_0123);
vst1q_f32(out_ptr + 8 * 1 + 4, acc_1_4567);
vst1q_f32(out_ptr + 8 * 2 + 0, acc_2_0123);
vst1q_f32(out_ptr + 8 * 2 + 4, acc_2_4567);
vst1q_f32(out_ptr + 8 * 3 + 0, acc_3_0123);
vst1q_f32(out_ptr + 8 * 3 + 4, acc_3_4567);
vst1q_f32(out_ptr + 8 * 4 + 0, acc_4_0123);
vst1q_f32(out_ptr + 8 * 4 + 4, acc_4_4567);
vst1q_f32(out_ptr + 8 * 5 + 0, acc_5_0123);
vst1q_f32(out_ptr + 8 * 5 + 4, acc_5_4567);
vst1q_f32(out_ptr + 8 * 6 + 0, acc_6_0123);
vst1q_f32(out_ptr + 8 * 6 + 4, acc_6_4567);
vst1q_f32(out_ptr + 8 * 7 + 0, acc_7_0123);
vst1q_f32(out_ptr + 8 * 7 + 4, acc_7_4567);
// Swizzle accumulator 2x2 register tiles back to row-major and store.
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 2; ++j) {
float32x4_t acc_1x4_0 =
iree_uk_neon_uzp1_f32_as_s64(acc[i][2 * j + 0], acc[i][2 * j + 1]);
float32x4_t acc_1x4_1 =
iree_uk_neon_uzp2_f32_as_s64(acc[i][2 * j + 0], acc[i][2 * j + 1]);
vst1q_f32(out_ptr + 8 * (2 * i + 0) + 4 * j, acc_1x4_0);
vst1q_f32(out_ptr + 8 * (2 * i + 1) + 4 * j, acc_1x4_1);
}
}
}
103 changes: 33 additions & 70 deletions runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,82 +14,45 @@ void iree_uk_mmt4d_tile_i8i8i32_8x8x4_arm_64_dotprod(
const iree_uk_int8_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
const iree_uk_int8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile;
int32x4_t acc0, acc1, acc2, acc3, acc4, acc5, acc6, acc7, acc8, acc9, acc10,
acc11, acc12, acc13, acc14, acc15;
int32x4_t acc[16];
if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
acc0 = vld1q_s32(out_ptr + 4 * 0);
acc1 = vld1q_s32(out_ptr + 4 * 1);
acc2 = vld1q_s32(out_ptr + 4 * 2);
acc3 = vld1q_s32(out_ptr + 4 * 3);
acc4 = vld1q_s32(out_ptr + 4 * 4);
acc5 = vld1q_s32(out_ptr + 4 * 5);
acc6 = vld1q_s32(out_ptr + 4 * 6);
acc7 = vld1q_s32(out_ptr + 4 * 7);
acc8 = vld1q_s32(out_ptr + 4 * 8);
acc9 = vld1q_s32(out_ptr + 4 * 9);
acc10 = vld1q_s32(out_ptr + 4 * 10);
acc11 = vld1q_s32(out_ptr + 4 * 11);
acc12 = vld1q_s32(out_ptr + 4 * 12);
acc13 = vld1q_s32(out_ptr + 4 * 13);
acc14 = vld1q_s32(out_ptr + 4 * 14);
acc15 = vld1q_s32(out_ptr + 4 * 15);
for (int i = 0; i < 16; ++i) {
acc[i] = vld1q_s32(out_ptr + 4 * i);
}
} else {
acc0 = vdupq_n_s32(0);
acc1 = vdupq_n_s32(0);
acc2 = vdupq_n_s32(0);
acc3 = vdupq_n_s32(0);
acc4 = vdupq_n_s32(0);
acc5 = vdupq_n_s32(0);
acc6 = vdupq_n_s32(0);
acc7 = vdupq_n_s32(0);
acc8 = vdupq_n_s32(0);
acc9 = vdupq_n_s32(0);
acc10 = vdupq_n_s32(0);
acc11 = vdupq_n_s32(0);
acc12 = vdupq_n_s32(0);
acc13 = vdupq_n_s32(0);
acc14 = vdupq_n_s32(0);
acc15 = vdupq_n_s32(0);
for (int i = 0; i < 16; ++i) {
acc[i] = vdupq_n_s32(0);
}
}
IREE_UK_ASSUME(params->K >= 1);
for (int k = 0; k < params->K; ++k) {
int8x16_t lhs0 = vld1q_s8(lhs_ptr + 0);
int8x16_t lhs1 = vld1q_s8(lhs_ptr + 16);
int8x16_t lhs[2];
int8x16_t rhs[2];
for (int i = 0; i < 2; ++i) {
lhs[i] = vld1q_s8(lhs_ptr + 16 * i);
rhs[i] = vld1q_s8(rhs_ptr + 16 * i);
}
lhs_ptr += 32;
int8x16_t rhs0 = vld1q_s8(rhs_ptr + 0);
int8x16_t rhs1 = vld1q_s8(rhs_ptr + 16);
rhs_ptr += 32;
acc0 = vdotq_lane_s32(acc0, rhs0, vget_low_s8(lhs0), 0);
acc1 = vdotq_lane_s32(acc1, rhs1, vget_low_s8(lhs0), 0);
acc2 = vdotq_lane_s32(acc2, rhs0, vget_low_s8(lhs0), 1);
acc3 = vdotq_lane_s32(acc3, rhs1, vget_low_s8(lhs0), 1);
acc4 = vdotq_lane_s32(acc4, rhs0, vget_high_s8(lhs0), 0);
acc5 = vdotq_lane_s32(acc5, rhs1, vget_high_s8(lhs0), 0);
acc6 = vdotq_lane_s32(acc6, rhs0, vget_high_s8(lhs0), 1);
acc7 = vdotq_lane_s32(acc7, rhs1, vget_high_s8(lhs0), 1);
acc8 = vdotq_lane_s32(acc8, rhs0, vget_low_s8(lhs1), 0);
acc9 = vdotq_lane_s32(acc9, rhs1, vget_low_s8(lhs1), 0);
acc10 = vdotq_lane_s32(acc10, rhs0, vget_low_s8(lhs1), 1);
acc11 = vdotq_lane_s32(acc11, rhs1, vget_low_s8(lhs1), 1);
acc12 = vdotq_lane_s32(acc12, rhs0, vget_high_s8(lhs1), 0);
acc13 = vdotq_lane_s32(acc13, rhs1, vget_high_s8(lhs1), 0);
acc14 = vdotq_lane_s32(acc14, rhs0, vget_high_s8(lhs1), 1);
acc15 = vdotq_lane_s32(acc15, rhs1, vget_high_s8(lhs1), 1);
acc[0] = vdotq_lane_s32(acc[0], rhs[0], vget_low_s8(lhs[0]), 0);
acc[1] = vdotq_lane_s32(acc[1], rhs[1], vget_low_s8(lhs[0]), 0);
acc[2] = vdotq_lane_s32(acc[2], rhs[0], vget_low_s8(lhs[0]), 1);
acc[3] = vdotq_lane_s32(acc[3], rhs[1], vget_low_s8(lhs[0]), 1);
acc[4] = vdotq_lane_s32(acc[4], rhs[0], vget_high_s8(lhs[0]), 0);
acc[5] = vdotq_lane_s32(acc[5], rhs[1], vget_high_s8(lhs[0]), 0);
acc[6] = vdotq_lane_s32(acc[6], rhs[0], vget_high_s8(lhs[0]), 1);
acc[7] = vdotq_lane_s32(acc[7], rhs[1], vget_high_s8(lhs[0]), 1);
acc[8] = vdotq_lane_s32(acc[8], rhs[0], vget_low_s8(lhs[1]), 0);
acc[9] = vdotq_lane_s32(acc[9], rhs[1], vget_low_s8(lhs[1]), 0);
acc[10] = vdotq_lane_s32(acc[10], rhs[0], vget_low_s8(lhs[1]), 1);
acc[11] = vdotq_lane_s32(acc[11], rhs[1], vget_low_s8(lhs[1]), 1);
acc[12] = vdotq_lane_s32(acc[12], rhs[0], vget_high_s8(lhs[1]), 0);
acc[13] = vdotq_lane_s32(acc[13], rhs[1], vget_high_s8(lhs[1]), 0);
acc[14] = vdotq_lane_s32(acc[14], rhs[0], vget_high_s8(lhs[1]), 1);
acc[15] = vdotq_lane_s32(acc[15], rhs[1], vget_high_s8(lhs[1]), 1);
}

for (int i = 0; i < 16; ++i) {
vst1q_s32(out_ptr + 4 * i, acc[i]);
}
vst1q_s32(out_ptr + 4 * 0, acc0);
vst1q_s32(out_ptr + 4 * 1, acc1);
vst1q_s32(out_ptr + 4 * 2, acc2);
vst1q_s32(out_ptr + 4 * 3, acc3);
vst1q_s32(out_ptr + 4 * 4, acc4);
vst1q_s32(out_ptr + 4 * 5, acc5);
vst1q_s32(out_ptr + 4 * 6, acc6);
vst1q_s32(out_ptr + 4 * 7, acc7);
vst1q_s32(out_ptr + 4 * 8, acc8);
vst1q_s32(out_ptr + 4 * 9, acc9);
vst1q_s32(out_ptr + 4 * 10, acc10);
vst1q_s32(out_ptr + 4 * 11, acc11);
vst1q_s32(out_ptr + 4 * 12, acc12);
vst1q_s32(out_ptr + 4 * 13, acc13);
vst1q_s32(out_ptr + 4 * 14, acc14);
vst1q_s32(out_ptr + 4 * 15, acc15);
}
51 changes: 18 additions & 33 deletions runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_fp16.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,47 +14,32 @@ void iree_uk_mmt4d_tile_f16f16f16_8x8x1_arm_64_fp16(
float16_t* IREE_UK_RESTRICT out_ptr = out_tile;
const float16_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
const float16_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
float16x8_t acc0, acc1, acc2, acc3, acc4, acc5, acc6, acc7;
float16x8_t acc[8];
if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
acc0 = vld1q_f16(out_ptr + 8 * 0);
acc1 = vld1q_f16(out_ptr + 8 * 1);
acc2 = vld1q_f16(out_ptr + 8 * 2);
acc3 = vld1q_f16(out_ptr + 8 * 3);
acc4 = vld1q_f16(out_ptr + 8 * 4);
acc5 = vld1q_f16(out_ptr + 8 * 5);
acc6 = vld1q_f16(out_ptr + 8 * 6);
acc7 = vld1q_f16(out_ptr + 8 * 7);
for (int i = 0; i < 8; ++i) {
acc[i] = vld1q_f16(out_ptr + 8 * i);
}
} else {
acc0 = vdupq_n_f16(0);
acc1 = vdupq_n_f16(0);
acc2 = vdupq_n_f16(0);
acc3 = vdupq_n_f16(0);
acc4 = vdupq_n_f16(0);
acc5 = vdupq_n_f16(0);
acc6 = vdupq_n_f16(0);
acc7 = vdupq_n_f16(0);
for (int i = 0; i < 8; ++i) {
acc[i] = vdupq_n_f16(0);
}
}
IREE_UK_ASSUME(params->K >= 1);
for (int k = 0; k < params->K; ++k) {
float16x8_t lhs = vld1q_f16(lhs_ptr);
lhs_ptr += 8;
float16x8_t rhs = vld1q_f16(rhs_ptr);
rhs_ptr += 8;
acc0 = vfmaq_lane_f16(acc0, rhs, vget_low_f16(lhs), 0);
acc1 = vfmaq_lane_f16(acc1, rhs, vget_low_f16(lhs), 1);
acc2 = vfmaq_lane_f16(acc2, rhs, vget_low_f16(lhs), 2);
acc3 = vfmaq_lane_f16(acc3, rhs, vget_low_f16(lhs), 3);
acc4 = vfmaq_lane_f16(acc4, rhs, vget_high_f16(lhs), 0);
acc5 = vfmaq_lane_f16(acc5, rhs, vget_high_f16(lhs), 1);
acc6 = vfmaq_lane_f16(acc6, rhs, vget_high_f16(lhs), 2);
acc7 = vfmaq_lane_f16(acc7, rhs, vget_high_f16(lhs), 3);
acc[0] = vfmaq_lane_f16(acc[0], rhs, vget_low_f16(lhs), 0);
acc[1] = vfmaq_lane_f16(acc[1], rhs, vget_low_f16(lhs), 1);
acc[2] = vfmaq_lane_f16(acc[2], rhs, vget_low_f16(lhs), 2);
acc[3] = vfmaq_lane_f16(acc[3], rhs, vget_low_f16(lhs), 3);
acc[4] = vfmaq_lane_f16(acc[4], rhs, vget_high_f16(lhs), 0);
acc[5] = vfmaq_lane_f16(acc[5], rhs, vget_high_f16(lhs), 1);
acc[6] = vfmaq_lane_f16(acc[6], rhs, vget_high_f16(lhs), 2);
acc[7] = vfmaq_lane_f16(acc[7], rhs, vget_high_f16(lhs), 3);
}
for (int i = 0; i < 8; ++i) {
vst1q_f16(out_ptr + 8 * i, acc[i]);
}
vst1q_f16(out_ptr + 8 * 0, acc0);
vst1q_f16(out_ptr + 8 * 1, acc1);
vst1q_f16(out_ptr + 8 * 2, acc2);
vst1q_f16(out_ptr + 8 * 3, acc3);
vst1q_f16(out_ptr + 8 * 4, acc4);
vst1q_f16(out_ptr + 8 * 5, acc5);
vst1q_f16(out_ptr + 8 * 6, acc6);
vst1q_f16(out_ptr + 8 * 7, acc7);
}
Loading

0 comments on commit 801356f

Please sign in to comment.