Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 28, 2024
1 parent 2903672 commit 8e12a5b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
32 changes: 17 additions & 15 deletions neural_compressor/torch/algorithms/weight_only/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import time

import torch
from auto_round import AutoRound # pylint: disable=E0401
from auto_round.calib_dataset import CALIB_DATASETS # pylint: disable=E0401
from auto_round.utils import get_block_names # pylint: disable=E0401
from auto_round.utils import get_block_names # pylint: disable=E0401

from neural_compressor.torch.utils import logger
from neural_compressor.torch.algorithms import Quantizer
from neural_compressor.torch.utils import logger


class AutoRoundQuantizer(Quantizer):
def __init__(
self,
self,
model,
weight_config: dict = {},
enable_full_range: bool = False,
Expand All @@ -48,6 +50,7 @@ def __init__(
scale_dtype="fp32",
):
"""Init a AutQRoundQuantizer object.
Args:
model: The PyTorch model to be quantized.
weight_config (dict): Configuration for weight quantization (default is an empty dictionary).
Expand Down Expand Up @@ -88,7 +91,7 @@ def __init__(
scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels
have different choices.
"""

self.model = model
self.tokenizer = None
self.weight_config = weight_config
Expand All @@ -113,7 +116,6 @@ def __init__(
self.dynamic_max_gap = dynamic_max_gap
self.data_type = "int"
self.scale_dtype = scale_dtype


def quantize(self, model: torch.nn.Module, *args, **kwargs):
run_fn = kwargs.get("run_fn", None)
Expand All @@ -129,8 +131,7 @@ def quantize(self, model: torch.nn.Module, *args, **kwargs):
run_fn(model)
model = self.convert(model)
return model



def prepare(self, model, *args, **kwargs):
"""Prepares a given model for quantization.
Args:
Expand Down Expand Up @@ -167,12 +168,12 @@ def prepare(self, model, *args, **kwargs):
)
self.rounder.prepare()
return model

def convert(self, model: torch.nn.Module, *args, **kwargs):
model, weight_config = self.rounder.convert()
model.autoround_config = weight_config
return model


@torch.no_grad()
def get_autoround_default_run_fn(
Expand Down Expand Up @@ -248,8 +249,9 @@ def get_autoround_default_run_fn(
"Effective samples size: {}, Target sample size: {}".format(total_cnt, n_samples)
)


class AutoRoundProcessor(AutoRound):

def prepare(self):
"""Quantize the model and return the quantized model along with weight configurations.
Expand All @@ -267,18 +269,18 @@ def prepare(self):
if not self.low_gpu_mem_usage:
self.model = self.model.to(self.device)
# inputs = self.cache_block_input(block_names[0], self.n_samples)

# cache block input
self.inputs = {}
self.tmp_block_name = self.block_names[0]
self._replace_forward()

def convert(self):
# self.calib(self.n_samples)
self._recover_forward()
inputs = self.inputs[self.tmp_block_name]
del self.tmp_block_name

del self.inputs
if "input_ids" in inputs.keys():
dim = int((hasattr(self.model, "config") and "chatglm" in self.model.config.model_type))
Expand Down Expand Up @@ -343,4 +345,4 @@ def convert(self):

self.quantized = True
self.model = self.model.to(self.model_orig_dtype)
return self.model, self.weight_config
return self.model, self.weight_config
4 changes: 2 additions & 2 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def teq_quantize_entry(
###################### AUTOROUND Algo Entry ##################################
@register_algo(name=AUTOROUND)
def autoround_quantize_entry(
model: torch.nn.Module,
model: torch.nn.Module,
configs_mapping: Dict[Tuple[str, callable], AutoRoundConfig],
mode: Mode = Mode.QUANTIZE,
*args,
Expand Down Expand Up @@ -376,7 +376,7 @@ def autoround_quantize_entry(
scale_dtype = quant_config.scale_dtype

kwargs.pop("example_inputs")

quantizer = AutoRoundQuantizer(
model=model,
weight_config=weight_config,
Expand Down

0 comments on commit 8e12a5b

Please sign in to comment.