代码拉取完成,页面将自动刷新
import torch
from torch import nn
from torch.nn import functional as F
import math
import numpy as np
from net import TransformerModel
def train(model, epochs):
train_data = np.load("./npys/train_secs.npy", allow_pickle=True)
train_labels = np.load("./npys/train_labs.npy", allow_pickle=True)
train_data = torch.tensor(train_data, dtype=torch.float32).to("cuda")
train_labels = torch.tensor(train_labels, dtype=torch.float32).to("cuda")
test_data = np.load("./npys/test_secs.npy", allow_pickle=True)
test_labels = np.load("./npys/test_labs.npy", allow_pickle=True)
test_data = torch.tensor(test_data, dtype=torch.float32).to("cuda")
test_labels = torch.tensor(test_labels, dtype=torch.float32).to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss().to("cuda")
batch_size = train_data.shape[0]
loss = 0
losses = []
acc = []
for i in range(epochs):
for j in range(train_data.shape[0]):
y_pred = model(train_data[j])
loss = criterion(y_pred, train_labels[j])
losses.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
temp_acc = []
for k in range(batch_size):
y_pred[k][0] = 0 if y_pred[k][0] < y_pred[k][1] else 1
if y_pred[k][0] == train_labels[j][k][0]:
temp_acc.append(1)
else:
temp_acc.append(0)
temp_acc = np.array(temp_acc, dtype="float32")
acc.append(np.sum(temp_acc) / temp_acc.shape[0])
train_acc = float(np.average(np.array(acc, dtype="float32")))
print('epoch:%d loss:%.5f acc:%.5f' % (i, loss.item(), train_acc))
# 每10个epochs检查一次测试集的准确率
if (i + 1) % 10 == 0:
print('Start eval in test dataset...')
acc = []
with torch.no_grad():
for j in range(test_data.shape[0]):
y_pred = model(test_data[j])
temp_acc = []
for k in range(batch_size):
y_pred[k][0] = 0 if y_pred[k][0] < y_pred[k][1] else 1
if y_pred[k][0] == test_labels[j][k][0]:
temp_acc.append(1)
else:
temp_acc.append(0)
temp_acc = np.array(temp_acc, dtype="float32")
acc.append(np.sum(temp_acc) / temp_acc.shape[0])
test_acc = float(np.average(np.array(acc, dtype="float32")))
print('test_acc:%.5f' % (test_acc))
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(state, 'models/TransformerModel.pkl')
def test(model, weight):
model.load_state_dict(torch.load(weight)['model'])
test_data = np.load("./npys/test_secs.npy", allow_pickle=True)
test_labels = np.load("./npys/test_labs.npy", allow_pickle=True)
test_data = torch.tensor(test_data, dtype=torch.float32).to("cuda")
test_labels = torch.tensor(test_labels, dtype=torch.float32).to("cuda")
acc = []
batch_size = test_data.shape[0]
with torch.no_grad():
for j in range(test_data.shape[0]):
temp_acc = []
y_pred = model(test_data[j])
for k in range(batch_size):
y_pred[k][0] = 0 if y_pred[k][0] < y_pred[k][1] else 1
if y_pred[k][0] == test_labels[j][k][0]:
temp_acc.append(1)
else:
temp_acc.append(0)
temp_acc = np.array(temp_acc, dtype="float32")
acc.append(np.sum(temp_acc) / temp_acc.shape[0])
test_acc = float(np.average(np.array(acc, dtype="float32")))
print('test_acc:%.5f' % (test_acc))
if __name__ == "__main__":
epochs = 100
model = TransformerModel().to("cuda")
weight_path = "./models/TransformerModel.pkl"
# train(model, epochs = epochs)
test(model, weight_path)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。