Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matrix Multiplication Improvements #1551

Merged
merged 20 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
154581b
[matmul] work around object fifo bug; accumulate in float for verific…
andrej Jun 11, 2024
08e404a
[matmul] allow modifiable tile size for whole_array
andrej Jun 12, 2024
74c59e2
[matmul] simplify verification
andrej Jun 12, 2024
5af0e9d
[matmul] simplify whole_array strides to element size after PR #1538
andrej Jun 12, 2024
cf43feb
[matmul] offsets seem to still be in bytes; fix
andrej Jun 12, 2024
2529d27
[matmul] fix matrix printing error
andrej Jun 12, 2024
d76ce27
[matmul] allow single_core overall and tile size to be adjusted
andrej Jun 12, 2024
9f86e90
[matmul] fix typo in makefile-common
andrej Jun 12, 2024
20349d9
[matmul] restore fifo depth to two for whole_array
andrej Jun 12, 2024
6544134
[matmul] express offets in bytes in whole_array design; reformat
andrej Jun 12, 2024
c697459
[matmul] reduce verification tolerance to 5% relative, 0.5 absolute
andrej Jun 12, 2024
07c029c
[matmul] format
andrej Jun 12, 2024
1dc7a6a
[matmul] fix CI test errors
andrej Jun 13, 2024
ed9905a
Merge branch 'main' into fix-matmul-2
jgmelber Jun 14, 2024
762d866
[matmul] fix matrix printing rounding error
andrej Jun 14, 2024
cfe3bbf
[matvec] use integers to avoid float errors; swap in scalar kernel fo…
andrej Jun 14, 2024
c254b41
Merge branch 'main' into fix-matmul-2
andrej Jun 14, 2024
df84e5b
[matmul] add missing includes to common.h to make it work standalone
andrej Jun 23, 2024
8093cf0
Merge branch 'main' into fix-matmul-2
andrej Jun 23, 2024
828b13d
format
andrej Jun 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions aie_kernels/aie2/mm.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@

#include "zero.cc"

template <typename T_in, typename T_out, int M, int K, int N>
template <typename T_in, typename T_out, int rowA, int colA, int colB>
void matmul_scalar(T_in *a, T_in *b, T_out *c) {
event0();
for (int row = 0; row < M; row++) {
for (int col = 0; col < N; col++) {
for (int row = 0; row < rowA; row++) {
for (int col = 0; col < colB; col++) {
T_out running_sum = 0;
for (int i = 0; i < K; i++) {
running_sum += a[row * K + i] * b[i * N + col];
for (int i = 0; i < colA; i++) {
running_sum += a[row * colA + i] * b[i * colB + col];
}
c[row * N + col] += running_sum;
c[row * colB + col] += running_sum;
}
}
event1();
Expand Down Expand Up @@ -397,6 +397,23 @@ void matmul_vectorized_4x8x4_bf16_f32(const bfloat16 *__restrict pA,

extern "C" {

// If you want to compile microkernels with different inner tile sizes,
// define DIM_M, DIM_K and DIM_N at compile time using -DDIM_M 32 etc.
// These dimensions must be divisible by the r, s, t dimensions used in
// the kernels.

#ifndef DIM_M
#define DIM_M 64
#endif

#ifndef DIM_K
#define DIM_K 64
#endif

#ifndef DIM_N
#define DIM_N 64
#endif

#define combos(X) \
X(int16, i16, int16, i16, 4, 4, 4) \
X(bfloat16, bf16, bfloat16, bf16, 4, 8, 4) \
Expand All @@ -407,26 +424,27 @@ extern "C" {
void matmul_##mlir_type_in##_##mlir_type_out(ctype_in *a_in, ctype_in *b_in, \
ctype_out *c_out) { \
matmul_vectorized_##r##x##s##x##t##_##mlir_type_in##_##mlir_type_out< \
64, 64, 64>(a_in, b_in, c_out); \
DIM_M, DIM_K, DIM_N>(a_in, b_in, c_out); \
}

#define matmul_scalar_c_func(ctype_in, mlir_type_in, ctype_out, mlir_type_out, \
r, s, t) \
void matmul_scalar_##mlir_type_in##_##mlir_type_out( \
ctype_in *a_in, ctype_in *b_in, ctype_out *c_out) { \
matmul_scalar<ctype_in, ctype_out, 64, 32, 64>(a_in, b_in, c_out); \
matmul_scalar<ctype_in, ctype_out, DIM_M, DIM_K, DIM_N>(a_in, b_in, \
c_out); \
}

#define zero_vectorized_c_func(ctype_in, mlir_type_in, ctype_out, \
mlir_type_out, r, s, t) \
void zero_##mlir_type_out(ctype_out *c_out) { \
zero_vectorized<ctype_out, 64, 64, 32>(c_out); \
zero_vectorized<ctype_out, DIM_M, DIM_N, 32>(c_out); \
}

#define zero_scalar_c_func(ctype_in, mlir_type_in, ctype_out, mlir_type_out, \
r, s, t) \
void zero_scalar_##mlir_type_out(ctype_out *c_out) { \
zero_scalar<ctype_out, 64, 64>(c_out); \
zero_scalar<ctype_out, DIM_M, DIM_N>(c_out); \
}

