代码拉取完成,页面将自动刷新
#!/usr/bin/env python
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import Variable
class MultiGRU(nn.Module):
""" Implements a three layer GRU cell including an embedding layer
and an output linear layer back to the size of the vocabulary"""
def __init__(self, voc_size):
super(MultiGRU, self).__init__()
self.embedding = nn.Embedding(voc_size, 128)
self.gru_1 = nn.GRUCell(128, 512)
self.gru_2 = nn.GRUCell(512, 512)
self.gru_3 = nn.GRUCell(512, 512)
self.linear = nn.Linear(512, voc_size)
def forward(self, x, h):
x = self.embedding(x)
h_out = Variable(torch.zeros(h.size()))
x = h_out[0] = self.gru_1(x, h[0])
x = h_out[1] = self.gru_2(x, h[1])
x = h_out[2] = self.gru_3(x, h[2])
x = self.linear(x)
return x, h_out
def init_h(self, batch_size):
# Initial cell state is zero
return Variable(torch.zeros(3, batch_size, 512))
class RNN():
"""Implements the Prior and Agent RNN. Needs a Vocabulary instance in
order to determine size of the vocabulary and index of the END token"""
def __init__(self, voc):
self.rnn = MultiGRU(voc.vocab_size)
if torch.cuda.is_available():
self.rnn.cuda()
self.voc = voc
def likelihood(self, target):
"""
Retrieves the likelihood of a given sequence
Args:
target: (batch_size * sequence_lenght) A batch of sequences
Outputs:
log_probs : (batch_size) Log likelihood for each example*
entropy: (batch_size) The entropies for the sequences. Not
currently used.
"""
batch_size, seq_length = target.size()
start_token = Variable(torch.zeros(batch_size, 1).long())
start_token[:] = self.voc.vocab['GO']
x = torch.cat((start_token, target[:, :-1]), 1)
h = self.rnn.init_h(batch_size)
log_probs = Variable(torch.zeros(batch_size))
entropy = Variable(torch.zeros(batch_size))
for step in range(seq_length):
logits, h = self.rnn(x[:, step], h)
log_prob = F.log_softmax(logits)
prob = F.softmax(logits)
log_probs += NLLLoss(log_prob, target[:, step])
entropy += -torch.sum((log_prob * prob), 1)
return log_probs, entropy
def sample(self, batch_size, max_length=140):
"""
Sample a batch of sequences
Args:
batch_size : Number of sequences to sample
max_length: Maximum length of the sequences
Outputs:
seqs: (batch_size, seq_length) The sampled sequences.
log_probs : (batch_size) Log likelihood for each sequence.
entropy: (batch_size) The entropies for the sequences. Not
currently used.
"""
start_token = Variable(torch.zeros(batch_size).long())
start_token[:] = self.voc.vocab['GO']
h = self.rnn.init_h(batch_size)
x = start_token
sequences = []
log_probs = Variable(torch.zeros(batch_size))
finished = torch.zeros(batch_size).byte()
entropy = Variable(torch.zeros(batch_size))
if torch.cuda.is_available():
finished = finished.cuda()
for step in range(max_length):
logits, h = self.rnn(x, h)
prob = F.softmax(logits)
log_prob = F.log_softmax(logits)
x = torch.multinomial(prob).view(-1)
sequences.append(x.view(-1, 1))
log_probs += NLLLoss(log_prob, x)
entropy += -torch.sum((log_prob * prob), 1)
x = Variable(x.data)
EOS_sampled = (x == self.voc.vocab['EOS']).data
finished = torch.ge(finished + EOS_sampled, 1)
if torch.prod(finished) == 1: break
sequences = torch.cat(sequences, 1)
return sequences.data, log_probs, entropy
def NLLLoss(inputs, targets):
"""
Custom Negative Log Likelihood loss that returns loss per example,
rather than for the entire batch.
Args:
inputs : (batch_size, num_classes) *Log probabilities of each class*
targets: (batch_size) *Target class index*
Outputs:
loss : (batch_size) *Loss for each example*
"""
if torch.cuda.is_available():
target_expanded = torch.zeros(inputs.size()).cuda()
else:
target_expanded = torch.zeros(inputs.size())
target_expanded.scatter_(1, targets.contiguous().view(-1, 1).data, 1.0)
loss = Variable(target_expanded) * inputs
loss = torch.sum(loss, 1)
return loss
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。