1 Star 0 Fork 0

painx/WZCQ

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
Batch.py 2.15 KB
一键复制 编辑 原始数据 按行查看 历史
冯泉荔 提交于 2021-07-05 00:13 . Add files via upload
import torch
#from torchtext import data
import numpy as np
from torch.autograd import Variable
def nopeak_mask(size, device):
np_mask = np.triu(np.ones((1, size, size)),
k=1).astype('uint8')
variable = Variable
np_mask = variable(torch.from_numpy(np_mask) == 0)
np_mask = np_mask.cuda(device)
return np_mask
def create_masks(src, trg, device):
src_mask = (src != -1).unsqueeze(-2)
if trg is not None:
trg_mask = (trg != -1).unsqueeze(-2)
trg_mask.cuda(device)
size = trg.size(1) # get seq_len for matrix
np_mask = nopeak_mask(size, device)
trg_mask = trg_mask & np_mask
else:
trg_mask = None
return src_mask, trg_mask
# patch on Torchtext's batching process that makes it more efficient
# from http://nlp.seas.harvard.edu/2018/04/03/attention.html#position-wise-feed-forward-networks
# class MyIterator(data.Iterator):
# def create_batches(self):
# if self.train:
# def pool(d, random_shuffler):
# for p in data.batch(d, self.batch_size * 100):
# p_batch = data.batch(
# sorted(p, key=self.sort_key),
# self.batch_size, self.batch_size_fn)
# for b in random_shuffler(list(p_batch)):
# yield b
# self.batches = pool(self.data(), self.random_shuffler)
#
# else:
# self.batches = []
# for b in data.batch(self.data(), self.batch_size,
# self.batch_size_fn):
# self.batches.append(sorted(b, key=self.sort_key))
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
"Keep augmenting batch and calculate total number of tokens + padding."
global max_src_in_batch, max_tgt_in_batch
if count == 1:
max_src_in_batch = 0
max_tgt_in_batch = 0
max_src_in_batch = max(max_src_in_batch, len(new.src))
max_tgt_in_batch = max(max_tgt_in_batch, len(new.trg) + 2)
src_elements = count * max_src_in_batch
tgt_elements = count * max_tgt_in_batch
return max(src_elements, tgt_elements)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/painx/WZCQ.git
git@gitee.com:painx/WZCQ.git
painx
WZCQ
WZCQ
main

搜索帮助