Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fp8 support for llama model family on Navi4x #245

Merged
merged 4 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12")
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")

# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101")
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1200")

#
# Supported/expected torch versions for CUDA/ROCm.
Expand Down Expand Up @@ -172,6 +172,20 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result")
#
get_torch_gpu_compiler_flags(VLLM_GPU_FLAGS ${VLLM_GPU_LANG})

#
# Get supported FP8 format based on GPU arches
#
get_supported_fp8_format(FP8_FORMAT ${VLLM_GPU_LANG} "${VLLM_GPU_ARCHES}")
if(${FP8_FORMAT} STREQUAL "E4M3FN")
message(STATUS "FP8 format: E4M3FN")
list(APPEND VLLM_GPU_FLAGS "-DUSE_CUDA_FP8_FORMAT")
elseif(${FP8_FORMAT} STREQUAL "E4M3FNUZ")
message(STATUS "FP8 format: E4M3FNUZ")
list(APPEND VLLM_GPU_FLAGS "-DUSE_HIP_FP8_FORMAT")
elseif(${FP8_FORMAT} STREQUAL "CONFLICT")
message(FATAL_ERROR "Target architectures support different types of FP8 formats!")
endif()

#
# Set nvcc parallelism.
#
Expand Down
30 changes: 30 additions & 0 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,33 @@ function (define_gpu_extension_target GPU_MOD_NAME)

install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME})
endfunction()


# gfx12xx should not be compiled together with gfx94x (MI300) because they support different types of FP8 format.
# FP8_FORMAT will be returned (E4M3FN / E4M3FNUZ / NONE / CONFLICT)
macro (get_supported_fp8_format FP8_FORMAT GPU_LANG GPU_ARCHES)
set(_USING_CUDA_FP8_FORMAT "FALSE")
set(_USING_HIP_FP8_FORMAT "FALSE")

if (NOT (${GPU_LANG} STREQUAL "HIP"))
set(_USING_CUDA_FP8_FORMAT "TRUE")
else()
foreach (_ARCH ${GPU_ARCHES})
if (_ARCH MATCHES "gfx94.")
set(_USING_HIP_FP8_FORMAT "TRUE")
elseif(_ARCH MATCHES "gfx12..")
set(_USING_CUDA_FP8_FORMAT "TRUE")
endif()
endforeach()
endif()

if ((${_USING_CUDA_FP8_FORMAT} STREQUAL "FALSE") AND (${_USING_HIP_FP8_FORMAT} STREQUAL "FALSE"))
set(FP8_FORMAT "NONE")
elseif((${_USING_CUDA_FP8_FORMAT} STREQUAL "FALSE") AND (${_USING_HIP_FP8_FORMAT} STREQUAL "TRUE"))
set(FP8_FORMAT "E4M3FNUZ")
elseif((${_USING_CUDA_FP8_FORMAT} STREQUAL "TRUE") AND (${_USING_HIP_FP8_FORMAT} STREQUAL "FALSE"))
set(FP8_FORMAT "E4M3FN")
else()
set(FP8_FORMAT "CONFLICT")
endif()
endmacro()
6 changes: 3 additions & 3 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"

#ifndef USE_ROCM
#if defined(USE_CUDA_FP8_FORMAT)
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif

#ifndef USE_ROCM
#if defined(USE_CUDA_FP8_FORMAT)
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
std::numeric_limits<FP8_TYPE>::max();
Expand Down Expand Up @@ -50,7 +50,7 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
}

float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
#ifndef USE_ROCM
#if defined(USE_CUDA_FP8_FORMAT)
return static_cast<c10::Float8_e4m3fn>(r);
#else
// Use hardware cvt instruction for fp8 on rocm
Expand Down
4 changes: 2 additions & 2 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType
from vllm.utils import is_hip
from vllm.utils import is_hip, is_navi4x

logger = init_logger(__name__)

Expand Down Expand Up @@ -711,7 +711,7 @@ def scaled_fp8_quant(
assert (input.ndim == 2)
shape: Union[Tuple[int, int], torch.Size] = input.shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = torch.float8_e4m3fnuz if is_hip() \
out_dtype: torch.dtype = torch.float8_e4m3fnuz if is_hip() and not is_navi4x() \
else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
Expand Down
15 changes: 8 additions & 7 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_hip, print_warning_once
from vllm.utils import is_hip, is_navi4x, print_warning_once

ACTIVATION_SCHEMES = ["static", "dynamic"]

Expand Down Expand Up @@ -227,8 +227,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
weight = layer.weight
weight_scale = layer.weight_scale

# If rocm, use float8_e4m3fnuz.
if is_hip():
# If rocm (except Navi4x), use float8_e4m3fnuz.
if is_hip() and not is_navi4x():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
Expand Down Expand Up @@ -378,9 +378,9 @@ def process_weights_after_loading(self, layer: Module) -> None:

# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype
# If rocm (except Navi4x), use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz \
if is_hip() else torch.float8_e4m3fn
if is_hip() and not is_navi4x() else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
Expand Down Expand Up @@ -427,8 +427,9 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.w13_input_scale.max(), requires_grad=False)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz
if is_hip():
# If rocm (except Navi4x, which uses e4m3fn),
# normalize the weights and scales to e4m3fnuz
if is_hip() and not is_navi4x():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_hip
from vllm.utils import is_hip, is_navi4x

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
Expand Down Expand Up @@ -87,7 +87,8 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.use_fp8 = isinstance(quant_config, Fp8Config)
self.use_fp8 = isinstance(quant_config, Fp8Config) \
qli88 marked this conversation as resolved.
Show resolved Hide resolved
if is_hip() and not is_navi4x() else False
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
Expand Down Expand Up @@ -189,8 +190,10 @@ def __init__(
cache_config=cache_config,
quant_config=quant_config,
)
# For CUDA devices and Navi4x, attn_fp8_out will be set to false.
self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \
and is_hip() \
and not is_navi4x() \
qli88 marked this conversation as resolved.
Show resolved Hide resolved
and isinstance(quant_config, Fp8Config)

def forward(
Expand Down Expand Up @@ -225,7 +228,8 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.use_fp8 = isinstance(quant_config, Fp8Config)
self.use_fp8 = isinstance(quant_config, Fp8Config) \
qli88 marked this conversation as resolved.
Show resolved Hide resolved
if is_hip() and not is_navi4x() else False
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
Expand Down Expand Up @@ -456,7 +460,8 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn

if is_hip():
# Navi4x quantization should be treated as CUDA devices.
if is_hip() and not is_navi4x():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
Expand Down
12 changes: 12 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ipaddress
import os
import random
import re
import socket
import subprocess
import sys
Expand Down Expand Up @@ -425,6 +426,17 @@ def is_hip() -> bool:
return torch.version.hip is not None


@lru_cache(maxsize=None)
def is_navi4x() -> bool:
qli88 marked this conversation as resolved.
Show resolved Hide resolved
qli88 marked this conversation as resolved.
Show resolved Hide resolved
if not is_hip() or not torch.cuda.is_available():
return False
# All (visible) GPUs must be of the same type,
# otherwise FP8 results can't be guaranteed.
archName = torch.cuda.get_device_properties('cuda').gcnArchName
return (archName is not None) and \
(re.match("gfx12[0-9]{2}", archName) is not None)


@lru_cache(maxsize=None)
def is_cpu() -> bool:
from importlib.metadata import PackageNotFoundError, version
Expand Down