代码拉取完成,页面将自动刷新
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @author: Wesley
# @time: 2020-12-11 10:47
import os
import cv2
import torch
from models.unet import UNet
from torchvision import transforms
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(1, 1).to(device)
weight = r'E:\PyCharmProject\Road-Detection\weights\weight.pt'
if os.path.exists(weight):
net.load_state_dict(torch.load(weight))
img_path = 'src/img/1_sat.jpg'
mask_path = 'src/img/1_mask.png'
if __name__ == '__main__':
origin = cv2.imread(img_path, 1)
cv2.imshow('origin', origin)
tr = transforms.Compose([transforms.ToTensor()])
img = tr(origin).unsqueeze(0).to(device)
mask = tr(cv2.imread(mask_path, 0))
net.eval()
with torch.no_grad():
pred = net(img)
pred[pred >= 0.5] = 1
pred[pred < 0.5] = 0
TP = ((pred == 1) & (mask == 1)).sum()
TN = ((pred == 0) & (mask == 0)).sum()
FN = ((pred == 0) & (mask == 1)).sum()
FP = ((pred == 1) & (mask == 0)).sum()
pa = (TP + TN) / (TP + TN + FP + FN)
iou = TP / (TP + FP + FN)
print('pa: ', pa)
print('iou', iou)
cv2.imshow('origin_out', np.hstack([img, pred]))
cv2.waitKey(0)
cv2.destroyAllWindows()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。