From 33de5fe3671ce6ee30c91905611ca5c710cc37a9 Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 13 Jun 2024 16:20:28 +0000 Subject: [PATCH] fix linter --- vllm/model_executor/models/llama.py | 37 +++++++++++++++++------------ 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 8759d1435c6ad..32dd5b96c8e40 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -52,6 +52,7 @@ import os + class LlamaMLP(nn.Module): def __init__( @@ -439,8 +440,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - - def load_quantized_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + def load_quantized_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]): + def load_ammo(): params_dict = dict(self.named_parameters()) quant_shards = [ @@ -454,7 +457,8 @@ def load_ammo(): ] for name, loaded_weight in weights: name = name.replace('transformer', 'model') - name = name.replace('kv_cache_scaling_factor', 'qkv.output_scaling_factor') + name = name.replace('kv_cache_scaling_factor', + 'qkv.output_scaling_factor') loaded_weight = loaded_weight.to("cuda") if loaded_weight.dtype == torch.int8: loaded_weight[loaded_weight == -128] = 0 @@ -478,15 +482,16 @@ def load_ammo(): continue name = name.replace(weight_name, param_name) param = params_dict[name] - if "activation_scaling_factor" in name or "weights_scaling_factor" in name: - param.data.copy_(loaded_weight) - elif "output_scaling_factor" in name: + if ("activation_scaling_factor" in name + or "weights_scaling_factor" in name + or "output_scaling_factor" in name): param.data.copy_(loaded_weight) else: weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) break + def load_quark(): params_dict = dict(self.named_parameters()) quant_shards = [ @@ -503,7 +508,7 @@ def load_quark(): ("weights_scaling_factor", "weight_quant_scale"), ("output_scaling_factor", "output_quant_scale"), ] - for name, loaded_weight in weights: + for name, loaded_weight in weights: if "zero_point" in name: continue if len(loaded_weight.shape) == 0: @@ -514,9 +519,9 @@ def load_quark(): continue name = name.replace(weight_name, scale_name) if loaded_weight.dtype == torch.int8: - loaded_weight[loaded_weight == -128] = 0 - assert loaded_weight.is_contiguous - loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz) + loaded_weight[loaded_weight == -128] = 0 + assert loaded_weight.is_contiguous + loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz) for (param_name, weight_name, shard_id) in quant_shards: if weight_name not in name: @@ -536,16 +541,18 @@ def load_quark(): continue name = name.replace(weight_name, param_name) param = params_dict[name] - if "activation_scaling_factor" in name or "weights_scaling_factor" in name: - param.data.copy_(loaded_weight) - elif "output_scaling_factor" in name: + if ("activation_scaling_factor" in name + or "weights_scaling_factor" in name + or "output_scaling_factor" in name): param.data.copy_(loaded_weight) else: weight_loader = getattr(param, "weight_loader", - default_weight_loader) + default_weight_loader) weight_loader(param, loaded_weight) break - load_func = load_ammo if os.getenv("VLLM_FP8_USE_AMMO") == "1" else load_quark + + load_func = load_ammo if os.getenv( + "VLLM_FP8_USE_AMMO") == "1" else load_quark load_func() # If this function is called, it should always initialize KV cache scale