Fetch the repository succeeded.
import torch
from mydataset import *
from precessing import *
from Seq2SeqModel import *
import torch.nn.functional as F
sos_token =0
batch_size=1
input_lang,out_lang = creat_lang(500)
input_data,tag_input,tag_output = read_data(input_lang,out_lang,data_path="data/seq.data")
print("数据长度:",len(input_data))
Mydatset = MyDataset(input_data,tag_input,tag_output)
train_loader = DataLoader(Mydatset,batch_size=batch_size,shuffle=True)
EncoderModel = Encoder(input_lang.n_words,hidden_size=32)
DecoderModel = AttentionDencoder(output_size=out_lang.n_words, hidden_size=32)
crossentropyloss=nn.CrossEntropyLoss()
opt_config = [{'params': EncoderModel.parameters(), 'lr': 1e-4},
{'params': DecoderModel.parameters(), 'lr': 1e-4}]
opt = torch.optim.Adam(opt_config,lr=1e-4)
for epoch in range(1):
for data in train_loader:
input_data, tag_input, tag_output = data
encoder_output,hidden = EncoderModel(input_data,None)
decoder_input = torch.tensor([sos_token]*input_data.shape[0], device=device)
output_len=[]
for i in range(MAX_LEN):
output, hidden, attn_weights = DecoderModel(decoder_input,hidden,encoder_output)
output_len.append(output)
_,id = output.topk(1)
#decoder_input = id.view(-1)
decoder_input = tag_output[:,i] # teacher_forcing
# print(output_len)
loss = 0
for id,out in enumerate(output_len):
loss+=crossentropyloss(out[:,0,:], tag_output[:,id])
print(loss)
opt.zero_grad()
loss.backward()
opt.step()
torch.save(EncoderModel.state_dict(),"savemode/EncoderModel.pkl")
torch.save(DecoderModel.state_dict(),"savemode/DecoderModel.pkl")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。