代码拉取完成,页面将自动刷新
import torch.optim as optim
import torch.optim.lr_scheduler as lrs
import torch
from torch.autograd import Variable
import torch.nn as nn
class MeanShift(nn.Conv2d):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
self.weight.data.div_(std.view(3, 1, 1, 1))
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
self.bias.data.div_(std)
self.requires_grad = False
def make_optimizer(args, my_model):
trainable = filter(lambda x: x.requires_grad, my_model.parameters())
if args.optimizer == 'SGD':
optimizer_function = optim.SGD
kwargs = {'momentum': 0.9}
elif args.optimizer == 'ADAM':
optimizer_function = optim.Adam
kwargs = {
'betas': (0.9, 0.9999),
'eps': 1e-08
}
elif args.optimizer == 'ADAMax':
optimizer_function = optim.Adamax
kwargs = {
'betas': (0.9, 0.999),
'eps': 1e-08
}
elif args.optimizer == 'RMSprop':
optimizer_function = optim.RMSprop
kwargs = {'eps': 1e-08}
kwargs['lr'] = args.lr
# kwargs['weight_decay'] = args.weight_decay
return optimizer_function(trainable, **kwargs)
def make_scheduler(args, my_optimizer):
if args.decay_type == 'step':
scheduler = lrs.StepLR(
my_optimizer,
step_size=args.lr_decay,
gamma=args.gamma
)
elif args.decay_type.find('step') >= 0:
milestones = args.decay_type.split('_')
milestones.pop(0)
milestones = list(map(lambda x: int(x), milestones))
scheduler = lrs.MultiStepLR(
my_optimizer,
milestones=milestones,
gamma=args.gamma
)
return scheduler
def CharbonnierFunc(data, epsilon=0.001):
return torch.mean(torch.sqrt(data ** 2 + epsilon ** 2))
class Module_CharbonnierLoss(nn.Module):
def __init__(self, epsilon=0.001):
super(Module_CharbonnierLoss, self).__init__()
self.epsilon = epsilon
def forward(self, output, gt):
return torch.mean(torch.sqrt((output - gt) ** 2 + self.epsilon ** 2))
def to_variable(x):
if torch.cuda.is_available():
x = x.cuda()
return Variable(x)
def moduleNormalize(frame):
return torch.cat([(frame[:, 0:1, :, :] - 0.4631), (frame[:, 1:2, :, :] - 0.4352), (frame[:, 2:3, :, :] - 0.3990)], 1)
def normalize(x):
return x * 2.0 - 1.0
def denormalize(x):
return (x + 1.0) / 2.0
def meshgrid(height, width, grid_min, grid_max):
x_t = torch.matmul(
torch.ones(height, 1), torch.linspace(grid_min, grid_max, width).view(1, width))
y_t = torch.matmul(
torch.linspace(grid_min, grid_max, height).view(height, 1), torch.ones(1, width))
grid_x = x_t.view(1, height, width)
grid_y = y_t.view(1, height, width)
return grid_x, grid_y
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。