Skip to content

Commit

Permalink
[AMD-AIE] Add zero fill ukernel for iree-amd-aie (#1479)
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek Varma <abhvarma@amd.com>
  • Loading branch information
Abhishek-Varma authored May 15, 2024
1 parent b085480 commit 896cebb
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
10 changes: 9 additions & 1 deletion aie_kernels/mm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

#include <aie_api/aie.hpp>

#include "zero.cc"

template <typename T_in, typename T_out, unsigned rowA, unsigned colA,
unsigned colB, unsigned r, unsigned s, unsigned t>
void matmul_vectorized(const T_in *__restrict pA, unsigned offsetA,
Expand Down Expand Up @@ -273,6 +275,12 @@ extern "C" {
64, 64, 64>(a_in, offsetA, b_in, offsetB, c_out, offsetC); \
}

combos(matmul_vectorized_c_func)
#define zero_vectorized_c_func(ctype_in, mlir_type_in, ctype_out, \
mlir_type_out, r, s, t) \
void zero_##mlir_type_out(ctype_out *c_out, unsigned offsetC) { \
zero_vectorized<ctype_out, 64, 64, 32>(c_out, offsetC); \
}

combos(matmul_vectorized_c_func) combos(zero_vectorized_c_func)

} // extern "C"
34 changes: 34 additions & 0 deletions aie_kernels/zero.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===- zero.cc --------------------------------------------000---*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Copyright (C) 2024, Advanced Micro Devices, Inc.
//
//===----------------------------------------------------------------------===//

#ifndef ZERO_CC
#define ZERO_CC

#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <type_traits>

template <typename T, int M, int N, int r>
void zero_vectorized(T *__restrict pC, unsigned offsetC) {
const aie::vector<T, r> zeros = aie::zeros<T, r>();
T *__restrict pC1 = pC + offsetC;
const T *__restrict c_end = pC1 + M * N;
for (; pC1 + r < c_end; pC1 += r) {
aie::store_v(pC1, zeros);
}
// Do a scalar write for any remainder not divisible by vector instruction
// size r
for (; pC1 < c_end; pC1++) {
*pC1 = 0;
}
}

#endif

0 comments on commit 896cebb

Please sign in to comment.