Skip to content

Commit

Permalink
Refine base Quantizer (#1760)
Browse files Browse the repository at this point in the history
Refine base Quantizer class
---------

Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Co-authored-by: yuwenzho <yuwen.zhou@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 28, 2024
1 parent 95e67ea commit 044e6db
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 48 deletions.
72 changes: 46 additions & 26 deletions neural_compressor/torch/algorithms/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,41 @@

from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any
from typing import Any, Optional

import torch

from neural_compressor.torch.utils import Mode


class Quantizer(ABC):
"""The base quantizer for all algorithm quantizers."""
"""The base quantizer for all algorithm quantizers.
def __init__(self, tune_cfg: OrderedDict = {}):
"""Init a Quantizer object.
The `Quantizer` unifies the interfaces across various quantization algorithms, including GPTQ, RTN, etc.
Given a float model, `Quantizer` apply the quantization algorithm to the model according to the `quant_config`.
Args:
tune_cfg (OrderedDict, optional): quantization config for ops. Defaults to {}.
Take weight-only quantization as an example,
tune_cfg={
'fc2':
{
'dtype': 'int',
'bits': 4,
'group_size': 32,
'scheme': 'sym'
}
}
"""
self.tune_cfg = tune_cfg
To implement a new quantization algorithm,, inherit from `Quantizer` and implement the following methods:
- `prepare`: prepare a given model for convert.
- `convert`: convert a prepared model to a quantized model.
Note: `quantize` and `execute` are optional for new quantization algorithms.
"""

@abstractmethod
def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
"""Quantizes a given torch model.
def __init__(self, quant_config: Optional[Any] = None):
"""Init a Quantizer object.
Args:
model (torch.nn.Module): The torch model to be quantized.
Returns:
A quantized model.
quant_config : Specifies how to apply the algorithm on the given model.
The format of `quant_config` can be defined by `Quantized` itself.
For example, `quant_config` can be a dictionary as below:
quant_config={
'fc2':{
'dtype': 'int',
'bits': 4,
'group_size': 32,
'scheme': 'sym'
}}
"""
raise NotImplementedError("{} doesn't implement `quantize` function.".format(self.__class__.__name__))
self.quant_config = quant_config

@abstractmethod
def prepare(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
Expand Down Expand Up @@ -80,6 +76,30 @@ def convert(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
"""
raise NotImplementedError("{} doesn't implement `convert` function. ".format(self.__class__.__name__))

def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
"""Quantizes a given float model.
Args:
model (torch.nn.Module): The float model to be quantized.
Returns:
A quantized model.
"""
run_fn = kwargs.get("run_fn", None)
run_args = kwargs.get("run_args", None)
assert run_fn is not None, (
"Can't find run_func. Please provide run_func to quantize API "
"or overwrite quantize member function in your Quantizer class."
)

model = self.prepare(model, *args, **kwargs)
if run_args:
run_fn(model, *run_args)
else:
run_fn(model)
model = self.convert(model, *args, **kwargs)
return model

def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any): # pragma: no cover
"""Execute according to mode.
Expand Down
25 changes: 4 additions & 21 deletions neural_compressor/torch/algorithms/static_quant/static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@


class StaticQuantQuantizer(Quantizer):
def __init__(self, tune_cfg: OrderedDict = {}):
def __init__(self, quant_config: OrderedDict = {}):
"""Init a StaticQuantQuantizer object.
Args:
tune_cfg (OrderedDict, optional): quantization config for ops. Defaults to {}.
quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
"""
super().__init__(tune_cfg)
super().__init__(quant_config)

def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
"""Prepares a given model for quantization.
Expand All @@ -71,7 +71,7 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
model, example_inputs
)
# update json file in ipex_config_path; map ipex op_name to pt op_name
user_cfg = cfg_to_qconfig(self.tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
model.eval()

# Check save_qconf_summary part is a workaround for IPEX bug.
Expand Down Expand Up @@ -126,23 +126,6 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
model.save = MethodType(save, model)
return model

def quantize(self, model, example_inputs, run_fn, inplace=True, *args, **kwargs):
"""Quantizes a given torch model.
Args:
model: A float model to be quantized.
example_inputs: Used to trace torch model.
run_fn: a calibration function for calibrating the model.
inplace: Whether to carry out model transformations in-place. Defaults to True.
Returns:
A quantized model.
"""
model = self.prepare(model, example_inputs=example_inputs, inplace=inplace)
run_fn(model)
model = self.convert(model, example_inputs=example_inputs, inplace=inplace)
return model


def _ipex_post_quant_process(model, example_inputs, inplace=False):
"""Convert to a jit model.
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def static_quant_entry(
inplace = kwargs.get("inplace", True)
assert example_inputs is not None, "Please provide example_inputs for static quantization."

quantizer = StaticQuantQuantizer(tune_cfg=quant_config_mapping)
quantizer = StaticQuantQuantizer(quant_config=quant_config_mapping)
model = quantizer.execute(model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace)
return model

Expand Down

0 comments on commit 044e6db

Please sign in to comment.