combos(matmul_vectorized_c_func) combos(matmul_scalar_c_func)
Expand Down
28 changes: 21 additions & 7 deletions aie_kernels/aie2/mv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ void matvec_vectorized(T_in *__restrict a, T_in *__restrict b,
static_assert(s == 8); // s is fixed to 8 because that is the number of
// column vectors (a_vec_0_0..a_vec_3_1) we create
static_assert(k % s == 0);
static_assert(std::is_same<T_in, bfloat16>::value);
static_assert(std::is_same<T_in, bfloat16>::value ||
std::is_same<T_in, int16_t>::value);

// This kernel expects a "32-bit word transposed matrix", i.e. the result
// of transposing the row-major representation of the matrix at a
Expand Down Expand Up @@ -133,35 +134,48 @@ void matvec_vectorized(T_in *__restrict a, T_in *__restrict b,

extern "C" {

// If you want to compile microkernels with different inner tile sizes,
// define DIM_M and DIM_K at compile time using -DDIM_M 16 etc.
// These dimensions must be divisible by the r, s dimensions used in
// the kernels.

#ifndef DIM_M
#define DIM_M 32
#endif

#ifndef DIM_K
#define DIM_K 32
#endif

#define combos(X) \
X(bfloat16, bf16, float, f32, accfloat) \
// X(int16, i16, int16, i16, acc32) \
/* X(bfloat16, bf16, float, f32, accfloat) */ \
X(int16, i16, int32, i32, acc32)

#define matvec_scalar_c_func(ctype_in, mlir_type_in, ctype_out, mlir_type_out, \
ctype_acc) \
void matvec_scalar_##mlir_type_in##_##mlir_type_out( \
ctype_in *a_in, ctype_in *b_in, ctype_out *c_out) { \
matvec_scalar<ctype_in, ctype_out, 32, 32>(a_in, b_in, c_out); \
matvec_scalar<ctype_in, ctype_out, DIM_M, DIM_K>(a_in, b_in, c_out); \
}

#define matvec_vectorized_c_func(ctype_in, mlir_type_in, ctype_out, \
mlir_type_out, ctype_acc) \
void matvec_vectorized_##mlir_type_in##_##mlir_type_out( \
ctype_in *a_in, ctype_in *b_in, ctype_out *c_out) { \
matvec_vectorized<ctype_in, ctype_out, ctype_acc, 32, 32, 16, 8>( \
matvec_vectorized<ctype_in, ctype_out, ctype_acc, DIM_M, DIM_K, 16, 8>( \
a_in, b_in, c_out); \
}

#define zero_vectorized_c_func(ctype_in, mlir_type_in, ctype_out, \
mlir_type_out, ctype_acc) \
void zero_vectorized_##mlir_type_out(ctype_out *c_out) { \
zero_vectorized<ctype_out, 32, 1, 32>(c_out); \
zero_vectorized<ctype_out, DIM_M, 1, 32>(c_out); \
}

#define zero_scalar_c_func(ctype_in, mlir_type_in, ctype_out, mlir_type_out, \
ctype_acc) \
void zero_scalar_##mlir_type_out(ctype_out *c_out) { \
zero_scalar<ctype_out, 32, 1>(c_out); \
zero_scalar<ctype_out, DIM_M, 1>(c_out); \
}

combos(matvec_scalar_c_func) combos(matvec_vectorized_c_func)
Expand Down
112 changes: 44 additions & 68 deletions programming_examples/basic/matrix_multiplication/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
#define MATRIX_MULTIPLICATION_H

#include <algorithm>
#include <bits/stdc++.h>
#include <boost/program_options.hpp>
#include <cmath>
#include <fstream>
#include <optional>
#include <ostream>
#include <stdfloat>

