1 Star 0 Fork 1

linzhengtian/Bert-Chinese-Text-Classification-Pytorch-Learn

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
predict.py 2.36 KB
一键复制 编辑 原始数据 按行查看 历史
sdfvgvsvs 提交于 2024-02-06 11:09 . 2024020601 修复bug
from models.ERNIE import *
PAD, CLS = '[PAD]', '[CLS]' # padding符号, bert中综合信息符号
class Run(object):
"""配置参数"""
def __init__(self, model_name, dataset):
self.model_name = model_name
self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'
self.config = Config(dataset)
self.model = Model(self.config).to(self.config.device)
self.pad_size = self.config.pad_size
def load_data(self, content):
token = self.config.tokenizer.tokenize(content)
token = [CLS] + token
seq_len = len(token)
mask = []
# 编码后的词
token_ids = self.config.tokenizer.convert_tokens_to_ids(token)
if self.pad_size:
if len(token) < self.pad_size:
# 后面补0
mask = [1] * len(token_ids) + [0] * (self.pad_size - len(token))
token_ids += ([0] * (self.pad_size - len(token)))
else:
mask = [1] * self.pad_size
token_ids = token_ids[:self.pad_size]
seq_len = self.pad_size
x = torch.LongTensor([token_ids]).to(self.config.device)
# pad前的长度(超过pad_size的设为pad_size)
seq_len = torch.LongTensor([seq_len]).to(self.config.device)
mask = torch.LongTensor([mask]).to(self.config.device)
return x, seq_len, mask
def run_single_1(self, data, load=True):
# 读取预训练模型
print("加载模型")
load_dict = torch.load(self.save_path)
# print(load_dict.keys())
if load:
self.model.load_state_dict(load_dict)
self.model.eval()
print("模型加载完成,开始预测")
encoding = self.load_data(data)
# print(encoding)
outputs = self.model(encoding)
# 指定1维,压缩至1维,返回行最大值及其下标
predict = torch.max(outputs.data, 1)[1].cpu().numpy()
print(self.config.class_list[predict.astype(int)[0]])
if __name__ == '__main__':
model_name = 'ERNIE'
dataset = 'THUCNews' # data/class.txt显示分类类别
model = Run(model_name, dataset)
# 测试数据
data = "经济发生危机,导致学生出现罢课,工人出现暴动"
# print("初始结果:")
# model.run_single_1(data, False)
print("已训练结果:")
model.run_single_1(data)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/linzhengtian/bert-chinese-text-classification-pytorch-learn.git
git@gitee.com:linzhengtian/bert-chinese-text-classification-pytorch-learn.git
linzhengtian
bert-chinese-text-classification-pytorch-learn
Bert-Chinese-Text-Classification-Pytorch-Learn
master

搜索帮助