From 044e6db24c22764633ac8213c41998358433fa7a Mon Sep 17 00:00:00 2001 From: Yi Liu <106061964+yiliu30@users.noreply.github.com> Date: Sun, 28 Apr 2024 16:51:11 +0800 Subject: [PATCH] Refine base Quantizer (#1760) Refine base Quantizer class --------- Signed-off-by: yuwenzho Signed-off-by: yiliu30 Co-authored-by: yuwenzho Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../torch/algorithms/base_algorithm.py | 72 ++++++++++++------- .../algorithms/static_quant/static_quant.py | 25 ++----- .../torch/quantization/algorithm_entry.py | 2 +- 3 files changed, 51 insertions(+), 48 deletions(-) diff --git a/neural_compressor/torch/algorithms/base_algorithm.py b/neural_compressor/torch/algorithms/base_algorithm.py index 36337d0f4aa..08cbde43ab8 100644 --- a/neural_compressor/torch/algorithms/base_algorithm.py +++ b/neural_compressor/torch/algorithms/base_algorithm.py @@ -14,7 +14,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Any +from typing import Any, Optional import torch @@ -22,37 +22,33 @@ class Quantizer(ABC): - """The base quantizer for all algorithm quantizers.""" + """The base quantizer for all algorithm quantizers. - def __init__(self, tune_cfg: OrderedDict = {}): - """Init a Quantizer object. + The `Quantizer` unifies the interfaces across various quantization algorithms, including GPTQ, RTN, etc. + Given a float model, `Quantizer` apply the quantization algorithm to the model according to the `quant_config`. - Args: - tune_cfg (OrderedDict, optional): quantization config for ops. Defaults to {}. - Take weight-only quantization as an example, - tune_cfg={ - 'fc2': - { - 'dtype': 'int', - 'bits': 4, - 'group_size': 32, - 'scheme': 'sym' - } - } - """ - self.tune_cfg = tune_cfg + To implement a new quantization algorithm,, inherit from `Quantizer` and implement the following methods: + - `prepare`: prepare a given model for convert. + - `convert`: convert a prepared model to a quantized model. + Note: `quantize` and `execute` are optional for new quantization algorithms. + """ - @abstractmethod - def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any): - """Quantizes a given torch model. + def __init__(self, quant_config: Optional[Any] = None): + """Init a Quantizer object. Args: - model (torch.nn.Module): The torch model to be quantized. - - Returns: - A quantized model. + quant_config : Specifies how to apply the algorithm on the given model. + The format of `quant_config` can be defined by `Quantized` itself. + For example, `quant_config` can be a dictionary as below: + quant_config={ + 'fc2':{ + 'dtype': 'int', + 'bits': 4, + 'group_size': 32, + 'scheme': 'sym' + }} """ - raise NotImplementedError("{} doesn't implement `quantize` function.".format(self.__class__.__name__)) + self.quant_config = quant_config @abstractmethod def prepare(self, model: torch.nn.Module, *args: Any, **kwargs: Any): @@ -80,6 +76,30 @@ def convert(self, model: torch.nn.Module, *args: Any, **kwargs: Any): """ raise NotImplementedError("{} doesn't implement `convert` function. ".format(self.__class__.__name__)) + def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any): + """Quantizes a given float model. + + Args: + model (torch.nn.Module): The float model to be quantized. + + Returns: + A quantized model. + """ + run_fn = kwargs.get("run_fn", None) + run_args = kwargs.get("run_args", None) + assert run_fn is not None, ( + "Can't find run_func. Please provide run_func to quantize API " + "or overwrite quantize member function in your Quantizer class." + ) + + model = self.prepare(model, *args, **kwargs) + if run_args: + run_fn(model, *run_args) + else: + run_fn(model) + model = self.convert(model, *args, **kwargs) + return model + def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any): # pragma: no cover """Execute according to mode. diff --git a/neural_compressor/torch/algorithms/static_quant/static_quant.py b/neural_compressor/torch/algorithms/static_quant/static_quant.py index 06796c82199..a013ce51144 100644 --- a/neural_compressor/torch/algorithms/static_quant/static_quant.py +++ b/neural_compressor/torch/algorithms/static_quant/static_quant.py @@ -46,13 +46,13 @@ class StaticQuantQuantizer(Quantizer): - def __init__(self, tune_cfg: OrderedDict = {}): + def __init__(self, quant_config: OrderedDict = {}): """Init a StaticQuantQuantizer object. Args: - tune_cfg (OrderedDict, optional): quantization config for ops. Defaults to {}. + quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}. """ - super().__init__(tune_cfg) + super().__init__(quant_config) def prepare(self, model, example_inputs, inplace=True, *args, **kwargs): """Prepares a given model for quantization. @@ -71,7 +71,7 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs): model, example_inputs ) # update json file in ipex_config_path; map ipex op_name to pt op_name - user_cfg = cfg_to_qconfig(self.tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) + user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) model.eval() # Check save_qconf_summary part is a workaround for IPEX bug. @@ -126,23 +126,6 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs): model.save = MethodType(save, model) return model - def quantize(self, model, example_inputs, run_fn, inplace=True, *args, **kwargs): - """Quantizes a given torch model. - - Args: - model: A float model to be quantized. - example_inputs: Used to trace torch model. - run_fn: a calibration function for calibrating the model. - inplace: Whether to carry out model transformations in-place. Defaults to True. - - Returns: - A quantized model. - """ - model = self.prepare(model, example_inputs=example_inputs, inplace=inplace) - run_fn(model) - model = self.convert(model, example_inputs=example_inputs, inplace=inplace) - return model - def _ipex_post_quant_process(model, example_inputs, inplace=False): """Convert to a jit model. diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 63ba0db7421..2834b42949a 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -155,7 +155,7 @@ def static_quant_entry( inplace = kwargs.get("inplace", True) assert example_inputs is not None, "Please provide example_inputs for static quantization." - quantizer = StaticQuantQuantizer(tune_cfg=quant_config_mapping) + quantizer = StaticQuantQuantizer(quant_config=quant_config_mapping) model = quantizer.execute(model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace) return model