From 95e67eac624285d304487b654330d660b169cfb1 Mon Sep 17 00:00:00 2001 From: Zixuan Cheng <110808245+violetch24@users.noreply.github.com> Date: Sun, 28 Apr 2024 16:08:25 +0800 Subject: [PATCH] refine load API for 3.x ipex backend (#1755) Signed-off-by: Cheng, Zixuan --- .../quantization/habana_fp8/run_llm.py | 2 +- neural_compressor/common/utils/save_load.py | 4 +-- .../torch/quantization/load_entry.py | 26 ++++++++++++------- .../torch/quantization/habana_fp8/test_fp8.py | 2 +- .../torch/quantization/test_smooth_quant.py | 3 ++- .../torch/quantization/test_static_quant.py | 2 +- 6 files changed, 24 insertions(+), 15 deletions(-) diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py index 2ed07707a7e..5cd0f046aba 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py @@ -143,7 +143,7 @@ def calib_func(model): if args.load: from neural_compressor.torch.quantization import load - user_model = load(user_model, "saved_results") + user_model = load("saved_results", user_model) if args.approach in ["dynamic", "static"] or args.load: diff --git a/neural_compressor/common/utils/save_load.py b/neural_compressor/common/utils/save_load.py index 5ecd8ce3b97..15de5d8c2a3 100644 --- a/neural_compressor/common/utils/save_load.py +++ b/neural_compressor/common/utils/save_load.py @@ -19,7 +19,7 @@ import os -def save_config_mapping(config_mapping, qconfig_file_path): +def save_config_mapping(config_mapping, qconfig_file_path): # pragma: no cover """Save config mapping to json file. Args: @@ -36,7 +36,7 @@ def save_config_mapping(config_mapping, qconfig_file_path): json.dump(per_op_qconfig, f, indent=4) -def load_config_mapping(qconfig_file_path, config_name_mapping): +def load_config_mapping(qconfig_file_path, config_name_mapping): # pragma: no cover """Reload config mapping from json file. Args: diff --git a/neural_compressor/torch/quantization/load_entry.py b/neural_compressor/torch/quantization/load_entry.py index 52c5ad759bb..a576e005bf5 100644 --- a/neural_compressor/torch/quantization/load_entry.py +++ b/neural_compressor/torch/quantization/load_entry.py @@ -24,15 +24,23 @@ } -def load(model, output_dir="./saved_results"): +def load(output_dir="./saved_results", model=None): from neural_compressor.common.base_config import ConfigRegistry qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), "qconfig.json") - config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"]) - model.qconfig = config_mapping - # select load function - config_object = config_mapping[next(iter(config_mapping))] - if isinstance(config_object, FP8Config): - from neural_compressor.torch.algorithms.habana_fp8 import load - - return load(model, output_dir) + with open(qconfig_file_path, "r") as f: + per_op_qconfig = json.load(f) + if " " in per_op_qconfig.keys(): # ipex qconfig format: {' ': {'q_op_infos': {'0': {'op_type': ... + from neural_compressor.torch.algorithms.static_quant import load + + return load(output_dir) + + else: # FP8 + config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"]) + model.qconfig = config_mapping + # select load function + config_object = config_mapping[next(iter(config_mapping))] + if isinstance(config_object, FP8Config): + from neural_compressor.torch.algorithms.habana_fp8 import load + + return load(model, output_dir) diff --git a/test/3x/torch/quantization/habana_fp8/test_fp8.py b/test/3x/torch/quantization/habana_fp8/test_fp8.py index f1c6df7092b..8fafc302f65 100644 --- a/test/3x/torch/quantization/habana_fp8/test_fp8.py +++ b/test/3x/torch/quantization/habana_fp8/test_fp8.py @@ -153,7 +153,7 @@ def calib_func(model): from neural_compressor.torch.quantization import load m = copy.deepcopy(self.model) - m = load(m, "saved_results") + m = load("saved_results", m) recovered_out = m(inp) assert (recovered_out == fp8_out).all(), "Unexpected result. Please double check." assert isinstance(m.fc1, FP8Linear), "Unexpected result. Please double check." diff --git a/test/3x/torch/quantization/test_smooth_quant.py b/test/3x/torch/quantization/test_smooth_quant.py index 7aae7ab61f2..f5e82412265 100644 --- a/test/3x/torch/quantization/test_smooth_quant.py +++ b/test/3x/torch/quantization/test_smooth_quant.py @@ -133,7 +133,8 @@ def test_sq_save_load(self): q_model.save("saved_results") inc_out = q_model(example_inputs) - from neural_compressor.torch.algorithms.smooth_quant import load, recover_model_from_json + from neural_compressor.torch.algorithms.smooth_quant import recover_model_from_json + from neural_compressor.torch.quantization import load # load using saved model loaded_model = load("saved_results") diff --git a/test/3x/torch/quantization/test_static_quant.py b/test/3x/torch/quantization/test_static_quant.py index c8569262343..82177c6bfb4 100644 --- a/test/3x/torch/quantization/test_static_quant.py +++ b/test/3x/torch/quantization/test_static_quant.py @@ -153,7 +153,7 @@ def run_fn(model): assert torch.allclose(inc_out, ipex_out, atol=2e-02), "Unexpected result. Please double check." q_model.save("saved_results") - from neural_compressor.torch.algorithms.static_quant import load + from neural_compressor.torch.quantization import load # load loaded_model = load("saved_results")