代码拉取完成,页面将自动刷新
import os, sys, copy, glob, json, time, random, argparse
from shutil import copyfile
from tqdm import tqdm, trange
import mmcv
import imageio
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from lib import utils, dvgo, dcvgo, dmpigo
from lib.load_data import load_data
from torch_efficient_distloss import flatten_eff_distloss
from loguru import logger
def config_parser():
'''Define command line arguments
'''
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--config', required=True,
help='config file path')
parser.add_argument("--seed", type=int, default=777,
help='Random seed')
parser.add_argument("--no_reload", action='store_true',
help='do not reload weights from saved ckpt')
parser.add_argument("--no_reload_optimizer", action='store_true',
help='do not reload optimizer state from saved ckpt')
parser.add_argument("--ft_path", type=str, default='',
help='specific weights npy file to reload for coarse network')
parser.add_argument("--export_bbox_and_cams_only", type=str, default='',
help='export scene bbox and camera poses for debugging and 3d visualization')
parser.add_argument("--export_coarse_only", type=str, default='')
# testing options
parser.add_argument("--render_only", action='store_true',
help='do not optimize, reload weights and render out render_poses path')
parser.add_argument("--render_test", action='store_true')
parser.add_argument("--render_train", action='store_true')
parser.add_argument("--render_video", action='store_true')
parser.add_argument("--render_video_flipy", action='store_true')
parser.add_argument("--render_video_rot90", default=0, type=int)
parser.add_argument("--render_video_factor", type=float, default=0,
help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
parser.add_argument("--dump_images", action='store_true')
parser.add_argument("--eval_ssim", action='store_true')
parser.add_argument("--eval_lpips_alex", action='store_true')
parser.add_argument("--eval_lpips_vgg", action='store_true')
# logging/saving options
parser.add_argument("--i_logger.info", type=int, default=500,
help='frequency of console logger.infoout and metric loggin')
parser.add_argument("--i_weights", type=int, default=100000,
help='frequency of weight ckpt saving')
return parser
@torch.no_grad()
def render_viewpoints(model, render_poses, HW, Ks, ndc, render_kwargs,
gt_imgs=None, savedir=None, dump_images=False,
render_factor=0, render_video_flipy=False, render_video_rot90=0,
eval_ssim=False, eval_lpips_alex=False, eval_lpips_vgg=False):
'''Render images for the given viewpoints; run evaluation if gt given.
'''
assert len(render_poses) == len(HW) and len(HW) == len(Ks)
if render_factor!=0:
HW = np.copy(HW)
Ks = np.copy(Ks)
HW = (HW/render_factor).astype(int)
Ks[:, :2, :3] /= render_factor
rgbs = []
depths = []
bgmaps = []
psnrs = []
ssims = []
lpips_alex = []
lpips_vgg = []
for i, c2w in enumerate(tqdm(render_poses)):
H, W = HW[i]
K = Ks[i]
c2w = torch.Tensor(c2w)
rays_o, rays_d, viewdirs = dvgo.get_rays_of_a_view(
H, W, K, c2w, ndc, inverse_y=render_kwargs['inverse_y'],
flip_x=cfg.data.flip_x, flip_y=cfg.data.flip_y)
keys = ['rgb_marched', 'depth', 'alphainv_last']
rays_o = rays_o.flatten(0,-2)
rays_d = rays_d.flatten(0,-2)
viewdirs = viewdirs.flatten(0,-2)
render_result_chunks = [
{k: v for k, v in model(ro, rd, vd, **render_kwargs).items() if k in keys}
for ro, rd, vd in zip(rays_o.split(8192, 0), rays_d.split(8192, 0), viewdirs.split(8192, 0))
]
render_result = {
k: torch.cat([ret[k] for ret in render_result_chunks]).reshape(H,W,-1)
for k in render_result_chunks[0].keys()
}
rgb = render_result['rgb_marched'].cpu().numpy()
depth = render_result['depth'].cpu().numpy()
bgmap = render_result['alphainv_last'].cpu().numpy()
rgbs.append(rgb)
depths.append(depth)
bgmaps.append(bgmap)
if i==0:
logger.info('Testing', rgb.shape)
if gt_imgs is not None and render_factor==0:
p = -10. * np.log10(np.mean(np.square(rgb - gt_imgs[i])))
psnrs.append(p)
if eval_ssim:
ssims.append(utils.rgb_ssim(rgb, gt_imgs[i], max_val=1))
if eval_lpips_alex:
lpips_alex.append(utils.rgb_lpips(rgb, gt_imgs[i], net_name='alex', device=c2w.device))
if eval_lpips_vgg:
lpips_vgg.append(utils.rgb_lpips(rgb, gt_imgs[i], net_name='vgg', device=c2w.device))
if len(psnrs):
logger.info('Testing psnr', np.mean(psnrs), '(avg)')
if eval_ssim: logger.info('Testing ssim', np.mean(ssims), '(avg)')
if eval_lpips_vgg: logger.info('Testing lpips (vgg)', np.mean(lpips_vgg), '(avg)')
if eval_lpips_alex: logger.info('Testing lpips (alex)', np.mean(lpips_alex), '(avg)')
if render_video_flipy:
for i in range(len(rgbs)):
rgbs[i] = np.flip(rgbs[i], axis=0)
depths[i] = np.flip(depths[i], axis=0)
bgmaps[i] = np.flip(bgmaps[i], axis=0)
if render_video_rot90 != 0:
for i in range(len(rgbs)):
rgbs[i] = np.rot90(rgbs[i], k=render_video_rot90, axes=(0,1))
depths[i] = np.rot90(depths[i], k=render_video_rot90, axes=(0,1))
bgmaps[i] = np.rot90(bgmaps[i], k=render_video_rot90, axes=(0,1))
if savedir is not None and dump_images:
for i in trange(len(rgbs)):
rgb8 = utils.to8b(rgbs[i])
filename = os.path.join(savedir, '{:03d}.png'.format(i))
imageio.imwrite(filename, rgb8)
rgbs = np.array(rgbs)
depths = np.array(depths)
bgmaps = np.array(bgmaps)
return rgbs, depths, bgmaps
def seed_everything():
'''Seed everything for better reproducibility.
(some pytorch operation is non-deterministic like the backprop of grid_samples)
'''
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
def load_everything(args, cfg):
'''Load images / poses / camera settings / data split.
'''
data_dict = load_data(cfg.data)
# remove useless field
kept_keys = {
'hwf', 'HW', 'Ks', 'near', 'far', 'near_clip',
'i_train', 'i_val', 'i_test', 'irregular_shape',
'poses', 'render_poses', 'images'}
for k in list(data_dict.keys()):
if k not in kept_keys:
data_dict.pop(k)
# construct data tensor
if data_dict['irregular_shape']:
data_dict['images'] = [torch.FloatTensor(im, device='cpu') for im in data_dict['images']]
else:
data_dict['images'] = torch.FloatTensor(data_dict['images'], device='cpu')
data_dict['poses'] = torch.Tensor(data_dict['poses'])
return data_dict
def _compute_bbox_by_cam_frustrm_bounded(cfg, HW, Ks, poses, i_train, near, far):
xyz_min = torch.Tensor([np.inf, np.inf, np.inf])
xyz_max = -xyz_min
for (H, W), K, c2w in zip(HW[i_train], Ks[i_train], poses[i_train]):
rays_o, rays_d, viewdirs = dvgo.get_rays_of_a_view(
H=H, W=W, K=K, c2w=c2w,
ndc=cfg.data.ndc, inverse_y=cfg.data.inverse_y,
flip_x=cfg.data.flip_x, flip_y=cfg.data.flip_y)
if cfg.data.ndc:
pts_nf = torch.stack([rays_o+rays_d*near, rays_o+rays_d*far])
else:
pts_nf = torch.stack([rays_o+viewdirs*near, rays_o+viewdirs*far])
xyz_min = torch.minimum(xyz_min, pts_nf.amin((0,1,2)))
xyz_max = torch.maximum(xyz_max, pts_nf.amax((0,1,2)))
return xyz_min, xyz_max
def _compute_bbox_by_cam_frustrm_unbounded(cfg, HW, Ks, poses, i_train, near_clip):
# Find a tightest cube that cover all camera centers
xyz_min = torch.Tensor([np.inf, np.inf, np.inf])
xyz_max = -xyz_min
for (H, W), K, c2w in zip(HW[i_train], Ks[i_train], poses[i_train]):
rays_o, rays_d, viewdirs = dvgo.get_rays_of_a_view(
H=H, W=W, K=K, c2w=c2w,
ndc=cfg.data.ndc, inverse_y=cfg.data.inverse_y,
flip_x=cfg.data.flip_x, flip_y=cfg.data.flip_y)
pts = rays_o + rays_d * near_clip
xyz_min = torch.minimum(xyz_min, pts.amin((0,1)))
xyz_max = torch.maximum(xyz_max, pts.amax((0,1)))
center = (xyz_min + xyz_max) * 0.5
radius = (center - xyz_min).max() * cfg.data.unbounded_inner_r
xyz_min = center - radius
xyz_max = center + radius
return xyz_min, xyz_max
def compute_bbox_by_cam_frustrm(args, cfg, HW, Ks, poses, i_train, near, far, **kwargs):
logger.info('compute_bbox_by_cam_frustrm: start')
if cfg.data.unbounded_inward:
xyz_min, xyz_max = _compute_bbox_by_cam_frustrm_unbounded(
cfg, HW, Ks, poses, i_train, kwargs.get('near_clip', None))
else:
xyz_min, xyz_max = _compute_bbox_by_cam_frustrm_bounded(
cfg, HW, Ks, poses, i_train, near, far)
logger.info('compute_bbox_by_cam_frustrm: xyz_min', xyz_min)
logger.info('compute_bbox_by_cam_frustrm: xyz_max', xyz_max)
logger.info('compute_bbox_by_cam_frustrm: finish')
return xyz_min, xyz_max
@torch.no_grad()
def compute_bbox_by_coarse_geo(model_class, model_path, thres):
logger.info('compute_bbox_by_coarse_geo: start')
eps_time = time.time()
model = utils.load_model(model_class, model_path)
interp = torch.stack(torch.meshgrid(
torch.linspace(0, 1, model.world_size[0]),
torch.linspace(0, 1, model.world_size[1]),
torch.linspace(0, 1, model.world_size[2]),
), -1)
dense_xyz = model.xyz_min * (1-interp) + model.xyz_max * interp
density = model.density(dense_xyz)
alpha = model.activate_density(density)
mask = (alpha > thres)
active_xyz = dense_xyz[mask]
xyz_min = active_xyz.amin(0)
xyz_max = active_xyz.amax(0)
logger.info('compute_bbox_by_coarse_geo: xyz_min', xyz_min)
logger.info('compute_bbox_by_coarse_geo: xyz_max', xyz_max)
eps_time = time.time() - eps_time
logger.info('compute_bbox_by_coarse_geo: finish (eps time:', eps_time, 'secs)')
return xyz_min, xyz_max
def create_new_model(cfg, cfg_model, cfg_train, xyz_min, xyz_max, stage, coarse_ckpt_path):
model_kwargs = copy.deepcopy(cfg_model)
num_voxels = model_kwargs.pop('num_voxels')
if len(cfg_train.pg_scale):
num_voxels = int(num_voxels / (2**len(cfg_train.pg_scale))) # largest voxel size, smallest grid size
if cfg.data.ndc:
logger.info(f'scene_rep_reconstruction ({stage}): \033[96muse multiplane images\033[0m') # multiplane images
model = dmpigo.DirectMPIGO(
xyz_min=xyz_min, xyz_max=xyz_max,
num_voxels=num_voxels,
**model_kwargs)
elif cfg.data.unbounded_inward:
logger.info(f'scene_rep_reconstruction ({stage}): \033[96muse contraced voxel grid (covering unbounded)\033[0m') # contraced voxel grid
model = dcvgo.DirectContractedVoxGO(
xyz_min=xyz_min, xyz_max=xyz_max,
num_voxels=num_voxels,
**model_kwargs)
else:
logger.info(f'scene_rep_reconstruction ({stage}): \033[96muse dense voxel grid\033[0m') # dense voxel grid
model = dvgo.DirectVoxGO(
xyz_min=xyz_min, xyz_max=xyz_max,
num_voxels=num_voxels,
mask_cache_path=coarse_ckpt_path,
**model_kwargs)
model = model.to(device)
optimizer = utils.create_optimizer_or_freeze_model(model, cfg_train, global_step=0)
return model, optimizer
def load_existed_model(args, cfg, cfg_train, reload_ckpt_path):
if cfg.data.ndc:
model_class = dmpigo.DirectMPIGO
elif cfg.data.unbounded_inward:
model_class = dcvgo.DirectContractedVoxGO
else:
model_class = dvgo.DirectVoxGO
model = utils.load_model(model_class, reload_ckpt_path).to(device)
optimizer = utils.create_optimizer_or_freeze_model(model, cfg_train, global_step=0)
model, optimizer, start = utils.load_checkpoint(
model, optimizer, reload_ckpt_path, args.no_reload_optimizer)
return model, optimizer, start
def scene_rep_reconstruction(args, cfg, cfg_model, cfg_train, xyz_min, xyz_max, data_dict, stage, coarse_ckpt_path=None):
# scene representation reconstruction
"""
Args:
args: arugments
cfg: config file
cfg_model: model config for fine model or coarse model
cfg_train: training config for fine model or coarse model
"""
# init
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if abs(cfg_model.world_bound_scale - 1) > 1e-9:
xyz_shift = (xyz_max - xyz_min) * (cfg_model.world_bound_scale - 1) / 2
xyz_min -= xyz_shift
xyz_max += xyz_shift
HW, Ks, near, far, i_train, i_val, i_test, poses, render_poses, images = [
data_dict[k] for k in [
'HW', 'Ks', 'near', 'far', 'i_train', 'i_val', 'i_test', 'poses', 'render_poses', 'images'
]
]
# find whether there is existing checkpoint path
last_ckpt_path = os.path.join(cfg.basedir, cfg.expname, f'{stage}_last.tar')
if args.no_reload:
reload_ckpt_path = None
elif args.ft_path:
reload_ckpt_path = args.ft_path
elif os.path.isfile(last_ckpt_path):
reload_ckpt_path = last_ckpt_path
else:
reload_ckpt_path = None
# init model and optimizer
if reload_ckpt_path is None:
logger.info(f'scene_rep_reconstruction ({stage}): train from scratch')
model, optimizer = create_new_model(cfg, cfg_model, cfg_train, xyz_min, xyz_max, stage, coarse_ckpt_path)
start = 0
if cfg_model.maskout_near_cam_vox:
model.maskout_near_cam_vox(poses[i_train,:3,3], near)
else:
logger.info(f'scene_rep_reconstruction ({stage}): reload from {reload_ckpt_path}')
model, optimizer, start = load_existed_model(args, cfg, cfg_train, reload_ckpt_path)
# init rendering setup
render_kwargs = {
'near': data_dict['near'],
'far': data_dict['far'],
'bg': 1 if cfg.data.white_bkgd else 0,
'rand_bkgd': cfg.data.rand_bkgd,
'stepsize': cfg_model.stepsize,
'inverse_y': cfg.data.inverse_y,
'flip_x': cfg.data.flip_x,
'flip_y': cfg.data.flip_y,
}
# init batch rays sampler
def gather_training_rays():
if data_dict['irregular_shape']:
rgb_tr_ori = [images[i].to('cpu' if cfg.data.load2gpu_on_the_fly else device) for i in i_train]
else:
rgb_tr_ori = images[i_train].to('cpu' if cfg.data.load2gpu_on_the_fly else device) # rgb training images
if cfg_train.ray_sampler == 'in_maskcache':
rgb_tr, rays_o_tr, rays_d_tr, viewdirs_tr, imsz = dvgo.get_training_rays_in_maskcache_sampling(
rgb_tr_ori=rgb_tr_ori,
train_poses=poses[i_train],
HW=HW[i_train], Ks=Ks[i_train],
ndc=cfg.data.ndc, inverse_y=cfg.data.inverse_y,
flip_x=cfg.data.flip_x, flip_y=cfg.data.flip_y,
model=model, render_kwargs=render_kwargs)
elif cfg_train.ray_sampler == 'flatten':
rgb_tr, rays_o_tr, rays_d_tr, viewdirs_tr, imsz = dvgo.get_training_rays_flatten(
rgb_tr_ori=rgb_tr_ori,
train_poses=poses[i_train],
HW=HW[i_train], Ks=Ks[i_train], ndc=cfg.data.ndc, inverse_y=cfg.data.inverse_y,
flip_x=cfg.data.flip_x, flip_y=cfg.data.flip_y)
else:
rgb_tr, rays_o_tr, rays_d_tr, viewdirs_tr, imsz = dvgo.get_training_rays(
rgb_tr=rgb_tr_ori,
train_poses=poses[i_train],
HW=HW[i_train], Ks=Ks[i_train], ndc=cfg.data.ndc, inverse_y=cfg.data.inverse_y,
flip_x=cfg.data.flip_x, flip_y=cfg.data.flip_y)
index_generator = dvgo.batch_indices_generator(len(rgb_tr), cfg_train.N_rand) # N_rand is batch size
batch_index_sampler = lambda: next(index_generator)
return rgb_tr, rays_o_tr, rays_d_tr, viewdirs_tr, imsz, batch_index_sampler
rgb_tr, rays_o_tr, rays_d_tr, viewdirs_tr, imsz, batch_index_sampler = gather_training_rays()
# view-count-based learning rate
if cfg_train.pervoxel_lr:
def per_voxel_init():
cnt = model.voxel_count_views(
rays_o_tr=rays_o_tr, rays_d_tr=rays_d_tr, imsz=imsz, near=near, far=far,
stepsize=cfg_model.stepsize, downrate=cfg_train.pervoxel_lr_downrate,
irregular_shape=data_dict['irregular_shape'])
optimizer.set_pervoxel_lr(cnt)
model.mask_cache.mask[cnt.squeeze() <= 2] = False
per_voxel_init()
if cfg_train.maskout_lt_nviews > 0:
model.update_occupancy_cache_lt_nviews(
rays_o_tr, rays_d_tr, imsz, render_kwargs, cfg_train.maskout_lt_nviews)
# GOGO
torch.cuda.empty_cache()
psnr_lst = []
time0 = time.time()
global_step = -1
for global_step in trange(1+start, 1+cfg_train.N_iters):
# renew occupancy grid
if model.mask_cache is not None and (global_step + 500) % 1000 == 0:
model.update_occupancy_cache()
# progress scaling checkpoint
if global_step in cfg_train.pg_scale:
n_rest_scales = len(cfg_train.pg_scale)-cfg_train.pg_scale.index(global_step)-1
cur_voxels = int(cfg_model.num_voxels / (2**n_rest_scales)) # n_rest_scales from large to small
if isinstance(model, (dvgo.DirectVoxGO, dcvgo.DirectContractedVoxGO)):
model.scale_volume_grid(cur_voxels)
elif isinstance(model, dmpigo.DirectMPIGO):
model.scale_volume_grid(cur_voxels, model.mpi_depth)
else:
raise NotImplementedError
optimizer = utils.create_optimizer_or_freeze_model(model, cfg_train, global_step=0)
model.act_shift -= cfg_train.decay_after_scale
torch.cuda.empty_cache()
# random sample rays
if cfg_train.ray_sampler in ['flatten', 'in_maskcache']:
sel_i = batch_index_sampler()
target = rgb_tr[sel_i]
rays_o = rays_o_tr[sel_i]
rays_d = rays_d_tr[sel_i]
viewdirs = viewdirs_tr[sel_i]
elif cfg_train.ray_sampler == 'random':
sel_b = torch.randint(rgb_tr.shape[0], [cfg_train.N_rand])
sel_r = torch.randint(rgb_tr.shape[1], [cfg_train.N_rand])
sel_c = torch.randint(rgb_tr.shape[2], [cfg_train.N_rand])
target = rgb_tr[sel_b, sel_r, sel_c]
rays_o = rays_o_tr[sel_b, sel_r, sel_c]
rays_d = rays_d_tr[sel_b, sel_r, sel_c]
viewdirs = viewdirs_tr[sel_b, sel_r, sel_c]
else:
raise NotImplementedError
if cfg.data.load2gpu_on_the_fly:
target = target.to(device)
rays_o = rays_o.to(device)
rays_d = rays_d.to(device)
viewdirs = viewdirs.to(device)
# volume rendering
render_result = model(
rays_o, rays_d, viewdirs,
global_step=global_step, is_train=True,
**render_kwargs)
# gradient descent step
optimizer.zero_grad(set_to_none=True)
loss = cfg_train.weight_main * F.mse_loss(render_result['rgb_marched'], target)
psnr = utils.mse2psnr(loss.detach())
if cfg_train.weight_entropy_last > 0:
pout = render_result['alphainv_last'].clamp(1e-6, 1-1e-6)
entropy_last_loss = -(pout*torch.log(pout) + (1-pout)*torch.log(1-pout)).mean()
loss += cfg_train.weight_entropy_last * entropy_last_loss
if cfg_train.weight_nearclip > 0:
near_thres = data_dict['near_clip'] / model.scene_radius[0].item()
near_mask = (render_result['t'] < near_thres)
density = render_result['raw_density'][near_mask]
if len(density):
nearclip_loss = (density - density.detach()).sum()
loss += cfg_train.weight_nearclip * nearclip_loss
if cfg_train.weight_distortion > 0:
n_max = render_result['n_max']
s = render_result['s']
w = render_result['weights']
ray_id = render_result['ray_id']
loss_distortion = flatten_eff_distloss(w, s, 1/n_max, ray_id)
loss += cfg_train.weight_distortion * loss_distortion
if cfg_train.weight_rgbper > 0:
rgbper = (render_result['raw_rgb'] - target[render_result['ray_id']]).pow(2).sum(-1)
rgbper_loss = (rgbper * render_result['weights'].detach()).sum() / len(rays_o)
loss += cfg_train.weight_rgbper * rgbper_loss
loss.backward()
if global_step<cfg_train.tv_before and global_step>cfg_train.tv_after and global_step%cfg_train.tv_every==0:
if cfg_train.weight_tv_density>0:
model.density_total_variation_add_grad(
cfg_train.weight_tv_density/len(rays_o), global_step<cfg_train.tv_dense_before)
if cfg_train.weight_tv_k0>0:
model.k0_total_variation_add_grad(
cfg_train.weight_tv_k0/len(rays_o), global_step<cfg_train.tv_dense_before)
optimizer.step()
psnr_lst.append(psnr.item())
# update lr
decay_steps = cfg_train.lrate_decay * 1000
decay_factor = 0.1 ** (1/decay_steps)
for i_opt_g, param_group in enumerate(optimizer.param_groups):
param_group['lr'] = param_group['lr'] * decay_factor
# check log & save
if global_step%args.i_logger.info==0:
eps_time = time.time() - time0
eps_time_str = f'{eps_time//3600:02.0f}:{eps_time//60%60:02.0f}:{eps_time%60:02.0f}'
tqdm.write(f'scene_rep_reconstruction ({stage}): iter {global_step:6d} / '
f'Loss: {loss.item():.9f} / PSNR: {np.mean(psnr_lst):5.2f} / '
f'Eps: {eps_time_str}')
psnr_lst = []
if global_step%args.i_weights==0:
path = os.path.join(cfg.basedir, cfg.expname, f'{stage}_{global_step:06d}.tar')
torch.save({
'global_step': global_step,
'model_kwargs': model.get_kwargs(),
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, path)
logger.info(f'scene_rep_reconstruction ({stage}): saved checkpoints at', path)
if global_step != -1:
torch.save({
'global_step': global_step,
'model_kwargs': model.get_kwargs(),
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, last_ckpt_path)
logger.info(f'scene_rep_reconstruction ({stage}): saved checkpoints at', last_ckpt_path)
def train(args, cfg, data_dict):
# init
logger.info('train: start')
eps_time = time.time()
os.makedirs(os.path.join(cfg.basedir, cfg.expname), exist_ok=True)
with open(os.path.join(cfg.basedir, cfg.expname, 'args.txt'), 'w') as file:
for arg in sorted(vars(args)):
attr = getattr(args, arg)
file.write('{} = {}\n'.format(arg, attr))
cfg.dump(os.path.join(cfg.basedir, cfg.expname, 'config.py'))
# coarse geometry searching (only works for inward bounded scenes)
eps_coarse = time.time()
xyz_min_coarse, xyz_max_coarse = compute_bbox_by_cam_frustrm(args=args, cfg=cfg, **data_dict)
if cfg.coarse_train.N_iters > 0:
scene_rep_reconstruction(
args=args, cfg=cfg,
cfg_model=cfg.coarse_model_and_render, cfg_train=cfg.coarse_train,
xyz_min=xyz_min_coarse, xyz_max=xyz_max_coarse,
data_dict=data_dict, stage='coarse')
eps_coarse = time.time() - eps_coarse
eps_time_str = f'{eps_coarse//3600:02.0f}:{eps_coarse//60%60:02.0f}:{eps_coarse%60:02.0f}'
logger.info('train: coarse geometry searching in', eps_time_str)
coarse_ckpt_path = os.path.join(cfg.basedir, cfg.expname, f'coarse_last.tar')
else:
logger.info('train: skip coarse geometry searching')
coarse_ckpt_path = None
# fine detail reconstruction
eps_fine = time.time()
if cfg.coarse_train.N_iters == 0:
xyz_min_fine, xyz_max_fine = xyz_min_coarse.clone(), xyz_max_coarse.clone()
else:
xyz_min_fine, xyz_max_fine = compute_bbox_by_coarse_geo(
model_class=dvgo.DirectVoxGO, model_path=coarse_ckpt_path,
thres=cfg.fine_model_and_render.bbox_thres)
scene_rep_reconstruction(
args=args, cfg=cfg,
cfg_model=cfg.fine_model_and_render, cfg_train=cfg.fine_train,
xyz_min=xyz_min_fine, xyz_max=xyz_max_fine,
data_dict=data_dict, stage='fine',
coarse_ckpt_path=coarse_ckpt_path)
eps_fine = time.time() - eps_fine
eps_time_str = f'{eps_fine//3600:02.0f}:{eps_fine//60%60:02.0f}:{eps_fine%60:02.0f}'
logger.info('train: fine detail reconstruction in', eps_time_str)
eps_time = time.time() - eps_time
eps_time_str = f'{eps_time//3600:02.0f}:{eps_time//60%60:02.0f}:{eps_time%60:02.0f}'
logger.info('train: finish (eps time', eps_time_str, ')')
if __name__=='__main__':
# load setup
parser = config_parser()
args = parser.parse_args()
cfg = mmcv.Config.fromfile(args.config)
# add file logger
logger.add(os.path.join(cfg.basedir, cfg.expname, 'log.txt'))
# init enviroment
if torch.cuda.is_available():
torch.set_default_tensor_type('torch.cuda.FloatTensor')
device = torch.device('cuda')
else:
device = torch.device('cpu')
seed_everything()
# load images / poses / camera settings / data split
data_dict = load_everything(args=args, cfg=cfg)
# export scene bbox and camera poses in 3d for debugging and visualization
if args.export_bbox_and_cams_only:
logger.info('Export bbox and cameras...')
xyz_min, xyz_max = compute_bbox_by_cam_frustrm(args=args, cfg=cfg, **data_dict)
poses, HW, Ks, i_train = data_dict['poses'], data_dict['HW'], data_dict['Ks'], data_dict['i_train']
near, far = data_dict['near'], data_dict['far']
if data_dict['near_clip'] is not None:
near = data_dict['near_clip']
cam_lst = []
for c2w, (H, W), K in zip(poses[i_train], HW[i_train], Ks[i_train]):
rays_o, rays_d, viewdirs = dvgo.get_rays_of_a_view(
H, W, K, c2w, cfg.data.ndc, inverse_y=cfg.data.inverse_y,
flip_x=cfg.data.flip_x, flip_y=cfg.data.flip_y,)
cam_o = rays_o[0,0].cpu().numpy()
cam_d = rays_d[[0,0,-1,-1],[0,-1,0,-1]].cpu().numpy()
cam_lst.append(np.array([cam_o, *(cam_o+cam_d*max(near, far*0.05))]))
np.savez_compressed(args.export_bbox_and_cams_only,
xyz_min=xyz_min.cpu().numpy(), xyz_max=xyz_max.cpu().numpy(),
cam_lst=np.array(cam_lst))
logger.info('done')
sys.exit()
if args.export_coarse_only:
logger.info('Export coarse visualization...')
with torch.no_grad():
ckpt_path = os.path.join(cfg.basedir, cfg.expname, 'coarse_last.tar')
model = utils.load_model(dvgo.DirectVoxGO, ckpt_path).to(device)
alpha = model.activate_density(model.density.get_dense_grid()).squeeze().cpu().numpy()
rgb = torch.sigmoid(model.k0.get_dense_grid()).squeeze().permute(1,2,3,0).cpu().numpy()
np.savez_compressed(args.export_coarse_only, alpha=alpha, rgb=rgb)
logger.info('done')
sys.exit()
# train
if not args.render_only:
train(args, cfg, data_dict)
# load model for rendring
if args.render_test or args.render_train or args.render_video:
if args.ft_path:
ckpt_path = args.ft_path
else:
ckpt_path = os.path.join(cfg.basedir, cfg.expname, 'fine_last.tar')
ckpt_name = ckpt_path.split('/')[-1][:-4]
if cfg.data.ndc:
model_class = dmpigo.DirectMPIGO
elif cfg.data.unbounded_inward:
model_class = dcvgo.DirectContractedVoxGO
else:
model_class = dvgo.DirectVoxGO
model = utils.load_model(model_class, ckpt_path).to(device)
stepsize = cfg.fine_model_and_render.stepsize
render_viewpoints_kwargs = {
'model': model,
'ndc': cfg.data.ndc,
'render_kwargs': {
'near': data_dict['near'],
'far': data_dict['far'],
'bg': 1 if cfg.data.white_bkgd else 0,
'stepsize': stepsize,
'inverse_y': cfg.data.inverse_y,
'flip_x': cfg.data.flip_x,
'flip_y': cfg.data.flip_y,
'render_depth': True,
},
}
# render trainset and eval
if args.render_train:
testsavedir = os.path.join(cfg.basedir, cfg.expname, f'render_train_{ckpt_name}')
os.makedirs(testsavedir, exist_ok=True)
logger.info('All results are dumped into', testsavedir)
rgbs, depths, bgmaps = render_viewpoints(
render_poses=data_dict['poses'][data_dict['i_train']],
HW=data_dict['HW'][data_dict['i_train']],
Ks=data_dict['Ks'][data_dict['i_train']],
gt_imgs=[data_dict['images'][i].cpu().numpy() for i in data_dict['i_train']],
savedir=testsavedir, dump_images=args.dump_images,
eval_ssim=args.eval_ssim, eval_lpips_alex=args.eval_lpips_alex, eval_lpips_vgg=args.eval_lpips_vgg,
**render_viewpoints_kwargs)
imageio.mimwrite(os.path.join(testsavedir, 'video.rgb.mp4'), utils.to8b(rgbs), fps=30, quality=8)
imageio.mimwrite(os.path.join(testsavedir, 'video.depth.mp4'), utils.to8b(1 - depths / np.max(depths)), fps=30, quality=8)
# render testset and eval
if args.render_test:
testsavedir = os.path.join(cfg.basedir, cfg.expname, f'render_test_{ckpt_name}')
os.makedirs(testsavedir, exist_ok=True)
logger.info('All results are dumped into', testsavedir)
rgbs, depths, bgmaps = render_viewpoints(
render_poses=data_dict['poses'][data_dict['i_test']],
HW=data_dict['HW'][data_dict['i_test']],
Ks=data_dict['Ks'][data_dict['i_test']],
gt_imgs=[data_dict['images'][i].cpu().numpy() for i in data_dict['i_test']],
savedir=testsavedir, dump_images=args.dump_images,
eval_ssim=args.eval_ssim, eval_lpips_alex=args.eval_lpips_alex, eval_lpips_vgg=args.eval_lpips_vgg,
**render_viewpoints_kwargs)
imageio.mimwrite(os.path.join(testsavedir, 'video.rgb.mp4'), utils.to8b(rgbs), fps=30, quality=8)
imageio.mimwrite(os.path.join(testsavedir, 'video.depth.mp4'), utils.to8b(1 - depths / np.max(depths)), fps=30, quality=8)
# render video
if args.render_video:
testsavedir = os.path.join(cfg.basedir, cfg.expname, f'render_video_{ckpt_name}')
os.makedirs(testsavedir, exist_ok=True)
logger.info('All results are dumped into', testsavedir)
rgbs, depths, bgmaps = render_viewpoints(
render_poses=data_dict['render_poses'],
HW=data_dict['HW'][data_dict['i_test']][[0]].repeat(len(data_dict['render_poses']), 0),
Ks=data_dict['Ks'][data_dict['i_test']][[0]].repeat(len(data_dict['render_poses']), 0),
render_factor=args.render_video_factor,
render_video_flipy=args.render_video_flipy,
render_video_rot90=args.render_video_rot90,
savedir=testsavedir, dump_images=args.dump_images,
**render_viewpoints_kwargs)
imageio.mimwrite(os.path.join(testsavedir, 'video.rgb.mp4'), utils.to8b(rgbs), fps=30, quality=8)
import matplotlib.pyplot as plt
depths_vis = depths * (1-bgmaps) + bgmaps
dmin, dmax = np.percentile(depths_vis[bgmaps < 0.1], q=[5, 95])
depth_vis = plt.get_cmap('rainbow')(1 - np.clip((depths_vis - dmin) / (dmax - dmin), 0, 1)).squeeze()[..., :3]
imageio.mimwrite(os.path.join(testsavedir, 'video.depth.mp4'), utils.to8b(depth_vis), fps=30, quality=8)
logger.info('Done')
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。