代码拉取完成,页面将自动刷新
# -*- coding: UTF-8 -*-
import scipy.io.wavfile as wav
import matplotlib.pyplot as plt
import os
import numpy as np
from scipy.fftpack import fft
import matplotlib.pyplot as plt
import tensorflow as tf
import tqdm
import keras
from keras.layers import Input, Conv2D, BatchNormalization, MaxPooling2D
from keras.layers import Reshape, Dense, Lambda, Dropout
from keras.optimizers import Adam
from keras import backend as K
from keras.models import Model
from keras.utils import multi_gpu_model
import pandas.util.testing as tm
from fun.utils import cnn_cell, dense, ctc_lambda, source_get, gen_label_data, mk_vocab, data_generator, decode_ctc, \
mk_lm_pny_vocab, mk_lm_han_vocab, GetEditDistance
from fun.model import Amodel
from model_language.transformer import Lm, lm_hparams
from GetAudio import get_Audio
import warnings
import os
warnings.filterwarnings("ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
testPath = "./TestWav/"
testPath += "test1.wav"
modelPath = "./model/model_gobal202101172Model_300.h5"
lm_modelPath = "./lm/1610903008_time"
# print("="*5+"开始录音"+"="*5)
# get_Audio(testPath)
source_file = ["./hai_train.txt"]
size = 289
# 数据处理
wav_lst, pany_lst, han_lst = source_get(source_file)
wav_data = wav_lst[:size]
label_data = pany_lst[:size]
han_data = han_lst[:size]
vocab = mk_vocab(label_data)
vocab_size = len(vocab)
han_vocab = mk_vocab(han_data)
shuffle_list = [i for i in range(size)]
am = Amodel(vocab_size)
# loading model
print("===== loading model ======")
print(modelPath)
am.ctc_model.load_weights(modelPath)
# am.ctc_model.summary()
batch = data_generator(1, shuffle_list, wav_lst, label_data, vocab)
pany_vocab = mk_lm_pny_vocab(label_data)
han_vocab = mk_lm_han_vocab(han_data)
def getHanZi(text, i):
lm_args = lm_hparams()
# print(len(pany_vocab))
# print(len(han_vocab))
lm_args.input_vocab_size = len(pany_vocab)
lm_args.label_vocab_size = len(han_vocab)
lm_args.dropout_rate = 0.
print('loading language model...')
lm = Lm(lm_args)
word_num = 0
word_error_num = 0
sess = tf.Session(graph=lm.graph)
with lm.graph.as_default():
saver = tf.train.Saver()
with sess.as_default():
latest = tf.train.latest_checkpoint(lm_modelPath)
saver.restore(sess, latest)
text = ' '.join(text)
with sess.as_default():
text = text.strip('\n').split(' ')
x = np.array([pany_vocab.index(pny) for pny in text])
x = x.reshape(1, -1)
preds = sess.run(lm.preds, {lm.x: x})
if i is not None:
label = han_lst[i]
got = ''.join(han_vocab[idx] for idx in \
preds[0])
# print('原文汉字:', label)
print('语言模型-识别结果:', got)
if i is not None:
word_error_num += min(len(label), GetEditDistance(label, got))
word_num += len(label)
if i is not None:
print('词错误率:', word_error_num / word_num)
return got
# 方法函数
def getText(path=None):
global testPath, label_data
if path is not None:
testPath = path
# 单独数据集:
test_wav_lst = []
test_wav_lst.append(testPath)
print("test_wav_lst \n")
print(test_wav_lst)
label_data = label_data[:1]
test_batch = data_generator(1, shuffle_list, test_wav_lst, label_data, vocab)
test_inputs, test_outputs = next(test_batch)
Ax = test_inputs['the_inputs']
Ay = test_inputs['the_labels'][0]
print("开始识别")
predict = am.model.predict(Ax, steps=1)
result, text = decode_ctc(predict, vocab)
print("数字结果:", result)
print("声学模型-识别结果:", str(text))
print('-----------------------------------------------')
print('-----------------------------------------------')
print('----------------语言模型预测---------------------')
print('-----------------------------------------------')
print('-----------------------------------------------')
# 2.语言模型-------------------------------------------
zi = getHanZi(text, None)
# for i in range(10):
# # 载入训练好的模型,并进行识别
# inputs, outputs = next(batch)
# x = inputs['the_inputs']
# # print("========x")
# # print(x.shape)
# # print("========x")
# y = inputs['the_labels'][0]
# result = am.model.predict(x, steps=1)
# # 将数字结果转化为文本结果
# result, text = decode_ctc(result, vocab)
# print('---------------------------------------------')
# print('数字结果: ', result)
# print('文本结果:', text)
# print('原文结果:', [vocab[int(i)] for i in y])
# getHanZi(text,i)
# print('---------------------------------------------')
return zi
if __name__ == '__main__':
getText()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。