代码拉取完成,页面将自动刷新
同步操作将从 Hauk Zero/Transformer Demo 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import torch
from torch import nn
from enc import Encoder
from dec import Decoder
class Transformer(nn.Module):
def __init__(self, n, n_vocab, d_model,
d_k, d_v, n_head, d_ff,
pad_token=0, max_len=5000,
dropout=0.5, device='cpu'):
super().__init__()
self.device = device
self.encoder = Encoder(n, n_vocab, d_model,
d_k, n_head, d_ff, pad_token,
max_len, dropout, device)
self.decoder = Decoder(n, n_vocab, d_model, d_k,
d_v, n_head, d_ff, pad_token,
max_len, dropout, device)
self.decoder.embd.weight = self.encoder.embd.weight
self.proj = nn.Linear(d_model, n_vocab).to(device)
def forward(self, x_enc, x_dec):
enc_output = self.encoder(x_enc)
dec_output = self.decoder(x_dec, x_enc, enc_output)
# logits: (batch_size, n_vocab, n_vocab)
logits = self.proj(dec_output)
return logits.view(-1, logits.size(-1))
def greedy_decoder(self, x_enc, start_token, end_token, pad_token=0):
self.decoder.use_kv_cache()
enc_output = self.encoder(x_enc)
dec_input = torch.ones_like(x_enc).type_as(x_enc) * pad_token
next_token = torch.ones(x_enc.shape[ 0 ]).to(self.device) * start_token
end_token = torch.ones(x_enc.shape[ 0 ]).to(self.device) * end_token
stop_flag = torch.ones(x_enc.shape[ 0 ]).to(self.device) * pad_token
i = 0
pred = torch.ones_like(x_enc).type_as(x_enc) * pad_token
while (not (torch.all(next_token == stop_flag))
and (not torch.all(next_token == end_token))
and i < x_enc.shape[ 1 ]):
dec_input[ :, i ] = next_token
dec_output = self.decoder(dec_input, x_enc, enc_output, i)
logits = self.proj(dec_output)
pred = logits.argmax(dim=-1, keepdim=False)
next_token = pred[ :, i ]
i += 1
self.decoder.clear_cache()
return pred
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。