1 Star 2 Fork 0

zhang/seq2seq聊天机器人

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
precessing.py 2.99 KB
一键复制 编辑 原始数据 按行查看 历史
zhang 提交于 2022-04-09 19:02 . 第一次提交
import os
import pickle
MAX_LEN = 7
from mydataset import *
class Lang:
def __init__(self, name):
self.name = name
self.word2index = {"SOS":0,"EOS":1,"PAD":2,"UNK":3}
self.word2count = {}
self.index2word = {0: "SOS", 1: "EOS",2:"PAD",3:"UNK"}
self.n_words = 4 # Count SOS and EOS
def addSentence(self, sentence):
for word in sentence.split(' '):
self.addWord(word)
def addWord(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
def creat_lang(max_num):
input_lang = Lang("问题字典")
out_lang = Lang("答案字典")
with open("data/seq.data","r",encoding="utf-8") as f:
total_data = f.readlines()
for line in total_data:
data_c = line.strip().split("\t")
if len(data_c)==2 and len(data_c[0].split(" "))>4 and len(data_c[1].split(" ")) >4:
input_c = data_c[0]
output_c = data_c[1]
input_lang.addSentence(input_c)
out_lang.addSentence(output_c)
# 保存字典
with open("dict/input_lang.pkl", 'wb') as f:
pickle.dump(input_lang, f)
with open("dict/out_lang.pkl", 'wb') as f:
pickle.dump(out_lang, f)
return input_lang,out_lang
# 获取train_data
def tensorfromsentenct(sentenct,lang,tag="input"):
words = sentenct.split(" ")
ids= [lang.word2index.get(word,3) for word in words] # 获取ward的id,没有的花返回unk id :3
if tag=="input":
if len(ids) < MAX_LEN:
ids = ids + [2]*(MAX_LEN-len(ids))# 短于max_len 填充
else:
ids = ids[:MAX_LEN]
if tag=="tag_input":
ids = [0]+ids
if len(ids) < MAX_LEN:
ids = ids + [2]*(MAX_LEN-len(ids)) # 短于max_len 填充
else:
ids = ids[:MAX_LEN]
if tag=="tag_output":
if len(ids) < MAX_LEN-1:
ids = ids+[1]
ids = ids + [2]*(MAX_LEN-len(ids)) # 短于max_len 填充
else:
ids = ids[:MAX_LEN-1]
ids = ids+[1]
return ids
def read_data(input_lang,out_lang,data_path="data/seq.data",num_max=500):
with open(data_path, "r", encoding="utf-8") as f:
# total_data = f.readlines()[:num_max]
total_data = f.readlines()
input_data = []
tag_input=[]
tag_output=[]
for line in total_data:
data_c = line.strip().split("\t")
if len(data_c) == 2 and len(data_c[0].split(" "))>4 and len(data_c[1].split(" ")) >4:
input_c = data_c[0]
output_c = data_c[1]
input_data.append(tensorfromsentenct(input_c,input_lang,"input"))
tag_input.append(tensorfromsentenct(output_c,out_lang,"tag_input"))
tag_output.append(tensorfromsentenct(output_c,out_lang,"tag_output"))
return input_data,tag_input,tag_output
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhangtuo0723/seq2seq-chat-robot.git
git@gitee.com:zhangtuo0723/seq2seq-chat-robot.git
zhangtuo0723
seq2seq-chat-robot
seq2seq聊天机器人
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385