代码拉取完成,页面将自动刷新
import os
import torch
import json
from pprint import pprint
from accelerate import infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
data_name = "msra"
train_args_path = "./checkpoint/{}/train_trainer/adapter_model/train_args.json".format(data_name)
with open(train_args_path, "r") as fp:
args = json.load(fp)
config = AutoConfig.from_pretrained(args["model_dir"], trust_remote_code=True)
pprint(config)
tokenizer = AutoTokenizer.from_pretrained(args["model_dir"], trust_remote_code=True)
model = AutoModel.from_pretrained(args["model_dir"], trust_remote_code=True).half().cuda()
model = model.eval()
model = PeftModel.from_pretrained(model, os.path.join(args["save_dir"], "adapter_model"), torch_dtype=torch.float32, trust_remote_code=True)
model.half().cuda()
model.eval()
while True:
inp = input("用户 >>> ")
response, history = model.chat(tokenizer, inp, history=[])
print("ChatNER >>> ", response))
print("="*100)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。