1 Star 1 Fork 1

LeeDvan/CoupletAI

forked from 东方佑/CoupletAI 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
preprocess.py 3.84 KB
一键复制 编辑 原始数据 按行查看 历史
import argparse
from pathlib import Path
from typing import Tuple, List, Mapping
import torch
import torch.nn as nn
from tqdm import trange
import config
from data_load import load_vocab, load_dataset
def create_dataset(seqs: List[List[str]],
tags: List[List[str]],
word_to_ix: Mapping[str, int],
max_seq_len: int,
pad_ix: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert List[str] -> torch.Tensor.
Returns:
seqs_tensor: shape=[num_seqs, max_seq_len].
seqs_mask: shape=[num_seqs, max_seq_len].
tags_tesnor: shape=[num_seqs, max_seq_len].
"""
assert len(seqs) == len(tags)
num_seqs = len(seqs)
seqs_tensor = torch.ones(num_seqs, max_seq_len) * pad_ix
seqs_mask = torch.zeros(num_seqs, max_seq_len)
tags_tesnor = torch.ones(num_seqs, max_seq_len) * pad_ix
for i in trange(num_seqs):
seqs_mask[i, : len(seqs[i])] = 1
for j, word in enumerate(seqs[i]):
seqs_tensor[i, j] = word_to_ix.get(word, word_to_ix['[UNK]'])
for j, tag in enumerate(tags[i]):
tags_tesnor[i, j] = word_to_ix.get(tag, word_to_ix['[UNK]'])
return seqs_tensor.long(), seqs_mask, tags_tesnor.long()
def save_dataset(seqs_tensor, seqs_mask, tags_tesnor, path):
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
torch.save(seqs_tensor, path / 'seqs_tensor.pkl')
torch.save(seqs_mask, path / 'seqs_mask.pkl')
torch.save(tags_tesnor, path / 'tags_tesnor.pkl')
def create_attention_mask(raw_mask: torch.Tensor) -> torch.Tensor:
"""Convert mask to attention mask.
"""
extended_attention_mask = raw_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask.float()
def create_transformer_attention_mask(raw_mask: torch.Tensor) -> torch.Tensor:
"""Convert mask to transformer attention mask.
"""
return (1 - raw_mask).bool()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# parser.add_argument("--dir", default='tensor_dataset', type=str)
parser.add_argument("--max_len", default=32, type=int)
args = parser.parse_args()
vocab_path = f'{config.data_dir}/vocabs'
word_to_ix = load_vocab(vocab_path)
vocab_size = len(word_to_ix)
max_seq_len = args.max_len
# 训练集
# path = f'{config.data_dir}/tensor/train'
# print('预处理训练集, 保存路劲:' + path)
# seq_path = f'{config.data_dir}/train/in.txt' # 上联数据
# tag_path = f'{config.data_dir}/train/out.txt' # 下联数据
# seqs, tags = load_dataset(seq_path, tag_path)
# seqs, masks, tags = create_dataset(seqs, tags, word_to_ix, max_seq_len, word_to_ix['[PAD]'])
# save_dataset(seqs, masks, tags, path)
# print('成功')
# 测试集
path = f'{config.data_dir}/tensor/test'
print('预处理测试集, 保存路劲:' + path)
seq_path = f'{config.data_dir}/test/in.txt' # 上联数据
tag_path = f'{config.data_dir}/test/out.txt' # 下联数据
seqs, tags = load_dataset(seq_path, tag_path)
seqs, masks, tags = create_dataset(seqs, tags, word_to_ix, max_seq_len, word_to_ix['[PAD]'])
# save_dataset(seqs, masks, tags, path)
print('成功')
# 计算最大损失熵
# 损失函数
loss_func = nn.CrossEntropyLoss(ignore_index=word_to_ix['[PAD]'])
# 计算等概出现候选字时, 交互熵
logits = torch.full((32, vocab_size), 1/vocab_size)
i = 6
output = torch.zeros((1, 32), dtype=torch.int64)
seq = torch.arange(4, i + 4)
output[0, 0: i] = seq
loss = loss_func(logits, output.view(-1))
print(loss)
res = torch.tensor([loss])
path = f'{config.data_dir}/tensor'
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
torch.save(res, path / 'max_entropy_tensor.pkl')
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/LeeDvan/CoupletAI.git
git@gitee.com:LeeDvan/CoupletAI.git
LeeDvan
CoupletAI
CoupletAI
master

搜索帮助