1 Star 0 Fork 0

Bytedance Inc./SALMONN

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
dataset.py 3.15 KB
一键复制 编辑 原始数据 按行查看 历史
yuwenyi.1 提交于 2024-05-27 23:57 . fix bug in dataset
# Copyright (2024) Tsinghua University, Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import soundfile as sf
import numpy as np
from transformers import WhisperFeatureExtractor
class SALMONNDataset(Dataset):
def __init__(self, ann_path, whisper_path):
super().__init__()
self.annotation = json.load(open(ann_path, "r"))["annotation"]
self.wav_processor = WhisperFeatureExtractor.from_pretrained(whisper_path)
def __len__(self):
return len(self.annotation)
def collater(self, samples):
samples_spectrogram = [s["spectrogram"] for s in samples]
cat_spectrogram = torch.stack(samples_spectrogram, dim=0)
raw_wav = [torch.from_numpy(s["raw_wav"]) for s in samples]
raw_wav_length = torch.tensor([len(s["raw_wav"]) for s in samples])
raw_wav = pad_sequence(raw_wav, batch_first=True, padding_value=0)
paddding_mask = torch.arange(raw_wav.size(1)).unsqueeze(0) >= raw_wav_length.unsqueeze(1)
text = [s["text"] for s in samples]
task = [s["task"] for s in samples]
Q = [s["Q"] for s in samples]
id = [s["id"] for s in samples]
return {
"spectrogram": cat_spectrogram,
"raw_wav": raw_wav,
"padding_mask": paddding_mask,
"text": text,
"task": task,
"Q": Q,
"id": id,
}
def __getitem__(self, index):
ann = self.annotation[index]
audio, sr = sf.read(ann["path"])
if len(audio.shape) == 2: # stereo to mono
audio = audio[:, 0]
if "expand_wav" in ann:
for p in ann["expand_wav"]:
expand_audio, _ = sf.read(p)
if len(expand_audio.shape) == 2:
expand_audio = expand_audio[:, 0]
sil = np.zeros(1600, dtype=float)
audio = np.concatenate((audio, sil, expand_audio), axis=0)
if len(audio) < sr: # pad audio to at least 1s
sil = np.zeros(sr - len(audio), dtype=float)
audio = np.concatenate((audio, sil), axis=0)
audio = audio[: sr * 30] # truncate audio to at most 30s
spectrogram = self.wav_processor(audio, sampling_rate=sr, return_tensors="pt")["input_features"].squeeze()
text = ann["text"]
task = ann.get("task", "asr")
Q = ann.get("Q", "")
return {
"spectrogram": spectrogram,
"raw_wav": audio,
"text": text,
"task": task,
"Q": Q,
"id": ann["path"],
}
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ByteDance/SALMONN.git
git@gitee.com:ByteDance/SALMONN.git
ByteDance
SALMONN
SALMONN
main

搜索帮助