代码拉取完成,页面将自动刷新
import numpy as np
import torch
from torch import nn
from PIL import Image
import torchvision.transforms as transforms
from PIL import ImageFile
from torch.utils.data import Dataset
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
import matplotlib as mlb
import pandas as pd
import os
from torch.utils.data import TensorDataset, DataLoader
from gdordbfcnl import myFcn
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
ImageFile.LOAD_TRUNCATED_IMAGES = True
train_dataset1 = pd.read_csv("reduced_data.csv", index_col=None)
train_dataset2 = pd.read_csv("reduced_gotmd.csv", index_col=None)
train_dataset = pd.concat([ train_dataset2,train_dataset1], ignore_index=True)
# first_row = train_dataset.iloc[0]
# print(first_row)
train_dataset = train_dataset.iloc[:276352]
features = train_dataset.drop(['label', 'dborgot'], axis=1)
labels = train_dataset['label']
dborgotlabels = train_dataset['dborgot']
# 区分两个数据集的标签
features = np.array(features)
labels = np.array(labels)
dborgotlabels = np.array(dborgotlabels)
# print(labels.shape)
# print(features.shape)
# exit()
features = torch.tensor(features, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.float32)
dborgotlabels = torch.tensor(dborgotlabels, dtype=torch.float32)
# dataset = TensorDataset(features, labels)
dataset = TensorDataset(features, dborgotlabels)
train_loader = DataLoader(dataset=dataset,
batch_size=128,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=128,
shuffle=True)
device = "cuda"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# cuda
model = myFcn().to(device)
# 损失函数,交叉熵
Loss_fn = nn.CrossEntropyLoss().to(device)
Yh_fn = torch.optim.SGD(model.parameters(), lr=5e-4,
momentum=0.9) # ,weight_decay=1e-2,nesterov=True)#,weight_decay=2,nesterov=True
lr_scheduler = lr_scheduler.StepLR(Yh_fn, step_size=1, gamma=0.01)
train_loss_list = []
train_acc_list = []
val_loss_list = []
val_acc_list = []
all_x_db = []
all_y_db = []
all_x_gotmd = []
all_y_gotmd = []
def mytrain(train_data, model, Loss_fn, Yh_fn, n_epoch):
total_loss, total_acc, n = 0.0, 0.0, 0
n = 0
# i = 0
for batch, (x, y) in enumerate(train_data):
x, y = x.to(device), y.to(device)
# print(batch)
# 执行
model.set_val(n_epoch, batch)
output = model(x)
Loss = Loss_fn(output, y.long())
MaxValue, Pred_idx = torch.max(output, axis=1)
x_db = x[Pred_idx == 0].to(device)
y_db = y[Pred_idx == 0].to(device)
x_gotmd = x[Pred_idx == 1].to(device)
y_gotmd = y[Pred_idx == 1].to(device)
cur_acc = torch.sum(y == Pred_idx) / output.shape[0]
all_x_db.append(x_db)
all_y_db.append(y_db)
all_x_gotmd.append(x_gotmd)
all_y_gotmd.append(y_gotmd)
Yh_fn.zero_grad() # x.grad=0,x是output,y
Loss.backward() # cur_loss是交叉熵函数,cur_loss求导
# 更新全部权重及偏置参数
Yh_fn.step()
# 总误差
total_loss += Loss.item()
total_acc += cur_acc.item()
n = n + 1
print('batch:' + str(batch) + ',n=' + str(n))
# print("333")
# exit()
return (total_acc / n) , total_loss / n
# print("训练:总误差:"+str(total_loss)+',平均误差'+str(total_loss/n)+',n='+str(n))
# print("训练:总体准确率:"+str(total_acc)+',平均准确率'+str(total_acc/n)+',n='+str(n))
# 模型验证
def valM(train_data, model, loss_fn):
model.eval()
loss, current, n = 0.0, 0.0, 0
with torch.no_grad():
for batch, (x, y) in enumerate(train_data):
x, y = x.to(device), y.to(device)
output = model(x)
cur_loss = loss_fn(output, y.long())
# cur_loss = loss_fn(output, y.unsqueeze(1).unsqueeze(2).long())
_, pred = torch.max(output, axis=1)
cur_acc = torch.sum(y == pred) / output.shape[0]
loss += cur_loss.item()
current += cur_acc.item()
n = n + 1
return current / n, loss / n
epoch = 100 # 训练次数
min_acc = 0
for t in range(epoch):
print(f'批次{t + 1}训练:')
# 训练
# print(train_loader.shape)
t_a, t_loss = mytrain(train_loader, model, Loss_fn, Yh_fn, t)
# print("333")
# exit()
train_loss_list.append(t_loss)
train_acc_list.append(t_a)
print('训练正确率:' + str(t_a) + ',训练Loss:' + str(t_loss))
# 验证
v_a, v_loss = valM(test_loader, model, Loss_fn)
val_loss_list.append(v_loss)
val_acc_list.append(v_a)
print('验证:历史正确率:' + str(min_acc) + ',最新正确率:' + str(v_a) + ',Loss=' + str(v_loss))
if v_a > min_acc:
# 保存模型
# print('验证:上次正确率:'+str(min_acc)+',最新正确率:'+str(a))
min_acc = v_a
# torch.save(model.state_dict(), 'best_model_flower.pth')
# lr_scheduler.step()
mlb.rcParams['font.family'] = 'SimHei' # 'STKAITI'——字体
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
print("printing")
plt.plot(train_acc_list, label='训练准确率')
plt.plot(val_acc_list, label='验证准确率')
plt.legend(loc='lower right')
plt.title('训练、验证准确率')
plt.subplot(1, 2, 2)
plt.plot(train_loss_list, label='训练误差')
plt.plot(val_loss_list, label='验证误差')
plt.legend(loc='lower right')
plt.title('训练、验证损失')
plt.show()
print('done')
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。