1 Star 2 Fork 0

jacinth2006/机器学习常见算法及演示

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
GAN.py 6.79 KB
一键复制 编辑 原始数据 按行查看 历史
jacinth2006 提交于 2021-09-05 00:16 . 生成对抗生成手写数字
#%%
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense,BatchNormalization,LeakyReLU
from tensorflow.keras.losses import SparseCategoricalCrossentropy,BinaryCrossentropy
from tensorflow.keras.optimizers import Adam
BATCH_SIZE=64
BATCH_SIZE_FAKE=16
IMAGE_SIZE=28
Z_DIMS=100
test_images_path = "D:/work/demo/gan_test_images"
(x_train, y_train), (x_test, y_test)= \
tf.keras.datasets.mnist.load_data(path="D:/AI/tf/tf_offcial_tutorial/src/dataset/mnist.npz")
#将像素数据转换到[-1,1],因为Generator的输入经过tanh,也是在该区间。
x_train=(x_train.astype("float32")-127.5)/127.5
#将数据展开
x_train=tf.reshape(x_train,(x_train.shape[0],IMAGE_SIZE*IMAGE_SIZE))
train_image=tf.data.Dataset.from_tensor_slices((x_train)).shuffle(1000).batch(BATCH_SIZE)
train_image=[i for i in train_image]
#%%
class Generator(Model):
def __init__(self):
super(Generator,self).__init__()
self.d1=Dense(1024)
#在激活之前先批标准化
self.b1=BatchNormalization()
#使用ReLU激活
self.a1=LeakyReLU()
self.d2=Dense(512)
self.b2=BatchNormalization()
self.a2=LeakyReLU()
self.d3=Dense(28*28,activation="tanh")
def call(self,x):
x=self.d1(x)
x=self.b1(x)
x=self.a1(x)
x=self.d2(x)
x=self.b2(x)
x=self.a2(x)
return self.d3(x)
class Discriminator(Model):
def __init__(self):
super(Discriminator,self).__init__()
self.d1=Dense(1024)
self.b1=BatchNormalization()
self.a1=LeakyReLU()
self.d2=Dense(512)
self.b2=BatchNormalization()
self.a2=LeakyReLU()
self.d3=Dense(32)
self.b3=BatchNormalization()
self.a3=LeakyReLU()
#最后一个神经元,输出图片判别为真的概率
self.d4=Dense(1)
#如果使用如下作为最后的输出,其输出二分类的概率分布,需要修改损失函数为SparseCategoricalCrossentropy()
#self.d4=Dense(2,activation="softmax")
def call(self,x):
x=self.d1(x)
x=self.b1(x)
x=self.a1(x)
x=self.d2(x)
x=self.b2(x)
x=self.a2(x)
x=self.d3(x)
x=self.b3(x)
x=self.a3(x)
return self.d4(x)
fake_loss_object=BinaryCrossentropy(from_logits=True)
fake_optimizer=Adam(learning_rate=1e-4)
#如果最后一层输出没有经过softmax或者sigmoid进行概率化,from_logits必须设置为True,使结果更稳定,
#否则由于结果输出差别太大,将找不到下降梯度。
real_loss_object=BinaryCrossentropy(from_logits=True)
real_optimizer=Adam(learning_rate=1e-4)
#如果判别器最后输出为2个神经元,输出二分类概率分布,则使用该损失函数
loss_object=SparseCategoricalCrossentropy()
gen=Generator()
dis=Discriminator()
@tf.function
def train_step(real):
#输入100长度的噪声数据
noises=tf.random.normal((BATCH_SIZE_FAKE,Z_DIMS))
with tf.GradientTape() as gen_tape,tf.GradientTape() as dis_tape:
fake=gen(noises)
fake_output=dis(fake)
real_output=dis(real)
#相当于-E_(x~p_data (x) ) log⁡(D(x)),对该值最小化即使真实图片尽量输出1
#相当于-E_(z~p_z (z) ) log⁡(1-D(G(z))),对该值最小化即使假图片尽量输出0
dis_loss=real_loss_object(tf.ones_like(real_output),real_output) + \
real_loss_object(tf.zeros_like(fake_output),fake_output)
#判别器输出2维的概率分布时,使用如下损失函数
#dis_loss=loss_object(tf.ones((real_output.shape[0])),real_output)+\
# loss_object(tf.zeros((fake_output.shape[0])),fake_output)
#相当于E_(z~p_z (z) ) log⁡(1-D(G(z))),对该值最小化即使假图片尽量输出1
#损失函数计算在原始代码种生成器和判别器分别定义的,其实可以公用一个
gen_loss=real_loss_object(tf.ones_like(fake_output),fake_output)
#判别器输出2维的概率分布时,生成器使用如下损失函数
#gen_loss=loss_object(tf.ones((fake_output.shape[0])),fake_output)
#Gradient在with内部将目标函数和其依赖的参数记录下来,求导的时候知道其依赖关系
#如果不某些依赖不放在with里将导致梯度无法计算,因为没有记下来
gen_grad=gen_tape.gradient(gen_loss,gen.trainable_variables)
#梯度generator和discriminator可以公用
real_optimizer.apply_gradients(zip(gen_grad,gen.trainable_variables))
#GradientTape不能公用,否则报如下错误RuntimeError: GradientTape.gradient can only be called once on non-persistent tapes
dis_grad=dis_tape.gradient(dis_loss,dis.trainable_variables)
real_optimizer.apply_gradients(zip(dis_grad,dis.trainable_variables))
i=0
ECHO=50
seed=tf.random.normal((16,Z_DIMS))
def generate_and_save_images(model, epoch, test_input,flag=0):
predictions = model(test_input, training=False)
#生成器网络预测结果输出数据区间维[-1,1],转换到mnist像素空间
predictions=predictions*127.5+127.5
#变换成2维图片数据
predictions=tf.reshape(predictions,(predictions.shape[0],28,28))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow(predictions[i], cmap='gray')
#clf前必现加pause,否则不显示
if(1==flag):
plt.show()
else:
plt.pause(0.1)
plt.clf()
def save_and_show_generate_image(model, epoch, test_input, flag=0):
images = model(test_input, training=False)
images = images*127.5+127.5
images = tf.reshape(images, (images.shape[0], 28, 28))
for i in range(images.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow(images[i], cmap="gray")
plt.savefig('{}/image_at_epoch_{:04d}.png'.format(test_images_path, epoch))
if(1 == flag):
plt.show()
else:
plt.pause(0.1)
plt.clf()
for i in range(ECHO):
for j in range(len(train_image)):
train_step(train_image[j])
# if(j%20000==0):
# generate_and_save_images(gen,i,seed)
print("echo{}".format(i))
save_and_show_generate_image(gen,i,seed,0)
#generate_and_save_images(gen,i,seed,1)
import glob
import imageio
filenames=glob.glob("{}/{}".format(test_images_path,"*.png"))
filenames=sorted(filenames)
images=[]
for i in filenames:
print(i)
images.append(imageio.imread(i))
output_file = test_images_path +"/gif.gif"
imageio.mimsave(output_file, images, duration=0.3)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/jacinth2006/ML.git
git@gitee.com:jacinth2006/ML.git
jacinth2006
ML
机器学习常见算法及演示
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385