1 Star 0 Fork 0

xiajw06/VDNet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
utils.py 5.57 KB
一键复制 编辑 原始数据 按行查看 历史
zsyOAOA 提交于 2020-01-10 10:42 . Add the kai_ming normalization
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2019-01-22 22:07:08
import torch
import torch.nn as nn
from torch.autograd import Function as autoF
from scipy.special import gammaln
from skimage.measure import compare_psnr, compare_ssim
from skimage import img_as_ubyte
import numpy as np
import sys
from math import floor
def ssim_index(im1, im2):
'''
Input:
im1, im2: np.uint8 format
'''
if im1.ndim == 2:
out = compare_ssim(im1, im2, data_range=255, gaussian_weights=True,
use_sample_covariance=False, multichannel=False)
elif im1.ndim == 3:
out = compare_ssim(im1, im2, data_range=255, gaussian_weights=True,
use_sample_covariance=False, multichannel=True)
else:
sys.exit('Please input the corrected images')
return out
def im2patch(im, pch_size, stride=1):
'''
Transform image to patches.
Input:
im: 3 x H x W or 1 X H x W image, numpy format
pch_size: (int, int) tuple or integer
stride: (int, int) tuple or integer
'''
if isinstance(pch_size, tuple):
pch_H, pch_W = pch_size
elif isinstance(pch_size, int):
pch_H = pch_W = pch_size
else:
sys.exit('The input of pch_size must be a integer or a int tuple!')
if isinstance(stride, tuple):
stride_H, stride_W = stride
elif isinstance(stride, int):
stride_H = stride_W = stride
else:
sys.exit('The input of stride must be a integer or a int tuple!')
C, H, W = im.shape
num_H = len(range(0, H-pch_H+1, stride_H))
num_W = len(range(0, W-pch_W+1, stride_W))
num_pch = num_H * num_W
pch = np.zeros((C, pch_H*pch_W, num_pch), dtype=im.dtype)
kk = 0
for ii in range(pch_H):
for jj in range(pch_W):
temp = im[:, ii:H-pch_H+ii+1:stride_H, jj:W-pch_W+jj+1:stride_W]
pch[:, kk, :] = temp.reshape((C, num_pch))
kk += 1
return pch.reshape((C, pch_H, pch_W, num_pch))
def batch_PSNR(img, imclean):
Img = img.data.cpu().numpy()
Iclean = imclean.data.cpu().numpy()
Img = img_as_ubyte(Img)
Iclean = img_as_ubyte(Iclean)
PSNR = 0
for i in range(Img.shape[0]):
PSNR += compare_psnr(Iclean[i,:,:,:], Img[i,:,:,:], data_range=255)
return (PSNR/Img.shape[0])
def batch_SSIM(img, imclean):
Img = img.data.cpu().numpy()
Iclean = imclean.data.cpu().numpy()
Img = img_as_ubyte(Img)
Iclean = img_as_ubyte(Iclean)
SSIM = 0
for i in range(Img.shape[0]):
SSIM += ssim_index(Iclean[i,:,:,:].transpose((1,2,0)), Img[i,:,:,:].transpose((1,2,0)))
return (SSIM/Img.shape[0])
def peaks(n):
'''
Implementation the peak function of matlab.
'''
X = np.linspace(-3, 3, n)
Y = np.linspace(-3, 3, n)
[XX, YY] = np.meshgrid(X, Y)
ZZ = 3 * (1-XX)**2 * np.exp(-XX**2 - (YY+1)**2) \
- 10 * (XX/5.0 - XX**3 -YY**5) * np.exp(-XX**2-YY**2) - 1/3.0 * np.exp(-(XX+1)**2 - YY**2)
return ZZ
def generate_gauss_kernel_mix(H, W):
'''
Generate a H x W mixture Gaussian kernel with mean (center) and std (scale).
Input:
H, W: interger
center: mean value of x axis and y axis
scale: float value
'''
pch_size = 32
K_H = floor(H / pch_size)
K_W = floor(W / pch_size)
K = K_H * K_W
# prob = np.random.dirichlet(np.ones((K,)), size=1).reshape((1,1,K))
centerW = np.random.uniform(low=0, high=pch_size, size=(K_H, K_W))
ind_W = np.arange(K_W) * pch_size
centerW += ind_W.reshape((1, -1))
centerW = centerW.reshape((1,1,K)).astype(np.float32)
centerH = np.random.uniform(low=0, high=pch_size, size=(K_H, K_W))
ind_H = np.arange(K_H) * pch_size
centerH += ind_H.reshape((-1, 1))
centerH = centerH.reshape((1,1,K)).astype(np.float32)
scale = np.random.uniform(low=pch_size/2, high=pch_size, size=(1,1,K))
scale = scale.astype(np.float32)
XX, YY = np.meshgrid(np.arange(0, W), np.arange(0,H))
XX = XX[:, :, np.newaxis].astype(np.float32)
YY = YY[:, :, np.newaxis].astype(np.float32)
ZZ = 1./(2*np.pi*scale**2) * np.exp( (-(XX-centerW)**2-(YY-centerH)**2)/(2*scale**2) )
# ZZ *= prob
# out = ZZ.sum(axis=2, keepdims=False)
out = ZZ.sum(axis=2, keepdims=False) / K
return out
def sincos_kernel():
# Nips Version
[xx, yy] = np.meshgrid(np.linspace(1, 10, 256), np.linspace(1, 20, 256))
# [xx, yy] = np.meshgrid(np.linspace(1, 10, 256), np.linspace(-10, 15, 256))
zz = np.sin(xx) + np.cos(yy)
return zz
def capacity_cal(net):
out = 0
for param in net.parameters():
out += param.numel()*4/1024/1024
# print('Networks Parameters: {:.2f}M'.format(out))
return out
class LogGamma(autoF):
'''
Implement of the logarithm of gamma Function.
'''
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
if input.is_cuda:
input_np = input.detach().cpu().numpy()
else:
input_np = input.detach().numpy()
out = gammaln(input_np)
out = torch.from_numpy(out).to(device=input.device).type(dtype=input.dtype)
return out
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = torch.digamma(input) * grad_output
return grad_input
def load_state_dict_cpu(net, state_dict0):
state_dict1 = net.state_dict()
for name, value in state_dict1.items():
assert 'module.'+name in state_dict0
state_dict1[name] = state_dict0['module.'+name]
net.load_state_dict(state_dict1)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/xiajw06/VDNet.git
git@gitee.com:xiajw06/VDNet.git
xiajw06
VDNet
VDNet
master

搜索帮助