1 Star 0 Fork 0

幽灵代码/ExcaliburGPT

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
chattogpt.py 1.94 KB
一键复制 编辑 原始数据 按行查看 历史
幽灵代码 提交于 2023-04-23 15:21 . 温度控制
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 12 14:22:48 2023
@author: lv
"""
import os
import torch
from models.TransformerQA import TransformerQA
from utils.JiebaTokenizer import JiebaTokenizer
tokenizer = JiebaTokenizer()
# 模型参数
vocab_size = tokenizer.vocab_size
embedding_dim = 96
hidden_dim = 1024
num_layers = 3
num_heads = 12
max_seq_len = 512
model_save_path = f'./run/{embedding_dim}_{hidden_dim}_{num_layers}_{num_heads}'
# 指定设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'
# 加载模型
model = TransformerQA(vocab_size, embedding_dim, hidden_dim, num_layers, num_heads, max_seq_len)
checkpoint = torch.load(os.path.join(model_save_path, 'best_model.pth'),map_location=torch.device(device))
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
def encode_input(input_text,history_text):
#最新的提问在末尾
querystr = history_text + input_text
#长度超过则截断,优先保留末尾
if len(querystr) > max_seq_len:
querystr = querystr[-max_seq_len:]
# 将输入文本转换为模型需要的格式
query = tokenizer.encode(querystr)
return torch.tensor(query),querystr
history_text = ""
bool_history = False
# 生成对话
while True:
input_text = input("你:")
# 编码输入序列
encoded_input,history_text = encode_input(input_text,history_text)
encoded_input = encoded_input.unsqueeze(0).to(device)
# 生成回答
output_seq = model.generate_output_sequence(encoded_input, None)
#print(output_seq)
#print('gen:',output_seq.size(1))
output_seq = output_seq.argmax(dim=2).squeeze(1)
# 将输出序列转换为自然语言形式输出
output_text = tokenizer.decode(output_seq.view(-1).tolist())
output_text = "".join(output_text)
history_text = history_text + output_text
if bool_history == False:
history_text = ""
print("AI: " + output_text)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/ghostcode/ExcaliburGPT.git
git@gitee.com:ghostcode/ExcaliburGPT.git
ghostcode
ExcaliburGPT
ExcaliburGPT
master

搜索帮助