Fetch the repository succeeded.
import os
import torch
import torchaudio
from torch.utils.data import Dataset
from matplotlib import pyplot as plt
class AudioDirId():
""" 存储音频文件夹的路径和ID
(方便后续无需从路径上在分离ID)
"""
def __init__(self, dir, idx) -> None:
self.dir = dir
self.id = idx
class AudioDataset(Dataset):
""" 构建音频数据集
root: 数据集路径,该路径下包含按地区分类的文件夹
use_dr: 是否仅使用指定地区的数据,默认 False 不使用
dr: 如果使用地区,指定地区名称,默认 None
"""
def __init__(self, root, use_dr=False, dr=None) -> None:
super().__init__()
self.audio_list = [] # 所有音频的文件所在文件夹的路径,如:./.../DR1/0001
if use_dr: # 只使用指定地区
p = os.path.join(root, dr)
for i in os.listdir(p):
self.audio_list.append(AudioDirId(os.path.join(p, i), i))
else: # 默认将所有地区的数据混合在一起训练
for d in os.listdir(root):
p = os.path.join(root, d)
for i in os.listdir(p):
self.audio_list.append(AudioDirId(os.path.join(p, i), i))
def __len__(self):
""" 返回整体数据的数量 """
return len(self.audio_list)
def __getitem__(self, index):
""" 根据 index 返回指定的数据 """
audio = self.audio_list[index]
# mix audio
wf, sr = torchaudio.load(os.path.join(audio.dir, f"{audio.id}-mix.wav"))
mix_audio = wf.reshape(-1)
# target audio
wf1, sr = torchaudio.load(os.path.join(audio.dir, f"{audio.id}-s1.wav"))
wf2, sr = torchaudio.load(os.path.join(audio.dir, f"{audio.id}-s2.wav"))
target_audio = torch.cat((wf1, wf2), dim=0)
return mix_audio, target_audio
def plot_line(data, root):
""" 绘制折线图
data: 数据
root: 存储的文件名
"""
plt.figure(figsize=(20, 8), dpi=80)
plt.plot(range(len(data)), data)
plt.title(root)
plt.savefig(f"./result/{root}.png")
plt.close()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。