diff --git a/csrc/custom/custom.cu b/csrc/custom/custom.cu index 9e92187967d47..bf196b235178e 100644 --- a/csrc/custom/custom.cu +++ b/csrc/custom/custom.cu @@ -51,6 +51,61 @@ void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int N_in, at::cuda::getCurrentCUDAStream(), CuCount); } +void wvSpltK_fsdMoe_(void* in_a, void* in_b, void* out_c, + void* topk_weights, + void* topk_ids, + void* sorted_token_ids, + void* expert_ids, + void* num_tokens_post_padded, + const int M, const int N, const int K, const int EM, + const int num_valid_tokens, + const int stride_am, + const int stride_ak, + const int stride_be, + const int stride_bk, + const int stride_bn, + const int stride_cm, + const int stride_cn, + const int m_blck_sz, + const bool mul_routed_weight, + const int top_k, + cudaStream_t stream, const int CuCount); + +void wvSpltK_fsdMoe(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, + at::Tensor topk_weights, + at::Tensor topk_ids, + at::Tensor sorted_token_ids, + at::Tensor expert_ids, + at::Tensor num_tokens_post_padded, + const int M, const int N, const int K, const int EM, + const int num_valid_tokens, + const int stride_am, + const int stride_ak, + const int stride_be, + const int stride_bk, + const int stride_bn, + const int stride_cm, + const int stride_cn, + const int m_blck_sz, + const bool mul_routed_weight, + const int top_k, + const int CuCount) { + //int M = in_a.size(0); + //int K = in_a.size(1); + //int N = N_in; + wvSpltK_fsdMoe_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), + topk_weights.data_ptr(), + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), + num_tokens_post_padded.data_ptr(), + M, N, K, EM, + num_valid_tokens, + stride_am, stride_ak,stride_be,stride_bk,stride_bn,stride_cm,stride_cn, + m_blck_sz, mul_routed_weight,top_k, + at::cuda::getCurrentCUDAStream(), CuCount); +} + void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, cudaStream_t stream, const int solidx); @@ -103,5 +158,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("paged_attention_custom", &paged_attention_custom, "PagedAttention LL4Mi Custom."); m.def("wvSpltK", &wvSpltK); + m.def("wvSpltK_fsdMoe", &wvSpltK_fsdMoe); // m.def("MMCustomGPU", &MMCustomGPU); } diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index f03d3da5a8f9c..e55e1510ec27f 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -1925,6 +1925,1419 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, #endif // defined(__HIP__MI300__) TODO: Add NAVI support + + +#undef M +#undef YTILE +#undef UNRL +#define UNRL 1 +//#define M_BLOCK 4 + +template +__global__ void +__launch_bounds__(WvPrGrp * THRDS) +wvSpltK_fsdMoe_hf_( + const DTYPE* __restrict__ A, + const DTYPE* __restrict__ B, + DTYPE* C, + const float* __restrict__ topk_weights, + const int* __restrict__ topk_ids, + const int* __restrict__ sorted_token_ids, + const int* __restrict__ expert_ids, + const int* __restrict__ num_tokens_post_padded, + const int M_in, const int N, const int K, const int E, + const int num_valid_tokens, + const int stride_am, + const int stride_ak, + const int stride_be, + const int stride_bk, + const int stride_bn, + const int stride_cm, + const int stride_cn, + const bool mul_routed_weight, + const int top_k, + const int CuCount + ) { + bool PCML = (K * M_in > 32*1024); + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + }; + + __shared__ half s[1024 * 32]; + + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + + if (!PCML) { + for (uint32_t k = 0; k < min(K * M_in, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + if (k_in >= min(K * M_in, 32 * 1024)) break; + + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + __syncthreads(); + } + + int YW = (YTILE * WvPrGrp); + int TWC = (THRDS * WvPrGrp * A_CHUNK); + int TUC = (THRDS * UNRL * A_CHUNK); + uint32_t kBase = 0; + //find biggest k size that fits in LDS + uint32_t kFit = (32*1024)/M_BLOCK; + //kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple of TUC + kFit = (kFit%TUC==0) ? kFit : (kFit-kFit%TUC); //round down to multiple of TUC + //if (kFit == 0) kFit = TUC; + kFit = min(kFit, K); + + //if (kFit < TUC) PCML = false; + + float sum[M_BLOCK][YTILE]; + + //TRITON + //offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + //offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + //token_mask = offs_token < num_valid_tokens + int offs_token[M_BLOCK]; + bool token_mask[M_BLOCK]; // add to A[] /top_k*k + int off_experts; // add to B[] *K*N loads + + uint32_t Nrndp = (N%YW==0) ? N : (N-N%YW+YW); // Note: All waves in the group need to stay alive to the bitter end, just in case they're needed for cooperative loading of next chunk of A[] into LDS. Such Zomby waves are prevented from doing any real work with continues in the loop below. + if (!PCML) Nrndp = N; //unless its not peicmeal + while (n < Nrndp) { + kBase = 0; + for (uint32_t e = 0; e < num_tokens_post_padded[0]; e+=M_BLOCK) { + kBase = 0; + + for (int m=0; m= K) break; + if (kOff >= kFit) break; + for (uint32_t m = 0; m < M_BLOCK; m++) { + if (!token_mask[m]) continue; + uint32_t k_in = kBase + (offs_token[m]/top_k) * K + kOff; + uint32_t k_ot = m * kFit + kOff; + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + } + } + __syncthreads(); + } + } + + // kept alive just to participate in A[] loads + if (n >= N) continue; + +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // load only 1 column of weights, despite the moe-gate, made possible by expert list. + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M_BLOCK; m++) + { + if (!token_mask[m]) continue; + if (PCML) { + //bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + kFit*m]))); + // skip A[] fetches for Ms that are disabled + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } else { + int aidx = k_ + (offs_token[m]/top_k) * K; + if (aidx + A_CHUNK <= 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[aidx]))); + else + bigA[m][k2] = *((const bigType*)(&(A[aidx]))); + } + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; +#pragma unroll + for (uint32_t m = 0; m < M_BLOCK; m++) { + // skip compute for Ms that are disabled + if (!token_mask[m]) continue; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) + for (int y=0; y= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} + +#define mfmaTILEn 16 +#define mfmaTILEk 4 +//#undef WvPrGrp +//#define WvPrGrp 8 +#define USEMFMA +//#define PIPELINED_33334x +//#define PIPELINED_556x +#define PIPELINED4x + +template +__global__ void +__launch_bounds__(WvPrGrp * THRDS) +wvSpltK_fsdMoe_hf_mfma16_( + const DTYPE* __restrict__ A, + const DTYPE* __restrict__ B, + DTYPE* C, + const float* __restrict__ topk_weights, + const int* __restrict__ topk_ids, + const int* __restrict__ sorted_token_ids, + const int* __restrict__ expert_ids, + const int* __restrict__ num_tokens_post_padded, + const int M_in, const int N, const int K, const int E, + const int num_valid_tokens, + const int stride_am, + const int stride_ak, + const int stride_be, + const int stride_bk, + const int stride_bn, + const int stride_cm, + const int stride_cn, + const bool mul_routed_weight, + const int top_k, + const int CuCount + ) { + +using halfCxT = __attribute__((__vector_size__(mfmaTILEn * A_CHUNK / 2 * sizeof(float)))) float; +using halfC = __attribute__((__vector_size__(A_CHUNK / 2 * sizeof(float)))) float; +using halfT = __attribute__((__vector_size__(mfmaTILEk / 2 * sizeof(float)))) float; + +bool PCML = true;//(K * M_in > 32*1024); + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + int i[A_CHUNK / 2]; + long int l[A_CHUNK / 4]; + halfT hT[A_CHUNK / mfmaTILEk]; + halfC hC; + }; + union bigTypeXt{ + bigType B[mfmaTILEn]; + halfCxT hCT; + }; + + + __shared__ half s[1024 * 32]; + + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + int ETILE = (CuCount * WvPrGrp ) / (N/YTILE); // bump up etile to fill machine + if (ETILE < 1) ETILE = 1; //TODO: what is best default min ETILE? + if (M_in >= 128) ETILE = min(M_in/64, 15); // Heuristic: Add an ETILE for every 64 Ms + + const int num_tblk = num_tokens_post_padded[0] / M_BLOCK; + + // its worth spending time trying to load balance for this num_tokens... + if ((CuCount/(ETILE*2) > 0) && (num_tblk>0))// TODO: make sure all overflow/inf conditions are avoided + { + int nPrRnd0 = ((CuCount/(ETILE))*WvPrGrp)*YTILE; + int nRnds0 = (N + nPrRnd0 - 1 ) / nPrRnd0; + int tRnds0 = (num_tblk + (ETILE) - 1) / (ETILE); + int rnds0 = nRnds0 * tRnds0; + + int nPrRnd1n = ((CuCount/(ETILE/2))*WvPrGrp)*YTILE; + int nRnds1n = (N + nPrRnd1n - 1 ) / nPrRnd1n; + int tRnds1n = (num_tblk + (ETILE/2) - 1) / (ETILE/2); + int rnds1n = nRnds1n * tRnds1n; + + int nPrRnd1p = ((CuCount/(ETILE*2))*WvPrGrp)*YTILE; + int nRnds1p = (N + nPrRnd1p - 1 ) / nPrRnd1p; + int tRnds1p = (num_tblk + (ETILE*2) - 1) / (ETILE*2); + int rnds1p = nRnds1p * tRnds1p; + + int etl = ETILE; + if (rnds0 > rnds1n) { etl = ETILE/2; rnds0 = rnds1n; } + if (rnds0 > rnds1p) { etl = ETILE*2; rnds0 = rnds1p; } + ETILE = etl; + } + + uint32_t n = ((blockIdx.x/ETILE) * WvPrGrp + threadIdx.y) * YTILE; + +/* if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + }*/ + + if (!PCML) { + for (uint32_t k = 0; k < min(K * M_in, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + if (k_in >= min(K * M_in, 32 * 1024)) break; + + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + __syncthreads(); + } + + int YW = (YTILE * WvPrGrp); + int TWC = (THRDS * WvPrGrp * A_CHUNK); + int TUC = (THRDS * UNRL * A_CHUNK); + uint32_t kBase = 0; + //find biggest k size that fits in LDS + uint32_t kFit = (32*1024)/M_BLOCK; + //kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple of TUC + kFit = (kFit%TUC==0) ? kFit : (kFit-kFit%TUC); //round down to multiple of TUC + //if (kFit == 0) kFit = TUC; + kFit = min(kFit, K); + +#ifdef USEMFMA + using float4_ = __attribute__( (__vector_size__(4 * sizeof(float)) )) float; + float4_ sum4; +#else + float sum[M_BLOCK][YTILE]; +#endif + + //TRITON + //offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + //offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + //token_mask = offs_token < num_valid_tokens + uint32_t offs_token[M_BLOCK]; + bool token_mask[M_BLOCK]; // add to A[] /top_k*k + uint32_t off_experts; // add to B[] *K*N loads + + int kShfl = A_CHUNK * THRDS * ( threadIdx.y + (threadIdx.x/16)); + int kSprd = A_CHUNK * ( threadIdx.x ); + + uint32_t Nrndp = (N%YW==0) ? N : (N-N%YW+YW); // Note: All waves in the group need to stay alive to the bitter end, just in case they're needed for cooperative loading of next chunk of A[] into LDS. Such Zomby waves are prevented from doing any real work with continues in the loop below. + if (!PCML) Nrndp = N; //unless its not peicmeal + while (n < Nrndp) { + kBase = 0; + for (uint32_t e = (blockIdx.x % ETILE) * M_BLOCK; e < num_tokens_post_padded[0]; e+=M_BLOCK*ETILE) { + kBase = 0; + +#pragma unroll M_BLOCK + for (uint32_t m=0; m= K) break; + if (kOff >= kFit) break; +#ifdef USEMFMA + uint32_t k_in = kBase + (offs_token[m]/top_k) * K + kOff; + uint32_t k_ot = m * K + kOff; // yes, K should be kFit here. but we'lltranspose this below anyway + // Transpose A for MFMAs + uint32_t k_in_x = (k_ot / A_CHUNK) % (K / A_CHUNK); + uint32_t k_in_y = (k_ot / A_CHUNK) / (K / A_CHUNK); + uint32_t k_ot_x = (k_in_x / mfmaTILEn) * mfmaTILEn + (k_in_y % mfmaTILEn); + uint32_t k_ot_y = (k_in_y / mfmaTILEn) * mfmaTILEn + (k_in_x % mfmaTILEn); + + k_ot = (k_ot_y * (kFit / A_CHUNK) + k_ot_x) * A_CHUNK; + + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + //} +#else + //int m = threadIdx.x % M_BLOCK; + //for (uint32_t m = 0; m < M_BLOCK; m++) { + //if (!token_mask[m]) continue; + uint32_t k_in = kBase + (offs_token[m]/top_k) * K + kOff; + uint32_t k_ot = m * kFit + kOff; + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + //} +#endif + } + __syncthreads(); + } + } + + // kept alive just to participate in A[] loads + if (n >= N) continue; + + int k1 = k1_; + if (shflk) k1 = kBase + (((k1_-kBase) + kShfl) % kFit ); // shfl loads within this lane, to reduce temporal hotspotting + + #define StgMfma4(_LN) { \ + for (uint32_t _t = 0; _t < A_CHUNK/mfmaTILEk; _t++) { \ + sum4 = __builtin_amdgcn_mfma_f32_16x16x16f16( \ + bigB[0][k2].B[_LN].hT[_t], \ + bigA[_LN][k2].hT[_t], \ + sum4, 0, 0, 0); \ + } \ + } + + +#ifdef PIPELINED1x +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y= K) break; + for (int m = 0; m < M_BLOCK; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y= K) break; + for (int m = 0; m < M_BLOCK/2; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=YTILE/2; y= K) break; + for (int m = M_BLOCK/2; m < M_BLOCK; m++) + { + bigA[m-M_BLOCK/2][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y= K) break; + for (int m = 0; m < M_BLOCK/4; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=YTILE/4; y= K) break; + for (int m = M_BLOCK/4; m < M_BLOCK/2; m++) + { + bigA[m-M_BLOCK/4][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=YTILE/2; y<3*YTILE/4; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/2].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-YTILE/2].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = M_BLOCK/2; m < 3*M_BLOCK/4; m++) + { + bigA[m-M_BLOCK/2][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=3*YTILE/4; y= K) break; + for (int m = 3*M_BLOCK/4; m < M_BLOCK; m++) + { + bigA[m-3*M_BLOCK/4][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y<3; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 0; m < 3; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<3; l++) + StgMfma4(l); + } + +///////////////////////////ROUND 2////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=3; y<6; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/4].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-3].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 3; m < 6; m++) + { + bigA[m-3][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<3; l++) + StgMfma4(l); + } +///////////////////////////ROUND 3////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=6; y<9; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/2].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-6].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 6; m < 9; m++) + { + bigA[m-6][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<3; l++) + StgMfma4(l); + } + +///////////////////////////ROUND 4////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=9; y<12; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/2].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-9].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 9; m < 12; m++) + { + bigA[m-9][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<3; l++) + StgMfma4(l); + } + +///////////////////////////ROUND 5////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=12; y<16; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/2].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-12].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 12; m < 16; m++) + { + bigA[m-12][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<4; l++) + StgMfma4(l); + } + + + + +#elif defined(PIPELINED_556x) //556x + +///////////////////////////ROUND 1////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y<5; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 0; m < 5; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<5; l++) + StgMfma4(l); + //} + +///////////////////////////ROUND 2////////////////////////// +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + kSprd; + // if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=5; y<10; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/4].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-5].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 5; m < 10; m++) + { + bigA[m-5][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<5; l++) + StgMfma4(l); + //} +///////////////////////////ROUND 3////////////////////////// + //#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + kSprd; + // if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=10; y<16; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/2].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-10].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 10; m < 16; m++) + { + bigA[m-10][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<6; l++) + StgMfma4(l); + } + +#elif defined(PIPELINED8x) //8x + +///////////////////////////ROUND 1////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y= K) break; + for (int m = 0; m < M_BLOCK/8; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=YTILE/8; y<2*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = M_BLOCK/8; m < 2*M_BLOCK/8; m++) + { + bigA[m-M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=2*YTILE/8; y<3*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-2*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-2*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 2*M_BLOCK/8; m < 3*M_BLOCK/8; m++) + { + bigA[m-2*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=3*YTILE/8; y<4*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-3*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-3*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 3*M_BLOCK/8; m < 4*M_BLOCK/8; m++) + { + bigA[m-3*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=4*YTILE/8; y<5*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-4*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-4*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 4*M_BLOCK/8; m < 5*M_BLOCK/8; m++) + { + bigA[m-4*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=5*YTILE/8; y<6*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-5*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-5*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 5*M_BLOCK/8; m < 6*M_BLOCK/8; m++) + { + bigA[m-5*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=6*YTILE/8; y<7*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-6*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-6*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 6*M_BLOCK/8; m < 7*M_BLOCK/8; m++) + { + bigA[m-6*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=7*YTILE/8; y<8*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-7*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-7*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 7*M_BLOCK/8; m < 8*M_BLOCK/8; m++) + { + bigA[m-7*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; +#ifdef USEMFMA + for (int y=0; y= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M_BLOCK; m++) + { +#ifdef USEMFMA +#else + if (!token_mask[m]) continue; +#endif + if (PCML) { + //bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + kFit*m]))); + // skip A[] fetches for Ms that are disabled + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } else { + int aidx = k_ + (offs_token[m]/top_k) * K; + if (aidx + A_CHUNK <= 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[aidx]))); + else + bigA[m][k2] = *((const bigType*)(&(A[aidx]))); + } + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + +#ifdef USEMFMA + bigType stgB; + for (int l=0; l= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} + + + +// a = torch.randn((m, k) +// b1 = torch.randn((e, 2 * n, k) +// b2 = torch.randn((e, k, n) +// topk_weights = torch.randn((m, e), device='cuda', dtype=dtype) + +void wvSpltK_fsdMoe_(void* in_a, void* in_b, void* out_c, + void* topk_weights, + void* topk_ids, + void* sorted_token_ids, + void* expert_ids, + void* num_tokens_post_padded, + const int M_in, const int N_in, const int K_in, const int E, + const int num_valid_tokens, + const int stride_am, + const int stride_ak, + const int stride_be, + const int stride_bk, + const int stride_bn, + const int stride_cm, + const int stride_cn, + const int m_blck_sz, + const bool mul_routed_weight, + const int top_k, + cudaStream_t stream, const int CuCount) { + dim3 grid(CuCount); + dim3 block(THRDS, WvPrGrp); + auto* a = reinterpret_cast(in_a); + auto* b = reinterpret_cast(in_b); + auto* c = reinterpret_cast(out_c); + auto* topk_weights_ = reinterpret_cast(topk_weights); + auto* topk_ids_ = reinterpret_cast(topk_ids); + auto* sorted_token_ids_ = reinterpret_cast(sorted_token_ids); + auto* expert_ids_ = reinterpret_cast(expert_ids); + auto* num_tokens_post_padded_ = reinterpret_cast(num_tokens_post_padded); + switch (m_blck_sz) { + case 1: + wvSpltK_fsdMoe_hf_<1,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 2: + wvSpltK_fsdMoe_hf_<2,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 3: + wvSpltK_fsdMoe_hf_<3,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 4: + wvSpltK_fsdMoe_hf_<4,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 5: + wvSpltK_fsdMoe_hf_<5,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 6: + wvSpltK_fsdMoe_hf_<6,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 16: + wvSpltK_fsdMoe_hf_mfma16_<16,16><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + + } +} + void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, const int K_in, const int N_in, cudaStream_t stream, const int CuCount = 0) { diff --git a/vllm/envs.py b/vllm/envs.py index 739a4792ce078..287f4f9eadf9e 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -42,7 +42,10 @@ VERBOSE: bool = False VLLM_SYNC_SERVER_ACCUM_REQUESTS: int = 1 VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS: int = 1 - VLLM_MOE_PADDING: bool = True + VLLM_MOE_PADDING: bool = False + + VLLM_MOE_MFMASWIZZLE: bool = True + VLLM_MOE_MFMASWIZZLE_M_THRSHLD: int = 32 # The begin-* and end* here are used by the documentation generator # to extract the used env vars. @@ -90,6 +93,13 @@ "VERBOSE": lambda: bool(int(os.getenv('VERBOSE', '0'))), + # Swizzle the weights for mfma ops in moe kernel, or not + "VLLM_MOE_MFMASWIZZLE": + lambda: bool(int(os.getenv("VLLM_MOE_MFMASWIZZLE", "1"))), + # Swizzle the weights for mfma ops in moe kernel, or not + "VLLM_MOE_MFMASWIZZLE_M_THRSHLD": + lambda: int(os.getenv("VLLM_MOE_MFMASWIZZLE_M_THRSHLD", "32")), + # Root directory for VLLM configuration files # Note that this not only affects how vllm finds its configuration files # during runtime, but also affects how vllm installs its configuration @@ -245,7 +255,7 @@ # Pad the weight for moe kernel or not "VLLM_MOE_PADDING": - lambda: bool(int(os.getenv("VLLM_MOE_PADDING", "1"))), + lambda: bool(int(os.getenv("VLLM_MOE_PADDING", "0"))), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e759d63b588b3..7d0a57c6f4236 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -13,6 +13,8 @@ from vllm import envs from vllm.logger import init_logger +from vllm import _custom_C + logger = init_logger(__name__) padding_size = 128 if envs.VLLM_MOE_PADDING else 0 @@ -229,6 +231,42 @@ def moe_align_block_size( ) return sorted_ids, expert_ids, num_tokens_post_pad +def invoke_mega_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + m_blck_sz: int, mul_routed_weight: bool, top_k: int, + use_fp8: bool) -> None: + assert topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + #print("\nm=",A.shape[0],"n=",B.shape[1],"k=",B.shape[2],"e=", B.shape[0], "ml_rt:",mul_routed_weight,"tpk",top_k, "\n") + _custom_C.wvSpltK_fsdMoe(#A, B, C, B.shape[1], 80) + A, + B, + C, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + A.shape[0], + B.shape[1], + B.shape[2], + B.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + m_blck_sz, + mul_routed_weight, + top_k, + 80) def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, A_scale: Optional[torch.Tensor], @@ -367,8 +405,9 @@ def fused_experts(hidden_states: torch.Tensor, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None): # Check constraints. - assert hidden_states.shape[ - 1] == w1.shape[2] - padding_size, "Hidden size mismatch" + #print("hidenSize:", hidden_states.shape) + # print("hidenSize:", w1.shape[2]) + assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -407,7 +446,7 @@ def fused_experts(hidden_states: torch.Tensor, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, } - + intermediate_cache1 = torch.empty( (M, topk_ids.shape[1], N), device=hidden_states.device, @@ -424,12 +463,87 @@ def fused_experts(hidden_states: torch.Tensor, dtype=hidden_states.dtype, ) - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config['BLOCK_SIZE_M'], E) compute_type = (tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16) + #print(hidden_states.shape) + #print(intermediate_cache2.shape) + #print("M1:", hidden_states.shape[0], "M2:", intermediate_cache2.shape[0]) + #if hidden_states.shape[0] <= 256 and hidden_states.shape[1] % 8 == 0 and intermediate_cache2.shape[0] <= 256 and not use_fp8 : + + #WVSPLTK_M_THRSHLD = 64 + #if hidden_states.shape[0] <= WVSPLTK_M_THRSHLD \ + # and hidden_states.shape[1] % 8 == 0 \ + # and intermediate_cache2.shape[0] <= WVSPLTK_M_THRSHLD \ + # and intermediate_cache2.shape[1] % 8 == 0 \ + # and not use_fp8 : + if envs.VLLM_MOE_MFMASWIZZLE and M<=envs.VLLM_MOE_MFMASWIZZLE_M_THRSHLD: + assert(compute_type == tl.float16, "Only fp16 supported for wvSplitK_mfma16x16 for now") + #m_blck_sz = -(-(M*topk_ids.shape[1]*3)//E) # target 75% of expert distribution for this M size + #if (m_blck_sz >= 12): + # m_blck_sz = 16 + + # all calls go to wvSplitK_mfma16x16 + m_blck_sz = 16 # TODO: this is for decode stage, need another for prefill + #print("M:", M, " M_BLOCK PICKED:", m_blck_sz) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, m_blck_sz, E) # target 75% of expert distribution for this M size + #topk_ids, config2['BLOCK_SIZE_M'],E) + #print("\nsrtd_tkn:", sorted_token_ids) + #print("w1Shape:",w1.shape) + + #env VLLM_MOE_MFMASWIZZLE does this swizzle on init + w1_ = w1 + w2_ = w2 + if not envs.VLLM_MOE_MFMASWIZZLE : # for debug only + if m_blck_sz >= 16 : + w1_ = torch.clone(w1) + w1_ = w1_.view(w1.shape[0], w1.shape[1]//16, 16, w1.shape[2]//128, 16, 8); + w1_ = w1_.permute(0, 1, 4, 3, 2, 5) + w1_ = w1_.contiguous() + w1_ = w1_.view(w1.shape[0],w1.shape[1],w1.shape[2]); + w2_ = torch.clone(w2) + w2_ = w2_.view(w2.shape[0], w2.shape[1]//16, 16, w2.shape[2]//128, 16, 8); + w2_ = w2_.permute(0, 1, 4, 3, 2, 5) + w2_ = w2_.contiguous() + w2_ = w2_.view(w2.shape[0],w2.shape[1],w2.shape[2]); + + #print(w1_) + + invoke_mega_fused_moe_kernel(hidden_states, + w1_, + intermediate_cache1, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + m_blck_sz, + False, + topk_ids.shape[1], + use_fp8=use_fp8) + #print("shdr_invk1:",intermediate_cache1.view(-1, N)) + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + #print("shdr_silu:",intermediate_cache2) + #print("shdr_silu_shape:", intermediate_cache2.shape) + #print("-----------------------------") + + invoke_mega_fused_moe_kernel(intermediate_cache2, + w2_, + intermediate_cache3, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + m_blck_sz, + True, + 1, + use_fp8=use_fp8) - invoke_fused_moe_kernel(hidden_states, + else: + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config['BLOCK_SIZE_M'], E) + invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, a1_scale, @@ -445,9 +559,9 @@ def fused_experts(hidden_states: torch.Tensor, compute_type=compute_type, use_fp8=use_fp8) - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - invoke_fused_moe_kernel(intermediate_cache2, + invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, a2_scale, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index ee9db7048f1f6..382a112df3bbd 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -181,8 +181,29 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, param_data[expert_id] = loaded_weight def process_weights_after_loading(self): - # Fp8 is the only case where we need to process after loading. if not self.use_fp8: + if envs.VLLM_MOE_MFMASWIZZLE: + #w1_ = torch.clone(w1) + b,n,k = self.w13_weight.shape + w1_ = self.w13_weight + w1_ = w1_.view(b, n//16, 16, k//128, 16, 8); + w1_ = w1_.transpose(2,4) + w1_ = w1_.contiguous() + w1_ = w1_.view(b,n,k) + w1_ = w1_.contiguous() + self.w13_weight = nn.Parameter(w1_, requires_grad=False) + + #w2_ = torch.clone(w2) + b,n,k = self.w2_weight.shape + w2_ = self.w2_weight + w2_ = w2_.view(b, n//16, 16, k//128, 16, 8); + w2_ = w2_.transpose(2,4) + #w2_ = w2_.permute(0, 1, 4, 3, 2, 5) + w2_ = w2_.contiguous() + w2_ = w2_.view(b,n,k) + w2_ = w2_.contiguous() + self.w2_weight = nn.Parameter(w2_, requires_grad=False) + return if envs.VLLM_MOE_PADDING: self.w13_weight = nn.Parameter(F.pad(self.w13_weight.data, (0, 128), "constant", 0),