diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 4a206d37486..0515aa0fdfa 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -457,9 +457,20 @@ def __repr__(self) -> str: return f"{self.__class__.__name__} {self.to_json_string()}" def to_config_mapping( - self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None + self, config_list: List[BaseConfig] = None, model_info: Dict[str, Any] = None ) -> OrderedDict[str, BaseConfig]: - return super().to_config_mapping(self.config_list, model_info) + config_mapping = OrderedDict() + for config in self.config_list: + global_config = config.global_config + op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config() + single_config_model_info = model_info.get(config.name, None) + for op_name, op_type in single_config_model_info: + if op_type in op_type_config_dict: + config_mapping[(op_name, op_type)] = op_name_config_dict[op_type] + for op_name_pattern in op_name_config_dict: + if re.match(op_name_pattern, op_name): + config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern] + return config_mapping @classmethod def register_supported_configs(cls): @@ -471,6 +482,12 @@ def get_config_set_for_tuning(cls) -> None: # TODO (Yi) handle the composable config in `tuning_config` return None + def get_model_info(self, model, *args, **kwargs): + model_info_dict = dict() + for config in self.config_list: + model_info_dict.update({config.name: config.get_model_info(model, *args, **kwargs)}) + return model_info_dict + def get_all_config_set_from_config_registry(fwk_name: str) -> Union[BaseConfig, List[BaseConfig]]: all_registered_config_cls: List[BaseConfig] = config_registry.get_all_config_cls_by_fwk_name(fwk_name) diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index a963d4e0567..14598278452 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -41,6 +41,8 @@ def rtn_entry( # rebuild weight_config for rtn_quantize function weight_config = {} for (op_name, op_type), quant_config in configs_mapping.items(): + if quant_config.name != RTN: + continue weight_config[op_name] = { "dtype": quant_config.dtype, "bits": quant_config.bits, @@ -74,6 +76,8 @@ def gptq_entry( # rebuild weight_config for gptq_quantize function weight_config = {} for (op_name, op_type), quant_config in configs_mapping.items(): + if quant_config.name != GPTQ: + continue weight_config[op_name] = { "dtype": quant_config.dtype, "bits": quant_config.bits, @@ -120,6 +124,8 @@ def static_quant_entry( cfgs = deepcopy(configs_mapping) quant_config_mapping["op"] = cfgs for (op_name, op_type), cfg in cfgs.items(): + if cfg.name != STATIC_QUANT: + continue quant_config_mapping["op"][(op_name, op_type)] = { "weight": { "dtype": cfg.w_dtype, @@ -161,6 +167,8 @@ def awq_quantize_entry( weight_config = {} for (op_name, op_type), op_config in configs_mapping.items(): + if op_config.name != AWQ: + continue if op_config.dtype == "fp32": weight_config[op_name] = { "bits": -1, diff --git a/test/3x/common/test_common.py b/test/3x/common/test_common.py index 177ec9a361e..338570912b2 100644 --- a/test/3x/common/test_common.py +++ b/test/3x/common/test_common.py @@ -40,13 +40,14 @@ logger = Logger().get_logger() -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union from neural_compressor.common.base_config import ( BaseConfig, - ComposableConfig, + config_registry, get_all_config_set_from_config_registry, register_config, + register_supported_configs_for_fwk, ) from neural_compressor.common.base_tuning import ConfigLoader, ConfigSet, SequentialSampler from neural_compressor.common.tuning_param import TuningParam @@ -54,10 +55,25 @@ PRIORITY_FAKE_ALGO = 100 FAKE_CONFIG_NAME = "fake" +PRIORITY_FAKE_ALGO_1 = 90 +FAKE_CONFIG_NAME_1 = "fake_one" DEFAULT_WEIGHT_BITS = [4, 6] FAKE_FRAMEWORK_NAME = "FAKE_FWK" +FAKE_MODEL_INFO = [("OP1_NAME", "OP_TYPE1"), ("OP2_NAME", "OP_TYPE1"), ("OP3_NAME", "OP_TYPE2")] + + +class FakeModel: + def __init__(self) -> None: + self.name = "fake_model" + + def __call__(self, x) -> Any: + return x + + def __repr__(self) -> str: + return "FakeModel" + @register_config(framework_name=FAKE_FRAMEWORK_NAME, algo_name=FAKE_CONFIG_NAME, priority=PRIORITY_FAKE_ALGO) class FakeAlgoConfig(BaseConfig): @@ -102,17 +118,14 @@ def register_supported_configs(cls) -> List: pass @staticmethod - def get_model_info(model: Any) -> List[Tuple[str, Callable]]: - pass + def get_model_info(model: Any) -> List[Tuple[str, Any]]: + return FAKE_MODEL_INFO @classmethod def get_config_set_for_tuning(cls) -> Union[None, "FakeAlgoConfig", List["FakeAlgoConfig"]]: return FakeAlgoConfig(weight_bits=DEFAULT_WEIGHT_BITS) -FakeAlgoConfig.register_supported_configs() - - def get_default_fake_config() -> FakeAlgoConfig: """Generate the default fake config. @@ -122,10 +135,64 @@ def get_default_fake_config() -> FakeAlgoConfig: return FakeAlgoConfig() +@register_config(framework_name=FAKE_FRAMEWORK_NAME, algo_name=FAKE_CONFIG_NAME_1, priority=PRIORITY_FAKE_ALGO_1) +class FakeAlgoOneConfig(BaseConfig): + """Config class for fake algo.""" + + supported_configs: List = [] + params_list = [ + "weight_dtype", + "weight_bits", + TuningParam("target_op_type_list", tunable_type=List[List[str]]), + ] + name = FAKE_CONFIG_NAME_1 + + def __init__( + self, + weight_dtype: str = "int", + weight_bits: int = 4, + target_op_type_list: List[str] = ["Conv", "Gemm"], + white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, + ): + """Init fake config. + + Args: + weight_dtype (str): Data type for weights, default is "int". + weight_bits (int): Number of bits used to represent weights, default is 4. + """ + super().__init__(white_list=white_list) + self.weight_bits = weight_bits + self.weight_dtype = weight_dtype + self.target_op_type_list = target_op_type_list + self._post_init() + + def to_dict(self): + return super().to_dict() + + @classmethod + def from_dict(cls, config_dict): + return super(FakeAlgoOneConfig, cls).from_dict(config_dict=config_dict) + + @classmethod + def register_supported_configs(cls) -> List: + pass + + @staticmethod + def get_model_info(model: Any) -> List[Tuple[str, Any]]: + return FAKE_MODEL_INFO + + @classmethod + def get_config_set_for_tuning(cls) -> Union[None, "FakeAlgoOneConfig", List["FakeAlgoOneConfig"]]: + return FakeAlgoOneConfig(weight_bits=DEFAULT_WEIGHT_BITS) + + def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]: return get_all_config_set_from_config_registry(fwk_name=FAKE_FRAMEWORK_NAME) +register_supported_configs_for_fwk(fwk_name=FAKE_FRAMEWORK_NAME) + + class TestBaseConfig(unittest.TestCase): @classmethod def setUpClass(self): @@ -143,7 +210,7 @@ def test_api(self): fake_default_config = get_default_fake_config() self.assertEqual(fake_default_config.weight_dtype, "int") config_set = get_all_config_set() - self.assertEqual(len(config_set), 1) + self.assertEqual(len(config_set), len(config_registry.get_all_config_cls_by_fwk_name(FAKE_FRAMEWORK_NAME))) self.assertEqual(config_set[0].weight_bits, DEFAULT_WEIGHT_BITS) def test_config_expand_complex_tunable_type(self): @@ -154,6 +221,18 @@ def test_config_expand_complex_tunable_type(self): for i in range(len(configs_list)): self.assertEqual(configs_list[i].target_op_type_list, target_op_type_list_options[i]) + def test_mixed_two_algos(self): + model = FakeModel() + OP1_NAME = "OP1_NAME" + OP2_NAME = "OP2_NAME" + fake_config = FakeAlgoConfig(weight_bits=4, white_list=[OP1_NAME]) + fake1_config = FakeAlgoOneConfig(weight_bits=2, white_list=[OP2_NAME]) + mixed_config = fake_config + fake1_config + model_info = mixed_config.get_model_info(model) + config_mapping = mixed_config.to_config_mapping(model_info=model_info) + self.assertIn(OP1_NAME, [op_info[0] for op_info in config_mapping]) + self.assertIn(OP2_NAME, [op_info[0] for op_info in config_mapping]) + class TestConfigSet(unittest.TestCase): def setUp(self): diff --git a/test/3x/torch/quantization/weight_only/test_mixed_algos.py b/test/3x/torch/quantization/weight_only/test_mixed_algos.py new file mode 100644 index 00000000000..bc5ae94add3 --- /dev/null +++ b/test/3x/torch/quantization/weight_only/test_mixed_algos.py @@ -0,0 +1,40 @@ +import copy +from unittest.mock import patch + +import pytest +import torch +import transformers + +from neural_compressor.common.utils import logger +from neural_compressor.torch.quantization import GPTQConfig, RTNConfig, quantize + + +def run_fn(model): + # GPTQ uses ValueError to reduce computation when collecting input data of the first block + # It's special for UTs, no need to add this wrapper in examples. + with pytest.raises(ValueError): + model(torch.tensor([[10, 20, 30]], dtype=torch.long)) + model(torch.tensor([[40, 50, 60]], dtype=torch.long)) + + +class TestMixedTwoAlgo: + def test_mixed_gptq_and_rtn(self): + with patch.object(logger, "info") as mock_info: + rtn_config = RTNConfig(white_list=["lm_head"]) + gptq_config = GPTQConfig(double_quant_bits=4, white_list=["transformer.*"]) + combined_config = rtn_config + gptq_config + logger.info(combined_config) + + self.tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", + ) + self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long) + # record label for comparison + out_original_model = self.tiny_gptj(self.example_inputs)[0] + model = copy.deepcopy(self.tiny_gptj) + q_model = quantize(model, combined_config, run_fn=run_fn) + out_q_model = q_model(self.example_inputs)[0] + rtn_log = "Start to apply rtn on the model." + gptq_log = "Start to apply gptq on the model." + assert rtn_log in [_call[0][0] for _call in mock_info.call_args_list] + assert gptq_log in [_call[0][0] for _call in mock_info.call_args_list]