1 Star 0 Fork 0

伊拉克肥灵/ChatGLM-LoRA-Tuning

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test.py 2.14 KB
一键复制 编辑 原始数据 按行查看 历史
西西嘛呦 提交于 2023-05-26 17:01 . Update test.py
import os
import json
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from dataset import load_data, NerCollate
from transformers import AutoModel, AutoTokenizer
from config_utils import ConfigParser
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
data_name = "msra"
train_args_path = "./checkpoint/{}/train_trainer/adapter_model/train_args.json".format(data_name)
with open(train_args_path, "r") as fp:
args = json.load(fp)
config_parser = ConfigParser(args)
args = config_parser.parse_main()
model = AutoModel.from_pretrained(args.model_dir, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
model.eval()
model = PeftModel.from_pretrained(model, os.path.join(args.save_dir, "adapter_model"), torch_dtype=torch.float32, trust_remote_code=True)
model.half().cuda()
model.eval()
test_data = load_data(args.dev_path)
ner_collate = NerCollate(args, tokenizer)
test_dataloader = DataLoader(test_data,
batch_size=args.train_batch_size,
shuffle=True,
drop_last=False,
collate_fn=ner_collate.collate_fn)
# 找到labels中预测开始的部分
bos_token_id = tokenizer.bos_token_id
eos_token_id = tokenizer.eos_token_id
with torch.no_grad():
all_preds = []
all_trues = []
for step, batch in enumerate(tqdm(test_dataloader, ncols=100)):
for k,v in batch.items():
batch[k] = v.cuda()
labels = batch["labels"].detach().cpu().numpy()
output = model(**batch)
logits = output.logits
preds = torch.argmax(logits, -1).detach().cpu().numpy()
preds = np.where(labels != -100, preds, tokenizer.pad_token_id)
preds = preds[:, :-1]
labels = labels[:, 1:]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
all_preds.extend(decoded_preds)
all_trues.extend(decoded_labels)
print("预测:", all_preds[:20])
print("真实:", all_trues[:20])
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/estaryjl/ChatGLM-LoRA-Tuning.git
git@gitee.com:estaryjl/ChatGLM-LoRA-Tuning.git
estaryjl
ChatGLM-LoRA-Tuning
ChatGLM-LoRA-Tuning
main

搜索帮助

0d507c66 1850385 C8b1a773 1850385