Skip to content

Commit

Permalink
fix linter
Browse files Browse the repository at this point in the history
  • Loading branch information
charlifu committed Jun 13, 2024
1 parent 8cdee54 commit 33de5fe
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

import os


class LlamaMLP(nn.Module):

def __init__(
Expand Down Expand Up @@ -439,8 +440,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

def load_quantized_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

def load_quantized_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]):

def load_ammo():
params_dict = dict(self.named_parameters())
quant_shards = [
Expand All @@ -454,7 +457,8 @@ def load_ammo():
]
for name, loaded_weight in weights:
name = name.replace('transformer', 'model')
name = name.replace('kv_cache_scaling_factor', 'qkv.output_scaling_factor')
name = name.replace('kv_cache_scaling_factor',
'qkv.output_scaling_factor')
loaded_weight = loaded_weight.to("cuda")
if loaded_weight.dtype == torch.int8:
loaded_weight[loaded_weight == -128] = 0
Expand All @@ -478,15 +482,16 @@ def load_ammo():
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
if "activation_scaling_factor" in name or "weights_scaling_factor" in name:
param.data.copy_(loaded_weight)
elif "output_scaling_factor" in name:
if ("activation_scaling_factor" in name
or "weights_scaling_factor" in name
or "output_scaling_factor" in name):
param.data.copy_(loaded_weight)
else:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
break

def load_quark():
params_dict = dict(self.named_parameters())
quant_shards = [
Expand All @@ -503,7 +508,7 @@ def load_quark():
("weights_scaling_factor", "weight_quant_scale"),
("output_scaling_factor", "output_quant_scale"),
]
for name, loaded_weight in weights:
for name, loaded_weight in weights:
if "zero_point" in name:
continue
if len(loaded_weight.shape) == 0:
Expand All @@ -514,9 +519,9 @@ def load_quark():
continue
name = name.replace(weight_name, scale_name)
if loaded_weight.dtype == torch.int8:
loaded_weight[loaded_weight == -128] = 0
assert loaded_weight.is_contiguous
loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz)
loaded_weight[loaded_weight == -128] = 0
assert loaded_weight.is_contiguous
loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz)

for (param_name, weight_name, shard_id) in quant_shards:
if weight_name not in name:
Expand All @@ -536,16 +541,18 @@ def load_quark():
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
if "activation_scaling_factor" in name or "weights_scaling_factor" in name:
param.data.copy_(loaded_weight)
elif "output_scaling_factor" in name:
if ("activation_scaling_factor" in name
or "weights_scaling_factor" in name
or "output_scaling_factor" in name):
param.data.copy_(loaded_weight)
else:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
default_weight_loader)
weight_loader(param, loaded_weight)
break
load_func = load_ammo if os.getenv("VLLM_FP8_USE_AMMO") == "1" else load_quark

load_func = load_ammo if os.getenv(
"VLLM_FP8_USE_AMMO") == "1" else load_quark
load_func()

# If this function is called, it should always initialize KV cache scale
Expand Down

0 comments on commit 33de5fe

Please sign in to comment.