diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 4652dd170fa..dd0236f39c3 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -21,9 +21,8 @@ import re from abc import ABC, abstractmethod from collections import OrderedDict -from copy import deepcopy from itertools import product -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from neural_compressor.common import Logger from neural_compressor.common.utils import ( @@ -44,13 +43,12 @@ "register_config", "BaseConfig", "ComposableConfig", - "Options", + "get_all_config_set_from_config_registry", "options", ] -# Dictionary to store registered configurations - +# Config registry to store all registered configs. class ConfigRegistry: registered_configs = {} @@ -104,6 +102,13 @@ def get_cls_configs(cls) -> Dict[str, Dict[str, object]]: cls_configs[framework_name][algo_name] = config_data["cls"] return cls_configs + @classmethod + def get_all_config_cls_by_fwk_name(cls, fwk_name: str) -> List[Type[BaseConfig]]: + configs_cls = [] + for algo_name, config_pairs in cls.registered_configs.get(fwk_name, {}).items(): + configs_cls.append(config_pairs["cls"]) + return configs_cls + config_registry = ConfigRegistry() @@ -373,6 +378,11 @@ def _is_op_type(name: str) -> bool: # TODO (Yi), ort and tf need override it return not isinstance(name, str) + @classmethod + @abstractmethod + def get_config_set_for_tuning(cls): + raise NotImplementedError + class ComposableConfig(BaseConfig): name = COMPOSABLE_CONFIG @@ -420,6 +430,24 @@ def register_supported_configs(cls): """Add all supported configs.""" raise NotImplementedError + @classmethod + def get_config_set_for_tuning(cls) -> None: + # TODO (Yi) handle the composable config in `tuning_config` + return None + + +def get_all_config_set_from_config_registry(fwk_name: str) -> Union[BaseConfig, List[BaseConfig]]: + all_registered_config_cls: List[BaseConfig] = config_registry.get_all_config_cls_by_fwk_name(fwk_name) + config_set = [] + for config_cls in all_registered_config_cls: + config_set.append(config_cls.get_config_set_for_tuning()) + return config_set + + +####################################################### +#### Options +####################################################### + def _check_value(name, src, supported_type, supported_value=[]): """Check if the given object is the given supported type and in the given supported value. diff --git a/neural_compressor/common/base_tuning.py b/neural_compressor/common/base_tuning.py index eaccc217f10..6145c49379b 100644 --- a/neural_compressor/common/base_tuning.py +++ b/neural_compressor/common/base_tuning.py @@ -129,8 +129,8 @@ class Sampler: class ConfigLoader: - def __init__(self, quant_configs, sampler: Sampler) -> None: - self.quant_configs = quant_configs + def __init__(self, config_set, sampler: Sampler) -> None: + self.config_set = config_set self.sampler = sampler @staticmethod @@ -146,7 +146,7 @@ def parse_quant_config(quant_config: BaseConfig) -> List[BaseConfig]: def parse_quant_configs(self) -> List[BaseConfig]: # TODO (Yi) separate this functionality into `Sampler` in the next PR quant_config_list = [] - for quant_config in self.quant_configs: + for quant_config in self.config_set: quant_config_list.extend(ConfigLoader.parse_quant_config(quant_config)) return quant_config_list @@ -210,14 +210,14 @@ class TuningConfig: """Base Class for Tuning Criterion. Args: - quant_configs: quantization configs. Default value is empty. + config_set: quantization configs. Default value is empty. timeout: Tuning timeout (seconds). Default value is 0 which means early stop. max_trials: Max tuning times. Default value is 100. Combine with timeout field to decide when to exit. """ - def __init__(self, quant_configs=None, timeout=0, max_trials=100, sampler: Sampler = None) -> None: + def __init__(self, config_set=None, timeout=0, max_trials=100, sampler: Sampler = None) -> None: """Init a TuneCriterion object.""" - self.quant_configs = quant_configs + self.config_set = config_set self.timeout = timeout self.max_trials = max_trials self.sampler = sampler @@ -265,7 +265,7 @@ def need_stop(self) -> bool: def init_tuning(tuning_config: TuningConfig) -> Tuple[ConfigLoader, TuningLogger, TuningMonitor]: - config_loader = ConfigLoader(quant_configs=tuning_config.quant_configs, sampler=tuning_config.sampler) + config_loader = ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler) tuning_logger = TuningLogger() tuning_monitor = TuningMonitor(tuning_config) return config_loader, tuning_logger, tuning_monitor diff --git a/neural_compressor/onnxrt/quantization/config.py b/neural_compressor/onnxrt/quantization/config.py index a827431d5a9..12e97865149 100644 --- a/neural_compressor/onnxrt/quantization/config.py +++ b/neural_compressor/onnxrt/quantization/config.py @@ -144,6 +144,11 @@ def get_model_info(model: Union[onnx.ModelProto, Path, str]) -> List[Tuple[str, logger.debug(f"Get model info: {filter_result}") return filter_result + @classmethod + def get_config_set_for_tuning(cls) -> Union[None, "RTNConfig", List["RTNConfig"]]: # pragma: no cover + # TODO fwk owner needs to update it. + return RTNConfig(weight_bits=[4, 6]) + # TODO(Yi) run `register_supported_configs` for all registered config. RTNConfig.register_supported_configs() diff --git a/neural_compressor/tensorflow/quantization/config.py b/neural_compressor/tensorflow/quantization/config.py index e3704b5f57d..781c3ef19a8 100644 --- a/neural_compressor/tensorflow/quantization/config.py +++ b/neural_compressor/tensorflow/quantization/config.py @@ -103,6 +103,13 @@ def register_supported_configs(cls) -> List[OperatorConfig]: supported_configs.append(OperatorConfig(config=static_quant_config, operators=operators)) cls.supported_configs = supported_configs + @classmethod + def get_config_set_for_tuning( + cls, + ) -> Union[None, "StaticQuantConfig", List["StaticQuantConfig"]]: # pragma: no cover + # TODO fwk owner needs to update it. + return StaticQuantConfig(weight_sym=[True, False]) + # TODO(Yi) run `register_supported_configs` for all registered config. StaticQuantConfig.register_supported_configs() diff --git a/neural_compressor/torch/__init__.py b/neural_compressor/torch/__init__.py index 5fe8d73cc8c..e60a2a2c2ec 100644 --- a/neural_compressor/torch/__init__.py +++ b/neural_compressor/torch/__init__.py @@ -28,4 +28,4 @@ ) from neural_compressor.common.base_tuning import TuningConfig -from neural_compressor.torch.quantization.autotune import autotune, get_default_tune_config +from neural_compressor.torch.quantization.autotune import autotune, get_all_config_set diff --git a/neural_compressor/torch/quantization/autotune.py b/neural_compressor/torch/quantization/autotune.py index d2c38357b93..bb48f0685c6 100644 --- a/neural_compressor/torch/quantization/autotune.py +++ b/neural_compressor/torch/quantization/autotune.py @@ -12,28 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy from typing import Dict, List, Optional, Union import torch from neural_compressor.common import Logger -from neural_compressor.common.base_config import BaseConfig +from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry from neural_compressor.common.base_tuning import TuningConfig, evaluator, init_tuning from neural_compressor.torch import quantize -from neural_compressor.torch.quantization.config import GPTQConfig, RTNConfig +from neural_compressor.torch.quantization.config import FRAMEWORK_NAME logger = Logger().get_logger() __all__ = [ - "get_default_tune_config", "autotune", + "get_all_config_set", ] -def get_default_tune_config() -> TuningConfig: - # TODO use the registered default tuning config in the next PR - return TuningConfig(quant_configs=[GPTQConfig(weight_bits=[4, 8]), RTNConfig(weight_bits=[4, 8])]) +def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]: + return get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME) def autotune( @@ -52,7 +52,9 @@ def autotune( for trial_index, quant_config in enumerate(config_loader): tuning_logger.trial_start(trial_index=trial_index) tuning_logger.quantization_start() - q_model = quantize(model, quant_config=quant_config, run_fn=run_fn, run_args=run_args) + logger.info(f"quant config: {quant_config}") + # !!! Make sure to use deepcopy only when inplace is set to `True`. + q_model = quantize(deepcopy(model), quant_config=quant_config, run_fn=run_fn, run_args=run_args, inplace=True) tuning_logger.quantization_end() tuning_logger.evaluation_start() eval_result: float = evaluator.evaluate(q_model) @@ -60,7 +62,8 @@ def autotune( tuning_monitor.add_trial_result(trial_index, eval_result, quant_config) if tuning_monitor.need_stop(): best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config() - quantize(model, quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, inplace=True) + # !!! Make sure to use deepcopy only when inplace is set to `True`. + quantize(deepcopy(model), quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, inplace=True) best_quant_model = model # quantize model inplace tuning_logger.trial_end(trial_index) tuning_logger.tuning_end() diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 86c5b2d458e..e4ee3130587 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -36,6 +36,14 @@ from neural_compressor.torch.utils.constants import PRIORITY_GPTQ, PRIORITY_RTN from neural_compressor.torch.utils.utility import is_hpex_avaliable, logger +__all__ = [ + "RTNConfig", + "get_default_rtn_config", + "GPTQConfig", + "get_default_gptq_config", +] + + FRAMEWORK_NAME = "torch" DTYPE_RANGE = Union[torch.dtype, List[torch.dtype]] @@ -153,6 +161,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: logger.debug(f"Get model info: {filter_result}") return filter_result + @classmethod + def get_config_set_for_tuning(cls) -> Union[None, "RTNConfig", List["RTNConfig"]]: + # TODO fwk owner needs to update it. + return RTNConfig(weight_bits=[4, 6]) + # TODO(Yi) run `register_supported_configs` for all registered config. RTNConfig.register_supported_configs() @@ -276,6 +289,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: logger.debug(f"Get model info: {filter_result}") return filter_result + @classmethod + def get_config_set_for_tuning(cls) -> Union[None, "GPTQConfig", List["GPTQConfig"]]: + # TODO fwk owner needs to update it. + return GPTQConfig(weight_bits=[4, 6]) + # TODO(Yi) run `register_supported_configs` for all registered config. GPTQConfig.register_supported_configs() @@ -352,6 +370,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: logger.debug(f"Get model info: {filter_result}") return filter_result + @classmethod + def get_config_set_for_tuning(cls) -> Union[None, "StaticQuantConfig", List["StaticQuantConfig"]]: + # TODO fwk owner needs to update it. + return StaticQuantConfig(w_sym=[True, False]) + # TODO(Yi) run `register_supported_configs` for all registered config. StaticQuantConfig.register_supported_configs() @@ -461,6 +484,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: logger.debug(f"Get model info: {filter_result}") return filter_result + @classmethod + def get_config_set_for_tuning(cls) -> Union[None, "SmoothQuantConfig", List["SmoothQuantConfig"]]: + # TODO fwk owner needs to update it. + return SmoothQuantConfig(alpha=[0.1, 0.5]) + # TODO(Yi) run `register_supported_configs` for all registered config. SmoothQuantConfig.register_supported_configs() @@ -541,6 +569,11 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: logger.debug(f"Get model info: {filter_result}") return filter_result + @classmethod + def get_config_set_for_tuning(cls) -> Union[None, "FP8QConfig", List["FP8QConfig"]]: + # TODO fwk owner needs to update it. + return FP8QConfig(act_dtype=[torch.float8_e4m3fn]) + # TODO(Yi) run `register_supported_configs` for all registered config. FP8QConfig.register_supported_configs() diff --git a/test/3x/onnxrt/test_config.py b/test/3x/onnxrt/test_config.py index 1bb51c141c7..dfc8f00dea5 100644 --- a/test/3x/onnxrt/test_config.py +++ b/test/3x/onnxrt/test_config.py @@ -328,6 +328,14 @@ def test_expand_config(self): self.assertEqual(expand_config_list[0].weight_bits, 4) self.assertEqual(expand_config_list[1].weight_bits, 8) + def test_config_set_api(self): + # *Note: this test is only for improving the code coverage and can be removed once the test_common is enabled. + from neural_compressor.common.base_config import config_registry, get_all_config_set_from_config_registry + from neural_compressor.onnxrt.quantization.config import FRAMEWORK_NAME + + config_set = get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME) + self.assertEqual(len(config_set), len(config_registry.registered_configs[FRAMEWORK_NAME])) + if __name__ == "__main__": unittest.main() diff --git a/test/3x/tensorflow/test_config.py b/test/3x/tensorflow/test_config.py index fe9c7830356..6a7bd7afeab 100644 --- a/test/3x/tensorflow/test_config.py +++ b/test/3x/tensorflow/test_config.py @@ -315,6 +315,14 @@ def test_expand_config(self): self.assertEqual(expand_config_list[0].weight_granularity, "per_channel") self.assertEqual(expand_config_list[1].weight_granularity, "per_tensor") + def test_config_set_api(self): + # *Note: this test is only for improving the code coverage and can be removed once the test_common is enabled. + from neural_compressor.common.base_config import config_registry, get_all_config_set_from_config_registry + from neural_compressor.tensorflow.quantization.config import FRAMEWORK_NAME + + config_set = get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME) + self.assertEqual(len(config_set), len(config_registry.registered_configs[FRAMEWORK_NAME])) + if __name__ == "__main__": unittest.main() diff --git a/test/3x/torch/test_autotune.py b/test/3x/torch/test_autotune.py index cbdf587d2c4..e1b717e3163 100644 --- a/test/3x/torch/test_autotune.py +++ b/test/3x/torch/test_autotune.py @@ -40,6 +40,62 @@ def forward(self, x): return model +def get_gpt_j(): + import transformers + + tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", + torchscript=True, + ) + return tiny_gptj + + +class GPTQLLMDataLoader: + def __init__(self, length=512): + self.batch_size = 1 + self.length = length + + def __iter__(self): + for i in range(10): + yield torch.ones([1, self.length], dtype=torch.long) + + +class GPTQLLMDataLoaderList(GPTQLLMDataLoader): + def __iter__(self): + for i in range(10): + yield (torch.ones([1, self.length], dtype=torch.long), torch.ones([1, self.length], dtype=torch.long)) + + +class GPTQLLMDataLoaderDict(GPTQLLMDataLoader): + def __iter__(self): + for i in range(10): + yield { + "input_ids": torch.ones([1, self.length], dtype=torch.long), + "attention_mask": torch.ones([1, self.length], dtype=torch.long), + } + + +from tqdm import tqdm + +from neural_compressor.torch.algorithms.weight_only.gptq import move_input_to_device + + +def run_fn_for_gptq(model, dataloader_for_calibration, *args): + logger.info("Collecting calibration inputs...") + for batch in tqdm(dataloader_for_calibration): + batch = move_input_to_device(batch, device=None) + try: + if isinstance(batch, tuple) or isinstance(batch, list): + model(batch[0]) + elif isinstance(batch, dict): + model(**batch) + else: + model(batch) + except ValueError: + pass + return + + class TestAutoTune(unittest.TestCase): @classmethod def setUpClass(self): @@ -67,7 +123,7 @@ def test_autotune_api(self): def eval_acc_fn(model) -> float: return 1.0 - custom_tune_config = TuningConfig(quant_configs=[RTNConfig(weight_bits=[4, 6])], max_trials=2) + custom_tune_config = TuningConfig(config_set=[RTNConfig(weight_bits=[4, 6])], max_trials=2) best_model = autotune( model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}] ) @@ -94,17 +150,55 @@ def eval_perf_fn(model) -> float: }, ] - custom_tune_config = TuningConfig(quant_configs=[RTNConfig(weight_bits=[4, 6])], max_trials=2) + custom_tune_config = TuningConfig(config_set=[RTNConfig(weight_bits=[4, 6])], max_trials=2) best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=eval_fns) self.assertIsNotNone(best_model) self.assertEqual(len(evaluator.eval_fn_registry), 2) + @reset_tuning_target + def test_autotune_get_config_set_api(self): + from neural_compressor.torch import TuningConfig, autotune, get_all_config_set + from neural_compressor.torch.algorithms.weight_only.gptq import DataloaderPreprocessor + + dataloader = GPTQLLMDataLoader() + + model = get_gpt_j() + input = torch.ones([1, 512], dtype=torch.long) + + dataloaderPreprocessor = DataloaderPreprocessor( + dataloader_original=dataloader, use_max_length=False, pad_max_length=512, nsamples=128 + ) + dataloader_for_calibration = dataloaderPreprocessor.get_prepared_dataloader() + + def eval_acc_fn(model) -> float: + return 1.0 + + def eval_perf_fn(model) -> float: + return 1.0 + + eval_fns = [ + {"eval_fn": eval_acc_fn, "weight": 0.5, "name": "accuracy"}, + { + "eval_fn": eval_perf_fn, + "weight": 0.5, + }, + ] + custom_tune_config = TuningConfig(config_set=get_all_config_set(), max_trials=4) + best_model = autotune( + model=get_gpt_j(), + tune_config=custom_tune_config, + eval_fns=eval_fns, + run_fn=run_fn_for_gptq, + run_args=dataloader_for_calibration, + ) + self.assertIsNotNone(best_model) + @reset_tuning_target def test_autotune_not_eval_func(self): logger.info("test_autotune_api") from neural_compressor.torch import RTNConfig, TuningConfig, autotune - custom_tune_config = TuningConfig(quant_configs=[RTNConfig(weight_bits=[4, 6])], max_trials=2) + custom_tune_config = TuningConfig(config_set=[RTNConfig(weight_bits=[4, 6])], max_trials=2) # Use assertRaises to check that an AssertionError is raised with self.assertRaises(AssertionError) as context: