代码拉取完成,页面将自动刷新
#! -*- coding: utf-8 -*-
import numpy as np
from scipy import misc
import glob
import imageio
from keras.models import Model
from keras.layers import *
from keras import backend as K
from keras.optimizers import Adam
from keras.callbacks import Callback
imgs = glob.glob('img_align_celeba/*.jpg')
np.random.shuffle(imgs)
height,width = misc.imread(imgs[0]).shape[:2]
center_height = int((height - width) / 2)
img_dim = 64
z_dim = 512
def imread(f):
x = misc.imread(f)
x = x[center_height:center_height+width, :]
x = misc.imresize(x, (img_dim, img_dim))
return x.astype(np.float32) / 255 * 2 - 1
def data_generator(batch_size=32):
X = []
while True:
np.random.shuffle(imgs)
for f in imgs:
X.append(imread(f))
if len(X) == batch_size:
X = np.array(X)
yield X,None
X = []
x_in = Input(shape=(img_dim, img_dim, 3))
x = x_in
x = Conv2D(z_dim/16, kernel_size=(5,5), strides=(2,2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(z_dim/8, kernel_size=(5,5), strides=(2,2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(z_dim/4, kernel_size=(5,5), strides=(2,2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(z_dim/2, kernel_size=(5,5), strides=(2,2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(z_dim, kernel_size=(5,5), strides=(2,2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = GlobalAveragePooling2D()(x)
encoder = Model(x_in, x)
encoder.summary()
map_size = K.int_shape(encoder.layers[-2].output)[1:-1]
# 解码层,也就是生成器部分
z_in = Input(shape=K.int_shape(x)[1:])
z = z_in
z = Dense(np.prod(map_size)*z_dim)(z)
z = Reshape(map_size + (z_dim,))(z)
z = Conv2DTranspose(z_dim/2, kernel_size=(5,5), strides=(2,2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(z_dim/4, kernel_size=(5,5), strides=(2,2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(z_dim/8, kernel_size=(5,5), strides=(2,2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(z_dim/16, kernel_size=(5,5), strides=(2,2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(3, kernel_size=(5,5), strides=(2,2), padding='SAME')(z)
z = Activation('tanh')(z)
decoder = Model(z_in, z)
decoder.summary()
class ScaleShift(Layer):
def __init__(self, **kwargs):
super(ScaleShift, self).__init__(**kwargs)
def call(self, inputs):
z, shift, log_scale = inputs
z = K.exp(log_scale) * z + shift
logdet = -K.sum(K.mean(log_scale, 0))
self.add_loss(logdet)
return z
z_shift = Dense(z_dim)(x)
z_log_scale = Dense(z_dim)(x)
u = Lambda(lambda z: K.random_normal(shape=K.shape(z)))(z_shift)
z = ScaleShift()([u, z_shift, z_log_scale])
x_recon = decoder(z)
x_out = Subtract()([x_in, x_recon])
recon_loss = 0.5 * K.sum(K.mean(x_out**2, 0)) + 0.5 * np.log(2*np.pi) * np.prod(K.int_shape(x_out)[1:])
z_loss = 0.5 * K.sum(K.mean(z**2, 0)) - 0.5 * K.sum(K.mean(u**2, 0))
vae_loss = recon_loss + z_loss
vae = Model(x_in, x_out)
vae.add_loss(vae_loss)
vae.compile(optimizer=Adam(1e-4))
def sample(path):
n = 9
figure = np.zeros((img_dim*n, img_dim*n, 3))
for i in range(n):
for j in range(n):
x_recon = decoder.predict(np.random.randn(1, *K.int_shape(x)[1:]))
digit = x_recon[0]
figure[i*img_dim: (i+1)*img_dim,
j*img_dim: (j+1)*img_dim] = digit
figure = (figure + 1) / 2 * 255
imageio.imwrite(path, figure)
class Evaluate(Callback):
def __init__(self):
import os
self.lowest = 1e10
self.losses = []
if not os.path.exists('samples'):
os.mkdir('samples')
def on_epoch_end(self, epoch, logs=None):
path = 'samples/test_%s.png' % epoch
sample(path)
self.losses.append((epoch, logs['loss']))
if logs['loss'] <= self.lowest:
self.lowest = logs['loss']
encoder.save_weights('./best_encoder.weights')
evaluator = Evaluate()
vae.fit_generator(data_generator(),
epochs=1000,
steps_per_epoch=1000,
callbacks=[evaluator])
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。