代码拉取完成,页面将自动刷新
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 12 14:22:48 2023
@author: lv
"""
import os
import torch
from models.TransformerQA import TransformerQA
from utils.JiebaTokenizer import JiebaTokenizer
tokenizer = JiebaTokenizer()
# 模型参数
vocab_size = tokenizer.vocab_size
embedding_dim = 96
hidden_dim = 1024
num_layers = 3
num_heads = 12
max_seq_len = 512
model_save_path = f'./run/{embedding_dim}_{hidden_dim}_{num_layers}_{num_heads}'
# 指定设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'
# 加载模型
model = TransformerQA(vocab_size, embedding_dim, hidden_dim, num_layers, num_heads, max_seq_len)
checkpoint = torch.load(os.path.join(model_save_path, 'best_model.pth'),map_location=torch.device(device))
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
def encode_input(input_text,history_text):
#最新的提问在末尾
querystr = history_text + input_text
#长度超过则截断,优先保留末尾
if len(querystr) > max_seq_len:
querystr = querystr[-max_seq_len:]
# 将输入文本转换为模型需要的格式
query = tokenizer.encode(querystr)
return torch.tensor(query),querystr
history_text = ""
bool_history = False
# 生成对话
while True:
input_text = input("你:")
# 编码输入序列
encoded_input,history_text = encode_input(input_text,history_text)
encoded_input = encoded_input.unsqueeze(0).to(device)
# 生成回答
output_seq = model.generate_output_sequence(encoded_input, None)
#print(output_seq)
#print('gen:',output_seq.size(1))
output_seq = output_seq.argmax(dim=2).squeeze(1)
# 将输出序列转换为自然语言形式输出
output_text = tokenizer.decode(output_seq.view(-1).tolist())
output_text = "".join(output_text)
history_text = history_text + output_text
if bool_history == False:
history_text = ""
print("AI: " + output_text)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。