Skip to content

Commit

Permalink
add code-generaion evaluation for woq gptq (#1475)
Browse files Browse the repository at this point in the history
Signed-off-by: changwangss <chang1.wang@intel.com>
Signed-off-by: YIYANGCAI <yiyang.cai@intel.com>
Signed-off-by: chensuyue <suyue.chen@intel.com>
  • Loading branch information
changwangss authored Dec 28, 2023
1 parent c88d765 commit 7634409
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 25 deletions.
23 changes: 22 additions & 1 deletion examples/.config/model_params_pytorch.json
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,13 @@
"main_script": "run_clm_no_trainer.py",
"batch_size": 8
},
"opt_125m_woq_gptq_debug_int4":{
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
"dataset_location": "",
"input_model": "",
"main_script": "run_clm_no_trainer.py",
"batch_size": 8
},
"opt_125m_woq_teq":{
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
"dataset_location": "",
Expand Down Expand Up @@ -513,7 +520,14 @@
"main_script": "run_clm_no_trainer.py",
"batch_size": 1
},
"gpt_j_woq_rtn":{
"gpt_j_woq_rtn_int4":{
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
"dataset_location": "",
"input_model": "",
"main_script": "run_clm_no_trainer.py",
"batch_size": 1
},
"gpt_j_woq_gptq_debug_int4":{
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
"dataset_location": "",
"input_model": "",
Expand All @@ -527,6 +541,13 @@
"main_script": "run_clm_no_trainer.py",
"batch_size": 1
},
"falcon_7b_woq_gptq_debug_int4":{
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
"dataset_location": "",
"input_model": "",
"main_script": "run_clm_no_trainer.py",
"batch_size": 1
},
"xlm-roberta-base_MRPC": {
"model_src_dir": "nlp/huggingface_models/text-classification/quantization/ptq_static/fx",
"dataset_location": "",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ accelerate
protobuf
sentencepiece != 0.1.92
datasets >= 1.1.3
peft
torch >= 1.10
transformers
pytest
wandb
einops
neural-compressor
intel-extension-for-transformers
git+https://github.com/EleutherAI/lm-evaluation-harness.git@83dbfbf6070324f3e5872f63e49d49ff7ef4c9b3
git+https://github.com/huggingface/peft.git@6c44096c7b8d55a2ecf24be9bc68393467e1584a
git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
help="calibration iters.")
parser.add_argument("--tasks", nargs='+', default=["lambada_openai",
"hellaswag", "winogrande", "piqa", "wikitext"],
type=str, help="tasks list for accuracy validation")
type=str, help="tasks list for accuracy validation, text-generation and code-generation tasks are different.")
parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model")
# ============SmoothQuant configs==============
parser.add_argument("--sq", action="store_true")
Expand All @@ -78,7 +78,40 @@
this should align with your model config, \
and your dataset builder args: args.pad_max_length')
parser.add_argument('--gptq_debug', action='store_true', help='Whether to use debug model ')
# =======================================
# ==============code generation args===========
parser.add_argument("--code_generation", action="store_true")
parser.add_argument("--n_samples", default=200, type=int)
parser.add_argument(
"--limit", default=None, type=int, help="Limit number of samples to eval"
)
parser.add_argument("--allow_code_execution", action="store_true")
parser.add_argument("--prefix", default="")
parser.add_argument("--generation_only", action="store_true")
parser.add_argument("--postprocess", action="store_false")
parser.add_argument("--save_references", action="store_true")
parser.add_argument("--save_generations", action="store_true")
parser.add_argument("--instruction_tokens", default=None)
parser.add_argument("--save_generations_path", default="generations.json")
parser.add_argument("--load_generations_path", default=None)
parser.add_argument("--metric_output_path", default="evaluation_results.json")
parser.add_argument("--max_length_generation", default=512, type=int)
parser.add_argument("--temperature", default=0.8, type=float)
parser.add_argument("--top_p", default=0.8, type=float)
parser.add_argument("--top_k", default=0, type=int)
parser.add_argument("--do_sample", action="store_true")
parser.add_argument("--check_references", action="store_true")
parser.add_argument("--max_memory_per_gpu", type=str, default=None)
parser.add_argument(
"--modeltype",
default="causal",
help="AutoModel to use, it can be causal or seq2seq",
)
parser.add_argument(
"--limit_start",
type=int,
default=0,
help="Optional offset to start from when limiting the number of samples",
)

args = parser.parse_args()
if args.ipex:
Expand Down Expand Up @@ -262,7 +295,7 @@ def calib_func(prepared_model):
if args.gptq_debug:
from neural_compressor.adaptor.torch_utils.weight_only import gptq_quantize

