Skip to content

Commit

Permalink
update fp8 implementation, design and implement save&load (#1605)
Browse files Browse the repository at this point in the history
Signed-off-by: xinhe3 <xinhe3@habana.ai>
  • Loading branch information
xin3he authored Feb 27, 2024
1 parent a8d81ca commit f812e67
Show file tree
Hide file tree
Showing 31 changed files with 2,676 additions and 510 deletions.
Original file line number Diff line number Diff line change
@@ -1,30 +1,33 @@
import os
os.environ["EXPERIMENTAL_WEIGHT_SHARING"] = "False"
os.environ["USE_GAUDI2_SCALE"] = "True"
os.environ.pop("USE_GAUDI2_SCALE") # gaudi scale work
# os.environ["GRAPH_VISUALIZATION"] = "True"
# import shutil
# shutil.rmtree(".graph_dumps", ignore_errors=True)
import argparse
import time
import json
import re
import torch
import transformers
import os
import deepspeed
from transformers import AutoModelForCausalLM, AutoTokenizer
import habana_frameworks.torch.hpex
from habana_frameworks.torch.hpu import memory_stats
import torch.nn.functional as F
import deepspeed
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import habana_frameworks.torch.core as htcore
import numpy as np
import lm_eval
import lm_eval.tasks
import lm_eval.evaluator
torch.set_grad_enabled(False)
from accelerate import init_empty_weights
from utils import itrex_bootstrap_stderr, show_msg, save_to_excel


def itrex_bootstrap_stderr(f, xs, iters):
from lm_eval.metrics import _bootstrap_internal, sample_stddev
res = []
chunk_size = min(1000, iters)
it = _bootstrap_internal(f, chunk_size)
for i in range(iters // chunk_size):
bootstrap = it((i, xs))
res.extend(bootstrap)
return sample_stddev(res)
torch.set_grad_enabled(False)
htcore.hpu_set_env()
torch.device('hpu')


# to avoid out-of-memory caused by Popen for large language models.
lm_eval.metrics.bootstrap_stderr = itrex_bootstrap_stderr
Expand All @@ -51,22 +54,26 @@ def itrex_bootstrap_stderr(f, xs, iters):
parser.add_argument("--accuracy", action="store_true")
parser.add_argument("--performance", action="store_true")
parser.add_argument("--generate", action="store_true")
parser.add_argument("--skip_fp8_mm", action="store_true")
parser.add_argument("--dump_to_excel", action="store_true")
parser.add_argument("--save", action="store_true")
parser.add_argument("--load", action="store_true")
parser.add_argument("--batch_size", default=1, type=int,
help="For accuracy measurement only.")
parser.add_argument("--pad_max_length", default=512, type=int,
help="Pad input ids to max length.")
parser.add_argument("--calib_iters", default=100, type=int,
help="calibration iters.")
parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], type=str, \
choices=["winogrande", "copa", "piqa", "rte", "hellaswag", \
"openbookqa", "lambada_openai", "lambada_standard", "wikitext"],
parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], \
type=str, choices=["hellaswag", "lambada_openai", "piqa", "winogrande", "copa",
"rte", "openbookqa", "lambada_standard", "wikitext"],
help="tasks list for accuracy validation")
parser.add_argument("--limit", default=None, type=int,
help="the sample num of evaluation.")
parser.add_argument("--max_new_tokens", default=100, type=int,
help="calibration iters.")
parser.add_argument('--buckets', type=int, nargs='+', \
help="Input length buckets to use with static_shapes", default=[129])
help="Input length buckets to use with static_shapes", default=[256, 512])
parser.add_argument("--local_rank",
type=int,
default=-1,
Expand All @@ -78,67 +85,65 @@ def itrex_bootstrap_stderr(f, xs, iters):
world_size = int(os.getenv('WORLD_SIZE', '1'))
local_rank = int(os.getenv('LOCAL_RANK', '-1'))

#if local_rank == 0:
# os.environ["ENABLE_CONSOLE"] = 'True'
# os.environ["LOG_LEVEL_ALL"] = '0'

# model
model_dtype = torch.float32
if re.search("llama", args.model.lower()) or re.search("bloom", args.model.lower()):
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
torch.device('hpu')
config = AutoConfig.from_pretrained(args.model)
if world_size > 1:
model_dtype = torch.bfloat16
config = AutoConfig.from_pretrained(args.model)
model_dtype = torch.bfloat16 # RuntimeErrorCastToFp8V2 input must be of float or bfloat16 dtype
deepspeed.init_distributed(dist_backend="hccl")
with deepspeed.OnDevice(dtype=model_dtype, device="meta"):
user_model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype)
import tempfile
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
from utils import write_checkpoints_json
from optimum.habana.checkpoint_utils import write_checkpoints_json # in optimum-habana
write_checkpoints_json(
args.model,
local_rank,
checkpoints_json,
token=None,
args.model,
local_rank,
checkpoints_json,
token=None,
)
elif re.search("llama", args.model.lower()):
from models.modeling_llama import LlamaForCausalLM
user_model = LlamaForCausalLM.from_pretrained(
else:
if args.load:
config = AutoConfig.from_pretrained(args.model)
with init_empty_weights():
user_model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype)
else:
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
device_map='hpu',
torch_dtype=model_dtype,
)
elif re.search("chatglm", args.model.lower()):
if args.load:
config = AutoConfig.from_pretrained(args.model, torch_dtype=model_dtype)
with init_empty_weights():
user_model = AutoModelForCausalLM.from_config(config)
else:
from models.modeling_chatglm import ChatGLMForConditionalGeneration
user_model = ChatGLMForConditionalGeneration.from_pretrained(
args.model,
revision=args.revision,
device_map='hpu',
torch_dtype=model_dtype,
)
# print(user_model.transformer.output_layer.weight.dtype) # always fp16
user_model.float() # static fp8 need float32 for graph compiler
else:
if args.load:
config = AutoConfig.from_pretrained(args.model)
with init_empty_weights():
user_model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype)
else:
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
trust_remote_code=args.trust_remote_code,
revision=args.revision,
device_map='hpu',
torch_dtype=model_dtype,
)
elif re.search("chatglm", args.model.lower()):
from models.modeling_chatglm import ChatGLMForConditionalGeneration
user_model = ChatGLMForConditionalGeneration.from_pretrained(
args.model,
revision=args.revision,
device_map='hpu',
)
else:
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
trust_remote_code=args.trust_remote_code,
revision=args.revision,
device_map='hpu',
)

