代码拉取完成,页面将自动刷新
import torch
import argparse
from tqdm import tqdm
from torchmetrics import R2Score
import datetime
from src.dataset import PPDataset
from torch.utils.tensorboard import SummaryWriter
from test import ModelAnalyzer, DistanceLoss, BCELoss_simple, CELoss, WeighedBCELoss, WeighedBCELossV2
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dt = datetime.datetime.now()
dt_str = dt.strftime("%y%m%d%H%M%S")
# 训练参数控制
parser = argparse.ArgumentParser(description='PyTorch model trainer')
parser.add_argument('--model', default="unetstft2", help='name of model')
parser.add_argument('--pre', default="fcy2",
help='name of dataset preprocess method')
parser.add_argument('--lr', type=float, default=0.0005,
help='learn rate of optimizer')
parser.add_argument('--epoch_size', type=int, default=10,
help='how much epoch to train')
parser.add_argument('--batch_size', type=int, default=32,
help='number of waveforms in each batch')
parser.add_argument('--data_len', type=int, default=4000,
help='samples in each piece of data')
parser.add_argument('--ds_path', default="E:/RealSeisData/Diting50hz/",
help='name of dataset preprocess method')
if __name__ == "__main__":
# 读取参数
args, unknown = parser.parse_known_args()
model_name = 'src.model_'+args.model
pre_method = args.pre
dataset_path = args.ds_path
lr = args.lr
lr_final = lr/50
batch_size = args.batch_size
epoch_size = args.epoch_size
dlen = args.data_len
# 参数输出
for arg in vars(args):
print(format(arg, '<15'), format(
str(getattr(args, arg)), '<'))
print("device:", device)
writer = SummaryWriter(log_dir="./logs/{}/{}".format(model_name, dt_str))
# 取得数据集
print('loading dataset...')
train_dataset = PPDataset(
dataset_path, "DiTing330km_train.csv", methodmame=pre_method, dlen=dlen)
val_dataset = PPDataset(
dataset_path, "DiTing330km_validation.csv", methodmame=pre_method, dlen=dlen)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, num_workers=8)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=batch_size*8, num_workers=4)
# 导入并定义模型,优化算法
modelpkg = __import__(model_name, fromlist=['Model'])
model = modelpkg.Model().to(device)
# 加载预训练模型
model_pretrain = torch.load("./model/unetstft2_231101154751_ep16.pth")
model.load_state_dict(model_pretrain.state_dict())
# 优化器
# opt = torch.optim.SGD(model.parameters(), lr=lr)
opt = torch.optim.AdamW(
model.parameters(), lr=lr, weight_decay=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
opt, epoch_size, lr_final)
# 指定评价指标
# loss_MAE = torch.nn.L1Loss()
# loss_SmoothMAE = torch.nn.SmoothL1Loss()
loss_R2 = R2Score(num_outputs=6)
# loss_BCE = torch.nn.BCELoss(reduction='none')
loss_CE = CELoss()
loss_BCE = BCELoss_simple()
loss_Distance = DistanceLoss()
loss_WeighedBCE = WeighedBCELossV2(w=0.8)
# P S N 损失权重
# w = torch.ones(1, 2, 1).to(device)
# w[0, 0, 0] = 0.4 #P
# w[0, 1, 0] = 0.6 #S
# w=w * 2 # let mean == 1
analyzer = ModelAnalyzer(
'DiTing330km_validation.csv', 0.2, 25, 'fcy2', dlen=dlen)
# 批量训练
for epoch in range(epoch_size):
print('\n[ epoch {} ]'.format(epoch + 1))
loss_sum = 0
model.train()
pbar = tqdm(train_loader, ncols=0, mininterval=1)
pbar.set_description("train")
for train_x, train_label in pbar:
train_x = train_x.to(device)
train_label = train_label.to(device)
predict_y = model(train_x.float())[:, :2]
# loss1 = (loss_BCE(predict_y, train_label.float())*w).mean()
loss1 = loss_WeighedBCE(predict_y, train_label.float())
# loss1 = loss_BCE(predict_y, train_label.float()).mean()
loss2 = loss_Distance(predict_y, train_label.float())
loss = loss1 #+ loss2 # 最终loss
opt.zero_grad()
loss.backward()
opt.step()
pbar.set_postfix(
loss1='{:.6f}'.format(loss1.item()),
loss2='{:.6f}'.format(loss2.item()), refresh=False)
loss_sum += loss1.item()
print('train bce loss: {:.6f}'.format(loss_sum/len(train_loader)))
# 每次训练迭代后,保存当前模型参数
# The model output location is placed under /model
model_fname = './model/{}_{}_ep{}.pth'.format(
args.model, dt_str, epoch + 1)
torch.save(
model, model_fname)
print(model_fname)
scheduler.step()
print('learn_rate:', scheduler.get_last_lr()[0])
# 每次训练迭代后,使用validation数据评估模型准确率
# loss1_sum = 0
# model.eval()
# pbar = tqdm(val_loader, ncols=0, mininterval=1)
# pbar.set_description("validation")
# for val_x, val_label in pbar:
# val_x = val_x.to(device)
# val_label = val_label.to(device)
# predict_y = model(val_x.float()).detach().to(device)[:, :2]
# loss1 = loss_BCE(predict_y, val_label.float())
# pbar.set_postfix(
# loss1='{:.6f}'.format(loss1.item()), refresh=False)
# loss1_sum += loss1.item()
# print('vali loss1: {:.6f}'.format(loss1_sum/len(val_loader)))
p_recall, p_precision, p_f1, s_recall, s_precision, s_f1 = analyzer.analyse_findpeak(
model_fname)
writer.add_scalar('learn_rate', scheduler.get_last_lr()[0], epoch + 1)
writer.add_scalar('Loss/train', loss_sum/len(train_loader), epoch + 1)
# writer.add_scalar('Loss/val', loss1_sum/len(val_loader), epoch)
writer.add_scalar('Recall/p', p_recall, epoch + 1)
writer.add_scalar('Precision/p', p_precision, epoch + 1)
writer.add_scalar('F1/p', p_f1, epoch + 1)
writer.add_scalar('Recall/s', s_recall, epoch + 1)
writer.add_scalar('Precision/s', s_precision, epoch + 1)
writer.add_scalar('F1/s', s_f1, epoch + 1)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。