openmusic / audioldm_train /modules /audiomae /audiovisual_dataset.py
jadechoghari's picture
add qa files
5085882 verified
import json
import random
from tqdm import tqdm
import torch
import decord
decord.bridge.set_bridge("torch")
import torchaudio
from math import ceil
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
class AudioVisualDataset(Dataset):
"""Can sample data from audio-visual databases
Params:
min_video_frames: used to drop short video clips
video_resize: resize for CLIP processing
sampling_rate: audio sampling rate
max_clip_len: max length (seconds) of audiovisual clip to be sampled
num_sample_frames: number of image frames to be uniformly sampled from video
"""
def __init__(
self,
datafiles=[
"/mnt/bn/data-xubo/dataset/audioset_videos/datafiles/audioset_balanced_train.json",
],
min_video_frames=30,
video_resize=[224, 224],
sampling_rate=16000,
sample_av_clip=True,
max_clip_len=10,
num_sample_frames=10,
# hyparameters used for SpecAug
freqm=48,
timem=192,
return_label=False,
):
all_data_json = []
for datafile in datafiles:
with open(datafile, "r") as fp:
data_json = json.load(fp)["data"]
all_data_json.extend(data_json)
# drop short video clips
self.all_data_json = [
data
for data in all_data_json
if int(data["video_shape"][0]) >= min_video_frames
]
self.max_clip_len = max_clip_len
self.video_resize = video_resize
self.sampling_rate = sampling_rate
self.sample_av_clip = sample_av_clip
self.num_sample_frames = num_sample_frames
self.corresponding_audio_len = self.sampling_rate * self.max_clip_len
# hyparameters used for AudioMAE
self.freqm = freqm
self.timem = timem
self.norm_mean = -4.2677393
self.norm_std = 4.5689974
self.melbins = 128
self.TARGET_LEN = 1024
self.return_label = return_label
if self.return_label:
self.audioset_label2idx = self._prepare_audioset()
def __len__(self):
return len(self.all_data_json)
def _read_audio_video(self, index):
try:
video_path = self.all_data_json[index]["mp4"]
# read audio
ar = decord.AudioReader(
video_path, sample_rate=self.sampling_rate, mono=True
)
# read video frames
vr = decord.VideoReader(
video_path,
height=self.video_resize[0],
width=self.video_resize[1],
)
labels = self.all_data_json[index]["labels"]
return vr, ar, labels
except Exception as e:
print(f"error: {e} occurs, when loading {video_path}")
random_index = random.randint(0, len(self.all_data_json) - 1)
return self._read_audio_video(index=random_index)
def _prepare_audioset(self):
df1 = pd.read_csv(
"/mnt/bn/lqhaoheliu/datasets/audioset/metadata/class_labels_indices.csv",
delimiter=",",
skiprows=0,
)
label_set = df1.to_numpy()
code2id = {}
for i in range(len(label_set)):
code2id[label_set[i][1]] = label_set[i][0]
return code2id
def __getitem__(self, index):
# read audio and video
vr, ar, labels = self._read_audio_video(index)
# create a audio tensor
audio_data = ar[:] # [1, samples]
audio_len = audio_data.shape[1] / self.sampling_rate
audio_data = audio_data.squeeze(0) # [samples]
# create a video tensor
full_vid_length = len(vr)
video_rate = ceil(vr.get_avg_fps())
samples_per_frame = float(self.sampling_rate) / video_rate
start_frame = 0
# sample video clip
if audio_len > self.max_clip_len and self.sample_av_clip:
start_frame = random.randint(
0, max(full_vid_length - video_rate * self.max_clip_len, 0)
)
end_frame = min(start_frame + video_rate * self.max_clip_len, full_vid_length)
video_data = vr.get_batch(range(start_frame, end_frame))
# sample audio clip
if audio_len > self.max_clip_len and self.sample_av_clip:
# corresponding_audio_len = int(video_data.size()[0] * samples_per_frame)
corresponding_audio_start = int(start_frame * samples_per_frame)
audio_data = audio_data[corresponding_audio_start:]
# cut or pad audio clip with respect to the sampled video clip
if audio_data.shape[0] < self.corresponding_audio_len:
zero_data = torch.zeros(self.corresponding_audio_len)
zero_data[: audio_data.shape[0]] = audio_data
audio_data = zero_data
elif audio_data.shape[0] > self.corresponding_audio_len:
audio_data = audio_data[: self.corresponding_audio_len]
# uniformly sample image frames from video [tentative solution]
interval = video_data.shape[0] // self.num_sample_frames
video_data = video_data[::interval][: self.num_sample_frames]
assert (
video_data.shape[0] == self.num_sample_frames
), f"number of sampled image frames is {video_data.shape[0]}"
assert (
audio_data.shape[0] == self.corresponding_audio_len
), f"number of audio samples is {audio_data.shape[0]}"
# video transformation
video_data = video_data / 255.0
video_data = video_data.permute(0, 3, 1, 2) # [N, H, W, C] -> [N, C, H, W]
# calculate mel fbank of waveform for audio encoder
audio_data = audio_data.unsqueeze(0) # [1, samples]
audio_data = audio_data - audio_data.mean()
fbank = torchaudio.compliance.kaldi.fbank(
audio_data,
htk_compat=True,
sample_frequency=self.sampling_rate,
use_energy=False,
window_type="hanning",
num_mel_bins=self.melbins,
dither=0.0,
frame_shift=10,
)
# cut and pad
n_frames = fbank.shape[0]
p = self.TARGET_LEN - n_frames
if p > 0:
m = torch.nn.ZeroPad2d((0, 0, 0, p))
fbank = m(fbank)
elif p < 0:
fbank = fbank[0 : self.TARGET_LEN, :]
# SpecAug for training (not for eval)
freqm = torchaudio.transforms.FrequencyMasking(self.freqm)
timem = torchaudio.transforms.TimeMasking(self.timem)
fbank = fbank.transpose(0, 1).unsqueeze(0) # 1, 128, 1024 (...,freq,time)
if self.freqm != 0:
fbank = freqm(fbank)
if self.timem != 0:
fbank = timem(fbank) # (..., freq, time)
fbank = torch.transpose(fbank.squeeze(), 0, 1) # time, freq
fbank = (fbank - self.norm_mean) / (self.norm_std * 2)
fbank = fbank.unsqueeze(0)
if self.return_label:
# get audioset lebel indexes
label_indices = np.zeros(527)
for label_str in labels.split(","):
label_indices[int(self.audioset_label2idx[label_str])] = 1.0
label_indices = torch.FloatTensor(label_indices)
data_dict = {
"labels": label_indices,
"images": video_data,
"fbank": fbank,
# 'modality': 'audio_visual'
}
else:
data_dict = {
"images": video_data,
"fbank": fbank,
# 'modality': 'audio_visual'
}
return data_dict
def collate_fn(list_data_dict):
r"""Collate mini-batch data to inputs and targets for training.
Args:
list_data_dict: e.g., [
{'vocals': (channels_num, segment_samples),
'accompaniment': (channels_num, segment_samples),
'mixture': (channels_num, segment_samples)
},
{'vocals': (channels_num, segment_samples),
'accompaniment': (channels_num, segment_samples),
'mixture': (channels_num, segment_samples)
},
...]
Returns:
data_dict: e.g. {
'vocals': (batch_size, channels_num, segment_samples),
'accompaniment': (batch_size, channels_num, segment_samples),
'mixture': (batch_size, channels_num, segment_samples)
}
"""
data_dict = {}
for key in list_data_dict[0].keys():
# for key in ['waveform']:
# try:
data_dict[key] = [data_dict[key] for data_dict in list_data_dict]
# except:
# from IPython import embed; embed(using=False); os._exit(0)
data_dict[key] = torch.stack(data_dict[key])
return data_dict