diff --git a/ptuning/cli_demp.py b/ptuning/cli_demp.py new file mode 100644 index 00000000..43ffb080 --- /dev/null +++ b/ptuning/cli_demp.py @@ -0,0 +1,71 @@ +import os +import platform +import signal +from transformers import AutoTokenizer, AutoModel +import readline +import torch + +#首先载入Tokenizer: +from transformers import AutoConfig, AutoModel, AutoTokenizer + +# 载入Tokenizer +tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) +#如果需要加载的是新 Checkpoint(只需包含 PrefixEncoder 参数): +config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128) +model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True) +prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin")) +new_prefix_state_dict = {} +for k, v in prefix_state_dict.items(): + if k.startswith("transformer.prefix_encoder."): + new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v +model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) +model = model.eval() + +os_name = platform.system() +clear_command = 'cls' if os_name == 'Windows' else 'clear' +stop_stream = False + + +def build_prompt(history): + prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" + for query, response in history: + prompt += f"\n\n用户:{query}" + prompt += f"\n\nChatGLM-6B:{response}" + return prompt + + +def signal_handler(signal, frame): + global stop_stream + stop_stream = True + + +def main(): + history = [] + global stop_stream + print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + while True: + query = input("\n用户:") + if query.strip() == "stop": + break + if query.strip() == "clear": + history = [] + os.system(clear_command) + print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + continue + count = 0 + for response, history in model.stream_chat(tokenizer, query, history=history): + if stop_stream: + stop_stream = False + break + else: + count += 1 + if count % 8 == 0: + os.system(clear_command) + print(build_prompt(history), flush=True) + signal.signal(signal.SIGINT, signal_handler) + os.system(clear_command) + print(build_prompt(history), flush=True) + + +if __name__ == "__main__": + main()