1 Star 0 Fork 1

wcj/ERE_pixelaug

forked from Curli Trans/ERE_pixelaug 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train_adap.py 5.70 KB
一键复制 编辑 原始数据 按行查看 历史
Curli Trans 提交于 2023-09-08 14:37 . first
import argparse
# from asyncio import threads
import torch
import model
import numpy as np
import matplotlib.pyplot as plt
from model import two_d_softmax
# from model import nll_across_batch
from landmark_dataset_adap import LandmarkDataset
from utils import prepare_config_output_and_logger
from einops import repeat
# from torchsummary.torchsummary import summary_string
'''
Code design based on Bin Xiao's Deep High Resolution Network Repository:
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
'''
thresholds = [0.6, 0.6, 0.3, 0.7, 0.6, 0.6,
0.5, 0.4, 0.6, 0.8, 0.8, 0.8,
0.7, 0.5, 0.6, 0.8, 0.7, 0.5,
0.7,]
freq_weights = 0.8 / np.array(thresholds)
# thresholds = [1.0, 1.1, 1.2, 1.3, 1.4, 1.0, 1.1, 1.2, 1.3, 1.4, 1.0, 1.1, 1.2, 1.3, 1.4, 1.0, 1.1, 1.2, 1.3, ]
thresholds = [1 -np.array(thresholds), 1 + np.array(thresholds)]
thresholds = np.array(thresholds)
def parse_args():
parser = argparse.ArgumentParser(description='Train a network to detect landmarks')
parser.add_argument('--cfg',
help='The path to the configuration file for the experiment',
default='experiments/cephalometric.yaml',
type=str)
parser.add_argument('--training_images',
help='The path to the training images',
default='/home1/quanquan/datasets/Cephalometric/RawImage/TrainingData/',
type=str,)
parser.add_argument('--annotations',
help='The path to the directory where annotations are stored',
default='/home1/quanquan/datasets/Cephalometric/AnnotationsByMD',
type=str,)
args = parser.parse_args()
return args
def nll_across_batch_weighted(output, target, weights):
nll = -target * torch.log(output.double())
loss = torch.sum(nll, dim=(2,3))
weights = repeat(weights, "b -> b c", c=19).numpy()
thresholds_0 = repeat(thresholds[0], "c -> b c", b=weights.shape[0])
thresholds_1 = repeat(thresholds[1], "c -> b c", b=weights.shape[0])
# import ipdb; ipdb.set_trace()
ist = (weights >= thresholds_0) & (weights <= thresholds_1)
loss = loss * torch.Tensor(ist * freq_weights).cuda()
return torch.mean(loss)
def main():
# get arguments and the experiment file
args = parse_args()
cfg, logger, _, save_model_path, _ = prepare_config_output_and_logger(args.cfg, 'train')
# print the arguments into the log
logger.info("-----------Arguments-----------")
logger.info(vars(args))
logger.info("")
# print the configuration into the log
logger.info("-----------Configuration-----------")
logger.info(cfg)
logger.info("")
# load the train dataset and put it into a loader
training_dataset = LandmarkDataset(args.training_images, args.annotations, cfg.DATASET, perform_augmentation=True)
training_loader = torch.utils.data.DataLoader(training_dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True)
# data = training_dataset.__getitem__(0)
# import ipdb; ipdb.set_trace()
'''
for batch, (image, channels, meta) in enumerate(train_loader):
s = 0
plt.imshow(image[s, 0].detach().numpy(), cmap='gray')
squashed_channels = np.max(channels[s].detach().numpy(), axis=0)
plt.imshow(squashed_channels, cmap='inferno', alpha=0.5)
landmarks_per_annotator = meta['landmarks_per_annotator'].detach().numpy()[s]
averaged_landmarks = np.mean(landmarks_per_annotator, axis=0)
for i, position in enumerate(averaged_landmarks):
plt.text(position[0], position[1], "{}".format(i + 1), color="yellow", fontsize="small")
plt.show()
'''
use_smp_model = True
model = eval("model." + cfg.MODEL.NAME)(cfg.MODEL, cfg.DATASET.KEY_POINTS, smp_model=use_smp_model).cuda()
if not use_smp_model:
# model.load("/home1/quanquan/code/landmark/code/runs/ssl/ssl_pos_ip/debug/ckpt_v/model_best.pth")
model.load("/home1/quanquan/code/landmark/code/runs/ssl/interpolate/collect_sim/ckpt_v/model_best.pth")
else:
model.load_smp_model("/home1/quanquan/code/landmark/code/runs/ssl2/ssl_smp/debug_384/ckpt_v/model_best.pth")
logger.info("-----------Model Summary-----------")
# model_summary, _ = summary_string(model, (1, *cfg.DATASET.CACHED_IMAGE_SIZE))
# logger.info(model_summary)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.TRAIN.LR)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[4, 6, 8], gamma=0.1)
for epoch in range(cfg.TRAIN.EPOCHS):
logger.info('-----------Epoch {} Training-----------'.format(epoch))
model.train()
losses_per_epoch = []
for batch, (image, channels, meta) in enumerate(training_loader):
# Put image and channels onto gpu
image = image.cuda()
# print(image.shape)
channels = channels.cuda()
mul = meta['mul']
output = model(image.float())
output = two_d_softmax(output)
optimizer.zero_grad()
loss = nll_across_batch_weighted(output, channels, mul)
# import ipdb; ipdb.set_trace()
loss.backward()
optimizer.step()
losses_per_epoch.append(loss.item())
if (batch + 1) % 5 == 0:
logger.info("[{}/{}]\tLoss: {:.3f}".format(batch + 1, len(training_loader), np.mean(losses_per_epoch)))
scheduler.step()
logger.info("Saving Model's State Dict to {}".format(save_model_path))
torch.save(model.state_dict(), save_model_path)
logger.info("-----------Training Complete-----------")
if __name__ == '__main__':
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/wcj6/ere_pixelaug.git
git@gitee.com:wcj6/ere_pixelaug.git
wcj6
ere_pixelaug
ERE_pixelaug
master

搜索帮助