Skip to content

Commit

Permalink
examples/matmul: fix host matrix printing & verification code (#1480)
Browse files Browse the repository at this point in the history
Co-authored-by: Joseph Melber <jgmelber@gmail.com>
  • Loading branch information
andrej and jgmelber authored May 29, 2024
1 parent 250f117 commit 61ef658
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 25 deletions.
138 changes: 115 additions & 23 deletions programming_examples/basic/matrix_multiplication/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
#ifndef MATRIX_MULTIPLICATION_H
#define MATRIX_MULTIPLICATION_H

#include <algorithm>
#include <boost/program_options.hpp>
#include <cmath>
#include <optional>
#include <ostream>

namespace matmul_common {

Expand Down Expand Up @@ -175,6 +178,16 @@ void matmul(int M, int N, int K, const std::vector<Tin> A,
}
}

template <typename Tin, typename Tout>
Tout mul_acc(int M, int N, int K, int row, int col, const std::vector<Tin> A,
const std::vector<Tin> B) {
Tout running_sum = 0;
for (int k = 0; k < K; k++) {
running_sum += Tout(A[row * K + k] * B[k * N + col]);
}
return running_sum;
}

// nearly_equal function adapted from Stack Overflow, License CC BY-SA 4.0
// Original author: P-Gn
// Source: https://stackoverflow.com/a/32334103
Expand Down Expand Up @@ -227,7 +240,7 @@ void print_matrix(const std::vector<T> matrix, int n_cols,
ostream << std::fixed << std::setprecision(2);

#define print_row(what) \
for (int col = 0; col < n_printable_cols / 2; col++) { \
for (int col = 0; col < (n_printable_cols + 1) / 2; col++) { \
ostream << std::right << std::setw(w) << (what); \
ostream << std::setw(0) << col_sep; \
} \
Expand All @@ -239,60 +252,139 @@ void print_matrix(const std::vector<T> matrix, int n_cols,
ostream << std::setw(0) << col_sep; \
}

for (int row = 0; row < n_printable_rows / 2; row++) {
print_row(matrix[row * n_rows + col]);
for (int row = 0; row < (n_printable_rows + 1) / 2; row++) {
print_row(matrix[row * n_cols + col]);
ostream << std::endl;
}
if (elide_rows) {
print_row(elide_sym);
ostream << std::endl;
}
for (int row = n_printable_rows / 2 + 1; row < n_printable_rows; row++) {
print_row(matrix[row * n_rows + col]);
print_row(matrix[row * n_cols + col]);
ostream << std::endl;
}

#undef print_row
}

constexpr int max_printable_errors = 32;

template <typename Tout>
struct error {
int row;
int col;
Tout expected;
Tout actual;
};

template <typename Tout>
std::optional<struct error<Tout>>
verify_single(std::ostream &os, int row, int col, Tout expected, Tout actual) {
const float absTol = 0.5;
const float relTol = 0.15;
if (!nearly_equal(expected, actual, relTol, absTol)) {
return (struct error<Tout>){row, col, expected, actual};
}
return std::nullopt;
}

template <typename Tout>
void print_error_summary(std::ostream &os, int n_errors,
std::vector<struct error<Tout>> &errors) {
for (struct error<Tout> &err : errors) {
os << "[" << std::setw(5) << err.row << ", " << std::setw(5) << err.col
<< "] " << std::setw(4) << std::setprecision(2) << std::fixed
<< (float)err.actual << " =!= " << std::setw(4) << std::setprecision(2)
<< std::fixed << (float)err.expected << std::endl;
}
if (n_errors > max_printable_errors) {
os << "...and " << std::setw(0) << n_errors - max_printable_errors
<< " further errors." << std::endl;
}
}

void print_progress_bar(std::ostream &os, double progress, int len = 75) {
os << "\r" << std::string((int)(progress * len), '|')
<< std::string(len - (int)(progress * len), ' ') << std::setw(4)
<< std::fixed << std::setprecision(0) << progress * 100 << "%"
<< "\r";
}

template <typename Tin, typename Tout>
int verify(int M, int N, int K, std::vector<Tin> A, std::vector<Tin> B,
std::vector<Tout> C) {
int errors = 0;
int max_printable_errors = 10;
const float absTol = 0.5;
const float relTol = 0.5;
std::vector<Tout> C, int verbosity = 0) {
int n_errors = 0;
std::vector<struct error<Tout>> errors;

std::vector<Tout> CRef(M * N);
matmul(M, N, K, A, B, CRef);

for (int row = 0; row < M; row++) {
for (int col = 0; col < N; col++) {
if (!nearly_equal(CRef[row * N + col], C[row * N + col], relTol,
absTol)) {
errors++;
if (errors < max_printable_errors) {
std::cout << "Error in row " << row << ", col " << col << ". "
<< "Expected " << std::setw(4) << (float)CRef[row * N + col]
<< ", got " << std::setw(4) << (float)C[row * N + col]
<< "." << std::endl;
std::optional<struct error<Tout>> error = verify_single(
std::cout, row, col, CRef[row * N + col], C[row * N + col]);
if (error.has_value()) {
if (n_errors < max_printable_errors) {
errors.push_back(*error);
}
n_errors++;
}
}
}
print_error_summary(std::cout, n_errors, errors);

if (errors >= max_printable_errors) {
std::cout << "...and " << std::setw(0) << errors << " further errors."
<< std::endl;
}
if (errors > 0) {
if (n_errors > 0) {
std::cout << std::endl << "Reference:" << std::endl;
matmul_common::print_matrix(CRef, N);
std::cout << std::endl << "Output:" << std::endl;
matmul_common::print_matrix(C, N);
}

return errors;
return n_errors;
}

template <typename Tin, typename Tout>
int verify_stochastic(int M, int N, int K, std::vector<Tin> A,
std::vector<Tin> B, std::vector<Tout> C, int n_samples,
int verbosity = 0) {
std::mt19937 rng;
auto rows = std::views::iota(0, M);
auto cols = std::views::iota(0, N);
auto sampled_rows = std::vector<int>(n_samples);
auto sampled_cols = std::vector<int>(n_samples);

std::ranges::sample(rows, sampled_rows.begin(), n_samples, rng);
std::ranges::sample(cols, sampled_cols.begin(), n_samples, rng);

int n_errors = 0;
std::vector<struct error<Tout>> errors;
double progress = 0;
for (std::tuple<size_t, std::tuple<int &, int &>> cell :
std::views::enumerate(std::views::zip(sampled_rows, sampled_cols))) {
int i = std::get<0>(cell);
int row = std::get<0>(std::get<1>(cell));
int col = std::get<1>(std::get<1>(cell));
if (verbosity >= 1 &&
(int)(progress * 100) < (int)((double)i / n_samples * 100)) {
// Only print progress bar if percentage changed
progress = (double)i / n_samples;
print_progress_bar(std::cerr, progress);
}
Tout ref = mul_acc<Tin, Tout>(M, N, K, row, col, A, B);
std::optional<struct error<Tout>> error =
verify_single(std::cout, row, col, ref, C[row * N + col]);
if (error.has_value()) {
if (n_errors < max_printable_errors) {
errors.push_back(*error);
}
n_errors++;
}
}
std::cout << std::endl;

print_error_summary(std::cout, n_errors, errors);
return n_errors;
}

// --------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,10 @@ def device_body():
2,
memRef_A_ty,
[
(k // 2 // 2, 2),
(m, k),
(k, 1),
],
(2, 1),
], # transpose at 4-byte (2xbf16) granularity
)
object_fifo_link(
memA_fifos[memA_fifo_names[i]], inA_fifos[inA_fifo_names[i]]
Expand Down

0 comments on commit 61ef658

Please sign in to comment.