1 Star 0 Fork 5

耀轩之/tflite_train

forked from WJG/tflite_train 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
cnn.py 4.71 KB
一键复制 编辑 原始数据 按行查看 历史
WJG 提交于 2019-12-01 10:06 . 测试版0.2
# -*- coding: UTF-8 -*-
import tensorflow as tf
import numpy as np
import os
from PIL import Image
import random
# 取消tf2.0紧急模式
tf.compat.v1.disable_eager_execution()
class CNN(object):
def __init__(self, image_height, image_width, max_captcha, char_set, model_save_dir):
# 初始值
self.image_height = image_height
self.image_width = image_width
self.max_captcha = max_captcha
self.char_set = char_set
self.char_set_len = len(char_set)
self.model_save_dir = model_save_dir # 模型路径
with tf.compat.v1.name_scope('parameters'):
self.w_alpha = 0.01
self.b_alpha = 0.1
# tf初始化占位符
with tf.compat.v1.name_scope('data'):
self.X = tf.compat.v1.placeholder(tf.float32, [None, self.image_height * self.image_width]) # 特征向量
self.Y = tf.compat.v1.placeholder(tf.float32, [None, self.max_captcha * self.char_set_len]) # 标签
@staticmethod
def convert2gray(img):
"""
图片转为灰度图,如果是3通道图则计算,单通道图则直接返回
:param img:
:return:
"""
if len(img.shape) > 2:
r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
return gray
else:
return img
def text2vec(self, text):
"""
转标签为oneHot编码
:param text: str
:return: numpy.array
"""
text_len = len(text)
if text_len > self.max_captcha:
raise ValueError('验证码最长{}个字符'.format(self.max_captcha))
vector = np.zeros(self.max_captcha * self.char_set_len)
for i, ch in enumerate(text):
idx = i * self.char_set_len + self.char_set.index(ch)
vector[idx] = 1
return vector
def model(self):
x = tf.reshape(self.X, shape=[-1, self.image_height, self.image_width, 1])
print(">>> input x: {}".format(x))
# 卷积层1
wc1 = tf.compat.v1.get_variable(name='wc1', shape=[3, 3, 1, 32], dtype=tf.float32,
initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
bc1 = tf.Variable(self.b_alpha * tf.random.normal([32]))
conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(input=x, filters=wc1, strides=[1, 1, 1, 1], padding='SAME'), bc1))
conv1 = tf.nn.max_pool2d(input=conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
# 卷积层2
wc2 = tf.compat.v1.get_variable(name='wc2', shape=[3, 3, 32, 64], dtype=tf.float32,
initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
bc2 = tf.Variable(self.b_alpha * tf.random.normal([64]))
conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(input=conv1, filters=wc2, strides=[1, 1, 1, 1], padding='SAME'), bc2))
conv2 = tf.nn.max_pool2d(input=conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
# 卷积层3
wc3 = tf.compat.v1.get_variable(name='wc3', shape=[3, 3, 64, 128], dtype=tf.float32,
initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
bc3 = tf.Variable(self.b_alpha * tf.random.normal([128]))
conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(input=conv2, filters=wc3, strides=[1, 1, 1, 1], padding='SAME'), bc3))
conv3 = tf.nn.max_pool2d(input=conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
print(">>> convolution 3: ", conv3.shape)
next_shape = conv3.shape[1] * conv3.shape[2] * conv3.shape[3]
# 全连接层1
wd1 = tf.compat.v1.get_variable(name='wd1', shape=[next_shape, 1024], dtype=tf.float32,
initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
bd1 = tf.Variable(self.b_alpha * tf.random.normal([1024]))
dense = tf.reshape(conv3, [-1, wd1.get_shape().as_list()[0]])
dense = tf.nn.relu(tf.add(tf.matmul(dense, wd1), bd1))
# 全连接层2
wout = tf.compat.v1.get_variable('name', shape=[1024, self.max_captcha * self.char_set_len], dtype=tf.float32,
initializer=tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
bout = tf.Variable(self.b_alpha * tf.random.normal([self.max_captcha * self.char_set_len]))
with tf.compat.v1.name_scope('y_prediction'):
y_predict = tf.add(tf.matmul(dense, wout), bout)
return y_predict
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/yaoxuanzhi/tflite_train.git
git@gitee.com:yaoxuanzhi/tflite_train.git
yaoxuanzhi
tflite_train
tflite_train
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385