1 Star 0 Fork 2

zzzhangys/PointRend-PyTorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
infer.py 703 Bytes
一键复制 编辑 原始数据 按行查看 历史
小荷才露尖尖角 提交于 2020-02-06 22:19 . # update for ourself
import torch
import logging
def iou_pytorch(outputs, labels, eps=1e-6):
outputs = outputs.squeeze(1)
intersection = (outputs & labels).float().sum((1, 2))
union = (outputs | labels).float().sum((1, 2))
iou = (intersection + eps) / (union + eps)
return iou
@torch.no_grad()
def infer(device, loader, net):
net.eval()
mIoU = 0
for i, (x, gt) in enumerate(loader):
x = x.to(device, non_blocking=True)
gt = gt.squeeze(1).to(device, dtype=torch.long, non_blocking=True)
pred = net(x)["fine"]
mIoU += iou_pytorch(pred, gt).mean()
mIoU = (mIoU / len(loader.dataset)).item()
logging.info(f"[Infer] mIOU : {mIoU}")
return mIoU
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/zhangys0425/PointRend-PyTorch.git
git@gitee.com:zhangys0425/PointRend-PyTorch.git
zhangys0425
PointRend-PyTorch
PointRend-PyTorch
master

搜索帮助

23e8dbc6 1850385 7e0993f3 1850385