代码拉取完成,页面将自动刷新
import os
import pickle
MAX_LEN = 7
from mydataset import *
class Lang:
def __init__(self, name):
self.name = name
self.word2index = {"SOS":0,"EOS":1,"PAD":2,"UNK":3}
self.word2count = {}
self.index2word = {0: "SOS", 1: "EOS",2:"PAD",3:"UNK"}
self.n_words = 4 # Count SOS and EOS
def addSentence(self, sentence):
for word in sentence.split(' '):
self.addWord(word)
def addWord(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
def creat_lang(max_num):
input_lang = Lang("问题字典")
out_lang = Lang("答案字典")
with open("data/seq.data","r",encoding="utf-8") as f:
total_data = f.readlines()
for line in total_data:
data_c = line.strip().split("\t")
if len(data_c)==2 and len(data_c[0].split(" "))>4 and len(data_c[1].split(" ")) >4:
input_c = data_c[0]
output_c = data_c[1]
input_lang.addSentence(input_c)
out_lang.addSentence(output_c)
# 保存字典
with open("dict/input_lang.pkl", 'wb') as f:
pickle.dump(input_lang, f)
with open("dict/out_lang.pkl", 'wb') as f:
pickle.dump(out_lang, f)
return input_lang,out_lang
# 获取train_data
def tensorfromsentenct(sentenct,lang,tag="input"):
words = sentenct.split(" ")
ids= [lang.word2index.get(word,3) for word in words] # 获取ward的id,没有的花返回unk id :3
if tag=="input":
if len(ids) < MAX_LEN:
ids = ids + [2]*(MAX_LEN-len(ids))# 短于max_len 填充
else:
ids = ids[:MAX_LEN]
if tag=="tag_input":
ids = [0]+ids
if len(ids) < MAX_LEN:
ids = ids + [2]*(MAX_LEN-len(ids)) # 短于max_len 填充
else:
ids = ids[:MAX_LEN]
if tag=="tag_output":
if len(ids) < MAX_LEN-1:
ids = ids+[1]
ids = ids + [2]*(MAX_LEN-len(ids)) # 短于max_len 填充
else:
ids = ids[:MAX_LEN-1]
ids = ids+[1]
return ids
def read_data(input_lang,out_lang,data_path="data/seq.data",num_max=500):
with open(data_path, "r", encoding="utf-8") as f:
# total_data = f.readlines()[:num_max]
total_data = f.readlines()
input_data = []
tag_input=[]
tag_output=[]
for line in total_data:
data_c = line.strip().split("\t")
if len(data_c) == 2 and len(data_c[0].split(" "))>4 and len(data_c[1].split(" ")) >4:
input_c = data_c[0]
output_c = data_c[1]
input_data.append(tensorfromsentenct(input_c,input_lang,"input"))
tag_input.append(tensorfromsentenct(output_c,out_lang,"tag_input"))
tag_output.append(tensorfromsentenct(output_c,out_lang,"tag_output"))
return input_data,tag_input,tag_output
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。