1 Star 8 Fork 1

sunny_ou/SpeechRecongizeSystem

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 7.76 KB
一键复制 编辑 原始数据 按行查看 历史
sunny_ou 提交于 2021-01-29 16:28 . 上传train.py
import os
import time
import tensorflow as tf
from utils import get_data, data_hparams, GetEditDistance, decode_ctc
from keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau, LambdaCallback
from sklearn.metrics import roc_auc_score
import numpy as np
import matplotlib.pyplot as plt
import warnings
# dataLength = 1000 # 每次训练的数据数
# RootPath = "./log/"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
warnings.filterwarnings('ignore')
# gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
# sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
# 0.准备训练所需数据------------------------------
data_args = data_hparams()
data_args.data_type = 'train'
data_args.data_path = 'data/'
data_args.mmcs = True
data_args.thchs30 = False
data_args.aishell = False
data_args.prime = False
data_args.stcmd = False
data_args.batch_size = 5
# data_args.data_length = 10000
data_args.data_length = None
data_args.shuffle = True
train_data = get_data(data_args)
count_length = train_data.countLength
# 0.准备验证所需数据------------------------------
data_args = data_hparams()
data_args.data_type = 'dev'
data_args.data_path = 'data/'
data_args.mmcs = True
data_args.thchs30 = False
data_args.aishell = False
data_args.prime = False
data_args.stcmd = False
data_args.batch_size = 1
# max 893
data_args.data_length = None
# data_args.data_length = 2000
data_args.shuffle = False
dev_data = get_data(data_args)
dirPath = "./log/log_am"
modelPath = "./log/log_am/model_gobalModel_300.h5"
# start = i * dataLength
# end = start + dataLength - 1
# train_data.starItem = start
# train_data.endItem = end
# 重新获取数据
# train_data.adjustDataList()
# print("训练迭代数据轮:", str(i + 1))
# 开始训练
# 1.声学模型训练-----------------------------------
from model_speech.cnn_ctc import Am, am_hparams
# from model_speech.gru_ctc import Am, am_hparams
am_args = am_hparams()
am_args.vocab_size = len(train_data.am_vocab)
am_args.gpu_nums = 1
am_args.lr = 0.0008
am_args.is_training = True
am = Am(am_args)
print("数据开始:", train_data.starItem)
print("数据结束:", train_data.endItem)
epochs = 300
batch_num = len(train_data.wav_lst) // train_data.batch_size
if os.path.exists(modelPath):
print('load acoustic model...')
am.ctc_model.load_weights(modelPath) # 加载模型
# 准备数据
batch = train_data.get_am_batch()
dev_batch = dev_data.get_am_batch()
# 回调函数
tensorBoard = TensorBoard(log_dir="./log/logs_am/tensorboard/" + str(int(time.time())), write_grads=True,
histogram_freq=0, update_freq="batch")
tensorBoard.set_model(am.ctc_model)
earlyStopping = EarlyStopping(
monitor='loss', min_delta=1e-5, patience=5, verbose=1
)
reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.1,
patience=3, min_lr=0.00001)
plot_loss_callback = LambdaCallback(
on_epoch_end=lambda epoch, logs: plt.plot(np.arange(epoch),
logs['loss']))
myCallBack = tf.keras.callbacks.LambdaCallback(
on_epoch_end=lambda self, batch, logs: self.model.predict(self.validation_data))
# 开始训练
am.ctc_model.fit_generator(batch, steps_per_epoch=batch_num, initial_epoch=0, epochs=epochs,
callbacks=[tensorBoard, earlyStopping, reduce_lr
],
workers=1,
use_multiprocessing=False, validation_data=dev_batch, validation_steps=10)
# 保存模型
am.ctc_model.save_weights(modelPath)
# 测试准确率
word_error_num = 0
word_num = 0
with open("./log/logout.txt", "a") as file:
file.write("=" * 20 + "\n")
file.close()
j = 0
for item in range(10):
inputs, _ = next(dev_batch)
x = inputs['the_inputs']
result = am.model.predict(x)
# print(result.shape)
# print("============")
# print(len(dev_data.am_vocab))
# print(len(train_data.am_vocab))
# result = result.reshape(result.shape[1], result.shape[0], result.shape[2])
# print(result.shape)
_, result = decode_ctc(result, train_data.am_vocab)
label = dev_data.pny_lst[j]
j += 1
with open("./log/logout.txt", "a") as file:
file.write("预测:" + ','.join(result) + "\n")
file.write("实际:" + ','.join(label) + "\n")
file.close()
# 计算两个拼音的差距
word_error_num += min(len(label), GetEditDistance(label, result))
word_num += len(label)
print('词错误率:', (word_error_num / word_num))
i = 1
strLine = '【第' + str(i) + '轮】词错误率:' + str((word_error_num / word_num))
# 每次追加记录
with open("./log/logout.txt", "a") as file:
file.write(strLine + "\n")
file.close()
with open("./log/logout.txt", "a") as file:
file.write("=" * 20 + "\n")
file.close()
print("=================================")
print("=================================")
print("=================================")
print("=================================")
print("=================================")
print("=================================")
print("=================================")
print("声学模型学习完毕")
print("=================================")
print("=================================")
print("=================================")
print("=================================")
print("=================================")
print("=================================")
print("=================================")
print("=================================")
# checkpoint
# ckpt = "model_MMCS_{epoch:02d}-{val_loss:.2f}.h5"
# checkpoint = ModelCheckpoint(os.path.join('./checkpoint/Hai', ckpt), monitor='val_loss',
# save_weights_only=False,
# verbose=1,
# save_best_only=True)
#
# ckpt_pi = "model_{epoch:02d}-{loss:.2f}.h5"
# checkpointPi = ModelCheckpoint(os.path.join('./checkpointPi/Hai', ckpt_pi), monitor='loss',
# save_weights_only=False,
# verbose=1,
# save_best_only=True)
# 开始训练
# 2.语言模型训练-------------------------------------------
from model_language.transformer import Lm, lm_hparams
lm_args = lm_hparams()
lm_args.num_heads = 8
lm_args.num_blocks = 6
lm_args.input_vocab_size = len(train_data.pny_vocab)
lm_args.label_vocab_size = len(train_data.han_vocab)
lm_args.max_length = 500
lm_args.hidden_units = 512
lm_args.dropout_rate = 0.2
lm_args.lr = 0.0003
lm_args.is_training = True
lm = Lm(lm_args)
epochs = 50
with lm.graph.as_default():
saver = tf.train.Saver()
with tf.Session(graph=lm.graph) as sess:
merged = tf.summary.merge_all()
sess.run(tf.global_variables_initializer())
add_num = 0
# if os.path.exists('logs_lm/checkpoint'):
# print('loading language model...')
# latest = tf.train.latest_checkpoint('logs_lm')
# add_num = int(latest.split('_')[-1])
# saver.restore(sess, latest)
writer = tf.summary.FileWriter('./log/logs_lm/tensorboard/Hai', tf.get_default_graph())
for k in range(epochs):
total_loss = 0
batch = train_data.get_lm_batch()
for i in range(batch_num):
input_batch, label_batch = next(batch)
if len(np.shape(label_batch)) < 2:
print(label_batch)
continue
feed = {lm.x: input_batch, lm.y: label_batch}
cost, _ = sess.run([lm.mean_loss, lm.train_op], feed_dict=feed)
total_loss += cost
if (k * batch_num + i) % 10 == 0:
rs = sess.run(merged, feed_dict=feed)
writer.add_summary(rs, k * batch_num + i)
print('epochs', k + 1, ': average loss = ', total_loss / batch_num)
saver.save(sess, './log/logs_lm/%d_time/model20210129_%d' % (time.time(), (epochs + add_num)))
writer.close()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/sunny_ou/speech-recongize-system.git
git@gitee.com:sunny_ou/speech-recongize-system.git
sunny_ou
speech-recongize-system
SpeechRecongizeSystem
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385