代码拉取完成,页面将自动刷新
"""
Author: Benny
Date: Nov 2019
"""
import datetime
import importlib
import logging
import os
import shutil
from pathlib import Path
import hydra
import numpy as np
import torch
from omegaconf import OmegaConf
from torch.optim.lr_scheduler import MultiStepLR
from tqdm import tqdm
import provider
from dataset import PartNormalDataset
seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37],
'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49],
'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table}
for cat in seg_classes.keys():
for label in seg_classes[cat]:
seg_label_to_cat[label] = cat
def inplace_relu(m):
classname = m.__class__.__name__
if classname.find('ReLU') != -1:
m.inplace=True
def to_categorical(y, num_classes):
""" 1-hot encodes a tensor """
new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
if (y.is_cuda):
return new_y.cuda()
return new_y
@hydra.main(config_path='config', config_name='partseg')
def main(cfg):
OmegaConf.set_struct(cfg, False)
'''HYPER PARAMETER'''
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu)
logger = logging.getLogger(__name__)
print(OmegaConf.to_yaml(cfg))
print('\n')
root = hydra.utils.to_absolute_path('data/shapenetcore_partanno_segmentation_benchmark_v0_normal')
TRAIN_DATASET = PartNormalDataset(root=root, npoints=cfg.num_point, split='trainval', normal_channel=cfg.normal)
trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, drop_last=True)
TEST_DATASET = PartNormalDataset(root=root, npoints=cfg.num_point, split='test', normal_channel=cfg.normal)
testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)
'''MODEL LOADING'''
cfg.input_dim = (6 if cfg.normal else 3) + 16
cfg.num_class = 50
num_category = 16
num_part = cfg.num_class
shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(cfg.model.name)), '.')
model = getattr(importlib.import_module('models.{}.model'.format(cfg.model.name)), 'PointTransformerSeg')(cfg).cuda()
criterion = torch.nn.CrossEntropyLoss()
try:
checkpoint = torch.load('best_model.pth')
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
logger.info('Use pretrain model')
except:
logger.info('No existing model, starting training from scratch...')
start_epoch = 0
# if cfg.optimizer == 'Adam':
# optimizer = torch.optim.Adam(
# model.parameters(),
# lr=cfg.learning_rate,
# betas=(0.9, 0.999),
# eps=1e-08,
# weight_decay=cfg.weight_decay
# )
# else:
# optimizer = torch.optim.SGD(model.parameters(), lr=cfg.learning_rate, momentum=0.9)
if cfg.optimizer == 'Adam':
optimizer = torch.optim.Adam(
model.parameters(),
lr=cfg.learning_rate,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=cfg.weight_decay
)
else:
optimizer = torch.optim.SGD(
model.parameters(),
lr=cfg.learning_rate,
momentum=0.9,
weight_decay=0.0001
)
scheduler = MultiStepLR(
optimizer,
milestones = [120,180],
gamma=cfg.scheduler_gamma
)
def bn_momentum_adjust(m, momentum):
if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
m.momentum = momentum
LEARNING_RATE_CLIP = 1e-5
MOMENTUM_ORIGINAL = 0.1
MOMENTUM_DECCAY = 0.5
MOMENTUM_DECCAY_STEP = cfg.step_size
best_acc = 0
global_epoch = 0
best_class_avg_iou = 0
best_inctance_avg_iou = 0
for epoch in range(start_epoch, cfg.epoch):
mean_correct = []
logger.info('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, cfg.epoch))
'''Adjust learning rate and BN momentum'''
lr = max(cfg.learning_rate * (cfg.lr_decay ** (epoch // cfg.step_size)), LEARNING_RATE_CLIP)
logger.info('Learning rate:%f' % lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))
if momentum < 0.01:
momentum = 0.01
print('BN momentum updated to: %f' % momentum)
model = model.apply(lambda x: bn_momentum_adjust(x, momentum))
model = model.train()
'''learning one epoch'''
for i, (points, label, target) in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9):
points = points.data.numpy()
points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
points = torch.Tensor(points)
points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda()
optimizer.zero_grad()
seg_pred = model(torch.cat([points, to_categorical(label, num_category).repeat(1, points.shape[1], 1)], -1))
seg_pred = seg_pred.contiguous().view(-1, num_part)
target = target.view(-1, 1)[:, 0]
pred_choice = seg_pred.data.max(1)[1]
correct = pred_choice.eq(target.data).cpu().sum()
mean_correct.append(correct.item() / (cfg.batch_size * cfg.num_point))
loss = criterion(seg_pred, target)
loss.backward()
optimizer.step()
train_instance_acc = np.mean(mean_correct)
logger.info('Train accuracy is: %.5f' % train_instance_acc)
with torch.no_grad():
test_metrics = {}
total_correct = 0
total_seen = 0
total_seen_class = [0 for _ in range(num_part)]
total_correct_class = [0 for _ in range(num_part)]
shape_ious = {cat: [] for cat in seg_classes.keys()}
seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table}
for cat in seg_classes.keys():
for label in seg_classes[cat]:
seg_label_to_cat[label] = cat
model = model.eval()
for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
cur_batch_size, NUM_POINT, _ = points.size()
points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda()
seg_pred = model(torch.cat([points, to_categorical(label, num_category).repeat(1, points.shape[1], 1)], -1))
cur_pred_val = seg_pred.cpu().data.numpy()
cur_pred_val_logits = cur_pred_val
cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32)
target = target.cpu().data.numpy()
for i in range(cur_batch_size):
cat = seg_label_to_cat[target[i, 0]]
logits = cur_pred_val_logits[i, :, :]
cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0]
correct = np.sum(cur_pred_val == target)
total_correct += correct
total_seen += (cur_batch_size * NUM_POINT)
for l in range(num_part):
total_seen_class[l] += np.sum(target == l)
total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l)))
for i in range(cur_batch_size):
segp = cur_pred_val[i, :]
segl = target[i, :]
cat = seg_label_to_cat[segl[0]]
part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
for l in seg_classes[cat]:
if (np.sum(segl == l) == 0) and (
np.sum(segp == l) == 0): # part is not present, no prediction as well
part_ious[l - seg_classes[cat][0]] = 1.0
else:
part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float(
np.sum((segl == l) | (segp == l)))
shape_ious[cat].append(np.mean(part_ious))
all_shape_ious = []
for cat in shape_ious.keys():
for iou in shape_ious[cat]:
all_shape_ious.append(iou)
shape_ious[cat] = np.mean(shape_ious[cat])
mean_shape_ious = np.mean(list(shape_ious.values()))
test_metrics['accuracy'] = total_correct / float(total_seen)
test_metrics['class_avg_accuracy'] = np.mean(
np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float))
for cat in sorted(shape_ious.keys()):
logger.info('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat]))
test_metrics['class_avg_iou'] = mean_shape_ious
test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious)
logger.info('Epoch %d test Accuracy: %f Class avg mIOU: %f Inctance avg mIOU: %f' % (
epoch + 1, test_metrics['accuracy'], test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou']))
if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou):
logger.info('Save model...')
savepath = 'best_model.pth'
logger.info('Saving at %s' % savepath)
state = {
'epoch': epoch,
'train_acc': train_instance_acc,
'test_acc': test_metrics['accuracy'],
'class_avg_iou': test_metrics['class_avg_iou'],
'inctance_avg_iou': test_metrics['inctance_avg_iou'],
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(state, savepath)
logger.info('Saving model....')
if test_metrics['accuracy'] > best_acc:
best_acc = test_metrics['accuracy']
if test_metrics['class_avg_iou'] > best_class_avg_iou:
best_class_avg_iou = test_metrics['class_avg_iou']
if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou:
best_inctance_avg_iou = test_metrics['inctance_avg_iou']
logger.info('Best accuracy is: %.5f' % best_acc)
logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou)
logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou)
global_epoch += 1
if __name__ == '__main__':
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。