Skip to content

Commit

Permalink
Enable the Combination of Multiple Algorithms within a Single Model (#…
Browse files Browse the repository at this point in the history
…1616)

Signed-off-by: yiliu30 <yi4.liu@intel.com>
  • Loading branch information
yiliu30 authored Feb 26, 2024
1 parent ec91109 commit 071ab31
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 10 deletions.
21 changes: 19 additions & 2 deletions neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,20 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__} {self.to_json_string()}"

def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
self, config_list: List[BaseConfig] = None, model_info: Dict[str, Any] = None
) -> OrderedDict[str, BaseConfig]:
return super().to_config_mapping(self.config_list, model_info)
config_mapping = OrderedDict()
for config in self.config_list:
global_config = config.global_config
op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config()
single_config_model_info = model_info.get(config.name, None)
for op_name, op_type in single_config_model_info:
if op_type in op_type_config_dict:
config_mapping[(op_name, op_type)] = op_name_config_dict[op_type]
for op_name_pattern in op_name_config_dict:
if re.match(op_name_pattern, op_name):
config_mapping[(op_name, op_type)] = op_name_config_dict[op_name_pattern]
return config_mapping

@classmethod
def register_supported_configs(cls):
Expand All @@ -471,6 +482,12 @@ def get_config_set_for_tuning(cls) -> None:
# TODO (Yi) handle the composable config in `tuning_config`
return None

def get_model_info(self, model, *args, **kwargs):
model_info_dict = dict()
for config in self.config_list:
model_info_dict.update({config.name: config.get_model_info(model, *args, **kwargs)})
return model_info_dict


