代码拉取完成,页面将自动刷新
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import TensorDataset, DataLoader
import argparse
import os
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def generate(name):
num_sessions = 0
inputs = []
outputs = []
with open('data/' + name, 'r') as f:
for line in f.readlines():
num_sessions += 1
line = tuple(map(lambda n: n - 1, map(int, line.strip().split())))
for i in range(len(line) - window_size):
inputs.append(line[i:i + window_size])
outputs.append(line[i + window_size])
print('Number of sessions({}): {}'.format(name, num_sessions))
print('Number of seqs({}): {}'.format(name, len(inputs)))
dataset = TensorDataset(torch.tensor(inputs, dtype=torch.float), torch.tensor(outputs))
return dataset
class Model(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_keys):
super(Model, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_keys)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
out, _ = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :])
return out
if __name__ == '__main__':
# Hyperparameters
num_classes = 28
num_epochs = 300
batch_size = 2048
input_size = 1
model_dir = 'model'
log = 'Adam_batch_size={}_epoch={}'.format(str(batch_size), str(num_epochs))
parser = argparse.ArgumentParser()
parser.add_argument('-num_layers', default=2, type=int)
parser.add_argument('-hidden_size', default=64, type=int)
parser.add_argument('-window_size', default=10, type=int)
args = parser.parse_args()
num_layers = args.num_layers
hidden_size = args.hidden_size
window_size = args.window_size
model = Model(input_size, hidden_size, num_layers, num_classes).to(device)
seq_dataset = generate('hdfs_train')
dataloader = DataLoader(seq_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
writer = SummaryWriter(log_dir='log/' + log)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
# Train the model
start_time = time.time()
total_step = len(dataloader)
for epoch in range(num_epochs): # Loop over the dataset multiple times
train_loss = 0
for step, (seq, label) in enumerate(dataloader):
# Forward pass
seq = seq.clone().detach().view(-1, window_size, input_size).to(device)
output = model(seq)
loss = criterion(output, label.to(device))
# Backward and optimize
optimizer.zero_grad()
loss.backward()
train_loss += loss.item()
optimizer.step()
writer.add_graph(model, seq)
print('Epoch [{}/{}], train_loss: {:.4f}'.format(epoch + 1, num_epochs, train_loss / total_step))
writer.add_scalar('train_loss', train_loss / total_step, epoch + 1)
elapsed_time = time.time() - start_time
print('elapsed_time: {:.3f}s'.format(elapsed_time))
if not os.path.isdir(model_dir):
os.makedirs(model_dir)
torch.save(model.state_dict(), model_dir + '/' + log + '.pt')
writer.close()
print('Finished Training')
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。