|
import pickle |
|
import pathlib |
|
import torch |
|
from torch.utils.data.dataloader import DataLoader |
|
import pytorch_lightning as pl |
|
import numpy as np |
|
import yaml |
|
import torchaudio |
|
import pyworld |
|
import pysptk |
|
import random |
|
|
|
|
|
class DataModule(pl.LightningDataModule): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.batchsize = config["train"]["batchsize"] |
|
self.preprocessed_dir = pathlib.Path(config["general"]["preprocessed_path"]) |
|
|
|
def setup(self, stage): |
|
|
|
if not self.preprocessed_dir.exists(): |
|
raise RuntimeError("Preprocessed directory was not be found") |
|
|
|
if "dual" in self.config: |
|
if self.config["dual"]["enable"]: |
|
task_config = yaml.load( |
|
open(self.config["dual"]["config_path"], "r"), |
|
Loader=yaml.FullLoader, |
|
) |
|
task_preprocessed_dir = ( |
|
self.preprocessed_dir.parent |
|
/ pathlib.Path(task_config["general"]["preprocessed_path"]).name |
|
) |
|
if not task_preprocessed_dir.exists(): |
|
raise RuntimeError( |
|
"Preprocessed directory for multi-task learning was not be found" |
|
) |
|
|
|
self.flnames = { |
|
"train": "train.txt", |
|
"val": "val.txt", |
|
"test": "test.txt", |
|
} |
|
|
|
def get_ds(self, phase): |
|
ds = Dataset(self.flnames[phase], self.config) |
|
return ds |
|
|
|
def get_loader(self, phase): |
|
ds = self.get_ds(phase) |
|
dl = DataLoader( |
|
ds, |
|
self.batchsize, |
|
shuffle=True if phase == "train" else False, |
|
num_workers=self.config["train"]["num_workers"], |
|
drop_last=True, |
|
) |
|
return dl |
|
|
|
def train_dataloader(self): |
|
return self.get_loader(phase="train") |
|
|
|
def val_dataloader(self): |
|
return self.get_loader(phase="val") |
|
|
|
def test_dataloader(self): |
|
return self.get_loader(phase="test") |
|
|
|
|
|
class Dataset(torch.utils.data.Dataset): |
|
def __init__(self, filetxt, config): |
|
|
|
self.preprocessed_dir = pathlib.Path(config["general"]["preprocessed_path"]) |
|
self.config = config |
|
self.spec_module = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=config["preprocess"]["sampling_rate"], |
|
n_fft=config["preprocess"]["fft_length"], |
|
win_length=config["preprocess"]["frame_length"], |
|
hop_length=config["preprocess"]["frame_shift"], |
|
f_min=config["preprocess"]["fmin"], |
|
f_max=config["preprocess"]["fmax"], |
|
n_mels=config["preprocess"]["n_mels"], |
|
power=1, |
|
center=True, |
|
norm="slaney", |
|
mel_scale="slaney", |
|
) |
|
self.resample_candidate = [8000, 11025, 12000, 16000] |
|
self.quantization_candidate = range(2 ** 6, 2 ** 10 + 2, 2) |
|
self.segment_length = config["preprocess"]["segment_length"] |
|
|
|
with open(self.preprocessed_dir / filetxt, "r") as fr: |
|
self.filelist = [pathlib.Path(path.strip("\n")) for path in fr] |
|
|
|
self.d_out = dict() |
|
for item in ["wavs", "wavsaux"]: |
|
self.d_out[item] = [] |
|
|
|
for wp in self.filelist: |
|
|
|
if config["general"]["corpus_type"] == "single": |
|
basename = str(wp.stem) |
|
else: |
|
basename = str(wp.parent.name) + "-" + str(wp.stem) |
|
|
|
with open(self.preprocessed_dir / "{}.pickle".format(basename), "rb") as fw: |
|
d_preprocessed = pickle.load(fw) |
|
|
|
for item in ["wavs", "wavsaux"]: |
|
try: |
|
self.d_out[item].extend(d_preprocessed[item]) |
|
except: |
|
pass |
|
|
|
for item in ["wavs", "wavsaux"]: |
|
if self.d_out[item] != None: |
|
self.d_out[item] = np.asarray(self.d_out[item]) |
|
|
|
if "dual" in self.config: |
|
if self.config["dual"]["enable"]: |
|
task_config = yaml.load( |
|
open(config["dual"]["config_path"], "r"), |
|
Loader=yaml.FullLoader, |
|
) |
|
task_preprocessed_dir = ( |
|
self.preprocessed_dir.parent |
|
/ pathlib.Path(task_config["general"]["preprocessed_path"]).name |
|
) |
|
with open(task_preprocessed_dir / filetxt, "r") as fr: |
|
task_filelist = [pathlib.Path(path.strip("\n")) for path in fr] |
|
self.d_out["wavstask"] = [] |
|
for wp in task_filelist: |
|
if task_config["general"]["corpus_type"] == "single": |
|
basename = str(wp.stem) |
|
else: |
|
basename = str(wp.parent.name) + "-" + str(wp.stem) |
|
with open( |
|
task_preprocessed_dir / "{}.pickle".format(basename), "rb" |
|
) as fw: |
|
d_preprocessed = pickle.load(fw) |
|
self.d_out["wavstask"].extend(d_preprocessed["wavs"]) |
|
self.d_out["wavstask"] = np.asarray(self.d_out["wavstask"]) |
|
|
|
def __len__(self): |
|
return len(self.d_out["wavs"]) |
|
|
|
def __getitem__(self, idx): |
|
|
|
d_batch = {} |
|
|
|
if self.d_out["wavs"].size > 0: |
|
d_batch["wavs"] = torch.from_numpy(self.d_out["wavs"][idx]) |
|
if self.segment_length > 0: |
|
d_batch["wavs"] = self.get_segment(d_batch["wavs"], self.segment_length) |
|
|
|
if self.d_out["wavsaux"].size > 0: |
|
d_batch["wavsaux"] = torch.from_numpy(self.d_out["wavsaux"][idx]) |
|
if self.segment_length > 0: |
|
d_batch["wavsaux"] = self.get_segment( |
|
d_batch["wavsaux"], self.segment_length |
|
) |
|
|
|
if self.config["general"]["stage"] == "pretrain": |
|
if self.config["train"]["augment"]: |
|
d_batch["wavs"] = self.augmentation(d_batch["wavsaux"]) |
|
d_batch["wavs"] = self.normalize_waveform(d_batch["wavs"], db=-3) |
|
d_batch["wavsaux"] = self.normalize_waveform(d_batch["wavsaux"], db=-3) |
|
if len(d_batch["wavs"]) != len(d_batch["wavsaux"]): |
|
min_seq_len = min(len(d_batch["wavs"]), len(d_batch["wavsaux"])) |
|
d_batch["wavs"] = d_batch["wavs"][:min_seq_len] |
|
d_batch["wavsaux"] = d_batch["wavsaux"][:min_seq_len] |
|
d_batch["melspecs"] = self.calc_spectrogram(d_batch["wavs"]) |
|
if self.config["general"]["feature_type"] == "melspec": |
|
d_batch["melspecsaux"] = self.calc_spectrogram(d_batch["wavsaux"]) |
|
elif self.config["general"]["feature_type"] == "vocfeats": |
|
d_batch["melceps"] = self.calc_melcep(d_batch["wavsaux"]) |
|
d_batch["f0s"] = self.calc_f0(d_batch["wavs"]) |
|
d_batch["melcepssrc"] = self.calc_melcep(d_batch["wavs"]) |
|
else: |
|
raise NotImplementedError() |
|
|
|
elif self.config["general"]["stage"].startswith("ssl"): |
|
d_batch["wavs"] = self.normalize_waveform(d_batch["wavs"], db=-3) |
|
d_batch["melspecs"] = self.calc_spectrogram(d_batch["wavs"]) |
|
if self.config["general"]["feature_type"] == "vocfeats": |
|
d_batch["f0s"] = self.calc_f0(d_batch["wavs"]) |
|
d_batch["melcepssrc"] = self.calc_melcep(d_batch["wavs"]) |
|
if self.d_out["wavsaux"].size > 0: |
|
d_batch["wavsaux"] = self.normalize_waveform(d_batch["wavsaux"], db=-3) |
|
if self.config["general"]["feature_type"] == "melspec": |
|
d_batch["melspecsaux"] = self.calc_spectrogram(d_batch["wavsaux"]) |
|
elif self.config["general"]["feature_type"] == "vocfeats": |
|
d_batch["melceps"] = self.calc_melcep(d_batch["wavsaux"]) |
|
if "dual" in self.config: |
|
if self.config["dual"]["enable"]: |
|
d_batch["wavstask"] = torch.from_numpy(self.d_out["wavstask"][idx]) |
|
d_batch["wavstask"] = self.get_segment( |
|
d_batch["wavstask"], self.segment_length |
|
) |
|
d_batch["wavstask"] = self.normalize_waveform( |
|
d_batch["wavstask"], db=-3 |
|
) |
|
if self.config["general"]["feature_type"] == "melspec": |
|
d_batch["melspecstask"] = self.calc_spectrogram( |
|
d_batch["wavstask"] |
|
) |
|
elif self.config["general"]["feature_type"] == "vocfeats": |
|
d_batch["melcepstask"] = self.calc_melcep(d_batch["wavstask"]) |
|
else: |
|
raise NotImplementedError() |
|
else: |
|
raise NotImplementedError() |
|
|
|
return d_batch |
|
|
|
def calc_spectrogram(self, wav): |
|
specs = self.spec_module(wav) |
|
log_spec = torch.log( |
|
torch.clamp_min(specs, self.config["preprocess"]["min_magnitude"]) |
|
* self.config["preprocess"]["comp_factor"] |
|
).to(torch.float32) |
|
return log_spec |
|
|
|
def calc_melcep(self, wav): |
|
wav = wav.numpy() |
|
_, sp, _ = pyworld.wav2world( |
|
wav.astype(np.float64), |
|
self.config["preprocess"]["sampling_rate"], |
|
fft_size=self.config["preprocess"]["fft_length"], |
|
frame_period=( |
|
self.config["preprocess"]["frame_shift"] |
|
/ self.config["preprocess"]["sampling_rate"] |
|
* 1000 |
|
), |
|
) |
|
melcep = pysptk.sp2mc( |
|
sp, |
|
order=self.config["preprocess"]["cep_order"], |
|
alpha=pysptk.util.mcepalpha(self.config["preprocess"]["sampling_rate"]), |
|
).transpose(1, 0) |
|
melcep = torch.from_numpy(melcep).to(torch.float32) |
|
return melcep |
|
|
|
def calc_f0(self, wav): |
|
if self.config["preprocess"]["f0_extractor"] == "dio": |
|
return self.calc_f0_dio(wav) |
|
elif self.config["preprocess"]["f0_extractor"] == "harvest": |
|
return self.calc_f0_harvest(wav) |
|
elif self.config["preprocess"]["f0_extractor"] == "swipe": |
|
return self.calc_f0_swipe(wav) |
|
else: |
|
raise NotImplementedError() |
|
|
|
def calc_f0_dio(self, wav): |
|
wav = wav.numpy() |
|
_f0, _t = pyworld.dio( |
|
wav.astype(np.float64), |
|
self.config["preprocess"]["sampling_rate"], |
|
frame_period=( |
|
self.config["preprocess"]["frame_shift"] |
|
/ self.config["preprocess"]["sampling_rate"] |
|
* 1000 |
|
), |
|
) |
|
f0 = pyworld.stonemask( |
|
wav.astype(np.float64), _f0, _t, self.config["preprocess"]["sampling_rate"] |
|
) |
|
f0 = torch.from_numpy(f0).to(torch.float32) |
|
return f0 |
|
|
|
def calc_f0_harvest(self, wav): |
|
wav = wav.numpy() |
|
_f0, _t = pyworld.harvest( |
|
wav.astype(np.float64), |
|
self.config["preprocess"]["sampling_rate"], |
|
frame_period=( |
|
self.config["preprocess"]["frame_shift"] |
|
/ self.config["preprocess"]["sampling_rate"] |
|
* 1000 |
|
), |
|
) |
|
f0 = pyworld.stonemask( |
|
wav.astype(np.float64), _f0, _t, self.config["preprocess"]["sampling_rate"] |
|
) |
|
f0 = torch.from_numpy(f0).to(torch.float32) |
|
return f0 |
|
|
|
def calc_f0_swipe(self, wav): |
|
wav = wav.numpy() |
|
f0 = pysptk.sptk.swipe( |
|
wav.astype(np.float64), |
|
fs=self.config["preprocess"]["sampling_rate"], |
|
min=71, |
|
max=800, |
|
hopsize=self.config["preprocess"]["frame_shift"], |
|
otype="f0", |
|
) |
|
f0 = torch.from_numpy(f0).to(torch.float32) |
|
return f0 |
|
|
|
def augmentation(self, wav): |
|
wav /= torch.max(torch.abs(wav)) |
|
new_freq = random.choice(self.resample_candidate) |
|
new_quantization = random.choice(self.quantization_candidate) |
|
mulaw_encoder = torchaudio.transforms.MuLawEncoding( |
|
quantization_channels=new_quantization |
|
) |
|
wav_quantized = mulaw_encoder(wav) / new_quantization * 2.0 - 1.0 |
|
downsampler = torchaudio.transforms.Resample( |
|
orig_freq=self.config["preprocess"]["sampling_rate"], |
|
new_freq=new_freq, |
|
resampling_method="sinc_interpolation", |
|
lowpass_filter_width=6, |
|
dtype=torch.float32, |
|
) |
|
upsampler = torchaudio.transforms.Resample( |
|
orig_freq=new_freq, |
|
new_freq=self.config["preprocess"]["sampling_rate"], |
|
resampling_method="sinc_interpolation", |
|
lowpass_filter_width=6, |
|
dtype=torch.float32, |
|
) |
|
wav_processed = upsampler(downsampler(wav_quantized)) |
|
return wav_processed |
|
|
|
def normalize_waveform(self, wav, db=-3): |
|
wav, _ = torchaudio.sox_effects.apply_effects_tensor( |
|
wav.unsqueeze(0), |
|
self.config["preprocess"]["sampling_rate"], |
|
[["norm", "{}".format(db)]], |
|
) |
|
return wav.squeeze(0) |
|
|
|
def get_segment(self, wav, segment_length): |
|
seg_size = self.config["preprocess"]["sampling_rate"] * segment_length |
|
if len(wav) >= seg_size: |
|
max_wav_start = len(wav) - seg_size |
|
wav_start = random.randint(0, max_wav_start) |
|
wav = wav[wav_start : wav_start + seg_size] |
|
else: |
|
wav = torch.nn.functional.pad(wav, (0, seg_size - len(wav)), "constant") |
|
return wav |
|
|