1 Star 3 Fork 1

ANDY/Unet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
predict.py 4.56 KB
一键复制 编辑 原始数据 按行查看 历史
import argparse
import logging
import os
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from unet import *
# from unet import UNet
from utils.data_vis import plot_img_and_mask
from utils.dataset import BasicDataset
def predict_img(net,
full_img,
device,
scale_factor=1,
out_threshold=0.5):
net.eval()
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
with torch.no_grad():
output = net(img)
if net.n_classes > 1:
probs = F.softmax(output, dim=1)
else:
probs = torch.sigmoid(output)
probs = probs.squeeze(0)
tf = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize(full_img.size[1]),
transforms.ToTensor()
]
)
probs = tf(probs.cpu())
full_mask = probs.squeeze().cpu().numpy()
return full_mask > out_threshold
def get_args():
parser = argparse.ArgumentParser(description='Predict masks from input images',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--model', '-m', default='/home/andy/project/python/Pytorch-UNet/checkpoints/CP_epoch30.pth',
metavar='FILE',
help="Specify the file in which the model is stored")
parser.add_argument('--input', '-i', default="/home/andy/project/python/Pytorch-UNet/data/image/image011.jpg",metavar='INPUT', nargs='+',
help='filenames of input images', required=True)
parser.add_argument('--output', '-o',default="/home/andy/project/python/Pytorch-UNet/data/out/predict.png", metavar='INPUT', nargs='+',
help='Filenames of ouput images')
parser.add_argument('--viz', '-v', action='store_true',
help="Visualize the images as they are processed",
default=False)
parser.add_argument('--no-save', '-n', action='store_true',
help="Do not save the output masks",
default=False)
parser.add_argument('--mask-threshold', '-t', type=float,
help="Minimum probability value to consider a mask pixel white",
default=0.55)
parser.add_argument('--scale', '-s', type=float,
help="Scale factor for the input images",
default=1.0)
return parser.parse_args()
def get_output_filenames(args):
in_files = args.input
out_files = []
if not args.output:
for f in in_files:
pathsplit = os.path.splitext(f)
out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1]))
elif len(in_files) != len(args.output):
logging.error("Input files and output files are not of the same length")
raise SystemExit()
else:
out_files = args.output
return out_files
def mask_to_image(mask):
return Image.fromarray((mask * 255).astype(np.uint8))
if __name__ == "__main__":
args = get_args()
in_files = args.input
out_files = get_output_filenames(args)
# net = UNet(n_channels=3, n_classes=1)
net = get_efficientunet_b0(out_channels=1, concat_input=True, pretrained=False).cuda()
logging.info("Loading model {}".format(args.model))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
net.to(device=device)
net.load_state_dict(torch.load(args.model, map_location=device))
logging.info("Model loaded !")
for i, fn in enumerate(in_files):
logging.info("\nPredicting image {} ...".format(fn))
img = Image.open(fn)
out = img.resize((224, 224), Image.ANTIALIAS)
mask = predict_img(net=net,
full_img=out,
scale_factor=args.scale,
out_threshold=args.mask_threshold,
device=device)
if not args.no_save:
out_fn = out_files[i]
result = mask_to_image(mask)
out2 = result.resize((369, 191), Image.ANTIALIAS)
out2.save(out_files[i])
logging.info("Mask saved to {}".format(out_files[i]))
if args.viz:
logging.info("Visualizing results for image {}, close to continue ...".format(fn))
plot_img_and_mask(img, mask)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/andyFengce/unet.git
git@gitee.com:andyFengce/unet.git
andyFengce
unet
Unet
master

搜索帮助