namespace matmul_common {

Expand Down Expand Up @@ -116,76 +119,28 @@ static inline std::bfloat16_t random_bfloat16_t() {
return std::bfloat16_t(4.0 * (float)rand() / (float)(RAND_MAX));
}

template <typename Tin, typename Tout>
void matmul_naive(int M, int N, int K, const std::vector<Tin> A,
const std::vector<Tin> B, std::vector<Tout> &C) {
for (int row = 0; row < M; row++) {
for (int col = 0; col < N; col++) {
Tout running_sum = 0;
for (int k = 0; k < K; k++) {
running_sum += Tout(A[row * K + k] * B[k * N + col]);
}
C[row * N + col] = Tout(running_sum);
}
}
}

template <typename Tin, typename Tout>
template <typename Tin, typename Tout, typename Tacc>
void matmul(int M, int N, int K, const std::vector<Tin> A,
const std::vector<Tin> B, std::vector<Tout> &C) {
// A is an MxK matrix
// B is a KxN matrix
// C is the MxN output matrix, assumed to be zeroed out

constexpr int K_block_size = 64;
const int n_K_blocks = K / K_block_size;

const Tin *B_origin = B.data(); /* Avoid a calls to B.data() within the loop
with this const variable. B does not get
resized, so the pointer remains valid. */

const Tin *A_base = A.data(); /* Points to start of current row of A,
monotonically increasing by K. */
const Tin *B_base = B_origin; /* Points to start of current column of B;
increases by 1 in each inner loop, resets
to B_origin (0) at the start of a new row
(outer loop). */

const Tin *A_ptr = A_base;
const Tin *B_ptr = B_base;
Tout *C_ptr = C.data(); /* Monotonically increasing by 1. */

for (int row = 0; row < M; row++) {
for (int col = 0; col < N; col++) {
A_ptr = A_base;
B_ptr = B_base;
Tout running_sum = 0;
for (int k = 0; k < n_K_blocks; k++) {
for (int i = 0; i < K_block_size; i++) {
running_sum += Tout(*A_ptr) * Tout(*B_ptr);
A_ptr += 1; // Advance to right neighbor; next value in this row
B_ptr += N; // Advance to bottom neighbor; next value in this column
}
Tacc running_sum = 0;
for (int k = 0; k < K; k++) {
running_sum += Tacc(A[row * K + k] * B[k * N + col]);
}
*C_ptr = Tout(running_sum);
C_ptr += 1;
B_base += 1; /* Next iteration: same row of A (A_base unchanged),
next column of B (B_base increases by 1) */
C[row * N + col] = Tout(running_sum);
}
A_base += K; // Advance to next row of A
B_base = B_origin; /* Next row of A means we need to restart at the first
column of B. */
}
}

template <typename Tin, typename Tout>
template <typename Tin, typename Tout, typename Tacc>
Tout mul_acc(int M, int N, int K, int row, int col, const std::vector<Tin> A,
const std::vector<Tin> B) {
Tout running_sum = 0;
Tacc running_sum = 0;
for (int k = 0; k < K; k++) {
running_sum += Tout(A[row * K + k] * B[k * N + col]);
running_sum += Tacc(A[row * K + k] * B[k * N + col]);
}
return running_sum;
return (Tout)running_sum;
}

// nearly_equal function adapted from Stack Overflow, License CC BY-SA 4.0
Expand Down Expand Up @@ -219,7 +174,7 @@ void print_matrix(const std::vector<T> matrix, int n_cols,
assert(matrix.size() % n_cols == 0);

auto maxima = std::minmax_element(matrix.begin(), matrix.end());
T max_val = std::max(*maxima.first, std::abs(*maxima.second));
T max_val = std::max(*maxima.first, (T)std::abs(*maxima.second));
size_t n_digits = log10(max_val);
if (w == -1) {
w = n_digits;
Expand Down Expand Up @@ -247,7 +202,8 @@ void print_matrix(const std::vector<T> matrix, int n_cols,
if (elide_cols) { \
ostream << std::setw(0) << elide_sym; \
} \
for (int col = n_printable_cols / 2 + 1; col < n_printable_cols; col++) { \
for (int i = 0; i < n_printable_cols / 2; i++) { \
int col = n_cols - n_printable_cols / 2 + i; \
ostream << std::right << std::setw(w) << (what); \
ostream << std::setw(0) << col_sep; \
}
Expand All @@ -260,7 +216,8 @@ void print_matrix(const std::vector<T> matrix, int n_cols,
print_row(elide_sym);
ostream << std::endl;
}
for (int row = n_printable_rows / 2 + 1; row < n_printable_rows; row++) {
for (int i = 0; i < n_printable_rows / 2; i++) {
int row = n_rows - n_printable_rows / 2 + i;
print_row(matrix[row * n_cols + col]);
ostream << std::endl;
}
Expand All @@ -282,7 +239,7 @@ template <typename Tout>
std::optional<struct error<Tout>>
verify_single(std::ostream &os, int row, int col, Tout expected, Tout actual) {
const float absTol = 0.5;
const float relTol = 0.15;
const float relTol = 0.05;
if (!nearly_equal(expected, actual, relTol, absTol)) {
return (struct error<Tout>){row, col, expected, actual};
}
Expand All @@ -291,7 +248,8 @@ verify_single(std::ostream &os, int row, int col, Tout expected, Tout actual) {

template <typename Tout>
void print_error_summary(std::ostream &os, int n_errors,
std::vector<struct error<Tout>> &errors) {
std::vector<struct error<Tout>> &errors,
Tout max_rel_error) {
for (struct error<Tout> &err : errors) {
os << "[" << std::setw(5) << err.row << ", " << std::setw(5) << err.col
<< "] " << std::setw(4) << std::setprecision(2) << std::fixed
Expand All @@ -302,6 +260,10 @@ void print_error_summary(std::ostream &os, int n_errors,
os << "...and " << std::setw(0) << n_errors - max_printable_errors
<< " further errors." << std::endl;
}
if (n_errors > 0) {
os << "Maximum relative error: " << std::setw(3) << std::setprecision(0)
<< max_rel_error * 100 << "%" << std::endl;
}
}

void print_progress_bar(std::ostream &os, double progress, int len = 75) {
Expand All @@ -311,14 +273,15 @@ void print_progress_bar(std::ostream &os, double progress, int len = 75) {
<< "\r";
}

template <typename Tin, typename Tout>
template <typename Tin, typename Tout, typename Tacc>
int verify(int M, int N, int K, std::vector<Tin> A, std::vector<Tin> B,
std::vector<Tout> C, int verbosity = 0) {
int n_errors = 0;
std::vector<struct error<Tout>> errors;
Tout max_rel_error = (Tout)0.0f;

std::vector<Tout> CRef(M * N);
matmul(M, N, K, A, B, CRef);
matmul<Tin, Tout, Tacc>(M, N, K, A, B, CRef);

for (int row = 0; row < M; row++) {
for (int col = 0; col < N; col++) {
Expand All @@ -328,11 +291,17 @@ int verify(int M, int N, int K, std::vector<Tin> A, std::vector<Tin> B,
if (n_errors < max_printable_errors) {
errors.push_back(*error);
}
Tout rel_error =
std::abs(error->actual - error->expected) /
std::max(std::abs(error->actual), std::abs(error->expected));
if (rel_error > max_rel_error) {
max_rel_error = rel_error;
}
n_errors++;
}
}
}
print_error_summary(std::cout, n_errors, errors);
print_error_summary(std::cout, n_errors, errors, max_rel_error);

if (n_errors > 0) {
std::cout << std::endl << "Reference:" << std::endl;
Expand All @@ -344,7 +313,7 @@ int verify(int M, int N, int K, std::vector<Tin> A, std::vector<Tin> B,
return n_errors;
}

template <typename Tin, typename Tout>
template <typename Tin, typename Tout, typename Tacc>
int verify_stochastic(int M, int N, int K, std::vector<Tin> A,
std::vector<Tin> B, std::vector<Tout> C, int n_samples,
int verbosity = 0) {
Expand All @@ -359,6 +328,7 @@ int verify_stochastic(int M, int N, int K, std::vector<Tin> A,

int n_errors = 0;
std::vector<struct error<Tout>> errors;
Tout max_rel_error = (Tout)0.0f;
double progress = 0;
for (std::tuple<size_t, std::tuple<int &, int &>> cell :
std::views::enumerate(std::views::zip(sampled_rows, sampled_cols))) {
Expand All @@ -371,19 +341,25 @@ int verify_stochastic(int M, int N, int K, std::vector<Tin> A,
progress = (double)i / n_samples;
print_progress_bar(std::cerr, progress);
}
Tout ref = mul_acc<Tin, Tout>(M, N, K, row, col, A, B);
Tout ref = mul_acc<Tin, Tout, Tacc>(M, N, K, row, col, A, B);
std::optional<struct error<Tout>> error =
verify_single(std::cout, row, col, ref, C[row * N + col]);
if (error.has_value()) {
if (n_errors < max_printable_errors) {
errors.push_back(*error);
}
Tout rel_error =
std::abs(error->actual - error->expected) /
std::max(std::abs(error->actual), std::abs(error->expected));
if (rel_error > max_rel_error) {
max_rel_error = rel_error;
}
n_errors++;
}
}
std::cout << std::endl;

print_error_summary(std::cout, n_errors, errors);
print_error_summary(std::cout, n_errors, errors, max_rel_error);
return n_errors;
}

Expand Down
Loading
Loading