1 Star 2 Fork 0

jacinth2006/机器学习常见算法及演示

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
dcgan.py 4.69 KB
一键复制 编辑 原始数据 按行查看 历史
jacinth2006 提交于 2021-09-05 00:17 . 条件生成对抗网络
#%%
import imageio
import glob
import matplotlib.pyplot as plt
import os
import tensorflow as tf
from tensorflow.keras.layers import \
Dense, Dropout, Flatten, Conv2D, Conv2DTranspose, Reshape,\
BatchNormalization, LeakyReLU, Input
BATCH_SIZE = 64
data_path = "D:/AI/tf/tf_offcial_tutorial/src/dataset/mnist.npz"
checkpoint_path = "D:/work/demo/dcgan"
test_images_path = "D:/work/demo/dcgan_test_images"
def create_generator_model():
model = tf.keras.Sequential()
# 无此输入层,assert语句报错,或者在Dense指定输入层形状
model.add(Input((100,)))
model.add(Dense(7*7*256))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Reshape((7, 7, 256)))
assert model.output_shape == (None, 7, 7, 256)
model.add(Conv2DTranspose(128, (5, 5), strides=(
1, 1), padding="SAME", use_bias=False))
model.add(BatchNormalization())
model.add(LeakyReLU())
assert model.output_shape == (None, 7, 7, 128)
model.add(Conv2DTranspose(64, (5, 5), strides=(
2, 2), padding="SAME", use_bias=False))
model.add(BatchNormalization())
model.add(LeakyReLU())
assert model.output_shape == (None, 14, 14, 64)
model.add(Conv2DTranspose(1, (5, 5), strides=(
2, 2), padding="SAME", activation="tanh"))
assert model.output_shape == (None, 28, 28, 1)
return model
def create_discriminator_model():
model = tf.keras.Sequential()
model.add(Conv2D(64, 5, strides=(2, 2),
input_shape=[28, 28, 1], padding="SAME"))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Dropout(0.3))
assert model.output_shape == (None, 14, 14, 64)
model.add(Conv2D(128, 5, strides=(2, 2), padding="SAME"))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Dropout(0.3))
assert model.output_shape == (None, 7, 7, 128)
model.add(Flatten())
model.add(Dense(1))
return model
def save_and_show_generate_image(model, epoch, test_input, flag=0):
images = model(test_input, training=False)
images = images*127.5+127.5
images = tf.reshape(images, (images.shape[0], 28, 28))
for i in range(images.shape[0]):
plt.subplot(6, 6, i+1)
plt.imshow(images[i], cmap="gray")
plt.savefig('{}/image_at_epoch_{:04d}.png'.format(test_images_path, epoch))
if(1 == flag):
plt.show()
else:
plt.pause(0.1)
plt.clf()
@tf.function
def train_step(real):
noise = tf.random.normal((BATCH_SIZE, 100))
with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
# training必须设置
fake = gen(noise, training=True)
fake_output = dis(fake, training=True)
real_output = dis(real, training=True)
dis_loss = loss_object(tf.zeros_like(fake_output), fake_output) +\
loss_object(tf.ones_like(real_output), real_output)
gen_loss = loss_object(tf.ones_like(fake_output), fake_output)
gen_gradient = gen_tape.gradient(gen_loss, gen.trainable_variables)
optimizer.apply_gradients(zip(gen_gradient, gen.trainable_variables))
dis_gradient = dis_tape.gradient(dis_loss, dis.trainable_variables)
optimizer.apply_gradients(zip(dis_gradient, dis.trainable_variables))
(x_train, y_train), (x_test, y_test) = \
tf.keras.datasets.mnist.load_data(path=data_path)
x_train = x_train.reshape((x_train.shape[0], 28, 28, 1))
x_train = (x_train.astype("float32")-127.5)/127.5
x_train_data = tf.data.Dataset.from_tensor_slices(
(x_train)).shuffle(10000).batch(BATCH_SIZE)
if not os.path.exists(test_images_path):
os.mkdir(test_images_path)
gen = create_generator_model()
dis = create_discriminator_model()
tf.random.set_seed(100)
test_noise = tf.random.normal((36, 100))
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
checkpoint = tf.train.Checkpoint(mygen=gen, mydis=dis, myopt=optimizer)
manager = tf.train.CheckpointManager(
checkpoint, checkpoint_path, max_to_keep=5)
checkpoint.restore(manager.latest_checkpoint)
save_and_show_generate_image(gen, 0, test_noise, flag=0)
EPOCH = 20
for i in range(EPOCH):
for real in x_train_data:
train_step(real)
save_and_show_generate_image(gen, i, test_noise, flag=0)
manager.save()
#%%
test_images_path = "D:/work/demo/dcgan_test_images"
filenames = glob.glob("{}/{}".format(test_images_path, "*.png"))
filenames = sorted(filenames)
images = []
for i in filenames:
images.append(imageio.imread(i))
output_file = test_images_path + "/gif.gif"
imageio.mimsave(output_file, images, duration=0.8)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/jacinth2006/ML.git
git@gitee.com:jacinth2006/ML.git
jacinth2006
ML
机器学习常见算法及演示
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385