1 Star 0 Fork 0

孙志强/Pytorch-UNet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
evaluate.py 1.54 KB
一键复制 编辑 原始数据 按行查看 历史
Arka 提交于 2021-10-24 17:07 . bug fixes for i/o and low val set
import torch
import torch.nn.functional as F
from tqdm import tqdm
from utils.dice_score import multiclass_dice_coeff, dice_coeff
def evaluate(net, dataloader, device):
net.eval()
num_val_batches = len(dataloader)
dice_score = 0
# iterate over the validation set
for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
image, mask_true = batch['image'], batch['mask']
# move images and labels to correct device and type
image = image.to(device=device, dtype=torch.float32)
mask_true = mask_true.to(device=device, dtype=torch.long)
mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
with torch.no_grad():
# predict the mask
mask_pred = net(image)
# convert to one-hot format
if net.n_classes == 1:
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
# compute the Dice score
dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
else:
mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
# compute the Dice score, ignoring background
dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False)
net.train()
# Fixes a potential division by zero error
if num_val_batches == 0:
return dice_score
return dice_score / num_val_batches
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/sunzhiq99/Pytorch-UNet.git
git@gitee.com:sunzhiq99/Pytorch-UNet.git
sunzhiq99
Pytorch-UNet
Pytorch-UNet
master

搜索帮助

23e8dbc6 1850385 7e0993f3 1850385