代码拉取完成,页面将自动刷新
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import copy
import csv
import hyperparameters
from processing_data import casia_wav_route, segment_dataset, read_json
from data_set import make_dataset
from model import Light_SERNet_V1
from optim_utils import ExpLR
############################## 预处理数据集 ##############################
print("*"*25, " 开始预处理数据集 ", "*"*25)
print(
"""
简单说明: 1. 统计数据集的信息,在 data 目录下生成 json 文件(文件存在则跳过)
2. 根据生成的信息,将数据集划分为训练集与测试集,并规范每个音频文件为指定长度,划分好的数据重新存在 data 目录下(已划分则跳过)
"""
)
casia_wav_route()
segment_dataset(read_json("data/casia_wav_route.json"))
print("*"*25, " 数据集预处理完成 ", "*"*25)
##########################################################################
############################## 训练模型之前的一些准备 ##############################
# 设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 数据加载
train_dataset, test_dataset = make_dataset("data/casia_4_segment")
train_loader = DataLoader(train_dataset, hyperparameters.BATCH_SIZE)
test_loader = DataLoader(test_dataset, hyperparameters.BATCH_SIZE)
# 模型
net = Light_SERNet_V1(len(hyperparameters.CASIA_LABELS)).to(device)
# 损失函数
criterion = nn.CrossEntropyLoss().to(device)
# 优化器
optimizer = optim.Adam(net.parameters(), lr=hyperparameters.LEARNING_RATE)
# 学习率做动态调整
# 50 个 epoch 后,每 20 个epoch 对学习率进行调整,lr = lr * exp(?)
scheduler = ExpLR(optimizer, hyperparameters.LEARNING_RATE_DECAY_STEP, hyperparameters.LEARNING_RATE_DECAY_PARAMETERS)
# 最佳的验证损失,先设置为最大
best_val_loss = float("inf")
# 最佳模型
best_net = None
####################################################################################
# 计算损失和精确度
def cal_acc(net, data_load, loss):
""" 计算损失和精确度
inputs:
net: 训练的模型
data_load: 需要计算的数据
loss: 模型使用的损失函数
output:
先返回模型精确度,再返回模型的平均损失
"""
net.eval()
all_item = 0.0
acc_item = 0.0
all_loss = 0.0
for x, y in data_load:
all_item += y.shape[0]
x, y = x.to(device), y.to(device)
y_hat = net(x)
all_loss += loss(y_hat, y).item()
y_hat = y_hat.argmax(dim=1)
acc_item += torch.eq(y, y_hat).sum().item()
return acc_item / all_item, all_loss / all_item
# 开始训练
print("-"*25, " 开始训练 ", "-"*25)
with open("train_log.csv", "w") as f: # 记录训练日志
w_log = csv.writer(f)
w_log.writerow(["Epoch", "train_loss", "train_acc", "test_loss", "test_acc", "lr"])
for epoch in range(hyperparameters.EPOCHS):
net.train()
for x, y in tqdm(train_loader):
x, y = x.to(device), y.to(device)
y_hat = net(x)
loss = criterion(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_acc, train_loss = cal_acc(net, train_loader, criterion)
test_acc, test_loss = cal_acc(net, test_loader, criterion)
print(
"Epoch {:3d}/{:3d} train_loss: {:5.6f} train_acc: {:5.6f} test_loss: {:5.6f} test_acc: {:5.6f} lr: {:5.6f}".format(
epoch + 1, hyperparameters.EPOCHS, train_loss, train_acc, test_loss, test_acc, scheduler.get_last_lr()[0]
)
)
w_log.writerow([
epoch + 1, train_loss, train_acc, test_loss, test_acc, scheduler.get_last_lr()[0]
])
if test_loss < best_val_loss:
best_val_loss = test_loss
best_net = copy.deepcopy(net)
if epoch >= hyperparameters.LEARNING_RATE_DECAY_STRATPOINT - hyperparameters.LEARNING_RATE_DECAY_STEP:
scheduler.step() # 调整学习率
# 保存最佳模型
torch.save(best_net.state_dict(), "Light_SERNet_V1.pth")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。