1 Star 0 Fork 0

萧纵/MimicMotion

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
inference.py 5.88 KB
一键复制 编辑 原始数据 按行查看 历史
zyayoung 提交于 2024-07-09 18:45 . Optimize VRAM Usage
import os
import argparse
import logging
import math
from omegaconf import OmegaConf
from datetime import datetime
from pathlib import Path
import numpy as np
import torch.jit
from torchvision.datasets.folder import pil_loader
from torchvision.transforms.functional import pil_to_tensor, resize, center_crop
from torchvision.transforms.functional import to_pil_image
from mimicmotion.utils.geglu_patch import patch_geglu_inplace
patch_geglu_inplace()
from constants import ASPECT_RATIO
from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline
from mimicmotion.utils.loader import create_pipeline
from mimicmotion.utils.utils import save_to_mp4
from mimicmotion.dwpose.preprocess import get_video_pose, get_image_pose
logging.basicConfig(level=logging.INFO, format="%(asctime)s: [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def preprocess(video_path, image_path, resolution=576, sample_stride=2):
"""preprocess ref image pose and video pose
Args:
video_path (str): input video pose path
image_path (str): reference image path
resolution (int, optional): Defaults to 576.
sample_stride (int, optional): Defaults to 2.
"""
image_pixels = pil_loader(image_path)
image_pixels = pil_to_tensor(image_pixels) # (c, h, w)
h, w = image_pixels.shape[-2:]
############################ compute target h/w according to original aspect ratio ###############################
if h>w:
w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64
else:
w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution
h_w_ratio = float(h) / float(w)
if h_w_ratio < h_target / w_target:
h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio)
else:
h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_target
image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None)
image_pixels = center_crop(image_pixels, [h_target, w_target])
image_pixels = image_pixels.permute((1, 2, 0)).numpy()
##################################### get image&video pose value #################################################
image_pose = get_image_pose(image_pixels)
video_pose = get_video_pose(video_path, image_pixels, sample_stride=sample_stride)
pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose])
image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2))
return torch.from_numpy(pose_pixels.copy()) / 127.5 - 1, torch.from_numpy(image_pixels) / 127.5 - 1
def run_pipeline(pipeline: MimicMotionPipeline, image_pixels, pose_pixels, device, task_config):
image_pixels = [to_pil_image(img.to(torch.uint8)) for img in (image_pixels + 1.0) * 127.5]
generator = torch.Generator(device=device)
generator.manual_seed(task_config.seed)
frames = pipeline(
image_pixels, image_pose=pose_pixels, num_frames=pose_pixels.size(0),
tile_size=task_config.num_frames, tile_overlap=task_config.frames_overlap,
height=pose_pixels.shape[-2], width=pose_pixels.shape[-1], fps=7,
noise_aug_strength=task_config.noise_aug_strength, num_inference_steps=task_config.num_inference_steps,
generator=generator, min_guidance_scale=task_config.guidance_scale,
max_guidance_scale=task_config.guidance_scale, decode_chunk_size=8, output_type="pt", device=device
).frames.cpu()
video_frames = (frames * 255.0).to(torch.uint8)
for vid_idx in range(video_frames.shape[0]):
# deprecated first frame because of ref image
_video_frames = video_frames[vid_idx, 1:]
return _video_frames
@torch.no_grad()
def main(args):
if not args.no_use_float16 :
torch.set_default_dtype(torch.float16)
infer_config = OmegaConf.load(args.inference_config)
pipeline = create_pipeline(infer_config, device)
for task in infer_config.test_case:
############################################## Pre-process data ##############################################
pose_pixels, image_pixels = preprocess(
task.ref_video_path, task.ref_image_path,
resolution=task.resolution, sample_stride=task.sample_stride
)
########################################### Run MimicMotion pipeline ###########################################
_video_frames = run_pipeline(
pipeline,
image_pixels, pose_pixels,
device, task
)
################################### save results to output folder. ###########################################
save_to_mp4(
_video_frames,
f"{args.output_dir}/{os.path.basename(task.ref_video_path).split('.')[0]}" \
f"_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4",
fps=task.fps,
)
def set_logger(log_file=None, log_level=logging.INFO):
log_handler = logging.FileHandler(log_file, "w")
log_handler.setFormatter(
logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s")
)
log_handler.setLevel(log_level)
logger.addHandler(log_handler)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--log_file", type=str, default=None)
parser.add_argument("--inference_config", type=str, default="configs/test.yaml") #ToDo
parser.add_argument("--output_dir", type=str, default="outputs/", help="path to output")
parser.add_argument("--no_use_float16",
action="store_true",
help="Whether use float16 to speed up inference",
)
args = parser.parse_args()
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
set_logger(args.log_file \
if args.log_file is not None else f"{args.output_dir}/{datetime.now().strftime('%Y%m%d%H%M%S')}.log")
main(args)
logger.info(f"--- Finished ---")
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ccst/mimic-motion.git
git@gitee.com:ccst/mimic-motion.git
ccst
mimic-motion
MimicMotion
main

搜索帮助