代码拉取完成,页面将自动刷新
同步操作将从 Curli Trans/ERE_pixelaug 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import argparse
# from asyncio import threads
import torch
import model
import numpy as np
import matplotlib.pyplot as plt
from model import two_d_softmax
# from model import nll_across_batch
from landmark_dataset_adap import LandmarkDataset
from utils import prepare_config_output_and_logger
from einops import repeat
# from torchsummary.torchsummary import summary_string
'''
Code design based on Bin Xiao's Deep High Resolution Network Repository:
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
'''
thresholds = [0.6, 0.6, 0.3, 0.7, 0.6, 0.6,
0.5, 0.4, 0.6, 0.8, 0.8, 0.8,
0.7, 0.5, 0.6, 0.8, 0.7, 0.5,
0.7,]
freq_weights = 0.8 / np.array(thresholds)
# thresholds = [1.0, 1.1, 1.2, 1.3, 1.4, 1.0, 1.1, 1.2, 1.3, 1.4, 1.0, 1.1, 1.2, 1.3, 1.4, 1.0, 1.1, 1.2, 1.3, ]
thresholds = [1 -np.array(thresholds), 1 + np.array(thresholds)]
thresholds = np.array(thresholds)
def parse_args():
parser = argparse.ArgumentParser(description='Train a network to detect landmarks')
parser.add_argument('--cfg',
help='The path to the configuration file for the experiment',
default='experiments/cephalometric.yaml',
type=str)
parser.add_argument('--training_images',
help='The path to the training images',
default='/home1/quanquan/datasets/Cephalometric/RawImage/TrainingData/',
type=str,)
parser.add_argument('--annotations',
help='The path to the directory where annotations are stored',
default='/home1/quanquan/datasets/Cephalometric/AnnotationsByMD',
type=str,)
args = parser.parse_args()
return args
def nll_across_batch_weighted(output, target, weights):
nll = -target * torch.log(output.double())
loss = torch.sum(nll, dim=(2,3))
weights = repeat(weights, "b -> b c", c=19).numpy()
thresholds_0 = repeat(thresholds[0], "c -> b c", b=weights.shape[0])
thresholds_1 = repeat(thresholds[1], "c -> b c", b=weights.shape[0])
# import ipdb; ipdb.set_trace()
ist = (weights >= thresholds_0) & (weights <= thresholds_1)
loss = loss * torch.Tensor(ist * freq_weights).cuda()
return torch.mean(loss)
def main():
# get arguments and the experiment file
args = parse_args()
cfg, logger, _, save_model_path, _ = prepare_config_output_and_logger(args.cfg, 'train')
# print the arguments into the log
logger.info("-----------Arguments-----------")
logger.info(vars(args))
logger.info("")
# print the configuration into the log
logger.info("-----------Configuration-----------")
logger.info(cfg)
logger.info("")
# load the train dataset and put it into a loader
training_dataset = LandmarkDataset(args.training_images, args.annotations, cfg.DATASET, perform_augmentation=True)
training_loader = torch.utils.data.DataLoader(training_dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True)
# data = training_dataset.__getitem__(0)
# import ipdb; ipdb.set_trace()
'''
for batch, (image, channels, meta) in enumerate(train_loader):
s = 0
plt.imshow(image[s, 0].detach().numpy(), cmap='gray')
squashed_channels = np.max(channels[s].detach().numpy(), axis=0)
plt.imshow(squashed_channels, cmap='inferno', alpha=0.5)
landmarks_per_annotator = meta['landmarks_per_annotator'].detach().numpy()[s]
averaged_landmarks = np.mean(landmarks_per_annotator, axis=0)
for i, position in enumerate(averaged_landmarks):
plt.text(position[0], position[1], "{}".format(i + 1), color="yellow", fontsize="small")
plt.show()
'''
use_smp_model = True
model = eval("model." + cfg.MODEL.NAME)(cfg.MODEL, cfg.DATASET.KEY_POINTS, smp_model=use_smp_model).cuda()
if not use_smp_model:
# model.load("/home1/quanquan/code/landmark/code/runs/ssl/ssl_pos_ip/debug/ckpt_v/model_best.pth")
model.load("/home1/quanquan/code/landmark/code/runs/ssl/interpolate/collect_sim/ckpt_v/model_best.pth")
else:
model.load_smp_model("/home1/quanquan/code/landmark/code/runs/ssl2/ssl_smp/debug_384/ckpt_v/model_best.pth")
logger.info("-----------Model Summary-----------")
# model_summary, _ = summary_string(model, (1, *cfg.DATASET.CACHED_IMAGE_SIZE))
# logger.info(model_summary)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.TRAIN.LR)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[4, 6, 8], gamma=0.1)
for epoch in range(cfg.TRAIN.EPOCHS):
logger.info('-----------Epoch {} Training-----------'.format(epoch))
model.train()
losses_per_epoch = []
for batch, (image, channels, meta) in enumerate(training_loader):
# Put image and channels onto gpu
image = image.cuda()
# print(image.shape)
channels = channels.cuda()
mul = meta['mul']
output = model(image.float())
output = two_d_softmax(output)
optimizer.zero_grad()
loss = nll_across_batch_weighted(output, channels, mul)
# import ipdb; ipdb.set_trace()
loss.backward()
optimizer.step()
losses_per_epoch.append(loss.item())
if (batch + 1) % 5 == 0:
logger.info("[{}/{}]\tLoss: {:.3f}".format(batch + 1, len(training_loader), np.mean(losses_per_epoch)))
scheduler.step()
logger.info("Saving Model's State Dict to {}".format(save_model_path))
torch.save(model.state_dict(), save_model_path)
logger.info("-----------Training Complete-----------")
if __name__ == '__main__':
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。