diff --git a/ppq/IR/base/graph.py b/ppq/IR/base/graph.py index c59b9ede..766bac85 100644 --- a/ppq/IR/base/graph.py +++ b/ppq/IR/base/graph.py @@ -234,7 +234,9 @@ def copy(self, copy_value: bool = False): 'however its value is not an instance of torch.Tensor, ' 'ppq will automaticall convert it to torch.Tensor now.') self.value = convert_any_to_torch_tensor(self.value) - return Variable(name=self.name, value=self.value.clone(), is_parameter=self.is_parameter) + if isinstance(self.value, torch.Tensor): + value = self.value.clone() + return Variable(name=self.name, value=value, is_parameter=self.is_parameter) class Operation(OperationBase, Serializable): diff --git a/ppq/api/interface.py b/ppq/api/interface.py index eab53059..a5bdaf5a 100644 --- a/ppq/api/interface.py +++ b/ppq/api/interface.py @@ -157,7 +157,7 @@ def dump_torch_to_onnx( model: torch.nn.Module, onnx_export_file: str, input_shape: List[int], - input_dtype: torch.dtype, + input_dtype: torch.dtype = torch.float, inputs: List[Any] = None, device: str = 'cuda'): """ diff --git a/ppq/core/common.py b/ppq/core/common.py index 42ab0a12..c1152437 100644 --- a/ppq/core/common.py +++ b/ppq/core/common.py @@ -2,7 +2,6 @@ # PPQ System configuration # You can modify following codes for your own purpose. - # Observer 中,最小 scale 限制,所有小于该值的 scale 将被该值覆盖 OBSERVER_MIN_SCALE = 1e-8 # Observer 中,最小 scale 的手动覆盖属性 @@ -64,9 +63,6 @@ DEFAULT_OPSET_VERSION = 11 STRICT_OPSET_CHECKING = False -# 导出 qdq 节点时是否需要导出状态已经是 overlap 的节点 -EXPORT_OVERLAPPED_CONFIG = False - # LSTM 算子的权重缓存属性 LSTM_FLATTEN_WEIGHT_ATTRIB = 'LSTM_FLATTEN_WEIGHT_ATTRIB' # GRU 算子的权重缓存属性 @@ -90,4 +86,7 @@ CHECKPOINT_TOLERANCE = 1 # 要做 Bias Correction 的算子种类 -BIAS_CORRECTION_INTERST_TYPE = {'Conv', 'Gemm', 'ConvTranspose'} \ No newline at end of file +BIAS_CORRECTION_INTERST_TYPE = {'Conv', 'Gemm', 'ConvTranspose'} + +# 导出 qdq 节点时是否需要导出状态已经是 overlap 的节点 +EXPORT_OVERLAPPED_CONFIG = False diff --git a/ppq/core/quant.py b/ppq/core/quant.py index 762c8c62..a974a752 100644 --- a/ppq/core/quant.py +++ b/ppq/core/quant.py @@ -4,15 +4,20 @@ """ import time # for hash generation -from abc import abstractmethod from enum import Enum from typing import Any, Iterable, List import torch +from .common import EXPORT_OVERLAPPED_CONFIG from .storage import Serializable +class QuantizationVisiblity(Enum): + FORCE_EXPORT = 1 + EXPOET_WHEN_ACTIVE = 2 + INTERNAL = 3 + class NetworkFramework(Enum): PPL = 1 ONNX = 2 @@ -365,7 +370,7 @@ def __init__( offset: Any = None, observer_algorithm: str = None, detail: Any = None, - require_export: bool = None, + visiblity: QuantizationVisiblity = QuantizationVisiblity.EXPOET_WHEN_ACTIVE, state: QuantizationStates = QuantizationStates.INITIAL ): """Create a PPQ Tensor Quantization Configuration Instance. @@ -395,7 +400,13 @@ def __init__( detail (Any, optional): Only used by PPQ internal logic, detail is used to store some internal data, you are not supposed to use it. - require_export (bool, optional): If require_export == True, PPQ exporter will export this TQC ignoring state checks. + visiblity (Visiblity): visiblity is the attribute that controls export logic. + + Currently, there are 3 Visiblity level in PPQ: + if Visiblity == FORCE_EXPORT, ppq exporter will export this TQC + ignoring state check(even if current TQC has been overrlapped). + if Visiblity == EXPORT_WHEN_ACTIVD, ppq exporter will export this TQC only when it has been actived. + if Visiblity == INTERNAL, This TQC will not be exported. state (QuantizationStates, optional): Defaults to QuantizationStates.INITIAL, see QuantizationStates for more detail. @@ -416,17 +427,25 @@ def __init__( self.detail = {} if detail is None else detail self._father_config = self # union-find self._hash = self.__create_hash() - self._require_export = require_export + self._visiblity = visiblity super().__init__() - @ abstractmethod - def export(self) -> str: - raise Exception('Implement this first') + def can_export(self) -> bool: + if self.visiblity == QuantizationVisiblity.INTERNAL: return False + type_check = isinstance(self.scale, torch.Tensor) and isinstance(self.offset, torch.Tensor) + valid_states = {QuantizationStates.BAKED, QuantizationStates.PASSIVE_BAKED} + + if EXPORT_OVERLAPPED_CONFIG: valid_states.add(QuantizationStates.OVERLAPPED) + state_check = QuantizationStates.is_activated(self.state) or self.state in valid_states + + if (state_check or self.visiblity == QuantizationVisiblity.FORCE_EXPORT): + if type_check: return True + return False def __eq__(self, o: object) -> bool: if not isinstance(o, TensorQuantizationConfig): - raise TypeError('Can only compare TensorQuantizationConfig object '\ - 'with another TensorQuantizationConfig object.') + raise TypeError('Can only compare TensorQuantizationConfig object ' + 'with another TensorQuantizationConfig object.') return self._hash == o._hash def __str__(self) -> str: @@ -509,17 +528,13 @@ def is_revisable(self): }) @ property - def exportable(self) -> bool: - value_check = isinstance(self.scale, torch.Tensor) - if self._require_export is None: - state_check = QuantizationStates.can_export(self.state) - return (value_check and state_check) - else: return (self._require_export and value_check) - - @ exportable.setter - def exportable(self, export_override: bool): - self._require_export = export_override - + def visiblity(self) -> bool: + return self._visiblity + + @ visiblity.setter + def visiblity(self, visiblity: bool): + self._visiblity = visiblity + @ property def scale(self) -> torch.Tensor: if self.dominated_by == self: return self._scale diff --git a/ppq/parser/caffe_exporter.py b/ppq/parser/caffe_exporter.py index ea72ce92..fc96395b 100644 --- a/ppq/parser/caffe_exporter.py +++ b/ppq/parser/caffe_exporter.py @@ -40,12 +40,7 @@ def export_quantization_config(self, config_path: str, graph: BaseGraph): for operation in graph.operations.values(): if not isinstance(operation, QuantableOperation): continue for config, var in operation.config_with_variable: - if not QuantizationStates.can_export(config.state): - raise PermissionError( - 'Can not export quant config cause not all quantization configurations ' - 'have been correctly initialized(or some of them has been deactivated). ' - f'Operation {operation.name} has an invalid quantization config({config.state}) ' - f'at variable {var.name}.') + if not config.can_export(): continue # PATCH 2021.11.25 # REMOVE BIAS FROM CONFIGURATION diff --git a/ppq/parser/ncnn_exporter.py b/ppq/parser/ncnn_exporter.py index 2b364b3a..65473444 100644 --- a/ppq/parser/ncnn_exporter.py +++ b/ppq/parser/ncnn_exporter.py @@ -18,6 +18,8 @@ def export_quantization_config(self, config_path: str, graph: BaseGraph): if op.is_computing_op and isinstance(op, QuantableOperation): fd.write(f'{op.name}_param_0 ') param_cfg = op.config.input_quantization_config[1] + if not param_cfg.can_export(): continue + assert param_cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED}\ and param_cfg.observer_algorithm in {'minmax', 'Minmax'} and \ param_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) @@ -32,6 +34,7 @@ def export_quantization_config(self, config_path: str, graph: BaseGraph): for s in scale: fd.write('%f '% s) fd.write('\n') + for op in topo_order: if op.is_computing_op and isinstance(op, QuantableOperation): fd.write(f'{op.name} ') diff --git a/ppq/parser/nxp_exporter.py b/ppq/parser/nxp_exporter.py index 659ef5a6..18ac422d 100644 --- a/ppq/parser/nxp_exporter.py +++ b/ppq/parser/nxp_exporter.py @@ -63,9 +63,8 @@ def export(self, file_path: str, graph: BaseGraph, if variable.is_parameter and not export_param: continue for config in configs: if config is None: continue # source_op can be None - if config.state in {QuantizationStates.ACTIVATED, QuantizationStates.BAKED, - QuantizationStates.OVERLAPPED, QuantizationStates.PASSIVE_BAKED}: - if config.state == QuantizationStates.OVERLAPPED: config = config.dominated_by + if config.can_export(): + tensor_range = config.scale * pow(2, config.num_of_bits - 1) min_val, max_val = -tensor_range, tensor_range - config.scale min_tensor = numpy_helper.from_array( diff --git a/ppq/parser/onnxruntime_exporter.py b/ppq/parser/onnxruntime_exporter.py index c2fcaa85..4069b6f4 100644 --- a/ppq/parser/onnxruntime_exporter.py +++ b/ppq/parser/onnxruntime_exporter.py @@ -3,20 +3,46 @@ import onnx import torch from onnx import helper -from ppq.core import (EXPORT_OVERLAPPED_CONFIG, GRAPH_OPSET_ATTRIB, PPQ_CONFIG, +from ppq.core import (GRAPH_OPSET_ATTRIB, PPQ_CONFIG, ChannelwiseTensorQuantizationConfig, DataType, OperationMeta, QuantizationProperty, QuantizationStates, TensorMeta, TensorQuantizationConfig, - convert_any_to_torch_tensor) -from ppq.IR import BaseGraph, Operation, QuantableOperation, QuantableVariable -from ppq.IR.base.command import GraphCommand, GraphCommandType -from ppq.IR.morph import GraphDeviceSwitcher, GraphFormatter + convert_any_to_torch_tensor, ppq_warning) +from ppq.IR import (BaseGraph, Operation, QuantableOperation, + QuantableVariable, Variable) +from ppq.IR.morph import GraphDeviceSwitcher from ppq.quantization.qfunction.linear import PPQLinearQuant_toInt from ppq.utils.round import ppq_tensor_round from .onnx_exporter import OnnxExporter +class QDQHelper(): + """Helper class for processing onnx qdq format""" + @ staticmethod + def TQC_Exportable_Check( + TQC: TensorQuantizationConfig, bounded_var: Variable) -> bool: + if not TQC.can_export(): return False + meta_check = bounded_var.meta is not None + + if TQC.num_of_bits == 8: + if TQC.policy.has_property(QuantizationProperty.ASYMMETRICAL): + range_check = (TQC.quant_max <= 255 and TQC.quant_min >= 0) + else: range_check = (TQC.quant_max <= 127 and TQC.quant_min >= -128) + else: range_check = True + + if not range_check: + ppq_warning(f'Is it not safe to export TQC({bounded_var.name}) to Onnx, ' + f'INT8 value range must be [-128, 127] or [0, 255], ' + f'however [{TQC.quant_min, TQC.quant_max}] was given.') + return False + + if not meta_check: + raise ValueError(f'Meta Data is missing! Graph Export Failed. ' + f'(Check Meta For Varaible: {bounded_var.name})') + return True + + class ONNXRUNTIMExporter(OnnxExporter): """ONNXRUNTIME int8 QDQ format exporter, no further actions should be applied to the graph because we will modify the graph in-place, and the @@ -145,7 +171,7 @@ def remove_activation_ops(self, graph: BaseGraph) -> BaseGraph: be removed from your network safely. Their function can be replaced by quant & dequant operations. - So to say those activation is unnecessary for Asymmetric quantized network. + Those activation is unnecessary for Asymmetric quantized network. Args: graph (BaseGraph): Processing Graph @@ -157,11 +183,29 @@ def remove_activation_ops(self, graph: BaseGraph) -> BaseGraph: if op.type in {'Relu', 'Clip'}: config = op.config.output_quantization_config[0] # Only ASYMMETRICAL quantized activations can be safely removed. - if config.policy.has_property(QuantizationProperty.ASYMMETRICAL): - # Patch 2022 06 29, 有些时候当我们启动了 alignment pass, relu 后面的定点信息将不再是 0, - # 此时我们拒绝移除激活函数 - if config._father_config != config: continue - removed_activations.append(op) + if config.policy.has_property(QuantizationProperty.SYMMETRICAL): continue + + if not isinstance(config.scale, torch.Tensor): continue + if not isinstance(config.offset, torch.Tensor): continue + + range_min = (config.scale * (config.quant_min - config.offset)).min().item() + range_max = (config.scale * (config.quant_max - config.offset)).max().item() + + if op.type == 'Relu': + if range_min >= 0: + removed_activations.append(op) + + if op.type == 'Clip': + if op.num_of_input == 3: + clip_min = op.inputs[1].value + clip_max = op.inputs[2].value + if clip_min is not None: clip_min = clip_min.item() + else: clip_min = float('-inf') + if clip_max is not None: clip_max = clip_max.item() + else: clip_max = float('+inf') + + if range_min >= clip_min and range_max <= clip_max: + removed_activations.append(op) # Activation op can only be relu and clip, # so it is safe to access op.inputs[0], op.outputs[0] as their input and output. @@ -193,23 +237,22 @@ def remove_activation_ops(self, graph: BaseGraph) -> BaseGraph: graph=graph, var=input_var, config=quant_config, related_op=upstream_op, meta=input_var.meta) - formatter = GraphFormatter(graph) - formatter(GraphCommand(GraphCommandType.DELETE_ISOLATED)) + # formatter = GraphFormatter(graph) + # formatter(GraphCommand(GraphCommandType.DELETE_ISOLATED)) return graph def remove_duplicated_quant_op(self, graph: BaseGraph) -> BaseGraph: - """Some time there will be more than 1 quant operation inserted with a + """ + Pattern: Quant - Dequant - Quant - Dequant + + Can reduced to: Quant - Dequant + + Some time there will be more than 1 quant operation inserted with a single variable. This function will remove duplicated quant operation from variable if it is possible. If inserted quant operations do not share a same zeropoint and scale, Then there is no way to remove any one of them. - - Args: - graph (BaseGraph): Processing Graph - - Returns: - _type_: Processed Graph """ interested_pairs = [] for qt_op in graph.operations.values(): @@ -234,6 +277,19 @@ def remove_duplicated_quant_op(self, graph: BaseGraph) -> BaseGraph: input_var, output_var = op.inputs[0], op.outputs[0] graph.remove_operation(op) graph.create_link_with_var(input_var, output_var) + + """ + There is another type of fusion: + Pattern: Quant +-- Dequant + | + +-- Dequant + + Can reduce to: Quant - Dequant +-- + | + +-- + + Not implemented. + """ return graph @ property @@ -291,54 +347,40 @@ def convert_operation(self, graph: BaseGraph, op: QuantableOperation, """ # collect quantable vars, where we need to insert quant and dequant op for config, var in op.config_with_variable: + if not QDQHelper.TQC_Exportable_Check(TQC=config, bounded_var=var): continue + meta = var.meta - if var.is_parameter: + if var.is_parameter and process_parameter: # we do not want to process clip value here. - if op.type in {'Clip'}: continue - + if op.type in {'Clip', 'Pad'}: continue assert len(var.dest_ops) == 1, ( f'Can not export variable {var.name}, cause it has more than 1 destination operations. ' 'PPQ require all parameters to have only 1 destination operation.') - if not process_parameter: continue # override quantization state, so that we can export parameter correctly. if config.state == QuantizationStates.BAKED: config.state = QuantizationStates.ACTIVATED if config.state == QuantizationStates.PASSIVE_BAKED: config.state = QuantizationStates.PASSIVE - if QuantizationStates.can_export(config.state) and config.state not in { - QuantizationStates.FP32, QuantizationStates.SOI}: - if (config.state == QuantizationStates.OVERLAPPED and - not EXPORT_OVERLAPPED_CONFIG): - continue - - # if not quant parameter to int, all parameter should export as fp32. - # needs insert both quant and dequant op for them - if not quant_param_to_int: - created = self.insert_quant_on_variable( - graph=graph, var=var, config=config, related_op=op, meta=meta) - var = created.outputs[0] - - self.insert_dequant_on_variable( - graph=graph, var=var, config=config, related_op=op, meta=meta) - if quant_param_to_int: - var.value = PPQLinearQuant_toInt(tensor=var.value, config=config) - - else: - if not process_activation: continue - - if QuantizationStates.can_export(config.state) and config.state not in { - QuantizationStates.FP32, QuantizationStates.SOI}: - if (config.state == QuantizationStates.OVERLAPPED and - not EXPORT_OVERLAPPED_CONFIG): - continue - + # if not quant parameter to int, all parameter should export as fp32. + # needs insert both quant and dequant op for them + if not quant_param_to_int: created = self.insert_quant_on_variable( graph=graph, var=var, config=config, related_op=op, meta=meta) var = created.outputs[0] - self.insert_dequant_on_variable( - graph=graph, var=var, config=config, related_op=op, meta=meta) + + self.insert_dequant_on_variable( + graph=graph, var=var, config=config, related_op=op, meta=meta) + if quant_param_to_int: + var.value = PPQLinearQuant_toInt(tensor=var.value, config=config) + + elif (not var.is_parameter) and process_activation: + created = self.insert_quant_on_variable( + graph=graph, var=var, config=config, related_op=op, meta=meta) + self.insert_dequant_on_variable( + graph=graph, var=created.outputs[0], config=config, + related_op=op, meta=meta) def prepare_graph( self, graph: BaseGraph, @@ -369,11 +411,6 @@ def prepare_graph( processor = GraphDeviceSwitcher(graph) processor.remove_switcher() - # remove activations - if remove_activation_fn: - # remove useless activation. - self.remove_activation_ops(graph) - # mark quantable variables for op in [op for op in graph.operations.values()]: if not isinstance(op, QuantableOperation): continue @@ -383,6 +420,11 @@ def prepare_graph( process_parameter=process_parameter, quant_param_to_int=quant_parameter_to_int) + # remove activations + if remove_activation_fn: + # remove useless activation. + self.remove_activation_ops(graph) + return self.remove_duplicated_quant_op(graph) def export(self, file_path: str, graph: BaseGraph, config_path: str = None) -> None: diff --git a/ppq/parser/onnxruntime_oos_exporter.py b/ppq/parser/onnxruntime_oos_exporter.py index f0717db4..9364176b 100644 --- a/ppq/parser/onnxruntime_oos_exporter.py +++ b/ppq/parser/onnxruntime_oos_exporter.py @@ -2,12 +2,11 @@ from typing import List, Tuple import torch -from ppq.IR.search import SearchableGraph -from ppq.core import (PASSIVE_OPERATIONS, DataType, QuantizationStates, - TensorMeta, TensorQuantizationConfig, - convert_any_to_torch_tensor, ppq_warning) +from ppq.core import (DataType, QuantizationStates, TargetPlatform, TensorMeta, + TensorQuantizationConfig, convert_any_to_torch_tensor, + ppq_warning) from ppq.IR import BaseGraph, Operation, QuantableOperation, Variable -from ppq.core.quant import TargetPlatform +from ppq.IR.search import SearchableGraph from ppq.quantization.qfunction.linear import PPQLinearQuant_toInt from ppq.utils.round import ppq_tensor_round diff --git a/ppq/parser/ppl.py b/ppq/parser/ppl.py index e976c725..3ae0e161 100644 --- a/ppq/parser/ppl.py +++ b/ppq/parser/ppl.py @@ -15,18 +15,14 @@ def convert_type(platform: TargetPlatform) -> str: if platform == TargetPlatform.FP32: return None raise TypeError(f'Unsupported platform type. ({str(platform)})') + class PPLBackendExporter(OnnxExporter): def export_quantization_config(self, config_path: str, graph: BaseGraph): var_quant_info_recorder, op_platform_recorder = {}, {} for operation in graph.operations.values(): if not isinstance(operation, QuantableOperation): continue for config, var in operation.config_with_variable: - if not QuantizationStates.can_export(config.state): - raise PermissionError( - 'Can not export quant config cause not all quantization configurations ' - 'have been correctly initialized(or some of them has been deactivated). ' - f'Operation {operation.name} has an invalid quantization state({config.state}) ' - f'at variable {var.name}.') + if not config.can_export(): continue # PATCH 2021.11.25 # REMOVE BIAS FROM CONFIGURATION diff --git a/ppq/quantization/algorithm/equalization.py b/ppq/quantization/algorithm/equalization.py index 61460089..834fde34 100644 --- a/ppq/quantization/algorithm/equalization.py +++ b/ppq/quantization/algorithm/equalization.py @@ -1,9 +1,8 @@ from enum import Enum -from typing import List, Tuple +from typing import List import torch from ppq.IR import Operation -from tqdm import tqdm as Progressbar class EqualizationMethod(Enum): @@ -22,386 +21,268 @@ class EqualizationMethod(Enum): SQUARE_MEAN = 4, +class EqualizationHelper(): + + @ staticmethod + def key_value_from_upstream( + op: Operation, including_bias: bool = False, including_act: bool = False, + bias_multiplier: float = 0.5, act_multiplier: float = 0.5) -> torch.Tensor: + if op.type not in {'Gemm', 'MatMul', 'Conv', 'ConvTranspose'}: + raise TypeError(f'Unsupported Op type {op.name}({op.type}) for Equalization Optimization.') + if not op.inputs[1].is_parameter: + raise ValueError(f'Parameter of Op {op.name} is non-static.') + buffer = [] + + # ---------------------------------- + # step - 1, extract weight from op: + # ---------------------------------- + w = op.inputs[1].value + if op.type == 'ConvTranspose': + num_of_groups = op.attributes.get('group', 1) + if w.ndim == 3: + w = torch.reshape(w, (num_of_groups, w.shape[0] // num_of_groups) + w.shape[1: ]) + w = torch.permute(w, (2, 0, 1, 3)) + w = torch.reshape(w, (w.shape[0] * w.shape[1], -1)) + elif w.ndim == 4: + w = torch.reshape(w, (num_of_groups, w.shape[0] // num_of_groups) + w.shape[1: ]) + w = torch.permute(w, (2, 0, 1, 3, 4)) + w = torch.reshape(w, (w.shape[0] * w.shape[1], -1)) + elif w.ndim == 5: + w = torch.reshape(w, (num_of_groups, w.shape[0] // num_of_groups) + w.shape[1: ]) + w = torch.permute(w, (2, 0, 1, 3, 4, 5)) + w = torch.reshape(w, (w.shape[0] * w.shape[1], -1)) + else: + raise ValueError(f'Unexpected dimension of weight of {op.name}.') + buffer.append(w) + + if op.type in {'MatMul', 'Gemm'}: + assert w.ndim == 2, f'Unexpected Error, Parameter of MatMul {op.name} should be 2-d.' + if op.attributes.get('transB', 0) == 0: + w = torch.transpose(w, 1, 0) + buffer.append(w) + + if op.type == 'Conv': + w = torch.reshape(w, (w.shape[0], -1)) + buffer.append(w) + + # ---------------------------------- + # step - 2, extract bias from op: + # ---------------------------------- + if including_bias and op.num_of_input == 3: + b = op.inputs[-1].value * bias_multiplier + if op.type in {'Conv', 'Gemm'} and op.inputs[-1].is_parameter: + b = torch.reshape(b, (w.shape[0], 1)) + buffer.append(b) + + if op.type == 'ConvTranspose': + b = torch.reshape(b, (w.shape[0], 1)) + buffer.append(b) + + # ---------------------------------- + # step - 3, extract activation from op: + # ---------------------------------- + if including_act and op.inputs[0].value is not None: + a = op.outputs[0].value * act_multiplier + buffer.append(a) + + # concat and return + return torch.cat(buffer, dim=-1) + + @ staticmethod + def key_value_from_downstream(op: Operation) -> torch.Tensor: + # ---------------------------------- + # step - 1, extract weight from op: + # ---------------------------------- + w = op.inputs[1].value + if op.type == 'ConvTranspose': + w = torch.reshape(w, (w.shape[0], -1)) + + if op.type in {'MatMul', 'Gemm'}: + assert w.ndim == 2, f'Unexpected Error, Parameter of MatMul {op.name} should be 2-d.' + if op.attributes.get('transB', 0) != 0: + w = torch.transpose(w, 1, 0) + + if op.type == 'Conv': + # for group convolution, we have to select its weight by group + num_of_groups = op.attributes.get('group', 1) + if w.ndim == 3: + w = torch.reshape(w, (num_of_groups, w.shape[0] // num_of_groups) + w.shape[1: ]) + w = torch.permute(w, (2, 0, 1, 3)) + w = torch.reshape(w, (w.shape[0] * w.shape[1], -1)) + elif w.ndim == 4: + w = torch.reshape(w, (num_of_groups, w.shape[0] // num_of_groups) + w.shape[1: ]) + w = torch.permute(w, (2, 0, 1, 3, 4)) + w = torch.reshape(w, (w.shape[0] * w.shape[1], -1)) + elif w.ndim == 5: + w = torch.reshape(w, (num_of_groups, w.shape[0] // num_of_groups) + w.shape[1: ]) + w = torch.permute(w, (2, 0, 1, 3, 4, 5)) + w = torch.reshape(w, (w.shape[0] * w.shape[1], -1)) + else: + raise ValueError(f'Unexpected dimension of weight of {op.name}.') + return w + + @ staticmethod + def scale_to_upstream(op: Operation, scale_factor: torch.Tensor): + if op.type not in {'Gemm', 'MatMul', 'Conv', 'ConvTranspose'}: + raise TypeError(f'Unsupported Op type {op.name}({op.type}) for Equalization Optimization.') + if not op.inputs[1].is_parameter: + raise ValueError(f'Parameter of Op {op.name} is non-static.') + + w = op.inputs[1].value + has_bias = op.num_of_input == 3 + if has_bias and not op.inputs[-1].is_parameter: + raise ValueError(f'Bias of Op {op.name} is non-static.') + if has_bias: bias = op.inputs[-1].value + + if op.type == 'ConvTranspose': + num_of_groups = op.attributes.get('group', 1) + w = torch.reshape(w, (num_of_groups, w.shape[0] // num_of_groups) + w.shape[1:]) + w *= torch.reshape(scale_factor, [num_of_groups, 1, -1] + [1] * (w.ndim - 3)) + w = torch.reshape(w, (w.shape[0] * w.shape[1], ) + w.shape[2:]) + if has_bias: bias *= scale_factor + + elif op.type == 'Conv': + w *= torch.reshape(scale_factor, [-1] + ([1] * (w.ndim - 1))) + if has_bias: bias *= scale_factor + + elif op.type in {'Gemm', 'MatMul'}: + if op.attributes.get('transB', 0) == 0: w = torch.transpose(w, 1, 0) + w *= torch.reshape(scale_factor, (-1, 1)) + if op.attributes.get('transB', 0) == 0: w = torch.transpose(w, 1, 0) + if has_bias: bias *= scale_factor + + # write back + with torch.no_grad(): + op.inputs[1].value.copy_(w) + if has_bias: op.inputs[-1].value.copy_(bias) + + @ staticmethod + def scale_to_downstream(op: Operation, scale_factor: torch.Tensor): + if op.type not in {'Gemm', 'MatMul', 'Conv', 'ConvTranspose'}: + raise TypeError(f'Unsupported Op type {op.name}({op.type}) for Equalization Optimization.') + if not op.inputs[1].is_parameter: + raise ValueError(f'Parameter of Op {op.name} is non-static.') + w = op.inputs[1].value + + if op.type == 'ConvTranspose': + w /= torch.reshape(scale_factor, [-1] + ([1] * (w.ndim - 1))) + + if op.type == 'Conv': + num_of_groups = op.attributes.get('group', 1) + w = torch.reshape(w, (num_of_groups, w.shape[0] // num_of_groups) + w.shape[1: ]) + w /= torch.reshape(scale_factor, [num_of_groups, 1, -1] + [1] * (w.ndim - 3)) + w = torch.reshape(w, (w.shape[1] * num_of_groups, ) + w.shape[2: ]) + + if op.type in {'Gemm', 'MatMul'}: + if op.attributes.get('transB', 0) != 0: w = torch.transpose(w, 1, 0) + w /= torch.reshape(scale_factor, (-1, 1)) + if op.attributes.get('transB', 0) != 0: w = torch.transpose(w, 1, 0) + + # write back + with torch.no_grad(): + op.inputs[1].value.copy_(w) + + class EqualizationPair: def __init__( self, - all_upstream_layers: List[Operation], - all_downstream_layers: List[Operation], - method:EqualizationMethod=EqualizationMethod.ABSOLUTE_MAX + upstream_layers: List[Operation], + downstream_layers: List[Operation] ): """ - EqualizationPair - 一个数据结构,封装了 equalization 的核心数据抽象和执行逻辑 a class encapsulating execution logic of equalization - 在 self.all_upstream_layers 包含了 equalization 操作中的所有上游层(op) - self.all_upstream_layers contain all upstream ops + 在 self.upstream_layers 包含了 equalization 操作中的所有上游层(op) + self.upstream_layers contain all upstream ops - 在 self.all_downstream_layers 包含了 equalization 操作中的所有下游层(op) - self.all_downstream_layers contain all downstream ops + 在 self.downstream_layers 包含了 equalization 操作中的所有下游层(op) + self.downstream_layers contain all downstream ops 一个 EqualizationPair 结构记录了参与 equalization 变换的所有相关层与其局部图结构信息 从而支持在局部子图上的 equalization 操作 An EqualizationPair records all relevant ops participating in the equalization transformation, thus supporting equalization on local subgraphs Args: - all_upstream_layers (list): + upstream_layers (list): equalization 操作中的所有上游层(op) - all_downstream_layers (list): + downstream_layers (list): equalization 操作中的所有下游层(op) - - method (EqualizationMethod, optional): - equalization 操作中,变换系数 s 的计算方式. Defaults to EqualizationMethod.ABSOLUTE_MAX. - """ - self.upstream_layers = all_upstream_layers - self.downstream_layers = all_downstream_layers - self.method = method - - def extract_key_value(self, including_bias: bool) -> Tuple[torch.Tensor, torch.Tensor]: - # extract all params from upstream_layers - upstream_params, downstream_params = [], [] - for upstream_layer in self.upstream_layers: - assert upstream_layer.type in ('Conv', 'ConvTranspose', 'Gemm'), ( - 'Only Conv or Linear layer is support in layerwise equalization now, ' - 'but %s got' % upstream_layer.type) - - if upstream_layer.type == 'ConvTranspose': - # weight shape is: [input channel, output channel / group, kernel, kernel] - weight, bias = self.get_convtranspose2d_params(upstream_layer, including_bias) - num_of_groups = upstream_layer.attributes.get('group', 1) - weight = torch.reshape(weight, (num_of_groups, weight.shape[0] // num_of_groups) + weight.shape[1:]) - weight = weight.permute(0, 2, 1, 3, 4) - weight = weight.reshape(weight.shape[0] * weight.shape[1], -1) - - upstream_params.append(weight) - if including_bias and bias is not None: - upstream_params.append(torch.reshape(bias, (weight.shape[1] * num_of_groups, 1))) - - elif upstream_layer.type == 'Conv': - # weight shape is: [output channel, input channel, kernel, kernel] - weight, bias = self.get_conv2d_params(upstream_layer, including_bias) - weight = torch.reshape(weight, (weight.shape[0], -1)) - - upstream_params.append(weight) - if including_bias and bias is not None: - upstream_params.append(torch.reshape(bias, (weight.shape[0], 1))) - - elif upstream_layer.type == 'Gemm': - # weight shape is: [output channel, input channel] - weight, bias = self.get_linear_params(upstream_layer, including_bias) - - upstream_params.append(weight) - if including_bias and bias is not None: - upstream_params.append(torch.reshape(bias, (weight.shape[0], 1))) - - # extract all params from downstream_layers - for downstream_layer in self.downstream_layers: - assert downstream_layer.type in ('Conv', 'ConvTranspose', 'Gemm'), ( - 'Only Conv or Linear layer is support in layerwise equalization now, ' - 'but %s got' % downstream_layer.type) - - if downstream_layer.type == 'Conv': - # weight shape is: [output channel, input channel // num_of_groups, kernel, kernel] - weight, bias = self.get_conv2d_params(downstream_layer, False) - - # for group convolution, we have to select its weight by group - num_of_groups = downstream_layer.attributes.get('group', 1) - - weight = torch.reshape(weight, (num_of_groups, weight.shape[0] // num_of_groups) + weight.shape[1: ]) - weight = weight.permute(2, 0, 1, 3, 4) - weight = torch.reshape(weight, (weight.shape[0] * weight.shape[1], -1)) - - downstream_params.append(weight) - - elif downstream_layer.type == 'ConvTranspose': - # weight shape is: [input channel, output channel // num_of_groups, kernel, kernel] - weight, bias = self.get_convtranspose2d_params(downstream_layer, False) - - # for group convolution, we have to select its weight by group - num_of_groups = downstream_layer.attributes.get('group', 1) - - weight = torch.reshape(weight, (weight.shape[0], -1)) - - downstream_params.append(weight) - - elif downstream_layer.type == 'Gemm': - # weight shape is: [output channel, input channel] - weight, bias = self.get_linear_params(downstream_layer, False) - downstream_params.append(weight.permute(1, 0)) - - # format all params - upstream_key_values = self.reduce_by_axis(upstream_params, method=self.method, aggerate_axis=1) - downstream_key_values = self.reduce_by_axis(downstream_params, method=self.method, aggerate_axis=1) - return upstream_key_values, downstream_key_values - - def layerwise_equalize( + self.upstream_layers = upstream_layers + self.downstream_layers = downstream_layers + + def equalize( self, value_threshold: float, - including_bias: bool + including_act: bool = False, + act_multiplier: float = 0.5, + including_bias: bool = False, + bias_multiplier: float = 0.5, + method: EqualizationMethod = EqualizationMethod.ABSOLUTE_MAX ): # extract key value from pair - upstream_key_values, downstream_key_values = self.extract_key_value(including_bias=including_bias) + upstream_key_values, downstream_key_values = [], [] + for op in self.upstream_layers: + key_value = EqualizationHelper.key_value_from_upstream( + op=op, including_bias=including_bias, including_act=including_act, + bias_multiplier=bias_multiplier, act_multiplier=act_multiplier) + upstream_key_values.append(key_value) + + for op in self.downstream_layers: + key_value = EqualizationHelper.key_value_from_downstream(op=op) + downstream_key_values.append(key_value) + + upstream_key_values = self.reduce_by_axis(upstream_key_values, method=method) + downstream_key_values = self.reduce_by_axis(downstream_key_values, method=method) # calculate scale scale = self.calculate_scale( upstream_key_values=upstream_key_values, downstream_key_values=downstream_key_values, - minval_threshold=value_threshold - ) + value_threshold=value_threshold) # write back all params - for upstream_layer in self.upstream_layers: - if upstream_layer.type == 'ConvTranspose': - weight, bias = self.get_convtranspose2d_params(upstream_layer, True) - num_of_groups = upstream_layer.attributes.get('group', 1) - weight = torch.reshape(weight, (num_of_groups, weight.shape[0] // num_of_groups) + weight.shape[1:]) - weight *= torch.reshape(scale, (num_of_groups, 1, -1, 1, 1)) - weight = torch.reshape(weight, (weight.shape[0] * weight.shape[1], ) + weight.shape[2:]) - if bias is not None: - bias *= scale - self.set_convtranspose2d_params(upstream_layer, bias, weight) - - elif upstream_layer.type == 'Conv': - weight, bias = self.get_conv2d_params(upstream_layer, True) - weight *= torch.reshape(scale, (-1, 1, 1, 1)) - if bias is not None: - bias *= scale - self.set_conv2d_params(upstream_layer, bias, weight) - - elif upstream_layer.type == 'Gemm': - weight, bias = self.get_linear_params(upstream_layer, True) - weight *= torch.reshape(scale, (-1, 1)) - if bias is not None: - bias *= scale - self.set_linear_params(upstream_layer, bias, weight) - - for downstream_layer in self.downstream_layers: - - if downstream_layer.type == 'ConvTranspose': - weight, bias = self.get_convtranspose2d_params(downstream_layer, False) - # for group convolution, we have to select its weight by group - - weight /= torch.reshape(scale, (-1, 1, 1, 1)) - self.set_convtranspose2d_params(downstream_layer, bias, weight) - - elif downstream_layer.type == 'Conv': - weight, bias = self.get_conv2d_params(downstream_layer, False) - # for group convolution, we have to select its weight by group - num_of_groups = downstream_layer.attributes.get('group', 1) - - weight = torch.reshape(weight, (num_of_groups, weight.shape[0] // num_of_groups) + weight.shape[1: ]) - weight /= torch.reshape(scale, (num_of_groups, 1, -1, 1, 1)) - weight = torch.reshape(weight, (weight.shape[1] * num_of_groups, ) + weight.shape[2: ]) - self.set_conv2d_params(downstream_layer, bias, weight) - - elif downstream_layer.type == 'Gemm': - weight, bias = self.get_linear_params(downstream_layer, False) - weight /= torch.reshape(scale, (1, -1)) - - self.set_linear_params(downstream_layer, bias, weight) - - def layerwise_channel_split( - self, value_threshold: float, - including_bias: bool): - pass - - def display(self) -> str: - for layer in self.upstream_layers + self.downstream_layers: - if layer.type == 'Conv': - weight, bias = self.get_conv2d_params(layer, including_bias=True) - elif layer.type == 'Gemm': - weight, bias = self.get_linear_params(layer, including_bias=True) - else: - raise Exception('Expect conv layer or linear layer only, while %s was given.' % layer.type) - - print('Stat of Layer %s: \t{%.4f}(Weight Max),\t{%.4f}(Weight Std)\t{%.4f}(Bias Max),\t{%.4f}(Bias Std)' % ( - layer.name, - torch.max(torch.abs(weight)), - torch.std(weight), - torch.max(torch.abs(bias)) if bias is not None else 0, - torch.std(bias) if bias is not None else 0 - )) - print('--- Layer-wise Equalization display end. ---') - - - def __str__(self) -> str: - return ( - 'Class EqualizationPair: ' - '[all_upstream_layers: %s, all_downstream_layers: %s]' % - (self.upstream_layers, self.downstream_layers)) - - - def get_conv2d_params(self, conv: Operation, including_bias: bool): - - assert conv.type == 'Conv', ( - 'Except input object with type Conv, but %s got' % conv.type) - assert conv.inputs[1].is_parameter, ( - f'Convolution layer {conv.name} has no static weights, please remove it from layerwise optimization.') - assert conv.inputs[1].value.ndim == 4, ( - f'Convolution layer {conv.name} is not a 2-d convolution, ' - 'layerwise equalization or split with n-d convolution is not supported yet.') - - weight, bias = conv.parameters[0].value, None - if including_bias and len(conv.parameters) > 1: - bias = conv.parameters[1].value - - return weight, bias - - - def get_convtranspose2d_params(self, conv: Operation, including_bias: bool): - - assert conv.type == 'ConvTranspose', ( - 'Except input object with type Conv, but %s got' % conv.type) - assert conv.inputs[1].is_parameter, ( - f'Convolution layer {conv.name} has no static weights, please remove it from layerwise optimization.') - assert conv.inputs[1].value.ndim == 4, ( - f'Convolution layer {conv.name} is not a 2-d convolution, ' - 'layerwise equalization or split with n-d convolution is not supported yet.') - - weight, bias = conv.parameters[0].value, None - if including_bias and len(conv.parameters) > 1: - bias = conv.parameters[1].value - - return weight, bias - - - def get_linear_params(self, linear: Operation, including_bias: bool): - - assert linear.type == 'Gemm', ( - 'Except input object with type Gemm, but %s got' % linear.type) - assert linear.inputs[1].is_parameter, ( - f'Linear layer {linear.name} has no static weights, please remove it from layerwise optimization.') - - weight, bias = linear.parameters[0].value, None - if including_bias and len(linear.parameters) > 1: - bias = linear.parameters[1].value - - if not linear.attributes.get('transB', 0): - weight = torch.transpose(weight, 1, 0) - if bias is not None: return weight, bias - else: return [weight, None] + for op in self.upstream_layers: + EqualizationHelper.scale_to_upstream(op, scale) + for op in self.downstream_layers: + EqualizationHelper.scale_to_downstream(op, scale) - def set_conv2d_params(self, conv: Operation, bias: torch.Tensor, weight: torch.Tensor): - - assert conv.type == 'Conv', ( - 'Except input object with type Conv, but %s got' % conv.type) - - conv.parameters[0].value = weight - if bias is not None and len(conv.parameters) > 1: - conv.parameters[1].value = bias - - - def set_convtranspose2d_params(self, conv: Operation, bias: torch.Tensor, weight: torch.Tensor): - - assert conv.type == 'ConvTranspose', ( - 'Except input object with type Conv, but %s got' % conv.type) - - conv.parameters[0].value = weight - if bias is not None and len(conv.parameters) > 1: - conv.parameters[1].value = bias - - - def set_linear_params(self, linear: Operation, bias: torch.Tensor, weight: torch.Tensor): - - assert linear.type == 'Gemm', ( - 'Except input object with type Gemm, but %s got' % linear.type) - - if not linear.attributes.get('transB', 0): - weight = torch.transpose(weight, 1, 0) - linear.parameters[0].value = weight - if bias is not None and len(linear.parameters) > 1: - linear.parameters[1].value = bias - + def channel_split( + self, + value_threshold: float, + including_bias: bool): + pass def calculate_scale( - self, - upstream_key_values: torch.Tensor, + self, upstream_key_values: torch.Tensor, downstream_key_values: torch.Tensor, - minval_threshold: float, - scale_clip_value: float = 10, - ): + value_threshold: float, scale_clip_value: float = 10): scale = 1 / torch.sqrt(upstream_key_values / downstream_key_values) scale = torch.clamp(scale, 1 / scale_clip_value, scale_clip_value) - scale[(upstream_key_values + downstream_key_values) < minval_threshold] = 1 - + scale[(upstream_key_values + downstream_key_values) < value_threshold] = 1 return scale - def reduce_by_axis( self, params: List[torch.Tensor], method: EqualizationMethod, - aggerate_axis: int=1, + axis: int=1, ) -> torch.Tensor: - params = torch.cat(params, axis=aggerate_axis) + params = torch.cat(params, axis=axis) if method is EqualizationMethod.ABSOLUTE_MAX: - return torch.max(torch.abs(params), axis=aggerate_axis)[0] + return torch.max(torch.abs(params), axis=axis)[0] elif method is EqualizationMethod.ABSOLUTE_MEAN: - return torch.mean(torch.abs(params), axis=aggerate_axis) + return torch.mean(torch.abs(params), axis=axis) elif method is EqualizationMethod.SQUARE_MAX: - return torch.max(torch.square(params), axis=aggerate_axis)[0] + return torch.max(torch.square(params), axis=axis)[0] elif method is EqualizationMethod.SQUARE_MEAN: - return torch.mean(torch.square(params), axis=aggerate_axis) + return torch.mean(torch.square(params), axis=axis) else: raise NotImplementedError('Equalization method %s is not support.' % str(method)) - - -def layerwise_equalization( - equalization_pairs: List[EqualizationPair], - weight_threshold: float = 0.5, - including_bias: bool = True, - iteration: int = 10, - verbose: bool = False -): - """ - - layerwise_equalization - 层间权重均一化,使用该函数从大尺度上拉平各个层之间的权重与bias,从而使得量化结果更加精确 - this func equalizes weights and biases between differenr layers and reduces - quantization error - 一次 equalization 操作是指利用性质: C * ( AX + b ) = C/s * ( AsX + b ),所作的恒等变换,其中s为对角矩阵 - one equalization step refers to use above formula to do equivalent transformation - - 通过上述变换以及精心选取的 s,可以使得权重矩阵 A, b, C 的数值大小尽可能接近,从而使得量化更加精准 - we could make numerical ranges of weight matrice of different layers become as similar as possible by - choosing appropriate s, reducing quantization error, while preserving correct results - - 注意目前只支持关于 CONV, GEMM 的权重拉平策略 - for now only CONV and GEMM support equalization - 相关论文: - - "Markus Nagel et al., Data-Free Quantization through Weight Equalization and Bias Correction" arXiv:1906.04721, 2019. - - Args: - equalization_pairs (list): 所有需要被拉平的层的组合结构 all equalization pairs. - weight_threshold (float, optional): 参与权重均一化的最小权重 minimum weight for weight equalization defaults to 0.5. - including_bias (bool, optional): 是否执行带bias的权重均一化 whether to include bias defaults to True. - iteration (int, optional): 均一化执行次数 num of equalization iterations defaults to 10. - verbose (bool, optional): 是否输出均一化的相关结果,这将打印均一化前后的权重变化情况 whether to print details defaults to True. - """ - - if verbose: - for equalization_pair in equalization_pairs: - equalization_pair.display() - - print(f'{len(equalization_pairs)} equalization pair(s) was found, ready to run optimization.') - for iter_times in Progressbar(range(iteration), desc='Layerwise Equalization', total=iteration): - for equalization_pair in equalization_pairs: - assert isinstance(equalization_pair, EqualizationPair), ( - 'Input equalization pairs should be encapsuled with class EqualizationPair') - - equalization_pair.layerwise_equalize( - value_threshold=weight_threshold, - including_bias=including_bias - ) - - if verbose: - for equalization_pair in equalization_pairs: - equalization_pair.display() - diff --git a/ppq/quantization/analyse/graphwise.py b/ppq/quantization/analyse/graphwise.py index 273202b3..68bdf1ea 100644 --- a/ppq/quantization/analyse/graphwise.py +++ b/ppq/quantization/analyse/graphwise.py @@ -2,7 +2,7 @@ from typing import Callable, Dict, Iterator, List import torch -from ppq.core import PASSIVE_OPERATIONS, OperationMeta +from ppq.core import PASSIVE_OPERATIONS, OperationMeta, ppq_warning from ppq.executor import RuntimeHook, TorchExecutor from ppq.IR import BaseGraph, Operation, QuantableOperation, Variable from ppq.quantization.measure.norm import torch_snr_error @@ -23,8 +23,6 @@ def pre_forward_hook(self, inputs: list, **kwargs) -> list: return super().pre_forward_hook(inputs, **kwargs) def post_forward_hook(self, outputs: list, **kwargs) -> list: - assert len(outputs) == 1, ('Multiple output tensor detected. ' - 'Can not monitoring an operation with more than 1 output.') output_tensor = outputs[0] assert isinstance(output_tensor, torch.Tensor), ( 'Output of monitoring operation is not a torch.Tensor') @@ -126,6 +124,10 @@ def graphwise_error_analyse( recorders, hooks, caches = {}, {}, {} for operation in interested_op: if isinstance(operation, QuantableOperation): + if operation.num_of_output > 1: + ppq_warning(f'Operation {operation.name} has more than 1 output, ' + 'analyser will process the first output of it.') + recorders[operation.name] = MeasureRecorder(measurement=method) hooks[operation.name] = OutputRecorder( operation=operation, operation_meta=operation.meta_data, fetchs=fetchs) @@ -197,8 +199,8 @@ def statistical_analyse( The return value of this function is a collection of statistics parameters You are recommended to processing them with pandas - from pandas import Dataframe - report_df = Dataframe(report) + from pandas import DataFrame + report_df = DataFrame(report) Args: graph (BaseGraph): _description_ diff --git a/ppq/quantization/optim/equalization.py b/ppq/quantization/optim/equalization.py index b0550201..f6c8ec5f 100644 --- a/ppq/quantization/optim/equalization.py +++ b/ppq/quantization/optim/equalization.py @@ -8,8 +8,7 @@ from ppq.IR import (BaseGraph, Operation, QuantableOperation, SearchableGraph, TraversalCommand) from ppq.IR.base.graph import BaseGraph -from ppq.quantization.algorithm.equalization import (EqualizationPair, - layerwise_equalization) +from ppq.quantization.algorithm.equalization import EqualizationPair from tqdm import tqdm from .base import QuantizationOptimizationPass @@ -173,8 +172,8 @@ class LayerwiseEqualizationPass(QuantizationOptimizationPass): """ def __init__( self, iterations: int, weight_threshold: float = 0.5, - including_bias: bool = False, including_activation: bool = False, - bias_multiplier: float = 0.5, activation_multiplier: float = 0.5, + including_bias: bool = False, including_act: bool = False, + bias_multiplier: float = 0.5, act_multiplier: float = 0.5, interested_layers: List[str] = None, optimize_level: int = 2, verbose:bool = False) -> None: """PPQ Customized Layerwise Equalization Pass. @@ -200,7 +199,7 @@ def __init__( set this to be True if your hardware does not allow a 32-bit bias. Defaults to False. - including_activation (bool, optional): + including_act (bool, optional): whether to include activation into consideration. Defaults to False. @@ -208,7 +207,7 @@ def __init__( a multiplier to bias, if not necessary do not change this. Defaults to 0.5. - activation_multiplier (float, optional): + act_multiplier (float, optional): a multiplier to activation, if not necessary do not change this. Defaults to 0.5. @@ -221,13 +220,13 @@ def __init__( """ self.optimize_level = optimize_level self.iterations = iterations - self.weight_threshold = weight_threshold + self.value_threshold = weight_threshold self.including_bias = including_bias - self.bias_multiplier = bias_multiplier + self.bias_multiplier = bias_multiplier - self.including_activation = including_activation - self.activation_multiplier = activation_multiplier + self.including_act = including_act + self.act_multiplier = act_multiplier self.interested_layers = interested_layers self.verbose = verbose @@ -287,25 +286,23 @@ def find_equalization_pair( # construct a new equalization pair. if len(upstream_ops) > 0 and len(downstream_ops) > 0: pairs.append(EqualizationPair( - all_upstream_layers=list(upstream_ops), - all_downstream_layers=list(downstream_ops))) + upstream_layers=list(upstream_ops), + downstream_layers=list(downstream_ops))) return pairs - def collect_activations(self, + def collect_activations(self, graph: BaseGraph, executor: TorchExecutor, dataloader: Iterable, collate_fn: Callable, operations: List[Operation], - steps: int = 16) -> Dict[Operation, torch.Tensor]: + steps: int = 16) -> Dict[str, torch.Tensor]: - def aggregate(tensor: torch.Tensor): - if tensor.ndim == 4: # Conv result: [n,c,h,w] + def aggregate(op: Operation, tensor: torch.Tensor): + if op.type in {'Conv', 'ConvTranspose'}: # Conv result: [n,c,h,w] num_of_channel = tensor.shape[1] - tensor = tensor.permute(dims=[1, 0, 2, 3]) + tensor = tensor.transpose(0, 1) tensor = tensor.reshape(shape=[num_of_channel, -1]) tensor = torch.max(tensor.abs(), dim=-1, keepdim=False)[0] - elif tensor.ndim == 2: # Gemm result: [n, c] - num_of_channel = tensor.shape[1] - tensor = tensor.permute(dims=[1, 0]) - tensor = tensor.reshape(shape=[num_of_channel, -1]) + elif op.type in {'MatMul', 'Gemm'}: # Gemm result: [n, c] + tensor = tensor.transpose(0, 1) tensor = torch.max(tensor.abs(), dim=-1, keepdim=False)[0] return tensor @@ -319,15 +316,19 @@ def aggregate(tensor: torch.Tensor): for idx, batch in tqdm(enumerate(dataloader), desc='Equalization Data Collecting.', total=min(len(dataloader), steps)): - data = collate_fn(batch) + data = batch + if collate_fn is not None: + data = collate_fn(batch) outputs = executor.forward(data, output_names=output_names) for name, output in zip(output_names, outputs): - output_collector[name].append(aggregate(output).unsqueeze(-1)) + op = graph.variables[name].source_op + output_collector[name].append(aggregate(op, output).unsqueeze(-1)) if idx > steps: break result = {} for name, output in zip(output_names, outputs): - result[name] = torch.max(torch.cat(output_collector[name], dim=-1)[0], dim=-1) + result[name] = torch.cat(output_collector[name], dim=-1) + print(name, result[name].shape) return result @ empty_ppq_cache @@ -351,19 +352,33 @@ def optimize( pairs = self.find_equalization_pair( graph=graph, interested_operations=interested_operations) - ''' + if self.including_act: + activations = self.collect_activations( + graph=graph, executor=executor, dataloader=dataloader, collate_fn=collate_fn, + operations=interested_operations) + + for name, act in activations.items(): + graph.variables[name].value = act # 将激活值写回网络 + + for name, act in activations.items(): + print(name, torch.max(act, dim=-1)[0][:25]) + + print(f'{len(pairs)} equalization pair(s) was found, ready to run optimization.') + for iter_times in tqdm(range(self.iterations), desc='Layerwise Equalization', total=self.iterations): + for equalization_pair in pairs: + equalization_pair.equalize( + value_threshold=self.value_threshold, + including_bias=self.including_bias, + including_act=self.including_act, + bias_multiplier=self.bias_multiplier, + act_multiplier=self.act_multiplier) + activations = self.collect_activations( - executor=executor, dataloader=dataloader, collate_fn=collate_fn, + graph=graph, executor=executor, dataloader=dataloader, collate_fn=collate_fn, operations=interested_operations) - ''' - - layerwise_equalization( - equalization_pairs=pairs, - weight_threshold=self.weight_threshold, - including_bias=self.including_bias, - iteration=self.iterations, - verbose=self.verbose) + for name, act in activations.items(): + print(name, torch.max(act, dim=-1)[0][:25]) # equalization progress directly changes fp32 value of weight, # store it for following procedure. for op in graph.operations.values(): diff --git a/ppq/quantization/optim/parameters.py b/ppq/quantization/optim/parameters.py index 9597f8c0..8cb7f704 100644 --- a/ppq/quantization/optim/parameters.py +++ b/ppq/quantization/optim/parameters.py @@ -46,6 +46,9 @@ def check_state(state: QuantizationStates): # PATCH 2022.07.29 有的时候 bias 是个多维的东西,此时要求前面的维度都是1 bias = op.inputs[-1].value + if bias is None: raise ValueError(f'Bias Varaible {op.inputs[-1].name} must be constant. ' + 'Please check it again.') + assert bias.numel() == bias.shape[-1], ( f'For op {op.name}, expect Bias shape to be {[bias.numel()]}, ' f'however {bias.shape} was given') @@ -108,12 +111,11 @@ def check_state(state: QuantizationStates): # inputs are [input value, pad[shape-related], pad value[optional]] if op.num_of_input != 3: continue i_cfg = op.config.input_quantization_config[0] - if i_cfg.state != QuantizationStates.PASSIVE_INIT and not self._override: continue if not check_state(i_cfg.state): raise PermissionError(f'Can not quantize pad value of layer {op.name}, ' 'cause input has not been correctly quantized.') - + if len(op.config.input_quantization_config) > 1: pad_config = op.config.input_quantization_config[-1] # 在两种情况下可以执行后续逻辑,1 状态为 PASSIVE_INIT,2 要求 override diff --git a/ppq/quantization/optim/refine.py b/ppq/quantization/optim/refine.py index 58f75306..59b9b71d 100644 --- a/ppq/quantization/optim/refine.py +++ b/ppq/quantization/optim/refine.py @@ -422,12 +422,12 @@ def optimize( ppq_warning(f'Unexpected dispatching was found: ' f'Op {computing_op.name} and {act_op.name} should be send to a same platform.') continue - + if not isinstance(act_op, QuantableOperation): ppq_warning(f'Unexpected dispatching was found: ' f'Op {computing_op.name} and {act_op.name} should both be quantized operation.') continue - + assert isinstance(act_op, QuantableOperation) if (len(graph.get_downstream_operations(computing_op)) == 1 and len(graph.get_upstream_operations(act_op)) == 1): @@ -435,7 +435,7 @@ def optimize( act_op.config.output_quantization_config[0]) act_op.config.input_quantization_config[0].dominated_by = ( act_op.config.output_quantization_config[0]) - + # fuse relu and clip if possible for op in graph.operations.values(): if op.type in {'Relu', 'Clip'}: @@ -642,7 +642,9 @@ def optimize( elif operation.type in TYPES_FOR_ALIGNMENT['Pooling']: if self.averagepool_method == 'None': continue if self.averagepool_method == 'Align to Output': - self.align_to_output(operation) + master_config = self.align_to_output(operation) + if self.averagepool_method == 'Align to Large': + raise ValueError('Alignment Method Error, Pooling Op can not align to lager input.') elif ALIGNMENT_MANUL_OVERRIDE in operation.extension_attrib: method = operation.extension_attrib[ALIGNMENT_MANUL_OVERRIDE] diff --git a/ppq/quantization/qfunction/linear.py b/ppq/quantization/qfunction/linear.py index 65b6413c..5bb8fb1f 100644 --- a/ppq/quantization/qfunction/linear.py +++ b/ppq/quantization/qfunction/linear.py @@ -35,7 +35,9 @@ def forward(ctx, tensor: torch.Tensor, scales: torch.Tensor, offsets: torch.Tensor, quant_min: int, quant_max: int, rounding: RoundingPolicy) -> torch.Tensor: - if not PPQ_CONFIG.USING_CUDA_KERNEL: + if not PPQ_CONFIG.USING_CUDA_KERNEL or not tensor.is_cuda: + scales = scales.to(tensor.device) + offsets = offsets.to(tensor.device) # quantization function, pytorch implmentation tensor = ppq_tensor_round((tensor / scales), rounding) + offsets tensor = torch.clamp(tensor, quant_min, quant_max) @@ -44,6 +46,7 @@ def forward(ctx, tensor: torch.Tensor, scales: torch.Tensor, else: from ppq.core import CUDA + # quantization function, pure cuda implmentation quantized = CUDA.LinearQuantize_T( tensor=tensor, @@ -78,7 +81,9 @@ def forward(ctx, tensor: torch.Tensor, scales: torch.Tensor, offsets: torch.Tensor, channel_axis: int, quant_min: int, quant_max: int, rounding: RoundingPolicy) -> torch.Tensor: - if not PPQ_CONFIG.USING_CUDA_KERNEL: + if not PPQ_CONFIG.USING_CUDA_KERNEL or not tensor.is_cuda: + scales = scales.to(tensor.device) + offsets = offsets.to(tensor.device) # generate a shape that likes [1, 1, -1, 1], the only -1 is at channel axe. shape = [1 if axis != channel_axis else -1 for axis in range(tensor.ndim)] scale, offset = scales.view(shape), offsets.view(shape) @@ -106,6 +111,19 @@ def backward(ctx, dy: torch.Tensor): def PPQLinearQuantFunction( tensor: torch.Tensor, config: TensorQuantizationConfig) -> torch.Tensor: + """ + PPQ 核心量化函数 + + Args: + tensor (torch.Tensor): _description_ + config (TensorQuantizationConfig): _description_ + + Raises: + ValueError: _description_ + + Returns: + torch.Tensor: _description_ + """ if not QuantizationStates.is_activated(config.state): return tensor if not config.policy.has_property(QuantizationProperty.LINEAR): raise ValueError('Critical Quantization Error! Non-linear config detected.') diff --git a/ppq/quantization/quantizer/DSPQuantizer.py b/ppq/quantization/quantizer/DSPQuantizer.py index 8ed284d8..27db6c19 100644 --- a/ppq/quantization/quantizer/DSPQuantizer.py +++ b/ppq/quantization/quantizer/DSPQuantizer.py @@ -73,7 +73,7 @@ def quant_operation_types(self) -> set: 'GlobalMaxPool', 'GlobalAveragePool', 'Softmax', 'Mul', 'Add', 'Max', 'Sub', 'Div', 'Reshape', 'LeakyRelu', 'Concat', 'Sigmoid', 'Slice', 'Interp', - 'ReduceMean'} + 'ReduceMean', 'Flatten'} @ property def quantize_policy(self) -> QuantizationPolicy: diff --git a/ppq/quantization/quantizer/OpenvinoQuantizer.py b/ppq/quantization/quantizer/OpenvinoQuantizer.py index c7ae1162..678d8f9d 100644 --- a/ppq/quantization/quantizer/OpenvinoQuantizer.py +++ b/ppq/quantization/quantizer/OpenvinoQuantizer.py @@ -16,7 +16,7 @@ def __init__( ) -> Union[torch.Tensor, list, dict]: super().__init__(graph=graph) self._num_of_bits = 8 - self._quant_min = -127 + self._quant_min = -128 self._quant_max = 127 def init_quantize_config( @@ -35,7 +35,7 @@ def init_quantize_config( # layout: [out_channel, in_channel, kernel_size, kernel_size] if operation.type in {'Conv', 'ConvTranspose'}: conv_weight_config = base_quant_config.input_quantization_config[1] - conv_weight_config._quant_min = -127 + conv_weight_config._quant_min = -128 conv_weight_config._quant_max = 127 conv_weight_config.policy = QuantizationPolicy( QuantizationProperty.SYMMETRICAL + @@ -52,7 +52,7 @@ def init_quantize_config( # layout: [in_dim, out_dim] elif operation.type in {'Gemm'}: gemm_weight_config = base_quant_config.input_quantization_config[1] - gemm_weight_config._quant_min = -127 + gemm_weight_config._quant_min = -128 gemm_weight_config._quant_max = 127 gemm_weight_config.policy = QuantizationPolicy( QuantizationProperty.SYMMETRICAL + diff --git a/ppq/quantization/quantizer/base.py b/ppq/quantization/quantizer/base.py index 2ba5376d..f2fe6233 100644 --- a/ppq/quantization/quantizer/base.py +++ b/ppq/quantization/quantizer/base.py @@ -375,9 +375,9 @@ def build_prequant_pipeline( iterations = equalization_setting.iterations, weight_threshold = equalization_setting.value_threshold, including_bias = equalization_setting.including_bias, - including_activation = equalization_setting.including_act, - bias_multiplier = equalization_setting.bias_multiplier, - activation_multiplier = equalization_setting.act_multiplier + including_act = equalization_setting.including_act, + bias_multiplier = equalization_setting.bias_multiplier, + act_multiplier = equalization_setting.act_multiplier )) return QuantizationOptimizationPipeline(passes=list_of_passes) diff --git a/tests/test_layerwise_equalization.py b/tests/test_layerwise_equalization.py new file mode 100644 index 00000000..d76b327f --- /dev/null +++ b/tests/test_layerwise_equalization.py @@ -0,0 +1,236 @@ +from ppq import * +from ppq.api import * +import torch + +# 创建计算图 +graph = BaseGraph(name='Created Graph', built_from=NetworkFramework.NATIVE) +op1 = graph.create_operation( + op_type='Gemm', name='Gemm1', + inputs=[graph.create_variable(), + graph.create_variable(is_parameter=True, value=torch.rand(size=[128, 1024]).cuda() * 100), + graph.create_variable(is_parameter=True, value=torch.rand(size=[1024]).cuda() * 100)], + outputs=[graph.create_variable()]) + +op2 = graph.create_operation( + op_type='Gemm', name='Gemm2', + inputs=[op1.outputs[0], + graph.create_variable(is_parameter=True, value=torch.rand(size=[1024, 128]).cuda()), + graph.create_variable(is_parameter=True, value=torch.rand(size=[128]).cuda())], + outputs=[graph.create_variable()]) + +op3 = graph.create_operation( + op_type='Gemm', name='Gemm3', attributes={'transB': 1}, + inputs=[op2.outputs[0], + graph.create_variable(is_parameter=True, value=torch.rand(size=[1024, 128]).cuda() * 500), + graph.create_variable(is_parameter=True, value=torch.rand(size=[1024]).cuda() * 0)], + outputs=[graph.create_variable()]) + +graph.mark_variable_as_graph_input(op1.inputs[0]) +graph.mark_variable_as_graph_output(op3.outputs[0]) + +inputs = [torch.rand(size=[8, 128]) for _ in range(32)] +executor = TorchExecutor(graph=graph) +b_outputs = [executor.forward(inputs=t.cuda())[0].unsqueeze(0) for t in inputs] +b_outputs = torch.cat(b_outputs) + +from ppq.quantization.optim import LayerwiseEqualizationPass +LayerwiseEqualizationPass(iterations=1000, including_bias=True, including_act=True).optimize( + graph=graph, dataloader=inputs, executor=executor, collate_fn=lambda x: x.cuda()) +p_outputs = [executor.forward(inputs=t.cuda())[0].unsqueeze(0) for t in inputs] +p_outputs = torch.cat(p_outputs) +from ppq.quantization.measure import torch_snr_error +assert torch_snr_error(b_outputs, p_outputs).item() < 1e-7 + +import torch +class MyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + with torch.no_grad(): + self.conv1 = torch.nn.Conv2d(in_channels=8, out_channels=32, kernel_size=3, stride=2, padding=1) + self.conv2 = torch.nn.Conv2d(in_channels=32, out_channels=32, groups=16, kernel_size=3, stride=1, padding=1, bias=False) + self.conv3 = torch.nn.Conv2d(in_channels=32, out_channels=8, groups=8, kernel_size=5, stride=1, padding=2) + self.convtranspose1 = torch.nn.ConvTranspose2d(in_channels=8, out_channels=32, kernel_size=5, stride=1, padding=2) + self.convtranspose2 = torch.nn.ConvTranspose2d(in_channels=32, out_channels=32, groups=32, kernel_size=3, stride=2, bias=False) + self.convtranspose3 = torch.nn.ConvTranspose2d(in_channels=32, out_channels=8, groups=1, kernel_size=1) + + self.conv1.bias.copy_(torch.rand_like(self.conv1.bias)) + self.conv3.bias.copy_(torch.rand_like(self.conv3.bias)) + self.convtranspose1.bias.copy_(torch.rand_like(self.convtranspose1.bias)) + self.convtranspose3.bias.copy_(torch.rand_like(self.convtranspose3.bias)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.convtranspose1(x) + x = self.convtranspose2(x) + x = self.convtranspose3(x) + return x + +model = MyModel().cuda() +dump_torch_to_onnx(model=model, onnx_export_file='model.onnx', input_shape=[1, 8, 96, 96]) +graph = load_onnx_graph(onnx_import_file='model.onnx') + +inputs = [torch.rand(size=[1, 8, 96, 96]) for _ in range(32)] +executor = TorchExecutor(graph=graph) +b_outputs = [executor.forward(inputs=t.cuda())[0].unsqueeze(0) for t in inputs] +b_outputs = torch.cat(b_outputs) + +from ppq.quantization.optim import LayerwiseEqualizationPass +LayerwiseEqualizationPass(iterations=10, including_bias=True, including_act=True).optimize( + graph=graph, dataloader=inputs, executor=executor, collate_fn=lambda x: x.cuda()) +p_outputs = [executor.forward(inputs=t.cuda())[0].unsqueeze(0) for t in inputs] +p_outputs = torch.cat(p_outputs) +from ppq.quantization.measure import torch_snr_error +assert torch_snr_error(b_outputs, p_outputs).item() < 1e-7 + + +class MyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + with torch.no_grad(): + self.conv1 = torch.nn.Conv3d(in_channels=8, out_channels=32, kernel_size=3, stride=2, padding=1) + self.conv2 = torch.nn.Conv3d(in_channels=32, out_channels=32, groups=16, kernel_size=3, stride=1, padding=1, bias=False) + self.conv3 = torch.nn.Conv3d(in_channels=32, out_channels=8, groups=8, kernel_size=5, stride=1, padding=2) + self.conv4 = torch.nn.Conv3d(in_channels=8, out_channels=32, kernel_size=3, stride=2, padding=1) + + self.conv1.bias.copy_(torch.rand_like(self.conv1.bias)) + self.conv3.bias.copy_(torch.rand_like(self.conv3.bias)) + self.conv4.bias.copy_(torch.rand_like(self.conv4.bias)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + return x + +model = MyModel().cuda() +dump_torch_to_onnx(model=model, onnx_export_file='model.onnx', input_shape=[1, 8, 16, 96, 96]) +graph = load_onnx_graph(onnx_import_file='model.onnx') + +inputs = [torch.rand(size=[1, 8, 16, 96, 96]) for _ in range(32)] +executor = TorchExecutor(graph=graph) +b_outputs = [executor.forward(inputs=t.cuda())[0].unsqueeze(0) for t in inputs] +b_outputs = torch.cat(b_outputs) + +from ppq.quantization.optim import LayerwiseEqualizationPass +LayerwiseEqualizationPass(iterations=10, including_bias=True, including_act=True).optimize( + graph=graph, dataloader=inputs, executor=executor, collate_fn=lambda x: x.cuda()) +p_outputs = [executor.forward(inputs=t.cuda())[0].unsqueeze(0) for t in inputs] +p_outputs = torch.cat(p_outputs) +from ppq.quantization.measure import torch_snr_error +assert torch_snr_error(b_outputs, p_outputs).item() < 1e-7 + + +class MyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + with torch.no_grad(): + self.conv1 = torch.nn.Conv1d(in_channels=8, out_channels=32, kernel_size=3, stride=2, padding=1) + self.conv2 = torch.nn.Conv1d(in_channels=32, out_channels=32, groups=16, kernel_size=3, stride=1, padding=1, bias=False) + self.conv3 = torch.nn.Conv1d(in_channels=32, out_channels=8, groups=8, kernel_size=5, stride=1, padding=2) + self.conv4 = torch.nn.Conv1d(in_channels=8, out_channels=32, kernel_size=3, stride=2, padding=1) + + self.conv1.bias.copy_(torch.rand_like(self.conv1.bias)) + self.conv3.bias.copy_(torch.rand_like(self.conv3.bias)) + self.conv4.bias.copy_(torch.rand_like(self.conv4.bias)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + return x + +model = MyModel().cuda() +dump_torch_to_onnx(model=model, onnx_export_file='model.onnx', input_shape=[1, 8, 96]) +graph = load_onnx_graph(onnx_import_file='model.onnx') + +inputs = [torch.rand(size=[1, 8, 96]) for _ in range(32)] +executor = TorchExecutor(graph=graph) +b_outputs = [executor.forward(inputs=t.cuda())[0].unsqueeze(0) for t in inputs] +b_outputs = torch.cat(b_outputs) + +from ppq.quantization.optim import LayerwiseEqualizationPass +LayerwiseEqualizationPass(iterations=10, including_bias=True, including_act=True).optimize( + graph=graph, dataloader=inputs, executor=executor, collate_fn=lambda x: x.cuda()) +p_outputs = [executor.forward(inputs=t.cuda())[0].unsqueeze(0) for t in inputs] +p_outputs = torch.cat(p_outputs) +from ppq.quantization.measure import torch_snr_error +assert torch_snr_error(b_outputs, p_outputs).item() < 1e-7 + + +class MyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + with torch.no_grad(): + self.conv1 = torch.nn.ConvTranspose1d(in_channels=8, out_channels=32, kernel_size=3, stride=2, padding=1) + self.conv2 = torch.nn.ConvTranspose1d(in_channels=32, out_channels=32, groups=16, kernel_size=3, stride=1, padding=1, bias=False) + self.conv3 = torch.nn.ConvTranspose1d(in_channels=32, out_channels=8, groups=8, kernel_size=5, stride=1, padding=2) + self.conv4 = torch.nn.ConvTranspose1d(in_channels=8, out_channels=32, kernel_size=3, stride=2, padding=1) + + self.conv1.bias.copy_(torch.rand_like(self.conv1.bias)) + self.conv3.bias.copy_(torch.rand_like(self.conv3.bias)) + self.conv4.bias.copy_(torch.rand_like(self.conv4.bias)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + return x + +model = MyModel().cuda() +dump_torch_to_onnx(model=model, onnx_export_file='model.onnx', input_shape=[1, 8, 96]) +graph = load_onnx_graph(onnx_import_file='model.onnx') + +inputs = [torch.rand(size=[1, 8, 96]) for _ in range(32)] +executor = TorchExecutor(graph=graph) +b_outputs = [executor.forward(inputs=t.cuda())[0].unsqueeze(0) for t in inputs] +b_outputs = torch.cat(b_outputs) + +from ppq.quantization.optim import LayerwiseEqualizationPass +LayerwiseEqualizationPass(iterations=10, including_bias=True, including_act=True).optimize( + graph=graph, dataloader=inputs, executor=executor, collate_fn=lambda x: x.cuda()) +p_outputs = [executor.forward(inputs=t.cuda())[0].unsqueeze(0) for t in inputs] +p_outputs = torch.cat(p_outputs) +from ppq.quantization.measure import torch_snr_error +assert torch_snr_error(b_outputs, p_outputs).item() < 1e-7 + +class MyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + with torch.no_grad(): + self.conv1 = torch.nn.ConvTranspose3d(in_channels=8, out_channels=32, kernel_size=3, stride=2, padding=1) + self.conv2 = torch.nn.ConvTranspose3d(in_channels=32, out_channels=32, groups=16, kernel_size=3, stride=1, padding=1, bias=False) + self.conv3 = torch.nn.ConvTranspose3d(in_channels=32, out_channels=8, groups=8, kernel_size=5, stride=1, padding=2) + self.conv4 = torch.nn.ConvTranspose3d(in_channels=8, out_channels=32, kernel_size=3, stride=2, padding=1) + + self.conv1.bias.copy_(torch.rand_like(self.conv1.bias)) + self.conv3.bias.copy_(torch.rand_like(self.conv3.bias)) + self.conv4.bias.copy_(torch.rand_like(self.conv4.bias)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + return x + +model = MyModel().cuda() +dump_torch_to_onnx(model=model, onnx_export_file='model.onnx', input_shape=[1, 8, 8, 8, 8]) +graph = load_onnx_graph(onnx_import_file='model.onnx') + +inputs = [torch.rand(size=[1, 8, 8, 8, 8]) for _ in range(32)] +executor = TorchExecutor(graph=graph) +b_outputs = [executor.forward(inputs=t.cuda())[0].unsqueeze(0) for t in inputs] +b_outputs = torch.cat(b_outputs) + +from ppq.quantization.optim import LayerwiseEqualizationPass +LayerwiseEqualizationPass(iterations=10, including_bias=True, including_act=True).optimize( + graph=graph, dataloader=inputs, executor=executor, collate_fn=lambda x: x.cuda()) +p_outputs = [executor.forward(inputs=t.cuda())[0].unsqueeze(0) for t in inputs] +p_outputs = torch.cat(p_outputs) +from ppq.quantization.measure import torch_snr_error +assert torch_snr_error(b_outputs, p_outputs).item() < 1e-7 \ No newline at end of file