|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import json |
|
import logging as _logging |
|
import os |
|
import tarfile |
|
import urllib.request |
|
from glob import glob |
|
|
|
import librosa as l |
|
from sklearn.model_selection import StratifiedShuffleSplit |
|
from tqdm import tqdm |
|
|
|
parser = argparse.ArgumentParser(description="HI-MIA Data download") |
|
parser.add_argument("--data_root", required=True, default=None, type=str) |
|
parser.add_argument("--log_level", default=20, type=int) |
|
args = parser.parse_args() |
|
logging = _logging.getLogger(__name__) |
|
logging.addHandler(_logging.StreamHandler()) |
|
logging.setLevel(args.log_level) |
|
|
|
URL = { |
|
"dev": "http://www.openslr.org/resources/85/dev.tar.gz", |
|
"test": "http://www.openslr.org/resources/85/test.tar.gz", |
|
"train": "http://www.openslr.org/resources/85/train.tar.gz", |
|
} |
|
|
|
|
|
def __retrieve_with_progress(source: str, filename: str): |
|
""" |
|
Downloads source to destination |
|
Displays progress bar |
|
Args: |
|
source: url of resource |
|
destination: local filepath |
|
Returns: |
|
""" |
|
with open(filename, "wb") as f: |
|
response = urllib.request.urlopen(source) |
|
total = response.length |
|
|
|
if total is None: |
|
f.write(response.content) |
|
else: |
|
with tqdm(total=total, unit="B", unit_scale=True, unit_divisor=1024) as pbar: |
|
for data in response: |
|
f.write(data) |
|
pbar.update(len(data)) |
|
|
|
|
|
def __maybe_download_file(destination: str, source: str): |
|
""" |
|
Downloads source to destination if it doesn't exist. |
|
If exists, skips download |
|
Args: |
|
destination: local filepath |
|
source: url of resource |
|
|
|
Returns: |
|
|
|
""" |
|
source = URL[source] |
|
if not os.path.exists(destination) and not os.path.exists(os.path.splitext(destination)[0]): |
|
logging.info("{0} does not exist. Downloading ...".format(destination)) |
|
__retrieve_with_progress(source, filename=destination + ".tmp") |
|
os.rename(destination + ".tmp", destination) |
|
logging.info("Downloaded {0}.".format(destination)) |
|
elif os.path.exists(destination): |
|
logging.info("Destination {0} exists. Skipping.".format(destination)) |
|
elif os.path.exists(os.path.splitext(destination)[0]): |
|
logging.warning( |
|
"Assuming extracted folder %s contains the extracted files from %s. Will not download.", |
|
os.path.basename(destination), |
|
destination, |
|
) |
|
return destination |
|
|
|
|
|
def __extract_all_files(filepath: str, data_root: str, data_dir: str): |
|
if not os.path.exists(data_dir): |
|
extract_file(filepath, data_root) |
|
audio_dir = os.path.join(data_dir, "wav") |
|
for subfolder, _, filelist in os.walk(audio_dir): |
|
for ftar in filelist: |
|
extract_file(os.path.join(subfolder, ftar), subfolder) |
|
else: |
|
logging.info("Skipping extracting. Data already there %s" % data_dir) |
|
|
|
|
|
def extract_file(filepath: str, data_dir: str): |
|
try: |
|
tar = tarfile.open(filepath, encoding='utf-8') |
|
tar.extractall(data_dir) |
|
tar.close() |
|
except Exception: |
|
logging.info("Not extracting. Maybe already there?") |
|
|
|
|
|
def __remove_tarred_files(filepath: str, data_dir: str): |
|
if os.path.exists(data_dir) and os.path.isfile(filepath): |
|
logging.info("Deleting %s" % filepath) |
|
os.remove(filepath) |
|
|
|
|
|
def write_file(name, lines, idx): |
|
with open(name, "w") as fout: |
|
for i in idx: |
|
dic = lines[i] |
|
json.dump(dic, fout) |
|
fout.write("\n") |
|
logging.info("wrote %s", name) |
|
|
|
|
|
def __process_data(data_folder: str, data_set: str): |
|
""" |
|
To generate manifest |
|
Args: |
|
data_folder: source with wav files |
|
Returns: |
|
|
|
""" |
|
fullpath = os.path.abspath(data_folder) |
|
filelist = glob(fullpath + "/**/*.wav", recursive=True) |
|
out = os.path.join(fullpath, data_set + "_all.json") |
|
utt2spk = os.path.join(fullpath, "utt2spk") |
|
utt2spk_file = open(utt2spk, "w") |
|
id = -2 |
|
|
|
if os.path.exists(out): |
|
logging.warning( |
|
"%s already exists and is assumed to be processed. If not, please delete %s and rerun this script", |
|
out, |
|
out, |
|
) |
|
return |
|
|
|
speakers = [] |
|
lines = [] |
|
with open(out, "w") as outfile: |
|
for line in tqdm(filelist): |
|
line = line.strip() |
|
y, sr = l.load(line, sr=None) |
|
if sr != 16000: |
|
y, sr = l.load(line, sr=16000) |
|
l.output.write_wav(line, y, sr) |
|
dur = l.get_duration(y=y, sr=sr) |
|
if data_set == "test": |
|
speaker = line.split("/")[-1].split(".")[0].split("_")[0] |
|
else: |
|
speaker = line.split("/")[id] |
|
speaker = list(speaker) |
|
speaker = "".join(speaker) |
|
speakers.append(speaker) |
|
meta = {"audio_filepath": line, "duration": float(dur), "label": speaker} |
|
lines.append(meta) |
|
json.dump(meta, outfile) |
|
outfile.write("\n") |
|
utt2spk_file.write(line.split("/")[-1] + "\t" + speaker + "\n") |
|
|
|
utt2spk_file.close() |
|
|
|
if data_set != "test": |
|
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=42) |
|
for train_idx, test_idx in sss.split(speakers, speakers): |
|
print(len(train_idx)) |
|
|
|
out = os.path.join(fullpath, "train.json") |
|
write_file(out, lines, train_idx) |
|
out = os.path.join(fullpath, "dev.json") |
|
write_file(out, lines, test_idx) |
|
|
|
|
|
def main(): |
|
data_root = args.data_root |
|
for data_set in URL.keys(): |
|
|
|
|
|
logging.info("\n\nWorking on: {0}".format(data_set)) |
|
file_path = os.path.join(data_root, data_set + ".tgz") |
|
logging.info("Getting {0}".format(data_set)) |
|
__maybe_download_file(file_path, data_set) |
|
logging.info("Extracting {0}".format(data_set)) |
|
data_folder = os.path.join(data_root, data_set) |
|
__extract_all_files(file_path, data_root, data_folder) |
|
__remove_tarred_files(file_path, data_folder) |
|
logging.info("Processing {0}".format(data_set)) |
|
__process_data(data_folder, data_set) |
|
logging.info("Done!") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|