From 27eea97d7d55fb4c0530d48c6ede0ea560b7ed1c Mon Sep 17 00:00:00 2001 From: Matt Wong <156021403+mawong-amd@users.noreply.github.com> Date: Fri, 7 Jun 2024 15:52:46 -0500 Subject: [PATCH] Linting main in line with upstream requirements (#43) * Linting * Fix linting for triton: unmeld if with constexpr --- .../kernels/benchmark_paged_attention.py | 3 +- csrc/custom/custom.cu | 137 +- csrc/custom/custom_kernels.cu | 604 ++++---- csrc/custom/fused_kernels.cu | 343 ++--- .../custom/paged_attention/attention_ll4mi.cu | 1335 +++++++++-------- csrc/ops.h | 26 +- csrc/pybind.cpp | 4 +- csrc/quantization/fp8/amd/gemm_kernel.cu | 514 ++++--- csrc/quantization/fp8/amd/quant_utils.cuh | 491 +++--- csrc/quantization/fp8/common.cu | 156 +- vllm/attention/backends/rocm_flash_attn.py | 91 +- vllm/attention/ops/triton_flash_attention.py | 295 ++-- vllm/distributed/communication_op.py | 3 +- vllm/distributed/parallel_state.py | 3 +- vllm/model_executor/layers/linear.py | 2 - .../layers/quantization/__init__.py | 2 +- .../layers/quantization/fp8_rocm.py | 123 +- vllm/model_executor/model_loader/loader.py | 9 +- vllm/model_executor/models/llama.py | 14 +- 19 files changed, 2197 insertions(+), 1958 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 0fcfc0a295ca2..d0d990410bc6e 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -81,8 +81,7 @@ def main( if not args.custom_paged_attn: global PARTITION_SIZE PARTITION_SIZE = 512 - num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // - PARTITION_SIZE) + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), dtype=output.dtype, diff --git a/csrc/custom/custom.cu b/csrc/custom/custom.cu index d75b2d2e41005..3da25ece3e87c 100644 --- a/csrc/custom/custom.cu +++ b/csrc/custom/custom.cu @@ -6,94 +6,89 @@ namespace py = pybind11; // declare templates for front (cpp) and back (cuda) sides of function: -//template - -void LLGemm_Silu(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int rows_per_block); -void LLMM_Silu(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int rows_per_block) { - int M = in_a.size(0); - int K = in_a.size(1); - LLGemm_Silu(in_a.data_ptr(), in_b.data_ptr(), - out_c.data_ptr(), M, K, at::cuda::getCurrentCUDAStream(),rows_per_block); +// template + +void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block); +void LLMM_Silu(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, + const int rows_per_block) { + int M = in_a.size(0); + int K = in_a.size(1); + LLGemm_Silu(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, + at::cuda::getCurrentCUDAStream(), rows_per_block); } -void LLGemm1(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream,const int rows_per_block); - -//template -void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int rows_per_block=4) { - int M = in_a.size(0); - int K = in_a.size(1); - //if (N != in_b.numel()) - // throw std::invalid_argument("Size mismatch A.numel(): " + std::to_string(in_a.numel()) - // + ", B.numel(): " + std::to_string(in_b.numel())); - - //out_c.resize_({N}); - - // call the kernel function... - LLGemm1(in_a.data_ptr(), in_b.data_ptr(), - out_c.data_ptr(), M, K, at::cuda::getCurrentCUDAStream(),rows_per_block); +void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block); + +// template +void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, + const int rows_per_block = 4) { + int M = in_a.size(0); + int K = in_a.size(1); + // if (N != in_b.numel()) + // throw std::invalid_argument("Size mismatch A.numel(): " + + // std::to_string(in_a.numel()) + // + ", B.numel(): " + + // std::to_string(in_b.numel())); + + // out_c.resize_({N}); + + // call the kernel function... + LLGemm1(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, + at::cuda::getCurrentCUDAStream(), rows_per_block); } -void LLGemmZZ(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int solidx); +void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int solidx); -void LLZZ(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int solidx=0) { - int M = in_a.size(0); - int K = in_a.size(1); +void LLZZ(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, + const int solidx = 0) { + int M = in_a.size(0); + int K = in_a.size(1); - LLGemmZZ(in_a.data_ptr(), in_b.data_ptr(), - out_c.data_ptr(), M, K, at::cuda::getCurrentCUDAStream(),solidx); + LLGemmZZ(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, + at::cuda::getCurrentCUDAStream(), solidx); } // instantiate the CPP template for T=float: -//template void AddGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c); - - -void MMGPUKernel(float *in_a, float *in_b, float *out_c, - int numARows, int numAColumns, - int numBRows, int numBColumns, - int numCRows, int numCColumns, - cudaStream_t stream); +// template void AddGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor +// out_c); +void MMGPUKernel(float* in_a, float* in_b, float* out_c, int numARows, + int numAColumns, int numBRows, int numBColumns, int numCRows, + int numCColumns, cudaStream_t stream); void MMCustomGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c) { - auto matA_sizes { in_a.sizes() }; - auto matB_sizes { in_b.sizes() }; - auto matO_sizes { out_c.sizes() }; - MMGPUKernel(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), - matA_sizes[0], matA_sizes[1], - matB_sizes[0], matB_sizes[1], - matO_sizes[0], matO_sizes[1], - at::cuda::getCurrentCUDAStream()); + auto matA_sizes{in_a.sizes()}; + auto matB_sizes{in_b.sizes()}; + auto matO_sizes{out_c.sizes()}; + MMGPUKernel(in_a.data_ptr(), in_b.data_ptr(), + out_c.data_ptr(), matA_sizes[0], matA_sizes[1], + matB_sizes[0], matB_sizes[1], matO_sizes[0], matO_sizes[1], + at::cuda::getCurrentCUDAStream()); } -void paged_attention_custom( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int block_size, - int max_context_len, +void paged_attention_custom(torch::Tensor& out, torch::Tensor& exp_sums, + torch::Tensor& max_logits, torch::Tensor& tmp_out, + torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, + float scale, torch::Tensor& block_tables, + torch::Tensor& context_lens, int block_size, + int max_context_len, #if 0 torch::Tensor& qk_out, torch::Tensor& softmax_out, #endif - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype); + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype); // declare the extension module with the AddGPU function: -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ - m.doc() = "pybind11 example plugin"; - m.def("LLMM1", &LLMM1); - m.def("LLMM_Silu", &LLMM_Silu); - m.def("LLZZ", &LLZZ); - m.def( - "paged_attention_custom", - &paged_attention_custom, - "PagedAttention LL4Mi Custom."); -//m.def("MMCustomGPU", &MMCustomGPU); +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "pybind11 example plugin"; + m.def("LLMM1", &LLMM1); + m.def("LLMM_Silu", &LLMM_Silu); + m.def("LLZZ", &LLZZ); + m.def("paged_attention_custom", &paged_attention_custom, + "PagedAttention LL4Mi Custom."); + // m.def("MMCustomGPU", &MMCustomGPU); } diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index b5ab0dbe8317c..6321f7ba23b3f 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -7,361 +7,355 @@ constexpr int WARP_SIZE = 64; template __device__ __forceinline__ T loadnt(T* addr) { - return __builtin_nontemporal_load(addr); + return __builtin_nontemporal_load(addr); } __device__ __forceinline__ float4 load_ntmprl(const float4* addr) { - auto addr_alias = reinterpret_cast(addr); - auto dat0 = loadnt(addr_alias); - auto dat1 = loadnt(addr_alias + 1); - auto dat2 = loadnt(addr_alias + 2); - auto dat3 = loadnt(addr_alias + 3); - //auto dat0 = *(addr_alias); - //auto dat1 = *(addr_alias+1); - //auto dat2 = *(addr_alias+2); - //auto dat3 = *(addr_alias+3); - return make_float4(dat0,dat1,dat2,dat3); + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + // auto dat0 = *(addr_alias); + // auto dat1 = *(addr_alias+1); + // auto dat2 = *(addr_alias+2); + // auto dat3 = *(addr_alias+3); + return make_float4(dat0, dat1, dat2, dat3); } -//TBlock fetches entire rows of A, and entire col of B (K dimension); assume N=1 for time being -//grid is M/A_NUM_ROWS blocks +// TBlock fetches entire rows of A, and entire col of B (K dimension); assume +// N=1 for time being grid is M/A_NUM_ROWS blocks template -__global__ void LLGemm1_kernel(float4 *af4, __half2 *bf4, __half2 *c) { - __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; - const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * blockDim.x; - //int row_addr_1 = row_addr + CUDA_NUM_THREADS; - //int row_addr_2 = row_addr_1 + CUDA_NUM_THREADS; - //int row_addr_3 = row_addr_2 + CUDA_NUM_THREADS; - const int threadid = threadIdx.x; - const int warp = threadIdx.x / WARP_SIZE; - const int lane = threadIdx.x % WARP_SIZE; - const int num_warps = blockDim.x / WARP_SIZE; - const int qwarpid = threadid/16; - const int qthreadid = threadid%16; - float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; - //float4 colB_elem4; - __half2 colB_elem4x,colB_elem4y,colB_elem4z,colB_elem4w; - float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; - float acc[NUM_A_ROWS_PER_BLOCK]; //= 0.0; - __half2 acch2; - __half2 oval; - - //rowA_elem4 = af4[row_addr + threadid]; - //__syncthreads(); - //rowA_elem4_1 = af4[row_addr_1 + threadid]; - //rowA_elem4_2 = af4[row_addr_2 + threadid]; - //rowA_elem4_3 = af4[row_addr_3 + threadid]; - #pragma unroll - for (int i=0; i(&colB_elem4); - //auto Bf2x = *Bh2ptr; - //auto Bf2y = *(Bh2ptr+1); - //auto Bf2z = *(Bh2ptr+2); - //auto Bf2w = *(Bh2ptr+3); - auto Ah2ptr = reinterpret_cast<__half2 *>(&rowA_elem4); - __half2 *ah2lptr; - #pragma unroll - for (int i=0; i= 1; mask /= 2) { - #pragma unroll - for (int i=0; i(&colB_elem4); + // auto Bf2x = *Bh2ptr; + // auto Bf2y = *(Bh2ptr+1); + // auto Bf2z = *(Bh2ptr+2); + // auto Bf2w = *(Bh2ptr+3); + auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4); + __half2* ah2lptr; +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + ah2lptr = Ah2ptr + i * 4; + Af2 = *(ah2lptr); + acch2 = __hmul2(Af2, colB_elem4x); + Af2 = *(ah2lptr + 1); + acch2 = __hfma2(Af2, colB_elem4y, acch2); + Af2 = *(ah2lptr + 2); + acch2 = __hfma2(Af2, colB_elem4z, acch2); + Af2 = *(ah2lptr + 3); + acch2 = __hfma2(Af2, colB_elem4w, acch2); + S = __half22float2(acch2); + acc[i] = S.x + S.y; + } - ////if (qthreadid= 1; mask /= 2) { - //#pragma unroll - //for (int i=0; i= 1; mask /= 2) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + acc[i] += __shfl_xor(acc[i], mask); + } + } + // Warp leaders store the data to shared memory. + // if (lane == 0) { + // #pragma unroll + // for (int i=0; i8) { - // #pragma unroll - // for (int j=0; j<8; j++) { - // acc[2*threadid] += red_smem[2*threadid][j]; - // acc[2*threadid+1] += red_smem[2*threadid+1][j]; - // } - // } - // #pragma unroll - // for (int j=0; j= 1; mask /= 2) { + // #pragma unroll + // for (int i=0; i8) { + // #pragma unroll + // for (int j=0; j<8; j++) { + // acc[2*threadid] += red_smem[2*threadid][j]; + // acc[2*threadid+1] += red_smem[2*threadid+1][j]; + // } + // } + // #pragma unroll + // for (int j=0; j -void LLGemm1(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int rows_per_block=4) { - float4 *af4 = reinterpret_cast(in_a); - auto *bf4 = reinterpret_cast<__half2*>(in_b); - auto *c = reinterpret_cast<__half2*>(out_c); - //constexpr int A_ROWS_PER_BLOCK = 8; - const int NUM_THREADS = K*2/16; - int NUM_BLOCKS = M/rows_per_block; - if (rows_per_block==2) { - LLGemm1_kernel<2><<>>(af4, bf4, c); - } - else if (rows_per_block==4) { - LLGemm1_kernel<4><<>>(af4, bf4, c); - } - else if (rows_per_block==8) { - LLGemm1_kernel<8><<>>(af4, bf4, c); - } - else if (rows_per_block==16) { - LLGemm1_kernel<16><<>>(af4, bf4, c); - } - else { - NUM_BLOCKS = M/4; - LLGemm1_kernel<4><<>>(af4, bf4, c); - } - +// template +void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block = 4) { + float4* af4 = reinterpret_cast(in_a); + auto* bf4 = reinterpret_cast<__half2*>(in_b); + auto* c = reinterpret_cast<__half2*>(out_c); + // constexpr int A_ROWS_PER_BLOCK = 8; + const int NUM_THREADS = K * 2 / 16; + int NUM_BLOCKS = M / rows_per_block; + if (rows_per_block == 2) { + LLGemm1_kernel<2><<>>(af4, bf4, c); + } else if (rows_per_block == 4) { + LLGemm1_kernel<4><<>>(af4, bf4, c); + } else if (rows_per_block == 8) { + LLGemm1_kernel<8><<>>(af4, bf4, c); + } else if (rows_per_block == 16) { + LLGemm1_kernel<16><<>>(af4, bf4, c); + } else { + NUM_BLOCKS = M / 4; + LLGemm1_kernel<4><<>>(af4, bf4, c); + } - cudaError_t err = cudaGetLastError(); - if (cudaSuccess != err) - throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } // instantiate the kernel template for T=float: -//template void AddGPUKernel(float *in_a, float *in_b, float *out_c, const int M, const int K, cudaStream_t stream); +// template void AddGPUKernel(float *in_a, float *in_b, float *out_c, +// const int M, const int K, cudaStream_t stream); const unsigned int TILE_WIDTH = 32; // Compute C = A * B -__global__ void matrixMultiplyShared(float *A, float *B, float *C, - int numARows, int numAColumns, - int numBRows, int numBColumns, - int numCRows, int numCColumns) { - __shared__ float sA[TILE_WIDTH][TILE_WIDTH]; // Tile size of 32x32 - __shared__ float sB[TILE_WIDTH][TILE_WIDTH]; - - int Row = blockDim.y * blockIdx.y + threadIdx.y; - int Col = blockDim.x * blockIdx.x + threadIdx.x; - float Cvalue = 0.0; - sA[threadIdx.y][threadIdx.x] = 0.0; - sB[threadIdx.y][threadIdx.x] = 0.0; - - for (int ph = 0; ph < (((numAColumns - 1) / TILE_WIDTH) + 1); ph++) { - if ((Row < numARows) && (threadIdx.x + (ph * TILE_WIDTH)) < numAColumns) { - sA[threadIdx.y][threadIdx.x] = A[(Row * numAColumns) + threadIdx.x + (ph * TILE_WIDTH)]; - } else { - sA[threadIdx.y][threadIdx.x] = 0.0; - } - if (Col < numBColumns && (threadIdx.y + ph * TILE_WIDTH) < numBRows) { - sB[threadIdx.y][threadIdx.x] = B[(threadIdx.y + ph * TILE_WIDTH) * numBColumns + Col]; - } else { - sB[threadIdx.y][threadIdx.x] = 0.0; - } - __syncthreads(); - for (int j = 0; j < TILE_WIDTH; ++j) { - Cvalue += sA[threadIdx.y][j] * sB[j][threadIdx.x]; - } - } - if (Row < numCRows && Col < numCColumns) { - C[Row * numCColumns + Col] = Cvalue; - } +__global__ void matrixMultiplyShared(float* A, float* B, float* C, int numARows, + int numAColumns, int numBRows, + int numBColumns, int numCRows, + int numCColumns) { + __shared__ float sA[TILE_WIDTH][TILE_WIDTH]; // Tile size of 32x32 + __shared__ float sB[TILE_WIDTH][TILE_WIDTH]; + + int Row = blockDim.y * blockIdx.y + threadIdx.y; + int Col = blockDim.x * blockIdx.x + threadIdx.x; + float Cvalue = 0.0; + sA[threadIdx.y][threadIdx.x] = 0.0; + sB[threadIdx.y][threadIdx.x] = 0.0; + + for (int ph = 0; ph < (((numAColumns - 1) / TILE_WIDTH) + 1); ph++) { + if ((Row < numARows) && (threadIdx.x + (ph * TILE_WIDTH)) < numAColumns) { + sA[threadIdx.y][threadIdx.x] = + A[(Row * numAColumns) + threadIdx.x + (ph * TILE_WIDTH)]; + } else { + sA[threadIdx.y][threadIdx.x] = 0.0; + } + if (Col < numBColumns && (threadIdx.y + ph * TILE_WIDTH) < numBRows) { + sB[threadIdx.y][threadIdx.x] = + B[(threadIdx.y + ph * TILE_WIDTH) * numBColumns + Col]; + } else { + sB[threadIdx.y][threadIdx.x] = 0.0; + } + __syncthreads(); + for (int j = 0; j < TILE_WIDTH; ++j) { + Cvalue += sA[threadIdx.y][j] * sB[j][threadIdx.x]; + } + } + if (Row < numCRows && Col < numCColumns) { + C[Row * numCColumns + Col] = Cvalue; + } } - -void MMGPUKernel(float *in_a, float *in_b, float *out_c, - int numARows, int numAColumns, - int numBRows, int numBColumns, - int numCRows, int numCColumns, - cudaStream_t stream) { - - // Initialize the grid and block dimensions - dim3 dimBlock(TILE_WIDTH, TILE_WIDTH, 1); - dim3 dimGrid((numCColumns / TILE_WIDTH) + 1, (numCRows / TILE_WIDTH) + 1, 1); - //@@ Launch the GPU Kernel here - matrixMultiplyShared <<>> - (in_a, in_b, out_c, numARows, numAColumns, numBRows, numBColumns, numCRows, numCColumns); - - cudaError_t err = cudaGetLastError(); - if (cudaSuccess != err) - throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +void MMGPUKernel(float* in_a, float* in_b, float* out_c, int numARows, + int numAColumns, int numBRows, int numBColumns, int numCRows, + int numCColumns, cudaStream_t stream) { + // Initialize the grid and block dimensions + dim3 dimBlock(TILE_WIDTH, TILE_WIDTH, 1); + dim3 dimGrid((numCColumns / TILE_WIDTH) + 1, (numCRows / TILE_WIDTH) + 1, 1); + //@@ Launch the GPU Kernel here + matrixMultiplyShared<<>>( + in_a, in_b, out_c, numARows, numAColumns, numBRows, numBColumns, numCRows, + numCColumns); + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } - - -template -__global__ -__launch_bounds__(512) -void HGEMV_WFPerRow(int m, int n, const _Float16 *A, int lda, const _Float16 *x, _Float16 *y) -{ +template +__global__ __launch_bounds__(512) void HGEMV_WFPerRow( + int m, int n, const _Float16* A, int lda, const _Float16* x, _Float16* y) { int num_row_per_block = CTA / nThreads_per_row; - int row_id = (blockIdx.x*num_row_per_block+threadIdx.y)*MT0; - int inc = (gridDim.x * num_row_per_block)*MT0; + int row_id = (blockIdx.x * num_row_per_block + threadIdx.y) * MT0; + int inc = (gridDim.x * num_row_per_block) * MT0; while (row_id < m) { float2 sum2[MT0]; #pragma unroll - for (int i = 0; i < MT0; ++i) - { - sum2[i] = {0.0,0.0}; + for (int i = 0; i < MT0; ++i) { + sum2[i] = {0.0, 0.0}; } - for (int j = threadIdx.x; j < n; j += (nThreads_per_row*MT1)){ - bool is_active = j < n; - if (is_active) { - float2 x2[MT1>>1]; + for (int j = threadIdx.x; j < n; j += (nThreads_per_row * MT1)) { + bool is_active = j < n; + if (is_active) { + float2 x2[MT1 >> 1]; #pragma unroll - for(int offset = 0; offset < MT1; offset += 2) - { - x2[offset>>1] = {x[j+nThreads_per_row*offset], x[j+nThreads_per_row*(offset+1)]}; - } - float2 a2[MT0][MT1>>1]; + for (int offset = 0; offset < MT1; offset += 2) { + x2[offset >> 1] = {x[j + nThreads_per_row * offset], + x[j + nThreads_per_row * (offset + 1)]}; + } + float2 a2[MT0][MT1 >> 1]; #pragma unroll - for (int i = 0; i < MT0; i++) - { + for (int i = 0; i < MT0; i++) { #pragma unroll - for (int offset = 0; offset < MT1; offset += 2) - { - a2[i][offset>>1] = {A[(row_id+i)*n+j+nThreads_per_row*offset], A[(row_id+i)*n+j+nThreads_per_row*(offset+1)]}; - } - } + for (int offset = 0; offset < MT1; offset += 2) { + a2[i][offset >> 1] = { + A[(row_id + i) * n + j + nThreads_per_row * offset], + A[(row_id + i) * n + j + nThreads_per_row * (offset + 1)]}; + } + } #pragma unroll - for (int i = 0; i < MT0; i++) - { + for (int i = 0; i < MT0; i++) { #pragma unroll - for (int offset = 0; offset < (MT1>>1); offset++) - { - sum2[i] += a2[i][offset]*x2[offset]; - } - } - + for (int offset = 0; offset < (MT1 >> 1); offset++) { + sum2[i] += a2[i][offset] * x2[offset]; + } } + } } float sum[MT0]; #pragma unroll - for (int i = 0; i < MT0; i++) - { - sum[i] = sum2[i].x+sum2[i].y; + for (int i = 0; i < MT0; i++) { + sum[i] = sum2[i].x + sum2[i].y; } #pragma unroll - for (int i = 0; i < MT0; i++) - { -#pragma unroll - for (int offset = nThreads_per_row >> 1; offset >= 1; offset = offset >> 1) { - sum[i] += __shfl_down(sum[i], offset, nThreads_per_row); - } + for (int i = 0; i < MT0; i++) { +#pragma unroll + for (int offset = nThreads_per_row >> 1; offset >= 1; + offset = offset >> 1) { + sum[i] += __shfl_down(sum[i], offset, nThreads_per_row); + } } - if (threadIdx.x == 0) - { + if (threadIdx.x == 0) { #pragma unroll - for (int i = 0; i < MT0; i++) - { - y[row_id+i] = sum[i]; - } + for (int i = 0; i < MT0; i++) { + y[row_id + i] = sum[i]; + } } row_id += inc; } } -void LLGemmZZ(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int solidx=0) { - //m -> M, n-> K - dim3 grid(1024); - dim3 block(64, 8); - if (solidx==0) { - HGEMV_WFPerRow<64, 512, 4, 8><<>>(M, K, reinterpret_cast(in_a), K, - reinterpret_cast(in_b),reinterpret_cast<_Float16*>(out_c)); - } - else if (solidx==1) { - HGEMV_WFPerRow<64, 512, 2, 8><<>>(M, K, reinterpret_cast(in_a), K, - reinterpret_cast(in_b),reinterpret_cast<_Float16*>(out_c)); - } - else if (solidx==2) { - HGEMV_WFPerRow<64, 512, 1, 8><<>>(M, K, reinterpret_cast(in_a), K, - reinterpret_cast(in_b),reinterpret_cast<_Float16*>(out_c)); - } - else { - HGEMV_WFPerRow<64, 512, 4, 8><<>>(M, K, reinterpret_cast(in_a), K, - reinterpret_cast(in_b),reinterpret_cast<_Float16*>(out_c)); - } - cudaError_t err = cudaGetLastError(); - if (cudaSuccess != err) - throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int solidx = 0) { + // m -> M, n-> K + dim3 grid(1024); + dim3 block(64, 8); + if (solidx == 0) { + HGEMV_WFPerRow<64, 512, 4, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } else if (solidx == 1) { + HGEMV_WFPerRow<64, 512, 2, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } else if (solidx == 2) { + HGEMV_WFPerRow<64, 512, 1, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } else { + HGEMV_WFPerRow<64, 512, 4, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } diff --git a/csrc/custom/fused_kernels.cu b/csrc/custom/fused_kernels.cu index 5a4a11f914eb9..4f3eea4562949 100644 --- a/csrc/custom/fused_kernels.cu +++ b/csrc/custom/fused_kernels.cu @@ -5,188 +5,191 @@ constexpr int WARP_SIZE = 64; -template +template __device__ __forceinline__ T silu(const T& x) { // x * sigmoid(x) - return (T) (((float) x) / (1.0f + expf((float) -x))); + return (T)(((float)x) / (1.0f + expf((float)-x))); } template __device__ __forceinline__ T loadnt(T* addr) { - return __builtin_nontemporal_load(addr); + return __builtin_nontemporal_load(addr); } __device__ __forceinline__ float4 load_ntmprl(const float4* addr) { - auto addr_alias = reinterpret_cast(addr); - auto dat0 = loadnt(addr_alias); - auto dat1 = loadnt(addr_alias + 1); - auto dat2 = loadnt(addr_alias + 2); - auto dat3 = loadnt(addr_alias + 3); - //auto dat0 = *(addr_alias); - //auto dat1 = *(addr_alias+1); - //auto dat2 = *(addr_alias+2); - //auto dat3 = *(addr_alias+3); - return make_float4(dat0,dat1,dat2,dat3); + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + // auto dat0 = *(addr_alias); + // auto dat1 = *(addr_alias+1); + // auto dat2 = *(addr_alias+2); + // auto dat3 = *(addr_alias+3); + return make_float4(dat0, dat1, dat2, dat3); } -//TBlock fetches entire rows of A, and entire col of B (K dimension); assume N=1 for time being -//grid is M/A_NUM_ROWS blocks +// TBlock fetches entire rows of A, and entire col of B (K dimension); assume +// N=1 for time being grid is M/A_NUM_ROWS blocks template -__global__ void LLGemm_Silu_kernel(float4 *af4, __half2 *bf4, _Float16 *c, const int d) { - __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; - const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK/2 * blockDim.x; - const int row_addr_d = row_addr + d * blockDim.x; - //int row_addr_1 = row_addr + CUDA_NUM_THREADS; - //int row_addr_2 = row_addr_1 + CUDA_NUM_THREADS; - //int row_addr_3 = row_addr_2 + CUDA_NUM_THREADS; - const int threadid = threadIdx.x; - const int warp = threadIdx.x / WARP_SIZE; - const int lane = threadIdx.x % WARP_SIZE; - const int num_warps = blockDim.x / WARP_SIZE; - const int qwarpid = threadid/16; - const int qthreadid = threadid%16; - float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; - //float4 colB_elem4; - __half2 colB_elem4x,colB_elem4y,colB_elem4z,colB_elem4w; - float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; - float acc[NUM_A_ROWS_PER_BLOCK]; //= 0.0; - __half2 acch2; - __half2 oval; - - //rowA_elem4 = af4[row_addr + threadid]; - //__syncthreads(); - //rowA_elem4_1 = af4[row_addr_1 + threadid]; - //rowA_elem4_2 = af4[row_addr_2 + threadid]; - //rowA_elem4_3 = af4[row_addr_3 + threadid]; - #pragma unroll - for (int i=0; i(&colB_elem4); - //auto Bf2x = *Bh2ptr; - //auto Bf2y = *(Bh2ptr+1); - //auto Bf2z = *(Bh2ptr+2); - //auto Bf2w = *(Bh2ptr+3); - auto Ah2ptr = reinterpret_cast<__half2 *>(&rowA_elem4); - __half2 *ah2lptr; - #pragma unroll - for (int i=0; i= 1; mask /= 2) { - #pragma unroll - for (int i=0; i= 1; mask /= 2) { - //#pragma unroll - //for (int i=0; i(&colB_elem4); + // auto Bf2x = *Bh2ptr; + // auto Bf2y = *(Bh2ptr+1); + // auto Bf2z = *(Bh2ptr+2); + // auto Bf2w = *(Bh2ptr+3); + auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4); + __half2* ah2lptr; +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + ah2lptr = Ah2ptr + i * 4; + Af2 = *(ah2lptr); + acch2 = __hmul2(Af2, colB_elem4x); + Af2 = *(ah2lptr + 1); + acch2 = __hfma2(Af2, colB_elem4y, acch2); + Af2 = *(ah2lptr + 2); + acch2 = __hfma2(Af2, colB_elem4z, acch2); + Af2 = *(ah2lptr + 3); + acch2 = __hfma2(Af2, colB_elem4w, acch2); + S = __half22float2(acch2); + acc[i] = S.x + S.y; + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + acc[i] += __shfl_xor(acc[i], mask); + } + } + + // Warp leaders store the data to shared memory. + // if (lane == 0) { + // #pragma unroll + // for (int i=0; i= 1; mask /= 2) { + // #pragma unroll + // for (int i=0; i -void LLGemm_Silu(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int rows_per_block=4) { - float4 *af4 = reinterpret_cast(in_a); - auto *bf4 = reinterpret_cast<__half2*>(in_b); - auto *c = reinterpret_cast<_Float16*>(out_c); - const int d = M/2; - const int NUM_THREADS = K*2/16; - int NUM_BLOCKS = M/rows_per_block; - if (rows_per_block==2) { - LLGemm_Silu_kernel<2><<>>(af4, bf4, c, d); - } - else if (rows_per_block==4) { - LLGemm_Silu_kernel<4><<>>(af4, bf4, c, d); - } - else if (rows_per_block==8) { - LLGemm_Silu_kernel<8><<>>(af4, bf4, c, d); - } - else if (rows_per_block==16) { - LLGemm_Silu_kernel<16><<>>(af4, bf4, c, d); - } - else { - NUM_BLOCKS = M/4; - LLGemm_Silu_kernel<4><<>>(af4, bf4, c, d); - } - - - cudaError_t err = cudaGetLastError(); - if (cudaSuccess != err) - throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +// template +void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block = 4) { + float4* af4 = reinterpret_cast(in_a); + auto* bf4 = reinterpret_cast<__half2*>(in_b); + auto* c = reinterpret_cast<_Float16*>(out_c); + const int d = M / 2; + const int NUM_THREADS = K * 2 / 16; + int NUM_BLOCKS = M / rows_per_block; + if (rows_per_block == 2) { + LLGemm_Silu_kernel<2> + <<>>(af4, bf4, c, d); + } else if (rows_per_block == 4) { + LLGemm_Silu_kernel<4> + <<>>(af4, bf4, c, d); + } else if (rows_per_block == 8) { + LLGemm_Silu_kernel<8> + <<>>(af4, bf4, c, d); + } else if (rows_per_block == 16) { + LLGemm_Silu_kernel<16> + <<>>(af4, bf4, c, d); + } else { + NUM_BLOCKS = M / 4; + LLGemm_Silu_kernel<4> + <<>>(af4, bf4, c, d); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } - diff --git a/csrc/custom/paged_attention/attention_ll4mi.cu b/csrc/custom/paged_attention/attention_ll4mi.cu index 6c9e84ab2f5f4..dcabc7932cfd5 100644 --- a/csrc/custom/paged_attention/attention_ll4mi.cu +++ b/csrc/custom/paged_attention/attention_ll4mi.cu @@ -1,4 +1,4 @@ -//TODO: add license terms +// TODO: add license terms #include #include #include @@ -14,9 +14,12 @@ #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; -using float16x4 = __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; +using float16x4 = + __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; typedef float16x4 _Half4; -typedef struct _Half8 { _Half4 xy[2]; } _Half8; +typedef struct _Half8 { + _Half4 xy[2]; +} _Half8; ////// Non temporal load stores /////// #if 1 @@ -39,68 +42,62 @@ __device__ __forceinline__ T load(const T* addr) { } template <> -__device__ __forceinline__ -float2 load (const float2* addr) { - auto addr_alias { reinterpret_cast(addr) }; +__device__ __forceinline__ float2 load(const float2* addr) { + auto addr_alias{reinterpret_cast(addr)}; auto result = __builtin_nontemporal_load(addr_alias); - auto ret = reinterpret_cast(&result); + auto ret = reinterpret_cast(&result); return ret[0]; } template <> -__device__ __forceinline__ -float4 load (const float4* addr) { - auto addr_alias { reinterpret_cast(addr) }; +__device__ __forceinline__ float4 load(const float4* addr) { + auto addr_alias{reinterpret_cast(addr)}; auto result1 = __builtin_nontemporal_load(addr_alias); auto result2 = __builtin_nontemporal_load(addr_alias + 1); float4 ret{}; - auto ret_alias = reinterpret_cast(&result1); + auto ret_alias = reinterpret_cast(&result1); ret.x = ret_alias->x; ret.y = ret_alias->y; - ret_alias = reinterpret_cast(&result2); + ret_alias = reinterpret_cast(&result2); ret.z = ret_alias->x; ret.w = ret_alias->y; return ret; } template <> -__device__ __forceinline__ -__half load (const __half* addr) { - auto addr_alias { reinterpret_cast(addr) }; +__device__ __forceinline__ __half load(const __half* addr) { + auto addr_alias{reinterpret_cast(addr)}; auto result = __builtin_nontemporal_load(addr_alias); - auto ret = reinterpret_cast<__half *>(&result); + auto ret = reinterpret_cast<__half*>(&result); return ret[0]; } template <> -__device__ __forceinline__ -__half2 load (const __half2* addr) { - auto addr_alias { reinterpret_cast(addr) }; +__device__ __forceinline__ __half2 load(const __half2* addr) { + auto addr_alias{reinterpret_cast(addr)}; auto result = __builtin_nontemporal_load(addr_alias); - auto ret = reinterpret_cast<__half2 *>(&result); + auto ret = reinterpret_cast<__half2*>(&result); return ret[0]; } template <> -__device__ __forceinline__ -vllm::Half4_ load (const vllm::Half4_* addr) { - auto addr_alias { reinterpret_cast(addr) }; +__device__ __forceinline__ vllm::Half4_ load(const vllm::Half4_* addr) { + auto addr_alias{reinterpret_cast(addr)}; auto result = __builtin_nontemporal_load(addr_alias); - auto ret = reinterpret_cast(&result); + auto ret = reinterpret_cast(&result); return ret[0]; } template <> -__device__ __forceinline__ -vllm::Half8_ load (const vllm::Half8_* addr) { - auto addr_alias { reinterpret_cast(addr) }; +__device__ __forceinline__ vllm::Half8_ load(const vllm::Half8_* addr) { + auto addr_alias{reinterpret_cast(addr)}; auto result1 = __builtin_nontemporal_load(addr_alias); auto result2 = __builtin_nontemporal_load(addr_alias + 1); - vllm::Half8_ ret {}; - auto ret_alias = reinterpret_cast(&result1); + vllm::Half8_ ret{}; + auto ret_alias = reinterpret_cast(&result1); ret.x = ret_alias->x; ret.y = ret_alias->y; - ret_alias = reinterpret_cast(&result2); + ret_alias = reinterpret_cast(&result2); ret.z = ret_alias->x; ret.w = ret_alias->y; return ret; @@ -116,394 +113,456 @@ __device__ __forceinline__ void store(T value, T* addr) { /////////////////////////////////////// -//grid (num_seqs, num_partitions,num_heads/gqa_ratio) -//block (partition size) -template +// grid (num_seqs, num_partitions,num_heads/gqa_ratio) +// block (partition size) +template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] #if 0 scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] #endif - int max_ctx_blocks - ) { - constexpr int NWARPS = NUM_THREADS/WARP_SIZE; - const int warpid = threadIdx.x / WARP_SIZE; - const int laneid = threadIdx.x % WARP_SIZE; - const int lane4id = laneid%4; - - const int seq_idx = blockIdx.x; - const int partition_idx = blockIdx.y; - const int partition_size = blockDim.x; - const int max_num_partitions = gridDim.y; - - const int context_len = context_lens[seq_idx]; - const int partition_start_token_idx = partition_idx * partition_size; - //exit if partition is out of context for seq - if (partition_start_token_idx >= context_len) { - return; - } - constexpr int QHLOOP = DIVIDE_ROUND_UP(GQA_RATIO,4); // each 4 lanes fetch 4 different qheads, total qheads =8, so qhloop is 2 - constexpr int GQA_RATIO4 = 4*QHLOOP; - __shared__ float shared_qk_max[NWARPS][GQA_RATIO4+1]; - __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4+1]; - _Half8 Qlocal[QHLOOP]; - constexpr int x = 16 / sizeof(scalar_t); - constexpr int KHELOOP = HEAD_SIZE/x; - _Half8 Klocal[KHELOOP]; - constexpr int VHELOOP = HEAD_SIZE/WARP_SIZE; //v head_size dimension is distributed across lanes - constexpr int VTLOOP = 8; //16 separate 4xtokens across warp -> 16/2 8xtokens - _Half8 Vlocal[VHELOOP][VTLOOP]; - floatx4 dout[QHLOOP]; - float qk_max[QHLOOP]; - #pragma unroll - for (int h=0; h= context_len) { + return; + } + constexpr int QHLOOP = + DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads, + // total qheads =8, so qhloop is 2 + constexpr int GQA_RATIO4 = 4 * QHLOOP; + __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; + __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1]; + _Half8 Qlocal[QHLOOP]; + constexpr int x = 16 / sizeof(scalar_t); + constexpr int KHELOOP = HEAD_SIZE / x; + _Half8 Klocal[KHELOOP]; + constexpr int VHELOOP = + HEAD_SIZE / + WARP_SIZE; // v head_size dimension is distributed across lanes + constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 + // 8xtokens + _Half8 Vlocal[VHELOOP][VTLOOP]; + floatx4 dout[QHLOOP]; + float qk_max[QHLOOP]; +#pragma unroll + for (int h = 0; h < QHLOOP; h++) { + dout[h] = {0}; + qk_max[h] = -FLT_MAX; + } - const int warp_start_token_idx = partition_start_token_idx + warpid*WARP_SIZE; + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; - if (warp_start_token_idx >= context_len) { //warp out of context - #pragma unroll - for(int h=0;h= context_len) { // warp out of context +#pragma unroll + for (int h = 0; h < GQA_RATIO4; h++) { + shared_qk_max[warpid][h] = -FLT_MAX; + shared_exp_sum[warpid][h] = 0.0f; + } + } else { // warp within context + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + const int local_token_idx = threadIdx.x; + const int global_token_idx = partition_start_token_idx + local_token_idx; + + const int block_idx = (global_token_idx < context_len) + ? global_token_idx / BLOCK_SIZE + : last_ctx_block; + + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems + const scalar_t* q_ptr = + q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; + const _Half8* q_ptrh8 = reinterpret_cast(q_ptr); + const int qhead_elemh8 = laneid / 4; +#pragma unroll + for (int h = 0; h < QHLOOP - 1; h++) { + const int qhead_idx = h * 4 + lane4id; + Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; + } + const int final_qhead_idx = 4 * (QHLOOP - 1) + lane4id; + if (final_qhead_idx < GQA_RATIO) { + Qlocal[QHLOOP - 1] = + q_ptrh8[final_qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; + } else { + Qlocal[QHLOOP - 1].xy[0] = {0}; + Qlocal[QHLOOP - 1].xy[1] = {0}; + } - const int local_token_idx = threadIdx.x; - const int global_token_idx = partition_start_token_idx + local_token_idx; + const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + wg_start_kv_head_idx * kv_head_stride; + const _Half8* k_ptrh8 = reinterpret_cast(k_ptr); - const int block_idx = (global_token_idx < context_len) ? global_token_idx / BLOCK_SIZE : last_ctx_block; + const int physical_block_offset = + local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset + // is already cast as _H8 - //int32 physical_block_number leads to overflow when multiplied with kv_block_stride - const int64_t physical_block_number = static_cast(block_table[block_idx]); +#pragma unroll + for (int d = 0; d < KHELOOP; d++) { + Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; + } - //each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems - const scalar_t* q_ptr = q + seq_idx*q_stride + wg_start_head_idx*HEAD_SIZE; - const _Half8* q_ptrh8 = reinterpret_cast(q_ptr); - const int qhead_elemh8 = laneid/4; - #pragma unroll - for (int h=0; h(k_ptr); - - const int physical_block_offset = local_token_idx%BLOCK_SIZE; //since x=half8, physical_block_offset is already cast as _H8 - + } - #pragma unroll - for (int d=0;d(v_ptr); +// iterate over each v block +#pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _Half8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; +// iterate over each head elem (within head_size) +#pragma unroll + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _Half8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; +// iterate over all velems within block +#pragma unroll + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; } } + } - constexpr int VBLOCKS=8*VTLOOP/BLOCK_SIZE; - int vphysical_blocks[VBLOCKS]; - - const int warp_start_block_idx = warp_start_token_idx/BLOCK_SIZE; - //fetch vphysical block numbers - #pragma unroll - for (int b=0;b 8) { + dout[h] = + GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[8].xy[0], dout[h], 4, 8, 0); + dout[h] = + GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[8].xy[1], dout[h], 4, 8, 0); + dout[h] = + GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[9].xy[0], dout[h], 4, 9, 0); + dout[h] = + GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[9].xy[1], dout[h], 4, 9, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[10].xy[0], dout[h], 4, + 10, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[10].xy[1], dout[h], 4, + 10, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[11].xy[0], dout[h], 4, + 11, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[11].xy[1], dout[h], 4, + 11, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[12].xy[0], dout[h], 4, + 12, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[12].xy[1], dout[h], 4, + 12, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[13].xy[0], dout[h], 4, + 13, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[13].xy[1], dout[h], 4, + 13, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[14].xy[0], dout[h], 4, + 14, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[14].xy[1], dout[h], 4, + 14, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[15].xy[0], dout[h], 4, + 15, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[15].xy[1], dout[h], 4, + 15, 0); + } // KHELOOP>8 + dout[h] *= scale; + } +// transpose dout so that 4 token ids are in each lane, and 4 heads are across 4 +// lanes +#pragma unroll + for (int h = 0; h < QHLOOP; h++) { + floatx4 tmp = {0}; +#pragma unroll + for (int i = 0; i < 4; i++) { + const float B = (lane4id == i) ? 1.0f : 0.0f; + // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f; + tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0); + // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0); } + dout[h] = tmp; + } - const scalar_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; - const _Half8* v_ptrh8 = reinterpret_cast(v_ptr); - //iterate over each v block - #pragma unroll - for (int b=0;b(vphysical_blocks[b]); - const _Half8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride)/8; - //iterate over each head elem (within head_size) - #pragma unroll - for (int h=0;h> 2); + const int alibi_offset = lane4_token_idx - context_len + 1; + if (alibi_slopes != nullptr) { +#pragma unroll + for (int h = 0; h < QHLOOP; h++) { +#pragma unroll + for (int i = 0; i < 4; i++) { + dout[h][i] += alibi_slope[h] * (alibi_offset + i); } } + } - #pragma unroll - for (int h=0;h8) { - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[8].xy[0], dout[h], 4, 8, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[8].xy[1], dout[h], 4, 8, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[9].xy[0], dout[h], 4, 9, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[9].xy[1], dout[h], 4, 9, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[10].xy[0], dout[h], 4, 10, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[10].xy[1], dout[h], 4, 10, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[11].xy[0], dout[h], 4, 11, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[11].xy[1], dout[h], 4, 11, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[12].xy[0], dout[h], 4, 12, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[12].xy[1], dout[h], 4, 12, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[13].xy[0], dout[h], 4, 13, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[13].xy[1], dout[h], 4, 13, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[14].xy[0], dout[h], 4, 14, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[14].xy[1], dout[h], 4, 14, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[15].xy[0], dout[h], 4, 15, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[15].xy[1], dout[h], 4, 15, 0); - } //KHELOOP>8 - dout[h]*=scale; - } - //transpose dout so that 4 token ids are in each lane, and 4 heads are across 4 lanes - #pragma unroll - for (int h=0;h>2); - const int alibi_offset = lane4_token_idx - context_len + 1; - if (alibi_slopes != nullptr) { - #pragma unroll - for (int h=0;h=4; mask/=2) { - qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h],mask)); - } - } - - float exp_sum[QHLOOP]; - #pragma unroll - for (int h=0;h=4; mask/=2) { - exp_sum[h] += __shfl_xor(exp_sum[h],mask); - } - } +#pragma unroll + for (int h = 0; h < QHLOOP; h++) { + qk_max[h] = -FLT_MAX; +#pragma unroll + for (int i = 0; i < 4; i++) { + qk_max[h] = (lane4_token_idx + i < context_len) + ? fmaxf(qk_max[h], dout[h][i]) + : qk_max[h]; + } +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); + } + } + float exp_sum[QHLOOP]; +#pragma unroll + for (int h = 0; h < QHLOOP; h++) { + exp_sum[h] = 0.0f; +#pragma unroll + for (int i = 0; i < 4; i++) { + dout[h][i] = (lane4_token_idx + i < context_len) + ? __expf(dout[h][i] - qk_max[h]) + : 0.0f; + exp_sum[h] += dout[h][i]; + } +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + exp_sum[h] += __shfl_xor(exp_sum[h], mask); + } + } - #pragma unroll - for (int h=0;h every 4 lanes hold 4 heads, each lane holds 4 tokens, there are 4x16 tokens across warp - float16x4 logits[QHLOOP]; - #pragma unroll - for (int h=0;h every 4 lanes hold 4 heads, each lane holds 4 tokens, there + // are 4x16 tokens across warp + float16x4 logits[QHLOOP]; +#pragma unroll + for (int h = 0; h < QHLOOP; h++) { +#pragma unroll + for (int i = 0; i < 4; i++) { + logits[h][i] = (scalar_t)dout[h][i]; + } + } - __shared__ float16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS+1]; + __shared__ float16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; - if (warp_start_token_idx >= context_len) { //warp out of context - #pragma unroll - for (int qh=0; qh= context_len) { // warp out of context +#pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { +#pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + vout_shared[qh][vh][laneid][warpid] = {0}; + } } - else{//warp in context - //iterate across heads - #pragma unroll - for (int qh=0; qh partition_size) { + out_num_partitions = max_num_partitions; + out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + } else { + out_num_partitions = 1; + out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; } - }//warp in context - - __syncthreads(); - - if (warpid==0) { - float16x4 vout[QHLOOP][VHELOOP]; - //iterate across heads - scalar_t* out_ptr; - int out_num_partitions; - if (context_len > partition_size) { - out_num_partitions = max_num_partitions; - out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + partition_idx * HEAD_SIZE; - } else { - out_num_partitions = 1; - out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; +#pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { +// iterate over each v head elem (within head_size) +#pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + vout[qh][vh] = {0}; +#pragma unroll + for (int w = 0; w < NWARPS; w++) { + vout[qh][vh] += vout_shared[qh][vh][laneid][w]; } - #pragma unroll - for (int qh=0; qh -__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_partitions) { - const int num_heads = gridDim.x; - const int head_idx = blockIdx.x; - const int seq_idx = blockIdx.y; - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); - if (num_partitions == 1) { - //if num_partitions==1, main kernel will write to out directly, no work in reduction kernel - return; - } - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int warpid = threadIdx.x / WARP_SIZE; - const int laneid = threadIdx.x % WARP_SIZE; - - __shared__ float shared_global_exp_sum; - __shared__ float shared_exp_sums[2*WARP_SIZE]; - - if (warpid==0) { - - const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // if num_partitions==1, main kernel will write to out directly, no work in + // reduction kernel + return; + } - //valid partition is the last valid partition in case threadid > num partitions - const int valid_partition = (threadIdx.x < num_partitions) ? threadIdx.x : num_partitions-1; - const int valid_partition2 = (WARP_SIZE+threadIdx.x < num_partitions) ? WARP_SIZE+threadIdx.x : num_partitions-1; + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + + __shared__ float shared_global_exp_sum; + __shared__ float shared_exp_sums[2 * WARP_SIZE]; + + if (warpid == 0) { + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + // valid partition is the last valid partition in case threadid > num + // partitions + const int valid_partition = + (threadIdx.x < num_partitions) ? threadIdx.x : num_partitions - 1; + const int valid_partition2 = (WARP_SIZE + threadIdx.x < num_partitions) + ? WARP_SIZE + threadIdx.x + : num_partitions - 1; float reg_max_logit = max_logits_ptr[valid_partition]; float reg_max_logit2 = max_logits_ptr[valid_partition2]; - float max_logit = fmaxf(reg_max_logit,reg_max_logit2); + float max_logit = fmaxf(reg_max_logit, reg_max_logit2); - #pragma unroll +#pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); } - const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float global_exp_sum = 0.0f; float rescaled_exp_sum = exp_sums_ptr[valid_partition]; float rescaled_exp_sum2 = exp_sums_ptr[valid_partition2]; - rescaled_exp_sum *= (threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f; - rescaled_exp_sum2 *= (threadIdx.x+WARP_SIZE < num_partitions) ? expf(reg_max_logit2 - max_logit) : 0.0f; + rescaled_exp_sum *= + (threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f; + rescaled_exp_sum2 *= (threadIdx.x + WARP_SIZE < num_partitions) + ? expf(reg_max_logit2 - max_logit) + : 0.0f; global_exp_sum += rescaled_exp_sum + rescaled_exp_sum2; shared_exp_sums[threadIdx.x] = rescaled_exp_sum; - shared_exp_sums[threadIdx.x+WARP_SIZE] = rescaled_exp_sum2; + shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2; - #pragma unroll +#pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { global_exp_sum += __shfl_xor(global_exp_sum, mask); } - if (threadIdx.x==0) { + if (threadIdx.x == 0) { shared_global_exp_sum = global_exp_sum; } - }//warpid == 0 - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; - constexpr int MAX_NPAR = 64; - scalar_t tmps[MAX_NPAR]; - #pragma unroll - for (int j = 0; j < MAX_NPAR; j++) { - tmps[j] = 0.0f; + } // warpid == 0 + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; + constexpr int MAX_NPAR = 64; + scalar_t tmps[MAX_NPAR]; +#pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + tmps[j] = 0.0f; + } + const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE; + const int num_partition_offset = (num_partitions)*HEAD_SIZE; + int idx = 0; + + constexpr int JCHUNK = 16; + +#pragma unroll + for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + __syncthreads(); + + if (num_partitions > JCHUNK) { +#pragma unroll + for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; } - const int last_partition_offset = (num_partitions-1)*HEAD_SIZE; - const int num_partition_offset = (num_partitions)*HEAD_SIZE; - int idx=0; - constexpr int JCHUNK = 16; - - #pragma unroll - for (int j = 0; j < JCHUNK*HEAD_SIZE; j+=HEAD_SIZE) { - //lastj is last valid partition - const int lastj_offset = (j 2 * JCHUNK) { +#pragma unroll + for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; tmps[idx] = tmp_out_ptr[lastj_offset]; idx++; + } } - __syncthreads(); - - if (num_partitions > JCHUNK) { - #pragma unroll - for (int j = JCHUNK*HEAD_SIZE; j < 2*JCHUNK*HEAD_SIZE; j+=HEAD_SIZE) { - const int lastj_offset = (j JCHUNK - if (num_partitions > 2*JCHUNK) { - #pragma unroll - for (int j = 2*JCHUNK*HEAD_SIZE; j < MAX_NPAR*HEAD_SIZE; j+=HEAD_SIZE) { - const int lastj_offset = (j JCHUNK - - // Aggregate tmp_out to out. - float acc = 0.0f; - #pragma unroll - for (int j = 0; j < JCHUNK; j++) { + // Aggregate tmp_out to out. + float acc = 0.0f; +#pragma unroll + for (int j = 0; j < JCHUNK; j++) { + acc += tmps[j] * shared_exp_sums[j]; + } + if (num_partitions > JCHUNK) { +#pragma unroll + for (int j = JCHUNK; j < 2 * JCHUNK; j++) { acc += tmps[j] * shared_exp_sums[j]; } - if (num_partitions > JCHUNK) { - #pragma unroll - for (int j = JCHUNK; j < 2*JCHUNK; j++) { - acc += tmps[j] * shared_exp_sums[j]; - } - if (num_partitions > 2*JCHUNK) { - #pragma unroll - for (int j = 2*JCHUNK; j < MAX_NPAR; j++) { - acc += tmps[j] * shared_exp_sums[j]; - } - } + if (num_partitions > 2 * JCHUNK) { +#pragma unroll + for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) { + acc += tmps[j] * shared_exp_sums[j]; + } } + } - if (num_partitions > MAX_NPAR) { - idx=0; - #pragma unroll - for (int j = MAX_NPAR*HEAD_SIZE; j < 2*MAX_NPAR*HEAD_SIZE; j+=HEAD_SIZE) { - //lastj is last valid partition - const int lastj_offset = (j MAX_NPAR) { + idx = 0; +#pragma unroll + for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; } - const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f); - acc *= inv_global_exp_sum; - //from_float(out_ptr[threadIdx.x], acc); - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - out_ptr[threadIdx.x] = (scalar_t)acc; +#pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + acc += tmps[j] * shared_exp_sums[j + MAX_NPAR]; + } } + const float inv_global_exp_sum = + __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + acc *= inv_global_exp_sum; + // from_float(out_ptr[threadIdx.x], acc); + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + out_ptr[threadIdx.x] = (scalar_t)acc; +} + +#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ + paged_attention_ll4mi_QKV_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks); -#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ - paged_attention_ll4mi_QKV_kernel \ - <<>>( \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - context_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr,out_ptr,max_ctx_blocks); - -template +template void paged_attention_custom_launcher( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - const int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, const int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + int max_context_len, #if 0 torch::Tensor& qk_out, torch::Tensor& softmax_out, #endif - const c10::optional& alibi_slopes) { + const c10::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); @@ -712,9 +774,10 @@ void paged_attention_custom_launcher( int kv_head_stride = key_cache.stride(1); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -731,116 +794,152 @@ void paged_attention_custom_launcher( #endif const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); - const int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); - const int gqa_ratio = num_heads/num_kv_heads; - assert(num_heads%num_kv_heads==0); - assert(head_size==HEAD_SIZE); - assert(max_num_partitions<=128); + const int max_num_partitions = + DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int gqa_ratio = num_heads / num_kv_heads; + assert(num_heads % num_kv_heads == 0); + assert(head_size == HEAD_SIZE); + assert(max_num_partitions <= 128); constexpr int NTHR = PARTITION_SIZE; - dim3 grid(num_seqs,max_num_partitions,num_kv_heads); + dim3 grid(num_seqs, max_num_partitions, num_kv_heads); dim3 block(NTHR); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (gqa_ratio) { - case 1: LAUNCH_CUSTOM_ATTENTION(1); break; - case 2: LAUNCH_CUSTOM_ATTENTION(2); break; - case 3: LAUNCH_CUSTOM_ATTENTION(3); break; - case 4: LAUNCH_CUSTOM_ATTENTION(4); break; - case 5: LAUNCH_CUSTOM_ATTENTION(5); break; - case 6: LAUNCH_CUSTOM_ATTENTION(6); break; - case 7: LAUNCH_CUSTOM_ATTENTION(7); break; - case 8: LAUNCH_CUSTOM_ATTENTION(8); break; - case 9: LAUNCH_CUSTOM_ATTENTION(9); break; - case 10: LAUNCH_CUSTOM_ATTENTION(10); break; - case 11: LAUNCH_CUSTOM_ATTENTION(11); break; - case 12: LAUNCH_CUSTOM_ATTENTION(12); break; - case 13: LAUNCH_CUSTOM_ATTENTION(13); break; - case 14: LAUNCH_CUSTOM_ATTENTION(14); break; - case 15: LAUNCH_CUSTOM_ATTENTION(15); break; - case 16: LAUNCH_CUSTOM_ATTENTION(16); break; - default: - TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); - break; + case 1: + LAUNCH_CUSTOM_ATTENTION(1); + break; + case 2: + LAUNCH_CUSTOM_ATTENTION(2); + break; + case 3: + LAUNCH_CUSTOM_ATTENTION(3); + break; + case 4: + LAUNCH_CUSTOM_ATTENTION(4); + break; + case 5: + LAUNCH_CUSTOM_ATTENTION(5); + break; + case 6: + LAUNCH_CUSTOM_ATTENTION(6); + break; + case 7: + LAUNCH_CUSTOM_ATTENTION(7); + break; + case 8: + LAUNCH_CUSTOM_ATTENTION(8); + break; + case 9: + LAUNCH_CUSTOM_ATTENTION(9); + break; + case 10: + LAUNCH_CUSTOM_ATTENTION(10); + break; + case 11: + LAUNCH_CUSTOM_ATTENTION(11); + break; + case 12: + LAUNCH_CUSTOM_ATTENTION(12); + break; + case 13: + LAUNCH_CUSTOM_ATTENTION(13); + break; + case 14: + LAUNCH_CUSTOM_ATTENTION(14); + break; + case 15: + LAUNCH_CUSTOM_ATTENTION(15); + break; + case 16: + LAUNCH_CUSTOM_ATTENTION(16); + break; + default: + TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); + break; } - //dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG); - //dim3 block2(1024); - // LAUNCH_CUSTOM_ATTENTION2; - - //reduction kernel is only required if max_context_len > partition size, otherwise main kernel writes directly to final output - // note there are cases with graphing where max_context_len is the max supported by graphing, not the actual max among - // all the sequences: in that case reduction kernel will still run but return immediately + // dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG); + // dim3 block2(1024); + // LAUNCH_CUSTOM_ATTENTION2; + + // reduction kernel is only required if max_context_len > partition size, + // otherwise main kernel writes directly to final output + // note there are cases with graphing where max_context_len is the max + // supported by graphing, not the actual max among all the sequences: in that + // case reduction kernel will still run but return immediately if (max_context_len > PARTITION_SIZE) { dim3 reduce_grid(num_heads, num_seqs); dim3 reduce_block(head_size); paged_attention_ll4mi_reduce_kernel - <<>>( - out_ptr, - exp_sums_ptr, - max_logits_ptr, - tmp_out_ptr, - context_lens_ptr, - max_num_partitions); + <<>>( + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, + context_lens_ptr, max_num_partitions); } } -#define CALL_CUSTOM_LAUNCHER(T,BLK_SIZE,HEAD_SIZE) \ - paged_attention_custom_launcher( \ - out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len,\ - alibi_slopes); - -#define CALL_CUSTOM_LAUNCHER_BLK(T,HEAD_SIZE) \ - switch (block_size) { \ - case 8: CALL_CUSTOM_LAUNCHER(T,8,HEAD_SIZE); break; \ - case 16: CALL_CUSTOM_LAUNCHER(T,16,HEAD_SIZE); break; \ - case 32: CALL_CUSTOM_LAUNCHER(T,32,HEAD_SIZE); break; \ - default: TORCH_CHECK(false, "Unsupported block size: ", block_size); break; \ - } +#define CALL_CUSTOM_LAUNCHER(T, BLK_SIZE, HEAD_SIZE) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ + alibi_slopes); + +#define CALL_CUSTOM_LAUNCHER_BLK(T, HEAD_SIZE) \ + switch (block_size) { \ + case 8: \ + CALL_CUSTOM_LAUNCHER(T, 8, HEAD_SIZE); \ + break; \ + case 16: \ + CALL_CUSTOM_LAUNCHER(T, 16, HEAD_SIZE); \ + break; \ + case 32: \ + CALL_CUSTOM_LAUNCHER(T, 32, HEAD_SIZE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } -#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T) \ - switch (head_size) { \ - case 64: CALL_CUSTOM_LAUNCHER_BLK(T,64); break; \ - case 128: CALL_CUSTOM_LAUNCHER_BLK(T,128); break; \ - default: TORCH_CHECK(false, "Unsupported head size: ", head_size); break; \ - } +#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T) \ + switch (head_size) { \ + case 64: \ + CALL_CUSTOM_LAUNCHER_BLK(T, 64); \ + break; \ + case 128: \ + CALL_CUSTOM_LAUNCHER_BLK(T, 128); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported head size: ", head_size); \ + break; \ + } void paged_attention_custom( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] - int block_size, - int max_context_len, + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int num_kv_heads, float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int block_size, int max_context_len, #if 0 torch::Tensor& qk_out, torch::Tensor& softmax_out, #endif - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype) { - const int head_size = query.size(2); - if (query.dtype() == at::ScalarType::Half) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype) { + const int head_size = query.size(2); + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } } #undef WARP_SIZE diff --git a/csrc/ops.h b/csrc/ops.h index d6cdfab434f2c..aa015c3d5dc39 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -113,25 +113,17 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); -void convert_fp8(torch::Tensor& dst_data, torch::Tensor& src_data, torch::Tensor& scale); +void convert_fp8(torch::Tensor& dst_data, torch::Tensor& src_data, + torch::Tensor& scale); #ifdef USE_ROCM -torch::Tensor fp8_gemm( - torch::Tensor& a, - torch::Tensor& b, - torch::Tensor& scaleA, - torch::Tensor& scaleB, - torch::Tensor& scaleD, - int algo_idx -); - -torch::Tensor fp8_gemm_16( - torch::Tensor& a, - torch::Tensor& b, - torch::Tensor& scaleA, - torch::Tensor& scaleB, - int algo_idx -); +torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b, + torch::Tensor& scaleA, torch::Tensor& scaleB, + torch::Tensor& scaleD, int algo_idx); + +torch::Tensor fp8_gemm_16(torch::Tensor& a, torch::Tensor& b, + torch::Tensor& scaleA, torch::Tensor& scaleB, + int algo_idx); #endif void moe_align_block_size(torch::Tensor topk_ids, int num_experts, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index a4693ccc2ae75..a507af396bcf9 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -67,8 +67,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Aligning the number of tokens to be processed by each expert such " "that it is divisible by the block size."); ops.def("convert_fp8", &convert_fp8, - "Convert the key and value cache to fp8 data type"); - + "Convert the key and value cache to fp8 data type"); + #ifdef USE_ROCM ops.def("fp8_gemm", &fp8_gemm, "fp8 GEMM with fp8 output"); ops.def("fp8_gemm_16", &fp8_gemm_16, "fp8 GEMM with fp16 output"); diff --git a/csrc/quantization/fp8/amd/gemm_kernel.cu b/csrc/quantization/fp8/amd/gemm_kernel.cu index 5464e9381e343..f8586b77d7792 100644 --- a/csrc/quantization/fp8/amd/gemm_kernel.cu +++ b/csrc/quantization/fp8/amd/gemm_kernel.cu @@ -12,258 +12,290 @@ #define max_workspace_size 2 * 128 * 1024 * 1024 #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) #ifndef CHECK_HIP_ERROR -#define CHECK_HIP_ERROR(error) \ - if (error != hipSuccess) { \ - fprintf(stderr, "Hip error: '%s'(%d) at %s:%d\n", hipGetErrorString(error), error, __FILE__, __LINE__); \ - exit(EXIT_FAILURE); \ + #define CHECK_HIP_ERROR(error) \ + if (error != hipSuccess) { \ + fprintf(stderr, "Hip error: '%s'(%d) at %s:%d\n", \ + hipGetErrorString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ } #endif #ifndef CHECK_HIPBLASLT_ERROR -#define CHECK_HIPBLASLT_ERROR(error) \ - if (error != HIPBLAS_STATUS_SUCCESS) { \ - fprintf( \ - stderr, "hipBLASLt error: '%s'(%d) at %s:%d\n", hipblasStatusToString(error), error, __FILE__, __LINE__); \ - exit(EXIT_FAILURE); \ + #define CHECK_HIPBLASLT_ERROR(error) \ + if (error != HIPBLAS_STATUS_SUCCESS) { \ + fprintf(stderr, "hipBLASLt error: '%s'(%d) at %s:%d\n", \ + hipblasStatusToString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ } #endif -torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b, torch::Tensor& scaleA, torch::Tensor& scaleB, - torch::Tensor& scaleD, int algo_idx) -{ - auto a_strides{a.strides()}; - auto b_strides{b.strides()}; - auto a_sizes{a.sizes()}; - auto b_sizes{b.sizes()}; - - // CHECK_INPUT(a); - // CHECK_INPUT(b); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fnuz && b.dtype() == torch::kFloat8_e4m3fnuz, - "The input tensors should be in fp8."); - TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Input tensors must be 2-D."); - TORCH_CHECK(a_sizes[1] == b_sizes[0], "a dim 1 must match b dim 0."); - - auto options{at::TensorOptions().dtype(torch::kFloat8_e4m3fnuz).device(at::kCUDA)}; - auto result{torch::empty({a_sizes[0], b_sizes[1]}, options)}; - - constexpr bool transpose_result = true; - bool transpose_a; - bool transpose_b; - if ((b_strides[0] == 1) && (b_strides[1] >= std::max(1, b_sizes[0]))) { - transpose_b = false; - } else if ((b_strides[1] == 1) && (b_strides[0] >= std::max(1, b_sizes[1]))) { - transpose_b = true; - } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); +torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b, + torch::Tensor& scaleA, torch::Tensor& scaleB, + torch::Tensor& scaleD, int algo_idx) { + auto a_strides{a.strides()}; + auto b_strides{b.strides()}; + auto a_sizes{a.sizes()}; + auto b_sizes{b.sizes()}; + + // CHECK_INPUT(a); + // CHECK_INPUT(b); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fnuz && + b.dtype() == torch::kFloat8_e4m3fnuz, + "The input tensors should be in fp8."); + TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Input tensors must be 2-D."); + TORCH_CHECK(a_sizes[1] == b_sizes[0], "a dim 1 must match b dim 0."); + + auto options{ + at::TensorOptions().dtype(torch::kFloat8_e4m3fnuz).device(at::kCUDA)}; + auto result{torch::empty({a_sizes[0], b_sizes[1]}, options)}; + + constexpr bool transpose_result = true; + bool transpose_a; + bool transpose_b; + if ((b_strides[0] == 1) && + (b_strides[1] >= std::max(1, b_sizes[0]))) { + transpose_b = false; + } else if ((b_strides[1] == 1) && + (b_strides[0] >= std::max(1, b_sizes[1]))) { + transpose_b = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + if ((a_strides[0] == 1) && + (a_strides[1] >= std::max(1, a_sizes[0]))) { + transpose_a = false; + } else if ((a_strides[1] == 1) && + (a_strides[0] >= std::max(1, a_sizes[1]))) { + transpose_a = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + + if (transpose_result) { + bool tmp = transpose_a; + transpose_a = !transpose_b; + transpose_b = !tmp; + a_strides = b.strides(); + b_strides = a.strides(); + a_sizes = b.sizes(); + b_sizes = a.sizes(); + } + + float alpha = 1.0f; + float beta = 0.0f; + int64_t m = a_sizes[transpose_result ? 1 : 0]; + int64_t k = a_sizes[transpose_result ? 0 : 1]; + int64_t n = b_sizes[transpose_result ? 0 : 1]; + + void* d_a = static_cast((transpose_result ? b : a).data_ptr()); + void* d_b = static_cast((transpose_result ? a : b).data_ptr()); + void* d_d = static_cast(result.data_ptr()); + + // void *d_scaleA, *d_scaleB, *d_workspace; + // CHECK_HIP_ERROR(hipMalloc(&d_scaleA, sizeof(float))); + // CHECK_HIP_ERROR(hipMalloc(&d_scaleB, sizeof(float))); + // CHECK_HIP_ERROR(hipMalloc(&d_workspace, max_workspace_size)); + // CHECK_HIP_ERROR(hipMemcpy(d_scaleA, &(transpose_result ? scaleB : scaleA), + // sizeof(float), hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(d_scaleB, + // &(transpose_result ? scaleA : scaleB), sizeof(float), + // hipMemcpyHostToDevice)); + auto d_scaleA = transpose_result ? scaleB.data_ptr() : scaleA.data_ptr(); + auto d_scaleB = transpose_result ? scaleA.data_ptr() : scaleB.data_ptr(); + auto d_scaleD = scaleD.data_ptr(); + + auto handle = at::cuda::getCurrentCUDABlasLtHandle(); + auto stream = at::cuda::getCurrentCUDAStream(); + + hipblaslt_ext::GemmPreference gemmPref; + gemmPref.setMaxWorkspaceBytes(0); + hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N, + HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, + HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, + HIPBLAS_COMPUTE_32F); + + hipblaslt_ext::GemmEpilogue + epilogue{}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. + // (Gemm only) + hipblaslt_ext::GemmInputs inputs; + inputs.a = d_a; + inputs.b = d_b; + inputs.c = d_d; + inputs.d = d_d; + inputs.alpha = α + inputs.beta = β + inputs.scaleA = d_scaleA; + inputs.scaleB = d_scaleB; + inputs.scaleD = d_scaleD; + gemm.setProblem(m, n, k, 1, epilogue, inputs); + if (algo_idx == 0) { + constexpr int request_solutions = 1024; + std::vector heuristicResult; + heuristicResult.reserve(request_solutions); + CHECK_HIPBLASLT_ERROR( + gemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult)); + static size_t solSize = 0; + if (heuristicResult.size() != solSize) { + std::cout << "fp8 sols: " << heuristicResult.size() << "\n"; + solSize = heuristicResult.size(); + for (auto& res : heuristicResult) { + auto idx = hipblaslt_ext::getIndexFromAlgo(res.algo); + std::cout << idx << "\n"; + } } - if ((a_strides[0] == 1) && (a_strides[1] >= std::max(1, a_sizes[0]))) { - transpose_a = false; - } else if ((a_strides[1] == 1) && (a_strides[0] >= std::max(1, a_sizes[1]))) { - transpose_a = true; - } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); - } - - if (transpose_result) { - bool tmp = transpose_a; - transpose_a = !transpose_b; - transpose_b = !tmp; - a_strides = b.strides(); - b_strides = a.strides(); - a_sizes = b.sizes(); - b_sizes = a.sizes(); - } - - float alpha = 1.0f; - float beta = 0.0f; - int64_t m = a_sizes[transpose_result ? 1 : 0]; - int64_t k = a_sizes[transpose_result ? 0 : 1]; - int64_t n = b_sizes[transpose_result ? 0 : 1]; - - void* d_a = static_cast((transpose_result ? b : a).data_ptr()); - void* d_b = static_cast((transpose_result ? a : b).data_ptr()); - void* d_d = static_cast(result.data_ptr()); - - // void *d_scaleA, *d_scaleB, *d_workspace; - // CHECK_HIP_ERROR(hipMalloc(&d_scaleA, sizeof(float))); - // CHECK_HIP_ERROR(hipMalloc(&d_scaleB, sizeof(float))); - // CHECK_HIP_ERROR(hipMalloc(&d_workspace, max_workspace_size)); - // CHECK_HIP_ERROR(hipMemcpy(d_scaleA, &(transpose_result ? scaleB : scaleA), sizeof(float), hipMemcpyHostToDevice)); - // CHECK_HIP_ERROR(hipMemcpy(d_scaleB, &(transpose_result ? scaleA : scaleB), sizeof(float), hipMemcpyHostToDevice)); - auto d_scaleA = transpose_result ? scaleB.data_ptr() : scaleA.data_ptr(); - auto d_scaleB = transpose_result ? scaleA.data_ptr() : scaleB.data_ptr(); - auto d_scaleD = scaleD.data_ptr(); - - auto handle = at::cuda::getCurrentCUDABlasLtHandle(); - auto stream = at::cuda::getCurrentCUDAStream(); - - hipblaslt_ext::GemmPreference gemmPref; - gemmPref.setMaxWorkspaceBytes(0); - hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, - HIP_R_8F_E4M3_FNUZ, HIPBLAS_COMPUTE_32F); - - hipblaslt_ext::GemmEpilogue epilogue{}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only) - hipblaslt_ext::GemmInputs inputs; - inputs.a = d_a; - inputs.b = d_b; - inputs.c = d_d; - inputs.d = d_d; - inputs.alpha = α - inputs.beta = β - inputs.scaleA = d_scaleA; - inputs.scaleB = d_scaleB; - inputs.scaleD = d_scaleD; - gemm.setProblem(m, n, k, 1, epilogue, inputs); - if (algo_idx == 0) { - constexpr int request_solutions = 1024; - std::vector heuristicResult; - heuristicResult.reserve(request_solutions); - CHECK_HIPBLASLT_ERROR(gemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult)); - static size_t solSize = 0; - if (heuristicResult.size() != solSize) { - std::cout << "fp8 sols: " << heuristicResult.size() << "\n"; - solSize = heuristicResult.size(); - for (auto& res : heuristicResult) { - auto idx = hipblaslt_ext::getIndexFromAlgo(res.algo); - std::cout << idx << "\n"; - } - } - TORCH_CHECK(!heuristicResult.empty(), "No valid solution found!"); - algo_idx = hipblaslt_ext::getIndexFromAlgo(heuristicResult[0].algo); - } - std::vector algoIndex(1); - algoIndex[0] = algo_idx; - std::vector tmpAlgo; - TORCH_CUDABLAS_CHECK(hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo)); - - CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, nullptr)); - CHECK_HIPBLASLT_ERROR(gemm.run(stream)); - - // hipFree(d_scaleA); - // hipFree(d_scaleB); - - return result; + TORCH_CHECK(!heuristicResult.empty(), "No valid solution found!"); + algo_idx = hipblaslt_ext::getIndexFromAlgo(heuristicResult[0].algo); + } + std::vector algoIndex(1); + algoIndex[0] = algo_idx; + std::vector tmpAlgo; + TORCH_CUDABLAS_CHECK( + hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo)); + + CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, nullptr)); + CHECK_HIPBLASLT_ERROR(gemm.run(stream)); + + // hipFree(d_scaleA); + // hipFree(d_scaleB); + + return result; } -torch::Tensor fp8_gemm_16( - torch::Tensor& a, torch::Tensor& b, torch::Tensor& scaleA, torch::Tensor& scaleB, int algo_idx) -{ - auto a_strides{a.strides()}; - auto b_strides{b.strides()}; - auto a_sizes{a.sizes()}; - auto b_sizes{b.sizes()}; - - // CHECK_INPUT(a); - // CHECK_INPUT(b); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fnuz && b.dtype() == torch::kFloat8_e4m3fnuz, - "The input tensors should be in fp8."); - TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Input tensors must be 2-D."); - TORCH_CHECK(a_sizes[1] == b_sizes[0], "a dim 1 must match b dim 0."); - - auto options{at::TensorOptions().dtype(torch::kFloat16).device(at::kCUDA)}; - auto result{torch::empty({a_sizes[0], b_sizes[1]}, options)}; - - constexpr bool transpose_result = true; - bool transpose_a; - bool transpose_b; - if ((b_strides[0] == 1) && (b_strides[1] >= std::max(1, b_sizes[0]))) { - transpose_b = false; - } else if ((b_strides[1] == 1) && (b_strides[0] >= std::max(1, b_sizes[1]))) { - transpose_b = true; - } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); +torch::Tensor fp8_gemm_16(torch::Tensor& a, torch::Tensor& b, + torch::Tensor& scaleA, torch::Tensor& scaleB, + int algo_idx) { + auto a_strides{a.strides()}; + auto b_strides{b.strides()}; + auto a_sizes{a.sizes()}; + auto b_sizes{b.sizes()}; + + // CHECK_INPUT(a); + // CHECK_INPUT(b); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fnuz && + b.dtype() == torch::kFloat8_e4m3fnuz, + "The input tensors should be in fp8."); + TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Input tensors must be 2-D."); + TORCH_CHECK(a_sizes[1] == b_sizes[0], "a dim 1 must match b dim 0."); + + auto options{at::TensorOptions().dtype(torch::kFloat16).device(at::kCUDA)}; + auto result{torch::empty({a_sizes[0], b_sizes[1]}, options)}; + + constexpr bool transpose_result = true; + bool transpose_a; + bool transpose_b; + if ((b_strides[0] == 1) && + (b_strides[1] >= std::max(1, b_sizes[0]))) { + transpose_b = false; + } else if ((b_strides[1] == 1) && + (b_strides[0] >= std::max(1, b_sizes[1]))) { + transpose_b = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + if ((a_strides[0] == 1) && + (a_strides[1] >= std::max(1, a_sizes[0]))) { + transpose_a = false; + } else if ((a_strides[1] == 1) && + (a_strides[0] >= std::max(1, a_sizes[1]))) { + transpose_a = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + + if (transpose_result) { + bool tmp = transpose_a; + transpose_a = !transpose_b; + transpose_b = !tmp; + a_strides = b.strides(); + b_strides = a.strides(); + a_sizes = b.sizes(); + b_sizes = a.sizes(); + } + + float alpha = 1.0f; + float beta = 0.0f; + int64_t m = a_sizes[transpose_result ? 1 : 0]; + int64_t k = a_sizes[transpose_result ? 0 : 1]; + int64_t n = b_sizes[transpose_result ? 0 : 1]; + + void* d_a = static_cast((transpose_result ? b : a).data_ptr()); + void* d_b = static_cast((transpose_result ? a : b).data_ptr()); + void* d_d = static_cast(result.data_ptr()); + + // void *d_scaleA, *d_scaleB, *d_workspace; + // CHECK_HIP_ERROR(hipMalloc(&d_scaleA, sizeof(float))); + // CHECK_HIP_ERROR(hipMalloc(&d_scaleB, sizeof(float))); + // CHECK_HIP_ERROR(hipMalloc(&d_workspace, max_workspace_size)); + // CHECK_HIP_ERROR(hipMemcpy(d_scaleA, &(transpose_result ? scaleB : scaleA), + // sizeof(float), hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(d_scaleB, + // &(transpose_result ? scaleA : scaleB), sizeof(float), + // hipMemcpyHostToDevice)); + auto d_scaleA = transpose_result ? scaleB.data_ptr() : scaleA.data_ptr(); + auto d_scaleB = transpose_result ? scaleA.data_ptr() : scaleB.data_ptr(); + + auto handle = at::cuda::getCurrentCUDABlasLtHandle(); + auto stream = at::cuda::getCurrentCUDAStream(); + + hipblaslt_ext::GemmPreference gemmPref; + gemmPref.setMaxWorkspaceBytes(0); + hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N, + HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_16F, + HIP_R_16F, HIPBLAS_COMPUTE_32F); + + hipblaslt_ext::GemmEpilogue + epilogue{}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. + // (Gemm only) + hipblaslt_ext::GemmInputs inputs; + inputs.a = d_a; + inputs.b = d_b; + inputs.c = d_d; + inputs.d = d_d; + inputs.alpha = α + inputs.beta = β + inputs.scaleA = d_scaleA; + inputs.scaleB = d_scaleB; + gemm.setProblem(m, n, k, 1, epilogue, inputs); + if (algo_idx == 0) { + constexpr int request_solutions = 1024; + std::vector heuristicResult; + heuristicResult.reserve(request_solutions); + CHECK_HIPBLASLT_ERROR( + gemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult)); + static size_t solSize = 0; + if (heuristicResult.size() != solSize) { + std::cout << "fp16 sols: " << heuristicResult.size() << "\n"; + solSize = heuristicResult.size(); + for (auto& res : heuristicResult) { + auto idx = hipblaslt_ext::getIndexFromAlgo(res.algo); + std::cout << idx << "\n"; + } } - if ((a_strides[0] == 1) && (a_strides[1] >= std::max(1, a_sizes[0]))) { - transpose_a = false; - } else if ((a_strides[1] == 1) && (a_strides[0] >= std::max(1, a_sizes[1]))) { - transpose_a = true; - } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); - } - - if (transpose_result) { - bool tmp = transpose_a; - transpose_a = !transpose_b; - transpose_b = !tmp; - a_strides = b.strides(); - b_strides = a.strides(); - a_sizes = b.sizes(); - b_sizes = a.sizes(); - } - - float alpha = 1.0f; - float beta = 0.0f; - int64_t m = a_sizes[transpose_result ? 1 : 0]; - int64_t k = a_sizes[transpose_result ? 0 : 1]; - int64_t n = b_sizes[transpose_result ? 0 : 1]; - - void* d_a = static_cast((transpose_result ? b : a).data_ptr()); - void* d_b = static_cast((transpose_result ? a : b).data_ptr()); - void* d_d = static_cast(result.data_ptr()); - - // void *d_scaleA, *d_scaleB, *d_workspace; - // CHECK_HIP_ERROR(hipMalloc(&d_scaleA, sizeof(float))); - // CHECK_HIP_ERROR(hipMalloc(&d_scaleB, sizeof(float))); - // CHECK_HIP_ERROR(hipMalloc(&d_workspace, max_workspace_size)); - // CHECK_HIP_ERROR(hipMemcpy(d_scaleA, &(transpose_result ? scaleB : scaleA), sizeof(float), hipMemcpyHostToDevice)); - // CHECK_HIP_ERROR(hipMemcpy(d_scaleB, &(transpose_result ? scaleA : scaleB), sizeof(float), hipMemcpyHostToDevice)); - auto d_scaleA = transpose_result ? scaleB.data_ptr() : scaleA.data_ptr(); - auto d_scaleB = transpose_result ? scaleA.data_ptr() : scaleB.data_ptr(); - - auto handle = at::cuda::getCurrentCUDABlasLtHandle(); - auto stream = at::cuda::getCurrentCUDAStream(); - - hipblaslt_ext::GemmPreference gemmPref; - gemmPref.setMaxWorkspaceBytes(0); - hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_16F, HIP_R_16F, - HIPBLAS_COMPUTE_32F); - - hipblaslt_ext::GemmEpilogue epilogue{}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only) - hipblaslt_ext::GemmInputs inputs; - inputs.a = d_a; - inputs.b = d_b; - inputs.c = d_d; - inputs.d = d_d; - inputs.alpha = α - inputs.beta = β - inputs.scaleA = d_scaleA; - inputs.scaleB = d_scaleB; - gemm.setProblem(m, n, k, 1, epilogue, inputs); - if (algo_idx == 0) { - constexpr int request_solutions = 1024; - std::vector heuristicResult; - heuristicResult.reserve(request_solutions); - CHECK_HIPBLASLT_ERROR(gemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult)); - static size_t solSize = 0; - if (heuristicResult.size() != solSize) { - std::cout << "fp16 sols: " << heuristicResult.size() << "\n"; - solSize = heuristicResult.size(); - for (auto& res : heuristicResult) { - auto idx = hipblaslt_ext::getIndexFromAlgo(res.algo); - std::cout << idx << "\n"; - } - } - algo_idx = hipblaslt_ext::getIndexFromAlgo(heuristicResult[0].algo); - TORCH_CHECK(!heuristicResult.empty(), "No valid solution found!"); - } - std::vector algoIndex(1); - algoIndex[0] = algo_idx; - std::vector tmpAlgo; - TORCH_CUDABLAS_CHECK(hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo)); - - CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, nullptr)); - CHECK_HIPBLASLT_ERROR(gemm.run(stream)); - - // hipFree(d_scaleA); - // hipFree(d_scaleB); - - return result; + algo_idx = hipblaslt_ext::getIndexFromAlgo(heuristicResult[0].algo); + TORCH_CHECK(!heuristicResult.empty(), "No valid solution found!"); + } + std::vector algoIndex(1); + algoIndex[0] = algo_idx; + std::vector tmpAlgo; + TORCH_CUDABLAS_CHECK( + hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo)); + + CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, nullptr)); + CHECK_HIPBLASLT_ERROR(gemm.run(stream)); + + // hipFree(d_scaleA); + // hipFree(d_scaleB); + + return result; } \ No newline at end of file diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh index 23d975fe0f37e..8a35467edbc21 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -307,344 +307,351 @@ vec_conversion(const Float8_& a) { // fp8 -> half template <> -__inline__ __device__ uint16_t scaled_vec_conversion(const uint8_t& a, float scale) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - __half_raw res; - res.data = static_cast(f8) * scale; - return res.x; +__inline__ __device__ uint16_t +scaled_vec_conversion(const uint8_t& a, float scale) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8) * scale; + return res.x; } // fp8x2 -> half2 template <> -__inline__ __device__ uint32_t scaled_vec_conversion(const uint16_t& a, float scale) -{ -#if defined(__HIP__MI300__) - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - union { - __half2_raw h2r; - uint32_t ui32; - } tmp; - tmp.h2r.x.data = f2[0] * scale; - tmp.h2r.y.data = f2[1] * scale; - return tmp.ui32; -#else - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - - tmp.u16[0] = scaled_vec_conversion(static_cast(a), scale); - tmp.u16[1] = scaled_vec_conversion(static_cast(a >> 8U), scale); - return tmp.u32; -#endif +__inline__ __device__ uint32_t +scaled_vec_conversion(const uint16_t& a, float scale) { + #if defined(__HIP__MI300__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0] * scale; + tmp.h2r.y.data = f2[1] * scale; + return tmp.ui32; + #else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = + scaled_vec_conversion(static_cast(a), scale); + tmp.u16[1] = scaled_vec_conversion( + static_cast(a >> 8U), scale); + return tmp.u32; + #endif } // fp8x4 -> half2x2 template <> -__inline__ __device__ uint2 scaled_vec_conversion(const uint32_t& a, float scale) -{ - union { - uint2 u32x2; - uint32_t u32[2]; - } tmp; - tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); - tmp.u32[1] = scaled_vec_conversion((uint16_t)(a >> 16U), scale); - return tmp.u32x2; +__inline__ __device__ uint2 +scaled_vec_conversion(const uint32_t& a, float scale) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); + tmp.u32[1] = + scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return tmp.u32x2; } // fp8x8 -> half2x4 template <> -__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, float scale) -{ - union { - uint4 u64x2; - uint2 u64[2]; - } tmp; - tmp.u64[0] = scaled_vec_conversion(a.x, scale); - tmp.u64[1] = scaled_vec_conversion(a.y, scale); - return tmp.u64x2; +__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, + float scale) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale); + tmp.u64[1] = scaled_vec_conversion(a.y, scale); + return tmp.u64x2; } using __nv_bfloat16 = __hip_bfloat16; // fp8 -> __nv_bfloat16 template <> -__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - float f{f8}; - return __float2bfloat16(f * scale); +__inline__ __device__ __nv_bfloat16 +scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f * scale); } using __nv_bfloat162 = __hip_bfloat162; // fp8x2 -> __nv_bfloat162 template <> -__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, float scale) -{ - __nv_bfloat162 res; - res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); - res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); - return res; +__inline__ __device__ __nv_bfloat162 +scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, + float scale) { + __nv_bfloat162 res; + res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); + res.y = + scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); + return res; } // fp8x4 -> bf16_4_t template <> -__inline__ __device__ bf16_4_t scaled_vec_conversion(const uint32_t& a, float scale) -{ - bf16_4_t res; - res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); - res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale); - return res; +__inline__ __device__ bf16_4_t +scaled_vec_conversion(const uint32_t& a, float scale) { + bf16_4_t res; + res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); + res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), + scale); + return res; } // fp8x8 -> bf16_8_t template <> -__inline__ __device__ bf16_8_t scaled_vec_conversion(const uint2& a, float scale) -{ - bf16_4_t tmp1, tmp2; - tmp1 = scaled_vec_conversion(a.x, scale); - tmp2 = scaled_vec_conversion(a.y, scale); - bf16_8_t res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ bf16_8_t +scaled_vec_conversion(const uint2& a, float scale) { + bf16_4_t tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } // fp8 -> float template <> -__inline__ __device__ float scaled_vec_conversion(const uint8_t& a, float scale) -{ - hip_fp8 fp8{a, hip_fp8::from_bits()}; - return static_cast(fp8) * scale; +__inline__ __device__ float scaled_vec_conversion( + const uint8_t& a, float scale) { + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8) * scale; } // fp8x2 -> float2 template <> -__inline__ __device__ float2 scaled_vec_conversion(const uint16_t& a, float scale) -{ -#if defined(__HIP__MI300__) - float2 res; - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - res.x = f2[0] * scale; - res.y = f2[1] * scale; - return res; -#else - float2 res; - res.x = scaled_vec_conversion(static_cast(a), scale); - res.y = scaled_vec_conversion(static_cast(a >> 8U), scale); - return res; -#endif +__inline__ __device__ float2 +scaled_vec_conversion(const uint16_t& a, float scale) { + #if defined(__HIP__MI300__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0] * scale; + res.y = f2[1] * scale; + return res; + #else + float2 res; + res.x = scaled_vec_conversion(static_cast(a), scale); + res.y = scaled_vec_conversion(static_cast(a >> 8U), + scale); + return res; + #endif } // fp8x4 -> float4 template <> -__inline__ __device__ Float4_ scaled_vec_conversion(const uint32_t& a, const float scale) -{ - Float4_ res; - res.x = scaled_vec_conversion((uint16_t)a, scale); - res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale); - return res; +__inline__ __device__ Float4_ +scaled_vec_conversion(const uint32_t& a, const float scale) { + Float4_ res; + res.x = scaled_vec_conversion((uint16_t)a, scale); + res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return res; } // fp8x4 -> float4 template <> -__inline__ __device__ float4 scaled_vec_conversion(const uint32_t& a, float scale) -{ - Float4_ res = scaled_vec_conversion(a, scale); - return {res.x.x, res.x.y, res.y.x, res.y.y}; +__inline__ __device__ float4 +scaled_vec_conversion(const uint32_t& a, float scale) { + Float4_ res = scaled_vec_conversion(a, scale); + return {res.x.x, res.x.y, res.y.x, res.y.y}; } // fp8x8 -> float8 template <> -__inline__ __device__ Float8_ scaled_vec_conversion(const uint2& a, float scale) -{ - Float4_ tmp1, tmp2; - tmp1 = scaled_vec_conversion(a.x, scale); - tmp2 = scaled_vec_conversion(a.y, scale); - Float8_ res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ Float8_ +scaled_vec_conversion(const uint2& a, float scale) { + Float4_ tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } // half -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const uint16_t& a, float scale) -{ - __half_raw tmp; - tmp.x = a; +__inline__ __device__ uint8_t +scaled_vec_conversion(const uint16_t& a, float scale) { + __half_raw tmp; + tmp.x = a; - hip_fp8 f8{static_cast(tmp.data / scale)}; - return f8.data; + hip_fp8 f8{static_cast(tmp.data / scale)}; + return f8.data; } // halfx2 -> fp8x2 -template<> -__inline__ __device__ uint16_t scaled_vec_conversion(const uint32_t& a, float scale) -{ -#ifdef __HIP__MI300__ - union { - uint32_t ui32; - __half2_raw h2r; - } tmp; - tmp.ui32 = a; - - union { - uint32_t ui32; - float f; - } f1, f2; - f1.f = tmp.h2r.x.data / scale; - f2.f = tmp.h2r.y.data / scale; - if ((f1.ui32 & 0x7F800000) != 0x7F800000) { - f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); - } - if ((f2.ui32 & 0x7F800000) != 0x7F800000) { - f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); - } - return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f, f2.f, 0, 0); -#else - union { - uint32_t ui32; - __half2_raw h2r; - } tmp; - tmp.ui32 = a; - - union { - uint8_t ui8[2]; - uint16_t ui16; - } res; - res.ui8[0] = scaled_vec_conversion(tmp.h2r.x.x, scale); - res.ui8[1] = scaled_vec_conversion(tmp.h2r.y.x, scale); - return res.ui16; -#endif +template <> +__inline__ __device__ uint16_t +scaled_vec_conversion(const uint32_t& a, float scale) { + #ifdef __HIP__MI300__ + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + + union { + uint32_t ui32; + float f; + } f1, f2; + f1.f = tmp.h2r.x.data / scale; + f2.f = tmp.h2r.y.data / scale; + if ((f1.ui32 & 0x7F800000) != 0x7F800000) { + f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); + } + if ((f2.ui32 & 0x7F800000) != 0x7F800000) { + f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); + } + return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f, f2.f, 0, 0); + #else + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + + union { + uint8_t ui8[2]; + uint16_t ui16; + } res; + res.ui8[0] = scaled_vec_conversion(tmp.h2r.x.x, scale); + res.ui8[1] = scaled_vec_conversion(tmp.h2r.y.x, scale); + return res.ui16; + #endif } // half2x2 -> fp8x4 template <> -__inline__ __device__ uint32_t scaled_vec_conversion(const uint2& a, float scale) -{ - union { - uint16_t ui16[2]; - uint32_t ui32; - } tmp; - tmp.ui16[0] = scaled_vec_conversion(a.x, scale); - tmp.ui16[1] = scaled_vec_conversion(a.y, scale); - return tmp.ui32; +__inline__ __device__ uint32_t +scaled_vec_conversion(const uint2& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; } // half2x4 -> fp8x8 template <> -__inline__ __device__ uint2 scaled_vec_conversion(const uint4& a, float scale) -{ - union { - uint2 ui2[2]; - uint4 ui4; - } tmp; - tmp.ui4 = a; - uint2 res; - res.x = scaled_vec_conversion(tmp.ui2[0], scale); - res.y = scaled_vec_conversion(tmp.ui2[1], scale); - return res; +__inline__ __device__ uint2 scaled_vec_conversion(const uint4& a, + float scale) { + union { + uint2 ui2[2]; + uint4 ui4; + } tmp; + tmp.ui4 = a; + uint2 res; + res.x = scaled_vec_conversion(tmp.ui2[0], scale); + res.y = scaled_vec_conversion(tmp.ui2[1], scale); + return res; } // bf16 -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const __nv_bfloat16& a, float scale) -{ - hip_fp8 res{__bfloat162float(a) / scale}; - return res.data; +__inline__ __device__ uint8_t scaled_vec_conversion( + const __nv_bfloat16& a, float scale) { + hip_fp8 res{__bfloat162float(a) / scale}; + return res.data; } // bf16x2 -> fp8x2 template <> -__inline__ __device__ uint16_t scaled_vec_conversion(const __nv_bfloat162& a, float scale) -{ - union { - uint8_t ui8[2]; - uint16_t ui16; - } tmp; - tmp.ui8[0] = scaled_vec_conversion(a.x, scale); - tmp.ui8[1] = scaled_vec_conversion(a.y, scale); - return tmp.ui16; +__inline__ __device__ uint16_t scaled_vec_conversion( + const __nv_bfloat162& a, float scale) { + union { + uint8_t ui8[2]; + uint16_t ui16; + } tmp; + tmp.ui8[0] = scaled_vec_conversion(a.x, scale); + tmp.ui8[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui16; } // bf16x4 -> fp8x4 template <> -__inline__ __device__ uint32_t scaled_vec_conversion(const bf16_4_t& a, float scale) -{ - union { - uint16_t ui16[2]; - uint32_t ui32; - } tmp; - tmp.ui16[0] = scaled_vec_conversion(a.x, scale); - tmp.ui16[1] = scaled_vec_conversion(a.y, scale); - return tmp.ui32; +__inline__ __device__ uint32_t +scaled_vec_conversion(const bf16_4_t& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; } // bf16x8 -> fp8x8 template <> -__inline__ __device__ uint2 scaled_vec_conversion(const bf16_8_t& a, float scale) -{ - uint2 res; - res.x = scaled_vec_conversion({a.x, a.y}, scale); - res.y = scaled_vec_conversion({a.z, a.w}, scale); - return res; +__inline__ __device__ uint2 +scaled_vec_conversion(const bf16_8_t& a, float scale) { + uint2 res; + res.x = scaled_vec_conversion({a.x, a.y}, scale); + res.y = scaled_vec_conversion({a.z, a.w}, scale); + return res; } // float -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const float& a, float scale) -{ - hip_fp8 f8(a); - return f8.data; +__inline__ __device__ uint8_t +scaled_vec_conversion(const float& a, float scale) { + hip_fp8 f8(a); + return f8.data; } // floatx2 -> fp8x2 template <> -__inline__ __device__ uint16_t scaled_vec_conversion(const float2& a, float scale) -{ -#ifdef __HIP__MI300__ - union { - uint32_t ui32; - float f; - } f1, f2; - f1.f = a.x / scale; - f2.f = a.y / scale; - if ((f1.ui32 & 0x7F800000) != 0x7F800000) { - f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); - } - if ((f2.ui32 & 0x7F800000) != 0x7F800000) { - f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); - } - return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f,f2.f, 0, 0); -#else - union { - uint8_t ui8[2]; - uint16_t ui16; - } tmp; - tmp.ui8[0] = scaled_vec_conversion(a.x, scale); - tmp.ui8[1] = scaled_vec_conversion(a.y, scale); - return tmp.ui16; -#endif +__inline__ __device__ uint16_t +scaled_vec_conversion(const float2& a, float scale) { + #ifdef __HIP__MI300__ + union { + uint32_t ui32; + float f; + } f1, f2; + f1.f = a.x / scale; + f2.f = a.y / scale; + if ((f1.ui32 & 0x7F800000) != 0x7F800000) { + f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); + } + if ((f2.ui32 & 0x7F800000) != 0x7F800000) { + f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); + } + return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f, f2.f, 0, 0); + #else + union { + uint8_t ui8[2]; + uint16_t ui16; + } tmp; + tmp.ui8[0] = scaled_vec_conversion(a.x, scale); + tmp.ui8[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui16; + #endif } // floatx4 -> fp8x4 template <> -__inline__ __device__ uint32_t scaled_vec_conversion(const float4& a, float scale) -{ - union { - uint16_t ui16[2]; - uint32_t ui32; - } tmp; - tmp.ui16[0] = scaled_vec_conversion({a.x, a.y}, scale); - tmp.ui16[1] = scaled_vec_conversion({a.z, a.w}, scale); - return tmp.ui32; +__inline__ __device__ uint32_t +scaled_vec_conversion(const float4& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion({a.x, a.y}, scale); + tmp.ui16[1] = scaled_vec_conversion({a.z, a.w}, scale); + return tmp.ui32; } #endif // ENABLE_FP8 diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 937df5a0bec13..bcb8fa514444d 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -14,33 +14,36 @@ #endif namespace vllm { - -template -__global__ void convert_fp8_kernel( - const Tin* __restrict__ src_data, Tout* __restrict__ dst_data, const float* scale, size_t N) -{ - const int64_t block_idx = blockIdx.x; - - using V_in_vec = typename Vec::Type; - using V_out_vec = typename Vec::Type; - auto dst_data_vec = reinterpret_cast(dst_data); - auto src_data_vec = reinterpret_cast(src_data); - int64_t startIdx = (threadIdx.x + blockDim.x * blockIdx.x); - auto idx = startIdx; - if (idx >= N) { - return; - } - dst_data_vec[idx] = fp8::scaled_vec_conversion(src_data_vec[idx], *scale); - //dst_data_vec[idx+1] = fp8_e4m3::vec_conversion(src_data_vec[idx+1], *scale); - - //for (int64_t i = 0; i < loopSize; ++i) { - // auto idx = startIdx + i; - // if (idx >= N) { - // return; - // } - // dst_data_vec[idx] = fp8_e4m3::vec_conversion(src_data_vec[idx], *scale); - //} +template +__global__ void convert_fp8_kernel(const Tin* __restrict__ src_data, + Tout* __restrict__ dst_data, + const float* scale, size_t N) { + const int64_t block_idx = blockIdx.x; + + using V_in_vec = typename Vec::Type; + using V_out_vec = typename Vec::Type; + auto dst_data_vec = reinterpret_cast(dst_data); + auto src_data_vec = reinterpret_cast(src_data); + + int64_t startIdx = (threadIdx.x + blockDim.x * blockIdx.x); + auto idx = startIdx; + if (idx >= N) { + return; + } + dst_data_vec[idx] = fp8::scaled_vec_conversion( + src_data_vec[idx], *scale); + // dst_data_vec[idx+1] = fp8_e4m3::vec_conversion(src_data_vec[idx+1], *scale); + + // for (int64_t i = 0; i < loopSize; ++i) { + // auto idx = startIdx + i; + // if (idx >= N) { + // return; + // } + // dst_data_vec[idx] = fp8_e4m3::vec_conversion(src_data_vec[idx], *scale); + // } } __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { @@ -117,7 +120,7 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, } } -} // namespace vllm +} // namespace vllm void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., d] @@ -158,54 +161,57 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] } template -struct call_convert_fp8 -{ - void operator()(torch::Tensor& src_data, torch::Tensor& dst_data, torch::Tensor& scale) - { - const auto N = src_data.numel() / 2; - //std::cout << N << "\n"; - constexpr uint32_t loopSize = 1;//std::max(N / 50000000LL, 1); - constexpr dim3 numThreads{1024, 1, 1}; - auto neededBlocks = (N + (numThreads.x * loopSize) - 1) / (numThreads.x * loopSize); - uint32_t actualBlocks = neededBlocks; - - //static uint32_t maxBlocks = 0; - //if (actualBlocks != maxBlocks) { - // maxBlocks = actualBlocks; - // std::cout << actualBlocks << "\n"; - //} - - const dim3 grid{actualBlocks, 1, 1}; - - const auto stream = at::cuda::getCurrentCUDAStream(); - - vllm::convert_fp8_kernel - <<>>(reinterpret_cast(src_data.data_ptr()), - reinterpret_cast(dst_data.data_ptr()), (float*)scale.data_ptr(), N); - } +struct call_convert_fp8 { + void operator()(torch::Tensor& src_data, torch::Tensor& dst_data, + torch::Tensor& scale) { + const auto N = src_data.numel() / 2; + // std::cout << N << "\n"; + constexpr uint32_t loopSize = 1; // std::max(N / 50000000LL, 1); + constexpr dim3 numThreads{1024, 1, 1}; + auto neededBlocks = + (N + (numThreads.x * loopSize) - 1) / (numThreads.x * loopSize); + uint32_t actualBlocks = neededBlocks; + + // static uint32_t maxBlocks = 0; + // if (actualBlocks != maxBlocks) { + // maxBlocks = actualBlocks; + // std::cout << actualBlocks << "\n"; + // } + + const dim3 grid{actualBlocks, 1, 1}; + + const auto stream = at::cuda::getCurrentCUDAStream(); + + vllm::convert_fp8_kernel + <<>>( + reinterpret_cast(src_data.data_ptr()), + reinterpret_cast(dst_data.data_ptr()), + (float*)scale.data_ptr(), N); + } }; -void convert_fp8(torch::Tensor& dst_data, torch::Tensor& src_data, torch::Tensor& scale) -{ - torch::Device src_device = src_data.device(); - torch::Device dst_device = dst_data.device(); - TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") - TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") - TORCH_CHECK(src_device.index() == dst_device.index(), "src and dst must be on the same GPU"); - at::cuda::OptionalCUDAGuard device_guard(src_device); - auto t1 = src_data.dtype(); - auto t2 = dst_data.dtype(); - if (src_data.dtype() == at::ScalarType::Float) { - call_convert_fp8{}(src_data, dst_data, scale); - } else if (src_data.dtype() == at::ScalarType::Half) { - call_convert_fp8{}(src_data, dst_data, scale); - } else if (src_data.dtype() == at::ScalarType::BFloat16) { - call_convert_fp8{}(src_data, dst_data, scale); - } else if (dst_data.dtype() == at::ScalarType::Float) { - call_convert_fp8{}(src_data, dst_data, scale); - } else if (dst_data.dtype() == at::ScalarType::Half) { - call_convert_fp8{}(src_data, dst_data, scale); - } else if (dst_data.dtype() == at::ScalarType::BFloat16) { - call_convert_fp8<__nv_bfloat16, uint8_t, 2>{}(src_data, dst_data, scale); - } +void convert_fp8(torch::Tensor& dst_data, torch::Tensor& src_data, + torch::Tensor& scale) { + torch::Device src_device = src_data.device(); + torch::Device dst_device = dst_data.device(); + TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") + TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); + at::cuda::OptionalCUDAGuard device_guard(src_device); + auto t1 = src_data.dtype(); + auto t2 = dst_data.dtype(); + if (src_data.dtype() == at::ScalarType::Float) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (src_data.dtype() == at::ScalarType::Half) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (src_data.dtype() == at::ScalarType::BFloat16) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (dst_data.dtype() == at::ScalarType::Float) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (dst_data.dtype() == at::ScalarType::Half) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (dst_data.dtype() == at::ScalarType::BFloat16) { + call_convert_fp8<__nv_bfloat16, uint8_t, 2>{}(src_data, dst_data, scale); + } } \ No newline at end of file diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index ad6ec10892b6e..894c9e9dc6554 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -165,6 +165,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: ) return self._cached_decode_metadata + def _make_alibi_bias( alibi_slopes: torch.Tensor, dtype: torch.dtype, @@ -185,43 +186,44 @@ def _make_alibi_bias( bias.mul_(alibi_slopes[:, None, None]) inf_mask = torch.empty( (1, seq_len, seq_len), - dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(alibi_slopes.device) + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to( + alibi_slopes.device) attn_biases.append((bias + inf_mask).to(dtype)) return attn_biases -def _make_alibi_bias_v2( - alibi_slopes: torch.Tensor, - dtype: torch.dtype, - seq_lens: List[int], - make_attn_mask: bool = True -) -> List[torch.Tensor]: +def _make_alibi_bias_v2(alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: Optional[List[int]], + make_attn_mask: bool = True) -> List[torch.Tensor]: attn_biases = [] - for seq_len in seq_lens: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seq_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - - num_heads = alibi_slopes.shape[0] - bias = bias[None, :].repeat((num_heads, 1, 1)).to(alibi_slopes.device) - bias.mul_(alibi_slopes[:, None, None]) - if make_attn_mask: - inf_mask = torch.empty( - (1, seq_len, seq_len), - dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(alibi_slopes.device) - attn_biases.append((bias + inf_mask).to(dtype)) - else: - attn_biases.append(bias.to(dtype)) + if seq_lens: + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat( + (num_heads, 1, 1)).to(alibi_slopes.device) + bias.mul_(alibi_slopes[:, None, None]) + if make_attn_mask: + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to( + alibi_slopes.device) + attn_biases.append((bias + inf_mask).to(dtype)) + else: + attn_biases.append(bias.to(dtype)) return attn_biases - class ROCmFlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: @@ -384,8 +386,10 @@ def forward( if self.use_triton_flash_attn: if self.alibi_slopes is not None: att_masks = _make_alibi_bias_v2( - self.alibi_slopes, query.dtype, - attn_metadata.seq_lens, make_attn_mask=False) # type: ignore + self.alibi_slopes, + query.dtype, + attn_metadata.seq_lens, + make_attn_mask=False) # type: ignore out, _ = self.attn_func( query, key, @@ -402,20 +406,17 @@ def forward( elif self.use_naive_attn: if self.alibi_slopes is not None: att_masks = _make_alibi_bias_v2( - self.alibi_slopes, query.dtype, - attn_metadata.seq_lens, make_attn_mask=True) # type: ignore + self.alibi_slopes, + query.dtype, + attn_metadata.seq_lens, + make_attn_mask=True) # type: ignore if self.num_kv_heads != self.num_heads: # Interleave for MQA workaround. key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) - out = self.attn_func( - query, - key, - value, - prefill_meta.seq_lens, - self.scale, - att_masks - ) + out = self.attn_func(query, key, value, + prefill_meta.seq_lens, self.scale, + att_masks) else: out = self.attn_func( q=query, @@ -486,7 +487,7 @@ def _naive_attention( key[start:end], value[start:end], scale, - attn_masks[i], + attn_masks[i] if attn_masks else None, ) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out) @@ -505,13 +506,13 @@ def _naive_masked_attention( seq_len, head_size, head_dim = query.shape if attn_mask is None: attn_mask = torch.triu(torch.ones(seq_len, - seq_len, - dtype=query.dtype, - device=query.device), - diagonal=1) + seq_len, + dtype=query.dtype, + device=query.device), + diagonal=1) attn_mask = attn_mask * torch.finfo(query.dtype).min attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() attn_weights = attn_weights + attn_mask.float() attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out \ No newline at end of file + return out diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index c99029175b5a2..05134872ba39c 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -61,7 +61,8 @@ def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): @triton.jit -def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): +def load_fn(ptrs, offset_first, offset_second, boundary_first, + boundary_second): if offset_first is not None and offset_second is not None: mask = (offset_first[:, None] < boundary_first) & \ (offset_second[None, :] < boundary_second) @@ -78,38 +79,43 @@ def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): @triton.jit -def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m, - actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, encoded_sm_ptrs, - block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, - IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr): +def _attn_fwd_inner( + acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, + stride_bn, start_m, actual_seqlen_k, actual_seqlen_q, dropout_p, + philox_seed, batch_philox_offset, encoded_sm_ptrs, block_min, + block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr): # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. - if MASK_STEPS: - k_offs_n = start_n + tl.arange(0, BLOCK_N) - else: - k_offs_n = None + k_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) - k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) + k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, + actual_seqlen_k) if PRE_LOAD_V: # We can use the same offsets as k, just with dims transposed. - v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, + ACTUAL_BLOCK_DMODEL) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: - # If this is the last block / iteration, we want to - # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. - # last step might get wasted but that is okay. check if this masking works For - # that case. - if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) + if MASK_STEPS: # NOQA: SIM102 + if start_n + BLOCK_N == block_max and n_extra_tokens != 0: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if + # not is_modulo_mn. Last step might get wasted but that is okay. + # Check if this masking works for that case. + boundary_m = tl.full([BLOCK_M], + actual_seqlen_k, + dtype=tl.int32) size_n = start_n + OFFS_N[None, :] mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) @@ -120,11 +126,14 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri # -- compute qk ---- qk += tl.dot(q, k) if bias_ptrs is not None: - bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None - bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) + bias_offs_n = start_n + tl.arange(0, + BLOCK_N) if MASK_STEPS else None + bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, + actual_seqlen_k) # While bias is added after multiplying qk with sm_scale, - # our optimization to use 2^x instead of e^x results in an additional - # scale factor of log2(e) which we must also multiply the bias with. + # our optimization to use 2^x instead of e^x results in an + # additional scale factor of log2(e) which we must also multiply + # the bias with. qk += (bias * 1.44269504089) # softmax @@ -135,10 +144,15 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: - philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N - keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) + philox_offset = (batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + start_n - + BLOCK_N) + keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, + BLOCK_N, actual_seqlen_k) if RETURN_ENCODED_SOFTMAX: - tl.store(encoded_sm_ptrs, tl.where(keep, p, -p).to(encoded_sm_ptrs.type.element_ty)) + tl.store( + encoded_sm_ptrs, + tl.where(keep, p, -p).to(encoded_sm_ptrs.type.element_ty)) p = tl.where(keep, p, 0.0) elif RETURN_ENCODED_SOFTMAX: tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) @@ -146,7 +160,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] if not PRE_LOAD_V: - v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, + ACTUAL_BLOCK_DMODEL) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i @@ -260,14 +275,20 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], ) @triton.jit -def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, - stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, - stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, - HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr): +def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, + stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, + stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, + stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, + stride_ah, cu_seqlens_q, cu_seqlens_k, dropout_p, philox_seed, + philox_offset_base, encoded_softmax, alibi_slopes, + HQ: tl.constexpr, HK: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr): start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) @@ -298,42 +319,46 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s # This block of code determines what N is, and if this WG is operating # on those M rows. n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - if (IS_CAUSAL): + if IS_CAUSAL: # If seqlen_q == seqlen_k, the attn scores are a square matrix. # If seqlen_q != seqlen_k, attn scores are rectangular which means # the causal mask boundary is bottom right aligned, and ends at either # the top edge (seqlen_q < seqlen_k) or left edge. - # This captures the decrease in n_blocks if we have a rectangular attn matrix - n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This captures the decrease in n_blocks if we have a rectangular + # attn matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) # This is what adjusts the block_max for the current WG, only # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this WG is part of - # the blocks that are all 0. We exit early. + # If we have no blocks after adjusting for seqlen deltas, this WG is + # part of the blocks that are all 0. We exit early. if n_blocks <= 0: - o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om - o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + + cu_seqlens_q_start * stride_om) + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[ + None, :] * stride_on acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) o_ptrs_mask = offs_m[:, None] < seqlen_q # We still need to write 0s to the result tl.store(o_ptrs, acc, mask=o_ptrs_mask) # The tensor allocated for L is based on MAX_SEQLENS_Q as that is # statically known. - l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # We store inf to LSE, not -inf because in the bwd pass, we subtract this - # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. - l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q + + off_h_q * MAX_SEQLENS_Q + offs_m) + # We store inf to LSE, not -inf because in the bwd pass, we subtract + # this from qk which makes it -inf, such that exp(qk - inf) = 0 for + # these masked blocks. + l = tl.full( # NOQA: E741 + [BLOCK_M], value=float("inf"), dtype=tl.float32) l_ptrs_mask = offs_m < MAX_SEQLENS_Q tl.store(l_ptrs, l, mask=l_ptrs_mask) - # TODO: Should dropout and return encoded softmax be handled here too? + # TODO: Should dropout & return encoded softmax be handled here too? return # If MQA / GQA, set the K and V head offsets appropriately. GROUP_SIZE: tl.constexpr = HQ // HK - if GROUP_SIZE != 1: - off_h_k = off_h_q // GROUP_SIZE - else: - off_h_k = off_h_q + off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q n_extra_tokens = 0 if seqlen_k < BLOCK_N: @@ -343,16 +368,23 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) # Compute pointers for all the tensors used in this kernel. - q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm - q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn - k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn - v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk - v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn + q_offset = (Q + off_z * stride_qz + off_h_q * stride_qh + + cu_seqlens_q_start * stride_qm) + q_ptrs = (q_offset + offs_m[:, None] * stride_qm + + offs_d[None, :] * stride_qk) + k_offset = (K + off_z * stride_kz + off_h_k * stride_kh + + cu_seqlens_k_start * stride_kn) + k_ptrs = (k_offset + offs_d[:, None] * stride_kk + + offs_n[None, :] * stride_kn) + v_offset = (V + off_z * stride_vz + off_h_k * stride_vh + + cu_seqlens_k_start * stride_vk) + v_ptrs = (v_offset + offs_n[:, None] * stride_vk + + offs_d[None, :] * stride_vn) if USE_BIAS: # Note: this might get large enough to overflow on some configs bias_offset = off_h_q * stride_bh - bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn + bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[ + None, :] * stride_bn else: bias_ptrs = None @@ -367,11 +399,13 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k else: batch_philox_offset = 0 - # We can ask to return the dropout mask without actually doing any dropout. In + # We can ask to return the dropout mask without actually doing dropout. In # this case, we return an invalid pointer so indicate the mask is not valid. if RETURN_ENCODED_SOFTMAX: encoded_sm_base = encoded_softmax + off_h_q * seqlen_q * seqlen_k - encoded_sm_ptrs = encoded_sm_base + offs_m[:, None] * seqlen_k + offs_n[None, :] + encoded_sm_ptrs = encoded_sm_base + offs_m[:, + None] * seqlen_k + offs_n[ + None, :] else: encoded_sm_ptrs = None # initialize pointer to m and l @@ -398,50 +432,105 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s else: # Padding on Q does not need to be masked in the FA loop. masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. - # In this case we might exceed n_blocks so pick the min. + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional + # block. In this case we might exceed n_blocks so pick the min. masked_blocks = min(masked_blocks, n_blocks) n_full_blocks = n_blocks - masked_blocks block_min = 0 block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its actual - # value because there is no masking. Similarly we do not need padding. + # Compute for full blocks. Here we set causal to false unconditionally + # because there is no masking. Similarly we do not need padding. if n_full_blocks > 0: block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, - start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, - encoded_sm_ptrs, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, block_max, 0, 0, 0, alibi_slope, - # IS_CAUSAL, .... - False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_sm_ptrs, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + alibi_slope, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + False, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + PADDED_HEAD, + ACTUAL_BLOCK_DMODEL) block_min = block_max block_max = n_blocks * BLOCK_N tl.debug_barrier() # Remaining blocks, if any, are full / not masked. if (masked_blocks > 0): - if IS_CAUSAL: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) - else: - offs_n_causal = 0 + offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 k_ptrs += n_full_blocks * BLOCK_N * stride_kn v_ptrs += n_full_blocks * BLOCK_N * stride_vk if USE_BIAS: bias_ptrs += n_full_blocks * BLOCK_N * stride_bn if RETURN_ENCODED_SOFTMAX: encoded_sm_ptrs += n_full_blocks * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, - start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, - encoded_sm_ptrs, block_min, block_max, offs_n_causal, masked_blocks, - n_extra_tokens, alibi_slope, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_sm_ptrs, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + alibi_slope, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + True, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + PADDED_HEAD, + ACTUAL_BLOCK_DMODEL) # epilogue acc = acc / l_i[:, None] if ENABLE_DROPOUT: @@ -454,28 +543,36 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s start_m_idx = start_m * BLOCK_M causal_start_idx = seqlen_q - seqlen_k acc = acc.to(Out.type.element_ty) - if IS_CAUSAL: + if IS_CAUSAL: # NOQA: SIM102 if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) + out_mask_boundary = tl.full((BLOCK_DMODEL, ), + causal_start_idx, + dtype=tl.int32) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + out_ptrs_mask = mask_m_offsets[:, + None] >= out_mask_boundary[None, :] z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. - # This is only true for the last M block. For others, overflow_size will be -ve + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few + # rows. This is only true for the last M block. For others, overflow_size + # will be -ve overflow_size = end_m_idx - seqlen_q if overflow_size > 0: - boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) + boundary = tl.full((BLOCK_M, ), + BLOCK_M - overflow_size, + dtype=tl.int32) l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) else: tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O - o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om - o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + + cu_seqlens_q_start * stride_om) + o_ptrs = (o_offset + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_on) o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) if overflow_size > 0: o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) @@ -595,7 +692,9 @@ def forward( else: bias_strides = (0, 0, 0, 0) alibi_strides = (0, 0) - M = torch.empty((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32) + M = torch.empty((batch, nheads_q, max_seqlens_q), + device=q.device, + dtype=torch.float32) attn_fwd[grid]( q, diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index b04adb532dd38..d5acc965ad200 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -6,13 +6,14 @@ import torch from torch.distributed import ProcessGroup +from vllm.utils import is_hip + from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_ca_communicator, get_tp_pynccl_communicator) -from vllm.utils import is_hip @dataclass diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6d1b3b47c3dd7..6e9b017ea93b3 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -188,7 +188,8 @@ def initialize_model_parallel( _TP_CPU_GROUP = cpu_group if tensor_model_parallel_size > 1 and not is_hip(): - from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.device_communicators.pynccl import ( + PyNcclCommunicator) _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( group=_TP_CPU_GROUP, device=_LOCAL_RANK, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 476301a216c48..00890a49b9be3 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -5,7 +5,6 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter -from vllm import _custom_C from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -16,7 +15,6 @@ QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.tuned_gemm import tgemm from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import is_hip logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index b963576aa4471..228a646d471f6 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -23,7 +23,7 @@ "aqlm": AQLMConfig, "awq": AWQConfig, "deepspeedfp": DeepSpeedFPConfig, - "fp8": Fp8Config if not is_hip() else Fp8RocmConfig, + "fp8": Fp8Config if not is_hip() else Fp8RocmConfig, # type: ignore # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) "marlin": MarlinConfig, diff --git a/vllm/model_executor/layers/quantization/fp8_rocm.py b/vllm/model_executor/layers/quantization/fp8_rocm.py index caa53fb6ceee8..ddccc5825c8a4 100644 --- a/vllm/model_executor/layers/quantization/fp8_rocm.py +++ b/vllm/model_executor/layers/quantization/fp8_rocm.py @@ -1,22 +1,19 @@ -from typing import Any, Dict, List, Optional, Tuple, Union, Iterator +import os +from typing import List, Optional, Tuple, Union +import pandas as pd import torch +import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter -import torch.nn.functional as F -from safetensors import safe_open -from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.utils import set_weight_attrs -import pandas as pd -import os - -try: +try: # NOQA: SIM105 from vllm._C import ops as vllm_ops except ImportError: pass @@ -25,10 +22,10 @@ class Fp8RocmConfig(QuantizationConfig): + def __init__(self) -> None: # self.quantized_weights_path = config["quantized_weights"] self._tuned = {} - self._stats = {} gemm_type = os.getenv("FP8_GEMM", "fp8_16") #print(f"Integral Cross factor = {self.factor}") if gemm_type == "fp8_8": @@ -38,10 +35,14 @@ def __init__(self) -> None: self.gemm_method = Fp8RocmLinearMethod.apply_fp8_16 tuned_filename = "/projects/tuned_fp8_16.csv" else: - raise Exception(f"Unknown fp8 gemm type: {gemm_type}") + raise ValueError(f"Unknown fp8 gemm type: {gemm_type}") try: df = pd.read_csv(tuned_filename) - except: + except pd.errors.ParserError as e: + logger.warning( + "An error occurred while parsing `%s`: %s" + "FP8 tuning results will not be used!", tuned_filename, e) + except (IOError, pd.errors.EmptyDataError): return for i in range(len(df)): @@ -58,7 +59,7 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config) -> "Fp8RocmConfig": - return cls(config) + return cls() @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: @@ -73,8 +74,8 @@ def get_min_capability(cls) -> int: def get_name(cls) -> str: return "Fp8Rocm" - def get_quant_method(self, - layer: torch.nn.Module) -> Optional["Fp8RocmLinearMethod"]: + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["Fp8RocmLinearMethod"]: if isinstance(layer, LinearBase): return Fp8RocmLinearMethod(self) return None @@ -84,10 +85,10 @@ def get_scaled_act_names(self) -> List[str]: class Fp8RocmLinearMethod(LinearMethodBase): + def __init__(self, config: Fp8RocmConfig): self._config = config - def _create_scale_param( self, scale_name: str, @@ -106,7 +107,6 @@ def _create_scale_param( self.scales_shard_indexer, }) - def create_weights( self, layer: torch.nn.Module, @@ -132,59 +132,56 @@ def create_weights( layer.register_parameter("weight", weight) set_weight_attrs(weight, { - **extra_weight_attrs, - "input_dim": 1, + **extra_weight_attrs, "input_dim": 1, "output_dim": 0 }) - - self._create_scale_param( - scale_name="weights_scaling_factor", - layer=layer, - output_partition_sizes=output_partition_sizes, - **extra_weight_attrs) - - self._create_scale_param( - scale_name="activation_scaling_factor", - layer=layer, - output_partition_sizes=output_partition_sizes, - **extra_weight_attrs) - - self._create_scale_param( - scale_name="output_scaling_factor", - layer=layer, - output_partition_sizes=output_partition_sizes, - **extra_weight_attrs) - - + + self._create_scale_param(scale_name="weights_scaling_factor", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + self._create_scale_param(scale_name="activation_scaling_factor", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + self._create_scale_param(scale_name="output_scaling_factor", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + def process_weights_after_loading(self, layer: Module) -> None: if (not hasattr(layer, "process_after_load") or not layer.process_after_load): return - layer.activation_scaling_factor = Parameter(layer.activation_scaling_factor.max(), - requires_grad=False) - layer.output_scaling_factor = Parameter(layer.output_scaling_factor.reciprocal().max(), - requires_grad=False) + layer.activation_scaling_factor = Parameter( + layer.activation_scaling_factor.max(), requires_grad=False) + layer.output_scaling_factor = Parameter( + layer.output_scaling_factor.reciprocal().max(), + requires_grad=False) max_w_scale = layer.weights_scaling_factor.max() if len(layer.logical_widths) > 1: start = 0 for idx, logical_width in enumerate(layer.logical_widths): end = start + logical_width - weight_dq = _per_tensor_dequantize(layer.weight[start:end, :], - layer.weights_scaling_factor[idx]) + weight_dq = _per_tensor_dequantize( + layer.weight[start:end, :], + layer.weights_scaling_factor[idx]) layer.weight[start:end, :] = _per_tensor_quantize( weight_dq, max_w_scale) start = end - layer.weights_scaling_factor = Parameter(max_w_scale, requires_grad=False) + layer.weights_scaling_factor = Parameter(max_w_scale, + requires_grad=False) # WEIGHT # Transpose weight for passing to torch._scaled_mm weight = layer.weight layer.weight = Parameter(weight, requires_grad=False) - def scales_shard_indexer( self, param: torch.Tensor, loaded_weight: torch.Tensor, shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]: @@ -198,7 +195,7 @@ def scales_shard_indexer( shard_id = qkv_idxs[shard_id] else: ValueError(f"Shard id must be int or str but got {type(shard_id)}") - + # To handle the scalar loaded tensor if loaded_weight.numel() == 1 and len(loaded_weight.shape) != 0: loaded_weight = torch.scalar_tensor(loaded_weight[0]) @@ -212,7 +209,9 @@ def apply_fp8_16( asf: torch.Tensor, wsf: torch.Tensor, osf: torch.Tensor, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert not bias x8 = torch.empty_like(x, dtype=torch.float8_e4m3fnuz) vllm_ops.convert_fp8(x8, x, asf) m = weight.shape[0] @@ -226,11 +225,15 @@ def apply_fp8_16( if os.getenv("TUNE_FP8") == "1": try: df = pd.read_csv("/projects/fp8_shapes.csv") - except: + except (IOError, pd.errors.EmptyDataError, + pd.errors.ParserError): df = pd.DataFrame(columns=["M", "N", "K"]) df = pd.concat( - [df, pd.DataFrame({"M": [m], "N": [n], "K": [k]})] - ).drop_duplicates() + [df, pd.DataFrame({ + "M": [m], + "N": [n], + "K": [k] + })]).drop_duplicates() df.to_csv("/projects/fp8_shapes.csv", index=False) algo = 0 res = vllm_ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo)) @@ -259,17 +262,21 @@ def apply_fp8_8( if os.getenv("TUNE_FP8") == "1": try: df = pd.read_csv("/projects/fp8_shapes.csv") - except: + except (IOError, pd.errors.EmptyDataError, + pd.errors.ParserError): df = pd.DataFrame(columns=["M", "N", "K"]) df = pd.concat( - [df, pd.DataFrame({"M": [m], "N": [n], "K": [k]})] - ).drop_duplicates() - df.to_csv("/projects/fp8_shapese.csv", index=False) + [df, pd.DataFrame({ + "M": [m], + "N": [n], + "K": [k] + })]).drop_duplicates() + df.to_csv("/projects/fp8_shapes.csv", index=False) algo = 0 res = vllm_ops.fp8_gemm(x8, weight.t(), asf, wsf, osf, int(algo)) res16 = torch.empty_like(res, dtype=torch.float16) - vllm_ops.convert_fp8(res16, res, 1/osf) + vllm_ops.convert_fp8(res16, res, 1 / osf) return res16 def apply( @@ -287,17 +294,17 @@ def apply( return self._config.gemm_method(self, x, weight, asf, wsf, osf) return F.linear(x, weight, bias) - + def _per_tensor_quantize(tensor: torch.Tensor, - inv_scale: float) -> torch.Tensor: + inv_scale: float) -> torch.Tensor: finfo = torch.finfo(torch.float8_e4m3fnuz) qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) return qweight.to(torch.float8_e4m3fnuz) def _per_tensor_dequantize(tensor: torch.Tensor, - inv_scale: float) -> torch.Tensor: + inv_scale: float) -> torch.Tensor: fake_qweight = tensor.to(torch.float16) dq_weight = fake_qweight * inv_scale return dq_weight diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index aa143c65d82b9..a7b8d1ad35620 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -247,10 +247,13 @@ def load_model(self, *, model_config: ModelConfig, model, "fall_back_to_pt_during_load", True)), ) - if model_config.quantization == 'fp8' and model_config.quantization_param_path is not None: + if (model_config.quantization == 'fp8' + and model_config.quantization_param_path is not None): model.load_quantized_weights( - safetensors_weights_iterator([model_config.model + model_config.quantization_param_path]) - ) + safetensors_weights_iterator([ + model_config.model + + model_config.quantization_param_path + ])) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2b8d3573f45cf..c7d63df353ca5 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -438,8 +438,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - - def load_quantized_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + def load_quantized_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]): params_dict = dict(self.named_parameters()) #with open("/projects/a.txt", "r") as f: # j = json.load(f) @@ -457,7 +458,8 @@ def load_quantized_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: #print(name) name = name.replace('transformer', 'model') - name = name.replace('kv_cache_scaling_factor', 'qkv.output_scaling_factor') + name = name.replace('kv_cache_scaling_factor', + 'qkv.output_scaling_factor') loaded_weight = loaded_weight.to("cuda") if loaded_weight.dtype == torch.int8: loaded_weight[loaded_weight == -128] = 0 @@ -481,9 +483,9 @@ def load_quantized_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) param = params_dict[name] - if "activation_scaling_factor" in name or "weights_scaling_factor" in name: - param.data.copy_(loaded_weight) - elif "output_scaling_factor" in name: + if ("activation_scaling_factor" in name + or "weights_scaling_factor" in name + or "output_scaling_factor" in name): param.data.copy_(loaded_weight) else: weight_loader = getattr(param, "weight_loader",