1 Star 0 Fork 5

耀轩之/tflite_train

forked from WJG/tflite_train 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
save.py 3.16 KB
一键复制 编辑 原始数据 按行查看 历史
WJG 提交于 2019-12-01 15:51 . 正式版1.0
# -*- coding: utf-8 -*-
import json
import tensorflow as tf
import numpy as np
import time
from PIL import Image
import random
import os
from cnn import CNN
class TestError(Exception):
pass
class TestBatch(CNN):
def __init__(self, max_captcha,image_width, image_height,img_path, char_set, model_save_dir, total):
# 模型路径
self.model_save_dir = model_save_dir
# 打乱文件顺序
self.img_path = img_path
self.img_list = os.listdir(img_path)
random.seed(time.time())
random.shuffle(self.img_list)
# 获得图片宽高和字符长度基本信息
self.image_height = image_height
self.image_width = image_width
self.max_captcha = max_captcha
# 初始化变量
super(TestBatch, self).__init__(image_height, image_width, max_captcha, char_set, model_save_dir)
self.total = total
# 相关信息打印
print("-->图片尺寸: {} X {}".format(image_height, image_width))
print("-->验证码长度: {}".format(self.max_captcha))
print("-->验证码共{}类 {}".format(self.char_set_len, char_set))
print("-->使用测试集为 {}".format(img_path))
def gen_captcha_text_image(self):
"""
返回一个验证码的array形式和对应的字符串标签
:return:tuple (str, numpy.array)
"""
img_name = random.choice(self.img_list)
# 标签
label = img_name.split("_")[0]
img_file = os.path.join(self.img_path, img_name)
captcha_image = Image.open(img_file)
captcha_array = self.img2input(captcha_image) # 图片转输入向量
return label, captcha_array
def img2input(self,img):
img = np.array(img)
test_image = 0.3 * img[:, :, 0] + 0.6 * img[:, :, 1] + 0.1 * img[:, :, 2]
tmpe_array = test_image.flatten() / 255
input_array = np.expand_dims(tmpe_array, axis=0)
return input_array
def test_batch(self):
y_predict = self.model()
saver = tf.compat.v1.train.Saver()
with tf.compat.v1.Session() as sess:
saver.restore(sess, self.model_save_dir)
test_text, test_image = self.gen_captcha_text_image() # 随机
self.Y = tf.argmax(input=tf.reshape(y_predict, [-1, self.max_captcha, self.char_set_len]), axis=2)
sess.run(self.Y, feed_dict={self.X: test_image})
print('准备保存模型......')
tf.compat.v1.saved_model.simple_save(sess,
"model2/",
inputs={"X": self.X},
outputs={"Y": self.Y})
print('>---模型保存成功---<')
def main():
with open("config.json", "r") as f:
sample_conf = json.load(f)
test_image_dir = sample_conf["test_image_dir"]
model_save_dir = sample_conf["model_save_dir"]
image_width = sample_conf['image_width']
image_height = sample_conf['image_height']
max_captcha = sample_conf['max_captcha']
char_set = sample_conf["char_set"]
total = 100
tb = TestBatch(max_captcha,image_width, image_height,test_image_dir, char_set, model_save_dir, total)
tb.test_batch()
if __name__ == '__main__':
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/yaoxuanzhi/tflite_train.git
git@gitee.com:yaoxuanzhi/tflite_train.git
yaoxuanzhi
tflite_train
tflite_train
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385