diff --git a/.azure-pipelines/scripts/ut/env_setup.sh b/.azure-pipelines/scripts/ut/env_setup.sh index b0ee2c77cf8..d5876b07cef 100644 --- a/.azure-pipelines/scripts/ut/env_setup.sh +++ b/.azure-pipelines/scripts/ut/env_setup.sh @@ -92,7 +92,7 @@ elif [[ $(echo "${test_case}" | grep -c "tf pruning") != 0 ]]; then fi if [[ $(echo "${test_case}" | grep -c "api") != 0 ]] || [[ $(echo "${test_case}" | grep -c "adaptor") != 0 ]]; then - pip install git+https://github.com/intel/auto-round.git@ecca5349981044e1278773a251b3fc5c0a11fe7b + pip install auto-round fi # test deps diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py index 5ff78a3413d..2e97533c0bb 100644 --- a/neural_compressor/torch/algorithms/weight_only/autoround.py +++ b/neural_compressor/torch/algorithms/weight_only/autoround.py @@ -12,27 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import json import time +from typing import Union import torch from auto_round import AutoRound # pylint: disable=E0401 -from auto_round.calib_dataset import CALIB_DATASETS # pylint: disable=E0401 -from auto_round.utils import get_block_names # pylint: disable=E0401 +from auto_round.export.export_to_itrex.export import pack_model # pylint: disable=E0401 from neural_compressor.torch.algorithms import Quantizer -from neural_compressor.torch.utils import logger +from neural_compressor.torch.utils import get_accelerator, logger + +from .utility import CapturedDataloader, InputCaptureModule class AutoRoundQuantizer(Quantizer): def __init__( self, - quant_config: dict = None, + quant_config: dict = {}, enable_full_range: bool = False, batch_size: int = 8, amp: bool = True, device=None, lr_scheduler=None, - use_quant_input: bool = True, + enable_quanted_input: bool = True, enable_minmax_tuning: bool = True, lr: float = None, minmax_lr: float = None, @@ -46,7 +50,9 @@ def __init__( gradient_accumulate_steps: int = 1, not_use_best_mse: bool = False, dynamic_max_gap: int = -1, - scale_dtype="fp32", + data_type: str = "int", + scale_dtype: str = "fp16", + **kwargs, ): """Init a AutQRoundQuantizer object. @@ -86,7 +92,7 @@ def __init__( gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). not_use_best_mse (bool): Whether to use mean squared error (default is False). dynamic_max_gap (int): The dynamic maximum gap (default is -1). - scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels + scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels have different choices. """ super().__init__(quant_config) @@ -94,9 +100,9 @@ def __init__( self.enable_full_range = enable_full_range self.batch_size = batch_size self.amp = amp - self.device = device + self.device = get_accelerator(kwargs.pop("device", "auto")).current_device_name() self.lr_scheduler = lr_scheduler - self.use_quant_input = use_quant_input + self.enable_quanted_input = enable_quanted_input self.enable_minmax_tuning = enable_minmax_tuning self.lr = lr self.minmax_lr = minmax_lr @@ -110,7 +116,7 @@ def __init__( self.gradient_accumulate_steps = gradient_accumulate_steps self.not_use_best_mse = not_use_best_mse self.dynamic_max_gap = dynamic_max_gap - self.data_type = "int" + self.data_type = data_type self.scale_dtype = scale_dtype def prepare(self, model: torch.nn.Module, *args, **kwargs): @@ -121,16 +127,23 @@ def prepare(self, model: torch.nn.Module, *args, **kwargs): Returns: A prepared model. """ - self.rounder = AutoRoundProcessor( + prepare_model = InputCaptureModule(model) + return prepare_model + + def convert(self, model: torch.nn.Module, *args, **kwargs): + dataloader = CapturedDataloader(model.args_list, model.kwargs_list) + model = model.orig_model + rounder = AutoRound( model=model, tokenizer=None, + dataset=dataloader, weight_config=self.quant_config or {}, enable_full_range=self.enable_full_range, batch_size=self.batch_size, amp=self.amp, device=self.device, lr_scheduler=self.lr_scheduler, - use_quant_input=self.use_quant_input, + enable_quanted_input=self.enable_quanted_input, enable_minmax_tuning=self.enable_minmax_tuning, lr=self.lr, minmax_lr=self.minmax_lr, @@ -147,179 +160,32 @@ def prepare(self, model: torch.nn.Module, *args, **kwargs): data_type=self.data_type, scale_dtype=self.scale_dtype, ) - self.rounder.prepare() - return model - - def convert(self, model: torch.nn.Module, *args, **kwargs): - model, weight_config = self.rounder.convert() + model, weight_config = rounder.quantize() model.autoround_config = weight_config + model = pack_model(model, weight_config, device=self.device, inplace=True) return model -@torch.no_grad() -def get_autoround_default_run_fn( - model, - tokenizer, - dataset_name="NeelNanda/pile-10k", - n_samples=512, - seqlen=2048, - seed=42, - bs=8, - dataset_split: str = "train", - dataloader=None, -): - """Perform calibration for quantization. - - This method calibrates the model for quantization by processing a specified - number of samples from the calibration dataset. It ensures that the data is - properly formatted and feeds it to the model. If the number of samples processed - is less than the specified number, it logs a warning. If no samples are processed, - it logs an error and exits. +def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, n_samples=512): + """Generate a DataLoader for calibration using specified parameters. Args: - n_samples (int): The number of samples to use for calibration. + tokenizer (Tokenizer): The tokenizer to use for tokenization. + seqlen (int): The exact sequence length. samples < seqlen will be dropped, + samples longer than seqlen will be truncated + dataset_name (str, optional): The name of the dataset or datasets separated by commas. + Defaults to "NeelNanda/pile-10k". + split (str, optional): The data split to use. Defaults to None. + seed (int, optional): The random seed for reproducibility. Defaults to 42. + bs (int, optional): The batch size. Defaults to 4. + n_samples (int, optional): The total number of samples to include. Defaults to 512. + + Returns: + DataLoader: The DataLoader for the calibrated dataset. """ - if dataloader is None: - get_dataloader = CALIB_DATASETS.get(dataset_name, CALIB_DATASETS["NeelNanda/pile-10k"]) - dataloader = get_dataloader( - tokenizer, - seqlen, - seed=seed, - bs=bs, - split=dataset_split, - dataset_name=dataset_name, - ) - total_cnt = 0 - for data in dataloader: - if data is None: - continue - if isinstance(data, torch.Tensor): - data_new = data.to(model.device) - input_ids = data_new - else: - data_new = {} - for key in data.keys(): - data_new[key] = data[key].to(model.device) - input_ids = data_new["input_ids"] - # if input_ids.shape[-1] < seqlen: - # continue - if total_cnt + input_ids.shape[0] > n_samples: - input_ids = input_ids[: n_samples - total_cnt, ...] - try: - if isinstance(data_new, torch.Tensor): - model(data_new) - elif isinstance(data_new, dict): - model(**data_new) - else: - # Handle cases where data_new is neither a Tensor nor a dict - raise NotImplementedError(f"Handling not implemented for data type {type(data)}") - except Exception as error: - logger.error(error) - total_cnt += input_ids.shape[0] - if total_cnt >= n_samples: - break - if total_cnt == 0: - logger.error( - "no data has been cached, please provide more data with sequence length >= {} in the ".format(seqlen) - + "dataloader or decease the sequence length." - ) - exit() - elif total_cnt < n_samples: - logger.warning( - "Insufficient number of samples collected may affect the quantification. " - "Effective samples size: {}, Target sample size: {}".format(total_cnt, n_samples) - ) - - -class AutoRoundProcessor(AutoRound): - - def prepare(self): - """Prepares a given model for quantization.""" - # logger.info("cache block input") - self.start_time = time.time() - self.block_names = get_block_names(self.model) - if len(self.block_names) == 0: - logger.warning("could not find blocks, exit with original model") - return - if self.amp: - self.model = self.model.to(self.amp_dtype) - if not self.low_gpu_mem_usage: - self.model = self.model.to(self.device) - # inputs = self.cache_block_input(block_names[0], self.n_samples) - - # cache block input - self.inputs = {} - self.tmp_block_name = self.block_names[0] - self._replace_forward() - - def convert(self): - """Converts a prepared model to a quantized model.""" - self._recover_forward() - inputs = self.inputs[self.tmp_block_name] - del self.tmp_block_name - - del self.inputs - if "input_ids" in inputs.keys(): - dim = int((hasattr(self.model, "config") and "chatglm" in self.model.config.model_type)) - total_samples = inputs["input_ids"].shape[dim] - self.n_samples = total_samples - if total_samples < self.train_bs: - self.train_bs = total_samples - logger.warning(f"force the train batch size to {total_samples} ") - self.model = self.model.to("cpu") - torch.cuda.empty_cache() - self.qdq_weight_round( - self.model, - inputs, - self.block_names, - n_blocks=self.n_blocks, - device=self.device, - ) - for n, m in self.model.named_modules(): - if n in self.weight_config.keys(): - if hasattr(m, "scale"): - self.weight_config[n]["scale"] = m.scale - self.weight_config[n]["zp"] = m.zp - if self.group_size <= 0: - self.weight_config[n]["g_idx"] = torch.tensor( - [0 for i in range(m.weight.shape[1])], dtype=torch.int32, device="cpu" - ) - else: - self.weight_config[n]["g_idx"] = torch.tensor( - [i // self.group_size for i in range(m.weight.shape[1])], dtype=torch.int32, device="cpu" - ) - delattr(m, "scale") - delattr(m, "zp") - else: - self.weight_config[n]["data_type"] = "float" - if self.amp_dtype == torch.bfloat16: - self.weight_config[n]["data_type"] = "bfloat" - self.weight_config[n]["bits"] = 16 - self.weight_config[n]["group_size"] = None - self.weight_config[n]["sym"] = None - - end_time = time.time() - cost_time = end_time - self.start_time - logger.info(f"quantization tuning time {cost_time}") - ## dump a summary - quantized_layers = [] - unquantized_layers = [] - for n, m in self.model.named_modules(): - if isinstance(m, tuple(self.supported_types)): - if self.weight_config[n]["bits"] == 16: - unquantized_layers.append(n) - else: - quantized_layers.append(n) - summary_info = ( - f"Summary: quantized {len(quantized_layers)}/{len(quantized_layers) + len(unquantized_layers)} in the model" - ) - if len(unquantized_layers) > 0: - summary_info += f", {unquantized_layers} have not been quantized" - - logger.info(summary_info) - if len(unquantized_layers) > 0: - logger.info(f"Summary: {unquantized_layers} have not been quantized") + from auto_round.calib_dataset import get_dataloader # pylint: disable=E0401 - self.quantized = True - self.model = self.model.to(self.model_orig_dtype) - return self.model, self.weight_config + dataloader = get_dataloader( + tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, n_samples=n_samples + ) + return dataloader diff --git a/neural_compressor/torch/algorithms/weight_only/utility.py b/neural_compressor/torch/algorithms/weight_only/utility.py index b0acf53280d..31cbe3bc342 100644 --- a/neural_compressor/torch/algorithms/weight_only/utility.py +++ b/neural_compressor/torch/algorithms/weight_only/utility.py @@ -1072,3 +1072,32 @@ def _hook(module, inputs, outputs): for h in hook_list: h.remove() return total_values + + +class CapturedDataloader: + def __init__(self, args_list, kwargs_list) -> None: + self.args_list = args_list + self.kwargs_list = kwargs_list + + def __iter__(self): + for args, kwargs in zip(self.args_list, self.kwargs_list): + if not args: + yield kwargs + elif not kwargs: + yield args + else: + yield args, kwargs + + +class InputCaptureModule(torch.nn.Module): + + def __init__(self, model) -> None: + super().__init__() + self.args_list = [] + self.kwargs_list = [] + self.orig_model = model + + def forward(self, *args, **kwargs): + with torch.no_grad(): + self.args_list.append(args) + self.kwargs_list.append(kwargs) diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 03e3bf23115..59e42729555 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -449,7 +449,7 @@ def autoround_quantize_entry( enable_full_range = quant_config.enable_full_range batch_size = quant_config.batch_size lr_scheduler = quant_config.lr_scheduler - use_quant_input = quant_config.use_quant_input + enable_quanted_input = quant_config.enable_quanted_input enable_minmax_tuning = quant_config.enable_minmax_tuning lr = quant_config.lr minmax_lr = quant_config.minmax_lr @@ -474,7 +474,7 @@ def autoround_quantize_entry( enable_full_range=enable_full_range, batch_size=batch_size, lr_scheduler=lr_scheduler, - use_quant_input=use_quant_input, + enable_quanted_input=enable_quanted_input, enable_minmax_tuning=enable_minmax_tuning, lr=lr, minmax_lr=minmax_lr, diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 160b17e8f96..430e3a07983 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -666,7 +666,7 @@ def __init__( enable_full_range: bool = False, batch_size: int = 8, lr_scheduler=None, - use_quant_input: bool = True, + enable_quanted_input: bool = True, enable_minmax_tuning: bool = True, lr: float = None, minmax_lr: float = None, @@ -680,7 +680,7 @@ def __init__( gradient_accumulate_steps: int = 1, not_use_best_mse: bool = False, dynamic_max_gap: int = -1, - scale_dtype: str = "fp32", + scale_dtype: str = "fp16", white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, ): """Init AUTOROUND weight-only quantization config. @@ -693,7 +693,7 @@ def __init__( enable_full_range (bool): Whether to enable full range quantization (default is False). batch_size (int): Batch size for training (default is 8). lr_scheduler: The learning rate scheduler to be used. - use_quant_input (bool): Whether to use quantized input data (default is True). + enable_quanted_input (bool): Whether to use quantized input data (default is True). enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True). lr (float): The learning rate (default is 0.005). minmax_lr (float): The learning rate for min-max tuning (default is None). @@ -707,7 +707,7 @@ def __init__( gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1). not_use_best_mse (bool): Whether to use mean squared error (default is False). dynamic_max_gap (int): The dynamic maximum gap (default is -1). - scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels + scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels have different choices. """ super().__init__(white_list=white_list) @@ -718,7 +718,7 @@ def __init__( self.enable_full_range = enable_full_range self.batch_size = batch_size self.lr_scheduler = lr_scheduler - self.use_quant_input = use_quant_input + self.enable_quanted_input = enable_quanted_input self.enable_minmax_tuning = enable_minmax_tuning self.lr = lr self.minmax_lr = minmax_lr diff --git a/neural_compressor/torch/utils/environ.py b/neural_compressor/torch/utils/environ.py index 7e38e9484c6..611ab5fda15 100644 --- a/neural_compressor/torch/utils/environ.py +++ b/neural_compressor/torch/utils/environ.py @@ -31,6 +31,20 @@ def is_hpex_available(): return _hpex_available +def is_ipex_imported() -> bool: + for name, _ in sys.modules.items(): + if name == "intel_extension_for_pytorch": + return True + return False + + +def is_transformers_imported() -> bool: + for name, _ in sys.modules.items(): + if name == "transformers": + return True + return False + + try: import intel_extension_for_pytorch as ipex @@ -67,20 +81,6 @@ def get_torch_version(): return version -def is_ipex_imported() -> bool: - for name, _ in sys.modules.items(): - if name == "intel_extension_for_pytorch": - return True - return False - - -def is_transformers_imported() -> bool: - for name, _ in sys.modules.items(): - if name == "transformers": - return True - return False - - def get_accelerator(device_name="auto"): global accelerator # update the global accelerator when calling this func from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py index 6ff08696a01..b4ca66ad00b 100644 --- a/test/3x/torch/quantization/weight_only/test_autoround.py +++ b/test/3x/torch/quantization/weight_only/test_autoround.py @@ -6,7 +6,6 @@ import transformers from packaging.version import Version -from neural_compressor.torch.algorithms.weight_only.autoround import AutoRoundQuantizer, get_autoround_default_run_fn from neural_compressor.torch.quantization import ( AutoRoundConfig, convert, @@ -16,18 +15,29 @@ ) from neural_compressor.torch.utils import logger +torch.backends.__allow_nonbracketed_mutation_flag = True +from neural_compressor.torch.algorithms.weight_only.autoround import get_dataloader + try: import auto_round + from auto_round.export.export_to_itrex.model_wrapper import WeightOnlyLinear - AUTO_ROUND_VERSION_0_11 = Version("0.11") - - auto_round_version = auto_round.__version__.split("+")[0] - auto_round_version = Version(auto_round_version) auto_round_installed = True except ImportError: auto_round_installed = False +@torch.no_grad() +def run_fn(model, dataloader): + for data in dataloader: + if isinstance(data, tuple) or isinstance(data, list): + model(*data) + elif isinstance(data, dict): + model(**data) + else: + model(data) + + @pytest.mark.skipif(not auto_round_installed, reason="auto_round module is not installed") class TestAutoRound: def setup_class(self): @@ -36,9 +46,10 @@ def setup_class(self): torchscript=True, ) self.inp = torch.ones([1, 10], dtype=torch.long) - self.tokenizer = transformers.AutoTokenizer.from_pretrained( + tokenizer = transformers.AutoTokenizer.from_pretrained( "hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True ) + self.dataloader = get_dataloader(tokenizer, 32, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, n_samples=10) self.label = self.gptj(self.inp)[0] def teardown_class(self): @@ -47,98 +58,56 @@ def teardown_class(self): def setup_method(self, method): logger.info(f"Running TestAutoRound test: {method.__name__}") - def test_autoround(self): - gpt_j_model = copy.deepcopy(self.gptj) - quant_config = AutoRoundConfig(n_samples=20, seqlen=10, iters=10, scale_dtype="fp32") + @pytest.mark.parametrize("quant_lm_head", [True, False]) + def test_autoround(self, quant_lm_head): + fp32_model = copy.deepcopy(self.gptj) + quant_config = AutoRoundConfig(n_samples=32, seqlen=10, iters=10, scale_dtype="fp32") + if quant_lm_head is False: + quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) logger.info(f"Test AutoRound with config {quant_config}") - run_fn = get_autoround_default_run_fn - run_args = ( - self.tokenizer, - "NeelNanda/pile-10k", - 20, - 10, - ) - fp32_model = gpt_j_model - # prepare + convert API model = prepare(model=fp32_model, quant_config=quant_config) - run_fn(model, *run_args) + + run_fn(model, self.dataloader) q_model = convert(model) out = q_model(self.inp)[0] assert torch.allclose(out, self.label, atol=1e-1) assert "transformer.h.0.attn.k_proj" in q_model.autoround_config.keys() assert "scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys() assert torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"] - - def test_quantizer(self): - gpt_j_model = copy.deepcopy(self.gptj) - - run_fn = get_autoround_default_run_fn - run_args = ( - self.tokenizer, - "NeelNanda/pile-10k", - 20, - 10, - ) - weight_config = { - "*": { - "data_type": "int", - "bits": 4, - "group_size": 32, - "sym": False, - } - } - quantizer = AutoRoundQuantizer(quant_config=weight_config) - fp32_model = gpt_j_model - - # quantizer execute - model = quantizer.prepare(model=fp32_model) - run_fn(model, *run_args) - q_model = quantizer.convert(model) - - out = q_model(self.inp)[0] - assert torch.allclose(self.label, out, atol=1e-1) - assert "transformer.h.0.attn.k_proj" in q_model.autoround_config.keys() - assert "scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys() - assert torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"] + assert isinstance(q_model.transformer.h[0].attn.k_proj, WeightOnlyLinear), "packing model failed." + if quant_lm_head is True: + assert isinstance(q_model.lm_head, WeightOnlyLinear), "quantization for lm_head failed." def test_autoround_with_quantize_API(self): gpt_j_model = copy.deepcopy(self.gptj) - quant_config = get_default_AutoRound_config() + quant_config = AutoRoundConfig(n_samples=32, seqlen=10, iters=10, scale_dtype="fp32") + quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) + logger.info(f"Test AutoRound with config {quant_config}") # quantize API q_model = quantize( model=gpt_j_model, quant_config=quant_config, - run_fn=get_autoround_default_run_fn, - run_args=( - self.tokenizer, - "NeelNanda/pile-10k", - 20, - 10, - ), + run_fn=run_fn, + run_args=(self.dataloader,), ) out = q_model(self.inp)[0] assert torch.allclose(out, self.label, atol=1e-1) + assert isinstance(q_model.transformer.h[0].attn.k_proj, WeightOnlyLinear), "packing model failed." def test_save_and_load(self): fp32_model = copy.deepcopy(self.gptj) - quant_config = get_default_AutoRound_config() + quant_config = AutoRoundConfig(n_samples=32, seqlen=10, iters=10, scale_dtype="fp32") + # quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) logger.info(f"Test AutoRound with config {quant_config}") - run_fn = get_autoround_default_run_fn - run_args = ( - self.tokenizer, - "NeelNanda/pile-10k", - 20, - 10, - ) # quantizer execute model = prepare(model=fp32_model, quant_config=quant_config) - run_fn(model, *run_args) + run_fn(model, self.dataloader) q_model = convert(model) assert q_model is not None, "Quantization failed!" @@ -151,8 +120,10 @@ def test_save_and_load(self): loaded_model = load("saved_results") loaded_out = loaded_model(self.inp)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." + assert isinstance( + loaded_model.transformer.h[0].attn.k_proj, WeightOnlyLinear + ), "loading compressed model failed." - @pytest.mark.skipif(auto_round_version <= AUTO_ROUND_VERSION_0_11, reason="Requires auto_round>=0.11") def test_conv1d(self): input = torch.randn(1, 32) from transformers import GPT2Model, GPT2Tokenizer @@ -162,26 +133,10 @@ def test_conv1d(self): text = "Replace me by any text you'd like." encoded_input = tokenizer(text, return_tensors="pt") out1 = model(**encoded_input)[0] - run_fn = get_autoround_default_run_fn - run_args = ( - tokenizer, - "NeelNanda/pile-10k", - 20, - 10, - ) - weight_config = { - "*": { - "data_type": "int", - "bits": 4, - "group_size": 32, - "sym": False, - } - } - quantizer = AutoRoundQuantizer(quant_config=weight_config) - - # quantizer execute - model = quantizer.prepare(model=model) - run_fn(model, *run_args) - q_model = quantizer.convert(model) + quant_config = AutoRoundConfig(n_samples=32, seqlen=10, iters=10, scale_dtype="fp32") + model = prepare(model=model, quant_config=quant_config) + run_fn(model, self.dataloader) + q_model = convert(model) out2 = q_model(**encoded_input)[0] assert torch.allclose(out2, out1, atol=0.01), "Accuracy gap atol > 0.01 is unexpected." + assert isinstance(q_model.h[0].attn.c_attn, WeightOnlyLinear), "loading compressed model failed."