1 Star 9 Fork 4

老张/first-order-model

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
main_window.py 42.59 KB
一键复制 编辑 原始数据 按行查看 历史
老张 提交于 2021-03-01 19:20 . 增加超分辨率,UI界面
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963
import os
import sys
import cv2
import glob
import pathlib
import webbrowser as web
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
import matplotlib
matplotlib.use('Agg')
import yaml
from argparse import ArgumentParser
from tqdm import tqdm
import imageio
import numpy as np
from skimage.transform import resize
from skimage import img_as_ubyte
import torch
from sync_batchnorm import DataParallelWithCallback
from modules.generator import OcclusionAwareGenerator
from modules.keypoint_detector import KPDetector
from animate import normalize_kp
from scipy.spatial import ConvexHull
import RRDBNet_arch as arch
# device = torch.device('cuda')
# precision = torch.float32
# loader = transforms.Compose([
# transforms.ToTensor()])
def load_checkpoints(config_path, checkpoint_path, cpu=False):
with open(config_path) as f:
config = yaml.load(f,Loader=yaml.FullLoader)
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
if not cpu:
generator.cuda()
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
if not cpu:
kp_detector.cuda()
if cpu:
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
else:
checkpoint = torch.load(checkpoint_path)
generator.load_state_dict(checkpoint['generator'])
kp_detector.load_state_dict(checkpoint['kp_detector'])
if not cpu:
generator = DataParallelWithCallback(generator)
kp_detector = DataParallelWithCallback(kp_detector)
generator.eval()
kp_detector.eval()
return generator, kp_detector
def make_animation(window,source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True,
cpu=False):
with torch.no_grad():
predictions = []
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
if not cpu:
source = source.cuda()
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
kp_source = kp_detector(source)
kp_driving_initial = kp_detector(driving[:, :, 0])
count = 0
for frame_idx in tqdm(range(driving.shape[2])):
driving_frame = driving[:, :, frame_idx]
if not cpu:
driving_frame = driving_frame.cuda()
kp_driving = kp_detector(driving_frame)
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
count +=1
window.update_command(f'make_animation {count}')
return predictions
def find_best_frame(source, driving, cpu=False):
import face_alignment
def normalize_kp(kp):
kp = kp - kp.mean(axis=0, keepdims=True)
area = ConvexHull(kp[:, :2]).volume
area = np.sqrt(area)
kp[:, :2] = kp[:, :2] / area
return kp
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
device='cpu' if cpu else 'cuda')
kp_source = fa.get_landmarks(255 * source)[0]
kp_source = normalize_kp(kp_source)
norm = float('inf')
frame_num = 0
for i, image in tqdm(enumerate(driving)):
kp_driving = fa.get_landmarks(255 * image)[0]
kp_driving = normalize_kp(kp_driving)
new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
if new_norm < norm:
norm = new_norm
frame_num = i
return frame_num
# class MattingProgressBar(QWidget):
# close_Signal = pyqtSignal(int)
# def __init__(self):
# super(MattingProgressBar,self).__init__()
# self.setStyleSheet(
# 'QWidget{background-color:rgb(75,75,75);color:rgb(200,200,200)}')
#
# self.setWindowTitle('生成Alpha')
# self.setWindowFlags(Qt.WindowStaysOnTopHint)
# self.setMinimumWidth(500)
# self.mainLayout=QVBoxLayout()
# self.mainLayout.setContentsMargins(15,12,15,15)
# self.tip_layout =QHBoxLayout()
# self.tip_layout.setAlignment(Qt.AlignCenter)
# self.tip_label=QLabel("文件名:")
# self.tip_name =QLabel('xxx')
# self.tip_layout.addWidget(self.tip_label)
# self.tip_layout.addWidget(self.tip_name)
# self.frame_layout =QHBoxLayout()
# self.frame_layout.setAlignment(Qt.AlignCenter)
# self.all_frame_label = QLabel('总帧数: ')
# self.all_frame = QLabel('0')
# self.frame_layout.addWidget(self.all_frame_label)
# self.frame_layout.addWidget(self.all_frame)
# self.current_frame_layout =QHBoxLayout()
# self.current_frame_layout.setAlignment(Qt.AlignCenter)
# self.current_frame_label =QLabel('当前帧:')
# self.current_frame = QLabel('0')
# self.current_frame_layout.addWidget(self.current_frame_label)
# self.current_frame_layout.addWidget(self.current_frame)
# self.progress_layout=QHBoxLayout()
# self.p=QProgressBar()
#
#
# self.mainLayout.addLayout(self.tip_layout)
# self.mainLayout.addLayout(self.frame_layout)
# self.mainLayout.addLayout(self.current_frame_layout)
# self.mainLayout.addWidget(self.p)
# self.setLayout(self.mainLayout)
#
# def setValue(self,value):
# self.p.setValue(value)
#
# def closeEvent(self, QCloseEvent):
# self.close_Signal.emit(1)
# class DoThread(QThread):
# back_Signal =pyqtSignal(dict)
# def __init__(self,data_info):
# super(DoThread,self).__init__()
#
# self.model_path =data_info['model_path']
# self.sample_type = data_info['sample_type']
# self.model_type = data_info['model_type']
# self.current_filename = data_info['filename'] #当前正在转化的文件名
# self.type = data_info['type']
# self.src_path = data_info['src_path']
# self.bg_path = data_info['bg_path']
# self.out_dir = data_info['out_dir']
# self.running =False
# self.count = 0 #当前执行帧数
#
#
# self.model = MattingRefine(backbone=self.model_type,
# backbone_scale=0.25,
# refine_mode=self.sample_type)
# self.model.load_state_dict(torch.load(self.model_path))
# print('start 02')
# self.model = self.model.eval().to(precision).to(device)
# print('start 03')
# self.fg_imgs,self.bg_imgs,self.all_frames = self.get_imgs()
# print('start 04')
#
# print('self.all_frames:',self.all_frames)
#
#
# def get_imgs(self):
# if self.type == 'video_video':
# fg= VideoDataset(self.src_path)
# bg= VideoDataset(self.bg_path)
# all_frames = fg.frame_count
# elif self.type in ['video_img','video_imgs']:
# fg= VideoDataset(self.src_path)
# bg = self.bg_path
# all_frames = fg.frame_count
# elif self.type == 'img_img':
# fg= self.src_path
# bg = self.bg_path
# all_frames = 1
# elif self.type == 'imgs_video':
# fg= self.src_path
# bg = VideoDataset(self.bg_path)
# all_frames = len(self.src_path)
# elif self.type == 'imgs_imgs':
# fg= self.src_path
# bg = self.bg_path
# all_frames = len(fg)
#
#
# return fg,bg,all_frames
#
# def image2tensor(self,image_name):
# image = Image.open(image_name).convert('RGB')
# image = loader(image).unsqueeze(0)
# return image.to(device, torch.float)
#
# def pil2tensor(self,pil_img):
# image = loader(pil_img).unsqueeze(0)
# return image.to(device, torch.float)
#
# def writer(self,img, path):
# img = to_pil_image(img[0].cpu())
# img.save(path)
#
#
# def video_video(self,filename,model,fg_imgs,bg_imgs):
# """
# 视频前景与视频背景
# """
# with torch.no_grad():
# for i in range(fg_imgs.frame_count):
# if self.running:
# src_tensor= self.pil2tensor(fg_imgs[i])
# bg_tensor = self.pil2tensor(bg_imgs[i])
# pha, fgr = model(src_tensor, bg_tensor)[:2]
# out_dir = os.path.join(self.out_dir,filename)
# if not os.path.exists(out_dir):
# os.makedirs(out_dir)
# out_path = os.path.join(out_dir,'{}.{}.png'.format(self.model_type,'%04d' % (i+1)))
# self.writer(pha, out_path)
# self.count += 1
# print('{}完成'.format(i + 1))
#
# def video_img(self,filename,model,fg_imgs,bg_path):
# """
# 单视频前景与单背景图输出
# """
# bg_tensor = self.image2tensor(bg_path)
# with torch.no_grad():
# for i in range(fg_imgs.frame_count):
# if self.running:
# src_tensor= self.pil2tensor(fg_imgs[i])
# pha, fgr = model(src_tensor, bg_tensor)[:2]
# out_dir = os.path.join(self.out_dir,filename)
# if not os.path.exists(out_dir):
# os.makedirs(out_dir)
# out_path = os.path.join(out_dir,'{}.{}.png'.format(self.model_type,'%04d' % (i+1)))
# self.writer(pha, out_path)
# self.count += 1
# print('{}完成'.format(i + 1))
#
# def video_imgs(self,filename,model,fg_imgs,bg_paths):
# """
# 单视频前景与单背景序列输出
# """
#
# with torch.no_grad():
# for i in range(fg_imgs.frame_count):
# print('self.running:',self.running)
# if self.running:
# src_tensor= self.pil2tensor(fg_imgs[i])
# bg_tensor =self.image2tensor(bg_paths[i])
# pha, fgr = model(src_tensor, bg_tensor)[:2]
# out_dir = os.path.join(self.out_dir,filename)
# if not os.path.exists(out_dir):
# os.makedirs(out_dir)
# out_path = os.path.join(out_dir,'{}.{}.png'.format(self.model_type,'%04d' % (i+1)))
# self.writer(pha, out_path)
# self.count += 1
# print('{}完成'.format(i + 1))
#
# def img_img(self,filename,model,src_path,bg_path):
# """
# 单图片前景与单图片背景输出
# """
# src_tensor = self.image2tensor(src_path)
# bg_tensor = self.image2tensor(bg_path)
# with torch.no_grad():
# pha, fgr = model(src_tensor, bg_tensor)[:2]
# out_dir = os.path.join(self.out_dir, filename)
# if not os.path.exists(out_dir):
# os.makedirs(out_dir)
# out_path = os.path.join(out_dir, '{}.png'.format(self.model_type))
# self.writer(pha, out_path)
# self.count += 1
# print('完成')
#
# def imgs_video(self,filename,model,src_paths,bg_imgs):
# """
# 序列前景与单视频背景
# """
#
# with torch.no_grad():
# for i in range(len(src_paths)):
# if self.running:
# src_tensor = self.image2tensor(src_paths[i])
# bg_tensor = self.pil2tensor(bg_imgs[i])
# pha, fgr = model(src_tensor, bg_tensor)[:2]
# out_dir = os.path.join(self.out_dir, filename)
# if not os.path.exists(out_dir):
# os.makedirs(out_dir)
# out_path = os.path.join(out_dir, '{}.{}.png'.format(self.model_type, '%04d' % (i+1)))
# self.writer(pha, out_path)
# self.count += 1
# print('{}完成'.format(i + 1))
#
# def imgs_imgs(self,filename,model,src_paths,bg_paths):
# """
# 序列前景与序列背景
# """
# with torch.no_grad():
# for i in range(len(src_paths)):
# if self.running:
# src_tensor = self.image2tensor(src_paths[i])
# bg_tensor = self.image2tensor(bg_paths[i])
# pha, fgr = model(src_tensor, bg_tensor)[:2]
# out_dir = os.path.join(self.out_dir, filename)
# if not os.path.exists(out_dir):
# os.makedirs(out_dir)
# out_path = os.path.join(out_dir, '{}.{}.png'.format(self.model_type, i + 1))
# self.writer(pha, out_path)
# self.count += 1
# print('{}完成'.format(i + 1))
#
#
# def run(self):
# print('start')
# print(self.model_type)
# print(self.sample_type)
#
# try:
# eval('self.{}'.format(self.type))(self.current_filename,self.model,self.fg_imgs,self.bg_imgs)
# print('start 04')
# except Exception as run_ERR:
# print('run_ERR:',str(run_ERR))
class FirstOrder(QWidget):
def __init__(self):
super(FirstOrder, self).__init__()
self.config_path = os.path.join(os.path.dirname(__file__),'config')
self.vox_yaml_path = os.path.join(self.config_path,'vox-256.yaml')
self.models_path = os.path.join(os.path.dirname(__file__),'models')
self.interp_path = os.path.join(self.models_path,'interp_10.pth')
self.checkpoint = 'vox-cpk.pth.tar'
self.config = {'dataset_params': {'root_dir': 'data/vox-png', 'frame_shape': [256, 256, 3],
'id_sampling': True, 'pairs_list': 'data/vox256.csv',
'augmentation_params': {'flip_param': {'horizontal_flip': True, 'time_flip': True},
'jitter_param': {'brightness': 0.1, 'contrast': 0.1,
'saturation': 0.1, 'hue': 0.1}}},
'model_params': {'common_params': {'num_kp': 10, 'num_channels': 3, 'estimate_jacobian': True},
'kp_detector_params': {'temperature': 0.1, 'block_expansion': 32, 'max_features': 1024,
'scale_factor': 0.25, 'num_blocks': 5},
'generator_params': {'block_expansion': 64, 'max_features': 512, 'num_down_blocks': 2,
'num_bottleneck_blocks': 6, 'estimate_occlusion_map': True,
'dense_motion_params': {'block_expansion': 64, 'max_features': 1024,
'num_blocks': 5, 'scale_factor': 0.25}},
'discriminator_params': {'scales': [1], 'block_expansion': 32, 'max_features': 512, 'num_blocks': 4, 'sn': True}},
'train_params': {'num_epochs': 100, 'num_repeats': 75, 'epoch_milestones': [60, 90],
'lr_generator': 0.0002, 'lr_discriminator': 0.0002, 'lr_kp_detector': 0.0002,
'batch_size': 40, 'scales': [1, 0.5, 0.25, 0.125], 'checkpoint_freq': 50,
'transform_params': {'sigma_affine': 0.05, 'sigma_tps': 0.005, 'points_tps': 5},
'loss_weights': {'generator_gan': 0, 'discriminator_gan': 1,
'feature_matching': [10, 10, 10, 10], 'perceptual': [10, 10, 10, 10, 10],
'equivariance_value': 10, 'equivariance_jacobian': 10}},
'reconstruction_params': {'num_videos': 1000, 'format': '.mp4'},
'animate_params': {'num_pairs': 50, 'format': '.mp4',
'normalization_params': {'adapt_movement_scale': False, 'use_relative_movement': True,
'use_relative_jacobian': True}},
'visualizer_params': {'kp_size': 5, 'draw_border': True, 'colormap': 'gist_rainbow'}}
"""
dataset_params:不需要提取,这是训练集参数
model_params:模型参数,不需要修改
train_params:训练参数,也不需要修改
reconstruction_params : 重构参数,也不需要动
animate_params:动画参数,也不需要
visualizer_params:视觉参数,也不需要
"""
icon = QIcon()
cgai_icon = str(pathlib.Path('CGAI.png').resolve())
icon.addPixmap(QPixmap(cgai_icon))
p = self.palette()
p.setColor(QPalette.Base, QColor('#1C1C1C'))
p.setColor(QPalette.Window, QColor('#393939'))
p.setColor(QPalette.WindowText, QColor('#E8E8E8'))
p.setColor(QPalette.Text, QColor('#1C1C1C'))
self.setPalette(p)
self.big_font = QFont('', 20, 65)
self.label_font = QFont('', 15, 65)
self.mid_font = QFont('', 11, 75)
self.link_btn_style = '''QPushButton{color:black}
QPushButton:hover{color:#FF7F24}
QPushButton{background-color:#CFCFCF}
QPushButton{border:2px}
QPushButton{border-radius:10px}
QPushButton{padding:5px 1px}'''
self.file_btn_style = '''QPushButton{color:black}
QPushButton:hover{color:#FF7F24}
QPushButton{background-color:#CFCFCF}
QPushButton{border:2px}
QPushButton{border-radius:3px}
QPushButton{padding:5px 1px}'''
self.export_btn_style = '''QPushButton{color:black}
QPushButton:hover{color:#FF7F24}
QPushButton{background-color:#CFCFCF}
QPushButton{border:2px}
QPushButton{border-radius:3px}
QPushButton{padding:5px 1px}'''
self.radio_btn_style = ''' QRadioButton:hover{color:#FF7F24}
QRadioButton{color : #E8E8E8}
QRadioButton{border:2px}
QRadioButton{border-radius:10px}
QRadioButton{padding:5px 1px}'''
self.input_type = '图片'
self.setWindowIcon(icon)
self.setWindowTitle('First Order Motion Model')
self.setMinimumHeight(500)
self.setMaximumWidth(650)
self.main_layout = QVBoxLayout()
self.main_layout.setAlignment(Qt.AlignTop)
self.main_layout.setSpacing(18)
self.cgai_layout = QHBoxLayout()
self.cgai_layout.setAlignment(Qt.AlignLeft)
self.cgai_icon = QLabel()
cgai_pixmap = QPixmap(cgai_icon)
scaled_pixmap = cgai_pixmap.scaled(80, 80, Qt.KeepAspectRatio) #
self.cgai_icon.setPixmap(scaled_pixmap) # scale_pixmap
self.cgai_label = QLabel('First-Order-model模型应用')
# self.cgai_label.setAlignment(Qt.AlignCenter)
self.cgai_label.setFont(self.big_font)
self.cgai_layout.addWidget(self.cgai_icon)
self.cgai_layout.addWidget(self.cgai_label)
self.git_layout = QHBoxLayout()
self.git_layout.setContentsMargins(0, 0, 0, 20)
self.gitee_btn = QPushButton('Gitee 源码')
self.gitee_btn.clicked.connect(self._open_gitee)
self.gitee_btn.setStyleSheet(self.link_btn_style)
self.github_btn = QPushButton('Github 原项目')
self.github_btn.clicked.connect(self._open_github)
self.github_btn.setStyleSheet(self.link_btn_style)
self.baidu_btn = QPushButton('百度云:提取码CGAI')
self.baidu_btn.clicked.connect(self._open_baidupan)
self.baidu_btn.setStyleSheet(self.link_btn_style)
self.help_btn = QPushButton('使用帮助')
self.help_btn.clicked.connect(self._help)
self.help_btn.setStyleSheet(self.link_btn_style)
self.git_layout.addWidget(self.gitee_btn)
self.git_layout.addWidget(self.github_btn)
self.git_layout.addWidget(self.baidu_btn)
self.git_layout.addWidget(self.help_btn)
float_rex = QRegExp("[0-9\.-]+$")
int_rex = QRegExp("[0-9]+$")
self.float_rv = QRegExpValidator(float_rex, self)
self.int_rv = QRegExpValidator(int_rex, self)
self.first_motion_label = QLabel('视频驱动')
self.first_motion_label.setAlignment(Qt.AlignCenter)
self.first_motion_label.setFont(self.label_font)
self.src_layout = QHBoxLayout()
self.src_layout.setContentsMargins(0, 25, 0, 0)
self.src_label = QLabel('驱动视频路径:')
self.src_label.setFont(self.mid_font)
self.src_edit = QLineEdit()
self.src_btn = QPushButton('··')
self.src_btn.setStyleSheet(self.file_btn_style)
self.src_btn.clicked.connect(self._select_src)
self.open_src_btn = QPushButton('👆')
self.open_src_btn.setStyleSheet(self.file_btn_style)
self.open_src_btn.clicked.connect(self._open_src)
self.src_layout.addWidget(self.src_label)
self.src_layout.addWidget(self.src_edit)
self.src_layout.addWidget(self.src_btn)
self.src_layout.addWidget(self.open_src_btn)
self.img_layout = QHBoxLayout()
self.img_layout.setContentsMargins(0, 0, 0, 0)
self.img_label = QLabel('源图片路径 :')
self.img_label.setFont(self.mid_font)
self.img_edit = QLineEdit()
self.img_btn = QPushButton('··')
self.img_btn.setStyleSheet(self.file_btn_style)
self.img_btn.clicked.connect(self._select_img)
self.open_img_btn = QPushButton('👆')
self.open_img_btn.setStyleSheet(self.file_btn_style)
self.open_img_btn.clicked.connect(self._open_img)
self.img_layout.addWidget(self.img_label)
self.img_layout.addWidget(self.img_edit)
self.img_layout.addWidget(self.img_btn)
self.img_layout.addWidget(self.open_img_btn)
self.export_layout = QHBoxLayout()
self.export_layout.setContentsMargins(0, 0, 0, 0)
self.export_label = QLabel('输出目录 :')
self.export_label.setFont(self.mid_font)
self.export_edit = QLineEdit()
self.export_btn = QPushButton('··')
self.export_btn.setStyleSheet(self.file_btn_style)
self.export_btn.clicked.connect(self._export_dir)
self.open_export_btn = QPushButton('👆')
self.open_export_btn.setStyleSheet(self.file_btn_style)
self.open_export_btn.clicked.connect(self._open_export)
self.export_layout.addWidget(self.export_label)
self.export_layout.addWidget(self.export_edit)
self.export_layout.addWidget(self.export_btn)
self.export_layout.addWidget(self.open_export_btn)
self.btn_layout = QHBoxLayout()
self.btn_layout.setContentsMargins(0, 15, 0, 20)
self.btn = QPushButton('驱动图片')
self.btn.setMaximumWidth(100)
self.btn.setStyleSheet(self.export_btn_style)
self.btn.clicked.connect(self._create)
self.btn_layout.addWidget(self.btn)
self.super_resolution_label =QLabel('超分辨率')
self.super_resolution_label.setAlignment(Qt.AlignCenter)
self.super_resolution_label.setFont(self.label_font)
# self.parm_layout = QVBoxLayout()
self.fps_layout = QHBoxLayout()
self.fps_layout.setContentsMargins(0,0,100,0)
self.fps_layout.setSpacing(55)
self.fps_layout.setAlignment(Qt.AlignLeft)
self.fps_label = QLabel('输出fps:')
self.fps_label.setToolTip('帧数率')
self.fps_label.setFont(self.mid_font)
self.fps_eidt = QLineEdit()
self.fps_eidt.setValidator(self.int_rv)
self.fps_eidt.setText('30')
self.fps_eidt.setFixedWidth(50)
self.video_check = QCheckBox('同时输出视频')
self.video_check.setChecked(True)
self.gpu_check = QCheckBox('使用GPU')
self.gpu_check.setChecked(True)
self.fps_layout.addWidget(self.fps_label)
self.fps_layout.addWidget(self.fps_eidt)
self.fps_layout.addWidget(self.video_check)
self.fps_layout.addWidget(self.gpu_check)
# self.parm_layout.addLayout(self.fps_layout)
self.images_layout = QHBoxLayout()
self.images_layout.setContentsMargins(0, 0, 0, 0)
self.images_label = QLabel('图片目录:')
self.images_label.setFont(self.mid_font)
self.images_edit = QLineEdit()
self.images_btn = QPushButton('··')
self.images_btn.setStyleSheet(self.file_btn_style)
self.images_btn.clicked.connect(self._images_dir)
self.open_images_btn = QPushButton('👆')
self.open_images_btn.setStyleSheet(self.file_btn_style)
self.open_images_btn.clicked.connect(self._open_images)
self.images_layout.addWidget(self.images_label)
self.images_layout.addWidget(self.images_edit)
self.images_layout.addWidget(self.images_btn)
self.images_layout.addWidget(self.open_images_btn)
self.out_layout = QHBoxLayout()
self.out_layout.setContentsMargins(0, 0, 0, 0)
self.out_label = QLabel('输出目录:')
self.out_label.setFont(self.mid_font)
self.out_edit = QLineEdit()
self.out_btn = QPushButton('··')
self.out_btn.setStyleSheet(self.file_btn_style)
self.out_btn.clicked.connect(self._out_dir)
self.open_out_btn = QPushButton('👆')
self.open_out_btn.setStyleSheet(self.file_btn_style)
self.open_out_btn.clicked.connect(self._open_out)
self.out_layout.addWidget(self.out_label)
self.out_layout.addWidget(self.out_edit)
self.out_layout.addWidget(self.out_btn)
self.out_layout.addWidget(self.open_out_btn)
self.convert_btn_layout = QHBoxLayout()
self.convert_btn_layout.setContentsMargins(0, 15, 0, 20)
self.convert_btn = QPushButton('超分图片')
self.convert_btn.setMaximumWidth(100)
self.convert_btn.setStyleSheet(self.export_btn_style)
self.convert_btn.clicked.connect(self._convert)
self.convert_btn_layout.addWidget(self.convert_btn)
self.command_text = QTextBrowser()
self.command_text.setTextColor(QColor('#E8E8E8'))
self.command_text.append('CGAI即时演绎,让你所想即所有')
self.command_text.setMinimumHeight(50)
self.main_layout.addLayout(self.cgai_layout)
self.main_layout.addLayout(self.git_layout)
self.main_layout.addWidget(self.first_motion_label)
self.main_layout.addLayout(self.src_layout)
self.main_layout.addLayout(self.img_layout)
self.main_layout.addLayout(self.export_layout)
self.main_layout.addLayout(self.btn_layout)
self.main_layout.addWidget(self.super_resolution_label)
self.main_layout.addLayout(self.fps_layout)
self.main_layout.addLayout(self.images_layout)
self.main_layout.addLayout(self.out_layout)
self.main_layout.addLayout(self.convert_btn_layout)
self.main_layout.addWidget(self.command_text)
self.setLayout(self.main_layout)
# self.timer = QBasicTimer()
# self.prog = MattingProgressBar()
# self.prog.close_Signal.connect(self._close_prog)
def _keep_len(self,stauts):
if stauts:
self.longer_side_len_eidt.setDisabled(True)
else:
self.longer_side_len_eidt.setDisabled(False)
def update_parms(self):
print('update_parms start')
fps = int(self.fps_eidt.text())
num_frame = int(self.num_frames_eidt.text())
x_shift_range = [float(self.x_shift_range_0.text()),float(self.x_shift_range_1.text()),
float(self.x_shift_range_2.text()),float(self.x_shift_range_3.text())]
y_shift_range = [float(self.y_shift_range_0.text()),float(self.y_shift_range_1.text()),
float(self.y_shift_range_2.text()),float(self.y_shift_range_3.text())]
z_shift_range = [float(self.z_shift_range_0.text()),float(self.z_shift_range_1.text()),
float(self.z_shift_range_2.text()),float(self.z_shift_range_3.text())]
# print('update_parms 001')
longer_side_len = int(self.longer_side_len_eidt.text())
save_ply = self.save_ply_check.isChecked()
gpu_ids = int(self.gpu_ids_eidt.text())
depth_threshold = float(self.depth_threshold_eidt.text())
ext_edge_threshold = float(self.ext_edge_threshold_eidt.text())
sparse_iter = int(self.sparse_iter_eidt.text())
sigma_s = float(self.sigma_s_eidt.text())
sigma_r = float(self.sigma_r_eidt.text())
# print('update_parms 002')
redundant_number = int(self.redundant_number_eidt.text())
background_thickness = int(self.background_thickness_eidt.text())
background_thickness_2 = int(self.background_thickness2_eidt.text())
context_thickness = int(self.context_thickness_eidt.text())
context_thickness_2 = int(self.context_thickness2_eidt.text())
largest_size = int(self.largest_size_eidt.text())
depth_edge_dilate = int(self.depth_edge_dilate_eidt.text())
depth_edge_dilate_2 = int(self.depth_edge_dilate2_eidt.text())
extrapolation_thickness = int(self.extrapolation_thickness_eidt.text())
crop_border = [float(self.crop_border_0.text()),float(self.crop_border_1.text()),
float(self.crop_border_2.text()),float(self.crop_border_3.text())]
# print('update_parms 003')
parms = {'fps':fps,'num_frame':num_frame,'x_shift_range':x_shift_range,
'y_shift_range':y_shift_range,'z_shift_range':z_shift_range,'longer_side_len':longer_side_len,
'save_ply':save_ply,'gpu_ids':gpu_ids,'depth_threshold':depth_threshold,
'ext_edge_threshold':ext_edge_threshold,'sparse_iter':sparse_iter,'sigma_s':sigma_s,
'sigma_r':sigma_r,'redundant_number':redundant_number,'background_thickness':background_thickness,
'background_thickness_2':background_thickness_2,'context_thickness':context_thickness,'context_thickness_2':context_thickness_2,
'largest_size':largest_size,'depth_edge_dilate':depth_edge_dilate,'depth_edge_dilate_2':depth_edge_dilate_2,
'extrapolation_thickness':extrapolation_thickness,'crop_border':crop_border }
self.config.update(parms)
# import pprint
# pprint.pprint(self.config)
def _select_src(self):
# dir_path = QFileDialog.getExistingDirectory(self, '选择文件夹')
# # print('dir_path:',dir_path)
# if dir_path:
# self.src_edit.setText(dir_path)
video_path = QFileDialog.getOpenFileName(self,'选择视频','','Excel files(*.mp4 ; *.flv;*.avi)')
# print('dir_path:',dir_path)
if video_path:
self.src_edit.setText(video_path[0])
def _open_src(self):
os.startfile(os.path.dirname(self.src_edit.text()))
def _select_img(self):
img_path = QFileDialog.getOpenFileName(self,'选择图片','','Excel files(*.jpg ; *.png; *.jpeg)')
if img_path:
self.img_edit.setText(img_path[0])
def _open_img(self):
os.startfile(os.path.dirname(self.img_edit.text()))
def _select_bg(self):
dir_path = QFileDialog.getExistingDirectory(self, '选择文件夹')
if dir_path:
self.bg_edit.setText(dir_path)
def _export_dir(self):
dir_path = QFileDialog.getExistingDirectory(self, '选择文件夹')
if dir_path:
self.export_edit.setText(dir_path)
def _open_export(self):
os.startfile(self.export_edit.text())
def _images_dir(self):
dir_path = QFileDialog.getExistingDirectory(self, '选择文件夹')
if dir_path:
self.images_edit.setText(dir_path)
def _open_images(self):
os.startfile(self.images_edit.text())
def _out_dir(self):
dir_path = QFileDialog.getExistingDirectory(self, '选择文件夹')
if dir_path:
self.out_edit.setText(dir_path)
def _open_out(self):
os.startfile(self.out_edit.text())
def _open_gitee(self):
web.open(r'https://gitee.com/cgai/first-order-model')
def _open_github(self):
web.open(r'https://github.com/AliaksandrSiarohin/first-order-model')
def _open_baidupan(self):
web.open(r'https://pan.baidu.com/s/1D6OcSzQKlxM_iYzuUaueOg')
def update_command(self,text):
self.command_text.append(text)
QApplication.processEvents()
def _create(self):
src_path = self.src_edit.text()
img_path = self.img_edit.text()
out_dir = self.export_edit.text()
if src_path and img_path and out_dir:
img_full_name = os.path.split(img_path)[1]
img_name = os.path.splitext(img_full_name)[0]
source_image = img_path
driving_video = src_path
result_video = os.path.join(out_dir,f'{img_name}.mp4')
imgs_dir = os.path.join(out_dir,img_name)
print('imgs_dir:',imgs_dir)
self.update_command(f'imgs_dir: {imgs_dir}')
os.makedirs(imgs_dir,exist_ok=True)
relative = True
adapt_scale = True
source_image = imageio.imread(source_image)
reader = imageio.get_reader(driving_video)
print('reader')
fps = reader.get_meta_data()['fps']
driving_video = []
try:
frame_count = 0
for im in reader:
driving_video.append(im)
frame_count +=1
print('append frame:',frame_count)
self.update_command(f'append frame: {frame_count}')
except RuntimeError:
pass
reader.close()
print('resize to 256')
self.update_command('resize to 256')
source_image = resize(source_image, (256, 256))[..., :3]
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
# generator, kp_detector = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint,
# cpu=opt.cpu)
print('load_checkpoints start')
self.update_command('load_checkpoints start')
generator, kp_detector = load_checkpoints(config_path=self.vox_yaml_path, checkpoint_path=self.checkpoint,
cpu=False)
print('load_checkpoints over')
self.update_command('load_checkpoints over')
# if opt.find_best_frame or opt.best_frame is not None:
# # i = opt.best_frame if opt.best_frame is not None else find_best_frame(source_image, driving_video,
# # cpu=opt.cpu)
# i = opt.best_frame if opt.best_frame is not None else find_best_frame(source_image, driving_video,
# cpu='cpu')
# print("Best frame: " + str(i))
# driving_forward = driving_video[i:]
# driving_backward = driving_video[:(i + 1)][::-1]
# predictions_forward = make_animation(source_image, driving_forward, generator, kp_detector,
# relative=opt.relative, adapt_movement_scale=opt.adapt_scale,
# cpu=opt.cpu)
# predictions_backward = make_animation(source_image, driving_backward, generator, kp_detector,
# relative=opt.relative, adapt_movement_scale=opt.adapt_scale,
# cpu=opt.cpu)
# predictions = predictions_backward[::-1] + predictions_forward[1:]
# else:
# predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative,
# adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
print('predictions start')
self.update_command('predictions start')
predictions = make_animation(self,source_image, driving_video, generator, kp_detector, relative=relative,
adapt_movement_scale=adapt_scale, cpu=True)
print('predictions over')
self.update_command('predictions over')
print()
print('save video start')
self.update_command('start save video')
imageio.mimsave(result_video, [img_as_ubyte(frame) for frame in predictions], fps=fps)
print('save video over')
self.update_command('save video over')
print()
print('start save imags')
self.update_command('start save imags')
for f in range(len(predictions)):
img = img_as_ubyte(predictions[f])
out_img_path = os.path.join(imgs_dir,f'{f+1}.jpg')
imageio.imsave(out_img_path,img)
print('输出:',out_img_path)
self.update_command(f'输出: {out_img_path}')
self.update_command('Done !')
def _convert(self):
imgs_dir = self.images_edit.text()
out_dir = self.out_edit.text()
if imgs_dir and out_dir:
img_0 = os.listdir(imgs_dir)[0]
img_full_name = os.path.split(img_0)[1]
img_name = os.path.splitext(img_full_name)[0]
# model_path = 'models/interp_08.pth' # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
model_path = self.interp_path# 'models/RRDB_ESRGAN_x4.pth' # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
dev= 'cuda' if self.gpu_check.isChecked() else 'cpu'
device = torch.device(dev) # if you want to run on CPU, change 'cuda' -> cpu
export_imgs_dir = os.path.join(out_dir,img_name)
os.makedirs(export_imgs_dir,exist_ok=True)
model = arch.RRDBNet(3, 3, 64, 23, gc=32)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
model = model.to(device)
print('Model path {:s}. \nTesting...'.format(model_path))
self.update_command('开始超分辨率')
idx = 0
for img in os.listdir(imgs_dir):
idx += 1
path = os.path.join(imgs_dir,img)
full_name = os.path.split(path)[1]
img = cv2.imread(path, cv2.IMREAD_COLOR)
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.to(device)
with torch.no_grad():
output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round()
# cv2.imwrite('results/{:s}_rlt.png'.format(base), output)
write_img_path = os.path.join(export_imgs_dir,full_name)
cv2.imwrite(write_img_path, output)
self.update_command(f'生成 :{write_img_path}')
self.update_command('序列生成完成!')
if self.video_check.isChecked():
self.update_command('合成视频')
sfps = self.fps_eidt.text()
fps = int(sfps) if sfps else 30
video_path = os.path.join(out_dir,f'{img_name}.mp4')
imgs_dir = export_imgs_dir
imgs_list = os.listdir(imgs_dir)
o_imgs_list = sorted(imgs_list, key=lambda x: int(os.path.splitext(x)[0]))
frames = []
for i in o_imgs_list:
img_path = os.path.join(imgs_dir, i)
source_image = imageio.imread(img_path)
frames.append(source_image)
imageio.mimsave(video_path, frames, fps=fps)
self.update_command('视频输出完成')
def _help(self):
QMessageBox.about(None, '帮助', '1.驱动视频下指认一个视频;源图片路径下指认一张需要驱动的图片,输出目录指认一个输出工程文件夹。\n\n'
'2.输入视频与图片都需要是矩形尺寸即长宽需要相等。\n'
' 图片格式暂时只支持uint8,像exr,tif等float32格式的暂不支持\n'
' 注意:无论输入尺寸多大,输出的尺寸都是256x256大小,所以没必要输入高清图片\n\n'
'3.使用超分辨率可以一定程度上增加尺寸大小,默认输出是1024x1024\n\n'
'4.使用超分辨率时如果要同时生成视频,可以勾上同时输出视频,但注意要指认帧速率\n\n'
'5.超分辨率图片目录需要指认序列图的文件夹(需要干净的目录,该目录下除了图片没其他类型文件)\n\n'
'6.超分辨率输出目录需要指认一个输出文件夹。\n\n'
'7.驱动视频与图片的生成较超分辨率更快,超分辨率时间较长,210帧图片我这里大概花了8分钟。\n\n'
'8.驱动与超分辨率的速度都受到机器性能尤其显卡的影响,'
)
if __name__ == '__main__':
os.chdir(os.path.dirname(__file__))
app = QApplication(sys.argv)
p = FirstOrder()
p.show()
sys.exit(app.exec_())
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/cgai/first-order-model.git
git@gitee.com:cgai/first-order-model.git
cgai
first-order-model
first-order-model
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385