1 Star 0 Fork 0

lipengfeiSUaz/MIMO-UNet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
valid.py 1.14 KB
一键复制 编辑 原始数据 按行查看 历史
Sungjin 提交于 2021-08-20 14:25 .
import torch
from torchvision.transforms import functional as F
from data import valid_dataloader
from utils import Adder
import os
from skimage.metrics import peak_signal_noise_ratio
def _valid(model, args, ep):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gopro = valid_dataloader(args.data_dir, batch_size=1, num_workers=0)
model.eval()
psnr_adder = Adder()
with torch.no_grad():
print('Start GoPro Evaluation')
for idx, data in enumerate(gopro):
input_img, label_img = data
input_img = input_img.to(device)
if not os.path.exists(os.path.join(args.result_dir, '%d' % (ep))):
os.mkdir(os.path.join(args.result_dir, '%d' % (ep)))
pred = model(input_img)
pred_clip = torch.clamp(pred[2], 0, 1)
p_numpy = pred_clip.squeeze(0).cpu().numpy()
label_numpy = label_img.squeeze(0).cpu().numpy()
psnr = peak_signal_noise_ratio(p_numpy, label_numpy, data_range=1)
psnr_adder(psnr)
print('\r%03d'%idx, end=' ')
print('\n')
model.train()
return psnr_adder.average()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/lipengfeiSUaz/MIMO-UNet.git
git@gitee.com:lipengfeiSUaz/MIMO-UNet.git
lipengfeiSUaz
MIMO-UNet
MIMO-UNet
main

搜索帮助