1 Star 0 Fork 0

le-cheng/leNet_ptf_gat

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train_partseg.py 11.15 KB
一键复制 编辑 原始数据 按行查看 历史
le-cheng 提交于 2022-03-10 10:28 . first commit
"""
Author: Benny
Date: Nov 2019
"""
import datetime
import importlib
import logging
import os
import shutil
from pathlib import Path
import hydra
import numpy as np
import torch
from omegaconf import OmegaConf
from torch.optim.lr_scheduler import MultiStepLR
from tqdm import tqdm
import provider
from dataset import PartNormalDataset
seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37],
'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49],
'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table}
for cat in seg_classes.keys():
for label in seg_classes[cat]:
seg_label_to_cat[label] = cat
def inplace_relu(m):
classname = m.__class__.__name__
if classname.find('ReLU') != -1:
m.inplace=True
def to_categorical(y, num_classes):
""" 1-hot encodes a tensor """
new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
if (y.is_cuda):
return new_y.cuda()
return new_y
@hydra.main(config_path='config', config_name='partseg')
def main(cfg):
OmegaConf.set_struct(cfg, False)
'''HYPER PARAMETER'''
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu)
logger = logging.getLogger(__name__)
print(OmegaConf.to_yaml(cfg))
print('\n')
root = hydra.utils.to_absolute_path('data/shapenetcore_partanno_segmentation_benchmark_v0_normal')
TRAIN_DATASET = PartNormalDataset(root=root, npoints=cfg.num_point, split='trainval', normal_channel=cfg.normal)
trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, drop_last=True)
TEST_DATASET = PartNormalDataset(root=root, npoints=cfg.num_point, split='test', normal_channel=cfg.normal)
testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)
'''MODEL LOADING'''
cfg.input_dim = (6 if cfg.normal else 3) + 16
cfg.num_class = 50
num_category = 16
num_part = cfg.num_class
shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(cfg.model.name)), '.')
model = getattr(importlib.import_module('models.{}.model'.format(cfg.model.name)), 'PointTransformerSeg')(cfg).cuda()
criterion = torch.nn.CrossEntropyLoss()
try:
checkpoint = torch.load('best_model.pth')
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
logger.info('Use pretrain model')
except:
logger.info('No existing model, starting training from scratch...')
start_epoch = 0
# if cfg.optimizer == 'Adam':
# optimizer = torch.optim.Adam(
# model.parameters(),
# lr=cfg.learning_rate,
# betas=(0.9, 0.999),
# eps=1e-08,
# weight_decay=cfg.weight_decay
# )
# else:
# optimizer = torch.optim.SGD(model.parameters(), lr=cfg.learning_rate, momentum=0.9)
if cfg.optimizer == 'Adam':
optimizer = torch.optim.Adam(
model.parameters(),
lr=cfg.learning_rate,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=cfg.weight_decay
)
else:
optimizer = torch.optim.SGD(
model.parameters(),
lr=cfg.learning_rate,
momentum=0.9,
weight_decay=0.0001
)
scheduler = MultiStepLR(
optimizer,
milestones = [120,180],
gamma=cfg.scheduler_gamma
)
def bn_momentum_adjust(m, momentum):
if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
m.momentum = momentum
LEARNING_RATE_CLIP = 1e-5
MOMENTUM_ORIGINAL = 0.1
MOMENTUM_DECCAY = 0.5
MOMENTUM_DECCAY_STEP = cfg.step_size
best_acc = 0
global_epoch = 0
best_class_avg_iou = 0
best_inctance_avg_iou = 0
for epoch in range(start_epoch, cfg.epoch):
mean_correct = []
logger.info('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, cfg.epoch))
'''Adjust learning rate and BN momentum'''
lr = max(cfg.learning_rate * (cfg.lr_decay ** (epoch // cfg.step_size)), LEARNING_RATE_CLIP)
logger.info('Learning rate:%f' % lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))
if momentum < 0.01:
momentum = 0.01
print('BN momentum updated to: %f' % momentum)
model = model.apply(lambda x: bn_momentum_adjust(x, momentum))
model = model.train()
'''learning one epoch'''
for i, (points, label, target) in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9):
points = points.data.numpy()
points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
points = torch.Tensor(points)
points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda()
optimizer.zero_grad()
seg_pred = model(torch.cat([points, to_categorical(label, num_category).repeat(1, points.shape[1], 1)], -1))
seg_pred = seg_pred.contiguous().view(-1, num_part)
target = target.view(-1, 1)[:, 0]
pred_choice = seg_pred.data.max(1)[1]
correct = pred_choice.eq(target.data).cpu().sum()
mean_correct.append(correct.item() / (cfg.batch_size * cfg.num_point))
loss = criterion(seg_pred, target)
loss.backward()
optimizer.step()
train_instance_acc = np.mean(mean_correct)
logger.info('Train accuracy is: %.5f' % train_instance_acc)
with torch.no_grad():
test_metrics = {}
total_correct = 0
total_seen = 0
total_seen_class = [0 for _ in range(num_part)]
total_correct_class = [0 for _ in range(num_part)]
shape_ious = {cat: [] for cat in seg_classes.keys()}
seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table}
for cat in seg_classes.keys():
for label in seg_classes[cat]:
seg_label_to_cat[label] = cat
model = model.eval()
for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
cur_batch_size, NUM_POINT, _ = points.size()
points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda()
seg_pred = model(torch.cat([points, to_categorical(label, num_category).repeat(1, points.shape[1], 1)], -1))
cur_pred_val = seg_pred.cpu().data.numpy()
cur_pred_val_logits = cur_pred_val
cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32)
target = target.cpu().data.numpy()
for i in range(cur_batch_size):
cat = seg_label_to_cat[target[i, 0]]
logits = cur_pred_val_logits[i, :, :]
cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0]
correct = np.sum(cur_pred_val == target)
total_correct += correct
total_seen += (cur_batch_size * NUM_POINT)
for l in range(num_part):
total_seen_class[l] += np.sum(target == l)
total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l)))
for i in range(cur_batch_size):
segp = cur_pred_val[i, :]
segl = target[i, :]
cat = seg_label_to_cat[segl[0]]
part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
for l in seg_classes[cat]:
if (np.sum(segl == l) == 0) and (
np.sum(segp == l) == 0): # part is not present, no prediction as well
part_ious[l - seg_classes[cat][0]] = 1.0
else:
part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float(
np.sum((segl == l) | (segp == l)))
shape_ious[cat].append(np.mean(part_ious))
all_shape_ious = []
for cat in shape_ious.keys():
for iou in shape_ious[cat]:
all_shape_ious.append(iou)
shape_ious[cat] = np.mean(shape_ious[cat])
mean_shape_ious = np.mean(list(shape_ious.values()))
test_metrics['accuracy'] = total_correct / float(total_seen)
test_metrics['class_avg_accuracy'] = np.mean(
np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float))
for cat in sorted(shape_ious.keys()):
logger.info('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat]))
test_metrics['class_avg_iou'] = mean_shape_ious
test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious)
logger.info('Epoch %d test Accuracy: %f Class avg mIOU: %f Inctance avg mIOU: %f' % (
epoch + 1, test_metrics['accuracy'], test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou']))
if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou):
logger.info('Save model...')
savepath = 'best_model.pth'
logger.info('Saving at %s' % savepath)
state = {
'epoch': epoch,
'train_acc': train_instance_acc,
'test_acc': test_metrics['accuracy'],
'class_avg_iou': test_metrics['class_avg_iou'],
'inctance_avg_iou': test_metrics['inctance_avg_iou'],
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(state, savepath)
logger.info('Saving model....')
if test_metrics['accuracy'] > best_acc:
best_acc = test_metrics['accuracy']
if test_metrics['class_avg_iou'] > best_class_avg_iou:
best_class_avg_iou = test_metrics['class_avg_iou']
if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou:
best_inctance_avg_iou = test_metrics['inctance_avg_iou']
logger.info('Best accuracy is: %.5f' % best_acc)
logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou)
logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou)
global_epoch += 1
if __name__ == '__main__':
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/chengleniubi/leNet_ptf_gat.git
git@gitee.com:chengleniubi/leNet_ptf_gat.git
chengleniubi
leNet_ptf_gat
leNet_ptf_gat
master

搜索帮助