forked from amirzandieh/QJL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_longbench.py
223 lines (186 loc) · 7.95 KB
/
run_longbench.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import os
import argparse
import random
import time
import numpy as np
import torch
import json
from tqdm import tqdm
from transformers import LlamaConfig, AutoTokenizer
from datasets import load_dataset
from eval_long_bench import dataset2metric
from fastchat.model import get_conversation_template
from models.llama2_utils_qjl import QJLSketch
from models.llama2_qjl import LlamaForCausalLM_QJL
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.cuda.manual_seed_all(seed)
def build_chat(prompt, model_name):
if "llama" in model_name:
prompt = f"[INST]{prompt}[/INST]"
elif "longchat" in model_name or "vicuna" in model_name:
conv = get_conversation_template("vicuna")
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
else:
raise NotImplementedError
return prompt
def setup_model_and_tokenizer(
model_name,
dtype=torch.float16,
key_quantization_bits=256,
key_quantization_bits_initial_layers=512,
initial_layers_count=15,
outlier_count_general=8,
outlier_count_initial_layers=8,
value_quantization_bits=2,
group_size=32,
buffer_size=128,
):
device = 'cuda'
config = LlamaConfig.from_pretrained(model_name)
config._flash_attn_2_enabled = True
tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_fast=False,
trust_remote_code=True,
tokenizer_type='llama'
)
config = LlamaConfig.from_pretrained(model_name)
config.attention_dropout = 0.0
config.key_quantization_bits = key_quantization_bits
config.key_quantization_bits_initial_layers = key_quantization_bits_initial_layers
config.initial_layers_count = initial_layers_count
config.outlier_count_general = outlier_count_general
config.outlier_count_initial_layers = outlier_count_initial_layers
config.value_quantization_bits = value_quantization_bits
config.group_size = group_size
config.buffer_size = buffer_size
generator = torch.Generator(device=torch.device(device))
config.qjl = QJLSketch(dim=(128, config.key_quantization_bits), dim_outlier=256, rot=True, rng=generator)
config.qjl_initial_layers = QJLSketch(dim=(128, config.key_quantization_bits_initial_layers), dim_outlier=128,
rot=True,
rng=generator)
config.use_flash = True
model_qjl = LlamaForCausalLM_QJL.from_pretrained(
pretrained_model_name_or_path=model_name,
config=config,
cache_dir=None,
torch_dtype=dtype,
low_cpu_mem_usage=True,
device_map="auto"
)
return model_qjl, tokenizer
def parse_args(args=None):
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default="lmsys/longchat-7b-v1.5-32k")
parser.add_argument('--dtype', type=str, default="float16", choices=["float16", "float32"])
parser.add_argument('--key_quantization_bits', type=int, default=256)
parser.add_argument('--key_quantization_bits_initial_layers', type=int, default=512)
parser.add_argument('--initial_layers_count', type=int, default=15)
parser.add_argument('--outlier_count_general', type=int, default=8)
parser.add_argument('--outlier_count_initial_layers', type=int, default=8)
parser.add_argument('--value_quantization_bits', type=int, default=2)
parser.add_argument('--group_size', type=int, default=32)
parser.add_argument('--buffer_size', type=int, default=128)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--dataset_name', type=str, required=True)
parser.add_argument('--n_data', type=int, default=150)
return parser.parse_args(args)
def load_configurations(config_dir):
with open(os.path.join(config_dir, 'dataset2maxlen.json')) as f:
dataset2maxlen = json.load(f)
with open(os.path.join(config_dir, 'dataset2prompt.json')) as f:
dataset2prompt = json.load(f)
with open(os.path.join(config_dir, 'model2maxlen.json')) as f:
model2maxlen = json.load(f)
return dataset2maxlen, dataset2prompt, model2maxlen
def evaluate_model(
model_qjl,
tokenizer,
dataset_name,
dataset2maxlen,
dataset2prompt,
model2maxlen,
n_data=150,
):
device = 'cuda'
prompt_format = dataset2prompt.get(dataset_name,
"Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:")
max_length = dataset2maxlen.get(dataset_name, 31500)
max_gen = model2maxlen.get(dataset_name, 64)
data = load_dataset('THUDM/LongBench', f"{dataset_name}_e", split='test')
total_score = 0.
aa = []
start = time.time()
for i in tqdm(range(n_data), desc="Evaluating"):
json_obj = data[i]
prompt = prompt_format.format(**json_obj)
tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
if len(tokenized_prompt) > max_length:
half = int(max_length / 2)
prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True) + tokenizer.decode(
tokenized_prompt[-half:], skip_special_tokens=True)
if dataset_name not in ["trec", "triviaqa", "samsum", "lsht", "lcc",
"repobench-p"]:
prompt = build_chat(prompt, model_qjl.config.name_or_path)
input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
context_length = input.input_ids.shape[-1]
output = model_qjl.generate(
**input,
max_new_tokens=max_gen,
num_beams=1,
do_sample=False,
temperature=1.0,
)[0]
pred = tokenizer.decode(output[context_length:], skip_special_tokens=True)
ground_truths = json_obj['answers']
all_classes = json_obj['all_classes']
prediction = pred
score = 0.
for ground_truth in ground_truths:
score = max(score, dataset2metric[dataset_name](prediction, ground_truth, all_classes=all_classes))
total_score += score
mem_alloc = torch.cuda.memory_allocated() / 1024 / 1024 / 1024
mem_reserve = torch.cuda.memory_reserved() / 1024 / 1024 / 1024
mem_peak = torch.cuda.memory_stats()['active_bytes.all.peak'] / 1024 / 1024 / 1024
mem_info = f"mem_alloc: {mem_alloc:.5f}, mem_reserved: {mem_reserve:.5f}, mem_peak: {mem_peak:.5f}"
aa.append(score)
print(f"[{i:>4}] score: {score:.4f}, avg_score: {total_score / (i + 1):.4f}, | {mem_info}")
print(f"Average score for dataset {dataset_name}: {np.mean(aa)}")
print(f"Total evaluation time: {time.time() - start:.2f} seconds")
def main(args):
seed_everything(args.seed)
dataset2maxlen, dataset2prompt, model2maxlen, = load_configurations(args.config_dir)
dtype = torch.float16 if args.dtype == "float16" else torch.float32
model_qjl, tokenizer = setup_model_and_tokenizer(
args.model_name,
dtype,
args.key_quantization_bits,
args.key_quantization_bits_initial_layers,
args.initial_layers_count,
args.outlier_count_general,
args.outlier_count_initial_layers,
args.value_quantization_bits,
args.group_size,
args.buffer_size,
)
print(f"Model and tokenizer for {args.model_name} are set up successfully.")
evaluate_model(
model_qjl,
tokenizer,
args.dataset_name,
dataset2maxlen,
dataset2prompt,
model2maxlen,
args.n_data,
)
if __name__ == "__main__":
args = parse_args()
main(args)