代码拉取完成,页面将自动刷新
import jieba
import torch
import pickle
from Seq2SeqModel import *
from beam import *
sos_token =0
eos_token =1
MAX_LEN = 7
def evaluatoin_beamsearch_heapq(encoder_outputs, encoder_hidden,DecoderModel,output_lang):
"""使用 堆 来完成beam search,对是一种优先级的队列,按照优先级顺序存取数据"""
batch_size = encoder_hidden.size(0)
# 1. 构造第一次需要的输入数据,保存在堆中
decoder_input = torch.LongTensor([sos_token] * 1)
decoder_hidden = encoder_hidden # 需要输入的hidden
prev_beam = Beam()
prev_beam.add(1, False, [decoder_input], decoder_input, decoder_hidden)
while True:
cur_beam = Beam()
# 2. 取出堆中的数据,进行forward_step的操作,获得当前时间步的output,hidden
# 这里使用下划线进行区分
for _probility, _complete, _seq, _decoder_input, _decoder_hidden in prev_beam:
# 判断前一次的_complete是否为True,如果是,则不需要forward
# 有可能为True,但是概率并不是最大
if _complete == True:
cur_beam.add(_probility, _complete, _seq, _decoder_input, _decoder_hidden)
else:
decoder_output_t, decoder_hidden, _ = DecoderModel(_decoder_input, _decoder_hidden,
encoder_outputs)
decoder_output_t.view(-1)
value, index = torch.topk(decoder_output_t.squeeze(0), 3) # [batch_size=1,beam_widht=3]
# 3. 从output中选择topk(k=beam width)个输出,作为下一次的input
for m, n in zip(value[0], index[0]):
decoder_input = torch.LongTensor([[n]])
seq = _seq + [n.item()]
probility = _probility * m
if n.item() == eos_token:
complete = True
else:
complete = False
# 4. 把下一个实践步骤需要的输入等数据保存在一个新的堆中
cur_beam.add(probility, complete, seq,
decoder_input, decoder_hidden)
# 5. 获取新的堆中的优先级最高(概率最大)的数据,判断数据是否是EOS结尾或者是否达到最大长度,如果是,停止迭代
best_prob, best_complete, best_seq, _, _ = max(cur_beam)
if best_complete == True or len(best_seq) == 7: # 减去sos
return decode(output_lang,best_seq[1:-1])
else:
# 6. 则重新遍历新的堆中的数据
prev_beam = cur_beam
def words_tensor(words,lang):
id = [lang.word2index.get(word,3) for word in words]
if len(id)>MAX_LEN:
id = id[:MAX_LEN]
else:
id = id + [2]*(MAX_LEN-len(id))
return torch.LongTensor(id)
with open("dict/input_lang.pkl","rb") as f:
input_lang = pickle.load(f)
with open("dict/out_lang.pkl","rb") as f:
out_lang = pickle.load(f)
EncoderModel = Encoder(input_lang.n_words,hidden_size=100)
EncoderModel.load_state_dict(torch.load("savemode/EncoderModel.pkl"))
EncoderModel.eval()
DecoderModel = AttentionDencoder(output_size=out_lang.n_words, hidden_size=100)
DecoderModel.load_state_dict(torch.load("savemode/DecoderModel.pkl"))
DecoderModel.eval()
def decode(lang,id_len):
sentenc = ""
for i in id_len:
sentenc += lang.index2word[i]
return sentenc
def predict(sentence,input_lang,output_lang,EncoderModel,DecoderModel):
input_words = jieba.lcut(sentence)
input = words_tensor(input_words,input_lang)
encoder_output,hidden = EncoderModel(input.unsqueeze(0),None)
decoder_input = torch.tensor([sos_token] * 1, device=device)
output_id = []
print("A:",evaluatoin_beamsearch_heapq(encoder_output, hidden,DecoderModel,output_lang))
# for i in range(MAX_LEN):
# output, hidden, attn_weights = DecoderModel(decoder_input,hidden,encoder_output)
#
#
#
# output = output.view(-1)
#
# _, id = output.topk(1)
# if id == 1:
# break
# output_id.append(id.item())
# decoder_input = id.view(-1)
#
#
# print(output_id)
# out = decode(output_lang,output_id)
# print(out)
# return out
while True:
sentence = input("Q:")
predict(sentence,input_lang,out_lang,EncoderModel,DecoderModel)
# decoder中的新方法
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。