Skip to content

Commit

Permalink
SQ refactor (#1633)
Browse files Browse the repository at this point in the history
Signed-off-by: Lu, Yintong <yintong.lu@intel.com>
  • Loading branch information
yintong-lu authored Feb 27, 2024
1 parent 5343009 commit a8d81ca
Show file tree
Hide file tree
Showing 18 changed files with 2,109 additions and 1,680 deletions.
1 change: 1 addition & 0 deletions .azure-pipelines/scripts/codeScan/pylint/pylint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pip install torch \
fvcore \
pymoo \
onnxruntime_extensions \
peft \
tf_slim \
transformers \
accelerate \
Expand Down
2 changes: 1 addition & 1 deletion docs/source/smooth_quant.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ In our experiments, an $\alpha$ range of [0.0, 1.0] with a step_size of 0.1 is f
*fully automated*: users only need to pass a model and dataloader.

```python
from neural_compressor.adaptor.torch_utils.smooth_quant import TorchSmoothQuant
from neural_compressor.adaptor.torch_utils.waq import TorchSmoothQuant

sq = TorchSmoothQuant(model, dataloader)
alpha = "auto" ##alpha could be a float number to disable auto-tuning and enable fixed-value alpha smoothquant.
Expand Down
46 changes: 9 additions & 37 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import math
import os
import re
from collections import OrderedDict, UserDict, namedtuple
from collections import OrderedDict, UserDict
from functools import partial

import yaml
Expand Down Expand Up @@ -1800,7 +1800,7 @@ def smooth_quant(
assert folding, "IPEX version >= 2.1 is required for SmoothQuant folding=False."

if not hasattr(self, "sq") or force_re_smooth:
from .torch_utils.smooth_quant import TorchSmoothQuant
from neural_compressor.adaptor.torch_utils.waq import TorchSmoothQuant

self.sq = TorchSmoothQuant(
model._model, dataloader=dataloader, example_inputs=self.example_inputs, q_func=self.q_func
Expand All @@ -1813,17 +1813,18 @@ def smooth_quant(
kwargs["percentile"] = percentile
if scales_per_op is not None:
kwargs["scales_per_op"] = scales_per_op
auto_alpha_args["init_alpha"] = default_alpha
model._model = self.sq.transform(
alpha=alpha,
folding=folding,
calib_iter=calib_iter,
weight_clip=weight_clip,
default_alpha=default_alpha,
auto_alpha_args=auto_alpha_args,
**kwargs,
)
if self.sq.record_max_info:
model.sq_max_info = self.sq.max_value_info
model.sq_scale_info = self.sq.sq_scale_info
return model

def _apply_pre_optimization(self, model, tune_cfg, recover=False):
Expand All @@ -1840,7 +1841,7 @@ def _apply_pre_optimization(self, model, tune_cfg, recover=False):
q_model = model._model
sq_max_info = model.sq_max_info
if sq_max_info:
from .torch_utils.smooth_quant import TorchSmoothQuant
from neural_compressor.adaptor.torch_utils.waq import TorchSmoothQuant

tsq = TorchSmoothQuant(q_model, None)
alpha = tune_cfg["recipe_cfgs"]["smooth_quant_args"]["alpha"]
Expand Down Expand Up @@ -1876,8 +1877,9 @@ def qdq_quantize(self, model, tune_cfg):
model: qdq quantized model.
"""
q_model = model._model
from neural_compressor.adaptor.torch_utils.waq import get_module, set_module

from .torch_utils.model_wrapper import QDQLinear, SQLinearWrapper
from .torch_utils.smooth_quant import get_module, set_module

smoothquant_scale_info = {}
fallback_op_name_list = []
Expand Down Expand Up @@ -3317,37 +3319,7 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func):
inplace = True if self.performance_only else False

# fetch SmoothQuant scale info from pre-optimized model
sq_max_info = model.sq_max_info
if sq_max_info:
smoothquant_scale_info = {}
from .torch_utils.model_wrapper import SQLinearWrapper
from .torch_utils.smooth_quant import get_module

for _, info in sq_max_info.items():
alpha = info["alpha"]
absorbed_layer = info["absorbed_layer"]
input_minmax = info["input_minmax"]
# for peft model,lora_B weights is 0.
weight_max = info["weight_max"]
if self.sq.weight_clip:
weight_max = weight_max.clamp(min=1e-5)
abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1]))
input_power = torch.pow(abs_input_max, alpha)
weight_power = torch.pow(weight_max, 1 - alpha)
scale = torch.clip(input_power / weight_power, min=1e-5)
for op_name in absorbed_layer:
module = copy.deepcopy(get_module(q_model._model, op_name))
new_module = SQLinearWrapper(module, 1.0 / scale, input_minmax, alpha)
weight_scale = new_module._get_weight_scale()
smoothquant_scale_info[op_name] = {
"alpha": new_module.alpha,
"input_scale_for_mul": new_module.input_scale,
"input_scale_after_mul": new_module.scale,
"input_zero_point_after_mul": new_module.zero_point,
"input_dtype": new_module.dtype,
"weight_scale_after_mul": weight_scale,
}
logger.debug(f"Current SmoothQuant alpha of {op_name} is {alpha}")
smoothquant_scale_info = model.sq_scale_info

# Check save_qconf_summary part is a workaround for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute
Expand Down Expand Up @@ -4795,7 +4767,7 @@ def teq_quantize(self, model, tune_cfg, dataloader, calib_func):

supported_layers = ["Linear"]
if folding: # pragma: no cover
from .torch_utils.smooth_quant import GraphTrace
from neural_compressor.adaptor.torch_utils.waq import GraphTrace

tg = GraphTrace()
absorb_to_layer, _ = tg.get_absorb_to_layer(model, self.example_inputs, supported_layers)
Expand Down
3 changes: 1 addition & 2 deletions neural_compressor/adaptor/torch_utils/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import copy
from functools import partial

import torch

Expand All @@ -25,10 +24,10 @@
get_hidden_states,
get_module_input_output,
)
from neural_compressor.adaptor.torch_utils.waq import set_module

from ...utils import logger
from .model_wrapper import MulLinear
from .smooth_quant import model_forward, set_module


def _get_absorb_per_block(model, example_inputs, folding=False, weight_config={}):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from torch.quantization import convert, prepare
from tqdm import tqdm

from neural_compressor.adaptor.torch_utils.waq import TorchSmoothQuant
from neural_compressor.config import default_workspace

from ..model_wrapper import QDQLayer
from ..smooth_quant import TorchSmoothQuant
from .utils import (
_get_path,
clean_module_weight,
Expand Down
10 changes: 5 additions & 5 deletions neural_compressor/adaptor/torch_utils/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def forward(self, X):

def qdq_weight(self):
# update weight w/ QDQ
from .smooth_quant import quant_dequant_w
from neural_compressor.adaptor.torch_utils.waq.utils import quant_dequant_w_v1

weith_qdq = quant_dequant_w(self.module)
weith_qdq = quant_dequant_w_v1(self.module)
self.module.weight = torch.nn.Parameter(weith_qdq)


Expand Down Expand Up @@ -139,7 +139,7 @@ def _calculate_qparams(self, input_scale, input_minmax, dtype=torch.quint8):
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
scale = torch.max(scale, torch.tensor([torch.finfo(torch.float32).eps]))
scale = torch.max(scale, torch.tensor([torch.finfo(torch.float32).eps], device=scale.device))
zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
return scale, zero_point
Expand Down Expand Up @@ -181,7 +181,7 @@ def forward(self, X):
return X

module_name_list = input_scale_dict.keys()
from .smooth_quant import get_module, set_module
from neural_compressor.adaptor.torch_utils.waq import get_module, set_module

for name in module_name_list:
module = get_module(tmp_model, name)
Expand All @@ -193,7 +193,7 @@ def forward(self, X):

def _wrapper_qdq_linear(tmp_model, module_name_list=[]):
"""Help function to generate a fake QDQ model for loading weights."""
from .smooth_quant import get_module, set_module
from neural_compressor.adaptor.torch_utils.waq import get_module, set_module

for name in module_name_list:
module = get_module(tmp_model, name)
Expand Down
Loading

0 comments on commit a8d81ca

Please sign in to comment.