1 Star 0 Fork 0

lengyanju8/DISC-LawLLM

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
cli_demo.py 2.44 KB
一键复制 编辑 原始数据 按行查看 历史
Yao Xiao 提交于 2023-09-23 19:18 . black and ruff
import os
import torch
import platform
from colorama import Fore, Style
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
def init_model():
print("Initializing model...")
model_path = "ShengbinYue/DISC-LawLLM"
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True
)
model.generation_config = GenerationConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, trust_remote_code=True
)
return model, tokenizer
def clear_screen():
if platform.system() == "Windows":
os.system("cls")
else:
os.system("clear")
print(
Fore.YELLOW
+ Style.BRIGHT
+ "欢迎使用复旦 DISC-LawLLM,输入进行对话,clear 清空历史,Ctrl+C 中断生成,"
+ "stream 开关流式生成,exit 结束。"
)
return []
def main(stream=True):
model, tokenizer = init_model()
messages = clear_screen()
while True:
prompt = input(Fore.GREEN + Style.BRIGHT + "\n用户:" + Style.NORMAL)
if prompt.strip() == "exit":
break
if prompt.strip() == "clear":
messages = clear_screen()
continue
print(Fore.CYAN + Style.BRIGHT + "\nDISC-LawLLM:" + Style.NORMAL, end="")
if prompt.strip() == "stream":
stream = not stream
print(
Fore.YELLOW + "({}流式生成)\n".format("开启" if stream else "关闭"),
end="",
)
continue
messages.append({"role": "user", "content": prompt})
if stream:
position = 0
try:
for response in model.chat(tokenizer, messages, stream=True):
print(response[position:], end="", flush=True)
position = len(response)
if torch.backends.mps.is_available():
torch.mps.empty_cache()
except KeyboardInterrupt:
pass
print()
else:
response = model.chat(tokenizer, messages)
print(response)
if torch.backends.mps.is_available():
torch.mps.empty_cache()
messages.append({"role": "assistant", "content": response})
print(Style.RESET_ALL)
if __name__ == "__main__":
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/lengyanju8/DISC-LawLLM.git
git@gitee.com:lengyanju8/DISC-LawLLM.git
lengyanju8
DISC-LawLLM
DISC-LawLLM
main

搜索帮助