File size: 6,256 Bytes
6127b48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import argparse
import importlib
import os
from argparse import RawTextHelpFormatter
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from TTS.config import load_config
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.models import setup_model
from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_checkpoint
if __name__ == "__main__":
# pylint: disable=bad-option-value
parser = argparse.ArgumentParser(
description="""Extract attention masks from trained Tacotron/Tacotron2 models.
These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n"""
"""Each attention mask is written to the same path as the input wav file with ".npy" file extension.
(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n"""
"""
Example run:
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth
--config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json
--dataset_metafile metadata.csv
--data_path /root/LJSpeech-1.1/
--batch_size 32
--dataset ljspeech
--use_cuda True
""",
formatter_class=RawTextHelpFormatter,
)
parser.add_argument("--model_path", type=str, required=True, help="Path to Tacotron/Tacotron2 model file ")
parser.add_argument(
"--config_path",
type=str,
required=True,
help="Path to Tacotron/Tacotron2 config file.",
)
parser.add_argument(
"--dataset",
type=str,
default="",
required=True,
help="Target dataset processor name from TTS.tts.dataset.preprocess.",
)
parser.add_argument(
"--dataset_metafile",
type=str,
default="",
required=True,
help="Dataset metafile inclusing file paths with transcripts.",
)
parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.")
parser.add_argument("--use_cuda", type=bool, default=False, help="enable/disable cuda.")
parser.add_argument(
"--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA."
)
args = parser.parse_args()
C = load_config(args.config_path)
ap = AudioProcessor(**C.audio)
# if the vocabulary was passed, replace the default
if "characters" in C.keys():
symbols, phonemes = make_symbols(**C.characters)
# load the model
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
# TODO: handle multi-speaker
model = setup_model(C)
model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True)
# data loader
preprocessor = importlib.import_module("TTS.tts.datasets.formatters")
preprocessor = getattr(preprocessor, args.dataset)
meta_data = preprocessor(args.data_path, args.dataset_metafile)
dataset = TTSDataset(
model.decoder.r,
C.text_cleaner,
compute_linear_spec=False,
ap=ap,
meta_data=meta_data,
characters=C.characters if "characters" in C.keys() else None,
add_blank=C["add_blank"] if "add_blank" in C.keys() else False,
use_phonemes=C.use_phonemes,
phoneme_cache_path=C.phoneme_cache_path,
phoneme_language=C.phoneme_language,
enable_eos_bos=C.enable_eos_bos_chars,
)
dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False))
loader = DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=4,
collate_fn=dataset.collate_fn,
shuffle=False,
drop_last=False,
)
# compute attentions
file_paths = []
with torch.no_grad():
for data in tqdm(loader):
# setup input data
text_input = data[0]
text_lengths = data[1]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_targets = data[6]
item_idxs = data[7]
# dispatch data to GPU
if args.use_cuda:
text_input = text_input.cuda()
text_lengths = text_lengths.cuda()
mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda()
model_outputs = model.forward(text_input, text_lengths, mel_input)
alignments = model_outputs["alignments"].detach()
for idx, alignment in enumerate(alignments):
item_idx = item_idxs[idx]
# interpolate if r > 1
alignment = (
torch.nn.functional.interpolate(
alignment.transpose(0, 1).unsqueeze(0),
size=None,
scale_factor=model.decoder.r,
mode="nearest",
align_corners=None,
recompute_scale_factor=None,
)
.squeeze(0)
.transpose(0, 1)
)
# remove paddings
alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy()
# set file paths
wav_file_name = os.path.basename(item_idx)
align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy"
file_path = item_idx.replace(wav_file_name, align_file_name)
# save output
wav_file_abs_path = os.path.abspath(item_idx)
file_abs_path = os.path.abspath(file_path)
file_paths.append([wav_file_abs_path, file_abs_path])
np.save(file_path, alignment)
# ourput metafile
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")
with open(metafile, "w", encoding="utf-8") as f:
for p in file_paths:
f.write(f"{p[0]}|{p[1]}\n")
print(f" >> Metafile created: {metafile}")
|