1 Star 2 Fork 0

jacinth2006/机器学习常见算法及演示

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
cvae.py 4.91 KB
一键复制 编辑 原始数据 按行查看 历史
#%%
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import datetime
(x_train, y_train), (x_test, y_test)= \
tf.keras.datasets.mnist.load_data(path="D:/AI/tf/tf_offcial_tutorial/src/dataset/mnist.npz")
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train=x_train.astype(np.float32)
x_test=x_test.astype(np.float32)
x_train=np.expand_dims(x_train,axis=-1)
x_test=np.expand_dims(x_test,axis=-1)
y_train=tf.one_hot(y_train,10)
from tensorflow.keras.layers import Conv2D,Dense,Flatten,Conv2DTranspose,Reshape
from tensorflow.keras.losses import BinaryCrossentropy
class Encoder(tf.keras.Model):
def __init__(self):
super(Encoder, self).__init__()
def build(self,input_shape):
self.conv1 = Conv2D(16, 3, strides=2, activation='relu', padding="same", input_shape=input_shape[1:])
self.conv2 = Conv2D(32, 3, strides=2, activation='relu', padding="same", input_shape=input_shape[1:])
self.flatten=Flatten()
self.dense=Dense(16,activation='relu')
self.dense_mu=Dense(2)
self.dense_sigma=Dense(2)
self.dense1_cls=Dense(16,activation='relu')
self.dense2_cls=Dense(2)
def call(self,x,y):
mu=None
sigma=None
#做生成测试时,x=None
if(x is not None):
x1=self.conv1(x)
x2=self.conv2(x1)
x3=self.flatten(x2)
x4=self.dense(x3)
mu=self.dense_mu(x4)
sigma=self.dense_sigma(x4)
y1=self.dense1_cls(y)
ex_mu=self.dense2_cls(y1)
return mu,sigma,ex_mu
def sample(mu,sigma):
epsilon=tf.random.normal(mu.shape)
z=mu+epsilon*tf.math.exp(sigma/2)
return z
class Decoder(tf.keras.Model):
def __init__(self):
super(Decoder, self).__init__()
def build(self,input_shape):
self.dense1=Dense(7*7*32,activation='relu')
self.reshape=Reshape((7,7,32))
self.conv_transpose1 = Conv2DTranspose(32, 3, strides=2, padding="SAME", activation='relu',use_bias=False)
self.conv_transpose2 = Conv2DTranspose(16, 3, strides=2, padding="SAME", activation='relu',use_bias=False)
self.conv_transpose3 = Conv2DTranspose(1, 3, padding="SAME", activation='sigmoid',use_bias=False)
def call(self,x):
x1=self.dense1(x)
x2=self.reshape(x1)
x3=self.conv_transpose1(x2)
x4=self.conv_transpose2(x3)
x5=self.conv_transpose3(x4)
return x5
encoder=Encoder()
decoder=Decoder()
#%%
from scipy.stats import norm
BATCH=64
ECHO=50
dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).shuffle(60000).batch(BATCH)
test_images_path = "D:/work/demo/cvae_test_images"
def restruction_loss_fun(x,x_h,mu,sigma,ex_mu):
regularization=0.5*tf.math.reduce_sum(tf.math.pow(mu-ex_mu,2)+tf.math.exp(sigma)-sigma-1,axis=-1)
restruction=BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(x,x_h)
restruction=tf.math.reduce_sum(restruction,axis=[1,2])
loss=tf.math.reduce_mean(restruction+regularization)
return loss
opt=tf.keras.optimizers.Adam(lr=1e-3)
@tf.function
def train_step(x,y):
with tf.GradientTape() as tape:
mu,sigma,ex_mu=encoder(x,y)
x0=sample(mu,sigma)
x_h=decoder(x0)
loss=restruction_loss_fun(x,x_h,mu,sigma,ex_mu)
variables=encoder.trainable_variables+decoder.trainable_variables
g = tape.gradient(loss, variables)
opt.apply_gradients(zip(g, variables))
return mu,sigma,ex_mu
def save_and_show_generate_image(decoder, ex_mu, flag=0,epoch=0):
grid_x = norm.ppf(np.linspace(0.4, 0.6, 4))
grid_y = norm.ppf(np.linspace(0.4, 0.6, 4))
for n,mu in enumerate(ex_mu):
for i, yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
x = np.array([[xi+mu[0], yi+mu[1]]])
image = decoder(x)
image = image*255
image = tf.reshape(image, (28, 28))
plt.subplot(10, 16, n*16+i*4+j+1)
plt.imshow(image, cmap="gray")
plt.savefig('{}/image_at_epoch_{:04d}.png'.format(test_images_path, epoch))
if(1 == flag):
plt.show()
else:
plt.pause(0.1)
plt.clf()
log_dir="./"+datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
for i in range(ECHO):
for x,y in dataset:
mu,sigma,ex_mu=train_step(x,y)
mu,sigma,ex_mu=encoder(None,tf.eye(10))
print(i)
save_and_show_generate_image(decoder,ex_mu,flag=0,epoch=i)
#%%
import glob
import imageio
test_images_path = "D:/work/demo/cvae_test_images"
filenames = glob.glob("{}/{}".format(test_images_path, "*.png"))
filenames = sorted(filenames)
images = []
for i in filenames:
images.append(imageio.imread(i))
output_file = test_images_path + "/gif.gif"
imageio.mimsave(output_file, images, duration=0.8)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/jacinth2006/ML.git
git@gitee.com:jacinth2006/ML.git
jacinth2006
ML
机器学习常见算法及演示
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385