1 Star 1 Fork 0

wwhio/KAIR

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
main_test_dncnn3_deblocking.py 4.36 KB
一键复制 编辑 原始数据 按行查看 历史
Kai Zhang 提交于 2020-05-21 21:20 . Add DnCNN3 for JPEG image deblocking
import os.path
import logging
import numpy as np
from datetime import datetime
from collections import OrderedDict
import torch
from utils import utils_logger
from utils import utils_model
from utils import utils_image as util
#import os
#os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
'''
Spyder (Python 3.6)
PyTorch 1.1.0
Windows 10 or Linux
Kai Zhang (cskaizhang@gmail.com)
github: https://github.com/cszn/KAIR
https://github.com/cszn/DnCNN
@article{zhang2017beyond,
title={Beyond a gaussian denoiser: Residual learning of deep cnn for image denoising},
author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
journal={IEEE Transactions on Image Processing},
volume={26},
number={7},
pages={3142--3155},
year={2017},
publisher={IEEE}
}
% If you have any question, please feel free to contact with me.
% Kai Zhang (e-mail: cskaizhang@gmail.com; github: https://github.com/cszn)
by Kai Zhang (12/Dec./2019)
'''
"""
# --------------------------------------------
|--model_zoo # model_zoo
|--dncnn3 # model_name
|--testset # testsets
|--set12 # testset_name
|--bsd68
|--results # results
|--set12_dncnn3 # result_name = testset_name + '_' + model_name
# --------------------------------------------
"""
def main():
# ----------------------------------------
# Preparation
# ----------------------------------------
model_name = 'dncnn3' # 'dncnn3'- can be used for blind Gaussian denoising, JPEG deblocking (quality factor 5-100) and super-resolution (x234)
# important!
testset_name = 'bsd68' # test set, low-quality grayscale/color JPEG images
n_channels = 1 # set 1 for grayscale image, set 3 for color image
x8 = False # default: False, x8 to boost performance
testsets = 'testsets' # fixed
results = 'results' # fixed
result_name = testset_name + '_' + model_name # fixed
L_path = os.path.join(testsets, testset_name) # L_path, for Low-quality grayscale/Y-channel JPEG images
E_path = os.path.join(results, result_name) # E_path, for Estimated images
util.mkdir(E_path)
model_pool = 'model_zoo' # fixed
model_path = os.path.join(model_pool, model_name+'.pth')
logger_name = result_name
utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
logger = logging.getLogger(logger_name)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# ----------------------------------------
# load model
# ----------------------------------------
from models.network_dncnn import DnCNN as net
model = net(in_nc=1, out_nc=1, nc=64, nb=20, act_mode='R')
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
logger.info('Model path: {:s}'.format(model_path))
number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
logger.info('Params number: {}'.format(number_parameters))
logger.info(L_path)
L_paths = util.get_image_paths(L_path)
for idx, img in enumerate(L_paths):
# ------------------------------------
# (1) img_L
# ------------------------------------
img_name, ext = os.path.splitext(os.path.basename(img))
logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
img_L = util.imread_uint(img, n_channels=n_channels)
img_L = util.uint2single(img_L)
if n_channels == 3:
ycbcr = util.rgb2ycbcr(img_L, False)
img_L = ycbcr[..., 0:1]
img_L = util.single2tensor4(img_L)
img_L = img_L.to(device)
# ------------------------------------
# (2) img_E
# ------------------------------------
if not x8:
img_E = model(img_L)
else:
img_E = utils_model.test_mode(model, img_L, mode=3)
img_E = util.tensor2single(img_E)
if n_channels == 3:
ycbcr[..., 0] = img_E
img_E = util.ycbcr2rgb(ycbcr)
img_E = util.single2uint(img_E)
# ------------------------------------
# save results
# ------------------------------------
util.imsave(img_E, os.path.join(E_path, img_name+'.png'))
if __name__ == '__main__':
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zfr9b/KAIR.git
git@gitee.com:zfr9b/KAIR.git
zfr9b
KAIR
KAIR
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385