2 Star 59 Fork 20

天涯/RobustVideoMattingGUI

Create your Gitee Account
Explore and code with more than 12 million developers,Free private repositories !:)
Sign up
文件
Clone or Download
cutie_matting.py 3.40 KB
Copy Edit Raw Blame History
天涯 authored 2024-04-07 21:24 . 修复BUG
'''
file name: File name
Date: 2024-03-18 16:32:16
LastEditors: YuanMing
LastEditTime: 2024-03-22 17:22:49
Describe:
'''
import torch
import numpy as np
from hydra import compose, initialize
from CUTIE.cutie.model.cutie import CUTIE
from CUTIE.cutie.inference.inference_core import InferenceCore
from CUTIE.scripts.download_models import download_models_if_needed
@torch.inference_mode()
@torch.cuda.amp.autocast()
def cutie_matting(images,mask,isnegative=1):
if torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'
masks=[]
download_models_if_needed()
initialize(version_base='1.3.2', config_path="cutie/config", job_name="gui")
cfg = compose(config_name="gui_config")
cutie = CUTIE(cfg).eval().to(device)
model_weights = torch.load(cfg.weights, map_location=device)
cutie.load_weights(model_weights)
# cutie = get_default_model()
processor = InferenceCore(cutie, cfg=cutie.cfg)
# image_path = './examples/images/bike'
# images = sorted(os.listdir(image_path)) # ordering is important
# mask = Image.open('./examples/masks/bike/00000.png')
# palette = mask.getpalette()
# objects = np.unique(np.array(mask))
# objects = objects[objects != 0].tolist() # background "0" does not count as an object
# mask = torch.from_numpy(np.array(mask)).cuda()
# for ti, image_name in enumerate(images):
# image = Image.open(os.path.join(image_path, image_name))
# image = to_tensor(image).cuda().float()
# if ti == 0:
# output_prob = processor.step(image, mask, objects=objects)
# else:
# output_prob = processor.step(image)
# # convert output probabilities to an object mask
# mask = processor.output_prob_to_mask(output_prob)
# # visualize prediction
# mask = Image.fromarray(mask.cpu().numpy().astype(np.uint8))
# mask.putpalette(palette)
# mask.show() # or use mask.save(...) to save it somewhere
objects = np.unique(np.array(mask))
objects = objects[objects != 0].tolist() # background "0" does not count as an object
palette = mask.getpalette()
if isnegative==1:
for i in range(0,len(images),1):
image=images[i]
if i==0:
output_prob = processor.step(image, mask, objects=objects)
else:
output_prob = processor.step(image)
pass
mask = processor.output_prob_to_mask(output_prob)
# visualize prediction
mask = image.fromarray(mask.cpu().numpy().astype(np.uint8))
mask.putpalette(palette)
masks.append(mask)
pass
# mask.show() # or use mask.save(...) to save it somewhere
pass
else:
for i in range(len(images),0,-1):
image=images[i]
if i==0:
output_prob = processor.step(image, mask, objects=objects)
else:
output_prob = processor.step(image)
pass
mask = processor.output_prob_to_mask(output_prob)
# visualize prediction
mask = image.fromarray(mask.cpu().numpy().astype(np.uint8))
mask.putpalette(palette)
masks.append(mask)
pass
# mask.show() # or use mask.save(...) to save it somewhere
pass
pass
pass
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/ymfjly/RobustVideoMatting.git
git@gitee.com:ymfjly/RobustVideoMatting.git
ymfjly
RobustVideoMatting
RobustVideoMattingGUI
RVMGUI.SCTOOL

Search