1 Star 2 Fork 0

zhang/seq2seq聊天机器人

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
Seq2SeqModel.py 2.46 KB
一键复制 编辑 原始数据 按行查看 历史
zhang 提交于 2022-04-09 19:02 . 第一次提交
import torch.nn as nn
import torch
device = torch.device("cpu")
MAX_LEN = 7
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size):
super(Encoder, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size,batch_first=True)
def forward(self, input, hidden):
# 对输入的序列进行embdedding处理
embedded = self.embedding(input)
output = embedded
# 在进行embedding处理之后,作为gru网络的输入,输入到gru,提取输入语句的特征。
output, hidden = self.gru(output, hidden)
return output, hidden[-1]
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)
# 定义Decoder方法类,这里的decoder过程是加上了attention机制
class AttentionDencoder(nn.Module):
def __init__(self, hidden_size, output_size, dropout_p=0.5, max_length=MAX_LEN):
super(AttentionDencoder, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.max_length = max_length
self.embedding = nn.Embedding(self.output_size, self.hidden_size)
self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(self.hidden_size, self.hidden_size,batch_first=True,dropout=self.dropout_p)
self.out = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input, hidden, encoder_outputs):
batch_size = input.size(0)
embedded = self.embedding(input).view(batch_size,-1,self.hidden_size)
embedded = self.dropout(embedded)
# 使用softmax方法来计算出attention的权重值
temp = torch.cat((embedded, hidden.unsqueeze(1)), -1)
attn_weights = F.softmax(self.attn(temp))
attn_applied = torch.bmm(attn_weights, encoder_outputs)
output = torch.cat((embedded, attn_applied), -1)
output = self.attn_combine(output)
output = F.relu(output)
output, hidden = self.gru(output, hidden.unsqueeze(0))
output = self.out(output)
return output, hidden[-1], attn_weights
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhangtuo0723/seq2seq-chat-robot.git
git@gitee.com:zhangtuo0723/seq2seq-chat-robot.git
zhangtuo0723
seq2seq-chat-robot
seq2seq聊天机器人
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385