From 63a45114a6e39c714dfd3248e94f87a62ab26c5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9B=BE=E5=B0=8F=E5=81=A5?= <2119516028@qq.com> Date: Tue, 27 Jun 2023 11:41:52 +0800 Subject: [PATCH] Create cli_demp.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 几乎直接可运行的(改自己的权重路径就行)官方p tuning v2 命令行演示版本 --- ptuning/cli_demp.py | 71 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 ptuning/cli_demp.py diff --git a/ptuning/cli_demp.py b/ptuning/cli_demp.py new file mode 100644 index 00000000..43ffb080 --- /dev/null +++ b/ptuning/cli_demp.py @@ -0,0 +1,71 @@ +import os +import platform +import signal +from transformers import AutoTokenizer, AutoModel +import readline +import torch + +#首先载入Tokenizer: +from transformers import AutoConfig, AutoModel, AutoTokenizer + +# 载入Tokenizer +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +#如果需要加载的是新 Checkpoint(只需包含 PrefixEncoder 参数): +config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True) +prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin")) +new_prefix_state_dict = {} +for k, v in prefix_state_dict.items(): + if k.startswith("transformer.prefix_encoder."): + new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v +model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) +model = model.eval() + +os_name = platform.system() +clear_command = 'cls' if os_name == 'Windows' else 'clear' +stop_stream = False + + +def build_prompt(history): + prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" + for query, response in history: + prompt += f"\n\n用户:{query}" + prompt += f"\n\nChatGLM-6B:{response}" + return prompt + + +def signal_handler(signal, frame): + global stop_stream + stop_stream = True + + +def main(): + history = [] + global stop_stream + print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + while True: + query = input("\n用户:") + if query.strip() == "stop": + break + if query.strip() == "clear": + history = [] + os.system(clear_command) + print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + continue + count = 0 + for response, history in model.stream_chat(tokenizer, query, history=history): + if stop_stream: + stop_stream = False + break + else: + count += 1 + if count % 8 == 0: + os.system(clear_command) + print(build_prompt(history), flush=True) + signal.signal(signal.SIGINT, signal_handler) + os.system(clear_command) + print(build_prompt(history), flush=True) + + +if __name__ == "__main__": + main()