代码拉取完成,页面将自动刷新
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
SEED = 42
tf.set_random_seed(SEED)
class GAN():
def sample_Z(self, batch_size, n):
return np.random.uniform(-1., 1., size=(batch_size, n))
def __init__(self, num_features, num_historical_days, generator_input_size=200, is_train=True):
def get_batch_norm_with_global_normalization_vars(size):
v = tf.Variable(tf.ones([size]), dtype=tf.float32)
m = tf.Variable(tf.ones([size]), dtype=tf.float32)
beta = tf.Variable(tf.ones([size]), dtype=tf.float32)
gamma = tf.Variable(tf.ones([size]), dtype=tf.float32)
return v, m, beta, gamma
self.X = tf.placeholder(tf.float32, shape=[None, num_historical_days, num_features])
X = tf.reshape(self.X, [-1, num_historical_days, 1, num_features])
self.Z = tf.placeholder(tf.float32, shape=[None, generator_input_size])
generator_output_size = num_features*num_historical_days
with tf.variable_scope("generator"):
W1 = tf.Variable(tf.truncated_normal([generator_input_size, generator_output_size*10]))
b1 = tf.Variable(tf.truncated_normal([generator_output_size*10]))
h1 = tf.nn.sigmoid(tf.matmul(self.Z, W1) + b1)
# v1, m1, beta1, gamma1 = get_batch_norm_with_global_normalization_vars(generator_output_size*10)
# h1 = tf.nn.batch_norm_with_global_normalization(h1, v1, m1,
# beta1, gamma1, variance_epsilon=0.000001, scale_after_normalization=False)
W2 = tf.Variable(tf.truncated_normal([generator_output_size*10, generator_output_size*5]))
b2 = tf.Variable(tf.truncated_normal([generator_output_size*5]))
h2 = tf.nn.sigmoid(tf.matmul(h1, W2) + b2)
# v2, m2, beta2, gamma2 = get_batch_norm_with_global_normalization_vars(generator_output_size*5)
# h2 = tf.nn.batch_norm_with_global_normalization(h2, v2, m2,
# beta2, gamma2, variance_epsilon=0.000001, scale_after_normalization=False)
W3 = tf.Variable(tf.truncated_normal([generator_output_size*5, generator_output_size]))
b3 = tf.Variable(tf.truncated_normal([generator_output_size]))
g_log_prob = tf.matmul(h2, W3) + b3
g_log_prob = tf.reshape(g_log_prob, [-1, num_historical_days, 1, num_features])
self.gen_data = tf.reshape(g_log_prob, [-1, num_historical_days, num_features])
#g_log_prob = g_log_prob / tf.reshape(tf.reduce_max(g_log_prob, axis=1), [-1, 1, num_features, 1])
#g_prob = tf.nn.sigmoid(g_log_prob)
theta_G = [W1, b1, W2, b2, W3, b3]
with tf.variable_scope("discriminator"):
#[filter_height, filter_width, in_channels, out_channels]
k1 = tf.Variable(tf.truncated_normal([3, 1, num_features, 32],
stddev=0.1,seed=SEED, dtype=tf.float32))
b1 = tf.Variable(tf.zeros([32], dtype=tf.float32))
v1, m1, beta1, gamma1 = get_batch_norm_with_global_normalization_vars(32)
k2 = tf.Variable(tf.truncated_normal([3, 1, 32, 64],
stddev=0.1,seed=SEED, dtype=tf.float32))
b2 = tf.Variable(tf.zeros([64], dtype=tf.float32))
v2, m2, beta2, gamma2 = get_batch_norm_with_global_normalization_vars(64)
k3 = tf.Variable(tf.truncated_normal([3, 1, 64, 128],
stddev=0.1,seed=SEED, dtype=tf.float32))
b3 = tf.Variable(tf.zeros([128], dtype=tf.float32))
v3, m3, beta3, gamma3 = get_batch_norm_with_global_normalization_vars(128)
W1 = tf.Variable(tf.truncated_normal([18*1*128, 128]))
b4 = tf.Variable(tf.truncated_normal([128]))
v4, m4, beta4, gamma4 = get_batch_norm_with_global_normalization_vars(128)
W2 = tf.Variable(tf.truncated_normal([128, 1]))
theta_D = [k1, b1, k2, b2, k3, b3, W1, b4, W2]
def discriminator(X):
conv = tf.nn.conv2d(X,k1,strides=[1, 1, 1, 1],padding='SAME')
relu = tf.nn.relu(tf.nn.bias_add(conv, b1))
pool = relu
# pool = tf.nn.avg_pool(relu, ksize=[1, 2, 1, 1], strides=[1, 2, 1, 1], padding='SAME')
if is_train:
pool = tf.nn.dropout(pool, keep_prob = 0.8)
# pool = tf.nn.batch_norm_with_global_normalization(pool, v1, m1,
# beta1, gamma1, variance_epsilon=0.000001, scale_after_normalization=False)
print(pool)
conv = tf.nn.conv2d(pool, k2,strides=[1, 1, 1, 1],padding='SAME')
relu = tf.nn.relu(tf.nn.bias_add(conv, b2))
pool = relu
#pool = tf.nn.avg_pool(relu, ksize=[1, 2, 1, 1], strides=[1, 2, 1, 1], padding='SAME')
if is_train:
pool = tf.nn.dropout(pool, keep_prob = 0.8)
# pool = tf.nn.batch_norm_with_global_normalization(pool, v2, m2,
# beta2, gamma2, variance_epsilon=0.000001, scale_after_normalization=False)
print(pool)
conv = tf.nn.conv2d(pool, k3, strides=[1, 1, 1, 1], padding='VALID')
relu = tf.nn.relu(tf.nn.bias_add(conv, b3))
if is_train:
relu = tf.nn.dropout(relu, keep_prob=0.8)
# relu = tf.nn.batch_norm_with_global_normalization(relu, v3, m3,
# beta3, gamma3, variance_epsilon=0.000001, scale_after_normalization=False)
print(relu)
flattened_convolution_size = int(relu.shape[1]) * int(relu.shape[2]) * int(relu.shape[3])
print(flattened_convolution_size)
flattened_convolution = features = tf.reshape(relu, [-1, flattened_convolution_size])
if is_train:
flattened_convolution = tf.nn.dropout(flattened_convolution, keep_prob=0.8)
h1 = tf.nn.relu(tf.matmul(flattened_convolution, W1) + b4)
# h1 = tf.nn.batch_norm_with_global_normalization(h1, v4, m4,
# beta4, gamma4, variance_epsilon=0.000001, scale_after_normalization=False)
D_logit = tf.matmul(h1, W2)
D_prob = tf.nn.sigmoid(D_logit)
return D_prob, D_logit, features
D_real, D_logit_real, self.features = discriminator(X)
D_fake, D_logit_fake, _ = discriminator(g_log_prob)
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
self.D_l2_loss = (0.0001 * tf.add_n([tf.nn.l2_loss(t) for t in theta_D]) / len(theta_D))
self.D_loss = D_loss_real + D_loss_fake + self.D_l2_loss
self.G_l2_loss = (0.00001 * tf.add_n([tf.nn.l2_loss(t) for t in theta_G]) / len(theta_G))
self.G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake))) + self.G_l2_loss
self.D_solver = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(self.D_loss, var_list=theta_D)
self.G_solver = tf.train.AdamOptimizer(learning_rate=0.000055).minimize(self.G_loss, var_list=theta_G)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。