Skip to content

Commit

Permalink
adds wvSpltK optimization for skinny gemm.
Browse files Browse the repository at this point in the history
  • Loading branch information
Hashem Hashemi committed Jun 17, 2024
1 parent 0370719 commit 9e3785b
Show file tree
Hide file tree
Showing 3 changed files with 1,461 additions and 4 deletions.
15 changes: 15 additions & 0 deletions csrc/custom/custom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
at::cuda::getCurrentCUDAStream(), rows_per_block);
}

void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K, const int N,
cudaStream_t stream, const int CuCount);

void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
const int N_in, const int CuCount) {
int M = in_a.size(0);
int K = in_a.size(1);
int N = N_in;
wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N,
at::cuda::getCurrentCUDAStream(), CuCount);
}



void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int solidx);

Expand Down Expand Up @@ -90,5 +104,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("LLZZ", &LLZZ);
m.def("paged_attention_custom", &paged_attention_custom,
"PagedAttention LL4Mi Custom.");
m.def("wvSpltK", &wvSpltK);
// m.def("MMCustomGPU", &MMCustomGPU);
}
Loading

0 comments on commit 9e3785b

Please sign in to comment.