代码拉取完成,页面将自动刷新
#%%
import imageio
import glob
import matplotlib.pyplot as plt
import os
import tensorflow as tf
from tensorflow.keras.layers import \
Dense, Dropout, Flatten, Conv2D, Conv2DTranspose, Reshape,\
BatchNormalization, LeakyReLU, Input
BATCH_SIZE = 64
data_path = "D:/AI/tf/tf_offcial_tutorial/src/dataset/mnist.npz"
checkpoint_path = "D:/work/demo/dcgan"
test_images_path = "D:/work/demo/dcgan_test_images"
def create_generator_model():
model = tf.keras.Sequential()
# 无此输入层,assert语句报错,或者在Dense指定输入层形状
model.add(Input((100,)))
model.add(Dense(7*7*256))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Reshape((7, 7, 256)))
assert model.output_shape == (None, 7, 7, 256)
model.add(Conv2DTranspose(128, (5, 5), strides=(
1, 1), padding="SAME", use_bias=False))
model.add(BatchNormalization())
model.add(LeakyReLU())
assert model.output_shape == (None, 7, 7, 128)
model.add(Conv2DTranspose(64, (5, 5), strides=(
2, 2), padding="SAME", use_bias=False))
model.add(BatchNormalization())
model.add(LeakyReLU())
assert model.output_shape == (None, 14, 14, 64)
model.add(Conv2DTranspose(1, (5, 5), strides=(
2, 2), padding="SAME", activation="tanh"))
assert model.output_shape == (None, 28, 28, 1)
return model
def create_discriminator_model():
model = tf.keras.Sequential()
model.add(Conv2D(64, 5, strides=(2, 2),
input_shape=[28, 28, 1], padding="SAME"))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Dropout(0.3))
assert model.output_shape == (None, 14, 14, 64)
model.add(Conv2D(128, 5, strides=(2, 2), padding="SAME"))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Dropout(0.3))
assert model.output_shape == (None, 7, 7, 128)
model.add(Flatten())
model.add(Dense(1))
return model
def save_and_show_generate_image(model, epoch, test_input, flag=0):
images = model(test_input, training=False)
images = images*127.5+127.5
images = tf.reshape(images, (images.shape[0], 28, 28))
for i in range(images.shape[0]):
plt.subplot(6, 6, i+1)
plt.imshow(images[i], cmap="gray")
plt.savefig('{}/image_at_epoch_{:04d}.png'.format(test_images_path, epoch))
if(1 == flag):
plt.show()
else:
plt.pause(0.1)
plt.clf()
@tf.function
def train_step(real):
noise = tf.random.normal((BATCH_SIZE, 100))
with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
# training必须设置
fake = gen(noise, training=True)
fake_output = dis(fake, training=True)
real_output = dis(real, training=True)
dis_loss = loss_object(tf.zeros_like(fake_output), fake_output) +\
loss_object(tf.ones_like(real_output), real_output)
gen_loss = loss_object(tf.ones_like(fake_output), fake_output)
gen_gradient = gen_tape.gradient(gen_loss, gen.trainable_variables)
optimizer.apply_gradients(zip(gen_gradient, gen.trainable_variables))
dis_gradient = dis_tape.gradient(dis_loss, dis.trainable_variables)
optimizer.apply_gradients(zip(dis_gradient, dis.trainable_variables))
(x_train, y_train), (x_test, y_test) = \
tf.keras.datasets.mnist.load_data(path=data_path)
x_train = x_train.reshape((x_train.shape[0], 28, 28, 1))
x_train = (x_train.astype("float32")-127.5)/127.5
x_train_data = tf.data.Dataset.from_tensor_slices(
(x_train)).shuffle(10000).batch(BATCH_SIZE)
if not os.path.exists(test_images_path):
os.mkdir(test_images_path)
gen = create_generator_model()
dis = create_discriminator_model()
tf.random.set_seed(100)
test_noise = tf.random.normal((36, 100))
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
checkpoint = tf.train.Checkpoint(mygen=gen, mydis=dis, myopt=optimizer)
manager = tf.train.CheckpointManager(
checkpoint, checkpoint_path, max_to_keep=5)
checkpoint.restore(manager.latest_checkpoint)
save_and_show_generate_image(gen, 0, test_noise, flag=0)
EPOCH = 20
for i in range(EPOCH):
for real in x_train_data:
train_step(real)
save_and_show_generate_image(gen, i, test_noise, flag=0)
manager.save()
#%%
test_images_path = "D:/work/demo/dcgan_test_images"
filenames = glob.glob("{}/{}".format(test_images_path, "*.png"))
filenames = sorted(filenames)
images = []
for i in filenames:
images.append(imageio.imread(i))
output_file = test_images_path + "/gif.gif"
imageio.mimsave(output_file, images, duration=0.8)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。