From f812e673999ae961c3ca9ab0e8384185bbaa28ba Mon Sep 17 00:00:00 2001 From: xinhe Date: Tue, 27 Feb 2024 22:14:37 +0800 Subject: [PATCH] update fp8 implementation, design and implement save&load (#1605) Signed-off-by: xinhe3 --- .../quantization/habana_fp8/README.md | 0 .../models/configuration_chatglm.py | 0 .../habana_fp8/models/modeling_chatglm.py | 0 .../habana_fp8/models/modeling_llama.py | 0 .../models/tokenization_baichuan.py | 0 .../quantization/habana_fp8/requirement.txt | 0 .../quantization/habana_fp8/run_llm.py | 239 ++- .../quantization/habana_fp8/utils.py | 35 + neural_compressor/common/utils/__init__.py | 1 + neural_compressor/common/utils/save_load.py | 61 + .../torch/algorithms/habana_fp8/__init__.py | 1 + .../torch/algorithms/habana_fp8/fp8_quant.py | 51 +- .../torch/algorithms/habana_fp8/modules.py | 447 +++-- .../torch/algorithms/habana_fp8/observer.py | 4 +- .../torch/algorithms/habana_fp8/save_load.py | 98 + .../algorithms/habana_fp8/tensor/__init__.py | 13 + .../algorithms/habana_fp8/tensor/convert.cpp | 63 + .../algorithms/smooth_quant/smoothquant.py | 1596 +++++++++++++++++ neural_compressor/torch/amp/fp8/functions.py | 7 +- .../torch/quantization/__init__.py | 3 + .../torch/quantization/algorithm_entry.py | 14 +- .../torch/quantization/config.py | 144 +- .../torch/quantization/load_entry.py | 38 + .../torch/quantization/modules.py | 47 - neural_compressor/torch/utils/environ.py | 1 - neural_compressor/torch/utils/utility.py | 4 + requirements_pt.txt | 1 - setup.py | 21 + test/3x/torch/quantization/fp8/test_fp8.py | 130 -- .../torch/quantization/habana_fp8/test_fp8.py | 166 ++ test/3x/torch/requirements.txt | 1 + 31 files changed, 2676 insertions(+), 510 deletions(-) rename examples/{ => 3.x_api}/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/README.md (100%) rename examples/{ => 3.x_api}/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/configuration_chatglm.py (100%) rename examples/{ => 3.x_api}/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_chatglm.py (100%) rename examples/{ => 3.x_api}/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_llama.py (100%) rename examples/{ => 3.x_api}/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/tokenization_baichuan.py (100%) rename examples/{ => 3.x_api}/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/requirement.txt (100%) rename examples/{ => 3.x_api}/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py (67%) create mode 100644 examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/utils.py create mode 100644 neural_compressor/common/utils/save_load.py create mode 100644 neural_compressor/torch/algorithms/habana_fp8/save_load.py create mode 100644 neural_compressor/torch/algorithms/habana_fp8/tensor/__init__.py create mode 100644 neural_compressor/torch/algorithms/habana_fp8/tensor/convert.cpp create mode 100644 neural_compressor/torch/algorithms/smooth_quant/smoothquant.py create mode 100644 neural_compressor/torch/quantization/load_entry.py delete mode 100644 neural_compressor/torch/quantization/modules.py delete mode 100644 test/3x/torch/quantization/fp8/test_fp8.py create mode 100644 test/3x/torch/quantization/habana_fp8/test_fp8.py diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/README.md b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/README.md similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/README.md rename to examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/README.md diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/configuration_chatglm.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/configuration_chatglm.py similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/configuration_chatglm.py rename to examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/configuration_chatglm.py diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_chatglm.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_chatglm.py similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_chatglm.py rename to examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_chatglm.py diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_llama.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_llama.py similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_llama.py rename to examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_llama.py diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/tokenization_baichuan.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/tokenization_baichuan.py similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/tokenization_baichuan.py rename to examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/tokenization_baichuan.py diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/requirement.txt b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/requirement.txt similarity index 100% rename from examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/requirement.txt rename to examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/requirement.txt diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py similarity index 67% rename from examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py rename to examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py index 3034b4b72a6..3c0f91b9f58 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py @@ -1,30 +1,33 @@ +import os +os.environ["EXPERIMENTAL_WEIGHT_SHARING"] = "False" +os.environ["USE_GAUDI2_SCALE"] = "True" +os.environ.pop("USE_GAUDI2_SCALE") # gaudi scale work +# os.environ["GRAPH_VISUALIZATION"] = "True" +# import shutil +# shutil.rmtree(".graph_dumps", ignore_errors=True) import argparse import time import json import re import torch -import transformers -import os -import deepspeed -from transformers import AutoModelForCausalLM, AutoTokenizer import habana_frameworks.torch.hpex -from habana_frameworks.torch.hpu import memory_stats +import torch.nn.functional as F +import deepspeed +import transformers +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig +import habana_frameworks.torch.core as htcore import numpy as np import lm_eval import lm_eval.tasks import lm_eval.evaluator -torch.set_grad_enabled(False) +from accelerate import init_empty_weights +from utils import itrex_bootstrap_stderr, show_msg, save_to_excel -def itrex_bootstrap_stderr(f, xs, iters): - from lm_eval.metrics import _bootstrap_internal, sample_stddev - res = [] - chunk_size = min(1000, iters) - it = _bootstrap_internal(f, chunk_size) - for i in range(iters // chunk_size): - bootstrap = it((i, xs)) - res.extend(bootstrap) - return sample_stddev(res) +torch.set_grad_enabled(False) +htcore.hpu_set_env() +torch.device('hpu') + # to avoid out-of-memory caused by Popen for large language models. lm_eval.metrics.bootstrap_stderr = itrex_bootstrap_stderr @@ -51,22 +54,26 @@ def itrex_bootstrap_stderr(f, xs, iters): parser.add_argument("--accuracy", action="store_true") parser.add_argument("--performance", action="store_true") parser.add_argument("--generate", action="store_true") +parser.add_argument("--skip_fp8_mm", action="store_true") +parser.add_argument("--dump_to_excel", action="store_true") +parser.add_argument("--save", action="store_true") +parser.add_argument("--load", action="store_true") parser.add_argument("--batch_size", default=1, type=int, help="For accuracy measurement only.") parser.add_argument("--pad_max_length", default=512, type=int, help="Pad input ids to max length.") parser.add_argument("--calib_iters", default=100, type=int, help="calibration iters.") -parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], type=str, \ - choices=["winogrande", "copa", "piqa", "rte", "hellaswag", \ - "openbookqa", "lambada_openai", "lambada_standard", "wikitext"], +parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], \ + type=str, choices=["hellaswag", "lambada_openai", "piqa", "winogrande", "copa", + "rte", "openbookqa", "lambada_standard", "wikitext"], help="tasks list for accuracy validation") parser.add_argument("--limit", default=None, type=int, help="the sample num of evaluation.") parser.add_argument("--max_new_tokens", default=100, type=int, help="calibration iters.") parser.add_argument('--buckets', type=int, nargs='+', \ - help="Input length buckets to use with static_shapes", default=[129]) + help="Input length buckets to use with static_shapes", default=[256, 512]) parser.add_argument("--local_rank", type=int, default=-1, @@ -78,67 +85,65 @@ def itrex_bootstrap_stderr(f, xs, iters): world_size = int(os.getenv('WORLD_SIZE', '1')) local_rank = int(os.getenv('LOCAL_RANK', '-1')) -#if local_rank == 0: -# os.environ["ENABLE_CONSOLE"] = 'True' -# os.environ["LOG_LEVEL_ALL"] = '0' -# model +model_dtype = torch.float32 if re.search("llama", args.model.lower()) or re.search("bloom", args.model.lower()): from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer - torch.device('hpu') - config = AutoConfig.from_pretrained(args.model) if world_size > 1: - model_dtype = torch.bfloat16 + config = AutoConfig.from_pretrained(args.model) + model_dtype = torch.bfloat16 # RuntimeErrorCastToFp8V2 input must be of float or bfloat16 dtype deepspeed.init_distributed(dist_backend="hccl") with deepspeed.OnDevice(dtype=model_dtype, device="meta"): user_model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype) import tempfile checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") - from utils import write_checkpoints_json + from optimum.habana.checkpoint_utils import write_checkpoints_json # in optimum-habana write_checkpoints_json( - args.model, - local_rank, - checkpoints_json, - token=None, + args.model, + local_rank, + checkpoints_json, + token=None, ) - elif re.search("llama", args.model.lower()): - from models.modeling_llama import LlamaForCausalLM - user_model = LlamaForCausalLM.from_pretrained( + else: + if args.load: + config = AutoConfig.from_pretrained(args.model) + with init_empty_weights(): + user_model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype) + else: + user_model = AutoModelForCausalLM.from_pretrained( + args.model, + device_map='hpu', + torch_dtype=model_dtype, + ) +elif re.search("chatglm", args.model.lower()): + if args.load: + config = AutoConfig.from_pretrained(args.model, torch_dtype=model_dtype) + with init_empty_weights(): + user_model = AutoModelForCausalLM.from_config(config) + else: + from models.modeling_chatglm import ChatGLMForConditionalGeneration + user_model = ChatGLMForConditionalGeneration.from_pretrained( args.model, + revision=args.revision, device_map='hpu', + torch_dtype=model_dtype, ) + # print(user_model.transformer.output_layer.weight.dtype) # always fp16 + user_model.float() # static fp8 need float32 for graph compiler +else: + if args.load: + config = AutoConfig.from_pretrained(args.model) + with init_empty_weights(): + user_model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype) else: user_model = AutoModelForCausalLM.from_pretrained( args.model, + trust_remote_code=args.trust_remote_code, + revision=args.revision, device_map='hpu', + torch_dtype=model_dtype, ) -elif re.search("chatglm", args.model.lower()): - from models.modeling_chatglm import ChatGLMForConditionalGeneration - user_model = ChatGLMForConditionalGeneration.from_pretrained( - args.model, - revision=args.revision, - device_map='hpu', - ) -else: - user_model = AutoModelForCausalLM.from_pretrained( - args.model, - trust_remote_code=args.trust_remote_code, - revision=args.revision, - device_map='hpu', - ) -# tokenizer -if re.search("baichuan", args.model.lower()): - from models.tokenization_baichuan import BaichuanTokenizer - tokenizer = BaichuanTokenizer.from_pretrained( - args.model, - trust_remote_code=args.trust_remote_code - ) -else: - tokenizer = AutoTokenizer.from_pretrained( - args.model, - trust_remote_code=args.trust_remote_code - ) if world_size > 1: if re.search("llama", args.model.lower()): @@ -148,7 +153,6 @@ def itrex_bootstrap_stderr(f, xs, iters): from transformers.models.llama.modeling_llama import LlamaDecoderLayer ds_inference_kwargs["injection_policy"] = {LlamaDecoderLayer: ("self_attn.o_proj", "mlp.down_proj")} ds_inference_kwargs["checkpoint"] = checkpoints_json.name - ds_model = deepspeed.init_inference(user_model, **ds_inference_kwargs) else: ds_model = deepspeed.init_inference(user_model, @@ -156,28 +160,37 @@ def itrex_bootstrap_stderr(f, xs, iters): replace_with_kernel_inject=False) user_model = ds_model.module + +# tokenizer +if re.search("baichuan", args.model.lower()): + from models.tokenization_baichuan import BaichuanTokenizer + tokenizer = BaichuanTokenizer.from_pretrained( + args.model, + trust_remote_code=args.trust_remote_code + ) +else: + tokenizer = AutoTokenizer.from_pretrained( + args.model, + trust_remote_code=args.trust_remote_code + ) + + user_model.eval() -if args.approach in ["dynamic", "static"]: + +### dynamic & static quantization ### +if args.approach in ["dynamic", "static"] and not args.load: print("device:", next(user_model.parameters()).device) - from neural_compressor.torch.quantization.config import FP8QConfig, get_default_fp8_qconfig - from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic + from neural_compressor.torch.quantization.config import FP8Config, get_default_fp8_config from neural_compressor.torch.quantization import quantize - if args.precision == "fp8_e4m3": - dtype = torch.float8_e4m3fn - else: - dtype = torch.float8_e5m2 + dtype = args.precision if args.approach == "dynamic": - #user_model = quantize_dynamic(user_model, dtype, inplace=True) - qconfig = FP8QConfig(weight_dtype=dtype, act_dtype=dtype, approach="dynamic") - if args.skip_lm_head: - fp32_config = FP8QConfig(weight_dtype=torch.float32, act_dtype=torch.float32) - qconfig.set_local("lm_head", fp32_config) - user_model = quantize_dynamic(user_model, qconfig, inplace=True) + from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic + user_model = quantize_dynamic(user_model, dtype, inplace=True) elif args.approach == "static": - qconfig = FP8QConfig(weight_dtype=dtype, act_dtype=dtype, approach="static") + qconfig = FP8Config(w_dtype=dtype, act_dtype=dtype, approach="static") if args.skip_lm_head: - fp32_config = FP8QConfig(weight_dtype=torch.float32, act_dtype=torch.float32) + fp32_config = FP8Config(w_dtype="fp32", act_dtype="fp32") qconfig.set_local("lm_head", fp32_config) # dataset from datasets import load_dataset @@ -186,7 +199,13 @@ def itrex_bootstrap_stderr(f, xs, iters): calib_data = [] for examples in calib_dataset: calib_data.append( - tokenizer(examples["text"], return_tensors="pt", max_length=128) + tokenizer( + examples["text"], + return_tensors="pt", + max_length=64, + padding="max_length", + truncation=True + ) ) def calib_func(model): @@ -199,12 +218,46 @@ def calib_func(model): ) user_model = quantize(user_model, qconfig, calib_func, inplace=True) - print(user_model, flush=True) + # saving + if args.save and local_rank in [-1, 0]: + user_model.save("saved_results") + + +if args.load: + from neural_compressor.torch.quantization import load + user_model = load(user_model, "saved_results") + + +if args.approach in ["dynamic", "static"] or args.load: + # It enables weights constant folding + from habana_frameworks.torch.core.quantization import _check_params_as_const, _mark_params_as_const + _mark_params_as_const(user_model) # can reduce memory allocated and speed up + _check_params_as_const(user_model) + + + +# If torch.matmul and torch.bmm are not replaced by INC module, +# Below codes can make torch.matmul and torch.bmm run on fp8 by injection. +if not args.skip_fp8_mm and args.precision in ['fp8_e4m3', 'fp8_e5m2']: + def replace_torch_mm_bmm(): + from neural_compressor.torch.amp.fp8.functions import fp8_matmul + torch.matmul = fp8_matmul + torch.bmm = fp8_matmul + replace_torch_mm_bmm() + + +# inference optimization if args.to_graph: import habana_frameworks.torch.hpu.graphs as htgraphs user_model = htgraphs.wrap_in_hpu_graph(user_model) + +# dump message of HPU after quantization or reloading +show_msg() + + +### generation, performance and accuracy validation ### if args.generate: input_prompt = "Here is my prompt" print("Prompt sentence:", input_prompt) @@ -234,6 +287,7 @@ def calib_func(model): print("Generated sentence:", output_sentence) print("Duration:", eval_end - eval_start) + if args.performance: eval_start = time.perf_counter() input_prompt = "Intel is a company which" @@ -242,6 +296,7 @@ def calib_func(model): outputs = user_model.generate(input_tokens, **generation_config) print("Duration of generating 100 tokens :", time.perf_counter() - eval_start) + if args.accuracy: class HabanaModelAdapter(lm_eval.base.BaseLM): @@ -292,16 +347,14 @@ def find_bucket(self, length): return [b for b in self.buckets if b >= length][0] def _model_call(self, inps): - #print(inps.shape) seq_length = inps.shape[-1] + padding_length = 0 bucket_length = self.find_bucket(seq_length) padding_length = bucket_length - seq_length - if True: - import torch.nn.functional as F - inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id) + inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id) + logits = self.model(inps.to(self._device))["logits"].cpu() - logits = self.model(inps.to(self._device))['logits'] - if True and padding_length > 0: + if padding_length > 0: logits = logits[:, :-padding_length, :] logits = logits.to(torch.float32) return logits @@ -333,18 +386,18 @@ def _model_call(self, inps): dumped = json.dumps(results, indent=2) + accu_dict = {} + case_name = args.approach + "-" + args.precision for task_name in args.tasks: if task_name == "wikitext": print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]), flush=True) + accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["word_perplexity"]] else: print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]), flush=True) + accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["acc"]] + if args.dump_to_excel and local_rank in [-1, 0]: + save_to_excel(accu_dict) + -# show memory usage -mem_stats = memory_stats() -mem_dict = { - "memory_allocated (GB)": np.round(mem_stats["InUse"] / 1024**3, 2), - "max_memory_allocated (GB)": np.round(mem_stats["MaxInUse"] / 1024**3, 2), - "total_memory_available (GB)": np.round(mem_stats["Limit"] / 1024**3, 2), -} -for k, v in mem_dict.items(): - print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v)) \ No newline at end of file +# dump final message of HPU +show_msg() diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/utils.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/utils.py new file mode 100644 index 00000000000..7eac0e0bdf7 --- /dev/null +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/utils.py @@ -0,0 +1,35 @@ +def show_msg(): + import numpy as np + import glob + from habana_frameworks.torch.hpu import memory_stats + print("Number of HPU graphs:", len(glob.glob(".graph_dumps/*PreGraph*"))) + mem_stats = memory_stats() + mem_dict = { + "memory_allocated (GB)": np.round(mem_stats["InUse"] / 1024**3, 2), + "max_memory_allocated (GB)": np.round(mem_stats["MaxInUse"] / 1024**3, 2), + "total_memory_available (GB)": np.round(mem_stats["Limit"] / 1024**3, 2), + } + for k, v in mem_dict.items(): + print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v)) + + +def itrex_bootstrap_stderr(f, xs, iters): + from lm_eval.metrics import _bootstrap_internal, sample_stddev + res = [] + chunk_size = min(1000, iters) + it = _bootstrap_internal(f, chunk_size) + for i in range(iters // chunk_size): + bootstrap = it((i, xs)) + res.extend(bootstrap) + return sample_stddev(res) + + +def save_to_excel(dict): + import pandas as pd + df_new = pd.DataFrame(dict) + try: + df_existing = pd.read_excel('output.xlsx') + except FileNotFoundError: + df_existing = pd.DataFrame() + df_combined = pd.concat([df_existing, df_new], axis=0, ignore_index=True) + df_combined.to_excel('output.xlsx', index=False, engine='openpyxl', header=True) diff --git a/neural_compressor/common/utils/__init__.py b/neural_compressor/common/utils/__init__.py index 3a2ae7280d3..88aeae223ef 100644 --- a/neural_compressor/common/utils/__init__.py +++ b/neural_compressor/common/utils/__init__.py @@ -14,6 +14,7 @@ from neural_compressor.common.utils.constants import * from neural_compressor.common.utils.logger import * +from neural_compressor.common.utils.save_load import save_config_mapping, load_config_mapping # ! Put the following `utility` import after the `logger` import as `utility` used `logger` from neural_compressor.common.utils.utility import * diff --git a/neural_compressor/common/utils/save_load.py b/neural_compressor/common/utils/save_load.py new file mode 100644 index 00000000000..5ecd8ce3b97 --- /dev/null +++ b/neural_compressor/common/utils/save_load.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + + +def save_config_mapping(config_mapping, qconfig_file_path): + """Save config mapping to json file. + + Args: + config_mapping (dict): config mapping. + qconfig_file_path (str): path to saved json file. + """ + + per_op_qconfig = {} + for (op_name, op_type), op_config in config_mapping.items(): + value = {op_config.name: op_config.to_dict()} + per_op_qconfig[str((op_name, op_type))] = value + + with open(qconfig_file_path, "w") as f: + json.dump(per_op_qconfig, f, indent=4) + + +def load_config_mapping(qconfig_file_path, config_name_mapping): + """Reload config mapping from json file. + + Args: + qconfig_file_path (str): path to saved json file. + config_name_mapping (dict): map config name to config object. + For example: ConfigRegistry.get_all_configs()["torch"] + + Returns: + config_mapping (dict): config mapping. + """ + config_mapping = {} + with open(qconfig_file_path, "r") as f: + per_op_qconfig = json.load(f) + for key, value in per_op_qconfig.items(): + op_name, op_type = eval(key) + # value here is a dict, so we convert it to an object with config_name_mapping, + # which is defined in a specific framework. + config_name = next(iter(value)) + config_obj = config_name_mapping[config_name]["cls"]() + config_obj.from_dict(value[config_name]) + config_mapping[(op_name, op_type)] = config_obj + return config_mapping diff --git a/neural_compressor/torch/algorithms/habana_fp8/__init__.py b/neural_compressor/torch/algorithms/habana_fp8/__init__.py index 7cba9af24cd..fe3a05d7d0b 100644 --- a/neural_compressor/torch/algorithms/habana_fp8/__init__.py +++ b/neural_compressor/torch/algorithms/habana_fp8/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .fp8_quant import quantize_dynamic, quantize, white_list +from .save_load import save, load diff --git a/neural_compressor/torch/algorithms/habana_fp8/fp8_quant.py b/neural_compressor/torch/algorithms/habana_fp8/fp8_quant.py index b1abe329d3b..da11ae0323f 100644 --- a/neural_compressor/torch/algorithms/habana_fp8/fp8_quant.py +++ b/neural_compressor/torch/algorithms/habana_fp8/fp8_quant.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + # pylint:disable=import-error import copy @@ -20,12 +21,15 @@ import torch from deepspeed.module_inject import LinearAllreduce, LinearLayer from deepspeed.module_inject.layers import LmHeadLinearAllreduce +from habana_frameworks.torch.core.quantization import _check_params_as_const, _mark_params_as_const -from neural_compressor.common.utils import FP8_QUANT -from neural_compressor.torch.quantization.modules import Autocast, BatchMatmul, Matmul -from neural_compressor.torch.utils.utility import fetch_module, logger, register_algo, set_module +from neural_compressor.torch.utils import fetch_module, logger, set_module -from .modules import ( +from .modules import ( # fp32; dynamic modules; static modules; dtype amax + E4M3_AMAX, + E5M2_AMAX, + Autocast, + BatchMatmul, FP8BatchMatmul, FP8Cast, FP8DynamicBatchMatmul, @@ -36,6 +40,8 @@ FP8LinearLayer, FP8LmHeadLinearAllreduce, FP8Matmul, + Matmul, + _map_guadi2_scale, ) quantization_mapping = { @@ -51,20 +57,20 @@ white_list = tuple(quantization_mapping.keys()) -# without scale factor 0.9, the output will be abnormal. -E4M3_AMAX = torch.tensor(240 * 0.9, dtype=torch.float).to("hpu") -E5M2_AMAX = torch.tensor(57344 * 0.9, dtype=torch.float).to("hpu") -FP8_DTYPE = [torch.float8_e5m2, torch.float8_e4m3fn] +FP8_DTYPE = [torch.float8_e5m2, torch.float8_e4m3fn, "fp8_e5m2", "fp8_e4m3"] +dtype_mapping = {"fp8_e5m2": torch.float8_e5m2, "fp8_e4m3": torch.float8_e4m3fn} +# enable inference optimizations +htcore.hpu_initialize() def _replace_module(module, qconfig): if qconfig.approach == "static": if isinstance(module, white_list): QModule = quantization_mapping[type(module)] - assert qconfig.weight_dtype == qconfig.act_dtype, "weight and activation should be the same dtype." - module = QModule(module, qconfig.act_dtype) + assert qconfig.w_dtype == qconfig.act_dtype, "weight and activation should be the same dtype." + module = QModule(module, dtype_mapping[qconfig.act_dtype]) elif qconfig.approach == "dynamic": - dtype = qconfig.act_dtype + dtype = dtype_mapping[qconfig.act_dtype] if isinstance(module, torch.nn.Linear): # need module for initialization module = FP8DynamicLinear(module, dtype) @@ -74,12 +80,14 @@ def _replace_module(module, qconfig): module = FP8DynamicBatchMatmul(dtype) elif isinstance(module, Autocast): module = FP8Cast(dtype=dtype) - htcore.mark_step() + htcore.mark_step() return module def quantize_dynamic(model, dtype=torch.float8_e4m3fn, inplace=True): q_model = model if inplace else copy.deepcopy(model) + if isinstance(dtype, str): + dtype = dtype_mapping[dtype] for n, m in q_model.named_modules(): if isinstance(m, torch.nn.Linear): new_m = FP8DynamicLinear(m, dtype) # need m for init @@ -94,6 +102,8 @@ def quantize_dynamic(model, dtype=torch.float8_e4m3fn, inplace=True): new_m = FP8Cast(dtype=dtype) set_module(q_model, n, new_m) htcore.mark_step() + _mark_params_as_const(q_model) + _check_params_as_const(q_model) return q_model @@ -129,7 +139,7 @@ def _remove_observer(module, qconfig): import deepspeed.comm as dist from torch.distributed import ReduceOp - HF_max = E4M3_AMAX if qconfig.act_dtype == torch.float8_e4m3fn else E5M2_AMAX + HF_max = E4M3_AMAX if qconfig.act_dtype == "fp8_e4m3" else E5M2_AMAX if hasattr(module, "input_activation_post_process"): if hasattr(module.input_activation_post_process, "_non_linear_param_search"): # kl min_val, max_val = module.input_activation_post_process._non_linear_param_search() @@ -141,7 +151,11 @@ def _remove_observer(module, qconfig): amax = amax.to("hpu") dist.all_reduce(amax, op=ReduceOp.MAX) scale = HF_max / amax - module.register_parameter("scale", torch.nn.Parameter(scale)) + scale = _map_guadi2_scale(scale) + if hasattr(module, "input_activation_post_process1"): + module.register_parameter("scale1", torch.nn.Parameter(scale)) + else: + module.register_parameter("scale", torch.nn.Parameter(scale)) delattr(module, "input_activation_post_process") if hasattr(module, "input_activation_post_process1"): if hasattr(module.input_activation_post_process1, "_non_linear_param_search"): @@ -154,7 +168,8 @@ def _remove_observer(module, qconfig): amax = amax.to("hpu") dist.all_reduce(amax, op=ReduceOp.MAX) scale = HF_max / amax - module.register_parameter("scale1", torch.nn.Parameter(scale)) + scale = _map_guadi2_scale(scale) + module.register_parameter("scale2", torch.nn.Parameter(scale)) delattr(module, "input_activation_post_process1") # remove observer hooks @@ -171,7 +186,7 @@ def prepare(model, qconfig_mapping): for (op_name, op_type), qconfig in qconfig_mapping.items(): if qconfig.approach == "dynamic": continue - if qconfig.weight_dtype not in FP8_DTYPE: + if qconfig.w_dtype not in FP8_DTYPE: continue module = fetch_module(model, op_name) if module is None: @@ -184,7 +199,7 @@ def prepare(model, qconfig_mapping): def convert(model, qconfig_mapping): for (op_name, op_type), qconfig in qconfig_mapping.items(): - if qconfig.weight_dtype not in FP8_DTYPE: + if qconfig.w_dtype not in FP8_DTYPE: continue module = fetch_module(model, op_name) if module is None: @@ -207,4 +222,6 @@ def quantize(model, qconfig_mapping, run_fn=None, run_args=None, inplace=True): else: run_fn(q_model) q_model = convert(q_model, qconfig_mapping) + _mark_params_as_const(q_model) + _check_params_as_const(q_model) return q_model diff --git a/neural_compressor/torch/algorithms/habana_fp8/modules.py b/neural_compressor/torch/algorithms/habana_fp8/modules.py index 759c4b8ada7..6e74c46870e 100644 --- a/neural_compressor/torch/algorithms/habana_fp8/modules.py +++ b/neural_compressor/torch/algorithms/habana_fp8/modules.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + # pylint:disable=import-error import os @@ -18,20 +19,56 @@ import habana_frameworks.torch.core as htcore import habana_frameworks.torch.hpex import torch +import torch.nn as nn from torch.nn import functional as F from neural_compressor.common import logger -# without scale factor 0.9, the output will be abnormal. -E4M3_AMAX = torch.tensor(240 * 0.9, dtype=torch.float).to("hpu") -E5M2_AMAX = torch.tensor(57344 * 0.9, dtype=torch.float).to("hpu") +E4M3_AMAX = torch.tensor(240, dtype=torch.float).to("hpu") +E5M2_AMAX = torch.tensor(57344, dtype=torch.float).to("hpu") + + +##################### FP32 modules ####################### +class Matmul(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + +class BatchMatmul(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.bmm(x, y) + + +class Autocast(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + + +##################### FP8 modules ####################### +def _map_guadi2_scale(scale): + USE_GAUDI2_SCALE = os.environ.get("USE_GAUDI2_SCALE") + if USE_GAUDI2_SCALE: + scale_list = torch.tensor([16, 1, 1 / 16, 1 / 256]) + for i in scale_list: + if scale > i or i == torch.tensor(1 / 256): + return i + else: + return scale class FP8DynamicLinear(torch.nn.Module): def __init__(self, org_module, dtype=torch.float8_e4m3fn) -> None: super().__init__() # attributes - org_module.to("hpu") self.use_amax = True self.dtype = dtype self.dtype_amax = E4M3_AMAX if self.dtype == torch.float8_e4m3fn else E5M2_AMAX @@ -39,6 +76,7 @@ def __init__(self, org_module, dtype=torch.float8_e4m3fn) -> None: self.out_features = org_module.out_features self.weight_dtype = self.dtype self.out_dtype = org_module.weight.dtype + # register weight, bias self.register_buffer( "weight", torch.empty( @@ -48,35 +86,52 @@ def __init__(self, org_module, dtype=torch.float8_e4m3fn) -> None: dtype=self.weight_dtype, ), ) + if org_module.bias is not None: + self.register_buffer( + "bias", + torch.empty( + self.out_features, + device="hpu", + dtype=self.out_dtype, + ), + ) + else: + self.bias = None + # register scale + if not org_module.weight.device.type == "meta": + weight_scale = self.dtype_amax / org_module.weight.data.abs().max() + weight_scale = _map_guadi2_scale(weight_scale) + else: + weight_scale = torch.tensor(1.0) self.register_buffer( - "bias", - torch.empty( - self.out_features, + "weight_scale", + torch.tensor( + weight_scale, device="hpu", - dtype=self.out_dtype, + dtype=torch.float32, ), ) - # user configuration - # scale = HF_max /amax - if self.use_amax: - self.weight_scale = self.dtype_amax / org_module.weight.data.abs().max() - self.weight_scale_inv = torch.reciprocal(self.weight_scale) - else: - self.weight_scale = None - self.weight_scale_inv = None - self.weight = torch.ops.hpu.cast_to_fp8_v2(org_module.weight.data, self.weight_scale, False, False, self.dtype)[ - 0 - ] - - if org_module.bias is not None: - self.bias = org_module.bias.data.type(self.out_dtype) - else: - self.bias = None + self.register_buffer( + "weight_scale_inv", + torch.tensor( + torch.reciprocal(weight_scale), + device="hpu", + dtype=torch.float32, + ), + ) + # copy weight and bias + if not org_module.weight.device.type == "meta": + org_module.to("hpu") + self.weight.data.copy_( + torch.ops.hpu.cast_to_fp8_v2(org_module.weight.data, self.weight_scale, False, False, self.dtype)[0] + ) + if org_module.bias is not None: + self.bias.data.copy_(org_module.bias.data.type(self.out_dtype)) def forward(self, inp): assert inp.shape[-1] == self.in_features, "GEMM not possible" org_middle_shape = inp.shape[1:-1] - inp = inp.view((-1, self.in_features)) + inp = inp.view(-1, self.in_features) if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: if self.use_amax: input_scale = self.dtype_amax / inp.abs().max() @@ -174,11 +229,10 @@ class FP8Linear(torch.nn.Module): def __init__(self, org_module, dtype) -> None: super().__init__() # attributes - org_module.to("hpu") - self.dtype = dtype - self.dtype_amax = E4M3_AMAX if self.dtype == torch.float8_e4m3fn else E5M2_AMAX self.in_features = org_module.in_features self.out_features = org_module.out_features + self.dtype = dtype + self.dtype_amax = E4M3_AMAX if self.dtype == torch.float8_e4m3fn else E5M2_AMAX self.weight_dtype = self.dtype self.out_dtype = org_module.weight.dtype self.register_buffer( @@ -190,41 +244,69 @@ def __init__(self, org_module, dtype) -> None: dtype=self.weight_dtype, ), ) + if org_module.bias is not None: + self.register_buffer( + "bias", + torch.empty( + self.out_features, + device="hpu", + dtype=self.out_dtype, + ), + ) + else: + self.bias = None + input_scale = _map_guadi2_scale(org_module.scale) if hasattr(org_module, "scale") else torch.tensor(1.0) self.register_buffer( - "bias", - torch.empty( - self.out_features, + "input_scale", + torch.tensor( + input_scale, device="hpu", - dtype=self.out_dtype, + dtype=torch.float32, ), ) - assert hasattr(org_module, "scale"), "scale is not recorded when convert to FP8Linear." self.register_buffer( - "scale", + "input_scale_inv", torch.tensor( - org_module.scale, + torch.reciprocal(input_scale), device="hpu", dtype=torch.float32, ), ) - self.scale_inv = torch.reciprocal(self.scale) - - self.weight_scale = self.dtype_amax / org_module.weight.data.abs().max() - self.weight_scale_inv = torch.reciprocal(self.weight_scale) - self.weight = torch.ops.hpu.cast_to_fp8_v2(org_module.weight.data, self.weight_scale, False, False, self.dtype)[ - 0 - ] - - if org_module.bias is not None: - self.bias = org_module.bias.data.type(self.out_dtype) + if not org_module.weight.device.type == "meta": + weight_scale = self.dtype_amax / org_module.weight.data.abs().max() + weight_scale = _map_guadi2_scale(weight_scale) else: - self.bias = None + weight_scale = torch.tensor(1.0) + self.register_buffer( + "weight_scale", + torch.tensor( + weight_scale, + device="hpu", + dtype=torch.float32, + ), + ) + self.register_buffer( + "weight_scale_inv", + torch.tensor( + torch.reciprocal(weight_scale), + device="hpu", + dtype=torch.float32, + ), + ) + # copy weight and bias + if not org_module.weight.device.type == "meta": + org_module.to("hpu") + self.weight.data.copy_( + torch.ops.hpu.cast_to_fp8_v2(org_module.weight.data, self.weight_scale, False, False, self.dtype)[0] + ) + if org_module.bias is not None: + self.bias.data.copy_(org_module.bias.data.type(self.out_dtype)) def forward(self, inp): assert inp.shape[-1] == self.in_features, "GEMM not possible" org_middle_shape = inp.shape[1:-1] - inp = inp.view((-1, self.in_features)) - inp = torch.ops.hpu.cast_to_fp8_v2(inp, self.scale, False, False, self.dtype)[0] + inp = inp.view(-1, self.in_features) + inp = torch.ops.hpu.cast_to_fp8_v2(inp, self.input_scale, False, False, self.dtype)[0] out = torch.ops.hpu.fp8_gemm_v2( inp, False, @@ -232,7 +314,7 @@ def forward(self, inp): True, None, self.out_dtype, - self.scale_inv, # inv is used for recover scale + self.input_scale_inv, # inv is used for recover scale self.weight_scale_inv, self.bias, False, @@ -245,7 +327,7 @@ def extra_repr(self) -> str: self.in_features, self.out_features, self.bias is not None, - self.scale, + self.input_scale, self.dtype, ) @@ -257,27 +339,24 @@ def __init__(self, org_module, dtype) -> None: self.dtype = dtype self.dtype_amax = E4M3_AMAX if self.dtype == torch.float8_e4m3fn else E5M2_AMAX self.out_dtype = torch.float32 - assert hasattr(org_module, "scale") and hasattr( - org_module, "scale1" - ), "scale is not recorded when convert to FP8Linear." + scale1 = org_module.scale1 if hasattr(org_module, "scale1") else 1.0 + scale2 = org_module.scale2 if hasattr(org_module, "scale2") else 1.0 self.register_buffer( - "scale", + "scale1", torch.tensor( - org_module.scale, + scale1, device="hpu", dtype=self.out_dtype, ), ) self.register_buffer( - "scale1", + "scale2", torch.tensor( - org_module.scale1, + scale2, device="hpu", dtype=self.out_dtype, ), ) - self.input1_scale_inv = torch.reciprocal(self.scale) - self.input2_scale_inv = torch.reciprocal(self.scale1) def forward(self, input1, input2): dim1 = input1.shape[-1] @@ -286,12 +365,14 @@ def forward(self, input1, input2): if input1.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: self.out_dtype = input1.dtype - input1 = torch.ops.hpu.cast_to_fp8_v2(input1, self.scale, False, False, self.dtype)[0] + input1 = torch.ops.hpu.cast_to_fp8_v2(input1, self.scale1, False, False, self.dtype)[0] + self.input1_scale_inv = torch.reciprocal(self.scale1) else: self.input1_scale_inv = None if input2.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: self.out_dtype = input2.dtype - input2 = torch.ops.hpu.cast_to_fp8_v2(input2, self.scale1, False, False, self.dtype)[0] + input2 = torch.ops.hpu.cast_to_fp8_v2(input2, self.scale2, False, False, self.dtype)[0] + self.input2_scale_inv = torch.reciprocal(self.scale2) else: self.input2_scale_inv = None out = torch.ops.hpu.fp8_gemm_v2( @@ -310,7 +391,7 @@ def forward(self, input1, input2): def extra_repr(self) -> str: return "scales={}, format={}".format( - (self.scale, self.scale1), + (self.scale1, self.scale2), self.dtype, ) @@ -326,11 +407,11 @@ def __init__(self, org_module=None, dtype=torch.float8_e4m3fn) -> None: self.dtype_amax = E4M3_AMAX if self.dtype == torch.float8_e4m3fn else E5M2_AMAX if org_module is not None: org_module.to("hpu") - assert hasattr(org_module, "scale"), "scale is not recorded when convert to FP8Cast." + scale = org_module.scale if hasattr(org_module, "scale") else 1.0 self.register_buffer( "scale", torch.tensor( - org_module.scale, + scale, device="hpu", dtype=torch.float32, ), @@ -357,13 +438,13 @@ class FP8LinearLayer(torch.nn.Module): def __init__(self, org_module, dtype) -> None: super().__init__() # attributes - org_module.to("hpu") self.dtype = dtype self.dtype_amax = E4M3_AMAX if self.dtype == torch.float8_e4m3fn else E5M2_AMAX self.in_features = org_module.weight.shape[1] self.out_features = org_module.weight.shape[0] self.weight_dtype = self.dtype self.out_dtype = org_module.weight.dtype + # register weight, bias self.register_buffer( "weight", torch.empty( @@ -373,23 +454,6 @@ def __init__(self, org_module, dtype) -> None: dtype=self.weight_dtype, ), ) - assert hasattr(org_module, "scale"), "scale is not recorded when convert to FP8Linear." - self.register_buffer( - "scale", - torch.tensor( - org_module.scale, - device="hpu", - dtype=torch.float32, - ), - ) - self.scale_inv = 1.0 / self.scale - # user configuration - # scale = HF_max /amax - self.weight_scale = self.dtype_amax / org_module.weight.data.abs().max() - self.weight_scale_inv = 1.0 / self.weight_scale - self.weight = torch.ops.hpu.cast_to_fp8_v2(org_module.weight.data, self.weight_scale, False, False, self.dtype)[ - 0 - ] if org_module.bias is not None: self.register_buffer( "bias", @@ -399,14 +463,60 @@ def __init__(self, org_module, dtype) -> None: dtype=self.out_dtype, ), ) - self.bias = org_module.bias.data.type(self.out_dtype) else: self.bias = None + # register scale + if not org_module.weight.device.type == "meta": + weight_scale = self.dtype_amax / org_module.weight.data.abs().max() + weight_scale = _map_guadi2_scale(weight_scale) + else: + weight_scale = torch.tensor(1.0) + self.register_buffer( + "weight_scale", + torch.tensor( + weight_scale, + device="hpu", + dtype=torch.float32, + ), + ) + self.register_buffer( + "weight_scale_inv", + torch.tensor( + torch.reciprocal(weight_scale), + device="hpu", + dtype=torch.float32, + ), + ) + # copy weight and bias + if not org_module.weight.device.type == "meta": + org_module.to("hpu") + self.weight.data.copy_( + torch.ops.hpu.cast_to_fp8_v2(org_module.weight.data, self.weight_scale, False, False, self.dtype)[0] + ) + if org_module.bias is not None: + self.bias.data.copy_(org_module.bias.data.type(self.out_dtype)) + input_scale = _map_guadi2_scale(org_module.scale) if hasattr(org_module, "scale") else torch.tensor(1.0) + self.register_buffer( + "input_scale", + torch.tensor( + input_scale, + device="hpu", + dtype=torch.float32, + ), + ) + self.register_buffer( + "input_scale_inv", + torch.tensor( + torch.reciprocal(input_scale), + device="hpu", + dtype=torch.float32, + ), + ) def forward(self, inp): assert inp.shape[-1] == self.in_features, "GEMM not possible" - inputmat = inp.view((-1, self.in_features)) - inputmat = torch.ops.hpu.cast_to_fp8_v2(inputmat, self.scale, False, False, self.dtype)[0] + inputmat = inp.view(-1, self.in_features) + inputmat = torch.ops.hpu.cast_to_fp8_v2(inputmat, self.input_scale, False, False, self.dtype)[0] out = torch.ops.hpu.fp8_gemm_v2( inputmat, False, @@ -414,13 +524,11 @@ def forward(self, inp): True, None, self.out_dtype, - self.scale_inv, # inv is used for recover scale + self.input_scale_inv, # inv is used for recover scale self.weight_scale_inv, - None, + self.bias, False, ) - if self.bias is not None: - out += self.bias return out.view(-1, *inp.shape[1:-1], out.shape[-1]) def extra_repr(self) -> str: @@ -428,7 +536,7 @@ def extra_repr(self) -> str: self.in_features, self.out_features, self.bias is not None, - self.scale, + self.input_scale, self.dtype, ) @@ -437,13 +545,13 @@ class FP8LinearAllreduce(torch.nn.Module): def __init__(self, org_module, dtype) -> None: super().__init__() # attributes - org_module.to("hpu") self.dtype = dtype self.dtype_amax = E4M3_AMAX if self.dtype == torch.float8_e4m3fn else E5M2_AMAX self.in_features = org_module.weight.shape[1] self.out_features = org_module.weight.shape[0] self.weight_dtype = self.dtype self.out_dtype = org_module.weight.dtype + # register weight, bias self.register_buffer( "weight", torch.empty( @@ -453,23 +561,6 @@ def __init__(self, org_module, dtype) -> None: dtype=self.weight_dtype, ), ) - assert hasattr(org_module, "scale"), "scale is not recorded when convert to FP8Linear." - self.register_buffer( - "scale", - torch.tensor( - org_module.scale, - device="hpu", - dtype=torch.float32, - ), - ) - self.scale_inv = 1.0 / self.scale - # user configuration - # scale = HF_max /amax - self.weight_scale = self.dtype_amax / org_module.weight.data.abs().max() - self.weight_scale_inv = 1.0 / self.weight_scale - self.weight = torch.ops.hpu.cast_to_fp8_v2(org_module.weight.data, self.weight_scale, False, False, self.dtype)[ - 0 - ] if org_module.bias is not None: self.register_buffer( "bias", @@ -479,15 +570,61 @@ def __init__(self, org_module, dtype) -> None: dtype=self.out_dtype, ), ) - self.bias = org_module.bias.data.type(self.out_dtype) else: self.bias = None - self.mp_group = org_module.mp_group + # register scale + if not org_module.weight.device.type == "meta": + weight_scale = self.dtype_amax / org_module.weight.data.abs().max() + weight_scale = _map_guadi2_scale(weight_scale) + else: + weight_scale = torch.tensor(1.0) + self.register_buffer( + "weight_scale", + torch.tensor( + weight_scale, + device="hpu", + dtype=torch.float32, + ), + ) + self.register_buffer( + "weight_scale_inv", + torch.tensor( + torch.reciprocal(weight_scale), + device="hpu", + dtype=torch.float32, + ), + ) + # copy weight and bias + if not org_module.weight.device.type == "meta": + org_module.to("hpu") + self.weight.data.copy_( + torch.ops.hpu.cast_to_fp8_v2(org_module.weight.data, self.weight_scale, False, False, self.dtype)[0] + ) + if org_module.bias is not None: + self.bias.data.copy_(org_module.bias.data.type(self.out_dtype)) + self.mp_group = org_module.mp_group + input_scale = _map_guadi2_scale(org_module.scale) if hasattr(org_module, "scale") else torch.tensor(1.0) + self.register_buffer( + "input_scale", + torch.tensor( + input_scale, + device="hpu", + dtype=torch.float32, + ), + ) + self.register_buffer( + "input_scale_inv", + torch.tensor( + torch.reciprocal(input_scale), + device="hpu", + dtype=torch.float32, + ), + ) def forward(self, inp): assert inp.shape[-1] == self.in_features, "GEMM not possible" - inputmat = inp.view((-1, self.in_features)) - inputmat = torch.ops.hpu.cast_to_fp8_v2(inputmat, self.scale, False, False, self.dtype)[0] + inputmat = inp.view(-1, self.in_features) + inputmat = torch.ops.hpu.cast_to_fp8_v2(inputmat, self.input_scale, False, False, self.dtype)[0] out = torch.ops.hpu.fp8_gemm_v2( inputmat, False, @@ -495,7 +632,7 @@ def forward(self, inp): True, None, self.out_dtype, - self.scale_inv, # inv is used for recover scale + self.input_scale_inv, # inv is used for recover scale self.weight_scale_inv, None, False, @@ -513,7 +650,7 @@ def extra_repr(self) -> str: self.in_features, self.out_features, self.bias is not None, - self.scale, + self.input_scale, self.dtype, ) @@ -522,13 +659,13 @@ class FP8LmHeadLinearAllreduce(torch.nn.Module): def __init__(self, org_module, dtype) -> None: super().__init__() # attributes - org_module.to("hpu") self.dtype = dtype self.dtype_amax = E4M3_AMAX if self.dtype == torch.float8_e4m3fn else E5M2_AMAX self.in_features = org_module.weight.shape[1] self.out_features = org_module.weight.shape[0] self.weight_dtype = self.dtype self.out_dtype = org_module.weight.dtype + # register weight, bias self.register_buffer( "weight", torch.empty( @@ -538,23 +675,6 @@ def __init__(self, org_module, dtype) -> None: dtype=self.weight_dtype, ), ) - assert hasattr(org_module, "scale"), "scale is not recorded when convert to FP8Linear." - self.register_buffer( - "scale", - torch.tensor( - org_module.scale, - device="hpu", - dtype=torch.float32, - ), - ) - self.scale_inv = 1.0 / self.scale - # user configuration - # scale = HF_max /amax - self.weight_scale = self.dtype_amax / org_module.weight.data.abs().max() - self.weight_scale_inv = 1.0 / self.weight_scale - self.weight = torch.ops.hpu.cast_to_fp8_v2(org_module.weight.data, self.weight_scale, False, False, self.dtype)[ - 0 - ] if org_module.bias is not None: self.register_buffer( "bias", @@ -564,12 +684,58 @@ def __init__(self, org_module, dtype) -> None: dtype=self.out_dtype, ), ) - self.bias = org_module.bias.data.type(self.out_dtype) else: self.bias = None - self.mp_group = org_module.mp_group - self.rank = org_module.rank - self.world_size = org_module.world_size + # register scale + if not org_module.weight.device.type == "meta": + weight_scale = self.dtype_amax / org_module.weight.data.abs().max() + weight_scale = _map_guadi2_scale(weight_scale) + else: + weight_scale = torch.tensor(1.0) + self.register_buffer( + "weight_scale", + torch.tensor( + weight_scale, + device="hpu", + dtype=torch.float32, + ), + ) + self.register_buffer( + "weight_scale_inv", + torch.tensor( + torch.reciprocal(weight_scale), + device="hpu", + dtype=torch.float32, + ), + ) + # copy weight and bias + if not org_module.weight.device.type == "meta": + org_module.to("hpu") + self.weight.data.copy_( + torch.ops.hpu.cast_to_fp8_v2(org_module.weight.data, self.weight_scale, False, False, self.dtype)[0] + ) + if org_module.bias is not None: + self.bias.data.copy_(org_module.bias.data.type(self.out_dtype)) + self.mp_group = org_module.mp_group + self.rank = org_module.rank + self.world_size = org_module.world_size + input_scale = _map_guadi2_scale(org_module.scale) if hasattr(org_module, "scale") else torch.tensor(1.0) + self.register_buffer( + "input_scale", + torch.tensor( + input_scale, + device="hpu", + dtype=torch.float32, + ), + ) + self.register_buffer( + "input_scale_inv", + torch.tensor( + torch.reciprocal(input_scale), + device="hpu", + dtype=torch.float32, + ), + ) def forward(self, inp): # from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list @@ -581,8 +747,9 @@ def forward(self, inp): inp.shape[-1] % self.world_size == 0 ), "Please ensure that self.world_size is divisible by input.shape[-1]" input_shard = inp.shape[-1] // self.world_size - inputmat = inp[:, :, self.rank * input_shard : (self.rank + 1) * input_shard] - inputmat = torch.ops.hpu.cast_to_fp8_v2(inputmat, self.scale, False, False, self.dtype)[0] + inp_part = inp[:, :, self.rank * input_shard : (self.rank + 1) * input_shard] + inputmat = inp_part.view(-1, input_shard) # dim=2 will help kernel speed + inputmat = torch.ops.hpu.cast_to_fp8_v2(inputmat, self.input_scale, False, False, self.dtype)[0] out = torch.ops.hpu.fp8_gemm_v2( inputmat, False, @@ -590,7 +757,7 @@ def forward(self, inp): True, None, self.out_dtype, - self.scale_inv, # inv is used for recover scale + self.input_scale_inv, # inv is used for recover scale self.weight_scale_inv, None, False, @@ -608,6 +775,6 @@ def extra_repr(self) -> str: self.in_features, self.out_features, self.bias is not None, - self.scale, + self.input_scale, self.dtype, ) diff --git a/neural_compressor/torch/algorithms/habana_fp8/observer.py b/neural_compressor/torch/algorithms/habana_fp8/observer.py index 27d585a7aa0..f329ebe04b3 100644 --- a/neural_compressor/torch/algorithms/habana_fp8/observer.py +++ b/neural_compressor/torch/algorithms/habana_fp8/observer.py @@ -17,9 +17,7 @@ import torch from torch.ao.quantization.observer import * -# without scale factor 0.9, the output will be abnormal. -E4M3_AMAX = torch.tensor(240 * 0.9, dtype=torch.float).to("hpu") -E5M2_AMAX = torch.tensor(57344 * 0.9, dtype=torch.float).to("hpu") +from .modules import E4M3_AMAX, E5M2_AMAX class FP8HistogramObserver(HistogramObserver): diff --git a/neural_compressor/torch/algorithms/habana_fp8/save_load.py b/neural_compressor/torch/algorithms/habana_fp8/save_load.py new file mode 100644 index 00000000000..f563cce234a --- /dev/null +++ b/neural_compressor/torch/algorithms/habana_fp8/save_load.py @@ -0,0 +1,98 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint:disable=import-error + +import json +import os + +import habana_frameworks.torch.core as htcore +import torch + +from neural_compressor.common.utils import load_config_mapping, save_config_mapping +from neural_compressor.torch.utils import QCONFIG_NAME, WEIGHT_NAME, logger + +from .fp8_quant import FP8_DTYPE, dtype_mapping +from .modules import ( # fp32; dynamic modules + Autocast, + BatchMatmul, + FP8Cast, + FP8DynamicBatchMatmul, + FP8DynamicLinear, + FP8DynamicMatmul, + Matmul, +) + + +def save(model, output_dir="./saved_results"): + if not os.path.exists(output_dir): + os.mkdir(output_dir) + qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) + qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME) + # saving process + save_config_mapping(model.qconfig, qconfig_file_path) + + import fp8_convert + + stat_dict = {} + for k, v in model.state_dict().items(): + if v.dtype in FP8_DTYPE: + v = fp8_convert.to_u8(v.to("cpu")) + stat_dict[k] = v.to("cpu") + torch.save(stat_dict, qmodel_file_path) + + logger.info("Save state_dict of quantized model to {}.".format(qmodel_file_path)) + logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path)) + + +def load(model, output_dir="./saved_results"): + from neural_compressor.torch.utils import fetch_module, set_module + + from .fp8_quant import quantization_mapping, white_list + + qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) + stat_dict = torch.load(qmodel_file_path) + import fp8_convert + + for (op_name, op_type), op_qconfig in model.qconfig.items(): + dtype = op_qconfig.w_dtype + choice = 1 if dtype == "fp8_e4m3" else 0 + if op_name + ".weight" in stat_dict: + stat_dict[op_name + ".weight"] = fp8_convert.from_u8(stat_dict[op_name + ".weight"], choice) + if dtype not in FP8_DTYPE: + continue + module = fetch_module(model, op_name) + dtype = dtype_mapping[dtype] + # replace module + if op_qconfig.approach == "static": + if isinstance(module, white_list): + QModule = quantization_mapping[type(module)] + module = QModule(module, dtype) + else: + if isinstance(module, torch.nn.Linear): + # need module for initialization + module = FP8DynamicLinear(module, dtype) + elif isinstance(module, Matmul): + module = FP8DynamicMatmul(dtype) + elif isinstance(module, BatchMatmul): + module = FP8DynamicBatchMatmul(dtype) + elif isinstance(module, Autocast): + module = FP8Cast(dtype=dtype) + set_module(model, op_name, module) + htcore.mark_step() + model.load_state_dict(stat_dict, assign=True) + model.to("hpu") + htcore.mark_step() + logger.info("Quantized model loading successful.") + return model diff --git a/neural_compressor/torch/algorithms/habana_fp8/tensor/__init__.py b/neural_compressor/torch/algorithms/habana_fp8/tensor/__init__.py new file mode 100644 index 00000000000..28f108cb636 --- /dev/null +++ b/neural_compressor/torch/algorithms/habana_fp8/tensor/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/neural_compressor/torch/algorithms/habana_fp8/tensor/convert.cpp b/neural_compressor/torch/algorithms/habana_fp8/tensor/convert.cpp new file mode 100644 index 00000000000..f22c5c82c89 --- /dev/null +++ b/neural_compressor/torch/algorithms/habana_fp8/tensor/convert.cpp @@ -0,0 +1,63 @@ +// Copyright (c) 2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Temporary implementation of fp8 tensor saving and loading +// Will remove after Habana torch applies below patch: +// https://github.com/pytorch/pytorch/pull/114662 + + +#include + + +// function prototype declaration +torch::Tensor to_u8(torch::Tensor tensor); +torch::Tensor from_u8(torch::Tensor tensor, int choice=1); + + +torch::Tensor to_u8(torch::Tensor tensor) { + auto p = tensor.data_ptr(); + // RuntimeError: HPU device type not enabled. + auto options = torch::TensorOptions().device(torch::kCPU).dtype(torch::kUInt8); + auto tmp = torch::from_blob(p, tensor.sizes(), options); + // copy to avoid memory leak. + torch::Tensor tensor_uint8 = torch::empty_like(tensor, torch::kUInt8).copy_(tmp); + return tensor_uint8; +}; + + +/* +choice=1 means torch.float8_e4m3fn; +others means torch.float8_e5m2; +*/ +torch::Tensor from_u8(torch::Tensor tensor, int choice) { + auto p = tensor.data_ptr(); + torch::ScalarType dtype; + if (choice == 1) { + dtype = torch::kFloat8_e4m3fn; + } + else { + dtype = torch::kFloat8_e5m2; + } + auto options = torch::TensorOptions().device(torch::kCPU).dtype(dtype); + auto tmp = torch::from_blob(p, tensor.sizes(), options); + // copy to avoid memory leak. + torch::Tensor tensor_fp8 = torch::empty_like(tensor, dtype).copy_(tmp); + return tensor_fp8; +}; + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("to_u8", &to_u8, "Convert tensor to u8 for saving."); + m.def("from_u8", &from_u8, "Recover tensor from u8 for loading."); +}; diff --git a/neural_compressor/torch/algorithms/smooth_quant/smoothquant.py b/neural_compressor/torch/algorithms/smooth_quant/smoothquant.py new file mode 100644 index 00000000000..9de5bbb40f9 --- /dev/null +++ b/neural_compressor/torch/algorithms/smooth_quant/smoothquant.py @@ -0,0 +1,1596 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +import json +import logging + +import torch + +logger = logging.getLogger() +from collections import UserDict, defaultdict + +import numpy +from tqdm import tqdm + + +def enough_memo_store_scale(device, need_space): + if device == "cuda": # pragma: no cover + current_gpu_index = torch.cuda.current_device() + total_memory = torch.cuda.get_device_properties(current_gpu_index).total_memory + used_memory = torch.cuda.memory_allocated(current_gpu_index) + free_space = total_memory - used_memory + else: + import psutil + + free_space = psutil.virtual_memory().free + return free_space >= need_space + + +def move_input_to_device(input, device=torch.device("cpu")): + if isinstance(input, dict) or isinstance(input, UserDict): + tmp_input = {} + for k, inp in input.items(): + tmp_input[k] = move_input_to_device(inp, device) + input = tmp_input + elif isinstance(input, list) or isinstance(input, tuple): + is_tuple = isinstance(input, tuple) + tmp_input = [] + for inp in input: + tmp_input.append(move_input_to_device(inp, device)) + input = tuple(tmp_input) if is_tuple else tmp_input + elif isinstance(input, torch.Tensor): + input = input.to(device) # pylint: disable=no-member + return input + + +##TODO potential bug, data typeR +def forward_wrapper(model, input, device=torch.device("cpu")): + try: + model = model.to(device) + input = move_input_to_device(input, device) + except Exception as e: + logger.warning(e) + logger.warning("Please check the input device if the error raised.") + if isinstance(input, dict) or isinstance(input, UserDict): + output = model(**input) + elif isinstance(input, list) or isinstance(input, tuple): + try: + output = model(*input) + except: + output = model(input) + else: + output = model(input) + return output + + +def model_forward(model, dataloader, iters, device): + try: + cnt = 0 + for idx, (input, label) in enumerate(dataloader): + output = forward_wrapper(model, input, device) + cnt += 1 + if iters != -1 and cnt >= iters: + break + except Exception as e: + cnt = 0 + for idx, input in enumerate(dataloader): + output = forward_wrapper(model, input, device) + cnt += 1 + if iters != -1 and cnt >= iters: + break + + +def model_forward_per_sample(model, sample, device): + try: + output = forward_wrapper(model, sample, device) + return output + + except Exception as e: + output = forward_wrapper(model, sample[0], device) + return output + + +def quant_dequant_w(m, num_bits=8, scheme="sym"): + eps = torch.finfo(torch.float32).eps + if isinstance(m, torch.nn.Linear): + x = m.weight + tmp = torch.zeros(torch.max(x, dim=1).values.size()) + if scheme == "sym": + q_min, q_max = -(2.0 ** (num_bits - 1)), 2.0 ** (num_bits - 1) - 1.0 + x_max = torch.max(torch.abs(x), dim=1).values + scale = x_max / (float(q_max - q_min) / 2) + else: + q_min, q_max = 0, 2.0**num_bits - 1.0 + x_max = torch.maximum(torch.max(x, dim=1).values, tmp) + x_min = torch.minimum(torch.min(x, dim=1).values, tmp) + scale = (x_max - x_min) / (2**num_bits - 1) + + scale = torch.clip(scale, min=eps) + + if scheme == "sym": + bias = 0 + else: + bias = torch.round(0 - (torch.min(x, dim=1).values) / scale) + bias = bias.unsqueeze(dim=-1) + scale = scale.unsqueeze(dim=-1) + q_x = torch.round(x / scale + bias) + q_x.clamp_(q_min, q_max) + return (q_x - bias) * scale + elif isinstance(m, torch.nn.Conv2d): + x = m.weight + x = torch.permute(x, (0, 2, 3, 1)) + x = x.reshape(-1, x.shape[-1]) + tmp = torch.zeros(torch.max(x, dim=0).values.size()) + if scheme == "sym": + q_min, q_max = -(2.0 ** (num_bits - 1)), 2.0 ** (num_bits - 1) - 1.0 + x_max = torch.max(torch.abs(x), dim=0).values + scale = x_max / (2 ** (num_bits - 1) - 1) + else: + q_min, q_max = 0, 2.0**num_bits - 1.0 + x_max = torch.maximum(torch.max(x, dim=0).values, tmp) + x_min = torch.minimum(torch.min(x, dim=0).values, tmp) + scale = (x_max - x_min) / (2**num_bits - 1) + scale = torch.clip(scale, min=eps) + if scheme == "sym": + bias = 0 + else: + bias = torch.round(0 - (torch.min(x, dim=0).values) / scale) + bias = bias.unsqueeze(dim=0) + scale = scale.unsqueeze(dim=0) + + q_x = x / scale + bias + q_x.clamp_(q_min, q_max).round_() + q_dq_x = (q_x - bias) * scale + q_dq_x = q_dq_x.view(m.weight.shape[0], m.weight.shape[2], m.weight.shape[3], m.weight.shape[1]) + q_dq_x = torch.permute(q_dq_x, (0, 3, 1, 2)) + return q_dq_x + else: + logger.warning("unsupported layer type, please have a check") + + +def quant_dequant_x(x, min_x=None, max_x=None, num_bits=8): + eps = torch.finfo(torch.float32).eps + q_min, q_max = 0, 2.0**num_bits - 1.0 + if max_x is None or min_x is None: + max_x, min_x = torch.max(x), torch.min(x) + else: + max_x = torch.max(max_x) + min_x = torch.min(min_x) + scale = (max_x - min_x) / (2**num_bits - 1) + scale = torch.clip(scale, min=eps) + bias = torch.round((0 - min_x) / scale) + q_x = torch.round(x / scale + bias) + q_x.clamp_(q_min, q_max) + return scale * (q_x - bias) + + +def get_module(model, key): + """Get module from model by key name. + + Args: + model (torch.nn.Module): original model + key (str): module name to be replaced + """ + module = model + name_list = key.split(".") + for name in name_list: + if hasattr(module, name): + module = getattr(module, name) + elif hasattr(module, "sq_linear"): # for peft models + module = getattr(module, "sq_linear") + module = getattr(module, name) + elif hasattr(module, "orig_layer"): # for peft models and auto alpha + module = getattr(module, "orig_layer") + module = getattr(module, name) + else: + module = module + return module + + +def set_module(model, key, new_module): + """Set new module into model by key name. + + Args: + model (torch.nn.Module): original model + key (str): module name to be replaced + new_module (torch.nn.Module): new module to be inserted + """ + module = model + name_list = key.split(".") + for name in name_list[:-1]: + if hasattr(module, name): + module = getattr(module, name) + elif hasattr(module, ("sq_linear")): # for peft models that Linears are contained in Linear + module = getattr(module, "sq_linear") + module = getattr(module, name) + elif hasattr(module, ("orig_layer")): # for peft models and auto alpha + module = getattr(module, "orig_layer") + module = getattr(module, name) + else: + module = module + + if hasattr(module, "sq_linear") and name_list[-1] != "sq_linear": # for peft models + module = getattr(module, "sq_linear") + if hasattr(module, "orig_layer") and name_list[-1] != "orig_layer": # for peft models and auto alpha + module = getattr(module, "orig_layer") + setattr(module, name_list[-1], new_module) + + +def cal_scale(input_max, weights, alpha, scale_type="orig"): + if scale_type == "orig": # same as the paper + weights = torch.cat(weights, dim=0) + weight_max = torch.max(torch.abs(weights), dim=0)[0] + input_power = torch.pow(input_max, alpha) + logger.debug(f"{max(input_max)}, {min(input_max)}") + weight_power = torch.pow(weight_max, 1 - alpha) + scale = torch.clip(input_power / weight_power, min=1e-5) + scale[input_power == 0] = 1.0 + if input_power.size() == weight_power.size(): + scale[weight_power == 0] = 0.0 ##FIXME + return scale + + +class WrapperLayer(torch.nn.Module): + def __init__(self, layer, input_min, input_max, save_q_input=False): + super(WrapperLayer, self).__init__() + self.add_module("orig_layer", layer) # set orig_layer in get/set_module + self.quant = False + self.q_input = None + self.fp32_output = None + self.input_max = input_max + self.input_min = input_min + self.weight_scale = None + self.input_scale = None + self.save_q_input = save_q_input + self.do_blockwise = False + + def enable_quant(self): + self.quant = True + + def disable_quant(self): + self.quant = False + + def update_scale(self, input_scale, weight_scale): + self.input_scale = input_scale + self.weight_scale = weight_scale + + ##TODO better tradeoff performance and memory, currently it's too slow + def q_dq_forward(self, x, input_scale, weight_scale): + layer_copy = copy.deepcopy(self.orig_layer) + if weight_scale is not None: + layer_copy.weight *= weight_scale + q_dq_weight = quant_dequant_w(layer_copy) + layer_copy.weight.data.copy_(q_dq_weight) + if input_scale is None: + x = quant_dequant_x(x, self.input_min, self.input_max) + else: + x = input_scale * x + x = quant_dequant_x(x, self.input_min * input_scale, self.input_max * input_scale) ##FIXME + output = layer_copy(x) + return output + + def q_dq_forward_blockwise(self, x, input_scale): + layer_copy = copy.deepcopy(self.orig_layer) + if input_scale is None: + x = quant_dequant_x(x, self.input_min, self.input_max) + else: + x = input_scale * x + x = quant_dequant_x(x, self.input_min * input_scale, self.input_max * input_scale) ##FIXME + output = layer_copy(x) + return output + + def forward(self, x): + if self.quant: + # self.q_input = x * scale ##save the q_input + if self.save_q_input: + self.q_input = x + if not self.do_blockwise: + output = self.q_dq_forward(x, self.input_scale, self.weight_scale) + else: + output = self.q_dq_forward_blockwise(x, self.input_scale) + + else: + output = self.orig_layer(x) + self.output = output + return output + + +class TorchSmoothQuant: + """Fake input channel quantization, for more details please refer to + [1] SmoothQuant: Accurate and Efficient + Post-Training Quantization for Large Language Models + [2] SPIQ: Data-Free Per-Channel Static Input Quantization + Currently, we only handle the layers whose smooth scale could be absorbed, we will support other layers later. + + We only support inplace mode which means the model weights will be changed, you can call recover function + to recover the weights if needed + """ + + def __init__(self, model, dataloader=None, example_inputs=None, q_func=None, traced_model=None): + """ + :param model: Torch model :param dataloader: Calibration dataloader :param traced_model: A specific model + shares the same architecture as the model and could be traced by torch.jit. If not supplied, we use model + instead. + """ + self.model = model + if not isinstance(self.model, torch.nn.Module): + return + device, dtype = self._get_device() + self.model = self.model.to(device) + self.model.eval() + self.device = device + self.dtype = dtype + self.dataloader = dataloader + self.example_inputs = example_inputs + self.q_func = q_func + self.input_maxes = {} + self.input_mins = {} + self.input_maxes_abs = {} + self.traced_model = traced_model + if self.traced_model is None: + self.traced_model = self.model + self.weight_scale_info = {} + self.absorb_scales_info = {} + self.insert_mul = False + self.allow_absorb = True + self.record_max_info = False + self.max_value_info = {} # to record max values for alpha tune + self.self_absorb_layers = {} + self.absorb_to_layer = {} + self.adjust_alpha_space = False + self.weight_clip = True + self.default_alpha = 0.5 + + self._save_scale = False + self.weight_scale_dict = {} + + self.do_blockwise = False + self.block_inputs = {} + self.block_outputs = {} + + def _get_device(self): + """Get the model device + :return:Model device.""" + for _, p in self.model.named_parameters(): + return p.data.device, p.data.dtype + + def _save_input_pc_hook(self, name, percentile=100): + """A forward hook to save input max of a module + :param name: the module name + :return: A hook function.""" + + def save_input_hook(module, inputs, outputs): + input = inputs[0] + ##TODO check input channel is correct + if len(module.weight.shape) == 4: ##conv3d or conv1d not supported now, need better way + input = input.permute(0, 2, 3, 1) + input = input.reshape(-1, input.shape[-1]) + max_tensor = torch.max(input, dim=0)[0] + min_tensor = torch.min(input, dim=0)[0] + k_index = int(input.shape[0] * percentile / 100) + res, _ = torch.kthvalue(torch.abs(input), k_index, dim=0) + ##res = torch.max(torch.abs(input),dim=0)[0] + if name not in self.input_maxes.keys(): + self.input_mins[name], self.input_maxes[name] = min_tensor, max_tensor + self.input_maxes_abs[name] = res + else: + self.input_mins[name] = torch.min(self.input_mins[name], min_tensor) + self.input_maxes[name] = torch.max(self.input_maxes[name], max_tensor) + self.input_maxes_abs[name] = torch.max(self.input_maxes_abs[name], res) + + return save_input_hook + + def _add_min_max_observer(self, modules, percentile=100): + """ + :param modules: the modules which the observer will insert to + :return: + """ + self.hook_handles = [] + for key in modules.keys(): + hook_func = self._save_input_pc_hook(key, percentile) + hook_handle = modules[key].register_forward_hook(hook_func) + self.hook_handles.append(hook_handle) + + def _remove_observer(self): + """Remove the observer from the model + :return:""" + for hook_handle in self.hook_handles: + hook_handle.remove() + + def _calibrate(self, absorb_to_layer, calib_iter, percentile): + """ + :param absorb_to_layer: A dict,key is the absorb layer, val is a list of the to be smoothed layer + :param calib_iter: Data size for calibration + :return: A dict that saved the layer name and the channel-wise max value info + """ + ##hook all the module + hook_modules = {} + for n, module in self.model.named_modules(): + if isinstance(module, tuple(self.op_types)): + hook_modules[n] = module + + self._add_min_max_observer(hook_modules, percentile) + + self._dump_min_max(calib_iter=calib_iter) + self._remove_observer() + return self.input_maxes_abs + + def _dump_min_max(self, calib_iter=100): + """Dump min max per channel information, the min max value will be saved in input_maxes attribute + :param calibration_method: only support min_max currently + :param calib_iter: Sample size for calibration + :return:""" + logger.info("Calibrating...") + if self.q_func: + self.q_func(self.model) + else: + assert self.dataloader, "Please set dataloader for calibration." + model_forward(self.model, self.dataloader, calib_iter, self.device) + + def _reshape_in_channel_to_last(self, layer_name): + """Move the input channel to the last dim + :param layer_name: Layer name + :return: The reshaped weight.""" + layer = get_module(self.model, layer_name) + if layer.__class__.__name__ == "WrapperLayer": + layer = layer.orig_layer + + weight = layer.weight ##TODO oc*ic, support transposed conv + if len(weight.shape) == 4: + weight = weight.permute(0, 2, 3, 1) + weight = weight.reshape(-1, weight.shape[-1]) + return weight + + def _reshape_scale_for_weight(self, layer, scale): + """Reshape the scale for weight input channel, depthwise output channel + :param layer: torch module + :param scale: orig scale + :return: reshaped scale.""" + if hasattr(layer, "orig_layer"): + layer = layer.orig_layer + if isinstance(layer, torch.nn.Conv2d) and layer.groups > 1: ##only depthwise conv could hit here + scale = scale.view(scale.shape[0], 1, 1, 1) ##mount on output channel + + elif isinstance(layer, torch.nn.Conv2d): + scale = scale.view(1, scale.shape[0], 1, 1) + + elif isinstance(layer, torch.nn.Linear): + scale = scale.view(1, scale.shape[0]) + + return scale + + def get_blocks(self): + block_names = [] + for n, m in self.model.named_modules(): + if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__: + for nn, mm in m.named_children(): + block_name = n + "." + nn + block_names.append(block_name) + return block_names + + def _reshape_scale_for_input(self, layer, scale): + """Reshape the scale for input feature in channel + :param layer: + + :param scale: + :return: + """ + if hasattr(layer, "orig_layer"): + layer = layer.orig_layer + if isinstance(layer, torch.nn.Conv2d): + scale = scale.view(1, scale.shape[0], 1, 1) + + elif isinstance(layer, torch.nn.Linear): + scale = scale.view(1, scale.shape[0]) + + return scale + + def _scale_layer_weight(self, layer_name, scale, alpha=0.5, input_minmax=None): ##input channel + """Scale the layer weights at input channel, depthwise conv output channel + :param layer_name: The layer name + :param scale: The scale to be multiplied + :param alpha: alpha for SQLinearWrapper + :param input_minmax: input_minmax for SQLinearWrapper + :return:""" + layer = get_module(self.model, layer_name) + if self.insert_mul: + from .model_wrapper import SQLinearWrapper + + layer = get_module(self.model, layer_name) + if isinstance(layer, SQLinearWrapper): + layer._recover_sq_linear() + set_module(self.model, layer_name, layer.sq_linear) ##recover + else: + new_module = SQLinearWrapper(layer, 1.0 / scale, input_minmax, alpha) + set_module(self.model, layer_name, new_module) + elif self.allow_absorb: + scale = self._reshape_scale_for_weight(layer, scale) + layer.weight = torch.nn.Parameter(layer.weight * scale) + return scale + + def _absorb_scales(self, layer_name, scale): ##output channel + """Absorb the scale to the layer at output channel + :param layer_name: The module name + :param scale: The scale to be absorbed + :param alpha_key: The alpha passed to SQLinearWrapper + :return:""" + if self.insert_mul or not self.allow_absorb: + return # absorb is updated in SQLinearWrapper in def _scale_layer_weight + + ##if self.allow absorb + layer = get_module(self.model, layer_name) + if layer.__class__.__name__ == "WrapperLayer": + layer = layer.orig_layer + if ( + isinstance(layer, torch.nn.BatchNorm2d) + or isinstance(layer, torch.nn.GroupNorm) + or isinstance(layer, torch.nn.InstanceNorm2d) + ): + if layer.affine: + layer.weight *= scale + layer.bias *= scale + else: + layer.affine = True + weight = torch.ones(layer.num_features, device=self.device, dtype=self.dtype) * scale + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + bias = torch.zeros(layer.num_features, device=self.device, dtype=self.dtype) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + elif isinstance(layer, torch.nn.LayerNorm): + if layer.elementwise_affine: + layer.weight *= scale + layer.bias *= scale + else: + layer.elementwise_affine = True + weight = torch.ones(layer.num_features, device=self.device, dtype=self.dtype) * scale + layer.weight = torch.nn.Parameter(torch.ones(weight, requires_grad=False)) + bias = torch.zeros(layer.num_features, device=self.device, dtype=self.dtype) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + + elif isinstance(layer, torch.nn.Conv2d): + ##the order could not be changed + if hasattr(layer, "bias") and (layer.bias is not None): + layer.bias *= scale + scale = scale.view(scale.shape[0], 1, 1, 1) + layer.weight *= scale + + elif isinstance(layer, torch.nn.Linear): + if hasattr(layer, "bias") and (layer.bias is not None): + layer.bias *= scale + scale = scale.view(scale.shape[0], 1) + layer.weight *= scale + + elif layer.__class__.__name__ == "LlamaRMSNorm" or layer.__class__.__name__ == "T5LayerNorm": ##quite tricky + layer.weight *= scale + + else: + logger.warning( + f"found unsupported layer {type(layer)}, try to multiply scale to " + f"weight and bias directly, this may introduce accuracy issue, please have a check " + ) + if hasattr(layer, "weight") and layer.weight is not None: + layer.weight *= scale + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias *= scale + + def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=False): + """Cal the adjust scales + :param absorb_to_layer: A dict mapping absorb layer to smooth quantized layer + :param input_maxes: The channel-wise input max info for layers + :param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict + :return:""" + absorb_to_input_maxes = {} + for key in absorb_to_layer.keys(): + layer_name = absorb_to_layer[key][0] + absorb_to_input_maxes[key] = input_maxes[layer_name] + + weight_scales_info = {} + absorb_scales_info = {} + for index, key in enumerate(absorb_to_layer.keys()): + alpha_tmp = alpha[key] if isinstance(alpha, dict) else alpha + if alpha_tmp < 0: + scale = torch.ones((1), device=self.device) + else: + input_max = absorb_to_input_maxes[key] + layer_names = absorb_to_layer[key] + weights = [] + for layer_name in layer_names: + weight = self._reshape_in_channel_to_last(layer_name) + weights.append(weight) + + weight_max_per_channel = torch.max(torch.abs(torch.cat(weights, dim=0)), dim=0)[0] + if self.weight_clip: + weight_max_per_channel = weight_max_per_channel.clamp(min=1e-5) + if self.record_max_info and not tuning: + # the input of layers with same absorb layer is the same. + input_minmax = [self.input_mins[layer_names[0]], self.input_maxes[layer_names[0]]] + self.max_value_info[key] = {} + self.max_value_info[key]["alpha"] = alpha_tmp + self.max_value_info[key]["input_minmax"] = input_minmax + self.max_value_info[key]["weight_max"] = weight_max_per_channel + self.max_value_info[key]["absorbed_layer"] = layer_names + continue + + if self._save_scale: + if key in self.weight_scale_dict and alpha_tmp in self.weight_scale_dict[key]: + scale = self.weight_scale_dict[key][alpha_tmp] + else: + scale = cal_scale(input_max, weights, alpha_tmp) + else: + scale = cal_scale(input_max, weights, alpha_tmp) + + absorb_scales_info[key] = 1.0 / scale + absorb_scales_info[key][scale == 0] = 0 + layer_names = absorb_to_layer[key] + for layer_name in layer_names: + ##self._scale_layer_weight(layer_name, scale) + weight_scales_info[layer_name] = scale + if self._save_scale: + if layer_name not in self.weight_scale_dict: + self.weight_scale_dict[layer_name] = {} + self.weight_scale_dict[layer_name][alpha_tmp] = scale + return absorb_scales_info, weight_scales_info + + def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=False): + """Adjust the weights and biases + :param absorb_to_layer: A dict mapping absorb layer to smooth quantized layer + :param input_maxes: The channel-wise input max info for layers + :param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict + :return:""" + absorb_scales_info, weight_scales_info = self._cal_scales(absorb_to_layer, input_maxes, alpha, tuning) + if not absorb_scales_info or not weight_scales_info: + return weight_scales_info, absorb_scales_info + for index, key in enumerate(absorb_to_layer.keys()): + if isinstance(alpha, float): + alpha_tmp = alpha + elif isinstance(alpha, dict): + alpha_tmp = alpha[key] + absorb_scale = absorb_scales_info[key] + self._absorb_scales(key, absorb_scale) + layer_names = absorb_to_layer[key] + for layer_name in layer_names: + input_minmax = [self.input_mins[layer_names[0]], self.input_maxes[layer_names[0]]] + self._scale_layer_weight(layer_name, weight_scales_info[layer_name], alpha_tmp, input_minmax) + return weight_scales_info, absorb_scales_info + + def _check_need_calibration(self, alpha, percentile, op_types, scales_per_op, calib_iter): + """ + check need calibration or not + :param alpha: current alpha + :param percentile: current percentile + :param op_types: current op_types + :param scales_per_op: current scales_per_op + :param calib_iter:: current scales_per_op + :return: + """ + need_calib = True + if len(self.input_maxes) == 0: ## the first time + need_calib = True + self.alpha = alpha + self.percentile = percentile + self.op_types = op_types + self.scales_per_op = scales_per_op + self.calib_iter = calib_iter + return need_calib + + if ( + self.percentile == percentile + and self.op_types == op_types + and self.scales_per_op == scales_per_op + and self.calib_iter == calib_iter + ): + if isinstance(alpha, float) or self.alpha == "auto": + need_calib = False + + self.alpha, self.percentile = alpha, percentile + self.op_types, self.scales_per_op = op_types, scales_per_op + self.calib_iter = calib_iter + return need_calib + + def _get_auto_loss(self, output, output_q, loss_type="abs", loss_alpha=1.0): + """Get the loss for auto tuning + :param output: Fp32 output for one layer + :param output_q: Quant output for one layer + :param loss_type: The type of loss + :param loss_alpha: Loss alpha i for mean scale error + :return: A tensor of the loss.""" + if len(output.shape) <= 2: + max_value = torch.max(torch.abs(output)) + else: + output = output.reshape(output.shape[0], -1) + output_q = output_q.reshape(output_q.shape[0], -1) + max_value = torch.max(torch.abs(output), dim=-1).values.unsqueeze(-1) + max_value = torch.clip(max_value, 1e-5) + output = output / max_value ##FIXME need copy not replace + output_q = output_q / max_value + # if loss_type == "nsr": # nsr is unused at this point. + # output[output == 0] = 1e-5 + # loss = torch.sum(torch.log(1.0 + torch.abs(output - output_q) / torch.abs(output))) + # return loss + if loss_type == "abs": + return torch.sum(torch.pow(torch.abs(output - output_q), 0.5)) + else: + return torch.sum((output - output_q) ** 2) + + def _get_sq_layer_names(self): + """Get the all the hook sq layer + :return: All the sq layer names.""" + ##TODO this may not fit for folding=False + module_names = [] + for key in self.absorb_to_layer: + module_names += self.absorb_to_layer[key] + return module_names + + def _get_all_hook_module_names(self): + module_names = [] + for n, module in self.model.named_modules(): + if isinstance(module, tuple(self.op_types)): + module_names.append(n) + return module_names + + def _qdq_model_wrapper_for_auto(self, save_q_input=False): + """Wrapper all the module with qdq + :return:""" + module_names = self._get_all_hook_module_names() + self.to_unwrap_module_names = module_names + for name in module_names: + if name not in self.input_mins: # skip module if it's not used in calibration + continue + module = get_module(self.model, name) + new_module = WrapperLayer(module, self.input_mins[name], self.input_maxes[name], save_q_input=save_q_input) + set_module(self.model, name, new_module) + + def _qdq_model_unwrapper_for_auto(self): + module_names = self.to_unwrap_module_names + for name in module_names: + module = get_module(self.model, name) + if not hasattr(module, "orig_layer"): # skip module if it's not used in calibration + continue + set_module(self.model, name, module.orig_layer) + + def _change_qdq_for_auto(self, enable=True): + module_names = self._get_all_hook_module_names() + for name in module_names: + name = name.split(".orig_layer")[0] + module = get_module(self.model, name) + if not hasattr(module, "orig_layer"): # skip module if it's not used in calibration + continue + if enable: + module.enable_quant() + else: + module.disable_quant() + + def _update_scales_for_auto(self, absorb_scales, weight_scales): + for key in self.absorb_to_layer.keys(): + layer_names = self.absorb_to_layer[key] + for layer_name in layer_names: + layer = get_module(self.model, layer_name) + input_scale = absorb_scales[key] + weight_scale = weight_scales[layer_name] + input_scale = self._reshape_scale_for_input(layer, input_scale) + weight_scale = self._reshape_scale_for_weight(layer, weight_scale) + layer.update_scale(input_scale, weight_scale) ##FIXME + + def _add_blockwise_observer(self, block_modules): + """ + :param block_modules: the block modules which the observer will insert to + :return: + """ + self.blockwise_hook_handles = [] + for key in block_modules.keys(): + hook_func = self._save_blockwise_hook(key) + hook_handle = block_modules[key].register_forward_hook(hook_func) + self.blockwise_hook_handles.append(hook_handle) + + def _save_blockwise_hook(self, name): + """A forward hook to save inputs/outputs of a block + :param name: the block name + :return: A hook function.""" + + def save_blockwise_hook(module, inputs, outputs): + self.block_inputs[name] = inputs[0] + self.block_outputs[name] = outputs[0] + + return save_blockwise_hook + + def _get_one_batch_auto_loss(self, input, alpha_space, orig_best_alpha, input_maxes): + self._change_qdq_for_auto(enable=False) + module_names = self._get_sq_layer_names() + + if self.do_blockwise: + block_modules = {} + for key in self.block_names: + block_modules[key] = get_module(self.model, key) + self._add_blockwise_observer(block_modules) + + forward_wrapper(self.model, input, self.device) ##disable quant and get fp32 output + + fp32_output = {} + if not self.do_blockwise: + for name in module_names: + module = get_module(self.model, name) + fp32_output[name] = module.output + module.output = None + else: + for block_name in self.block_names: + fp32_output[block_name] = self.block_outputs[block_name] + self._change_qdq_for_auto(enable=True) + absorb_input_scales, weight_scales = self._cal_scales( + self.absorb_to_layer, input_maxes, orig_best_alpha, tuning=True + ) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + forward_wrapper(self.model, input, self.device) ##save quant_input + for mod_name in module_names: # save fp32 values + mod = get_module(self.model, mod_name) + if mod_name in self.fp32_output_val: + self.fp32_output_val[mod_name].append(torch.norm(mod.output)) + else: + self.fp32_output_val[mod_name] = [torch.norm(mod.output)] + del mod + + loss_alphas = {} + if not self.do_blockwise: + for name in module_names: + module = get_module(self.model, name) + loss = self._get_auto_loss(fp32_output[name], module.output) + cur_alpha = orig_best_alpha + if isinstance(orig_best_alpha, dict): + cur_alpha = orig_best_alpha[name] + key_name = str(cur_alpha) + loss_alphas[name] = {key_name: loss} + else: + for block_name in self.block_names: + block = get_module(self.model, block_name) + loss = self._get_auto_loss(fp32_output[block_name], self.block_outputs[block_name]) + cur_alpha = orig_best_alpha + if isinstance(orig_best_alpha, dict): + cur_alpha = orig_best_alpha[self.block_to_module[block_name][0]] + key_name = str(cur_alpha) + loss_alphas[block_name] = {key_name: loss} + # for name in module_names: + # loss_alphas[name]={} + for alpha in alpha_space: + absorb_input_scales, weight_scales = self._cal_scales(self.absorb_to_layer, input_maxes, alpha, tuning=True) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + if not self.do_blockwise: + for name in module_names: + losses = loss_alphas[name] + if str(alpha) in losses.keys(): + continue + module = get_module(self.model, name) + output = module.q_dq_forward(module.q_input, module.input_scale, module.weight_scale) + loss = self._get_auto_loss(fp32_output[name], output) + loss_alphas[name][str(alpha)] = loss + else: + for block_name in self.block_names: + losses = loss_alphas[block_name] + if str(alpha) in losses.keys(): + continue + block = get_module(self.model, block_name) + block_copy = copy.deepcopy(block) + for name in self.block_to_module[block_name]: + if name == block_name and len(self.block_to_module[block_name]) == 1: + module, module_copy = block, block_copy + else: + module = get_module(block, name) + module_copy = copy.deepcopy(module) + if module.weight_scale is not None: + module_copy.orig_layer.weight *= module.weight_scale + q_dq_weight = quant_dequant_w(module_copy.orig_layer) + module_copy.orig_layer.weight.data.copy_(q_dq_weight) + module_copy.do_blockwise = True + if not (name == block_name and len(self.block_to_module[block_name]) == 1): + set_module(block_copy, name, module_copy) + try: + output = block_copy(self.block_inputs[block_name])[0] + except: # Llama model decoder_layer forward requires position_id + position_ids = torch.arange(self.block_inputs[block_name].size()[1]) + position_ids = position_ids.view(self.block_inputs[block_name].size()[0], -1) + output = block_copy(self.block_inputs[block_name], position_ids=position_ids)[0] + loss = self._get_auto_loss(fp32_output[block_name], output) + loss_alphas[block_name][str(alpha)] = loss + del block_copy # release memory + return loss_alphas + + def _get_best_alpha(self, absorb_to_layer, loss_alphas, shared_criterion): + def dict_to_list(dic): + res = [] + for key in dic.keys(): + res.append((key, dic[key])) + return res + + best_alpha = {} + for ln_name in absorb_to_layer.keys(): + layer_names = absorb_to_layer[ln_name] + cur_shared_criterion = shared_criterion + if len(layer_names) == 1: + cur_shared_criterion = "min" + if cur_shared_criterion == "mean": + loss_tmp = {} + for alpha in loss_alphas[layer_names[0]].keys(): + if alpha not in loss_tmp.keys(): + loss_tmp[alpha] = 0 + for layer_name in layer_names: + loss_tmp[alpha] += loss_alphas[layer_name][alpha] + res = dict_to_list(loss_tmp) + res.sort(key=lambda x: x[1]) + + best_alpha[ln_name] = float(res[0][0]) + + elif cur_shared_criterion == "min" or cur_shared_criterion == "max": + tmp_best_alpha = [] + for layer_name in layer_names: + res = dict_to_list(loss_alphas[layer_name]) + res.sort(key=lambda x: x[1]) + tmp_best_alpha.append(float(res[0][0])) + if cur_shared_criterion == "min": + best_alpha[ln_name] = min(tmp_best_alpha) + else: + best_alpha[ln_name] = max(tmp_best_alpha) + + else: + raise NotImplementedError + return best_alpha + + def _auto_tune_alpha( + self, + input_maxes, + calib_sample_num=32, + alpha_min=0.3, + alpha_max=0.7, + alpha_step=0.05, + shared_criterion="min", + do_blockwise=False, + ): + """Perform alpha-tuning to obtain layer-wise optimal alpha values and adjust parameters accordingly. + + This function takes quantization of the former layers into consideration when qdq one layer + Also, it reduces the memory usage at the cost of increasingtuning time + TODO may have compatibility issue when setting folding=True, check whether having issues when bs!=1 + :param input_maxes: calibration data, input max + :param calib_sample_num: sample count used to auto tuning alpha + :param alpha_min: the min value of alpha + :param alpha_max: the max value of alpha + :param alpha_step: the alpha step in search space + :param shared_criterion: the criterion to choose alpha when multiple layers must share one same alpha + :return: + """ + logger.info("start sq auto tuning") + round_num = max( + len(str(alpha_min).split(".")[1]), len(str(alpha_max).split(".")[1]), len(str(alpha_step).split(".")[1]) + ) + alpha_space = numpy.round(numpy.arange(alpha_min, alpha_max + alpha_step, alpha_step), round_num).tolist() + ##wrapper new module + self._qdq_model_wrapper_for_auto(save_q_input=True) + ##set alpha to 0.5 as default + default_alpha = alpha_space[len(alpha_space) // 2] + if 0.5 in alpha_space: + default_alpha = 0.5 + default_alpha = self.default_alpha + absorb_input_scales, weight_scales = self._cal_scales( + self.absorb_to_layer, input_maxes, default_alpha, tuning=True + ) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + total_cnt = 0 + tmp_cnt = 0 + alpha_update_iter = 0 + # multiply_factor is used to combine samples to calib_sample_num // 4 before summarizing the best alpha + tune_cnt = 4 + multiply_factor = calib_sample_num // tune_cnt if calib_sample_num >= tune_cnt else calib_sample_num + self.fp32_output_val = {} + + best_alphas = default_alpha + if not self.dataloader: + logger.info(f"Auto-tuning failed due to no dataloader, using {best_alphas} instead.") + self._qdq_model_unwrapper_for_auto() + return best_alphas + bar = tqdm(self.dataloader, total=calib_sample_num, desc="auto tune alpha") + try: + for input, label in bar: + loss_alphas = {} + best_alphas_per_module = best_alphas + if isinstance(best_alphas, dict): + for key in self.absorb_to_layer.keys(): + layer_names = self.absorb_to_layer[key] + for layer_name in layer_names: + best_alphas_per_module[layer_name] = best_alphas_per_module[key] + + loss_tmp = self._get_one_batch_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes) + if self.do_blockwise: + if loss_alphas == {}: + for block_name in self.block_names: + for key in self.block_to_module[block_name]: + loss_alphas[key] = loss_tmp[block_name] + else: + for block_name in self.block_names: + for key in self.block_to_module[block_name]: + cur_loss = loss_alphas[key] + for alpha_key in cur_loss.keys(): + cur_loss[alpha_key] += loss_tmp[block_name][alpha_key] + else: + if loss_alphas == {}: + loss_alphas = loss_tmp + else: + for key in loss_alphas.keys(): + cur_loss = loss_alphas[key] + for alpha_key in cur_loss.keys(): + cur_loss[alpha_key] += loss_tmp[key][alpha_key] + total_cnt += self.dataloader.batch_size + tmp_cnt += self.dataloader.batch_size + if tmp_cnt // multiply_factor >= 1: + alpha_update_iter += 1 + tmp_cnt = 0 + best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion) + for key in best_alphas.keys(): + logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}") + absorb_input_scales, weight_scales = self._cal_scales( + self.absorb_to_layer, input_maxes, best_alphas, tuning=True + ) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + # does not need to reset the weight_scale_dict, because use the weight of ori_layer, no change + # self.weight_scale_dict = {} + if total_cnt >= calib_sample_num: + break + except: + for input in bar: + loss_alphas = {} + best_alphas_per_module = best_alphas + if isinstance(best_alphas, dict): + for key in self.absorb_to_layer.keys(): + layer_names = self.absorb_to_layer[key] + for layer_name in layer_names: + best_alphas_per_module[layer_name] = best_alphas_per_module[key] + + loss_tmp = self._get_one_batch_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes) + if self.do_blockwise: + if loss_alphas == {}: + for block_name in self.block_names: + for key in self.block_to_module[block_name]: + loss_alphas[key] = loss_tmp[block_name] + else: + for block_name in self.block_names: + for key in self.block_to_module[block_name]: + cur_loss = loss_alphas[key] + for alpha_key in cur_loss.keys(): + cur_loss[alpha_key] += loss_tmp[block_name][alpha_key] + else: + if loss_alphas == {}: + loss_alphas = loss_tmp + else: + for key in loss_alphas.keys(): + cur_loss = loss_alphas[key] + for alpha_key in cur_loss.keys(): + cur_loss[alpha_key] += loss_tmp[key][alpha_key] + total_cnt += self.dataloader.batch_size + tmp_cnt += self.dataloader.batch_size + if tmp_cnt // multiply_factor >= 1: + alpha_update_iter += 1 + tmp_cnt = 0 + + best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion) + for key in best_alphas.keys(): + logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}") + absorb_input_scales, weight_scales = self._cal_scales( + self.absorb_to_layer, input_maxes, best_alphas, tuning=True + ) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + # self.weight_scale_dict = {} + if total_cnt >= calib_sample_num: + break + + best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion) + for key in best_alphas.keys(): + logger.info(f"Final alpha {key}:{best_alphas[key]}") + max_op, max_ratio, max_key = "", 0, "" + ratio_info = {} + for key in self.absorb_to_layer: + for op_name in self.absorb_to_layer[key]: + fp32_norm, loss_ = ( + torch.sum(torch.stack(self.fp32_output_val[op_name])), + loss_alphas[op_name][str(best_alphas[key])], + ) + ratio = loss_ / fp32_norm + max_op = op_name if ratio > max_ratio else max_op + max_key = key if ratio > max_ratio else max_key + max_ratio = max(ratio, max_ratio) + ratio_info[op_name] = ratio + logger.debug( + f"final loss: {op_name}: {loss_}; @alpha {best_alphas[key]}; \ + fp32_output norm: {fp32_norm}; ratio: {ratio}" + ) + import operator + + ratio_info = dict(sorted(ratio_info.items(), key=operator.itemgetter(1), reverse=True)) + for key in list(ratio_info.keys()): + logger.debug(f"sorted opname-ratio: {key}: {ratio_info[key]}") + if max_op != "": + logger.debug( + f"max loss: {max_op}: {loss_alphas[max_op][str(best_alphas[max_key])]} @alpha {best_alphas[max_key]}\ + fp32_output norm: {torch.sum(torch.stack(self.fp32_output_val[max_op]))}; ratio: {max_ratio}" + ) + self._qdq_model_unwrapper_for_auto() + logger.info("auto tuning done") + return best_alphas + + def transform( + self, + alpha=0.5, + folding=False, + percentile=100, + op_types=[torch.nn.Linear, torch.nn.Conv2d], + scales_per_op=False, + calib_iter=100, + auto_alpha_args={ + "alpha_min": 0.0, + "alpha_max": 1.0, + "alpha_step": 0.1, + "shared_criterion": "mean", + "do_blockwise": False, + }, + weight_clip=True, + default_alpha=0.5, + ): + """The main entry of smooth quant + :param alpha: Alpha value to balance the quantization difficulty of activation and weight, please refer + to the paper for more details + :param folding: whether insert mul(False) or just allow foldable layers(True) for SmoothQuant + :param percentile: remove the activation outlier when calculating the scale + :param op_types: The op typed to be smooth quantized + :param scales_per_op: Not supported now + :param calib_iter: Data size for calibration + :param weight_clip: Whether to clip weight_max when calculating scales. + + :param auto_alpha_args: Hyperparameters used to set the alpha search space in SQ auto-tuning. + By default the search space is 0.0-1.0 with step_size 0.1. + do_blockwise: Whether to do blockwise auto-tuning. + :param default_alpha: A hyperparameter that is used in SQ auto-tuning; by default it is 0.5. + :return: A FP32 model with the same architecture as the orig model but with different weight which will be + benefit to quantization. + """ + if isinstance(auto_alpha_args, dict): + self.do_blockwise = auto_alpha_args.get("do_blockwise", False) + else: + self.do_blockwise = False + if self.do_blockwise: + self.block_names = self.get_blocks() + logger.info("Blockwise auto-tuning will be performed") + if not isinstance(self.model, torch.nn.Module): + logger.warning("smooth quant is ignored since the model is not a torch module") + return self.model + + if folding: + self.insert_mul, self.allow_absorb = False, True + else: + self.insert_mul, self.allow_absorb = True, False + if isinstance(alpha, float) and (alpha < 0 or alpha > 1): + logger.warning("reset alpha to in range [0.0, 1.0]") + + alpha = numpy.clip(alpha, 0.0, 1.0) + + self.weight_clip = weight_clip + self.default_alpha = default_alpha + self.auto_alpha_args = auto_alpha_args + self.recover() + need_calibration = self._check_need_calibration(alpha, percentile, op_types, scales_per_op, calib_iter) + with torch.no_grad(): + str_op_types = [i.__name__ for i in op_types] + input_maxes_abs = self.input_maxes_abs + if need_calibration: ##avoid multiple calibaration during tuning if the only difference is alpha + if self.insert_mul: + self.self_absorb_layers = self._get_all_layer_names(op_types) # TODO: only support linear now. + # fetch modules with the same input + group_modules = self._trace(str_op_types, skip_unsupported_layers=False) + if group_modules is not None: + # use one input for qkv + for k, v in group_modules.items(): + for i in v: + if i in self.self_absorb_layers: + self.self_absorb_layers.pop(i) + self.self_absorb_layers[v[0]] = v + logger.debug(f"self_absorb_layers:{self.self_absorb_layers}") + if self.allow_absorb: + self.absorb_to_layer, no_absorb_layers = self._trace( + str_op_types + ) ##TODO we need to insert mul layer for no_absorb_layers later + if self.absorb_to_layer is None and no_absorb_layers is None: + return self.model + + # remove self.self_absorb_layers if it exists in self.absorb_to_layer + for k, v in self.absorb_to_layer.items(): + for i in v: + if i in self.self_absorb_layers: + self.self_absorb_layers.pop(i) + self.absorb_to_layer.update(self.self_absorb_layers) + + if self.absorb_to_layer is None and no_absorb_layers is None: + logger.warning( + "sorry, could not trace the model, smooth quant is ignored." + "If you are using huggingface model," + "you could set torchscript to True " + ) + return self.model + + if self.do_blockwise: + module_names = self._get_sq_layer_names() + block_names, self.block_to_module = self.block_names, {} + for block in block_names: + self.block_to_module[block] = [] + for module in module_names: + checked = False + for block in block_names: + if block + "." in module: + self.block_to_module[block].append(module) + checked = True + if not checked: + self.block_to_module[module] = [module] + self.block_names = list(self.block_to_module.keys()) + logger.info(f"Blockwise auto-tuning: {len(self.block_names)} blocks found") + logger.debug(f"Blockwise auto-tuning blocks info: {self.block_to_module}") + + input_maxes_abs = self._calibrate(self.absorb_to_layer, calib_iter, percentile) + + # Check if input_maxes match self.absorb_to_layer + # (due to self._get_all_layer_names use layer tree instead of forward_path) + if not folding: + diff_modules = set(self.absorb_to_layer.keys()).difference(input_maxes_abs.keys()) + for d in diff_modules: + del self.absorb_to_layer[d] + + scale_memo_use = 0 + for key in self.absorb_to_layer: + layer_name = self.absorb_to_layer[key][0] + input_max = input_maxes_abs[layer_name] + scale_memo_use += 4 * input_max.shape[0] * len(self.absorb_to_layer[key]) + if alpha == "auto": + alpha_space = (auto_alpha_args["alpha_max"] - auto_alpha_args["alpha_min"]) / auto_alpha_args[ + "alpha_step" + ] + 1 + scale_memo_use *= alpha_space + self._save_scale = enough_memo_store_scale(self.device, scale_memo_use) + + if alpha == "auto": + self.alpha_per_layer = self._auto_tune_alpha( + input_maxes_abs, calib_sample_num=32, **auto_alpha_args + ) ##save the alpha + + if alpha == "auto": + alpha = self.alpha_per_layer + example_inputs = self._get_example_input() + if example_inputs is not None: + out_pre_sq = model_forward_per_sample(self.model, example_inputs, self.device) + + if folding: + self._save_scale = False + if self.record_max_info: + # max_info is recorded in self.max_value_info + self._adjust_parameters(self.absorb_to_layer, input_maxes_abs, alpha) + self.model._smoothquant_optimized = False + return self.model + + self.weight_scale_info, self.absorb_scales_info = self._adjust_parameters( + self.absorb_to_layer, input_maxes_abs, alpha + ) + + self.model._smoothquant_optimized = True + if example_inputs is not None: + # Check mathematical equivelancy + out_post_sq = model_forward_per_sample(self.model, example_inputs, self.device) + + if not self.output_is_equal(out_post_sq, out_pre_sq): + logger.warning( + "Mathematical equivelancy of Smoothquant is not preserved. " + "Please kindly report this issue to https://github.com/intel/neural-compressor." + ) + else: + logger.warning(" Could not get example input, equivelancy check is skipped") + + return self.model + + def output_is_equal(self, out1, out2, atol=1e-04): + try: + if isinstance(out1, tuple): + return all(torch.all(torch.isclose(out1[i], out2[i], atol=atol)) for i in range(len(out1))) + elif isinstance(out1, dict): + return all(torch.all(torch.isclose(out1[k], out2[k], atol=atol)) for k in out1.keys()) + elif isinstance(out1, torch.Tensor): + return torch.all(torch.isclose(out1, out2, atol=atol)) + return False + except: + logger.warning( + "Automatically check failed, Please check equivelancy manually " + "between out_pre_sq and out_post_sq if necessary." + ) + return True + + def recover(self): + """Recover the model weights + :return:""" + with torch.no_grad(): + for key in self.weight_scale_info: + self._scale_layer_weight(key, 1.0 / self.weight_scale_info[key]) + for key in self.absorb_scales_info: + self._absorb_scales(key, 1.0 / self.absorb_scales_info[key]) + self.weight_scale_info = {} ##clear the data + self.absorb_scales_info = {} + + def _get_all_layer_names(self, op_types=[torch.nn.Linear]): + """Try the model to find the layers which can be smooth quantized. + + :param op_types: The op types to be smooth quantized + :return: + self_absorb_layer: A dict, absorb layer name (itself): layers to be smooth quantized + """ + self_absorb_layer = {} + op_types = [torch.nn.Linear] # TODOļ¼š only support SQLinearWrapper + for name, module in self.model.named_modules(): + if isinstance(module, tuple(op_types)): + self_absorb_layer[name] = [name] + return self_absorb_layer + + def _get_example_input(self): + if self.dataloader is None and self.example_inputs is None: + return None + if self.example_inputs is None: + try: + for idx, (input, label) in enumerate(self.dataloader): + self.example_inputs = input + break + except: + for idx, input in enumerate(self.dataloader): + self.example_inputs = input + break + + return self.example_inputs + + def _trace(self, op_types, skip_unsupported_layers=True): + """Try the model to find the layers which can be smooth quantized. + + :param op_types: The op types to be smooth quantized + :return: + absorb_to_layer: A dict, absorb layer name:layers to be smooth quantized + no_absorb_layers: A list saving the layers which could not find the absorb layer + """ + tg = GraphTrace() + self._get_example_input() + absorb_to_layer, no_absorb_layers = tg.get_absorb_to_layer( + self.traced_model, + self.example_inputs, + op_types, + skip_unsupported_layers=skip_unsupported_layers, + ) + if not skip_unsupported_layers: + return absorb_to_layer + if absorb_to_layer is None and no_absorb_layers is None: + logger.warning( + "sorry, could not trace the model, smooth quant is skipped." + "If you are using huggingface model," + "you could set torchscript to True " + "when loading the model or set the return_dict to False" + ) + elif absorb_to_layer == {}: + logger.warning("could not find any layer to be absorbed") + else: + to_absorb_cnt = 0 + for key, item in absorb_to_layer.items(): + to_absorb_cnt += len(item) + logger.info( + f" {to_absorb_cnt} out of {to_absorb_cnt + len(no_absorb_layers)} " + f"layers could be absorbed in smooth quant" + ) + return absorb_to_layer, no_absorb_layers + + +def get_parent(node, all_parents=False): + if node.inputs() is None: + return None + elif len(list(node.inputs())) == 0: + return None + if not all_parents: + return list(node.inputs())[0].node() + else: + return list(node.inputs()) + + +class GraphTrace: + """""" + + def __init__(self): + self.supported_torch_module_to_aten = { + "Linear": "aten::linear", + "Conv2d": "aten::_convolution", + "ConvTranspose2d": "aten::_convolution", + "LayerNorm": "aten::layer_norm", + "BatchNorm2d": "aten::batch_norm", + "GroupNorm": "aten::group_norm", + "InstanceNorm2d": "aten::instance_norm", + "LlamaRMSNorm": "aten::mul", + "T5LayerNorm": "aten::mul", + "LPLayerNorm": "aten::layer_norm", ##mpt_chat + } + + ##TODO potential bug, need to check only have one bug + ##TODO, must satisfy af(x)=f(ax),current skip layer may be incomplete + self.skip_ops_to_find_absorb = ["aten::to", "aten::relu", "aten::leaky_relu", "aten::hardtanh"] + + self.could_absorb_layers = [ + "aten::layer_norm", + "aten::batch_norm", + "aten::linear", + "aten::_convolution", + "aten::group_norm", + "aten::instance_norm", + "aten::mul", + ] ##TODO,support more norm + + def trace(self, model, dummy_input): + traced_model = None + optimize_numerics = False + orig_device = str(next(model.parameters()).device) + if orig_device != "cpu" and orig_device != "meta": # pragma: no cover + model = model.to("cpu") + dummy_input = move_input_to_device(dummy_input, "cpu") + if isinstance(dummy_input, dict) or isinstance(dummy_input, UserDict): + try: + traced_model = torch.jit.trace( + model, example_kwarg_inputs=dict(dummy_input), strict=False, check_trace=False + ) + traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) + except Exception as e: + logger.warning(e) + logger.warning("Jit trace in GraphTrace failed, absorb layer detection is skipped") + else: + try: + traced_model = torch.jit.trace(model, dummy_input, strict=False) + traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) + except: + try: + traced_model = torch.jit.trace(model, dummy_input[0], strict=False) + traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) + except Exception as e: + logger.warning(e) + logger.warning("Jit trace in GraphTrace failed, absorb layer detection is skipped") + model = model.to(orig_device) + return traced_model + + def get_nodes(self, traced_model, op_types=["Linear"]): + if isinstance(op_types, str): + op_types = [op_types] + nodes = [] + for node in traced_model.graph.nodes(): + node_type = node.kind() + for op_type in op_types: + if node_type == op_type: + nodes.append((node, op_type)) + break + return nodes + + def get_prev_absorb_layer(self, nodes): + prev_absorb_layer = [] + for node in nodes: + parent = get_parent(node) + while 1: + if parent.kind() in self.skip_ops_to_find_absorb: + parent = get_parent(parent) + continue + if parent.kind() in self.could_absorb_layers: + parent_out_kinds = [] + for val_user in list(parent.outputs())[0].uses(): + next_node = val_user.user + parent_out_kinds.append(next_node.kind()) + parent_out_kinds = set(parent_out_kinds) + parent_out_kinds.discard("aten::size") + + if parent_out_kinds == parent_out_kinds.intersection(self.could_absorb_layers): + prev_absorb_layer.append(parent) + elif parent_out_kinds.intersection(self.skip_ops_to_find_absorb): + res = self.skip_op_absorb_helper(parent) + prev_absorb_layer.append(parent) if res else prev_absorb_layer.append(None) + else: # When parent to multiple ops, sq transformation could be wrong. + prev_absorb_layer.append(None) + else: + prev_absorb_layer.append(None) + break + return prev_absorb_layer + + def skip_op_absorb_helper(self, parent_node): + for val_user in list(parent_node.outputs())[0].uses(): + next_node = val_user.user + if next_node.kind() == "aten::size": + continue + elif next_node.kind() in self.could_absorb_layers: + continue + elif next_node.kind() in self.skip_ops_to_find_absorb: + node_res = self.skip_op_absorb_helper(next_node) + if not node_res: + return False + else: + return False + return True + + def mapping_torch_module_to_aten(self, op_types): + res = [] + for op in op_types: + if op not in self.supported_torch_module_to_aten.keys(): + logger.warning(f"{op} is not supported in smooth quant, ignoring...") + continue + res.append(self.supported_torch_module_to_aten[op]) + res = list(set(res)) + return res + + def _check_valid_conv(self, module): + """Remove group conv except depthwise conv + :param module: + + :return: + """ + if not isinstance(module, torch.nn.Conv2d): + return True + if module.groups > 1: + if module.in_channels == module.out_channels and module.groups == module.in_channels: + return True + else: + return False + return True + + def get_absorb_to_layer(self, model, example_input, op_types, skip_unsupported_layers=True): + traced_model = self.trace(model, example_input) + if traced_model is None: + return None, None + + aten_op_types = self.mapping_torch_module_to_aten(op_types) + nodes_types = self.get_nodes(traced_model, aten_op_types) + nodes = [node_type[0] for node_type in nodes_types] + nodes_prev_absorb = self.get_prev_absorb_layer(nodes) + absorb_to_layer = {} + no_absorb_layers = [] + for index, absorb in enumerate(nodes_prev_absorb): + if absorb is None: + no_absorb_layers.append(".".join(nodes[index].scopeName().split("/")[-1].split(".")[1:])) + continue + node = nodes[index] + layer_name = ".".join(node.scopeName().split("/")[-1].split(".")[1:]) + absorb_name = ".".join(absorb.scopeName().split("/")[-1].split(".")[1:]) + if layer_name == "" or absorb_name == "": + continue + if absorb_name in absorb_to_layer.keys(): + absorb_to_layer[absorb_name].append(layer_name) + else: + absorb_to_layer[absorb_name] = [layer_name] + if skip_unsupported_layers: + absorb_to_layer = self.remove_unsupported_layers(model, absorb_to_layer, no_absorb_layers) + return absorb_to_layer, no_absorb_layers + + def remove_unsupported_layers(self, model, absorb_to_layer, no_absorb_layers): + res = {} + for key in absorb_to_layer.keys(): + absorb_layer = get_module(model, key) + layer_type = absorb_layer.__class__.__name__ + if layer_type not in self.supported_torch_module_to_aten.keys(): + no_absorb_layers.extend(absorb_to_layer[key]) + continue + supported = True + for layer_name in absorb_to_layer[key]: + layer = get_module(model, layer_name) + layer_type = layer.__class__.__name__ + if (layer_type not in self.supported_torch_module_to_aten.keys()) or not self._check_valid_conv(layer): + supported = False + no_absorb_layers.extend(absorb_to_layer[key]) + break + if supported: + res[key] = absorb_to_layer[key] + return res diff --git a/neural_compressor/torch/amp/fp8/functions.py b/neural_compressor/torch/amp/fp8/functions.py index 9a5fc277d97..49427f921f1 100644 --- a/neural_compressor/torch/amp/fp8/functions.py +++ b/neural_compressor/torch/amp/fp8/functions.py @@ -28,15 +28,14 @@ DATA_TYPE = torch.float8_e4m3fn -# without scale factor 0.9, the output will be abnormal. -E4M3_AMAX = torch.tensor(240 * 0.9, dtype=torch.float).to("hpu") -E5M2_AMAX = torch.tensor(57344 * 0.9, dtype=torch.float).to("hpu") +E4M3_AMAX = torch.tensor(240, dtype=torch.float).to("hpu") +E5M2_AMAX = torch.tensor(57344, dtype=torch.float).to("hpu") DTYPE_AMAX = E4M3_AMAX if DATA_TYPE == torch.float8_e4m3fn else E5M2_AMAX USE_AMAX = False if os.getenv("PT_USE_FP8_AMAX") is None else True -def fp8_linear_forward(input, weight, bias): +def fp8_linear_forward(input, weight, bias=None): out_dtype = torch.float32 org_middle_shape = input.shape[1:-1] input = input.view((-1, weight.shape[-1])) diff --git a/neural_compressor/torch/quantization/__init__.py b/neural_compressor/torch/quantization/__init__.py index fdb3cccae09..7892df51e23 100644 --- a/neural_compressor/torch/quantization/__init__.py +++ b/neural_compressor/torch/quantization/__init__.py @@ -29,6 +29,8 @@ get_default_teq_config, HQQConfig, get_default_hqq_config, + FP8Config, + get_default_fp8_config, ) from neural_compressor.torch.quantization.autotune import ( @@ -40,3 +42,4 @@ ### Quantization Function Registration ### import neural_compressor.torch.quantization.algorithm_entry +from neural_compressor.torch.quantization.load_entry import load diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 14598278452..0b5bddd9146 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -13,6 +13,7 @@ # limitations under the License. from copy import deepcopy +from types import MethodType from typing import Any, Callable, Dict, Tuple import torch @@ -20,6 +21,7 @@ from neural_compressor.common.utils import AWQ, FP8_QUANT, GPTQ, HQQ, RTN, STATIC_QUANT, TEQ from neural_compressor.torch.quantization import ( AWQConfig, + FP8Config, GPTQConfig, HQQConfig, RTNConfig, @@ -285,8 +287,14 @@ def hqq_entry( from neural_compressor.torch.utils import is_hpex_available if is_hpex_available(): - from neural_compressor.torch.algorithms.habana_fp8 import quantize + from neural_compressor.torch.algorithms.habana_fp8 import quantize, save @register_algo(FP8_QUANT) - def fp8_quant_entry(model, qconfig_mapping, run_fn=None, run_args=None, inplace=True): - return quantize(model, qconfig_mapping, run_fn=run_fn, run_args=run_args, inplace=inplace) + def fp8_quant_entry( + model: torch.nn.Module, configs_mapping: Dict[Tuple[str], FP8Config], *args, **kwargs + ) -> torch.nn.Module: + kwargs.pop("example_inputs") + model = quantize(model, configs_mapping, *args, **kwargs) + model.qconfig = configs_mapping + model.save = MethodType(save, model) + return model diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 3e179b9875f..be5b221132c 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -59,7 +59,6 @@ FRAMEWORK_NAME = "torch" -DTYPE_RANGE = Union[torch.dtype, List[torch.dtype]] class OperatorConfig(NamedTuple): @@ -869,85 +868,88 @@ def get_default_hqq_config() -> HQQConfig: ######################## FP8 Config ############################### -if is_hpex_available(): - - @register_config(framework_name=FRAMEWORK_NAME, algo_name=FP8_QUANT) - class FP8QConfig(BaseConfig): - """Config class for FP8 quantization.""" - - name = FP8_QUANT - supported_configs: List[OperatorConfig] = [] - params_list = [ - "weight_dtype", - "act_dtype", - "act_algo", - "approach", - "device", - ] - - def __init__( - self, - weight_dtype: DTYPE_RANGE = torch.float8_e4m3fn, - act_dtype: DTYPE_RANGE = torch.float8_e4m3fn, - act_algo: Union[str, List[str]] = "minmax", - approach: Union[str, List[str]] = "static", - device: Union[str, List[str]] = "hpu", - white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, - ): - """Init FP8 config. - - Args: - """ - super().__init__(white_list=white_list) - self.weight_dtype = weight_dtype - self.act_dtype = act_dtype - self.act_algo = act_algo - self.approach = approach - self.device = device - self._post_init() - - @classmethod - def register_supported_configs(cls) -> List[OperatorConfig]: - supported_configs = [] - fp8_config = FP8QConfig( - weight_dtype=[torch.float8_e5m2, torch.float8_e4m3fn], - act_dtype=[torch.float8_e5m2, torch.float8_e4m3fn], - act_algo=["minmax", "kl"], - approach=["static", "dynamic"], - device=["hpu"], - ) +@register_config(framework_name=FRAMEWORK_NAME, algo_name=FP8_QUANT) +class FP8Config(BaseConfig): + """Config class for FP8 quantization.""" + + name = FP8_QUANT + supported_configs: List[OperatorConfig] = [] + params_list = [ + "w_dtype", + "act_dtype", + "act_algo", + "approach", + "device", + ] + + def __init__( + self, + w_dtype: str = "fp8_e4m3", + act_dtype: str = "fp8_e4m3", + act_algo: Union[str, List[str]] = "minmax", + approach: Union[str, List[str]] = "static", + device: Union[str, List[str]] = "hpu", + white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, + ): + """Init FP8 config. + + Args: + """ + super().__init__(white_list=white_list) + self.w_dtype = w_dtype + self.act_dtype = act_dtype + self.act_algo = act_algo + self.approach = approach + self.device = device + self._post_init() + + @classmethod + def register_supported_configs(cls) -> List[OperatorConfig]: + supported_configs = [] + fp8_config = FP8Config( + w_dtype=["fp8_e5m2", "fp8_e4m3"], + act_dtype=["fp8_e5m2", "fp8_e4m3"], + act_algo=["minmax", "kl"], + approach=["static", "dynamic"], + device=["hpu"], + ) + if is_hpex_available(): from neural_compressor.torch.algorithms.habana_fp8 import white_list operators = white_list - supported_configs.append(OperatorConfig(config=fp8_config, operators=operators)) - cls.supported_configs = supported_configs + else: + operators = () + supported_configs.append(OperatorConfig(config=fp8_config, operators=operators)) + cls.supported_configs = supported_configs - @staticmethod - def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: - from neural_compressor.torch.algorithms.habana_fp8 import white_list + @staticmethod + def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: + from neural_compressor.torch.algorithms.habana_fp8 import white_list + + filter_result = [] + for op_name, module in model.named_modules(): + if isinstance(module, white_list): + pair = (op_name, type(module).__name__) + filter_result.append(pair) + logger.debug(f"Get model info: {filter_result}") + return filter_result - filter_result = [] - for op_name, module in model.named_modules(): - if isinstance(module, white_list): - pair = (op_name, type(module).__name__) - filter_result.append(pair) - logger.debug(f"Get model info: {filter_result}") - return filter_result + @classmethod + def get_config_set_for_tuning(cls) -> Union[None, "FP8Config", List["FP8Config"]]: + # TODO fwk owner needs to update it. + return FP8Config(act_algo=["minmax", "kl"]) - @classmethod - def get_config_set_for_tuning(cls) -> Union[None, "FP8QConfig", List["FP8QConfig"]]: - # TODO fwk owner needs to update it. - return FP8QConfig(act_dtype=[torch.float8_e4m3fn]) - def get_default_fp8_qconfig() -> FP8QConfig: - """Generate the default gptq config. +def get_default_fp8_config() -> FP8Config: + """Generate the default gptq config. + + Returns: + the default gptq config. + """ + return FP8Config() - Returns: - the default gptq config. - """ - return FP8QConfig() - ##################### Algo Configs End ################################### +##################### Algo Configs End ################################### register_supported_configs_for_fwk(fwk_name=FRAMEWORK_NAME) diff --git a/neural_compressor/torch/quantization/load_entry.py b/neural_compressor/torch/quantization/load_entry.py new file mode 100644 index 00000000000..52c5ad759bb --- /dev/null +++ b/neural_compressor/torch/quantization/load_entry.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +from neural_compressor.common.utils import FP8_QUANT # unified namespace +from neural_compressor.common.utils import load_config_mapping # unified namespace +from neural_compressor.torch.quantization.config import FP8Config + +config_name_mapping = { + FP8_QUANT: FP8Config, +} + + +def load(model, output_dir="./saved_results"): + from neural_compressor.common.base_config import ConfigRegistry + + qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), "qconfig.json") + config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"]) + model.qconfig = config_mapping + # select load function + config_object = config_mapping[next(iter(config_mapping))] + if isinstance(config_object, FP8Config): + from neural_compressor.torch.algorithms.habana_fp8 import load + + return load(model, output_dir) diff --git a/neural_compressor/torch/quantization/modules.py b/neural_compressor/torch/quantization/modules.py deleted file mode 100644 index 97b843d816a..00000000000 --- a/neural_compressor/torch/quantization/modules.py +++ /dev/null @@ -1,47 +0,0 @@ -# -# -*- coding: utf-8 -*- -# -# Copyright (c) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Torch.nn.Module Class Definition.""" -# Note: Do not import this file unless you have already imported torch, -# since the model classes inherit torch.nn.Module. -import math - -import torch -import torch.nn as nn - - -class Matmul(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.matmul(x, y) - - -class BatchMatmul(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.bmm(x, y) - - -class Autocast(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x diff --git a/neural_compressor/torch/utils/environ.py b/neural_compressor/torch/utils/environ.py index fbe4d2af91a..cab60b40416 100644 --- a/neural_compressor/torch/utils/environ.py +++ b/neural_compressor/torch/utils/environ.py @@ -18,7 +18,6 @@ # pylint:disable=import-error try: - import deepspeed import habana_frameworks.torch.hpex _hpex_available = True diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py index fbd182daefd..fb38b8eff53 100644 --- a/neural_compressor/torch/utils/utility.py +++ b/neural_compressor/torch/utils/utility.py @@ -30,6 +30,10 @@ WHITE_MODULE_LIST = [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d] +WEIGHT_NAME = "quantized_model.pt" +QCONFIG_NAME = "qconfig.json" + + def register_algo(name): """Decorator function to register algorithms in the algos_mapping dictionary. diff --git a/requirements_pt.txt b/requirements_pt.txt index e3129bee51a..4cc182d4c85 100644 --- a/requirements_pt.txt +++ b/requirements_pt.txt @@ -1,3 +1,2 @@ -intel_extension_for_pytorch pydantic torch diff --git a/setup.py b/setup.py index b17dc7c3ab5..071d56da9f6 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +import os import re import subprocess import sys @@ -152,6 +153,13 @@ def get_build_version(): if __name__ == "__main__": cfg_key = "neural_compressor" + + # Temporary implementation of fp8 tensor saving and loading + # Will remove after Habana torch applies below patch: + # https://github.com/pytorch/pytorch/pull/114662 + ext_modules = [] + cmdclass = {} + if "neural_insights" in sys.argv: sys.argv.remove("neural_insights") cfg_key = "neural_insights" @@ -176,6 +184,17 @@ def get_build_version(): sys.argv.remove("ort") cfg_key = "neural_compressor_3x_ort" + if bool(os.getenv("USE_FP8_CONVERT", False)): + from torch.utils.cpp_extension import BuildExtension, CppExtension + + ext_modules = [ + CppExtension( + "fp8_convert", + ["neural_compressor/torch/algorithms/habana_fp8/tensor/convert.cpp"], + ), + ] + cmdclass = {"build_ext": BuildExtension} + project_name = PKG_INSTALL_CFG[cfg_key].get("project_name") include_packages = PKG_INSTALL_CFG[cfg_key].get("include_packages") or {} package_data = PKG_INSTALL_CFG[cfg_key].get("package_data") or {} @@ -199,6 +218,8 @@ def get_build_version(): include_package_data=True, package_data=package_data, install_requires=install_requires, + ext_modules=ext_modules, # for fp8 + cmdclass=cmdclass, # for fp8 entry_points=entry_points, extras_require=extras_require, python_requires=">=3.7.0", diff --git a/test/3x/torch/quantization/fp8/test_fp8.py b/test/3x/torch/quantization/fp8/test_fp8.py deleted file mode 100644 index 6a1e4a99d2f..00000000000 --- a/test/3x/torch/quantization/fp8/test_fp8.py +++ /dev/null @@ -1,130 +0,0 @@ -import copy -import shutil -import unittest - -import torch - -from neural_compressor.torch.utils import is_hpex_available - -if is_hpex_available(): - from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic - from neural_compressor.torch.algorithms.habana_fp8.modules import ( - FP8BatchMatmul, - FP8DynamicBatchMatmul, - FP8DynamicLinear, - FP8DynamicMatmul, - FP8Linear, - FP8Matmul, - ) - from neural_compressor.torch.quantization import quantize - from neural_compressor.torch.quantization.config import FP8QConfig, get_default_fp8_qconfig - from neural_compressor.torch.quantization.modules import BatchMatmul, Matmul - - torch.set_grad_enabled(False) - - -class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.fc1 = torch.nn.Linear(10, 5) - self.fc2 = torch.nn.Linear(5, 10) - self.mm = Matmul() - self.bmm = BatchMatmul() - - def forward(self, inp): - x1 = self.fc1(inp) - x2 = self.fc2(x1) - x3 = self.mm(inp.T, x2) - x3 = x3.unsqueeze(0) - x4 = self.mm(inp.T, x2) - x4 = x4.unsqueeze(0) - x5 = self.bmm(x3, x4) - x6 = self.bmm(x3, x4) - out = x5 + x6 - return out - - -@unittest.skipIf(not is_hpex_available(), "HPEX is required for HPU inference") -class TestPytorchFP8Adaptor(unittest.TestCase): - @classmethod - def setUpClass(self): - self.model = M().to("hpu") - self.inp = torch.randn(1, 10).to("hpu") - - @classmethod - def tearDownClass(self): - shutil.rmtree("./saved", ignore_errors=True) - shutil.rmtree("./.graph_dumps", ignore_errors=True) - shutil.rmtree("runs", ignore_errors=True) - - def test_dynamic(self): - m = copy.deepcopy(self.model) - inp = self.inp - fp32_out = m(inp) - m = quantize_dynamic(m, dtype=torch.float8_e5m2, inplace=True) - self.assertTrue(isinstance(m.fc1, FP8DynamicLinear)) - self.assertTrue(isinstance(m.mm, FP8DynamicMatmul)) - self.assertTrue(isinstance(m.bmm, FP8DynamicBatchMatmul)) - print(m) - fp8_out = m(inp) - print("Dynamic quantization FP8_E5M2 MSE:", (fp32_out - fp8_out).pow(2).sum()) - - m = copy.deepcopy(self.model) - inp = self.inp - fp32_out = m(inp) - m = quantize_dynamic(m, dtype=torch.float8_e4m3fn, inplace=True) - self.assertTrue(isinstance(m.fc1, FP8DynamicLinear)) - self.assertTrue(isinstance(m.mm, FP8DynamicMatmul)) - self.assertTrue(isinstance(m.bmm, FP8DynamicBatchMatmul)) - print(m) - fp8_out = m(inp) - print("Dynamic quantization FP8_E4M3 MSE:", (fp32_out - fp8_out).pow(2).sum()) - - m = copy.deepcopy(self.model) - inp = self.inp - fp32_out = m(inp) - qconfig = FP8QConfig(approach="dynamic") - m = quantize(m, qconfig, inplace=True) - self.assertTrue(isinstance(m.fc1, FP8DynamicLinear)) - self.assertTrue(isinstance(m.mm, FP8DynamicMatmul)) - self.assertTrue(isinstance(m.bmm, FP8DynamicBatchMatmul)) - print(m) - fp8_out = m(inp) - print("Dynamic quantization FP8_E4M3 MSE:", (fp32_out - fp8_out).pow(2).sum()) - - def test_static(self): - m = copy.deepcopy(self.model) - inp = self.inp - fp32_out = m(inp) - qconfig = FP8QConfig(weight_dtype=torch.float8_e5m2, act_dtype=torch.float8_e5m2, approach="static") - - def calib_func(model): - model(inp) - - m = quantize(m, qconfig, run_fn=calib_func, inplace=True) - self.assertTrue(isinstance(m.fc1, FP8Linear)) - self.assertTrue(isinstance(m.mm, FP8Matmul)) - self.assertTrue(isinstance(m.bmm, FP8BatchMatmul)) - print(m) - fp8_out = m(inp) - print("Static quantization FP8_E5M2 MSE:", (fp32_out - fp8_out).pow(2).sum()) - - m = copy.deepcopy(self.model) - inp = self.inp - fp32_out = m(inp) - qconfig = get_default_fp8_qconfig() - - def calib_func(model): - model(inp) - - m = quantize(m, qconfig, run_fn=calib_func, inplace=True) - self.assertTrue(isinstance(m.fc1, FP8Linear)) - self.assertTrue(isinstance(m.mm, FP8Matmul)) - self.assertTrue(isinstance(m.bmm, FP8BatchMatmul)) - print(m) - fp8_out = m(inp) - print("Static quantization FP8_E4M3 MSE:", (fp32_out - fp8_out).pow(2).sum()) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/3x/torch/quantization/habana_fp8/test_fp8.py b/test/3x/torch/quantization/habana_fp8/test_fp8.py new file mode 100644 index 00000000000..41e3af35870 --- /dev/null +++ b/test/3x/torch/quantization/habana_fp8/test_fp8.py @@ -0,0 +1,166 @@ +import copy +import shutil + +import pytest +import torch + +from neural_compressor.torch.utils import is_hpex_available + +if is_hpex_available(): + from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic + from neural_compressor.torch.algorithms.habana_fp8.modules import ( + BatchMatmul, + FP8BatchMatmul, + FP8DynamicBatchMatmul, + FP8DynamicLinear, + FP8DynamicMatmul, + FP8Linear, + FP8Matmul, + Matmul, + ) + from neural_compressor.torch.quantization import quantize + from neural_compressor.torch.quantization.config import FP8Config, get_default_fp8_config + + torch.set_grad_enabled(False) + + +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(10, 5) + self.fc2 = torch.nn.Linear(5, 10) + self.mm = Matmul() + self.bmm = BatchMatmul() + + def forward(self, inp): + x1 = self.fc1(inp) + x2 = self.fc2(x1) + x3 = self.mm(inp.T, x2) + x3 = x3.unsqueeze(0) + x4 = self.mm(inp.T, x2) + x4 = x4.unsqueeze(0) + x5 = self.bmm(x3, x4) + x6 = self.bmm(x3, x4) + out = x5 + x6 + return out + + +@pytest.mark.skipif(not is_hpex_available(), reason="no hpex in environment here.") +class TestPytorchFP8Adaptor: + def setup_class(self): + self.model = M().to("hpu") + self.inp = torch.randn(1, 10).to("hpu") + + def teardown_class(self): + shutil.rmtree("./saved", ignore_errors=True) + shutil.rmtree("./.graph_dumps", ignore_errors=True) + shutil.rmtree("runs", ignore_errors=True) + + def test_dynamic_accu(self): + m = copy.deepcopy(self.model) + inp = self.inp + fp32_out = m(inp) + m = quantize_dynamic(m, dtype="fp8_e5m2", inplace=True) + assert isinstance(m.fc1, FP8DynamicLinear), "Unexpected result. Please double check." + assert isinstance(m.mm, FP8DynamicMatmul), "Unexpected result. Please double check." + assert isinstance(m.bmm, FP8DynamicBatchMatmul), "Unexpected result. Please double check." + print(m) + fp8_out = m(inp) + print("Dynamic quantization FP8_E5M2 MSE:", (fp32_out - fp8_out).pow(2).sum()) + + m = copy.deepcopy(self.model) + inp = self.inp + fp32_out = m(inp) + m = quantize_dynamic(m, dtype="fp8_e4m3", inplace=True) + assert isinstance(m.fc1, FP8DynamicLinear), "Unexpected result. Please double check." + assert isinstance(m.mm, FP8DynamicMatmul), "Unexpected result. Please double check." + assert isinstance(m.bmm, FP8DynamicBatchMatmul), "Unexpected result. Please double check." + print(m) + fp8_out = m(inp) + print("Dynamic quantization FP8_E4M3 MSE:", (fp32_out - fp8_out).pow(2).sum()) + + m = copy.deepcopy(self.model) + inp = self.inp + fp32_out = m(inp) + qconfig = FP8Config(approach="dynamic") + m = quantize(m, qconfig, inplace=True) + assert isinstance(m.fc1, FP8DynamicLinear), "Unexpected result. Please double check." + assert isinstance(m.mm, FP8DynamicMatmul), "Unexpected result. Please double check." + assert isinstance(m.bmm, FP8DynamicBatchMatmul), "Unexpected result. Please double check." + print(m) + fp8_out = m(inp) + print("Dynamic quantization FP8_E4M3 MSE:", (fp32_out - fp8_out).pow(2).sum()) + + def test_static_accu(self): + m = copy.deepcopy(self.model) + inp = self.inp + fp32_out = m(inp) + qconfig = FP8Config(w_dtype="fp8_e5m2", act_dtype="fp8_e5m2", approach="static") + + def calib_func(model): + model(inp) + + m = quantize(m, qconfig, run_fn=calib_func, inplace=True) + assert isinstance(m.fc1, FP8Linear), "Unexpected result. Please double check." + assert isinstance(m.mm, FP8Matmul), "Unexpected result. Please double check." + assert isinstance(m.bmm, FP8BatchMatmul), "Unexpected result. Please double check." + print(m) + fp8_out = m(inp) + print("Static quantization FP8_E5M2 MSE:", (fp32_out - fp8_out).pow(2).sum()) + + m = copy.deepcopy(self.model) + inp = self.inp + fp32_out = m(inp) + qconfig = get_default_fp8_config() + + def calib_func(model): + model(inp) + + m = quantize(m, qconfig, run_fn=calib_func, inplace=True) + assert isinstance(m.fc1, FP8Linear), "Unexpected result. Please double check." + assert isinstance(m.mm, FP8Matmul), "Unexpected result. Please double check." + assert isinstance(m.bmm, FP8BatchMatmul), "Unexpected result. Please double check." + print(m) + fp8_out = m(inp) + print("Static quantization FP8_E4M3 MSE:", (fp32_out - fp8_out).pow(2).sum()) + + def test_convert(self): + # Temporary implementation of fp8 tensor saving and loading + # Will remove after Habana torch applies below patch: + # https://github.com/pytorch/pytorch/pull/114662 + # e4m3 + fp8_inp = torch.ops.hpu.cast_to_fp8_v2(self.inp, 500, dtype=torch.float8_e4m3fn)[0].to("cpu") + import fp8_convert + + int8_inp = fp8_convert.to_u8(fp8_inp) + torch.save(int8_inp, "tmp.pt") + saved_int8_inp = torch.load("tmp.pt") + recovered_inp = fp8_convert.from_u8(saved_int8_inp, 1) + assert (fp8_inp == recovered_inp).all(), "Unexpected result. Please double check." + # e5m2 + fp8_inp = torch.ops.hpu.cast_to_fp8_v2(self.inp, 500, dtype=torch.float8_e5m2)[0].to("cpu") + int8_inp = fp8_convert.to_u8(fp8_inp) + recovered_inp = fp8_convert.from_u8(int8_inp, 0) + assert (fp8_inp == recovered_inp).all(), "Unexpected result. Please double check." + + def test_save_load(self): + m = copy.deepcopy(self.model) + inp = self.inp + qconfig = get_default_fp8_config() + + def calib_func(model): + model(inp) + + m = quantize(m, qconfig, run_fn=calib_func, inplace=True) + fp8_out = m(inp) + m.save("saved_results") + + from neural_compressor.torch.quantization import load + + m = copy.deepcopy(self.model) + m = load(m, "saved_results") + recovered_out = m(inp) + assert (recovered_out == fp8_out).all(), "Unexpected result. Please double check." + assert isinstance(m.fc1, FP8Linear), "Unexpected result. Please double check." + assert isinstance(m.mm, FP8Matmul), "Unexpected result. Please double check." + assert isinstance(m.bmm, FP8BatchMatmul), "Unexpected result. Please double check." diff --git a/test/3x/torch/requirements.txt b/test/3x/torch/requirements.txt index 664d541d556..48c8f3cd632 100644 --- a/test/3x/torch/requirements.txt +++ b/test/3x/torch/requirements.txt @@ -1,3 +1,4 @@ +intel-extension-for-pytorch numpy prettytable psutil