Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] patch wvSpltK_fused_moe from https://github.com/amd-hhashemi/vllm/tre… #126

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions csrc/custom/custom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,61 @@ void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int N_in,
at::cuda::getCurrentCUDAStream(), CuCount);
}

void wvSpltK_fsdMoe_(void* in_a, void* in_b, void* out_c,
void* topk_weights,
void* topk_ids,
void* sorted_token_ids,
void* expert_ids,
void* num_tokens_post_padded,
const int M, const int N, const int K, const int EM,
const int num_valid_tokens,
const int stride_am,
const int stride_ak,
const int stride_be,
const int stride_bk,
const int stride_bn,
const int stride_cm,
const int stride_cn,
const int m_blck_sz,
const bool mul_routed_weight,
const int top_k,
cudaStream_t stream, const int CuCount);

void wvSpltK_fsdMoe(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
at::Tensor topk_weights,
at::Tensor topk_ids,
at::Tensor sorted_token_ids,
at::Tensor expert_ids,
at::Tensor num_tokens_post_padded,
const int M, const int N, const int K, const int EM,
const int num_valid_tokens,
const int stride_am,
const int stride_ak,
const int stride_be,
const int stride_bk,
const int stride_bn,
const int stride_cm,
const int stride_cn,
const int m_blck_sz,
const bool mul_routed_weight,
const int top_k,
const int CuCount) {
//int M = in_a.size(0);
//int K = in_a.size(1);
//int N = N_in;
wvSpltK_fsdMoe_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(),
topk_weights.data_ptr(),
topk_ids.data_ptr(),
sorted_token_ids.data_ptr(),
expert_ids.data_ptr(),
num_tokens_post_padded.data_ptr(),
M, N, K, EM,
num_valid_tokens,
stride_am, stride_ak,stride_be,stride_bk,stride_bn,stride_cm,stride_cn,
m_blck_sz, mul_routed_weight,top_k,
at::cuda::getCurrentCUDAStream(), CuCount);
}

void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int solidx);

Expand Down Expand Up @@ -103,5 +158,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("paged_attention_custom", &paged_attention_custom,
"PagedAttention LL4Mi Custom.");
m.def("wvSpltK", &wvSpltK);
m.def("wvSpltK_fsdMoe", &wvSpltK_fsdMoe);
// m.def("MMCustomGPU", &MMCustomGPU);
}
Loading
Loading