1 Star 0 Fork 0

宝贝龙/wx-jump

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
isEndModel.py 5.32 KB
一键复制 编辑 原始数据 按行查看 历史
宝贝龙 提交于 2019-09-17 21:57 . sss
# coding:utf8
import tensorflow as tf
from PIL import Image,ImageFilter
import matplotlib.pyplot as plt
import time
import numpy as np
model_path = "./isEndModel/model/model.ckpt" #模型文件
CLASS = 2
g1=tf.Graph()
#读取二进制数据
def read_and_decode(filename):
with g1.as_default():
# 创建文件队列,不限读取的数量
filename_queue = tf.train.string_input_producer([filename])
# create a reader from file queue
reader = tf.TFRecordReader()
# reader从文件队列中读入一个序列化的样本
_, serialized_example = reader.read(filename_queue)
# get feature from serialized example
# 解析符号化的样本
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
}
)
label = features['label']
label = tf.cast(label, tf.int32)
label = tf.one_hot(label,CLASS,1,0)
img = features['img_raw']
img = tf.decode_raw(img, tf.uint8)
img = tf.reshape(img, [28,28, 1])
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
return img, label
"""
权重初始化
初始化为一个接近0的很小的正数
"""
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev = 0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape = shape)
return tf.Variable(initial)
"""
卷积和池化,使用卷积步长为1(stride size),0边距(padding size)
池化用简单传统的2x2大小的模板做max pooling
"""
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding = 'SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize = [1, 2, 2, 1],
strides = [1, 2, 2, 1], padding = 'SAME')
# 在计算图g1中定义张量和操作
with g1.as_default():
x = tf.placeholder("float", shape=[None, 28,28,1])
y_ = tf.placeholder("float", shape=[None,2])
"""
第一层 卷积层
"""
W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(x, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
"""
第二层 卷积层
"""
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
"""
第三层 全连接层
"""
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
"""
Dropout
"""
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
"""
第四层 Softmax输出层
"""
W_fc2 = weight_variable([1024, 2])
b_fc2 = bias_variable([2])
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
cross_entropy = tf.reduce_mean(tf.square(y_- y_conv))
train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
pred = tf.argmax(y_conv,1)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
saver = tf.train.Saver()
sess = tf.Session(graph=g1)
#sess.run(tf.global_variables_initializer())#第一次运行不要屏蔽
saver.restore(sess, model_path)#恢复模型 第一次运行屏蔽
#预测
def predict(img):
with g1.as_default():
img = img.crop((100, 1600, 300, 1800))
img = img.convert('L')
img = img.resize((28, 28))
img = np.array(img)
img = tf.constant(img)
img = tf.reshape(img, [28, 28, 1])
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
img = tf.expand_dims(img, 0)
img = sess.run(img)
result = sess.run([pred],feed_dict={x:img,keep_prob: 1})
return result[0][0]
# 训练
def train():
with sess.as_default():
batch_size = 15
img, label = read_and_decode("./isEndModel/data/train.tfrecords")
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size=batch_size, capacity=2000,
min_after_dequeue=1000)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(5000): #开始训练模型,循环训练5000次
train_batch_x, train_batch_y = sess.run([img_batch, label_batch])
sess.run([train_step],feed_dict={x:train_batch_x, y_: train_batch_y,keep_prob: 0.5})
if i % 100 == 0:
p,accuracy2 = sess.run([pred,accuracy],feed_dict={x:train_batch_x, y_: train_batch_y,keep_prob: 1})
print(train_batch_y)
print(p)
print(i, accuracy2)
save_path = saver.save(sess, model_path)#保存模型
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
print(predict(Image.open("./isEndModel/data/img/0001-1.png")))
#train()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/aokinba/wx-jump.git
git@gitee.com:aokinba/wx-jump.git
aokinba
wx-jump
wx-jump
master

搜索帮助