代码拉取完成,页面将自动刷新
import torch.nn as nn
import torch
device = torch.device("cpu")
MAX_LEN = 7
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size):
super(Encoder, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size,batch_first=True)
def forward(self, input, hidden):
# 对输入的序列进行embdedding处理
embedded = self.embedding(input)
output = embedded
# 在进行embedding处理之后,作为gru网络的输入,输入到gru,提取输入语句的特征。
output, hidden = self.gru(output, hidden)
return output, hidden[-1]
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)
# 定义Decoder方法类,这里的decoder过程是加上了attention机制
class AttentionDencoder(nn.Module):
def __init__(self, hidden_size, output_size, dropout_p=0.5, max_length=MAX_LEN):
super(AttentionDencoder, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.max_length = max_length
self.embedding = nn.Embedding(self.output_size, self.hidden_size)
self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(self.hidden_size, self.hidden_size,batch_first=True,dropout=self.dropout_p)
self.out = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input, hidden, encoder_outputs):
batch_size = input.size(0)
embedded = self.embedding(input).view(batch_size,-1,self.hidden_size)
embedded = self.dropout(embedded)
# 使用softmax方法来计算出attention的权重值
temp = torch.cat((embedded, hidden.unsqueeze(1)), -1)
attn_weights = F.softmax(self.attn(temp))
attn_applied = torch.bmm(attn_weights, encoder_outputs)
output = torch.cat((embedded, attn_applied), -1)
output = self.attn_combine(output)
output = F.relu(output)
output, hidden = self.gru(output, hidden.unsqueeze(0))
output = self.out(output)
return output, hidden[-1], attn_weights
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。