1 Star 0 Fork 0

ccc_38451/ocrcn_tf2

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train_simple.py 4.21 KB
一键复制 编辑 原始数据 按行查看 历史
Your Name 提交于 2019-06-07 19:04 . add
"""
training a simple net on Chinese Characters classification dataset
we got about 90% accuracy by simply applying a simple CNN net
"""
from alfred.dl.tf.common import mute_tf
mute_tf()
import os
import sys
import numpy as np
import tensorflow as tf
from alfred.utils.log import logger as logging
import tensorflow_datasets as tfds
from dataset.casia_hwdb import load_ds, load_characters, load_val_ds
from models.cnn_net import CNNNet, build_net_002, build_net_003
target_size = 64
num_classes = 7356
# use_keras_fit = False
use_keras_fit = True
ckpt_path = './checkpoints/cn_ocr-{epoch}.ckpt'
def preprocess(x):
"""
minus mean pixel or normalize?
"""
# original is 64x64, add a channel dim
x['image'] = tf.expand_dims(x['image'], axis=-1)
x['image'] = tf.image.resize(x['image'], (target_size, target_size))
x['image'] = (x['image'] - 128.) / 128.
return x['image'], x['label']
def train():
all_characters = load_characters()
num_classes = len(all_characters)
logging.info('all characters: {}'.format(num_classes))
train_dataset = load_ds()
train_dataset = train_dataset.shuffle(100).map(preprocess).batch(32).repeat()
val_ds = load_val_ds()
val_ds = val_ds.shuffle(100).map(preprocess).batch(32).repeat()
for data in train_dataset.take(2):
print(data)
# init model
model = build_net_003((64, 64, 1), num_classes)
model.summary()
logging.info('model loaded.')
start_epoch = 0
latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path))
if latest_ckpt:
start_epoch = int(latest_ckpt.split('-')[1].split('.')[0])
model.load_weights(latest_ckpt)
logging.info('model resumed from: {}, start at epoch: {}'.format(latest_ckpt, start_epoch))
else:
logging.info('passing resume since weights not there. training from scratch')
if use_keras_fit:
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
callbacks = [
tf.keras.callbacks.ModelCheckpoint(ckpt_path,
save_weights_only=True,
verbose=1,
period=500)
]
try:
model.fit(
train_dataset,
validation_data=val_ds,
validation_steps=1000,
epochs=15000,
steps_per_epoch=1024,
callbacks=callbacks)
except KeyboardInterrupt:
model.save_weights(ckpt_path.format(epoch=0))
logging.info('keras model saved.')
model.save_weights(ckpt_path.format(epoch=0))
model.save(os.path.join(os.path.dirname(ckpt_path), 'cn_ocr.h5'))
else:
loss_fn = tf.losses.SparseCategoricalCrossentropy()
optimizer = tf.optimizers.RMSprop()
train_loss = tf.metrics.Mean(name='train_loss')
train_accuracy = tf.metrics.SparseCategoricalAccuracy(name='train_accuracy')
for epoch in range(start_epoch, 120):
try:
for batch, data in enumerate(train_dataset):
# images, labels = data['image'], data['label']
images, labels = data
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_fn(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
if batch % 10 == 0:
logging.info('Epoch: {}, iter: {}, loss: {}, train_acc: {}'.format(
epoch, batch, train_loss.result(), train_accuracy.result()))
except KeyboardInterrupt:
logging.info('interrupted.')
model.save_weights(ckpt_path.format(epoch=epoch))
logging.info('model saved into: {}'.format(ckpt_path.format(epoch=epoch)))
exit(0)
if __name__ == "__main__":
train()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/chenyang3817/tensor2.git
git@gitee.com:chenyang3817/tensor2.git
chenyang3817
tensor2
ocrcn_tf2
master

搜索帮助