# tokenizer
if re.search("baichuan", args.model.lower()):
from models.tokenization_baichuan import BaichuanTokenizer
tokenizer = BaichuanTokenizer.from_pretrained(
args.model,
trust_remote_code=args.trust_remote_code
)
else:
tokenizer = AutoTokenizer.from_pretrained(
args.model,
trust_remote_code=args.trust_remote_code
)

if world_size > 1:
if re.search("llama", args.model.lower()):
Expand All @@ -148,36 +153,44 @@ def itrex_bootstrap_stderr(f, xs, iters):
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
ds_inference_kwargs["injection_policy"] = {LlamaDecoderLayer: ("self_attn.o_proj", "mlp.down_proj")}
ds_inference_kwargs["checkpoint"] = checkpoints_json.name

ds_model = deepspeed.init_inference(user_model, **ds_inference_kwargs)
else:
ds_model = deepspeed.init_inference(user_model,
mp_size=world_size,
replace_with_kernel_inject=False)
user_model = ds_model.module


# tokenizer
if re.search("baichuan", args.model.lower()):
from models.tokenization_baichuan import BaichuanTokenizer
tokenizer = BaichuanTokenizer.from_pretrained(
args.model,
trust_remote_code=args.trust_remote_code
)
else:
tokenizer = AutoTokenizer.from_pretrained(
args.model,
trust_remote_code=args.trust_remote_code
)


user_model.eval()

if args.approach in ["dynamic", "static"]:

### dynamic & static quantization ###
if args.approach in ["dynamic", "static"] and not args.load:
print("device:", next(user_model.parameters()).device)
from neural_compressor.torch.quantization.config import FP8QConfig, get_default_fp8_qconfig
from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic
from neural_compressor.torch.quantization.config import FP8Config, get_default_fp8_config
from neural_compressor.torch.quantization import quantize
if args.precision == "fp8_e4m3":
dtype = torch.float8_e4m3fn
else:
dtype = torch.float8_e5m2
dtype = args.precision
if args.approach == "dynamic":
#user_model = quantize_dynamic(user_model, dtype, inplace=True)
qconfig = FP8QConfig(weight_dtype=dtype, act_dtype=dtype, approach="dynamic")
if args.skip_lm_head:
fp32_config = FP8QConfig(weight_dtype=torch.float32, act_dtype=torch.float32)
qconfig.set_local("lm_head", fp32_config)
user_model = quantize_dynamic(user_model, qconfig, inplace=True)
from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic
user_model = quantize_dynamic(user_model, dtype, inplace=True)
elif args.approach == "static":
qconfig = FP8QConfig(weight_dtype=dtype, act_dtype=dtype, approach="static")
qconfig = FP8Config(w_dtype=dtype, act_dtype=dtype, approach="static")
if args.skip_lm_head:
fp32_config = FP8QConfig(weight_dtype=torch.float32, act_dtype=torch.float32)
fp32_config = FP8Config(w_dtype="fp32", act_dtype="fp32")
qconfig.set_local("lm_head", fp32_config)
# dataset
from datasets import load_dataset
Expand All @@ -186,7 +199,13 @@ def itrex_bootstrap_stderr(f, xs, iters):
calib_data = []
for examples in calib_dataset:
calib_data.append(
tokenizer(examples["text"], return_tensors="pt", max_length=128)
tokenizer(
examples["text"],
return_tensors="pt",
max_length=64,
padding="max_length",
truncation=True
)
)

