代码拉取完成,页面将自动刷新
# -*- coding: utf-8 -*-
#/usr/bin/python2
'''
By kyubyong park. kbpark.linguist@gmail.com.
https://www.github.com/kyubyong/tacotron
'''
from __future__ import print_function
from hyperparams import Hyperparams as hp
import numpy as np
import tensorflow as tf
from utils import *
import codecs
import re
import os
import unicodedata
def load_vocab():
char2idx = {char: idx for idx, char in enumerate(hp.vocab)}
idx2char = {idx: char for idx, char in enumerate(hp.vocab)}
return char2idx, idx2char
# 文本正则化处理
def text_normalize(text):
text = ''.join(char for char in unicodedata.normalize('NFD', text)
if unicodedata.category(char) != 'Mn') # Strip accents
text = text.lower()
text = re.sub("[^{}]".format(hp.vocab), " ", text)
text = re.sub("[ ]+", " ", text)
return text
# 加载数据
def load_data(mode="train"):
# Load vocabulary
char2idx, idx2char = load_vocab()
if mode in ("train", "eval"):
# Parse 文件路径,句子长度,文本
fpaths, text_lengths, texts = [], [], []
transcript = os.path.join(hp.data, 'metadata.csv')
# 获取所有的行数
lines = codecs.open(transcript, 'r', 'utf-8').readlines()
total_hours = 0
if mode=="train":
lines = lines[1:] # 获取除了第一行的所有数据
else: # We attack only one sample!
lines = lines[:1]
for line in lines:
fname, _, text = line.strip().split("|") # 按照 | 分割
fpath = os.path.join(hp.data, "wavs", fname + ".wav")
fpaths.append(fpath) # 所有文件的路径
text = text_normalize(text) + "E" # E: EOS
text = [char2idx[char] for char in text] # 把字母全部转数字
text_lengths.append(len(text)) # 添加句子长度
texts.append(np.array(text, np.int32).tostring())
return fpaths, text_lengths, texts
else:
# Parse
lines = codecs.open(hp.test_data, 'r', 'utf-8').readlines()[1:]
sents = [text_normalize(line.split(" ", 1)[-1]).strip() + "E" for line in lines] # text normalization, E: EOS
lengths = [len(sent) for sent in sents]
maxlen = sorted(lengths, reverse=True)[0]
texts = np.zeros((len(sents), maxlen), np.int32)
for i, sent in enumerate(sents):
texts[i, :len(sent)] = [char2idx[char] for char in sent]
return texts
def get_batch():
"""Loads training data and put them in queues"""
with tf.device('/cpu:0'):
# Load data
fpaths, text_lengths, texts = load_data() # list
maxlen, minlen = max(text_lengths), min(text_lengths)
# Calc total batch count
num_batch = len(fpaths) // hp.batch_size
fpaths = tf.convert_to_tensor(fpaths)
text_lengths = tf.convert_to_tensor(text_lengths)
texts = tf.convert_to_tensor(texts)
# Create Queues
fpath, text_length, text = tf.train.slice_input_producer([fpaths, text_lengths, texts], shuffle=True)
# Parse
text = tf.decode_raw(text, tf.int32) # (None,)
if hp.prepro:
def _load_spectrograms(fpath):
fname = os.path.basename(fpath)
mel = "mels/{}".format(fname.replace("wav", "npy"))
mag = "mags/{}".format(fname.replace("wav", "npy"))
return fname, np.load(mel), np.load(mag)
fname, mel, mag = tf.py_func(_load_spectrograms, [fpath], [tf.string, tf.float32, tf.float32])
else:
fname, mel, mag = tf.py_func(load_spectrograms, [fpath], [tf.string, tf.float32, tf.float32]) # (None, n_mels)
# Add shape information
fname.set_shape(())
text.set_shape((None,))
mel.set_shape((None, hp.n_mels*hp.r))
mag.set_shape((None, hp.n_fft//2+1))
# Batching
_, (texts, mels, mags, fnames) = tf.contrib.training.bucket_by_sequence_length(
input_length=text_length,
tensors=[text, mel, mag, fname],
batch_size=hp.batch_size,
bucket_boundaries=[i for i in range(minlen + 1, maxlen - 1, 20)],
num_threads=16,
capacity=hp.batch_size * 4,
dynamic_pad=True)
return texts, mels, mags, fnames, num_batch
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。