代码拉取完成,页面将自动刷新
# encoding: UTF-8
# Copyright 2016 Google.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
import tensorflowvisu
import math
from tensorflow.examples.tutorials.mnist import input_data as mnist_data
print("Tensorflow version " + tf.__version__)
tf.set_random_seed(0)
# neural network with 5 layers
#
# · · · · · · · · · · (input data, flattened pixels) X [batch, 784] # 784 = 28*28
# \x/x\x/x\x/x\x/x\x/ ✞ -- fully connected layer (relu+dropout) W1 [784, 200] B1[200]
# · · · · · · · · · Y1 [batch, 200]
# \x/x\x/x\x/x\x/ ✞ -- fully connected layer (relu+dropout) W2 [200, 100] B2[100]
# · · · · · · · Y2 [batch, 100]
# \x/x\x/x\x/ ✞ -- fully connected layer (relu+dropout) W3 [100, 60] B3[60]
# · · · · · Y3 [batch, 60]
# \x/x\x/ ✞ -- fully connected layer (relu+dropout) W4 [60, 30] B4[30]
# · · · Y4 [batch, 30]
# \x/ -- fully connected layer (softmax) W5 [30, 10] B5[10]
# · Y5 [batch, 10]
# Download images and labels into mnist.test (10K images+labels) and mnist.train (60K images+labels)
mnist = mnist_data.read_data_sets("data", one_hot=True, reshape=False, validation_size=0)
# input X: 28x28 grayscale images, the first dimension (None) will index the images in the mini-batch
X = tf.placeholder(tf.float32, [None, 28, 28, 1])
# correct answers will go here
Y_ = tf.placeholder(tf.float32, [None, 10])
# variable learning rate
lr = tf.placeholder(tf.float32)
# Probability of keeping a node during dropout = 1.0 at test time (no dropout) and 0.75 at training time
pkeep = tf.placeholder(tf.float32)
# five layers and their number of neurons (tha last layer has 10 softmax neurons)
L = 200
M = 100
N = 60
O = 30
# Weights initialised with small random values between -0.2 and +0.2
# When using RELUs, make sure biases are initialised with small *positive* values for example 0.1 = tf.ones([K])/10
W1 = tf.Variable(tf.truncated_normal([784, L], stddev=0.1)) # 784 = 28 * 28
B1 = tf.Variable(tf.ones([L])/10)
W2 = tf.Variable(tf.truncated_normal([L, M], stddev=0.1))
B2 = tf.Variable(tf.ones([M])/10)
W3 = tf.Variable(tf.truncated_normal([M, N], stddev=0.1))
B3 = tf.Variable(tf.ones([N])/10)
W4 = tf.Variable(tf.truncated_normal([N, O], stddev=0.1))
B4 = tf.Variable(tf.ones([O])/10)
W5 = tf.Variable(tf.truncated_normal([O, 10], stddev=0.1))
B5 = tf.Variable(tf.zeros([10]))
# The model, with dropout at each layer
XX = tf.reshape(X, [-1, 28*28],name="input_feed")
Y1 = tf.nn.relu(tf.matmul(XX, W1) + B1)
Y1d = tf.nn.dropout(Y1, pkeep)
Y2 = tf.nn.relu(tf.matmul(Y1d, W2) + B2)
Y2d = tf.nn.dropout(Y2, pkeep)
Y3 = tf.nn.relu(tf.matmul(Y2d, W3) + B3)
Y3d = tf.nn.dropout(Y3, pkeep)
Y4 = tf.nn.relu(tf.matmul(Y3d, W4) + B4)
Y4d = tf.nn.dropout(Y4, pkeep)
Ylogits = tf.matmul(Y4d, W5) + B5
Y = tf.nn.softmax(tf.matmul(Y4d, W5) + B5,name="operator_restore")
# cross-entropy loss function (= -sum(Y_i * log(Yi)) ), normalised for batches of 100 images
# TensorFlow provides the softmax_cross_entropy_with_logits function to avoid numerical stability
# problems with log(0) which is NaN
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=Ylogits, labels=Y_)
cross_entropy = tf.reduce_mean(cross_entropy)*100
# accuracy of the trained model, between 0 (worst) and 1 (best)
correct_prediction = tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# matplotlib visualisation
allweights = tf.concat([tf.reshape(W1, [-1]), tf.reshape(W2, [-1]), tf.reshape(W3, [-1]), tf.reshape(W4, [-1]), tf.reshape(W5, [-1])], 0)
allbiases = tf.concat([tf.reshape(B1, [-1]), tf.reshape(B2, [-1]), tf.reshape(B3, [-1]), tf.reshape(B4, [-1]), tf.reshape(B5, [-1])], 0)
I = tensorflowvisu.tf_format_mnist_images(X, Y, Y_)
It = tensorflowvisu.tf_format_mnist_images(X, Y, Y_, 1000, lines=25)
datavis = tensorflowvisu.MnistDataVis()
# training step, the learning rate is a placeholder
train_step = tf.train.AdamOptimizer(lr).minimize(cross_entropy)
# init
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
# You can call this function in a loop to train the model, 100 images at a time
def training_step(i, update_test_data, update_train_data):
# training on batches of 100 images with 100 labels
batch_X, batch_Y = mnist.train.next_batch(100)
# learning rate decay
max_learning_rate = 0.003
min_learning_rate = 0.0001
decay_speed = 2000.0 # 0.003-0.0001-2000=>0.9826 done in 5000 iterations
learning_rate = min_learning_rate + (max_learning_rate - min_learning_rate) * math.exp(-i/decay_speed)
# compute training values for visualisation
if update_train_data:
a, c, im, w, b = sess.run([accuracy, cross_entropy, I, allweights, allbiases], {X: batch_X, Y_: batch_Y, pkeep: 1.0})
print(str(i) + ": accuracy:" + str(a) + " loss: " + str(c) + " (lr:" + str(learning_rate) + ")")
datavis.append_training_curves_data(i, a, c)
datavis.update_image1(im)
datavis.append_data_histograms(i, w, b)
# compute test values for visualisation
if update_test_data:
a, c, im = sess.run([accuracy, cross_entropy, It], {X: mnist.test.images, Y_: mnist.test.labels, pkeep: 1.0})
print(str(i) + ": ********* epoch " + str(i*100//mnist.train.images.shape[0]+1) + " ********* test accuracy:" + str(a) + " test loss: " + str(c))
datavis.append_test_curves_data(i, a, c)
datavis.update_image2(im)
# the backpropagation training step
sess.run(train_step, {X: batch_X, Y_: batch_Y, pkeep: 0.75, lr: learning_rate})
#datavis.animate(training_step, iterations=10000+1, train_data_update_freq=20, test_data_update_freq=100, more_tests_at_start=True)
# to save the animation as a movie, add save_movie=True as an argument to datavis.animate
# to disable the visualisation use the following line instead of the datavis.animate line
for i in range(10000+1): training_step(i, i % 100 == 0, i % 20 == 0)
print("max test accuracy: " + str(datavis.get_max_test_accuracy()))
# Some results to expect:
# (In all runs, if sigmoids are used, all biases are initialised at 0, if RELUs are used,
# all biases are initialised at 0.1 apart from the last one which is initialised at 0.)
## test with and without dropout, decaying learning rate from 0.003 to 0.0001 decay_speed 2000, 10K iterations
# final test accuracy = 0.9817 (relu, dropout 0.75, training cross-entropy still a bit noisy, test cross-entropy stable, test accuracy stable just under 98.2)
# final test accuracy = 0.9824 (relu, no dropout, training cross-entropy down to 0, test cross-entropy goes up significantly, test accuracy stable around 98.2)
## learning rate = 0.003, 10K iterations, no dropout
# final test accuracy = 0.9788 (sigmoid - slow start, training cross-entropy not stabilised in the end)
# final test accuracy = 0.9825 (relu - above 0.97 in the first 1500 iterations but noisy curves)
## now with learning rate = 0.0001, 10K iterations, no dropout
# final test accuracy = 0.9722 (relu - slow but smooth curve, would have gone higher in 20K iterations)
## decaying learning rate from 0.003 to 0.0001 decay_speed 2000, 10K iterations, no dropout
# final test accuracy = 0.9746 (sigmoid - training cross-entropy not stabilised)
# final test accuracy = 0.9824 (relu, training cross-entropy down to 0, test cross-entropy goes up significantly, test accuracy stable around 98.2)
# on another run, peak at 0.9836
saver = tf.train.Saver()
md_path = R"C:\Users\yaoya\AppData\Local\conda\conda\envs\tensorflow\tensorflow-mnist-tutorial\TestProject\mdlib\md2\init"
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
sess.run(init)
save_path = saver.save(sess, md_path, global_step=1000)
print("Model saved in file: %s" % save_path)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。