def calib_func(model):
Expand All @@ -199,12 +218,46 @@ def calib_func(model):
)

user_model = quantize(user_model, qconfig, calib_func, inplace=True)
print(user_model, flush=True)
# saving
if args.save and local_rank in [-1, 0]:
user_model.save("saved_results")


if args.load:
from neural_compressor.torch.quantization import load
user_model = load(user_model, "saved_results")


if args.approach in ["dynamic", "static"] or args.load:
# It enables weights constant folding
from habana_frameworks.torch.core.quantization import _check_params_as_const, _mark_params_as_const
_mark_params_as_const(user_model) # can reduce memory allocated and speed up
_check_params_as_const(user_model)



# If torch.matmul and torch.bmm are not replaced by INC module,
# Below codes can make torch.matmul and torch.bmm run on fp8 by injection.
if not args.skip_fp8_mm and args.precision in ['fp8_e4m3', 'fp8_e5m2']:
def replace_torch_mm_bmm():
from neural_compressor.torch.amp.fp8.functions import fp8_matmul
torch.matmul = fp8_matmul
torch.bmm = fp8_matmul

replace_torch_mm_bmm()


# inference optimization
if args.to_graph:
import habana_frameworks.torch.hpu.graphs as htgraphs
user_model = htgraphs.wrap_in_hpu_graph(user_model)


# dump message of HPU after quantization or reloading
show_msg()


### generation, performance and accuracy validation ###
if args.generate:
input_prompt = "Here is my prompt"
print("Prompt sentence:", input_prompt)
Expand Down Expand Up @@ -234,6 +287,7 @@ def calib_func(model):
print("Generated sentence:", output_sentence)
print("Duration:", eval_end - eval_start)


if args.performance:
eval_start = time.perf_counter()
input_prompt = "Intel is a company which"
Expand All @@ -242,6 +296,7 @@ def calib_func(model):
outputs = user_model.generate(input_tokens, **generation_config)
print("Duration of generating 100 tokens :", time.perf_counter() - eval_start)


if args.accuracy:

class HabanaModelAdapter(lm_eval.base.BaseLM):
Expand Down Expand Up @@ -292,16 +347,14 @@ def find_bucket(self, length):
return [b for b in self.buckets if b >= length][0]

def _model_call(self, inps):
#print(inps.shape)
seq_length = inps.shape[-1]
padding_length = 0
bucket_length = self.find_bucket(seq_length)
padding_length = bucket_length - seq_length
if True:
import torch.nn.functional as F
inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id)
inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id)
logits = self.model(inps.to(self._device))["logits"].cpu()

logits = self.model(inps.to(self._device))['logits']
if True and padding_length > 0:
if padding_length > 0:
logits = logits[:, :-padding_length, :]
logits = logits.to(torch.float32)
return logits
Expand Down Expand Up @@ -333,18 +386,18 @@ def _model_call(self, inps):


dumped = json.dumps(results, indent=2)
accu_dict = {}
case_name = args.approach + "-" + args.precision
for task_name in args.tasks:
if task_name == "wikitext":
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]), flush=True)
accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["word_perplexity"]]
else:
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]), flush=True)
accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["acc"]]
if args.dump_to_excel and local_rank in [-1, 0]:
save_to_excel(accu_dict)


# show memory usage
mem_stats = memory_stats()
mem_dict = {
"memory_allocated (GB)": np.round(mem_stats["InUse"] / 1024**3, 2),
"max_memory_allocated (GB)": np.round(mem_stats["MaxInUse"] / 1024**3, 2),
"total_memory_available (GB)": np.round(mem_stats["Limit"] / 1024**3, 2),
}
for k, v in mem_dict.items():
print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v))
# dump final message of HPU
show_msg()
Loading

0 comments on commit f812e67

Please sign in to comment.