Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
sys.path.append("src") | |
import os | |
import math | |
import pandas as pd | |
import zlib | |
import yaml | |
import qa_mdt.audioldm_train.utilities.audio as Audio | |
from qa_mdt.audioldm_train.utilities.tools import load_json | |
from qa_mdt.audioldm_train.dataset_plugin import * | |
import librosa | |
from librosa.filters import mel as librosa_mel_fn | |
import threading | |
import random | |
import lmdb | |
from torch.utils.data import Dataset | |
import torch.nn.functional | |
import torch | |
from pydub import AudioSegment | |
import numpy as np | |
import torchaudio | |
import io | |
import json | |
from .datum_all_pb2 import Datum_all as Datum_lmdb | |
from .datum_mos_pb2 import Datum_mos as Datum_lmdb_mos | |
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): | |
return torch.log(torch.clamp(x, min=clip_val) * C) | |
def dynamic_range_decompression_torch(x, C=1): | |
return torch.exp(x) / C | |
def spectral_normalize_torch(magnitudes): | |
output = dynamic_range_compression_torch(magnitudes) | |
return output | |
def spectral_de_normalize_torch(magnitudes): | |
output = dynamic_range_decompression_torch(magnitudes) | |
return output | |
class AudioDataset(Dataset): | |
def __init__( | |
self, | |
config, | |
lmdb_path, | |
key_path, | |
mos_path, | |
lock=True | |
): | |
self.config = config | |
# self.lock = threading.Lock() | |
""" | |
Dataset that manages audio recordings | |
""" | |
self.pad_wav_start_sample = 0 | |
self.trim_wav = False | |
self.build_setting_parameters() | |
self.build_dsp() | |
self.lmdb_path = [_.encode("utf-8") for _ in lmdb_path] | |
self.lmdb_env = [lmdb.open(_, readonly=True, lock=False) for _ in self.lmdb_path] | |
self.mos_txn_env = lmdb.open(mos_path, readonly=True, lock=False) | |
self.key_path = [_.encode("utf-8") for id, _ in enumerate(key_path)] | |
self.keys = [] | |
for _ in range(len(key_path)): | |
with open(self.key_path[_]) as f: | |
for line in f: | |
key = line.strip() | |
self.keys.append((_, key.split()[0].encode('utf-8'))) | |
# only for test !!! | |
# if _ > 20: | |
# break | |
# self.keys : [(id, key), ..., ...] | |
# self.lmdb_env = lmdb.open(self.lmdb_path, readonly=True, lock=False) | |
# self.txn = self.lmdb_env.begin() | |
print(f"Dataset initialize finished, dataset_length : {len(self.keys)}") | |
print(f"Initialize of filter start: ") | |
with open('filter_all.lst', 'r') as f: | |
self.filter = {} | |
for _ in f.readlines(): | |
self.filter[_.strip()] = 1 | |
print(f"Initialize of filter finished") | |
#print(f"Initialize of fusion start: ") | |
#with open('new_file.txt', 'r') as f: | |
# self.refined_caption = {} | |
# for _ in f.readlines(): | |
# try: | |
# a, b = _.strip().split("@") | |
# b = b.strip('"\n') | |
# b = b.replace('\n', ',') | |
# self.refined_caption[a] = b | |
# except: | |
# pass | |
#print(f"Initialize of fusion finished") | |
def __getitem__(self, index): | |
( | |
# name of file, while we use dir of fine here | |
fname, | |
# wav of sr = 16000 | |
waveform, | |
# mel | |
stft, | |
# log mel | |
log_mel_spec, | |
label_vector, | |
# donot start at the begining | |
random_start, | |
# dict or single string which describes the wav file | |
caption, | |
# mos score for single music clip | |
mos | |
) = self.feature_extraction(index) | |
data = { | |
"text": [caption], # list ... dict ? | |
"fname": [fname], # list | |
# tensor, [batchsize, 1, samples_num] | |
"waveform": "" if (waveform is None) else waveform.float(), | |
# tensor, [batchsize, t-steps, f-bins] | |
"stft": "" if (stft is None) else stft.float(), | |
# tensor, [batchsize, t-steps, mel-bins] | |
"log_mel_spec": "" if (log_mel_spec is None) else log_mel_spec.float(), | |
"duration": self.duration, | |
"sampling_rate": self.sampling_rate, | |
"random_start_sample_in_original_audio_file": random_start, | |
"label_vector": label_vector, | |
"mos":mos | |
} | |
if data["text"] is None: | |
print("Warning: The model return None on key text", fname) | |
data["text"] = "" | |
return data | |
def __len__(self): | |
return len(self.keys) | |
def feature_extraction(self, index): | |
if index > len(self.keys) - 1: | |
print( | |
"The index of the dataloader is out of range: %s/%s" | |
% (index, len(self.data)) | |
) | |
index = random.randint(0, len(self.keys) - 1) | |
waveform = np.array([]) | |
tyu = 0 | |
flag = 0 | |
last_index = index | |
while(flag == 0): | |
id_, k = self.keys[index] | |
try: | |
if self.filter[k.decode()] == 1: | |
index = random.randint(0, len(self.keys) - 1) | |
else: | |
flag = 1 | |
except: | |
flag = 1 | |
index = last_index | |
while len(waveform) < 1000: | |
id_, k = self.keys[index] | |
with self.lmdb_env[id_].begin(write=False) as txn: | |
cursor = txn.cursor() | |
try: | |
cursor.set_key(k) | |
datum_tmp = Datum_lmdb() | |
datum_tmp.ParseFromString(cursor.value()) | |
zobj = zlib.decompressobj() # obj for decompressing data streams that won’t fit into memory at once. | |
decompressed_bytes = zobj.decompress(datum_tmp.wav_file) | |
# decompressed_bytes = zlib.decompress(file) | |
waveform = np.frombuffer(decompressed_bytes, dtype=np.float32) | |
except: | |
tyu += 1 | |
pass | |
tyu += 1 | |
last_index = index | |
index = random.randint(0, len(self.keys) - 1) | |
if tyu > 1: | |
print('error') | |
index = last_index | |
flag = 0 | |
val = 623787092.84794 | |
while (flag == 0): | |
id_, k = self.keys[index] | |
with self.mos_txn_env.begin(write=False) as txn: | |
cursor = txn.cursor() | |
try: | |
if cursor.set_key(k): | |
datum_mos = Datum_lmdb_mos() | |
datum_mos.ParseFromString(cursor.value()) | |
mos = datum_mos.mos | |
else: | |
mos = -1.0 | |
except : | |
mos = -1.0 | |
if 'pixa_' in k.decode() or 'ifly_' in k.decode(): | |
mos = 5.0 | |
if np.random.rand() < math.exp(5.0 * mos) / val: | |
flag = 1 | |
last_index = index | |
index = random.randint(0, len(self.keys) - 1) | |
index = last_index | |
caption_original = datum_tmp.caption_original | |
try: | |
caption_generated = datum_tmp.caption_generated[0] | |
except: | |
caption_generated = 'None' | |
assert len(caption_generated) > 1 | |
caption_original = caption_original.lower() | |
caption_generated = caption_generated.lower() | |
caption = 'music' | |
if ("msd_" in k.decode()): | |
caption = caption_generated if caption_original == "none" else caption_original | |
elif ("audioset_" in k.decode()): | |
caption = caption_generated if caption_generated != "none" else caption_original | |
elif ("mtt_" in k.decode()): | |
caption = caption_generated if caption_original == "none" else caption_original | |
elif ("fma_" in k.decode()): | |
caption = caption_generated if caption_generated != "none" else caption_original | |
elif ("pixa_" in k.decode() or "ifly_" in k.decode()): | |
caption = caption_generated if caption_generated != "none" else caption_original | |
else: | |
caption = caption_original | |
prefix = 'medium quality' | |
if ("pixa_" in k.decode() or "ifly_" in k.decode()): | |
if caption == 'none': | |
prefix = 'high quality' | |
caption = '' | |
else: | |
prefix = 'high quality' | |
mos = 5.00 | |
else: | |
mos = float(mos) | |
if mos > 3.55 and mos < 4.05: | |
prefix = "medium quality" | |
elif mos >= 4.05: | |
prefix = "high quality" | |
elif mos <= 3.55: | |
prefix = "low quality" | |
else: | |
print(f'mos score for key : {k.decode()} miss, please check') | |
#if 'low quality' or 'quality is low' in caption: | |
# prefix = 'low quality' | |
caption = prefix + ', ' + caption | |
miu = 3.80 | |
sigma = 0.20 | |
if miu - 2 * sigma <= mos < miu - sigma: | |
vq_mos = 2 | |
elif miu - sigma <= mos < miu + sigma: | |
vq_mos = 3 | |
elif miu + sigma <= mos < miu + 2 * sigma: | |
vq_mos = 4 | |
elif mos >= miu + 2 * sigma: | |
vq_mos = 5 | |
else: | |
vq_mos = 1 | |
""" | |
tags = datum_tmp.tags.decode() | |
caption_writing = datum_tmp.caption_writing.decode() | |
caption_paraphrase = datum_tmp.caption_paraphrase.decode() | |
caption_attribute_prediction = datum_tmp.caption_attribute_prediction.decode() | |
caption_summary = datum_tmp.caption_summary.decode() | |
""" | |
( | |
log_mel_spec, | |
stft, | |
waveform, | |
random_start, | |
) = self.read_audio_file(waveform, k.decode()) | |
fname = self.keys[index] | |
# t_step = log_mel_spec.size(0) | |
# waveform = torch.FloatTensor(waveform[..., : int(self.hopsize * t_step)]) | |
waveform = torch.FloatTensor(waveform) | |
label_vector = torch.FloatTensor(np.zeros(0, dtype=np.float32)) | |
# finally: | |
# self.lock.release() | |
# import pdb | |
# pdb.set_trace() | |
return ( | |
fname, | |
waveform, | |
stft, | |
log_mel_spec, | |
label_vector, | |
random_start, | |
caption, | |
vq_mos | |
) | |
def build_setting_parameters(self): | |
# Read from the json config | |
self.melbins = self.config["preprocessing"]["mel"]["n_mel_channels"] | |
# self.freqm = self.config["preprocessing"]["mel"]["freqm"] | |
# self.timem = self.config["preprocessing"]["mel"]["timem"] | |
self.sampling_rate = self.config["preprocessing"]["audio"]["sampling_rate"] | |
self.hopsize = self.config["preprocessing"]["stft"]["hop_length"] | |
self.duration = self.config["preprocessing"]["audio"]["duration"] | |
self.target_length = int(self.duration * self.sampling_rate / self.hopsize) | |
self.mixup = self.config["augmentation"]["mixup"] | |
# Calculate parameter derivations | |
# self.waveform_sample_length = int(self.target_length * self.hopsize) | |
# if (self.config["balance_sampling_weight"]): | |
# self.samples_weight = np.loadtxt( | |
# self.config["balance_sampling_weight"], delimiter="," | |
# ) | |
# if "train" not in self.split: | |
# self.mixup = 0.0 | |
# # self.freqm = 0 | |
# # self.timem = 0 | |
def build_dsp(self): | |
self.mel_basis = {} | |
self.hann_window = {} | |
self.filter_length = self.config["preprocessing"]["stft"]["filter_length"] | |
self.hop_length = self.config["preprocessing"]["stft"]["hop_length"] | |
self.win_length = self.config["preprocessing"]["stft"]["win_length"] | |
self.n_mel = self.config["preprocessing"]["mel"]["n_mel_channels"] | |
self.sampling_rate = self.config["preprocessing"]["audio"]["sampling_rate"] | |
self.mel_fmin = self.config["preprocessing"]["mel"]["mel_fmin"] | |
self.mel_fmax = self.config["preprocessing"]["mel"]["mel_fmax"] | |
self.STFT = Audio.stft.TacotronSTFT( | |
self.config["preprocessing"]["stft"]["filter_length"], | |
self.config["preprocessing"]["stft"]["hop_length"], | |
self.config["preprocessing"]["stft"]["win_length"], | |
self.config["preprocessing"]["mel"]["n_mel_channels"], | |
self.config["preprocessing"]["audio"]["sampling_rate"], | |
self.config["preprocessing"]["mel"]["mel_fmin"], | |
self.config["preprocessing"]["mel"]["mel_fmax"], | |
) | |
def resample(self, waveform, sr): | |
waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate) | |
# waveform = librosa.resample(waveform, sr, self.sampling_rate) | |
return waveform | |
# if sr == 16000: | |
# return waveform | |
# if sr == 32000 and self.sampling_rate == 16000: | |
# waveform = waveform[::2] | |
# return waveform | |
# if sr == 48000 and self.sampling_rate == 16000: | |
# waveform = waveform[::3] | |
# return waveform | |
# else: | |
# raise ValueError( | |
# "We currently only support 16k audio generation. You need to resample you audio file to 16k, 32k, or 48k: %s, %s" | |
# % (sr, self.sampling_rate) | |
# ) | |
def normalize_wav(self, waveform): | |
waveform = waveform - np.mean(waveform) | |
waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) | |
return waveform * 0.5 # Manually limit the maximum amplitude into 0.5 | |
def random_segment_wav(self, waveform, target_length): | |
waveform = torch.tensor(waveform) | |
waveform = waveform.unsqueeze(0) | |
waveform_length = waveform.shape[-1] | |
# assert waveform_length > 100, "Waveform is too short, %s" % waveform_length | |
if waveform_length < 100: | |
waveform = torch.nn.functional.pad(waveform, (0, target_length - waveform_length)) | |
# Too short | |
if (waveform_length - target_length) <= 0: | |
return waveform, 0 | |
for i in range(10): | |
random_start = int(self.random_uniform(0, waveform_length - target_length)) | |
if torch.max( | |
torch.abs(waveform[:, random_start : random_start + target_length]) | |
> 1e-4 | |
): | |
break | |
return waveform[:, random_start : random_start + target_length], random_start | |
def pad_wav(self, waveform, target_length): | |
# print(waveform) | |
# import pdb | |
# pdb.set_trace() | |
waveform_length = waveform.shape[-1] | |
# assert waveform_length > 100, "Waveform is too short, %s" % waveform_length | |
if waveform_length < 100: | |
waveform = torch.nn.functional.pad(waveform, (0, target_length - waveform_length)) | |
if waveform_length == target_length: | |
return waveform | |
# Pad | |
temp_wav = np.zeros((1, target_length), dtype=np.float32) | |
if self.pad_wav_start_sample is None: | |
rand_start = int(self.random_uniform(0, target_length - waveform_length)) | |
else: | |
rand_start = 0 | |
temp_wav[:, rand_start : rand_start + waveform_length] = waveform | |
return temp_wav | |
def trim_wav(self, waveform): | |
if np.max(np.abs(waveform)) < 0.0001: | |
return waveform | |
def detect_leading_silence(waveform, threshold=0.0001): | |
chunk_size = 1000 | |
waveform_length = waveform.shape[0] | |
start = 0 | |
while start + chunk_size < waveform_length: | |
if np.max(np.abs(waveform[start : start + chunk_size])) < threshold: | |
start += chunk_size | |
else: | |
break | |
return start | |
def detect_ending_silence(waveform, threshold=0.0001): | |
chunk_size = 1000 | |
waveform_length = waveform.shape[0] | |
start = waveform_length | |
while start - chunk_size > 0: | |
if np.max(np.abs(waveform[start - chunk_size : start])) < threshold: | |
start -= chunk_size | |
else: | |
break | |
if start == waveform_length: | |
return start | |
else: | |
return start + chunk_size | |
start = detect_leading_silence(waveform) | |
end = detect_ending_silence(waveform) | |
return waveform[start:end] | |
def read_wav_file(self, file, k): | |
#zobj = zlib.decompressobj() # obj for decompressing data streams that won’t fit into memory at once. | |
#decompressed_bytes = zobj.decompress(file) | |
# decompressed_bytes = zlib.decompress(file) | |
#waveform = np.frombuffer(decompressed_bytes, dtype=np.float32) | |
waveform = file | |
# # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower | |
# if "msd" in k or "fma" in k: | |
# try: | |
# waveform = torch.tensor([(np.array(file.get_array_of_samples(array_type_override='i')) / 2147483648)], dtype=torch.float32) | |
# except: | |
# waveform = torch.tensor([(np.array(file.get_array_of_samples(array_type_override='h')) / 32768)], dtype=torch.float32) | |
# else: | |
# waveform = torch.tensor([(np.array(file.get_array_of_samples(array_type_override='h')) / 32768)], dtype=torch.float32) | |
# # else: | |
# # raise AttributeError | |
# waveform = torch.tensor([(np.array(file.get_array_of_samples(array_type_override='h')) / 32768)], dtype=torch.float32) | |
# import pdb | |
# pdb.set_trace() | |
sr = 16000 | |
waveform, random_start = self.random_segment_wav( | |
waveform, target_length=int(sr * self.duration) | |
) | |
waveform = self.resample(waveform, sr) | |
# random_start = int(random_start * (self.sampling_rate / sr)) | |
waveform = waveform.numpy()[0, ...] | |
waveform = self.normalize_wav(waveform) | |
if self.trim_wav: | |
waveform = self.trim_wav(waveform) | |
waveform = waveform[None, ...] | |
waveform = self.pad_wav( | |
waveform, target_length=int(self.sampling_rate * self.duration) | |
) | |
return waveform, random_start | |
def mix_two_waveforms(self, waveform1, waveform2): | |
mix_lambda = np.random.beta(5, 5) | |
mix_waveform = mix_lambda * waveform1 + (1 - mix_lambda) * waveform2 | |
return self.normalize_wav(mix_waveform), mix_lambda | |
def read_audio_file(self, file, k): | |
# target_length = int(self.sampling_rate * self.duration) | |
# import pdb | |
# pdb.set_trace() | |
# print(type(file)) | |
waveform, random_start = self.read_wav_file(file, k) | |
# log_mel_spec, stft = self.wav_feature_extraction_torchaudio(waveform) # this line is faster, but this implementation is not aligned with HiFi-GAN | |
log_mel_spec, stft = self.wav_feature_extraction(waveform) | |
return log_mel_spec, stft, waveform, random_start | |
def mel_spectrogram_train(self, y): | |
if torch.min(y) < -1.0: | |
print("train min value is ", torch.min(y)) | |
if torch.max(y) > 1.0: | |
print("train max value is ", torch.max(y)) | |
# import pdb | |
# pdb.set_trace() | |
if self.mel_fmax not in self.mel_basis: | |
# import pdb | |
# pdb.set_trace() | |
mel = librosa_mel_fn( | |
sr=self.sampling_rate, | |
n_fft=self.filter_length, | |
n_mels=self.n_mel, | |
fmin=self.mel_fmin, | |
fmax=self.mel_fmax, | |
) | |
self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)] = ( | |
torch.from_numpy(mel).float().to(y.device) | |
) | |
self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to( | |
y.device | |
) | |
y = torch.nn.functional.pad( | |
y.unsqueeze(1), | |
( | |
int((self.filter_length - self.hop_length) / 2), | |
int((self.filter_length - self.hop_length) / 2), | |
), | |
mode="reflect", | |
) | |
y = y.squeeze(1) | |
# import pdb | |
# pdb.set_trace() | |
stft_spec = torch.stft( | |
y, | |
self.filter_length, | |
hop_length=self.hop_length, | |
win_length=self.win_length, | |
window=self.hann_window[str(y.device)], | |
center=False, | |
pad_mode="reflect", | |
normalized=False, | |
onesided=True, | |
return_complex=True, | |
) | |
stft_spec = torch.abs(stft_spec) | |
mel = spectral_normalize_torch( | |
torch.matmul( | |
self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)], stft_spec | |
) | |
) | |
return mel[0], stft_spec[0] | |
# This one is significantly slower than "wav_feature_extraction_torchaudio" if num_worker > 1 | |
def wav_feature_extraction(self, waveform): | |
waveform = waveform[0, ...] | |
waveform = torch.FloatTensor(waveform) | |
# log_mel_spec, stft, energy = Audio.tools.get_mel_from_wav(waveform, self.STFT)[0] | |
log_mel_spec, stft = self.mel_spectrogram_train(waveform.unsqueeze(0)) | |
log_mel_spec = torch.FloatTensor(log_mel_spec.T) | |
stft = torch.FloatTensor(stft.T) | |
log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft) | |
return log_mel_spec, stft | |
def pad_spec(self, log_mel_spec): | |
n_frames = log_mel_spec.shape[0] | |
p = self.target_length - n_frames | |
# cut and pad | |
if p > 0: | |
m = torch.nn.ZeroPad2d((0, 0, 0, p)) | |
log_mel_spec = m(log_mel_spec) | |
elif p < 0: | |
log_mel_spec = log_mel_spec[0 : self.target_length, :] | |
if log_mel_spec.size(-1) % 2 != 0: | |
log_mel_spec = log_mel_spec[..., :-1] | |
return log_mel_spec | |
def _read_datum_caption(self, datum): | |
caption_keys = [x for x in datum.keys() if ("caption" in x)] | |
random_index = torch.randint(0, len(caption_keys), (1,))[0].item() | |
return datum[caption_keys[random_index]] | |
def _is_contain_caption(self, datum): | |
caption_keys = [x for x in datum.keys() if ("caption" in x)] | |
return len(caption_keys) > 0 | |
def label_indices_to_text(self, datum, label_indices): | |
if self._is_contain_caption(datum): | |
return self._read_datum_caption(datum) | |
elif "label" in datum.keys(): | |
name_indices = torch.where(label_indices > 0.1)[0] | |
# description_header = "This audio contains the sound of " | |
description_header = "" | |
labels = "" | |
for id, each in enumerate(name_indices): | |
if id == len(name_indices) - 1: | |
labels += "%s." % self.num2label[int(each)] | |
else: | |
labels += "%s, " % self.num2label[int(each)] | |
return description_header + labels | |
else: | |
return "" # TODO, if both label and caption are not provided, return empty string | |
def random_uniform(self, start, end): | |
val = torch.rand(1).item() | |
return start + (end - start) * val | |
def frequency_masking(self, log_mel_spec, freqm): | |
bs, freq, tsteps = log_mel_spec.size() | |
mask_len = int(self.random_uniform(freqm // 8, freqm)) | |
mask_start = int(self.random_uniform(start=0, end=freq - mask_len)) | |
log_mel_spec[:, mask_start : mask_start + mask_len, :] *= 0.0 | |
return log_mel_spec | |
def time_masking(self, log_mel_spec, timem): | |
bs, freq, tsteps = log_mel_spec.size() | |
mask_len = int(self.random_uniform(timem // 8, timem)) | |
mask_start = int(self.random_uniform(start=0, end=tsteps - mask_len)) | |
log_mel_spec[:, :, mask_start : mask_start + mask_len] *= 0.0 | |
return log_mel_spec | |
class AudioDataset_infer(Dataset): | |
def __init__( | |
self, | |
config, | |
caption_list, | |
lock=True | |
): | |
self.config = config | |
# self.lock = threading.Lock() | |
""" | |
Dataset that manage caption writings | |
""" | |
self.captions = [] | |
with open(caption_list, 'r') as f: | |
for _ ,line in enumerate(f): | |
key = line.strip() | |
self.captions.append(key.split()[0]) | |
self.duration = self.duration = self.config["preprocessing"]["audio"]["duration"] | |
self.sampling_rate = self.config["variables"]["sampling_rate"] | |
self.target_length = int(self.sampling_rate * self.duration) | |
self.waveform = torch.zeros((1, self.target_length)) | |
def __getitem__(self, index): | |
fname = [f"sample_{index}"] | |
data = { | |
"text": [self.captions[index]], # list ... dict ? | |
"fname": fname, # list | |
# tensor, [batchsize, 1, samples_num] | |
"waveform": "", | |
# tensor, [batchsize, t-steps, f-bins] | |
"stft": "", | |
# tensor, [batchsize, t-steps, mel-bins] | |
"log_mel_spec": "", | |
"duration": self.duration, | |
"sampling_rate": self.sampling_rate, | |
"random_start_sample_in_original_audio_file": 0, | |
"label_vector": torch.FloatTensor(np.zeros(0, dtype=np.float32)), | |
"mos":mos | |
} | |
if data["text"] is None: | |
print("Warning: The model return None on key text", fname) | |
data["text"] = "" | |
return data | |
def __len__(self): | |
return len(self.captions) | |
if __name__ == "__main__": | |
import torch | |
from tqdm import tqdm | |
from pytorch_lightning import seed_everything | |
from torch.utils.data import DataLoader | |
seed_everything(0) | |
def write_json(my_dict, fname): | |
# print("Save json file at "+fname) | |
json_str = json.dumps(my_dict) | |
with open(fname, "w") as json_file: | |
json_file.write(json_str) | |
def load_json(fname): | |
with open(fname, "r") as f: | |
data = json.load(f) | |
return data | |
config = yaml.load( | |
open( | |
"/mnt/bn/lqhaoheliu/project/audio_generation_diffusion/config/vae_48k_256/ds_8_kl_1.0_ch_16.yaml", | |
"r", | |
), | |
Loader=yaml.FullLoader, | |
) | |
add_ons = config["data"]["dataloader_add_ons"] | |
# load_json(data) | |
dataset = AudioDataset( | |
config=config, split="train", waveform_only=False, add_ons=add_ons | |
) | |
loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=True) | |
# for cnt, each in tqdm(enumerate(loader)): | |
# print(each["waveform"].size(), each["log_mel_spec"].size()) | |
# print(each['freq_energy_percentile']) | |
# import ipdb | |
# ipdb.set_trace() | |
# pass | |