Skip to content

Commit

Permalink
Quantize weight with in-place mode in weight-only quantization (#1511)
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng, Penghui <penghui.cheng@intel.com>
  • Loading branch information
PenghuiCheng authored Jan 12, 2024
1 parent 5b2a887 commit deb1ed5
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 91 deletions.
4 changes: 2 additions & 2 deletions neural_compressor/adaptor/torch_utils/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def search_scale(self, block, block_name, module_list, input_values):
x_max = _get_act_scale(input_val)
absorbed_modules = {_m: fetch_module(block, _m) for _m in module_name_list}
# Step 4: collect origin output for MSE and state_dict for recover.
org_stat = {_m: module.state_dict() for _m, module in absorbed_modules.items()}
org_stat = {_m: copy.deepcopy(module.state_dict()) for _m, module in absorbed_modules.items()}
if len(module_tuple) > 1:
# use block inference for multi-modules
org_out = self.block_inference(block)
Expand Down Expand Up @@ -364,7 +364,7 @@ def search_clip(self, block_name, module_list, input_values):
# Step 2: update module name
module = fetch_module(self.model, module_name)
# Step 3: collect origin output for MSE and state_dict for recover.
org_stat = module.state_dict()
org_stat = copy.deepcopy(module.state_dict())
org_out = self.module_inference(module, input_val)
# Step 4: set different clip range for weight and compare the MSE loss.
logger.info("Searching the best clip range with AWQ algorithm")
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/torch_utils/teq.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def quantize(self):
group_size = self.weight_config[n]["group_size"]
scheme = self.weight_config[n]["scheme"]
if isinstance(m, torch.nn.Linear): # pragma: no cover
m.weight.data.copy_(quant_weight(m.weight, num_bits=num_bits, group_size=group_size, scheme=scheme))
quant_weight(m.weight.data, num_bits=num_bits, group_size=group_size, scheme=scheme)

def save(self, save_scale_file="", save_state_dict_file=""):
"""
Expand Down
183 changes: 95 additions & 88 deletions neural_compressor/adaptor/torch_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def quantize_4bit(tensor, quantile=1.0, data_type="nf4", return_int=False):
# get scale and update tensor
scale = tensor.abs().max(1)[0] * quantile / max(allow_data)
scale.unsqueeze_(dim=-1)
tensor = tensor / scale
tensor.div_(scale)
mid_data = [(allow_data[i] + allow_data[i + 1]) / 2 for i in range(len(allow_data) - 1)]
q_tensor = torch.zeros_like(tensor)
for i in range(len(allow_data)):
Expand All @@ -91,9 +91,10 @@ def quantize_4bit(tensor, quantile=1.0, data_type="nf4", return_int=False):
q_tensor += torch.where(tensor > mid_data[i - 1], data, 0)
else:
q_tensor += torch.where((mid_data[i - 1] < tensor) & (tensor <= mid_data[i]), data, 0)
tensor.copy_(q_tensor)
if return_int:
return q_tensor.type(torch.int8), scale.type(torch.float), None
return q_tensor * scale
return tensor.type(torch.int8), scale.type(torch.float), None
return tensor.mul_(scale)


def qdq_weight_asym(weight, num_bits=4, quantile=1.0, return_int=False):
Expand Down Expand Up @@ -122,10 +123,14 @@ def qdq_weight_asym(weight, num_bits=4, quantile=1.0, return_int=False):
zp = torch.round(-wmin / scale)
scale.unsqueeze_(dim=-1)
zp.unsqueeze_(dim=-1)
q = torch.clamp(torch.round(weight / scale) + zp, 0, maxq)
weight.div_(scale)
weight.round_()
weight.add_(zp)
weight.clamp_(0, maxq)
if return_int:
return q.type(torch.uint8), scale.type(torch.float), zp.type(torch.uint8)
return scale * (q - zp)
return weight.type(torch.uint8), scale.type(torch.float), zp.type(torch.uint8)
weight.sub_(zp)
return weight.mul_(scale)


def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False, full_range=False):
Expand Down Expand Up @@ -167,10 +172,12 @@ def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False, full_rang
else:
scale = wmax / maxq
scale.unsqueeze_(dim=-1)
q = torch.clamp(torch.round(weight / scale), minq, maxq)
weight.div_(scale)
weight.round_()
weight.clamp_(minq, maxq)
if return_int:
return q.type(torch.int8), scale.type(torch.float), None
return scale * q
return weight.type(torch.int8), scale.type(torch.float), None
return weight.mul_(scale)


def qdq_weight_actor(weight, num_bits, scheme, quantile=1.0, data_type="int", return_int=False, full_range=False):
Expand Down Expand Up @@ -200,7 +207,7 @@ def qdq_weight_actor(weight, num_bits, scheme, quantile=1.0, data_type="int", re
def quant_weight(
weight, num_bits=4, group_size=-1, scheme="asym", quantile=1.0, data_type="int", return_int=False, full_range=False
):
"""Quant and dequant tensor with group size.
"""Quant and dequant tensor with group size. It is an in-place op.
Args:
weight: input weight
Expand Down Expand Up @@ -248,7 +255,7 @@ def quant_weight(
zp = zp.reshape(orig_shape[0], -1)
return weight, scale, zp
else:
weight = qdq_weight_actor(
qdq_weight_actor(
weight, num_bits, scheme=scheme, data_type=data_type, quantile=quantile, full_range=full_range
)
return weight.reshape(orig_shape)
Expand Down Expand Up @@ -285,7 +292,6 @@ def quant_weight(
return_int=True,
full_range=full_range,
)
weight = torch.cat([weight1, weight2], dim=1)
scale = torch.cat([scale1, scale2], dim=1)
if zp2 is not None:
zp = torch.cat([zp1, zp2], dim=1)
Expand All @@ -296,7 +302,6 @@ def quant_weight(
weight2 = qdq_weight_actor(
weight2, num_bits, scheme=scheme, data_type=data_type, quantile=quantile, full_range=full_range
)
weight = torch.cat([weight1, weight2], dim=1)
return weight


Expand All @@ -314,7 +319,7 @@ def search_clip(m, num_bits=4, group_size=32, scheme="asym", data_type="int", en
Returns:
best_clip_ratio (float): best percentile of clip
"""
org_weight = m.weight.data
org_weight = m.weight.data.clone()
logger.info("Searching the best clip range with RTN algorithm")
best_error = float("inf")
best_clip_ratio = None
Expand Down Expand Up @@ -397,82 +402,84 @@ def rtn_quantize(
scale_dtype = kwargs.get("scale_dtype", torch.float32)
device = kwargs.get("device", "cpu")
use_optimum_format = kwargs.get("use_optimum_format", True)
for name, m in model.named_modules():
if m.__class__.__name__ not in supported_layers:
continue
orig_dtype = next(m.parameters()).dtype
if orig_dtype != torch.float:
m = m.float()
if name in weight_config: # pragma: no cover
num_bits = weight_config[name]["bits"]
group_size = weight_config[name]["group_size"]
scheme = weight_config[name]["scheme"]
quantile = weight_config[name].get("quantile", 1.0)
logger.debug(f"RTN quantized module:{name, m}")
log_msg = (
f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, "
+ f"scheme={scheme}, quantile={quantile}"
)
if data_type != "int":
log_msg += f", dtype={data_type}"
elif scheme == "sym": # nf4/fp4 is always [-7,7]
log_msg += f", enable_full_range={enable_full_range}"
logger.debug(log_msg)
if num_bits <= 0:
logger.info(f"Skip {name}")
continue
weight = m.weight.T if group_dim == 0 else m.weight
if enable_mse_search:
quantile = search_clip(m, num_bits, group_size, scheme, data_type, enable_full_range)
if return_int:
from .model_wrapper import WeightOnlyLinear

int_weight, scale, zp = quant_weight(
weight,
num_bits,
group_size,
scheme,
quantile,
data_type=data_type,
return_int=True,
full_range=enable_full_range,
with torch.no_grad():
for name, m in model.named_modules():
if m.__class__.__name__ not in supported_layers:
continue
orig_dtype = next(m.parameters()).dtype
if orig_dtype != torch.float:
m = m.float()
if name in weight_config: # pragma: no cover
num_bits = weight_config[name]["bits"]
group_size = weight_config[name]["group_size"]
scheme = weight_config[name]["scheme"]
quantile = weight_config[name].get("quantile", 1.0)
logger.debug(f"RTN quantized module:{name, m}")
log_msg = (
f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, "
+ f"scheme={scheme}, quantile={quantile}"
)
int_weight = int_weight.T if group_dim == 0 else int_weight
scale = scale.T if group_dim == 0 else scale
zp = zp.T if group_dim == 0 and zp is not None else zp
new_module = WeightOnlyLinear(
m.in_features,
m.out_features,
num_bits,
group_size,
dtype=data_type,
zp=zp is not None,
bias=m.bias is not None,
compression_dtype=compression_dtype,
compression_dim=compression_dim,
scale_dtype=scale_dtype,
device=device,
use_optimum_format=use_optimum_format,
)
new_module.pack(int_weight, scale, zp, m.bias)
if name == "":
return new_module
if data_type != "int":
log_msg += f", dtype={data_type}"
elif scheme == "sym": # nf4/fp4 is always [-7,7]
log_msg += f", enable_full_range={enable_full_range}"
logger.debug(log_msg)
if num_bits <= 0:
logger.info(f"Skip {name}")
continue
weight = m.weight.T if group_dim == 0 else m.weight
if enable_mse_search:
quantile = search_clip(m, num_bits, group_size, scheme, data_type, enable_full_range)
if return_int:
from .model_wrapper import WeightOnlyLinear

_, scale, zp = quant_weight(
weight,
num_bits,
group_size,
scheme,
quantile,
data_type=data_type,
return_int=True,
full_range=enable_full_range,
)
if group_dim == 0:
weight.transpose_(0, 1)
scale = scale.T if group_dim == 0 else scale
zp = zp.T if group_dim == 0 and zp is not None else zp
new_module = WeightOnlyLinear(
m.in_features,
m.out_features,
num_bits,
group_size,
dtype=data_type,
zp=zp is not None,
bias=m.bias is not None,
compression_dtype=compression_dtype,
compression_dim=compression_dim,
scale_dtype=scale_dtype,
device=device,
use_optimum_format=use_optimum_format,
)
new_module.pack(weight, scale, zp, m.bias)
if name == "":
return new_module
else:
set_module(model, name, new_module)
else:
set_module(model, name, new_module)
else:
q_weight = quant_weight(
weight,
num_bits,
group_size,
scheme,
quantile,
data_type=data_type,
full_range=enable_full_range,
)
q_weight = q_weight.T if group_dim == 0 else q_weight
m.weight.data.copy_(q_weight)
if orig_dtype != torch.float:
m = m.to(orig_dtype)
quant_weight(
weight,
num_bits,
group_size,
scheme,
quantile,
data_type=data_type,
full_range=enable_full_range,
)
if group_dim == 0:
weight.transpose_(0, 1)
if orig_dtype != torch.float:
m = m.to(orig_dtype)
return model


Expand Down

0 comments on commit deb1ed5

Please sign in to comment.