1 Star 2 Fork 0

zhang/seq2seq聊天机器人

Create your Gitee Account
Explore and code with more than 12 million developers,Free private repositories !:)
Sign up
文件
This repository doesn't specify license. Please pay attention to the specific project description and its upstream code dependency when using it.
Clone or Download
main.py 1.72 KB
Copy Edit Raw Blame History
zhang authored 2022-04-09 19:02 . 第一次提交
import torch
from mydataset import *
from precessing import *
from Seq2SeqModel import *
import torch.nn.functional as F
sos_token =0
batch_size=1
input_lang,out_lang = creat_lang(500)
input_data,tag_input,tag_output = read_data(input_lang,out_lang,data_path="data/seq.data")
print("数据长度:",len(input_data))
Mydatset = MyDataset(input_data,tag_input,tag_output)
train_loader = DataLoader(Mydatset,batch_size=batch_size,shuffle=True)
EncoderModel = Encoder(input_lang.n_words,hidden_size=32)
DecoderModel = AttentionDencoder(output_size=out_lang.n_words, hidden_size=32)
crossentropyloss=nn.CrossEntropyLoss()
opt_config = [{'params': EncoderModel.parameters(), 'lr': 1e-4},
{'params': DecoderModel.parameters(), 'lr': 1e-4}]
opt = torch.optim.Adam(opt_config,lr=1e-4)
for epoch in range(1):
for data in train_loader:
input_data, tag_input, tag_output = data
encoder_output,hidden = EncoderModel(input_data,None)
decoder_input = torch.tensor([sos_token]*input_data.shape[0], device=device)
output_len=[]
for i in range(MAX_LEN):
output, hidden, attn_weights = DecoderModel(decoder_input,hidden,encoder_output)
output_len.append(output)
_,id = output.topk(1)
#decoder_input = id.view(-1)
decoder_input = tag_output[:,i] # teacher_forcing
# print(output_len)
loss = 0
for id,out in enumerate(output_len):
loss+=crossentropyloss(out[:,0,:], tag_output[:,id])
print(loss)
opt.zero_grad()
loss.backward()
opt.step()
torch.save(EncoderModel.state_dict(),"savemode/EncoderModel.pkl")
torch.save(DecoderModel.state_dict(),"savemode/DecoderModel.pkl")
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhangtuo0723/seq2seq-chat-robot.git
git@gitee.com:zhangtuo0723/seq2seq-chat-robot.git
zhangtuo0723
seq2seq-chat-robot
seq2seq聊天机器人
master

Search

0d507c66 1850385 C8b1a773 1850385