def get_all_config_set_from_config_registry(fwk_name: str) -> Union[BaseConfig, List[BaseConfig]]:
all_registered_config_cls: List[BaseConfig] = config_registry.get_all_config_cls_by_fwk_name(fwk_name)
Expand Down
8 changes: 8 additions & 0 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def rtn_entry(
# rebuild weight_config for rtn_quantize function
weight_config = {}
for (op_name, op_type), quant_config in configs_mapping.items():
if quant_config.name != RTN:
continue
weight_config[op_name] = {
"dtype": quant_config.dtype,
"bits": quant_config.bits,
Expand Down Expand Up @@ -74,6 +76,8 @@ def gptq_entry(
# rebuild weight_config for gptq_quantize function
weight_config = {}
for (op_name, op_type), quant_config in configs_mapping.items():
if quant_config.name != GPTQ:
continue
weight_config[op_name] = {
"dtype": quant_config.dtype,
"bits": quant_config.bits,
Expand Down Expand Up @@ -120,6 +124,8 @@ def static_quant_entry(
cfgs = deepcopy(configs_mapping)
quant_config_mapping["op"] = cfgs
for (op_name, op_type), cfg in cfgs.items():
if cfg.name != STATIC_QUANT:
continue
quant_config_mapping["op"][(op_name, op_type)] = {
"weight": {
"dtype": cfg.w_dtype,
Expand Down Expand Up @@ -161,6 +167,8 @@ def awq_quantize_entry(

weight_config = {}
for (op_name, op_type), op_config in configs_mapping.items():
if op_config.name != AWQ:
continue
if op_config.dtype == "fp32":
weight_config[op_name] = {
"bits": -1,
Expand Down
95 changes: 87 additions & 8 deletions test/3x/common/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,40 @@

logger = Logger().get_logger()

from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

from neural_compressor.common.base_config import (
BaseConfig,
ComposableConfig,
config_registry,
get_all_config_set_from_config_registry,
register_config,
register_supported_configs_for_fwk,
)
from neural_compressor.common.base_tuning import ConfigLoader, ConfigSet, SequentialSampler
from neural_compressor.common.tuning_param import TuningParam
from neural_compressor.common.utils import DEFAULT_WHITE_LIST, OP_NAME_OR_MODULE_TYPE

PRIORITY_FAKE_ALGO = 100
FAKE_CONFIG_NAME = "fake"
PRIORITY_FAKE_ALGO_1 = 90
FAKE_CONFIG_NAME_1 = "fake_one"
DEFAULT_WEIGHT_BITS = [4, 6]

FAKE_FRAMEWORK_NAME = "FAKE_FWK"

FAKE_MODEL_INFO = [("OP1_NAME", "OP_TYPE1"), ("OP2_NAME", "OP_TYPE1"), ("OP3_NAME", "OP_TYPE2")]


class FakeModel:
def __init__(self) -> None:
self.name = "fake_model"

def __call__(self, x) -> Any:
return x

def __repr__(self) -> str:
return "FakeModel"


@register_config(framework_name=FAKE_FRAMEWORK_NAME, algo_name=FAKE_CONFIG_NAME, priority=PRIORITY_FAKE_ALGO)
class FakeAlgoConfig(BaseConfig):
Expand Down Expand Up @@ -102,17 +118,14 @@ def register_supported_configs(cls) -> List:
pass

@staticmethod
def get_model_info(model: Any) -> List[Tuple[str, Callable]]:
pass
def get_model_info(model: Any) -> List[Tuple[str, Any]]:
return FAKE_MODEL_INFO

@classmethod
def get_config_set_for_tuning(cls) -> Union[None, "FakeAlgoConfig", List["FakeAlgoConfig"]]:
return FakeAlgoConfig(weight_bits=DEFAULT_WEIGHT_BITS)


FakeAlgoConfig.register_supported_configs()


def get_default_fake_config() -> FakeAlgoConfig:
"""Generate the default fake config.
Expand All @@ -122,10 +135,64 @@ def get_default_fake_config() -> FakeAlgoConfig:
return FakeAlgoConfig()


@register_config(framework_name=FAKE_FRAMEWORK_NAME, algo_name=FAKE_CONFIG_NAME_1, priority=PRIORITY_FAKE_ALGO_1)
class FakeAlgoOneConfig(BaseConfig):
"""Config class for fake algo."""

supported_configs: List = []
params_list = [
"weight_dtype",
"weight_bits",
TuningParam("target_op_type_list", tunable_type=List[List[str]]),
]
name = FAKE_CONFIG_NAME_1

def __init__(
self,
weight_dtype: str = "int",
weight_bits: int = 4,
target_op_type_list: List[str] = ["Conv", "Gemm"],
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
"""Init fake config.
Args:
weight_dtype (str): Data type for weights, default is "int".
weight_bits (int): Number of bits used to represent weights, default is 4.
"""
super().__init__(white_list=white_list)
self.weight_bits = weight_bits
self.weight_dtype = weight_dtype
self.target_op_type_list = target_op_type_list
self._post_init()

def to_dict(self):
return super().to_dict()

@classmethod
def from_dict(cls, config_dict):
return super(FakeAlgoOneConfig, cls).from_dict(config_dict=config_dict)

@classmethod
def register_supported_configs(cls) -> List:
pass

@staticmethod
def get_model_info(model: Any) -> List[Tuple[str, Any]]:
return FAKE_MODEL_INFO

@classmethod
def get_config_set_for_tuning(cls) -> Union[None, "FakeAlgoOneConfig", List["FakeAlgoOneConfig"]]:
return FakeAlgoOneConfig(weight_bits=DEFAULT_WEIGHT_BITS)


def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
return get_all_config_set_from_config_registry(fwk_name=FAKE_FRAMEWORK_NAME)


register_supported_configs_for_fwk(fwk_name=FAKE_FRAMEWORK_NAME)


class TestBaseConfig(unittest.TestCase):
@classmethod
def setUpClass(self):
Expand All @@ -143,7 +210,7 @@ def test_api(self):
fake_default_config = get_default_fake_config()
self.assertEqual(fake_default_config.weight_dtype, "int")
config_set = get_all_config_set()
self.assertEqual(len(config_set), 1)
self.assertEqual(len(config_set), len(config_registry.get_all_config_cls_by_fwk_name(FAKE_FRAMEWORK_NAME)))
self.assertEqual(config_set[0].weight_bits, DEFAULT_WEIGHT_BITS)

def test_config_expand_complex_tunable_type(self):
Expand All @@ -154,6 +221,18 @@ def test_config_expand_complex_tunable_type(self):
for i in range(len(configs_list)):
self.assertEqual(configs_list[i].target_op_type_list, target_op_type_list_options[i])

def test_mixed_two_algos(self):
model = FakeModel()
OP1_NAME = "OP1_NAME"
OP2_NAME = "OP2_NAME"
fake_config = FakeAlgoConfig(weight_bits=4, white_list=[OP1_NAME])
fake1_config = FakeAlgoOneConfig(weight_bits=2, white_list=[OP2_NAME])
mixed_config = fake_config + fake1_config
model_info = mixed_config.get_model_info(model)
config_mapping = mixed_config.to_config_mapping(model_info=model_info)
self.assertIn(OP1_NAME, [op_info[0] for op_info in config_mapping])
self.assertIn(OP2_NAME, [op_info[0] for op_info in config_mapping])


class TestConfigSet(unittest.TestCase):
def setUp(self):
Expand Down
40 changes: 40 additions & 0 deletions test/3x/torch/quantization/weight_only/test_mixed_algos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import copy
from unittest.mock import patch

import pytest
import torch
import transformers

from neural_compressor.common.utils import logger
from neural_compressor.torch.quantization import GPTQConfig, RTNConfig, quantize


def run_fn(model):
# GPTQ uses ValueError to reduce computation when collecting input data of the first block
# It's special for UTs, no need to add this wrapper in examples.
with pytest.raises(ValueError):
model(torch.tensor([[10, 20, 30]], dtype=torch.long))
model(torch.tensor([[40, 50, 60]], dtype=torch.long))


class TestMixedTwoAlgo:
def test_mixed_gptq_and_rtn(self):
with patch.object(logger, "info") as mock_info:
rtn_config = RTNConfig(white_list=["lm_head"])
gptq_config = GPTQConfig(double_quant_bits=4, white_list=["transformer.*"])
combined_config = rtn_config + gptq_config
logger.info(combined_config)

self.tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM",
)
self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long)
# record label for comparison
out_original_model = self.tiny_gptj(self.example_inputs)[0]
model = copy.deepcopy(self.tiny_gptj)
q_model = quantize(model, combined_config, run_fn=run_fn)
out_q_model = q_model(self.example_inputs)[0]
rtn_log = "Start to apply rtn on the model."
gptq_log = "Start to apply gptq on the model."
assert rtn_log in [_call[0][0] for _call in mock_info.call_args_list]
assert gptq_log in [_call[0][0] for _call in mock_info.call_args_list]

0 comments on commit 071ab31

Please sign in to comment.