1 Star 0 Fork 0

KunCheng-He/Light-SERNet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
data_set.py 2.91 KB
一键复制 编辑 原始数据 按行查看 历史
KunCheng-He 提交于 2022-09-08 20:53 . 取消打乱测试集
"""
模型需要加载的数据
"""
import os
import numpy as np
import hyperparameters
from torch.utils.data import Dataset
import torchaudio
from torchaudio.transforms import MFCC
def label_to_index(labels_list, label):
""" 将情感标签转为对应的索引
inputs:
labels_list: 所有分类的列表
label: 其中一个分类
output:
输出这个分类对应的索引
"""
return labels_list.index(label)
def index_to_label(labels_list, index):
""" 将对应的索引转为具体的分类
inputs:
labels_list: 所有分类的列表
index: 索引
output:
输出这个索引对应的分类
"""
return labels_list[index]
class AudioDataset(Dataset):
""" 构建音频数据集
inputs:
wav_list: 音频文件的路径列表
label_list: 音频文件对应的标签(两个列表内容的顺序一一对应)
"""
def __init__(self, wav_list, label_list) -> None:
self.wav_list = wav_list
self.label_list = label_list
self.MFCC = MFCC(
sample_rate=hyperparameters.SAMPLE_RATE, n_mfcc=hyperparameters.N_MFCC,
melkwargs={
"n_fft": hyperparameters.N_FTT,
"n_mels": hyperparameters.N_MELS,
"hop_length": hyperparameters.HOP_LENGTH,
"mel_scale": "htk"
}
)
def __getitem__(self, index):
filename = self.wav_list[index]
label = self.label_list[index]
waveform, sample_rate = torchaudio.load(filename)
return self.MFCC(waveform), label_to_index(hyperparameters.CASIA_LABELS, label)
def __len__(self):
return len(self.wav_list)
def make_dataset(data_segment_path):
""" 构造数据集
inputs:
data_segment_path: 划分完成数据集的路径
"""
# 获取训练集与测试集所有文件的路径
train_datas, train_labels = [], []
test_datas, test_labels = [], []
for i in ["train", "test"]:
for path, dirs, filenames in os.walk(os.path.join(data_segment_path, i)):
for filename in filenames:
if i == "train":
train_datas.append(os.path.join(path, filename))
train_labels.append(path.split(os.path.sep)[-1])
else:
test_datas.append(os.path.join(path, filename))
test_labels.append(path.split(os.path.sep)[-1])
# 将训练集的顺序随机打乱
index = list(range(len(train_datas)))
np.random.shuffle(index)
train_datas = [train_datas[i] for i in index]
train_labels = [train_labels[i] for i in index]
# 将打乱后的列表传入数据集构造的类,直接返回
return AudioDataset(train_datas, train_labels), AudioDataset(test_datas, test_labels)
if __name__ == "__main__":
train_dataset, test_dataset = make_dataset("data/casia_4_segment")
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/byack/sernet.git
git@gitee.com:byack/sernet.git
byack
sernet
Light-SERNet
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385