import json from collections import defaultdict from random import shuffle from typing import Optional from tqdm import tqdm import click from text.cleaner import clean_text_bert import os import torch from text.symbols import symbols, num_languages, num_tones @click.command() @click.option( "--metadata", default="data/example/metadata.list", type=click.Path(exists=True, file_okay=True, dir_okay=False), ) @click.option("--cleaned-path", default=None) @click.option("--train-path", default=None) @click.option("--val-path", default=None) @click.option( "--config_path", default="configs/config.json", type=click.Path(exists=True, file_okay=True, dir_okay=False), ) @click.option("--val-per-spk", default=4) @click.option("--max-val-total", default=8) @click.option("--clean/--no-clean", default=True) def main( metadata: str, cleaned_path: Optional[str], train_path: str, val_path: str, config_path: str, val_per_spk: int, max_val_total: int, clean: bool, ): if train_path is None: train_path = os.path.join(os.path.dirname(metadata), 'train.list') if val_path is None: val_path = os.path.join(os.path.dirname(metadata), 'val.list') out_config_path = os.path.join(os.path.dirname(metadata), 'config.json') if cleaned_path is None: cleaned_path = metadata + ".cleaned" if clean: out_file = open(cleaned_path, "w", encoding="utf-8") new_symbols = [] for line in tqdm(open(metadata, encoding="utf-8").readlines()): try: utt, spk, language, text = line.strip().split("|") norm_text, phones, tones, word2ph, bert = clean_text_bert(text, language, device='cuda:0') for ph in phones: if ph not in symbols and ph not in new_symbols: new_symbols.append(ph) print('update!, now symbols:') print(new_symbols) with open(f'{language}_symbol.txt', 'w') as f: f.write(f'{new_symbols}') assert len(phones) == len(tones) assert len(phones) == sum(word2ph) out_file.write( "{}|{}|{}|{}|{}|{}|{}\n".format( utt, spk, language, norm_text, " ".join(phones), " ".join([str(i) for i in tones]), " ".join([str(i) for i in word2ph]), ) ) bert_path = utt.replace(".wav", ".bert.pt") os.makedirs(os.path.dirname(bert_path), exist_ok=True) torch.save(bert.cpu(), bert_path) except Exception as error: print("err!", line, error) out_file.close() metadata = cleaned_path spk_utt_map = defaultdict(list) spk_id_map = {} current_sid = 0 with open(metadata, encoding="utf-8") as f: for line in f.readlines(): utt, spk, language, text, phones, tones, word2ph = line.strip().split("|") spk_utt_map[spk].append(line) if spk not in spk_id_map.keys(): spk_id_map[spk] = current_sid current_sid += 1 train_list = [] val_list = [] for spk, utts in spk_utt_map.items(): shuffle(utts) val_list += utts[:val_per_spk] train_list += utts[val_per_spk:] if len(val_list) > max_val_total: train_list += val_list[max_val_total:] val_list = val_list[:max_val_total] with open(train_path, "w", encoding="utf-8") as f: for line in train_list: f.write(line) with open(val_path, "w", encoding="utf-8") as f: for line in val_list: f.write(line) config = json.load(open(config_path, encoding="utf-8")) config["data"]["spk2id"] = spk_id_map config["data"]["training_files"] = train_path config["data"]["validation_files"] = val_path config["data"]["n_speakers"] = len(spk_id_map) config["num_languages"] = num_languages config["num_tones"] = num_tones config["symbols"] = symbols with open(out_config_path, "w", encoding="utf-8") as f: json.dump(config, f, indent=2, ensure_ascii=False) if __name__ == "__main__": main()