代码拉取完成,页面将自动刷新
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torchvision import models
import torch.nn.functional as F
from torch.optim import lr_scheduler
import math
import utils
import matplotlib.pyplot as plt
import numpy as np
# Decide which device we want to run on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
PI = math.pi
###############################################################################
# Helper Functions
###############################################################################
class Identity(nn.Module):
def forward(self, x):
return x
def get_norm_layer(norm_type='instance'):
"""Return a normalization layer
Parameters:
norm_type (str) -- the name of the normalization layer: batch | instance | none
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
"""
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == 'none':
norm_layer = lambda x: Identity()
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def init_weights(net, init_type='normal', init_gain=0.02):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func) # apply the initialization function <init_func>
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
Parameters:
net (network) -- the network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
Return an initialized network.
"""
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
init_weights(net, init_type, init_gain=init_gain)
return net
def define_G(rdrr, netG, init_type='normal', init_gain=0.02, gpu_ids=[]):
net = None
if netG == 'plain-dcgan':
net = DCGAN(rdrr)
elif netG == 'plain-unet':
net = UNet(rdrr)
elif netG == 'huang-net':
net = HuangNet(rdrr)
elif netG == 'zou-fusion-net':
net = ZouFCNFusion(rdrr)
elif netG == 'zou-fusion-net-light':
net = ZouFCNFusionLight(rdrr)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
return init_net(net, init_type, init_gain, gpu_ids)
class DCGAN(nn.Module):
def __init__(self, rdrr, ngf=64):
super(DCGAN, self).__init__()
input_nc = rdrr.d
self.out_size = 128
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(input_nc, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 32 x 32
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf*2) x 64 x 64
nn.ConvTranspose2d(ngf, 6, 4, 2, 1, bias=False),
# state size. (nc) x 128 x 128
)
def forward(self, input):
output_tensor = self.main(input)
return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:]
class DCGAN_32(nn.Module):
def __init__(self, rdrr, ngf=64):
super(DCGAN_32, self).__init__()
input_nc = rdrr.d
self.out_size = 32
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(input_nc, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, 6, 4, 2, 1, bias=False),
# state size. 6 x 32 x 32
)
def forward(self, input):
output_tensor = self.main(input)
return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:]
class PixelShuffleNet(nn.Module):
def __init__(self, input_nc):
super(PixelShuffleNet, self).__init__()
self.fc1 = (nn.Linear(input_nc, 512))
self.fc2 = (nn.Linear(512, 1024))
self.fc3 = (nn.Linear(1024, 2048))
self.fc4 = (nn.Linear(2048, 4096))
self.conv1 = (nn.Conv2d(16, 32, 3, 1, 1))
self.conv2 = (nn.Conv2d(32, 32, 3, 1, 1))
self.conv3 = (nn.Conv2d(8, 16, 3, 1, 1))
self.conv4 = (nn.Conv2d(16, 16, 3, 1, 1))
self.conv5 = (nn.Conv2d(4, 8, 3, 1, 1))
self.conv6 = (nn.Conv2d(8, 4*3, 3, 1, 1))
self.pixel_shuffle = nn.PixelShuffle(2)
def forward(self, x):
x = x.squeeze()
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.relu(self.fc4(x))
x = x.view(-1, 16, 16, 16)
x = F.relu(self.conv1(x))
x = self.pixel_shuffle(self.conv2(x))
x = F.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
x = F.relu(self.conv5(x))
x = self.pixel_shuffle(self.conv6(x))
x = x.view(-1, 3, 128, 128)
return x
class PixelShuffleNet_32(nn.Module):
def __init__(self, input_nc):
super(PixelShuffleNet_32, self).__init__()
self.fc1 = (nn.Linear(input_nc, 512))
self.fc2 = (nn.Linear(512, 1024))
self.fc3 = (nn.Linear(1024, 2048))
self.conv1 = (nn.Conv2d(8, 64, 3, 1, 1))
self.conv2 = (nn.Conv2d(64, 4*3, 3, 1, 1))
self.pixel_shuffle = nn.PixelShuffle(2)
def forward(self, x):
x = x.squeeze()
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = x.view(-1, 8, 16, 16)
x = F.relu(self.conv1(x))
x = self.pixel_shuffle(self.conv2(x))
x = x.view(-1, 3, 32, 32)
return x
class HuangNet(nn.Module):
def __init__(self, rdrr):
super(HuangNet, self).__init__()
self.rdrr = rdrr
self.out_size = 128
self.fc1 = (nn.Linear(rdrr.d, 512))
self.fc2 = (nn.Linear(512, 1024))
self.fc3 = (nn.Linear(1024, 2048))
self.fc4 = (nn.Linear(2048, 4096))
self.conv1 = (nn.Conv2d(16, 32, 3, 1, 1))
self.conv2 = (nn.Conv2d(32, 32, 3, 1, 1))
self.conv3 = (nn.Conv2d(8, 16, 3, 1, 1))
self.conv4 = (nn.Conv2d(16, 16, 3, 1, 1))
self.conv5 = (nn.Conv2d(4, 8, 3, 1, 1))
self.conv6 = (nn.Conv2d(8, 4 * 6, 3, 1, 1))
self.pixel_shuffle = nn.PixelShuffle(2)
def forward(self, x):
x = x.squeeze()
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.relu(self.fc4(x))
x = x.view(-1, 16, 16, 16)
x = F.relu(self.conv1(x))
x = self.pixel_shuffle(self.conv2(x))
x = F.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
x = F.relu(self.conv5(x))
x = self.pixel_shuffle(self.conv6(x))
output_tensor = x.view(-1, 6, 128, 128)
return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:]
class ZouFCNFusion(nn.Module):
def __init__(self, rdrr):
super(ZouFCNFusion, self).__init__()
self.rdrr = rdrr
self.out_size = 128
self.huangnet = PixelShuffleNet(rdrr.d_shape)
self.dcgan = DCGAN(rdrr)
def forward(self, x):
x_shape = x[:, 0:self.rdrr.d_shape, :, :]
x_alpha = x[:, [-1], :, :]
if self.rdrr.renderer in ['oilpaintbrush', 'airbrush']:
x_alpha = torch.tensor(1.0).to(device)
mask = self.huangnet(x_shape)
color, _ = self.dcgan(x)
return color * mask, x_alpha * mask
class ZouFCNFusionLight(nn.Module):
def __init__(self, rdrr):
super(ZouFCNFusionLight, self).__init__()
self.rdrr = rdrr
self.out_size = 32
self.huangnet = PixelShuffleNet_32(rdrr.d_shape)
self.dcgan = DCGAN_32(rdrr)
def forward(self, x):
x_shape = x[:, 0:self.rdrr.d_shape, :, :]
x_alpha = x[:, [-1], :, :]
if self.rdrr.renderer in ['oilpaintbrush', 'airbrush']:
x_alpha = torch.tensor(1.0).to(device)
mask = self.huangnet(x_shape)
color, _ = self.dcgan(x)
return color * mask, x_alpha * mask
class UNet(torch.nn.Module):
def __init__(self, rdrr):
"""
In the constructor we instantiate two nn.Linear modules and assign them as
member variables.
"""
super(UNet, self).__init__()
norm_layer = get_norm_layer(norm_type='batch')
self.unet = UnetGenerator(rdrr.d, 6, 7, norm_layer=norm_layer, use_dropout=False)
def forward(self, x):
"""
In the forward function we accept a Tensor of input data and we must return
a Tensor of output data. We can use Modules defined in the constructor as
well as arbitrary operators on Tensors.
"""
# resnet layers
x = x.repeat(1, 1, 128, 128)
output_tensor = self.unet(x)
return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:]
class UnetGenerator(nn.Module):
"""Create a Unet-based generator"""
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(UnetGenerator, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
# gradually reduce the number of filters from ngf * 8 to ngf
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
def forward(self, input):
"""Standard forward"""
return self.model(input)
class UnetSkipConnectionBlock(nn.Module):
"""Defines the Unet submodule with skip connection.
X -------------------identity----------------------
|-- downsampling -- |submodule| -- upsampling --|
"""
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet submodule with skip connections.
Parameters:
outer_nc (int) -- the number of filters in the outer conv layer
inner_nc (int) -- the number of filters in the inner conv layer
input_nc (int) -- the number of channels in input images/features
submodule (UnetSkipConnectionBlock) -- previously defined submodules
outermost (bool) -- if this module is the outermost module
innermost (bool) -- if this module is the innermost module
norm_layer -- normalization layer
user_dropout (bool) -- if use dropout layers.
"""
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
# up = [uprelu, upconv, nn.Tanh()]
# up = [uprelu, upconv, nn.Sigmoid()] # ZZX
up = [uprelu, upconv] # ZZX
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else: # add skip connections
return torch.cat([x, self.model(x)], 1)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。