Skip to content

Commit

Permalink
refine load API for 3.x ipex backend (#1755)
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng, Zixuan <zixuan.cheng@intel.com>
  • Loading branch information
violetch24 authored Apr 28, 2024
1 parent 0b2080b commit 95e67ea
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def calib_func(model):

if args.load:
from neural_compressor.torch.quantization import load
user_model = load(user_model, "saved_results")
user_model = load("saved_results", user_model)


if args.approach in ["dynamic", "static"] or args.load:
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/common/utils/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os


def save_config_mapping(config_mapping, qconfig_file_path):
def save_config_mapping(config_mapping, qconfig_file_path): # pragma: no cover
"""Save config mapping to json file.
Args:
Expand All @@ -36,7 +36,7 @@ def save_config_mapping(config_mapping, qconfig_file_path):
json.dump(per_op_qconfig, f, indent=4)


def load_config_mapping(qconfig_file_path, config_name_mapping):
def load_config_mapping(qconfig_file_path, config_name_mapping): # pragma: no cover
"""Reload config mapping from json file.
Args:
Expand Down
26 changes: 17 additions & 9 deletions neural_compressor/torch/quantization/load_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,23 @@
}


def load(model, output_dir="./saved_results"):
def load(output_dir="./saved_results", model=None):
from neural_compressor.common.base_config import ConfigRegistry

qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), "qconfig.json")
config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"])
model.qconfig = config_mapping
# select load function
config_object = config_mapping[next(iter(config_mapping))]
if isinstance(config_object, FP8Config):
from neural_compressor.torch.algorithms.habana_fp8 import load

return load(model, output_dir)
with open(qconfig_file_path, "r") as f:
per_op_qconfig = json.load(f)
if " " in per_op_qconfig.keys(): # ipex qconfig format: {' ': {'q_op_infos': {'0': {'op_type': ...
from neural_compressor.torch.algorithms.static_quant import load

return load(output_dir)

else: # FP8
config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"])
model.qconfig = config_mapping
# select load function
config_object = config_mapping[next(iter(config_mapping))]
if isinstance(config_object, FP8Config):
from neural_compressor.torch.algorithms.habana_fp8 import load

return load(model, output_dir)
2 changes: 1 addition & 1 deletion test/3x/torch/quantization/habana_fp8/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def calib_func(model):
from neural_compressor.torch.quantization import load

m = copy.deepcopy(self.model)
m = load(m, "saved_results")
m = load("saved_results", m)
recovered_out = m(inp)
assert (recovered_out == fp8_out).all(), "Unexpected result. Please double check."
assert isinstance(m.fc1, FP8Linear), "Unexpected result. Please double check."
Expand Down
3 changes: 2 additions & 1 deletion test/3x/torch/quantization/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def test_sq_save_load(self):
q_model.save("saved_results")
inc_out = q_model(example_inputs)

from neural_compressor.torch.algorithms.smooth_quant import load, recover_model_from_json
from neural_compressor.torch.algorithms.smooth_quant import recover_model_from_json
from neural_compressor.torch.quantization import load

# load using saved model
loaded_model = load("saved_results")
Expand Down
2 changes: 1 addition & 1 deletion test/3x/torch/quantization/test_static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def run_fn(model):
assert torch.allclose(inc_out, ipex_out, atol=2e-02), "Unexpected result. Please double check."
q_model.save("saved_results")

from neural_compressor.torch.algorithms.static_quant import load
from neural_compressor.torch.quantization import load

# load
loaded_model = load("saved_results")
Expand Down

0 comments on commit 95e67ea

Please sign in to comment.