From 9e3785b2557db892eda7a72051ff7f405cffb35a Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Mon, 17 Jun 2024 15:31:02 -0500 Subject: [PATCH] adds wvSpltK optimization for skinny gemm. --- csrc/custom/custom.cu | 15 + csrc/custom/custom_kernels.cu | 1436 ++++++++++++++++++++++ vllm/model_executor/layers/tuned_gemm.py | 14 +- 3 files changed, 1461 insertions(+), 4 deletions(-) diff --git a/csrc/custom/custom.cu b/csrc/custom/custom.cu index 3da25ece3e87c..5c8beed37b304 100644 --- a/csrc/custom/custom.cu +++ b/csrc/custom/custom.cu @@ -39,6 +39,20 @@ void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, at::cuda::getCurrentCUDAStream(), rows_per_block); } +void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K, const int N, + cudaStream_t stream, const int CuCount); + +void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, + const int N_in, const int CuCount) { + int M = in_a.size(0); + int K = in_a.size(1); + int N = N_in; + wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, + at::cuda::getCurrentCUDAStream(), CuCount); +} + + + void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, cudaStream_t stream, const int solidx); @@ -90,5 +104,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("LLZZ", &LLZZ); m.def("paged_attention_custom", &paged_attention_custom, "PagedAttention LL4Mi Custom."); + m.def("wvSpltK", &wvSpltK); // m.def("MMCustomGPU", &MMCustomGPU); } diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index afecf82eb3d77..09c28d1a46eca 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -2,6 +2,9 @@ #include #include #include +#include +#include "hsa/hsa.h" +#include "hsa/hsa_ext_amd.h" constexpr int WARP_SIZE = 64; @@ -309,3 +312,1436 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, if (cudaSuccess != err) throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } + +///////////////////////////////////////////// + +using half8 = __attribute__((__vector_size__(4 * sizeof(float)))) float; + +/*template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); + //return *((T*)addr); +}*/ + +#define THRDS 64 +#define YTILE 2 +#define WvPrGrp 16 +#define A_CHUNK 8 +#define UNRL 2 +#define M 1 +#define DTYPE half + +__global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) + { + commitColumn[i] = 1; + } + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if(n < N && (n + YTILE) >= N) + { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) + { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32*1024); k += THRDS * WvPrGrp * A_CHUNK) + { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + //uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for bank-conflict-free readback + + if (k_in >= min(K * M, 32*1024)) break; + + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of availble columns + //---------------------------------------------------- + while (n < N) + { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // splitted across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m=0; m= 2) + bigType bigB1[UNRL]; +#endif +#if (YTILE >= 3) + bigType bigB2[UNRL]; +#endif +#if (YTILE >= 4) + bigType bigB3[UNRL]; +#endif +#if (YTILE >= 5) + bigType bigB4[UNRL]; +#endif +#if (YTILE >= 6) + bigType bigB5[UNRL]; +#endif +#if (YTILE >= 7) + bigType bigB6[UNRL]; +#endif +#if (YTILE >= 8) + bigType bigB7[UNRL]; +#endif +#if (YTILE >= 9) + bigType bigB8[UNRL]; +#endif +#if (YTILE >= 10) + bigType bigB9[UNRL]; +#endif +#if (YTILE >= 11) + bigType bigB10[UNRL]; +#endif + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) + { + // Fetch the weight matrix from memory! +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + //if (k_ >= K) break; + //bool skip = (k_ >= K); + //bool dummy = (k_ >= K); + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- +#if (YTILE >= 2) + //if (n+1>=N) continue; + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); +#endif +#if (YTILE >= 3) + //if (n+2>=N) continue; + bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); +#endif +#if (YTILE >= 4) + //if (n+3>=N) continue; + bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); +#endif +#if (YTILE >= 5) + //if (n+4>=N) continue; + bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); +#endif +#if (YTILE >= 6) + //if (n+5>=N) continue; + bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); +#endif +#if (YTILE >= 7) + //if (n+6>=N) continue; + bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); +#endif +#if (YTILE >= 8) + //if (n+7>=N) continue; + bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); +#endif +/* +#if (YTILE >= 9) + if (n+8>=N) continue; bigB8[k2].h8 = (loadnt((half8*)(&B_[8 * K]))); +#endif +#if (YTILE >= 10) + if (n+9>=N) continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); +#endif +#if (YTILE >= 11) + if (n+10>=N) continue; bigB10[k2].h8 = (loadnt((half8*)(&B_[10 * K]))); +#endif +*/ + } + + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m=0; m < M; m++) + { + if (k_+K*m < 32*1024) + bigA[m][k2] = *((const bigType*)(&(s[k_+K*m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_+K*m]))); + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t m = 0; m < M; m++) + { +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Do the matrix multiplication of activation and weight matrix + // - Rememeber the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) + { + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][0]) : "0"(sum[m][0]), "v" (bigA[m][k2].f[b]), "v" (bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][1]) : "0"(sum[m][1]), "v" (bigA[m][k2].f[b]), "v" (bigB1[k2].f[b])); +#endif +#if (YTILE >= 3) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][2]) : "0"(sum[m][2]), "v" (bigA[m][k2].f[b]), "v" (bigB2[k2].f[b])); +#endif +#if (YTILE >= 4) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][3]) : "0"(sum[m][3]), "v" (bigA[m][k2].f[b]), "v" (bigB3[k2].f[b])); +#endif +#if (YTILE >= 5) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][4]) : "0"(sum[m][4]), "v" (bigA[m][k2].f[b]), "v" (bigB4[k2].f[b])); +#endif +#if (YTILE >= 6) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][5]) : "0"(sum[m][5]), "v" (bigA[m][k2].f[b]), "v" (bigB5[k2].f[b])); +#endif +#if (YTILE >= 7) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][6]) : "0"(sum[m][6]), "v" (bigA[m][k2].f[b]), "v" (bigB6[k2].f[b])); +#endif +#if (YTILE >= 8) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][7]) : "0"(sum[m][7]), "v" (bigA[m][k2].f[b]), "v" (bigB7[k2].f[b])); +#endif +#if (YTILE >= 9) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][8]) : "0"(sum[m][8]), "v" (bigA[m][k2].f[b]), "v" (bigB8[k2].f[b])); +#endif +#if (YTILE >= 10) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][9]) : "0"(sum[m][9]), "v" (bigA[m][k2].f[b]), "v" (bigB9[k2].f[b])); +#endif +#if (YTILE >= 11) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][10]) : "0"(sum[m][10]), "v" (bigA[m][k2].f[b]), "v" (bigB10[k2].f[b])); +#endif + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) + { + for (int y = 0; y < YTILE; y++) + { + //for (int offset = 64 / 2; offset > 4 ; offset /= 2) { + // sum[y] += __shfl_down(sum[y], offset); + //} + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + + } + } + + if (threadIdx.x == 0) + { + for (int m = 0; m < M; m++) + { + for (int i = 0; i < YTILE; i++) + { + if (commitColumn[i]) + C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * WvPrGrp * YTILE; + + //if (threadIdx.x == 0) + //n = atomicAdd(((unsigned int*)(C)), YTILE); + //n = __shfl(n, 0, 64); + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if(n < N && (n + YTILE) >= N) + { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) + { + commitColumn[i] = 0; + } + n = startColumn; + } + } + + +} + +#undef YTILE +#undef UNRL +#undef M + + +#define YTILE 2 +#define UNRL 2 +#define M 2 + +__global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) + { + commitColumn[i] = 1; + } + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if(n < N && (n + YTILE) >= N) + { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) + { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32*1024); k += THRDS * WvPrGrp * A_CHUNK) + { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + //uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for bank-conflict-free readback + + if (k_in >= min(K * M, 32*1024)) break; + + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of availble columns + //---------------------------------------------------- + while (n < N) + { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // splitted across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m=0; m= 2) + bigType bigB1[UNRL]; +#endif +#if (YTILE >= 3) + bigType bigB2[UNRL]; +#endif +#if (YTILE >= 4) + bigType bigB3[UNRL]; +#endif +#if (YTILE >= 5) + bigType bigB4[UNRL]; +#endif +#if (YTILE >= 6) + bigType bigB5[UNRL]; +#endif +#if (YTILE >= 7) + bigType bigB6[UNRL]; +#endif +#if (YTILE >= 8) + bigType bigB7[UNRL]; +#endif +#if (YTILE >= 9) + bigType bigB8[UNRL]; +#endif +#if (YTILE >= 10) + bigType bigB9[UNRL]; +#endif +#if (YTILE >= 11) + bigType bigB10[UNRL]; +#endif + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) + { + // Fetch the weight matrix from memory! +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + //if (k_ >= K) break; + //bool skip = (k_ >= K); + //bool dummy = (k_ >= K); + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- +#if (YTILE >= 2) + //if (n+1>=N) continue; + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); +#endif +#if (YTILE >= 3) + //if (n+2>=N) continue; + bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); +#endif +#if (YTILE >= 4) + //if (n+3>=N) continue; + bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); +#endif +#if (YTILE >= 5) + //if (n+4>=N) continue; + bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); +#endif +#if (YTILE >= 6) + //if (n+5>=N) continue; + bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); +#endif +#if (YTILE >= 7) + //if (n+6>=N) continue; + bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); +#endif +#if (YTILE >= 8) + //if (n+7>=N) continue; + bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); +#endif +/* +#if (YTILE >= 9) + if (n+8>=N) continue; bigB8[k2].h8 = (loadnt((half8*)(&B_[8 * K]))); +#endif +#if (YTILE >= 10) + if (n+9>=N) continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); +#endif +#if (YTILE >= 11) + if (n+10>=N) continue; bigB10[k2].h8 = (loadnt((half8*)(&B_[10 * K]))); +#endif +*/ + } + + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m=0; m < M; m++) + { + if (k_+K*m < 32*1024) + bigA[m][k2] = *((const bigType*)(&(s[k_+K*m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_+K*m]))); + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t m = 0; m < M; m++) + { +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Do the matrix multiplication of activation and weight matrix + // - Rememeber the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) + { + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][0]) : "0"(sum[m][0]), "v" (bigA[m][k2].f[b]), "v" (bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][1]) : "0"(sum[m][1]), "v" (bigA[m][k2].f[b]), "v" (bigB1[k2].f[b])); +#endif +#if (YTILE >= 3) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][2]) : "0"(sum[m][2]), "v" (bigA[m][k2].f[b]), "v" (bigB2[k2].f[b])); +#endif +#if (YTILE >= 4) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][3]) : "0"(sum[m][3]), "v" (bigA[m][k2].f[b]), "v" (bigB3[k2].f[b])); +#endif +#if (YTILE >= 5) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][4]) : "0"(sum[m][4]), "v" (bigA[m][k2].f[b]), "v" (bigB4[k2].f[b])); +#endif +#if (YTILE >= 6) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][5]) : "0"(sum[m][5]), "v" (bigA[m][k2].f[b]), "v" (bigB5[k2].f[b])); +#endif +#if (YTILE >= 7) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][6]) : "0"(sum[m][6]), "v" (bigA[m][k2].f[b]), "v" (bigB6[k2].f[b])); +#endif +#if (YTILE >= 8) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][7]) : "0"(sum[m][7]), "v" (bigA[m][k2].f[b]), "v" (bigB7[k2].f[b])); +#endif +#if (YTILE >= 9) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][8]) : "0"(sum[m][8]), "v" (bigA[m][k2].f[b]), "v" (bigB8[k2].f[b])); +#endif +#if (YTILE >= 10) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][9]) : "0"(sum[m][9]), "v" (bigA[m][k2].f[b]), "v" (bigB9[k2].f[b])); +#endif +#if (YTILE >= 11) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][10]) : "0"(sum[m][10]), "v" (bigA[m][k2].f[b]), "v" (bigB10[k2].f[b])); +#endif + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) + { + for (int y = 0; y < YTILE; y++) + { + //for (int offset = 64 / 2; offset > 4 ; offset /= 2) { + // sum[y] += __shfl_down(sum[y], offset); + //} + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + + } + } + + if (threadIdx.x == 0) + { + for (int m = 0; m < M; m++) + { + for (int i = 0; i < YTILE; i++) + { + if (commitColumn[i]) + C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * WvPrGrp * YTILE; + + //if (threadIdx.x == 0) + //n = atomicAdd(((unsigned int*)(C)), YTILE); + //n = __shfl(n, 0, 64); + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if(n < N && (n + YTILE) >= N) + { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) + { + commitColumn[i] = 0; + } + n = startColumn; + } + } + + +} + +#undef YTILE +#undef UNRL +#undef M + +#define YTILE 5 +#define UNRL 2 +#define M 3 + +__global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) + { + commitColumn[i] = 1; + } + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if(n < N && (n + YTILE) >= N) + { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) + { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32*1024); k += THRDS * WvPrGrp * A_CHUNK) + { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + //uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for bank-conflict-free readback + + if (k_in >= min(K * M, 32*1024)) break; + + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of availble columns + //---------------------------------------------------- + while (n < N) + { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // splitted across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m=0; m= 2) + bigType bigB1[UNRL]; +#endif +#if (YTILE >= 3) + bigType bigB2[UNRL]; +#endif +#if (YTILE >= 4) + bigType bigB3[UNRL]; +#endif +#if (YTILE >= 5) + bigType bigB4[UNRL]; +#endif +#if (YTILE >= 6) + bigType bigB5[UNRL]; +#endif +#if (YTILE >= 7) + bigType bigB6[UNRL]; +#endif +#if (YTILE >= 8) + bigType bigB7[UNRL]; +#endif +#if (YTILE >= 9) + bigType bigB8[UNRL]; +#endif +#if (YTILE >= 10) + bigType bigB9[UNRL]; +#endif +#if (YTILE >= 11) + bigType bigB10[UNRL]; +#endif + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) + { + // Fetch the weight matrix from memory! +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + //if (k_ >= K) break; + //bool skip = (k_ >= K); + //bool dummy = (k_ >= K); + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- +#if (YTILE >= 2) + //if (n+1>=N) continue; + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); +#endif +#if (YTILE >= 3) + //if (n+2>=N) continue; + bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); +#endif +#if (YTILE >= 4) + //if (n+3>=N) continue; + bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); +#endif +#if (YTILE >= 5) + //if (n+4>=N) continue; + bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); +#endif +#if (YTILE >= 6) + //if (n+5>=N) continue; + bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); +#endif +#if (YTILE >= 7) + //if (n+6>=N) continue; + bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); +#endif +#if (YTILE >= 8) + //if (n+7>=N) continue; + bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); +#endif +/* +#if (YTILE >= 9) + if (n+8>=N) continue; bigB8[k2].h8 = (loadnt((half8*)(&B_[8 * K]))); +#endif +#if (YTILE >= 10) + if (n+9>=N) continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); +#endif +#if (YTILE >= 11) + if (n+10>=N) continue; bigB10[k2].h8 = (loadnt((half8*)(&B_[10 * K]))); +#endif +*/ + } + + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m=0; m < M; m++) + { + if (k_+K*m < 32*1024) + bigA[m][k2] = *((const bigType*)(&(s[k_+K*m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_+K*m]))); + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t m = 0; m < M; m++) + { +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Do the matrix multiplication of activation and weight matrix + // - Rememeber the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) + { + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][0]) : "0"(sum[m][0]), "v" (bigA[m][k2].f[b]), "v" (bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][1]) : "0"(sum[m][1]), "v" (bigA[m][k2].f[b]), "v" (bigB1[k2].f[b])); +#endif +#if (YTILE >= 3) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][2]) : "0"(sum[m][2]), "v" (bigA[m][k2].f[b]), "v" (bigB2[k2].f[b])); +#endif +#if (YTILE >= 4) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][3]) : "0"(sum[m][3]), "v" (bigA[m][k2].f[b]), "v" (bigB3[k2].f[b])); +#endif +#if (YTILE >= 5) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][4]) : "0"(sum[m][4]), "v" (bigA[m][k2].f[b]), "v" (bigB4[k2].f[b])); +#endif +#if (YTILE >= 6) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][5]) : "0"(sum[m][5]), "v" (bigA[m][k2].f[b]), "v" (bigB5[k2].f[b])); +#endif +#if (YTILE >= 7) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][6]) : "0"(sum[m][6]), "v" (bigA[m][k2].f[b]), "v" (bigB6[k2].f[b])); +#endif +#if (YTILE >= 8) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][7]) : "0"(sum[m][7]), "v" (bigA[m][k2].f[b]), "v" (bigB7[k2].f[b])); +#endif +#if (YTILE >= 9) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][8]) : "0"(sum[m][8]), "v" (bigA[m][k2].f[b]), "v" (bigB8[k2].f[b])); +#endif +#if (YTILE >= 10) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][9]) : "0"(sum[m][9]), "v" (bigA[m][k2].f[b]), "v" (bigB9[k2].f[b])); +#endif +#if (YTILE >= 11) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][10]) : "0"(sum[m][10]), "v" (bigA[m][k2].f[b]), "v" (bigB10[k2].f[b])); +#endif + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) + { + for (int y = 0; y < YTILE; y++) + { + //for (int offset = 64 / 2; offset > 4 ; offset /= 2) { + // sum[y] += __shfl_down(sum[y], offset); + //} + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + + } + } + + if (threadIdx.x == 0) + { + for (int m = 0; m < M; m++) + { + for (int i = 0; i < YTILE; i++) + { + if (commitColumn[i]) + C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * WvPrGrp * YTILE; + + //if (threadIdx.x == 0) + //n = atomicAdd(((unsigned int*)(C)), YTILE); + //n = __shfl(n, 0, 64); + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if(n < N && (n + YTILE) >= N) + { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) + { + commitColumn[i] = 0; + } + n = startColumn; + } + } + + +} + +#undef YTILE +#undef UNRL +#undef M + +#define YTILE 7 +#define UNRL 1 +#define M 4 + +__global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) + { + commitColumn[i] = 1; + } + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if(n < N && (n + YTILE) >= N) + { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) + { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32*1024); k += THRDS * WvPrGrp * A_CHUNK) + { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + //uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for bank-conflict-free readback + + if (k_in >= min(K * M, 32*1024)) break; + + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of availble columns + //---------------------------------------------------- + while (n < N) + { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // splitted across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m=0; m= 2) + bigType bigB1[UNRL]; +#endif +#if (YTILE >= 3) + bigType bigB2[UNRL]; +#endif +#if (YTILE >= 4) + bigType bigB3[UNRL]; +#endif +#if (YTILE >= 5) + bigType bigB4[UNRL]; +#endif +#if (YTILE >= 6) + bigType bigB5[UNRL]; +#endif +#if (YTILE >= 7) + bigType bigB6[UNRL]; +#endif +#if (YTILE >= 8) + bigType bigB7[UNRL]; +#endif +#if (YTILE >= 9) + bigType bigB8[UNRL]; +#endif +#if (YTILE >= 10) + bigType bigB9[UNRL]; +#endif +#if (YTILE >= 11) + bigType bigB10[UNRL]; +#endif + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) + { + // Fetch the weight matrix from memory! +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + //if (k_ >= K) break; + //bool skip = (k_ >= K); + //bool dummy = (k_ >= K); + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- +#if (YTILE >= 2) + //if (n+1>=N) continue; + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); +#endif +#if (YTILE >= 3) + //if (n+2>=N) continue; + bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); +#endif +#if (YTILE >= 4) + //if (n+3>=N) continue; + bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); +#endif +#if (YTILE >= 5) + //if (n+4>=N) continue; + bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); +#endif +#if (YTILE >= 6) + //if (n+5>=N) continue; + bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); +#endif +#if (YTILE >= 7) + //if (n+6>=N) continue; + bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); +#endif +#if (YTILE >= 8) + //if (n+7>=N) continue; + bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); +#endif +/* +#if (YTILE >= 9) + if (n+8>=N) continue; bigB8[k2].h8 = (loadnt((half8*)(&B_[8 * K]))); +#endif +#if (YTILE >= 10) + if (n+9>=N) continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); +#endif +#if (YTILE >= 11) + if (n+10>=N) continue; bigB10[k2].h8 = (loadnt((half8*)(&B_[10 * K]))); +#endif +*/ + } + + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m=0; m < M; m++) + { + if (k_+K*m < 32*1024) + bigA[m][k2] = *((const bigType*)(&(s[k_+K*m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_+K*m]))); + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t m = 0; m < M; m++) + { +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Do the matrix multiplication of activation and weight matrix + // - Rememeber the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) + { + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][0]) : "0"(sum[m][0]), "v" (bigA[m][k2].f[b]), "v" (bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][1]) : "0"(sum[m][1]), "v" (bigA[m][k2].f[b]), "v" (bigB1[k2].f[b])); +#endif +#if (YTILE >= 3) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][2]) : "0"(sum[m][2]), "v" (bigA[m][k2].f[b]), "v" (bigB2[k2].f[b])); +#endif +#if (YTILE >= 4) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][3]) : "0"(sum[m][3]), "v" (bigA[m][k2].f[b]), "v" (bigB3[k2].f[b])); +#endif +#if (YTILE >= 5) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][4]) : "0"(sum[m][4]), "v" (bigA[m][k2].f[b]), "v" (bigB4[k2].f[b])); +#endif +#if (YTILE >= 6) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][5]) : "0"(sum[m][5]), "v" (bigA[m][k2].f[b]), "v" (bigB5[k2].f[b])); +#endif +#if (YTILE >= 7) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][6]) : "0"(sum[m][6]), "v" (bigA[m][k2].f[b]), "v" (bigB6[k2].f[b])); +#endif +#if (YTILE >= 8) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][7]) : "0"(sum[m][7]), "v" (bigA[m][k2].f[b]), "v" (bigB7[k2].f[b])); +#endif +#if (YTILE >= 9) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][8]) : "0"(sum[m][8]), "v" (bigA[m][k2].f[b]), "v" (bigB8[k2].f[b])); +#endif +#if (YTILE >= 10) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][9]) : "0"(sum[m][9]), "v" (bigA[m][k2].f[b]), "v" (bigB9[k2].f[b])); +#endif +#if (YTILE >= 11) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][10]) : "0"(sum[m][10]), "v" (bigA[m][k2].f[b]), "v" (bigB10[k2].f[b])); +#endif + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) + { + for (int y = 0; y < YTILE; y++) + { + //for (int offset = 64 / 2; offset > 4 ; offset /= 2) { + // sum[y] += __shfl_down(sum[y], offset); + //} + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + + } + } + + if (threadIdx.x == 0) + { + for (int m = 0; m < M; m++) + { + for (int i = 0; i < YTILE; i++) + { + if (commitColumn[i]) + C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * WvPrGrp * YTILE; + + //if (threadIdx.x == 0) + //n = atomicAdd(((unsigned int*)(C)), YTILE); + //n = __shfl(n, 0, 64); + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if(n < N && (n + YTILE) >= N) + { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) + { + commitColumn[i] = 0; + } + n = startColumn; + } + } + + +} + + + +void wvSpltK_(void *in_a, void *in_b, void *out_c, const int M_in, const int K_in,const int N_in, cudaStream_t stream, const int CuCount = 0) { + dim3 grid(CuCount); + dim3 block(THRDS, WvPrGrp); + half* af4 = reinterpret_cast(in_a); + const half* bf4 = reinterpret_cast(in_b); + auto *c = reinterpret_cast(out_c); + switch(N_in) { + case 1: + wvSpltK_hf_m1_<<>>(K_in, M_in, af4, bf4, c, CuCount); + break; + case 2: + wvSpltK_hf_m2_<<>>(K_in, M_in, af4, bf4, c, CuCount); + break; + case 3: + wvSpltK_hf_m3_<<>>(K_in, M_in, af4, bf4, c, CuCount); + break; + case 4: + wvSpltK_hf_m4_<<>>(K_in, M_in, af4, bf4, c, CuCount); + break; + default: + throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + "," + std::to_string(K_in) + "," + std::to_string(N_in)); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) { + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); + } +} + + + diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index a3d299c05caef..6c78430adf4d0 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -23,6 +23,7 @@ def __init__(self): self.bestsols = {} self.load_best_sols() self.create_ds() + self.CuCount = torch.cuda.get_device_properties(device='cuda').multi_processor_count if (self.save_gemm == 1): self.tuned_df = pd.DataFrame(columns=['M', 'N', 'K']) @@ -69,13 +70,12 @@ def mm(self, inp, weights): k = inp_view.shape[1] soltype, solidx = self.query_sol(m=m, n=n, k=k) if soltype == 1: - #print(">>> found hipblas") + print(">>> found hipblas") out = hipb_mm(inp_view, weights.t(), solidx) elif soltype == 2: - #print(">>> found rocblas") + print(">>> found rocblas") out = rocb_mm(inp_view, weights.t(), solidx) else: - if (self.save_gemm == 1): #print('>>>Tgemm Default',inp_view.shape, # inp.shape,weights.shape,soltype,solidx) @@ -89,7 +89,13 @@ def mm(self, inp, weights): ]).drop_duplicates() self.tuned_df.to_csv(self.untune_path, index=False) - if n == 1 and inp_view.dtype == torch.float16: + if ((n == 4 or n == 3 or n== 2 or n == 1 ) and inp_view.dtype == torch.float16) : + out = torch.empty(inp_view.shape[0], + weights.shape[0], + dtype=inp_view.dtype, + device='cuda') + _custom_C.wvSpltK(weights, inp_view, out, n, self.CuCount) + elif n == 1 and inp_view.dtype == torch.float16: out = torch.empty(inp_view.shape[0], weights.shape[0], dtype=inp_view.dtype,