Skip to content

Commit

Permalink
Enhance WeightOnlyLinear capability (#1095)
Browse files Browse the repository at this point in the history
Signed-off-by: Xin He <xin3.he@intel.com>
  • Loading branch information
xin3he authored Jul 19, 2023
1 parent 8593158 commit 59172ad
Show file tree
Hide file tree
Showing 10 changed files with 267 additions and 91 deletions.
2 changes: 2 additions & 0 deletions .azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2686,3 +2686,5 @@ Chatbot
chatbot
fba
hostname
qweight
qconfig
2 changes: 1 addition & 1 deletion docs/source/quantization_weight_only.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ q_model = quantization.fit(model, conf, eval_func=eval_func)
q_model.save('saved_results')
```

The saved_results folder contains two files: `best_model.pt` and `weight_config.json`, and the generated q_model is a fake quantized model.
The saved_results folder contains two files: `best_model.pt` and `qconfig.json`, and the generated q_model is a fake quantized model.

## Reference

Expand Down
13 changes: 7 additions & 6 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4535,9 +4535,9 @@ def quantize(self, tune_cfg, model, dataloader, calib_func=None):
def rtn_quantize(self, model, tune_cfg):
logger.debug("quantizing with the round-to-nearest algorithm")
if 'rtn_args' in self.recipes:
full_range = self.recipes['rtn_args'].get('full_range', False)
sym_full_range = self.recipes['rtn_args'].get('sym_full_range', False)
else:
full_range=False
sym_full_range=False
from .torch_utils.weight_only import rtn_quantize
from .torch_utils.util import fetch_module
for key, config in tune_cfg['op'].items():
Expand All @@ -4553,7 +4553,8 @@ def rtn_quantize(self, model, tune_cfg):
continue
m = fetch_module(model, op_name)
m = rtn_quantize(m, num_bits, group_size, scheme,
return_int=False, full_range=full_range)
return_int=False,
sym_full_range=sym_full_range)
set_module(model, op_name, m)
return model

Expand Down Expand Up @@ -4651,9 +4652,9 @@ def awq_quantize(self, model, tune_cfg, dataloader, calib_func):
else:
auto_scale, mse_range = True, True
if 'rtn_args' in self.recipes:
full_range = self.recipes['rtn_args'].get('full_range', False)
sym_full_range = self.recipes['rtn_args'].get('sym_full_range', False)
else:
full_range=False
sym_full_range=False
calib_sampling_size = tune_cfg.get('calib_sampling_size', 1)
model = awq_quantize(
model,
Expand All @@ -4666,7 +4667,7 @@ def awq_quantize(self, model, tune_cfg, dataloader, calib_func):
calib_func=calib_func,
n_blocks=n_blocks,
return_int=False,
full_range=full_range,
sym_full_range=sym_full_range,
)
return model

Expand Down
112 changes: 83 additions & 29 deletions neural_compressor/adaptor/torch_utils/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,81 +151,128 @@ def _wrapper_qdq_linear(tmp_model, module_name_list=[]):


class WeightOnlyLinear(torch.nn.Module):
def __init__(self, in_features, out_features, bits, groupsize):
def __init__(self, in_features, out_features, bits, groupsize,
zp=False, bias=False, scale_dtype=torch.float32,
compression_dtype=torch.int32, compression_dim=1):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.bits = bits
self.groupsize = groupsize if groupsize != -1 else in_features
self.n_pack = 32 // self.bits

self.register_buffer(
'packed_weight',
torch.zeros(
(out_features, math.ceil(in_features / self.n_pack)),
dtype=torch.int32,
)
)
self.compression_dim = compression_dim
assert compression_dtype in [torch.int8, torch.int16, torch.int32, torch.int64], \
"Only support torch.int8|16|32|64 as compressed dtype."
dtype_bits_mapping = {torch.int8: 8, torch.int16: 16, torch.int32: 32, torch.int64: 64}
self.compress_bits = dtype_bits_mapping[compression_dtype]
self.n_pack = self.compress_bits // self.bits
self.compressed_dtype = compression_dtype
self.float_type = scale_dtype
# K is input channel, N is output channel
assert compression_dim in [0, 1], "Only support 0 or 1 as compression dimension, " +\
"0 is output channel, 1 is input channel."
self.register_buffer(
'scale',
torch.zeros(
(out_features, math.ceil(in_features / self.groupsize)),
dtype=torch.float,
dtype=self.float_type,
)
)
if compression_dim == 1:
self.register_buffer(
'packed_weight',
torch.zeros(
(out_features, math.ceil(in_features / self.n_pack)),
dtype=self.compressed_dtype,
)
)
if zp:
self.register_buffer(
'packed_zp',
torch.zeros(
(self.out_features, math.ceil(self.in_features / self.groupsize / self.n_pack)),
dtype=self.compressed_dtype,
)
)
else:
self.register_buffer(
'packed_weight',
torch.zeros(
(math.ceil(out_features / self.n_pack), in_features),
dtype=self.compressed_dtype,
)
)
if zp:
self.register_buffer(
'packed_zp',
torch.zeros(
(
math.ceil(self.out_features / self.n_pack),
math.ceil(self.in_features / self.groupsize)
),
dtype=self.compressed_dtype,
)
)
if bias:
self.register_buffer('bias', torch.zeros(self.out_features, dtype=self.float_type))
else:
self.bias = None

def pack(self, int_weight, scale, zp, bias):
if bias is not None:
self.register_buffer('bias', torch.zeros(self.out_features, dtype=torch.float))
else:
self.bias = None
self.bias = bias
assert hasattr(self, 'bias'), "bias is not set when initializing."
self.bias = bias.type(self.float_type)
assert scale.shape == self.scale.shape, "Scale shape is mismatched."
self.scale = scale
self.scale = scale.type(self.float_type)
if self.compression_dim == 0:
int_weight = int_weight.T
self.packed_weight = self.packed_weight.T
origin_shape = int_weight.shape
target_shape = self.packed_weight.shape
assert origin_shape[0] == target_shape[0], "output channels mismatch, please check."
mask = torch.tensor(2**self.bits - 1, dtype=torch.int32)
mask = torch.tensor(2**self.bits - 1, dtype=self.compressed_dtype)

# pack weight
for i in range(target_shape[0]):
for j in range(target_shape[1]):
start = self.n_pack * j
end = self.n_pack * (j + 1)
tmp = int_weight[i][start: end].type(torch.int32)
tmp = int_weight[i][start: end].type(self.compressed_dtype)
for e in range(len(tmp)):
tmp[e] &= mask
tmp[e] = tmp[e] << self.bits * (self.n_pack - 1 - e)
self.packed_weight[i][j] |= tmp[e]
if self.compression_dim == 0:
self.packed_weight = self.packed_weight.T

if zp is not None:
# pack zero_points
self.register_buffer(
'packed_zp',
torch.zeros(
(self.out_features, math.ceil(self.in_features / self.groupsize / self.n_pack)),
dtype=torch.int32,
)
)
if self.compression_dim == 0:
zp = zp.T
self.packed_zp = self.packed_zp.T
assert hasattr(self, 'packed_zp'), "zp is not set when initializing."
target_shape = self.packed_zp.shape
for i in range(target_shape[0]):
for j in range(target_shape[1]):
start = self.n_pack * j
end = self.n_pack * (j + 1)
tmp = zp[i][start: end].type(torch.int32)
tmp = zp[i][start: end].type(self.compressed_dtype)
for e in range(len(tmp)):
tmp[e] &= mask
tmp[e] = tmp[e] << self.bits * (self.n_pack - 1 - e)
self.packed_zp[i][j] |= tmp[e]
if self.compression_dim == 0:
self.packed_zp = self.packed_zp.T

def recover(self):
mask = torch.tensor(2**self.bits - 1, dtype=torch.int32)
mask = torch.tensor(2**self.bits - 1, dtype=self.compressed_dtype)
if hasattr(self, 'packed_zp'):
weight_dtype = torch.uint8
else:
weight_dtype = torch.int8
# unpack weight
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype)
if self.compression_dim == 0:
weight = weight.T
self.packed_weight = self.packed_weight.T
origin_shape = weight.shape
target_shape = self.packed_weight.shape
for i in range(target_shape[0]):
Expand All @@ -240,9 +287,14 @@ def recover(self):
if weight_dtype == torch.uint8:
tmp &= mask # remove sign bit
weight[i][index] = tmp.type(weight_dtype)
if self.compression_dim == 0:
weight = weight.T
# unpack zero_point
if hasattr(self, 'packed_zp'):
zp_dtype = torch.int32 # to avoid overflow when weight-zp
if self.compression_dim == 0:
zp = zp.T
self.packed_zp = self.packed_zp.T
zp_dtype = self.compressed_dtype # to avoid overflow when weight-zp
zp = torch.zeros(self.scale.shape, dtype=zp_dtype)
origin_shape = zp.shape
target_shape = self.packed_zp.shape
Expand All @@ -257,6 +309,8 @@ def recover(self):
tmp = tmp >> 32 - self.bits
tmp &= mask
zp[i][index] = tmp.type(zp_dtype)
if self.compression_dim == 0:
zp = zp.T
# recover fp32 weight with int_weight, scale, and zero_point
left_element = self.in_features % self.groupsize
if left_element != 0:
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/adaptor/torch_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,9 +918,9 @@ def get_op_type_by_name(op_name, quantizable_ops):
return None

def collect_weight_info(q_config):
"""collect weight info from q_config for dumping into weight_config.json
"""collect weight info from q_config for dumping into qconfig.json
weight_config.json example:
qconfig.json example:
```
{
'fc': {
Expand Down
41 changes: 26 additions & 15 deletions neural_compressor/adaptor/torch_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False, full_rang
minq = torch.tensor(2 ** (num_bits - 1) - 1)
max_val = torch.max(weight, 1)[0]
min_val = torch.min(weight, 1)[0]
flip_flag = torch.abs(min_val) > torch.abs(max_val)
flip_flag = torch.abs(max_val) > torch.abs(min_val)
wmax = torch.max(torch.abs(max_val), torch.abs(min_val))
wmax = wmax * quantile
tmp = (wmax == 0)
Expand Down Expand Up @@ -175,6 +175,9 @@ def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", quantile=1.0,
weight1, num_bits, scheme=scheme,
quantile=quantile, return_int=True, full_range=full_range
)
scale1 = scale1.reshape(orig_shape[0], -1)
if zp1 is not None:
zp1 = zp1.reshape(orig_shape[0], -1)
else:
weight1 = qdq_weight_actor(
weight1, num_bits, scheme=scheme, quantile=quantile, full_range=full_range
Expand All @@ -187,14 +190,11 @@ def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", quantile=1.0,
quantile=quantile, return_int=True, full_range=full_range
)
weight = torch.cat([weight1, weight2], dim=1)
scale = torch.cat([scale1, scale2], dim=0)
scale = torch.cat([scale1, scale2], dim=1)
if zp2 is not None:
zp = torch.cat([zp1, zp2], dim=0)
zp = torch.cat([zp1, zp2], dim=1)
else:
zp = None
scale = scale.reshape(orig_shape[0], -1)
if zp is not None:
zp = zp.reshape(orig_shape[0], -1)
return weight, scale, zp
else:
weight2 = qdq_weight_actor(
Expand All @@ -206,7 +206,8 @@ def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", quantile=1.0,


def rtn_quantize(model, num_bits=4, group_size=32, scheme="asym",
quantile=1.0, weight_config={}, return_int=False, full_range=False):
quantile=1.0, weight_config={}, return_int=False,
sym_full_range=False, **kwargs):
"""Quant the model with round to nearst method.
Args:
Expand All @@ -227,13 +228,18 @@ def rtn_quantize(model, num_bits=4, group_size=32, scheme="asym",
}
return_int (bool, optional): Choose return fp32 or int32 model.
Defaults to False.
full_range (bool, optional): Choose sym range whether use -2**(bits-1).
sym_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
Defaults to False.
Returns:
model: fake quantized torch module
"""
assert isinstance(model, torch.nn.Module), "only support torch module"
supported_layers = ['Linear']
if return_int:
compression_dtype = kwargs.get("compression_dtype", torch.int32)
compression_dim = kwargs.get("compression_dim", 1)
scale_dtype = kwargs.get("scale_dtype", torch.float32)
for n, m in model.named_modules():
if m.__class__.__name__ not in supported_layers:
continue
Expand All @@ -245,7 +251,7 @@ def rtn_quantize(model, num_bits=4, group_size=32, scheme="asym",
logger.debug(f"RTN quantized module:{n, m}")
if scheme == 'sym':
logger.debug(f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, " + \
f"scheme={scheme}, quantile={quantile}, full_range={full_range}")
f"scheme={scheme}, quantile={quantile}, sym_full_range={sym_full_range}")
else:
logger.debug(f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, " + \
f"scheme={scheme}, quantile={quantile}")
Expand All @@ -257,10 +263,14 @@ def rtn_quantize(model, num_bits=4, group_size=32, scheme="asym",
from .model_wrapper import WeightOnlyLinear
int_weight, scale, zp = quant_weight(
weight, num_bits, group_size, scheme,
quantile, return_int=True, full_range=full_range
quantile, return_int=True, full_range=sym_full_range
)
new_module = WeightOnlyLinear(
m.in_features, m.out_features, num_bits, group_size
m.in_features, m.out_features, num_bits, group_size,
zp=zp is not None, bias=m.bias is not None,
compression_dtype=compression_dtype,
compression_dim=compression_dim,
scale_dtype=scale_dtype,
)
new_module.pack(int_weight, scale, zp, m.bias)
if n == '':
Expand All @@ -269,7 +279,8 @@ def rtn_quantize(model, num_bits=4, group_size=32, scheme="asym",
set_module(model, n, new_module)
else:
q_weight = quant_weight(
weight, num_bits, group_size, scheme, quantile, full_range=full_range
weight, num_bits, group_size, scheme, quantile,
full_range=sym_full_range
)
m.weight.data.copy_(q_weight)
return model
Expand Down Expand Up @@ -389,7 +400,7 @@ def _update_input_with_scale(args, kwargs, scales):
@torch.no_grad()
def awq_quantize(model, weight_config={}, absorb_dict={}, dataloader=None, n_samples=128,
auto_scale=True, mse_range=True, calib_func=None, n_blocks=5,
return_int=False, full_range=False):
return_int=False, sym_full_range=False):
"""Quant the model with Activation-aware Weight quantization(AWQ) method.
Args:
Expand Down Expand Up @@ -418,7 +429,7 @@ def awq_quantize(model, weight_config={}, absorb_dict={}, dataloader=None, n_sam
n_blocks: split model into block number to avoid OOM.
return_int (bool, optional): Choose return fp32 or int32 model.
Defaults to False.
full_range (bool, optional): Choose sym range whether use -2**(bits-1).
sym_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
Returns:
model: fake quantized model
Expand Down Expand Up @@ -645,7 +656,7 @@ def forward(self, *args, **kwargs):
num_bits=-1,
weight_config=weight_config,
return_int=return_int,
full_range=full_range,
sym_full_range=sym_full_range,
)
logger.info("AWQ quantization is done.")
return model
Loading

0 comments on commit 59172ad

Please sign in to comment.