1 Star 1 Fork 0

旮旯里的秋田犬/stylized_neural_painting

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
networks.py 17.26 KB
一键复制 编辑 原始数据 按行查看 历史
jiupinjia 提交于 2020-11-26 15:10 . Updates on lightweight renderers
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torchvision import models
import torch.nn.functional as F
from torch.optim import lr_scheduler
import math
import utils
import matplotlib.pyplot as plt
import numpy as np
# Decide which device we want to run on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
PI = math.pi
###############################################################################
# Helper Functions
###############################################################################
class Identity(nn.Module):
def forward(self, x):
return x
def get_norm_layer(norm_type='instance'):
"""Return a normalization layer
Parameters:
norm_type (str) -- the name of the normalization layer: batch | instance | none
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
"""
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == 'none':
norm_layer = lambda x: Identity()
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def init_weights(net, init_type='normal', init_gain=0.02):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func) # apply the initialization function <init_func>
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
Parameters:
net (network) -- the network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
Return an initialized network.
"""
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
init_weights(net, init_type, init_gain=init_gain)
return net
def define_G(rdrr, netG, init_type='normal', init_gain=0.02, gpu_ids=[]):
net = None
if netG == 'plain-dcgan':
net = DCGAN(rdrr)
elif netG == 'plain-unet':
net = UNet(rdrr)
elif netG == 'huang-net':
net = HuangNet(rdrr)
elif netG == 'zou-fusion-net':
net = ZouFCNFusion(rdrr)
elif netG == 'zou-fusion-net-light':
net = ZouFCNFusionLight(rdrr)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
return init_net(net, init_type, init_gain, gpu_ids)
class DCGAN(nn.Module):
def __init__(self, rdrr, ngf=64):
super(DCGAN, self).__init__()
input_nc = rdrr.d
self.out_size = 128
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(input_nc, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 32 x 32
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf*2) x 64 x 64
nn.ConvTranspose2d(ngf, 6, 4, 2, 1, bias=False),
# state size. (nc) x 128 x 128
)
def forward(self, input):
output_tensor = self.main(input)
return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:]
class DCGAN_32(nn.Module):
def __init__(self, rdrr, ngf=64):
super(DCGAN_32, self).__init__()
input_nc = rdrr.d
self.out_size = 32
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(input_nc, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, 6, 4, 2, 1, bias=False),
# state size. 6 x 32 x 32
)
def forward(self, input):
output_tensor = self.main(input)
return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:]
class PixelShuffleNet(nn.Module):
def __init__(self, input_nc):
super(PixelShuffleNet, self).__init__()
self.fc1 = (nn.Linear(input_nc, 512))
self.fc2 = (nn.Linear(512, 1024))
self.fc3 = (nn.Linear(1024, 2048))
self.fc4 = (nn.Linear(2048, 4096))
self.conv1 = (nn.Conv2d(16, 32, 3, 1, 1))
self.conv2 = (nn.Conv2d(32, 32, 3, 1, 1))
self.conv3 = (nn.Conv2d(8, 16, 3, 1, 1))
self.conv4 = (nn.Conv2d(16, 16, 3, 1, 1))
self.conv5 = (nn.Conv2d(4, 8, 3, 1, 1))
self.conv6 = (nn.Conv2d(8, 4*3, 3, 1, 1))
self.pixel_shuffle = nn.PixelShuffle(2)
def forward(self, x):
x = x.squeeze()
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.relu(self.fc4(x))
x = x.view(-1, 16, 16, 16)
x = F.relu(self.conv1(x))
x = self.pixel_shuffle(self.conv2(x))
x = F.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
x = F.relu(self.conv5(x))
x = self.pixel_shuffle(self.conv6(x))
x = x.view(-1, 3, 128, 128)
return x
class PixelShuffleNet_32(nn.Module):
def __init__(self, input_nc):
super(PixelShuffleNet_32, self).__init__()
self.fc1 = (nn.Linear(input_nc, 512))
self.fc2 = (nn.Linear(512, 1024))
self.fc3 = (nn.Linear(1024, 2048))
self.conv1 = (nn.Conv2d(8, 64, 3, 1, 1))
self.conv2 = (nn.Conv2d(64, 4*3, 3, 1, 1))
self.pixel_shuffle = nn.PixelShuffle(2)
def forward(self, x):
x = x.squeeze()
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = x.view(-1, 8, 16, 16)
x = F.relu(self.conv1(x))
x = self.pixel_shuffle(self.conv2(x))
x = x.view(-1, 3, 32, 32)
return x
class HuangNet(nn.Module):
def __init__(self, rdrr):
super(HuangNet, self).__init__()
self.rdrr = rdrr
self.out_size = 128
self.fc1 = (nn.Linear(rdrr.d, 512))
self.fc2 = (nn.Linear(512, 1024))
self.fc3 = (nn.Linear(1024, 2048))
self.fc4 = (nn.Linear(2048, 4096))
self.conv1 = (nn.Conv2d(16, 32, 3, 1, 1))
self.conv2 = (nn.Conv2d(32, 32, 3, 1, 1))
self.conv3 = (nn.Conv2d(8, 16, 3, 1, 1))
self.conv4 = (nn.Conv2d(16, 16, 3, 1, 1))
self.conv5 = (nn.Conv2d(4, 8, 3, 1, 1))
self.conv6 = (nn.Conv2d(8, 4 * 6, 3, 1, 1))
self.pixel_shuffle = nn.PixelShuffle(2)
def forward(self, x):
x = x.squeeze()
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.relu(self.fc4(x))
x = x.view(-1, 16, 16, 16)
x = F.relu(self.conv1(x))
x = self.pixel_shuffle(self.conv2(x))
x = F.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
x = F.relu(self.conv5(x))
x = self.pixel_shuffle(self.conv6(x))
output_tensor = x.view(-1, 6, 128, 128)
return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:]
class ZouFCNFusion(nn.Module):
def __init__(self, rdrr):
super(ZouFCNFusion, self).__init__()
self.rdrr = rdrr
self.out_size = 128
self.huangnet = PixelShuffleNet(rdrr.d_shape)
self.dcgan = DCGAN(rdrr)
def forward(self, x):
x_shape = x[:, 0:self.rdrr.d_shape, :, :]
x_alpha = x[:, [-1], :, :]
if self.rdrr.renderer in ['oilpaintbrush', 'airbrush']:
x_alpha = torch.tensor(1.0).to(device)
mask = self.huangnet(x_shape)
color, _ = self.dcgan(x)
return color * mask, x_alpha * mask
class ZouFCNFusionLight(nn.Module):
def __init__(self, rdrr):
super(ZouFCNFusionLight, self).__init__()
self.rdrr = rdrr
self.out_size = 32
self.huangnet = PixelShuffleNet_32(rdrr.d_shape)
self.dcgan = DCGAN_32(rdrr)
def forward(self, x):
x_shape = x[:, 0:self.rdrr.d_shape, :, :]
x_alpha = x[:, [-1], :, :]
if self.rdrr.renderer in ['oilpaintbrush', 'airbrush']:
x_alpha = torch.tensor(1.0).to(device)
mask = self.huangnet(x_shape)
color, _ = self.dcgan(x)
return color * mask, x_alpha * mask
class UNet(torch.nn.Module):
def __init__(self, rdrr):
"""
In the constructor we instantiate two nn.Linear modules and assign them as
member variables.
"""
super(UNet, self).__init__()
norm_layer = get_norm_layer(norm_type='batch')
self.unet = UnetGenerator(rdrr.d, 6, 7, norm_layer=norm_layer, use_dropout=False)
def forward(self, x):
"""
In the forward function we accept a Tensor of input data and we must return
a Tensor of output data. We can use Modules defined in the constructor as
well as arbitrary operators on Tensors.
"""
# resnet layers
x = x.repeat(1, 1, 128, 128)
output_tensor = self.unet(x)
return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:]
class UnetGenerator(nn.Module):
"""Create a Unet-based generator"""
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(UnetGenerator, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
# gradually reduce the number of filters from ngf * 8 to ngf
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
def forward(self, input):
"""Standard forward"""
return self.model(input)
class UnetSkipConnectionBlock(nn.Module):
"""Defines the Unet submodule with skip connection.
X -------------------identity----------------------
|-- downsampling -- |submodule| -- upsampling --|
"""
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet submodule with skip connections.
Parameters:
outer_nc (int) -- the number of filters in the outer conv layer
inner_nc (int) -- the number of filters in the inner conv layer
input_nc (int) -- the number of channels in input images/features
submodule (UnetSkipConnectionBlock) -- previously defined submodules
outermost (bool) -- if this module is the outermost module
innermost (bool) -- if this module is the innermost module
norm_layer -- normalization layer
user_dropout (bool) -- if use dropout layers.
"""
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
# up = [uprelu, upconv, nn.Tanh()]
# up = [uprelu, upconv, nn.Sigmoid()] # ZZX
up = [uprelu, upconv] # ZZX
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else: # add skip connections
return torch.cat([x, self.model(x)], 1)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/QiutianDog/stylized_neural_painting.git
git@gitee.com:QiutianDog/stylized_neural_painting.git
QiutianDog
stylized_neural_painting
stylized_neural_painting
main

搜索帮助