代码拉取完成,页面将自动刷新
# coding:utf8
import tensorflow as tf
from PIL import Image,ImageFilter
import matplotlib.pyplot as plt
import time
import numpy as np
model_path = "./isEndModel/model/model.ckpt" #模型文件
CLASS = 2
g1=tf.Graph()
#读取二进制数据
def read_and_decode(filename):
with g1.as_default():
# 创建文件队列,不限读取的数量
filename_queue = tf.train.string_input_producer([filename])
# create a reader from file queue
reader = tf.TFRecordReader()
# reader从文件队列中读入一个序列化的样本
_, serialized_example = reader.read(filename_queue)
# get feature from serialized example
# 解析符号化的样本
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
}
)
label = features['label']
label = tf.cast(label, tf.int32)
label = tf.one_hot(label,CLASS,1,0)
img = features['img_raw']
img = tf.decode_raw(img, tf.uint8)
img = tf.reshape(img, [28,28, 1])
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
return img, label
"""
权重初始化
初始化为一个接近0的很小的正数
"""
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev = 0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape = shape)
return tf.Variable(initial)
"""
卷积和池化,使用卷积步长为1(stride size),0边距(padding size)
池化用简单传统的2x2大小的模板做max pooling
"""
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding = 'SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize = [1, 2, 2, 1],
strides = [1, 2, 2, 1], padding = 'SAME')
# 在计算图g1中定义张量和操作
with g1.as_default():
x = tf.placeholder("float", shape=[None, 28,28,1])
y_ = tf.placeholder("float", shape=[None,2])
"""
第一层 卷积层
"""
W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(x, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
"""
第二层 卷积层
"""
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
"""
第三层 全连接层
"""
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
"""
Dropout
"""
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
"""
第四层 Softmax输出层
"""
W_fc2 = weight_variable([1024, 2])
b_fc2 = bias_variable([2])
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
cross_entropy = tf.reduce_mean(tf.square(y_- y_conv))
train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
pred = tf.argmax(y_conv,1)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
saver = tf.train.Saver()
sess = tf.Session(graph=g1)
#sess.run(tf.global_variables_initializer())#第一次运行不要屏蔽
saver.restore(sess, model_path)#恢复模型 第一次运行屏蔽
#预测
def predict(img):
with g1.as_default():
img = img.crop((100, 1600, 300, 1800))
img = img.convert('L')
img = img.resize((28, 28))
img = np.array(img)
img = tf.constant(img)
img = tf.reshape(img, [28, 28, 1])
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
img = tf.expand_dims(img, 0)
img = sess.run(img)
result = sess.run([pred],feed_dict={x:img,keep_prob: 1})
return result[0][0]
# 训练
def train():
with sess.as_default():
batch_size = 15
img, label = read_and_decode("./isEndModel/data/train.tfrecords")
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size=batch_size, capacity=2000,
min_after_dequeue=1000)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(5000): #开始训练模型,循环训练5000次
train_batch_x, train_batch_y = sess.run([img_batch, label_batch])
sess.run([train_step],feed_dict={x:train_batch_x, y_: train_batch_y,keep_prob: 0.5})
if i % 100 == 0:
p,accuracy2 = sess.run([pred,accuracy],feed_dict={x:train_batch_x, y_: train_batch_y,keep_prob: 1})
print(train_batch_y)
print(p)
print(i, accuracy2)
save_path = saver.save(sess, model_path)#保存模型
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
print(predict(Image.open("./isEndModel/data/img/0001-1.png")))
#train()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。