Skip to content

Commit

Permalink
add quantization_weights_path for fp8 weights
Browse files Browse the repository at this point in the history
  • Loading branch information
charlifu committed Jun 19, 2024
1 parent 719bf9d commit a9be7c9
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 4 deletions.
2 changes: 1 addition & 1 deletion ROCm_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ For more details, please refer to Quark's documentation.

To use ammo, please follow this [instruction](https://github.com/ROCm/vllm/tree/main/examples/fp8/quantizer), and set `VLLM_FP8_USE_AMMO=1`.

Both quantizers generate a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be placed under your model folder. Then we can run a model with fp8 quantization using vllm. When creating `vllm.LLM` object, two additional parameters should be added: `quantization="fp8"` and `quantization_param_path={relative path of the safetensors with your model path}`.
Both quantizers generate a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be placed under your model folder. Then we can run a model with fp8 quantization using vllm. When creating `vllm.LLM` object, two additional parameters should be added: `quantization="fp8"` and `quantization_weights_path={relative path of the safetensors with your model path}`.

## Gemm Tuning for Fp8

Expand Down
8 changes: 8 additions & 0 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def main(args: argparse.Namespace):
num_speculative_tokens=args.num_speculative_tokens,
tokenizer=args.tokenizer,
quantization=args.quantization,
quantization_weights_path=args.quantization_weights_path,
tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
Expand Down Expand Up @@ -175,6 +176,13 @@ def run_to_completion(profile_dir: Optional[str] = None):
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument(
'--quantization-weights-path',
type=str,
default=None,
help='Path to the safetensor file containing the quantized weights '
'and scaling factors. This should generally be supplied, when '
'quantization is FP8.')
parser.add_argument(
'--profile',
action='store_true',
Expand Down
10 changes: 10 additions & 0 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def run_vllm(
model: str,
tokenizer: str,
quantization: Optional[str],
quantization_weights_path: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
Expand All @@ -87,6 +88,7 @@ def run_vllm(
model=model,
tokenizer=tokenizer,
quantization=quantization,
quantization_weights_path=quantization_weights_path,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
Expand Down Expand Up @@ -222,6 +224,7 @@ def main(args: argparse.Namespace):
if args.backend == "vllm":
elapsed_time = run_vllm(
requests, args.model, args.tokenizer, args.quantization,
args.quantization_weights_path,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
Expand Down Expand Up @@ -342,6 +345,13 @@ def main(args: argparse.Namespace):
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument(
'--quantization-weights-path',
type=str,
default=None,
help='Path to the safetensor file containing the quantized weights '
'and scaling factors. This should generally be supplied, when '
'quantization is FP8.')
parser.add_argument(
"--device",
type=str,
Expand Down
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
quantization_weights_path: Optional[str] = None,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
Expand All @@ -116,6 +117,7 @@ def __init__(
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.quantization_param_path = quantization_param_path
self.quantization_weights_path = quantization_weights_path
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
if self.max_context_len_to_capture is not None:
Expand Down
11 changes: 10 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class EngineArgs:
rope_scaling: Optional[dict] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
quantization_weights_path: Optional[str] = None
enforce_eager: bool = False
max_context_len_to_capture: Optional[int] = None
max_seq_len_to_capture: int = 32768
Expand Down Expand Up @@ -337,6 +338,13 @@ def add_cli_args(
'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data '
'type of the weights.')
parser.add_argument(
'--quantization-weights-path',
type=nullable_str,
default=None,
help='Path to the safetensor file containing the quantized weights '
'and scaling factors. This should generally be supplied, when '
'quantization is FP8.')
parser.add_argument('--rope-scaling',
default=None,
type=json.loads,
Expand Down Expand Up @@ -562,7 +570,8 @@ def create_engine_config(self, ) -> EngineConfig:
self.trust_remote_code, self.dtype, self.seed, self.revision,
self.code_revision, self.rope_scaling, self.tokenizer_revision,
self.max_model_len, self.quantization,
self.quantization_param_path, self.enforce_eager,
self.quantization_param_path, self.quantization_weights_path,
self.enforce_eager,
self.max_context_len_to_capture, self.max_seq_len_to_capture,
self.max_logprobs, self.disable_sliding_window,
self.skip_tokenizer_init, self.served_model_name)
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,11 @@ def load_model(self, *, model_config: ModelConfig,
"fall_back_to_pt_during_load",
True)), )
if (model_config.quantization == 'fp8'
and model_config.quantization_param_path is not None):
and model_config.quantization_weights_path is not None):
model.load_quantized_weights(
safetensors_weights_iterator([
model_config.model +
model_config.quantization_param_path
model_config.quantization_weights_path
]))
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
Expand Down

0 comments on commit a9be7c9

Please sign in to comment.