代码拉取完成,页面将自动刷新
同步操作将从 沉沦/DeepSpeechRecognition 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import os
import difflib
import numpy as np
import tensorflow as tf
import scipy.io.wavfile as wav
from tqdm import tqdm
from scipy.fftpack import fft
from python_speech_features import mfcc
from random import shuffle
from keras import backend as K
def data_hparams():
params = tf.contrib.training.HParams(
# vocab
data_type='train',
data_path='data/',
thchs30=True,
aishell=True,
prime=True,
stcmd=True,
batch_size=1,
data_length=10,
shuffle=True)
return params
class get_data():
def __init__(self, args):
self.data_type = args.data_type
self.data_path = args.data_path
self.thchs30 = args.thchs30
self.aishell = args.aishell
self.prime = args.prime
self.stcmd = args.stcmd
self.data_length = args.data_length
self.batch_size = args.batch_size
self.shuffle = args.shuffle
self.source_init()
def source_init(self):
print('get source list...')
read_files = []
if self.data_type == 'train':
if self.thchs30 == True:
read_files.append('thchs_train.txt')
if self.aishell == True:
read_files.append('aishell_train.txt')
if self.prime == True:
read_files.append('prime.txt')
if self.stcmd == True:
read_files.append('stcmd.txt')
elif self.data_type == 'dev':
if self.thchs30 == True:
read_files.append('thchs_dev.txt')
if self.aishell == True:
read_files.append('aishell_dev.txt')
elif self.data_type == 'test':
if self.thchs30 == True:
read_files.append('thchs_test.txt')
if self.aishell == True:
read_files.append('aishell_test.txt')
self.wav_lst = []
self.pny_lst = []
self.han_lst = []
for file in read_files:
print('load ', file, ' data...')
sub_file = 'data/' + file
with open(sub_file, 'r', encoding='utf8') as f:
data = f.readlines()
for line in tqdm(data):
wav_file, pny, han = line.split('\t')
self.wav_lst.append(wav_file)
self.pny_lst.append(pny.split(' '))
self.han_lst.append(han.strip('\n'))
if self.data_length:
self.wav_lst = self.wav_lst[:self.data_length]
self.pny_lst = self.pny_lst[:self.data_length]
self.han_lst = self.han_lst[:self.data_length]
print('make am vocab...')
self.am_vocab = self.mk_am_vocab(self.pny_lst)
print('make lm pinyin vocab...')
self.pny_vocab = self.mk_lm_pny_vocab(self.pny_lst)
print('make lm hanzi vocab...')
self.han_vocab = self.mk_lm_han_vocab(self.han_lst)
def get_am_batch(self):
shuffle_list = [i for i in range(len(self.wav_lst))]
while 1:
if self.shuffle == True:
shuffle(shuffle_list)
for i in range(len(self.wav_lst) // self.batch_size):
wav_data_lst = []
label_data_lst = []
begin = i * self.batch_size
end = begin + self.batch_size
sub_list = shuffle_list[begin:end]
for index in sub_list:
fbank = compute_fbank(self.data_path + self.wav_lst[index])
pad_fbank = np.zeros((fbank.shape[0] // 8 * 8 + 8, fbank.shape[1]))
pad_fbank[:fbank.shape[0], :] = fbank
label = self.pny2id(self.pny_lst[index], self.am_vocab)
label_ctc_len = self.ctc_len(label)
if pad_fbank.shape[0] // 8 >= label_ctc_len:
wav_data_lst.append(pad_fbank)
label_data_lst.append(label)
pad_wav_data, input_length = self.wav_padding(wav_data_lst)
pad_label_data, label_length = self.label_padding(label_data_lst)
inputs = {'the_inputs': pad_wav_data,
'the_labels': pad_label_data,
'input_length': input_length,
'label_length': label_length,
}
outputs = {'ctc': np.zeros(pad_wav_data.shape[0], )}
yield inputs, outputs
def get_lm_batch(self):
batch_num = len(self.pny_lst) // self.batch_size
for k in range(batch_num):
begin = k * self.batch_size
end = begin + self.batch_size
input_batch = self.pny_lst[begin:end]
label_batch = self.han_lst[begin:end]
max_len = max([len(line) for line in input_batch])
input_batch = np.array(
[self.pny2id(line, self.pny_vocab) + [0] * (max_len - len(line)) for line in input_batch])
label_batch = np.array(
[self.han2id(line, self.han_vocab) + [0] * (max_len - len(line)) for line in label_batch])
yield input_batch, label_batch
def pny2id(self, line, vocab):
return [vocab.index(pny) for pny in line]
def han2id(self, line, vocab):
return [vocab.index(han) for han in line]
def wav_padding(self, wav_data_lst):
wav_lens = [len(data) for data in wav_data_lst]
wav_max_len = max(wav_lens)
wav_lens = np.array([leng // 8 for leng in wav_lens])
new_wav_data_lst = np.zeros((len(wav_data_lst), wav_max_len, 200, 1))
for i in range(len(wav_data_lst)):
new_wav_data_lst[i, :wav_data_lst[i].shape[0], :, 0] = wav_data_lst[i]
return new_wav_data_lst, wav_lens
def label_padding(self, label_data_lst):
label_lens = np.array([len(label) for label in label_data_lst])
max_label_len = max(label_lens)
new_label_data_lst = np.zeros((len(label_data_lst), max_label_len))
for i in range(len(label_data_lst)):
new_label_data_lst[i][:len(label_data_lst[i])] = label_data_lst[i]
return new_label_data_lst, label_lens
def mk_am_vocab(self, data):
vocab = []
for line in tqdm(data):
line = line
for pny in line:
if pny not in vocab:
vocab.append(pny)
vocab.append('_')
return vocab
def mk_lm_pny_vocab(self, data):
vocab = ['<PAD>']
for line in tqdm(data):
for pny in line:
if pny not in vocab:
vocab.append(pny)
return vocab
def mk_lm_han_vocab(self, data):
vocab = ['<PAD>']
for line in tqdm(data):
line = ''.join(line.split(' '))
for han in line:
if han not in vocab:
vocab.append(han)
return vocab
def ctc_len(self, label):
add_len = 0
label_len = len(label)
for i in range(label_len - 1):
if label[i] == label[i + 1]:
add_len += 1
return label_len + add_len
# 对音频文件提取mfcc特征
def compute_mfcc(file):
fs, audio = wav.read(file)
mfcc_feat = mfcc(audio, samplerate=fs, numcep=26)
mfcc_feat = mfcc_feat[::3]
mfcc_feat = np.transpose(mfcc_feat)
return mfcc_feat
# 获取信号的时频图
def compute_fbank(file):
x = np.linspace(0, 400 - 1, 400, dtype=np.int64)
w = 0.54 - 0.46 * np.cos(2 * np.pi * (x) / (400 - 1)) # 汉明窗
fs, wavsignal = wav.read(file)
# wav波形 加时间窗以及时移10ms
time_window = 25 # 单位ms
wav_arr = np.array(wavsignal)
range0_end = int(len(wavsignal) / fs * 1000 - time_window) // 10 + 1 # 计算循环终止的位置,也就是最终生成的窗数
data_input = np.zeros((range0_end, 200), dtype=np.float) # 用于存放最终的频率特征数据
data_line = np.zeros((1, 400), dtype=np.float)
for i in range(0, range0_end):
p_start = i * 160
p_end = p_start + 400
data_line = wav_arr[p_start:p_end]
data_line = data_line * w # 加窗
data_line = np.abs(fft(data_line))
data_input[i] = data_line[0:200] # 设置为400除以2的值(即200)是取一半数据,因为是对称的
data_input = np.log(data_input + 1)
# data_input = data_input[::]
return data_input
# word error rate------------------------------------
def GetEditDistance(str1, str2):
leven_cost = 0
s = difflib.SequenceMatcher(None, str1, str2)
for tag, i1, i2, j1, j2 in s.get_opcodes():
if tag == 'replace':
leven_cost += max(i2-i1, j2-j1)
elif tag == 'insert':
leven_cost += (j2-j1)
elif tag == 'delete':
leven_cost += (i2-i1)
return leven_cost
# 定义解码器------------------------------------
def decode_ctc(num_result, num2word):
result = num_result[:, :, :]
in_len = np.zeros((1), dtype = np.int32)
in_len[0] = result.shape[1]
r = K.ctc_decode(result, in_len, greedy = True, beam_width=10, top_paths=1)
r1 = K.get_value(r[0][0])
r1 = r1[0]
text = []
for i in r1:
text.append(num2word[i])
return r1, text
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。