conf = {
gptq_conf = {
".*": {
'wbits': args.woq_bits, # 1-8 bits
'group_size': args.woq_group_size, # -1 (per-channel)
Expand All @@ -272,20 +305,16 @@ def calib_func(prepared_model):
}
q_model_gptq_debug, gptq_config = gptq_quantize(
user_model,
weight_config=conf,
weight_config=gptq_conf,
dataloader=calib_dataloader,
nsamples=args.gptq_nsamples,
use_max_length=args.gptq_use_max_length,
pad_max_length=args.gptq_pad_max_length
pad_max_length=args.gptq_pad_max_length,
)
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate

results = evaluate(
model="hf-causal",
model_args='pretrained=' + args.model + ',tokenizer=' + args.model + ',dtype=float32',
user_model=q_model_gptq_debug, tasks=["lambada_openai"],
batch_size=4
)
# save the fake quantized model
os.makedirs(args.output_dir, exist_ok=True)
torch.save(q_model_gptq_debug, os.path.join(args.output_dir, "gptq_best_model.pt"))
exit(0)

else:
Expand Down Expand Up @@ -317,7 +346,6 @@ def calib_func(prepared_model):
eval_dataset = load_dataset('lambada', split='validation')
evaluator = Evaluator(eval_dataset, tokenizer)


def eval_func(model):
acc = evaluator.evaluate(model)
return acc
Expand Down Expand Up @@ -347,15 +375,29 @@ def eval_func(model):

if args.accuracy:
user_model.eval()
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
if args.gptq_debug:
user_model = torch.load(os.path.join(args.output_dir, "gptq_best_model.pt"))
if args.code_generation:
from intel_extension_for_transformers.llm.evaluation.lm_code_eval import evaluate
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model)
results = evaluate(
model=user_model,
tokenizer=tokenizer,
tasks=",".join(args.tasks),
batch_size=args.batch_size,
args=args,
)
else:
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
results = evaluate(
model="hf-causal",
model_args='pretrained=' + args.model + ',tokenizer=' + args.model + ',dtype=float32',
user_model=user_model,
batch_size=args.batch_size,
tasks=args.tasks,
)

results = evaluate(
model="hf-causal",
model_args='pretrained=' + args.model + ',tokenizer=' + args.model + ',dtype=float32',
user_model=user_model,
batch_size=args.batch_size,
tasks=args.tasks,
)
dumped = json.dumps(results, indent=2)
if args.save_accuracy_path:
with open(args.save_accuracy_path, "w") as f:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ function run_tuning {
model_name_or_path="facebook/opt-125m"
approach="weight_only"
extra_cmd=$extra_cmd" --woq_algo GPTQ"
elif [ "${topology}" = "opt_125m_woq_gptq_debug_int4" ]; then
model_name_or_path="facebook/opt-125m"
approach="weight_only"
extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_scheme asym --woq_group_size 128 --gptq_use_max_length --gptq_debug"
elif [ "${topology}" = "opt_125m_woq_teq" ]; then
model_name_or_path="facebook/opt-125m"
approach="weight_only"
Expand All @@ -69,13 +73,21 @@ function run_tuning {
elif [ "${topology}" = "gpt_j_ipex_sq" ]; then
model_name_or_path="EleutherAI/gpt-j-6b"
extra_cmd=$extra_cmd" --ipex --sq --alpha 1.0"
elif [ "${topology}" = "gpt_j_woq_rtn" ]; then
elif [ "${topology}" = "gpt_j_woq_rtn_int4" ]; then
model_name_or_path="EleutherAI/gpt-j-6b"
approach="weight_only"
extra_cmd=$extra_cmd" --woq_algo RTN --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_enable_mse_search"
elif [ "${topology}" = "gpt_j_woq_gptq_debug_int4" ]; then
model_name_or_path="EleutherAI/gpt-j-6b"
approach="weight_only"
extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --gptq_use_max_length --gptq_debug"
elif [ "${topology}" = "falcon_7b_sq" ]; then
model_name_or_path="tiiuae/falcon-7b-instruct"
extra_cmd=$extra_cmd" --sq --alpha 0.5"
elif [ "${topology}" = "falcon_7b_woq_gptq_debug_int4" ]; then
model_name_or_path="tiiuae/falcon-7b-instruct"
approach="weight_only"
extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --gptq_use_max_length --gptq_debug"
fi

python -u run_clm_no_trainer.py \
Expand Down

0 comments on commit 7634409

Please sign in to comment.