-
Notifications
You must be signed in to change notification settings - Fork 1
/
ds.py
40 lines (34 loc) · 1.59 KB
/
ds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from torch.utils.data import Dataset
from inference.util import AudioUtil
class SoundDS(Dataset):
def __init__(self, data_paths, duration, sample_rate=44100, channel=2):
self.data_paths = data_paths
self.duration = duration # in ms
self.sr = sample_rate
self.channel = channel
# ----------------------------
# Number of items in dataset
# ----------------------------
def __len__(self):
return len(self.data_paths)
# ----------------------------
# Get i'th item in dataset
# ----------------------------
def __getitem__(self, idx):
# Absolute file path of the audio file - concatenate the audio directory with
# the relative path
audio_file = self.data_paths[idx][1]
class_id = self.data_paths[idx][0]
# print(audio_file)
aud = AudioUtil.open(audio_file)
# Some sounds have a higher sample rate, or fewer channels compared to the
# majority. So make all sounds have the same number of channels and same
# sample rate. Unless the sample rate is the same, the pad_trunc will still
# result in arrays of different lengths, even though the sound duration is
# the same.
reaud = AudioUtil.resample(aud, self.sr)
rechan = AudioUtil.rechannel(reaud, self.channel)
dur_aud = AudioUtil.pad_trunc(rechan, self.duration)
sgram = AudioUtil.spectro_gram(dur_aud, n_mels=64, n_fft=1024, hop_len=None)
aug_sgram = AudioUtil.spectro_augment(sgram, max_mask_pct=0.1, n_freq_masks=2, n_time_masks=2)
return aug_sgram, class_id