代码拉取完成,页面将自动刷新
"""
模型需要加载的数据
"""
import os
import numpy as np
import hyperparameters
from torch.utils.data import Dataset
import torchaudio
from torchaudio.transforms import MFCC
def label_to_index(labels_list, label):
""" 将情感标签转为对应的索引
inputs:
labels_list: 所有分类的列表
label: 其中一个分类
output:
输出这个分类对应的索引
"""
return labels_list.index(label)
def index_to_label(labels_list, index):
""" 将对应的索引转为具体的分类
inputs:
labels_list: 所有分类的列表
index: 索引
output:
输出这个索引对应的分类
"""
return labels_list[index]
class AudioDataset(Dataset):
""" 构建音频数据集
inputs:
wav_list: 音频文件的路径列表
label_list: 音频文件对应的标签(两个列表内容的顺序一一对应)
"""
def __init__(self, wav_list, label_list) -> None:
self.wav_list = wav_list
self.label_list = label_list
self.MFCC = MFCC(
sample_rate=hyperparameters.SAMPLE_RATE, n_mfcc=hyperparameters.N_MFCC,
melkwargs={
"n_fft": hyperparameters.N_FTT,
"n_mels": hyperparameters.N_MELS,
"hop_length": hyperparameters.HOP_LENGTH,
"mel_scale": "htk"
}
)
def __getitem__(self, index):
filename = self.wav_list[index]
label = self.label_list[index]
waveform, sample_rate = torchaudio.load(filename)
return self.MFCC(waveform), label_to_index(hyperparameters.CASIA_LABELS, label)
def __len__(self):
return len(self.wav_list)
def make_dataset(data_segment_path):
""" 构造数据集
inputs:
data_segment_path: 划分完成数据集的路径
"""
# 获取训练集与测试集所有文件的路径
train_datas, train_labels = [], []
test_datas, test_labels = [], []
for i in ["train", "test"]:
for path, dirs, filenames in os.walk(os.path.join(data_segment_path, i)):
for filename in filenames:
if i == "train":
train_datas.append(os.path.join(path, filename))
train_labels.append(path.split(os.path.sep)[-1])
else:
test_datas.append(os.path.join(path, filename))
test_labels.append(path.split(os.path.sep)[-1])
# 将训练集的顺序随机打乱
index = list(range(len(train_datas)))
np.random.shuffle(index)
train_datas = [train_datas[i] for i in index]
train_labels = [train_labels[i] for i in index]
# 将打乱后的列表传入数据集构造的类,直接返回
return AudioDataset(train_datas, train_labels), AudioDataset(test_datas, test_labels)
if __name__ == "__main__":
train_dataset, test_dataset = make_dataset("data/casia_4_segment")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。