代码拉取完成,页面将自动刷新
from models.ERNIE import *
PAD, CLS = '[PAD]', '[CLS]' # padding符号, bert中综合信息符号
class Run(object):
"""配置参数"""
def __init__(self, model_name, dataset):
self.model_name = model_name
self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'
self.config = Config(dataset)
self.model = Model(self.config).to(self.config.device)
self.pad_size = self.config.pad_size
def load_data(self, content):
token = self.config.tokenizer.tokenize(content)
token = [CLS] + token
seq_len = len(token)
mask = []
# 编码后的词
token_ids = self.config.tokenizer.convert_tokens_to_ids(token)
if self.pad_size:
if len(token) < self.pad_size:
# 后面补0
mask = [1] * len(token_ids) + [0] * (self.pad_size - len(token))
token_ids += ([0] * (self.pad_size - len(token)))
else:
mask = [1] * self.pad_size
token_ids = token_ids[:self.pad_size]
seq_len = self.pad_size
x = torch.LongTensor([token_ids]).to(self.config.device)
# pad前的长度(超过pad_size的设为pad_size)
seq_len = torch.LongTensor([seq_len]).to(self.config.device)
mask = torch.LongTensor([mask]).to(self.config.device)
return x, seq_len, mask
def run_single_1(self, data, load=True):
# 读取预训练模型
print("加载模型")
load_dict = torch.load(self.save_path)
# print(load_dict.keys())
if load:
self.model.load_state_dict(load_dict)
self.model.eval()
print("模型加载完成,开始预测")
encoding = self.load_data(data)
# print(encoding)
outputs = self.model(encoding)
# 指定1维,压缩至1维,返回行最大值及其下标
predict = torch.max(outputs.data, 1)[1].cpu().numpy()
print(self.config.class_list[predict.astype(int)[0]])
if __name__ == '__main__':
model_name = 'ERNIE'
dataset = 'THUCNews' # data/class.txt显示分类类别
model = Run(model_name, dataset)
# 测试数据
data = "经济发生危机,导致学生出现罢课,工人出现暴动"
# print("初始结果:")
# model.run_single_1(data, False)
print("已训练结果:")
model.run_single_1(data)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。