Fetch the repository succeeded.
'''
file name: File name
Date: 2024-03-18 16:32:16
LastEditors: YuanMing
LastEditTime: 2024-03-22 17:22:49
Describe:
'''
import torch
import numpy as np
from hydra import compose, initialize
from CUTIE.cutie.model.cutie import CUTIE
from CUTIE.cutie.inference.inference_core import InferenceCore
from CUTIE.scripts.download_models import download_models_if_needed
@torch.inference_mode()
@torch.cuda.amp.autocast()
def cutie_matting(images,mask,isnegative=1):
if torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'
masks=[]
download_models_if_needed()
initialize(version_base='1.3.2', config_path="cutie/config", job_name="gui")
cfg = compose(config_name="gui_config")
cutie = CUTIE(cfg).eval().to(device)
model_weights = torch.load(cfg.weights, map_location=device)
cutie.load_weights(model_weights)
# cutie = get_default_model()
processor = InferenceCore(cutie, cfg=cutie.cfg)
# image_path = './examples/images/bike'
# images = sorted(os.listdir(image_path)) # ordering is important
# mask = Image.open('./examples/masks/bike/00000.png')
# palette = mask.getpalette()
# objects = np.unique(np.array(mask))
# objects = objects[objects != 0].tolist() # background "0" does not count as an object
# mask = torch.from_numpy(np.array(mask)).cuda()
# for ti, image_name in enumerate(images):
# image = Image.open(os.path.join(image_path, image_name))
# image = to_tensor(image).cuda().float()
# if ti == 0:
# output_prob = processor.step(image, mask, objects=objects)
# else:
# output_prob = processor.step(image)
# # convert output probabilities to an object mask
# mask = processor.output_prob_to_mask(output_prob)
# # visualize prediction
# mask = Image.fromarray(mask.cpu().numpy().astype(np.uint8))
# mask.putpalette(palette)
# mask.show() # or use mask.save(...) to save it somewhere
objects = np.unique(np.array(mask))
objects = objects[objects != 0].tolist() # background "0" does not count as an object
palette = mask.getpalette()
if isnegative==1:
for i in range(0,len(images),1):
image=images[i]
if i==0:
output_prob = processor.step(image, mask, objects=objects)
else:
output_prob = processor.step(image)
pass
mask = processor.output_prob_to_mask(output_prob)
# visualize prediction
mask = image.fromarray(mask.cpu().numpy().astype(np.uint8))
mask.putpalette(palette)
masks.append(mask)
pass
# mask.show() # or use mask.save(...) to save it somewhere
pass
else:
for i in range(len(images),0,-1):
image=images[i]
if i==0:
output_prob = processor.step(image, mask, objects=objects)
else:
output_prob = processor.step(image)
pass
mask = processor.output_prob_to_mask(output_prob)
# visualize prediction
mask = image.fromarray(mask.cpu().numpy().astype(np.uint8))
mask.putpalette(palette)
masks.append(mask)
pass
# mask.show() # or use mask.save(...) to save it somewhere
pass
pass
pass
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。