1 Star 0 Fork 0

DFTL/Qizhi

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train_test_4.py 13.33 KB
一键复制 编辑 原始数据 按行查看 历史
DFTL 提交于 2023-12-13 11:16 . modify path
# -*- coding = utf-8 -*-
'''
# @time:2023/4/8 10:57
# Author:DFTL
# @File:test.py
'''
import argparse
import os
import imageio
import cv2
import numpy as np
import json
from sklearn.metrics import precision_score, recall_score, f1_score, cohen_kappa_score
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
from torch.utils.data import Dataset, DataLoader
from torch.optim import SGD, Adam
from torch.nn import CrossEntropyLoss
import torch
from utils.TT_Dataset import MyDataset
from MyModel.DFNet import DFNet
# from Models.model0603 import KRModel
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class_name = ['Seagrass bed',
'Spartina alterniflora',
'Reed',
'Tamarix',
'Tidal flat',
'Sparse vegetation',
'Sea',
'Yellow River',
'Pond',
'Cloud']
# entities_dict = {name:i for i,name in enumerate(class_name)}
# relation_dict = {rel:i for i,rel in enumerate()}
def z_score_normal(image_data):
B1, B2, B3, B4 = cv2.split(image_data)
B_mean = np.mean(B1)
B_std = np.std(B1)
B1_normalization = ((B1 - B_mean) / B_std).astype('float32')
B_mean = np.mean(B2)
B_std = np.std(B2)
B2_normalization = ((B2 - B_mean) / B_std).astype('float32')
B_mean = np.mean(B3)
B_std = np.std(B3)
B3_normalization = ((B3 - B_mean) / B_std).astype('float32')
B_mean = np.mean(B4)
B_std = np.std(B4)
B4_normalization = ((B4 - B_mean) / B_std).astype('float32')
image_data = cv2.merge([B1_normalization, B2_normalization, B3_normalization, B4_normalization])
return image_data
def train(model, data_loader, optimizer1, criterion, args):
model.train()
for epoch in range(args.epoch):
print('======================epoch:{}/{}========================='.format(epoch, args.epoch))
epoch_loss = 0
i = 0
# print(len(data_loader))
for data in data_loader:
i += 1
# img,label, super_mask = data
# img, label, super_mask = img.to(device), label.to(device), super_mask.to(device)
img, label = data
img, label = img.to(device), label.to(device)
#整除批次
if img.shape[0]!=args.batch:
break;
# output,loss_s = model(img,super_mask,label)
output, soft, triple = model(img)
optimizer1.zero_grad()
loss1 = criterion(output, label.long()) # +loss_s
loss2 = criterion(soft, label.long())
label_ = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).to(device)
label_ = label_.repeat(args.batch, 1).view(-1)
loss3 = criterion(triple.view(-1,args.num_class), label_)
loss = (loss1 + loss2 + loss3) / 3
epoch_loss += loss
loss.backward()
optimizer1.step()
print('\r', 'step: ', i, ' loss: {:.6f}'.format(loss.data), end='', flush=True)
avg_loss = epoch_loss / len(data_loader)
print('\n avg_loss:{:.6f} \n'.format(avg_loss.data))
'''save models'''
#if avg_loss < loss_min:
if epoch % 5 == 0:
#loss_min = avg_loss
weight_name = 'epoch_' + str(epoch) + '_loss_' + str('{:.6f}'.format(avg_loss.data)) + '.pt'
torch.save(model.state_dict(), os.path.join(args.weights_path, weight_name))
print('epoch: {} | loss: {:.6f} | Saving model... \n'.format(epoch, avg_loss.data))
# ========================================================================================================================================
def model_predict(model, img_data, img_size, kg=None):
model.eval()
row, col, dep = img_data.shape
if row % img_size != 0 or col % img_size != 0:
# 计算填充后图像的 hight 和 width
padding_h = (row // img_size + 1) * img_size
padding_w = (col // img_size + 1) * img_size
else:
# 不填充后图像的 hight 和 width
padding_h = (row // img_size) * img_size
padding_w = (col // img_size) * img_size
# 初始化一个 0 矩阵,将图像的值赋值到 0 矩阵的对应位置
padding_img = np.zeros((padding_h, padding_w, dep), dtype='float32')
padding_img[:row, :col, :] = img_data[:row, :col, :]
# 初始化一个 0 矩阵,用于将预测结果的值赋值到 0 矩阵的对应位置
padding_pre = np.zeros((padding_h, padding_w), dtype='uint8')
#接收基线返回结果 -8.27
padding_base = np.zeros((padding_h, padding_w), dtype='uint8')
# 对 img_size * img_size 大小的图像进行预测
count = 0 # 用于计数
for i in list(np.arange(0, padding_h, img_size)):
if (i + img_size) > padding_h:
continue
for j in list(np.arange(0, padding_w, img_size)):
if (j + img_size) > padding_w:
continue
# 取 img_size 大小的图像,在第一维添加维度,变成四维张量,用于模型预测
img_data_ = padding_img[i:i + img_size, j:j + img_size, :]
img_data_ = img_data_[np.newaxis, :, :, :]
img_data_ = np.transpose(img_data_, (0, 3, 1, 2))
img_data_ = torch.from_numpy(img_data_).to(device)
# 预测,对结果进行处理,添加基线返回结果 -8.27
y_pre,output_base,_ = model(img_data_)
y_pre = torch.squeeze(y_pre, dim=0)
y_pre = torch.argmax(y_pre, dim=0)
#粗分割结果作为基线结果返回
# output_base = torch.squeeze(output_base, dim=0)
# output_base = torch.argmax(output_base,dim=0)
# 将预测结果的值赋值到 0 矩阵的对应位置
padding_pre[i:i + img_size, j:j + img_size] = y_pre[:img_size, :img_size].cpu().detach().numpy()
# padding_base[i:i + img_size, j:j + img_size] = output_base[:img_size, :img_size].cpu().detach().numpy()
count += 1 # 每预测一块就+1
return padding_pre[:row, :col]#,padding_base[:row, :col]
# ========================================================================================================================================
def calculation(y_label, y_pre, row, col):
'''
本函数主要计算以下评估标准的值:
1、精准率
2、召回率
3、F1分数
'''
# 转成列向量
y_label = np.reshape(y_label, (row * col, 1))
y_pre = np.reshape(y_pre, (row * col, 1))
y_label.astype('float64')
y_pre.astype('float64')
# 精准率
precision = precision_score(y_label, y_pre, average=None)
# 召回率
recall = recall_score(y_label, y_pre, average=None)
# F1
f1 = f1_score(y_label, y_pre, average=None)
# kappa
kappa = cohen_kappa_score(y_label, y_pre)
return precision, recall, f1, kappa
# ========================================================================================================================================
def estimate(y_label, y_pred, model_hdf5_name, class_name, dirname):
'''
本函数主要实现以下功能:
1、计算准确率
2、将各种评估指标存成一个json格式的txt文件
@parameter:
y_label:标签
y_pred:预测结果
model_hdf5_name:模型名
class_name:类型
dirname:存放路径
'''
# 准确率
acc = np.mean(np.equal(y_label, y_pred) + 0)
print('=================================================================================================')
print('The estimate result of {} are as follows:'.format(model_hdf5_name))
print('The acc of {} is {}'.format(model_hdf5_name, acc))
precision, recall, f1, kappa = calculation(y_label, y_pred, y_label.shape[0], y_label.shape[1])
for i in range(len(class_name)):
print('{} F1: {:.5f}, Precision: {:.5f}, Recall: {:.5f}, kappa: {:.5f}'.format(class_name[i], f1[i],
precision[i], recall[i],
kappa))
# print('=================================================================================================')
if len(f1) == len(class_name):
result = {}
for i in range(len(class_name)):
result[class_name[i]] = []
tmp = {}
tmp['Recall'] = str(round(recall[i], 5))
tmp['Precision'] = str(round(precision[i], 5))
tmp['F1'] = str(round(f1[i], 5))
result[class_name[i]].append(tmp)
result['Model Name'] = [model_hdf5_name]
result['Accuracy'] = str(round(acc, 5))
result['kappa'] = str(kappa)
# 写入txt
txt_name = "epoch_" + model_hdf5_name.split("_")[1] + "_acc_" + str(round(acc, 5))
with open(os.path.join(dirname, txt_name + '.txt'), 'a', encoding="utf-8") as f:
f.write(json.dumps(result, ensure_ascii=False))
else:
print("======================================>Estimate error!===========================================")
return acc
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Model Controller')
parser.add_argument('--mode', type=str, default='pre_train', help='pre_train/test/train/final_test')
parser.add_argument('--image_path', type=str, default=r"/tmp/dataset/cut_train_images")
parser.add_argument('--label_path', type=str, default=r"/tmp/dataset/cut_train_labels")
parser.add_argument('--weights_path', type=str, default='/tmp/output/DFNet_Unet_checkpoint', help='the path saving weights')
parser.add_argument('--result_path', type=str, default='/tmp/output/DFNet_Unet_result', help='the path saving result')
parser.add_argument('--log_file',type=str,default='/tmp/output/DFNet_UNet.txt',help='the log of training process')
parser.add_argument('--val_image_path', type=str, default=r'/tmp/dataset/image.tif', help='val_imageset path')
parser.add_argument('--val_label_path', type=str, default=r'/tmp/dataset/rastercalc1.tif', help='val_labelset path')
parser.add_argument('--in_ch', type=int, default=4)
parser.add_argument('--num_class', type=int, default=10)
# parser.add_argument('--num_superpixel', type=int, default=64)
parser.add_argument('--feat_dim', type=int, default=128)
parser.add_argument('--batch', type=int, default=32)
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--epoch', type=int, default=200, help='train_epochs')
args = parser.parse_args()
criterion = CrossEntropyLoss()
# 加载数据集
# mydataset = Dataset_df.MyDataset(args.image_path,args.label_path)
# mydataset = Dataset(args.image_path,args.label_path)
mydataset = MyDataset(args.image_path,args.label_path)
data_loader = DataLoader(dataset=mydataset, batch_size=args.batch, shuffle=True, pin_memory=True)
print("The images in Dataset: %d" %len(mydataset))
# model = KRModel(num_class=10).to(device)
model = DFNet(args).to(device)
print(model)
total = sum([param.nelement() for param in model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
if not os.path.exists(args.weights_path):
os.makedirs(args.weights_path)
if not os.path.exists(args.result_path):
os.makedirs(args.result_path)
'''优化器更新参数'''
# optimizer = SGD(model.parameters(),lr=args.lr)
# optimizer1 = Adam(model.parameters(),lr=args.lr)
optimizer1 = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99))
# 开始训练
# 验证训练好的模型-8.27
train(model, data_loader, optimizer1, criterion, args)
# 验证
label_data = imageio.imread(args.val_label_path) - 1
image_data = imageio.imread(args.val_image_path)
image_data = z_score_normal(image_data)
weights = os.listdir(args.weights_path)
best_acc = 0
for w in weights:
model.load_state_dict(torch.load(os.path.join(args.weights_path, w), map_location=device))
# 原图预测,添加基线返回结果
output = model_predict(model, image_data, img_size=128)
acc = estimate(label_data, output, w, class_name, args.result_path)
# acc2 = estimate(label_data, output_base, w, class_name, args.result_path)
print('The acc of {} is {}'.format(w, acc))
if acc > best_acc:
best_acc = acc
best_weight = w
else:
os.remove(os.path.join(args.weights_path, w))
# 保存
save_name = "epoch_" + w.split("_")[1] + "_acc_" + str(acc) + ".tif"
imageio.imwrite(os.path.join(args.result_path, save_name), output_base)
# print(save_name)
print("Sucessfully saved to " + os.path.join(args.result_path, save_name))
print('=================================================================================================')
print("#################### Evaluate Finshed #################### ")
print("the best acc is {} of {}".format(best_acc, best_weight))
with open('/tmp/output/Backbone.txt', 'a', encoding="utf-8") as f:
# f.write("the eval model is {} \n".format(model_name))
f.write("--------------Backbone-Unet-----------------\n")
f.write("the best acc is {} of {} \n".format(best_acc, best_weight))
f.write("--------------------------------------------\n")
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/dftl/qizhi.git
git@gitee.com:dftl/qizhi.git
dftl
qizhi
Qizhi
master

搜索帮助