diff --git a/programming_examples/basic/matrix_multiplication/common.h b/programming_examples/basic/matrix_multiplication/common.h index 3fe03bd9ca..b7386f268c 100644 --- a/programming_examples/basic/matrix_multiplication/common.h +++ b/programming_examples/basic/matrix_multiplication/common.h @@ -14,8 +14,11 @@ #ifndef MATRIX_MULTIPLICATION_H #define MATRIX_MULTIPLICATION_H +#include #include #include +#include +#include namespace matmul_common { @@ -175,6 +178,16 @@ void matmul(int M, int N, int K, const std::vector A, } } +template +Tout mul_acc(int M, int N, int K, int row, int col, const std::vector A, + const std::vector B) { + Tout running_sum = 0; + for (int k = 0; k < K; k++) { + running_sum += Tout(A[row * K + k] * B[k * N + col]); + } + return running_sum; +} + // nearly_equal function adapted from Stack Overflow, License CC BY-SA 4.0 // Original author: P-Gn // Source: https://stackoverflow.com/a/32334103 @@ -227,7 +240,7 @@ void print_matrix(const std::vector matrix, int n_cols, ostream << std::fixed << std::setprecision(2); #define print_row(what) \ - for (int col = 0; col < n_printable_cols / 2; col++) { \ + for (int col = 0; col < (n_printable_cols + 1) / 2; col++) { \ ostream << std::right << std::setw(w) << (what); \ ostream << std::setw(0) << col_sep; \ } \ @@ -239,8 +252,8 @@ void print_matrix(const std::vector matrix, int n_cols, ostream << std::setw(0) << col_sep; \ } - for (int row = 0; row < n_printable_rows / 2; row++) { - print_row(matrix[row * n_rows + col]); + for (int row = 0; row < (n_printable_rows + 1) / 2; row++) { + print_row(matrix[row * n_cols + col]); ostream << std::endl; } if (elide_rows) { @@ -248,51 +261,130 @@ void print_matrix(const std::vector matrix, int n_cols, ostream << std::endl; } for (int row = n_printable_rows / 2 + 1; row < n_printable_rows; row++) { - print_row(matrix[row * n_rows + col]); + print_row(matrix[row * n_cols + col]); ostream << std::endl; } #undef print_row } +constexpr int max_printable_errors = 32; + +template +struct error { + int row; + int col; + Tout expected; + Tout actual; +}; + +template +std::optional> +verify_single(std::ostream &os, int row, int col, Tout expected, Tout actual) { + const float absTol = 0.5; + const float relTol = 0.15; + if (!nearly_equal(expected, actual, relTol, absTol)) { + return (struct error){row, col, expected, actual}; + } + return std::nullopt; +} + +template +void print_error_summary(std::ostream &os, int n_errors, + std::vector> &errors) { + for (struct error &err : errors) { + os << "[" << std::setw(5) << err.row << ", " << std::setw(5) << err.col + << "] " << std::setw(4) << std::setprecision(2) << std::fixed + << (float)err.actual << " =!= " << std::setw(4) << std::setprecision(2) + << std::fixed << (float)err.expected << std::endl; + } + if (n_errors > max_printable_errors) { + os << "...and " << std::setw(0) << n_errors - max_printable_errors + << " further errors." << std::endl; + } +} + +void print_progress_bar(std::ostream &os, double progress, int len = 75) { + os << "\r" << std::string((int)(progress * len), '|') + << std::string(len - (int)(progress * len), ' ') << std::setw(4) + << std::fixed << std::setprecision(0) << progress * 100 << "%" + << "\r"; +} + template int verify(int M, int N, int K, std::vector A, std::vector B, - std::vector C) { - int errors = 0; - int max_printable_errors = 10; - const float absTol = 0.5; - const float relTol = 0.5; + std::vector C, int verbosity = 0) { + int n_errors = 0; + std::vector> errors; std::vector CRef(M * N); matmul(M, N, K, A, B, CRef); for (int row = 0; row < M; row++) { for (int col = 0; col < N; col++) { - if (!nearly_equal(CRef[row * N + col], C[row * N + col], relTol, - absTol)) { - errors++; - if (errors < max_printable_errors) { - std::cout << "Error in row " << row << ", col " << col << ". " - << "Expected " << std::setw(4) << (float)CRef[row * N + col] - << ", got " << std::setw(4) << (float)C[row * N + col] - << "." << std::endl; + std::optional> error = verify_single( + std::cout, row, col, CRef[row * N + col], C[row * N + col]); + if (error.has_value()) { + if (n_errors < max_printable_errors) { + errors.push_back(*error); } + n_errors++; } } } + print_error_summary(std::cout, n_errors, errors); - if (errors >= max_printable_errors) { - std::cout << "...and " << std::setw(0) << errors << " further errors." - << std::endl; - } - if (errors > 0) { + if (n_errors > 0) { std::cout << std::endl << "Reference:" << std::endl; matmul_common::print_matrix(CRef, N); std::cout << std::endl << "Output:" << std::endl; matmul_common::print_matrix(C, N); } - return errors; + return n_errors; +} + +template +int verify_stochastic(int M, int N, int K, std::vector A, + std::vector B, std::vector C, int n_samples, + int verbosity = 0) { + std::mt19937 rng; + auto rows = std::views::iota(0, M); + auto cols = std::views::iota(0, N); + auto sampled_rows = std::vector(n_samples); + auto sampled_cols = std::vector(n_samples); + + std::ranges::sample(rows, sampled_rows.begin(), n_samples, rng); + std::ranges::sample(cols, sampled_cols.begin(), n_samples, rng); + + int n_errors = 0; + std::vector> errors; + double progress = 0; + for (std::tuple> cell : + std::views::enumerate(std::views::zip(sampled_rows, sampled_cols))) { + int i = std::get<0>(cell); + int row = std::get<0>(std::get<1>(cell)); + int col = std::get<1>(std::get<1>(cell)); + if (verbosity >= 1 && + (int)(progress * 100) < (int)((double)i / n_samples * 100)) { + // Only print progress bar if percentage changed + progress = (double)i / n_samples; + print_progress_bar(std::cerr, progress); + } + Tout ref = mul_acc(M, N, K, row, col, A, B); + std::optional> error = + verify_single(std::cout, row, col, ref, C[row * N + col]); + if (error.has_value()) { + if (n_errors < max_printable_errors) { + errors.push_back(*error); + } + n_errors++; + } + } + std::cout << std::endl; + + print_error_summary(std::cout, n_errors, errors); + return n_errors; } // -------------------------------------------------------------------------- diff --git a/programming_examples/basic/matrix_multiplication/matrix_vector/aie2.py b/programming_examples/basic/matrix_multiplication/matrix_vector/aie2.py index 0f7938c6a5..6b27d9f9e3 100644 --- a/programming_examples/basic/matrix_multiplication/matrix_vector/aie2.py +++ b/programming_examples/basic/matrix_multiplication/matrix_vector/aie2.py @@ -103,9 +103,10 @@ def device_body(): 2, memRef_A_ty, [ + (k // 2 // 2, 2), (m, k), - (k, 1), - ], + (2, 1), + ], # transpose at 4-byte (2xbf16) granularity ) object_fifo_link( memA_fifos[memA_fifo_names[i]], inA_fifos[inA_fifo_names[i]]