diff --git a/neural_compressor/adaptor/torch_utils/awq.py b/neural_compressor/adaptor/torch_utils/awq.py index daa17271431..3de46c82a6c 100644 --- a/neural_compressor/adaptor/torch_utils/awq.py +++ b/neural_compressor/adaptor/torch_utils/awq.py @@ -244,7 +244,7 @@ def search_scale(self, block, block_name, module_list, input_values): x_max = _get_act_scale(input_val) absorbed_modules = {_m: fetch_module(block, _m) for _m in module_name_list} # Step 4: collect origin output for MSE and state_dict for recover. - org_stat = {_m: module.state_dict() for _m, module in absorbed_modules.items()} + org_stat = {_m: copy.deepcopy(module.state_dict()) for _m, module in absorbed_modules.items()} if len(module_tuple) > 1: # use block inference for multi-modules org_out = self.block_inference(block) @@ -364,7 +364,7 @@ def search_clip(self, block_name, module_list, input_values): # Step 2: update module name module = fetch_module(self.model, module_name) # Step 3: collect origin output for MSE and state_dict for recover. - org_stat = module.state_dict() + org_stat = copy.deepcopy(module.state_dict()) org_out = self.module_inference(module, input_val) # Step 4: set different clip range for weight and compare the MSE loss. logger.info("Searching the best clip range with AWQ algorithm") diff --git a/neural_compressor/adaptor/torch_utils/teq.py b/neural_compressor/adaptor/torch_utils/teq.py index d22302839b2..716dcf236b2 100644 --- a/neural_compressor/adaptor/torch_utils/teq.py +++ b/neural_compressor/adaptor/torch_utils/teq.py @@ -294,7 +294,7 @@ def quantize(self): group_size = self.weight_config[n]["group_size"] scheme = self.weight_config[n]["scheme"] if isinstance(m, torch.nn.Linear): # pragma: no cover - m.weight.data.copy_(quant_weight(m.weight, num_bits=num_bits, group_size=group_size, scheme=scheme)) + quant_weight(m.weight.data, num_bits=num_bits, group_size=group_size, scheme=scheme) def save(self, save_scale_file="", save_state_dict_file=""): """ diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index c29994f7755..65bc1fb4e80 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -80,7 +80,7 @@ def quantize_4bit(tensor, quantile=1.0, data_type="nf4", return_int=False): # get scale and update tensor scale = tensor.abs().max(1)[0] * quantile / max(allow_data) scale.unsqueeze_(dim=-1) - tensor = tensor / scale + tensor.div_(scale) mid_data = [(allow_data[i] + allow_data[i + 1]) / 2 for i in range(len(allow_data) - 1)] q_tensor = torch.zeros_like(tensor) for i in range(len(allow_data)): @@ -91,9 +91,10 @@ def quantize_4bit(tensor, quantile=1.0, data_type="nf4", return_int=False): q_tensor += torch.where(tensor > mid_data[i - 1], data, 0) else: q_tensor += torch.where((mid_data[i - 1] < tensor) & (tensor <= mid_data[i]), data, 0) + tensor.copy_(q_tensor) if return_int: - return q_tensor.type(torch.int8), scale.type(torch.float), None - return q_tensor * scale + return tensor.type(torch.int8), scale.type(torch.float), None + return tensor.mul_(scale) def qdq_weight_asym(weight, num_bits=4, quantile=1.0, return_int=False): @@ -122,10 +123,14 @@ def qdq_weight_asym(weight, num_bits=4, quantile=1.0, return_int=False): zp = torch.round(-wmin / scale) scale.unsqueeze_(dim=-1) zp.unsqueeze_(dim=-1) - q = torch.clamp(torch.round(weight / scale) + zp, 0, maxq) + weight.div_(scale) + weight.round_() + weight.add_(zp) + weight.clamp_(0, maxq) if return_int: - return q.type(torch.uint8), scale.type(torch.float), zp.type(torch.uint8) - return scale * (q - zp) + return weight.type(torch.uint8), scale.type(torch.float), zp.type(torch.uint8) + weight.sub_(zp) + return weight.mul_(scale) def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False, full_range=False): @@ -167,10 +172,12 @@ def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False, full_rang else: scale = wmax / maxq scale.unsqueeze_(dim=-1) - q = torch.clamp(torch.round(weight / scale), minq, maxq) + weight.div_(scale) + weight.round_() + weight.clamp_(minq, maxq) if return_int: - return q.type(torch.int8), scale.type(torch.float), None - return scale * q + return weight.type(torch.int8), scale.type(torch.float), None + return weight.mul_(scale) def qdq_weight_actor(weight, num_bits, scheme, quantile=1.0, data_type="int", return_int=False, full_range=False): @@ -200,7 +207,7 @@ def qdq_weight_actor(weight, num_bits, scheme, quantile=1.0, data_type="int", re def quant_weight( weight, num_bits=4, group_size=-1, scheme="asym", quantile=1.0, data_type="int", return_int=False, full_range=False ): - """Quant and dequant tensor with group size. + """Quant and dequant tensor with group size. It is an in-place op. Args: weight: input weight @@ -248,7 +255,7 @@ def quant_weight( zp = zp.reshape(orig_shape[0], -1) return weight, scale, zp else: - weight = qdq_weight_actor( + qdq_weight_actor( weight, num_bits, scheme=scheme, data_type=data_type, quantile=quantile, full_range=full_range ) return weight.reshape(orig_shape) @@ -285,7 +292,6 @@ def quant_weight( return_int=True, full_range=full_range, ) - weight = torch.cat([weight1, weight2], dim=1) scale = torch.cat([scale1, scale2], dim=1) if zp2 is not None: zp = torch.cat([zp1, zp2], dim=1) @@ -296,7 +302,6 @@ def quant_weight( weight2 = qdq_weight_actor( weight2, num_bits, scheme=scheme, data_type=data_type, quantile=quantile, full_range=full_range ) - weight = torch.cat([weight1, weight2], dim=1) return weight @@ -314,7 +319,7 @@ def search_clip(m, num_bits=4, group_size=32, scheme="asym", data_type="int", en Returns: best_clip_ratio (float): best percentile of clip """ - org_weight = m.weight.data + org_weight = m.weight.data.clone() logger.info("Searching the best clip range with RTN algorithm") best_error = float("inf") best_clip_ratio = None @@ -397,82 +402,84 @@ def rtn_quantize( scale_dtype = kwargs.get("scale_dtype", torch.float32) device = kwargs.get("device", "cpu") use_optimum_format = kwargs.get("use_optimum_format", True) - for name, m in model.named_modules(): - if m.__class__.__name__ not in supported_layers: - continue - orig_dtype = next(m.parameters()).dtype - if orig_dtype != torch.float: - m = m.float() - if name in weight_config: # pragma: no cover - num_bits = weight_config[name]["bits"] - group_size = weight_config[name]["group_size"] - scheme = weight_config[name]["scheme"] - quantile = weight_config[name].get("quantile", 1.0) - logger.debug(f"RTN quantized module:{name, m}") - log_msg = ( - f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, " - + f"scheme={scheme}, quantile={quantile}" - ) - if data_type != "int": - log_msg += f", dtype={data_type}" - elif scheme == "sym": # nf4/fp4 is always [-7,7] - log_msg += f", enable_full_range={enable_full_range}" - logger.debug(log_msg) - if num_bits <= 0: - logger.info(f"Skip {name}") - continue - weight = m.weight.T if group_dim == 0 else m.weight - if enable_mse_search: - quantile = search_clip(m, num_bits, group_size, scheme, data_type, enable_full_range) - if return_int: - from .model_wrapper import WeightOnlyLinear - - int_weight, scale, zp = quant_weight( - weight, - num_bits, - group_size, - scheme, - quantile, - data_type=data_type, - return_int=True, - full_range=enable_full_range, + with torch.no_grad(): + for name, m in model.named_modules(): + if m.__class__.__name__ not in supported_layers: + continue + orig_dtype = next(m.parameters()).dtype + if orig_dtype != torch.float: + m = m.float() + if name in weight_config: # pragma: no cover + num_bits = weight_config[name]["bits"] + group_size = weight_config[name]["group_size"] + scheme = weight_config[name]["scheme"] + quantile = weight_config[name].get("quantile", 1.0) + logger.debug(f"RTN quantized module:{name, m}") + log_msg = ( + f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, " + + f"scheme={scheme}, quantile={quantile}" ) - int_weight = int_weight.T if group_dim == 0 else int_weight - scale = scale.T if group_dim == 0 else scale - zp = zp.T if group_dim == 0 and zp is not None else zp - new_module = WeightOnlyLinear( - m.in_features, - m.out_features, - num_bits, - group_size, - dtype=data_type, - zp=zp is not None, - bias=m.bias is not None, - compression_dtype=compression_dtype, - compression_dim=compression_dim, - scale_dtype=scale_dtype, - device=device, - use_optimum_format=use_optimum_format, - ) - new_module.pack(int_weight, scale, zp, m.bias) - if name == "": - return new_module + if data_type != "int": + log_msg += f", dtype={data_type}" + elif scheme == "sym": # nf4/fp4 is always [-7,7] + log_msg += f", enable_full_range={enable_full_range}" + logger.debug(log_msg) + if num_bits <= 0: + logger.info(f"Skip {name}") + continue + weight = m.weight.T if group_dim == 0 else m.weight + if enable_mse_search: + quantile = search_clip(m, num_bits, group_size, scheme, data_type, enable_full_range) + if return_int: + from .model_wrapper import WeightOnlyLinear + + _, scale, zp = quant_weight( + weight, + num_bits, + group_size, + scheme, + quantile, + data_type=data_type, + return_int=True, + full_range=enable_full_range, + ) + if group_dim == 0: + weight.transpose_(0, 1) + scale = scale.T if group_dim == 0 else scale + zp = zp.T if group_dim == 0 and zp is not None else zp + new_module = WeightOnlyLinear( + m.in_features, + m.out_features, + num_bits, + group_size, + dtype=data_type, + zp=zp is not None, + bias=m.bias is not None, + compression_dtype=compression_dtype, + compression_dim=compression_dim, + scale_dtype=scale_dtype, + device=device, + use_optimum_format=use_optimum_format, + ) + new_module.pack(weight, scale, zp, m.bias) + if name == "": + return new_module + else: + set_module(model, name, new_module) else: - set_module(model, name, new_module) - else: - q_weight = quant_weight( - weight, - num_bits, - group_size, - scheme, - quantile, - data_type=data_type, - full_range=enable_full_range, - ) - q_weight = q_weight.T if group_dim == 0 else q_weight - m.weight.data.copy_(q_weight) - if orig_dtype != torch.float: - m = m.to(orig_dtype) + quant_weight( + weight, + num_bits, + group_size, + scheme, + quantile, + data_type=data_type, + full_range=enable_full_range, + ) + if group_dim == 0: + weight.transpose_(0, 1) + if orig_dtype != torch.float: + m = m.to(orig_dtype) return model