File size: 29,765 Bytes
6127b48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
import collections
import os
import random
from typing import Dict, List, Union

import numpy as np
import torch
import tqdm
from torch.utils.data import Dataset

from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.utils.audio import AudioProcessor

# to prevent too many open files error as suggested here
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
torch.multiprocessing.set_sharing_strategy("file_system")


def _parse_sample(item):
    language_name = None
    attn_file = None
    if len(item) == 5:
        text, wav_file, speaker_name, language_name, attn_file = item
    elif len(item) == 4:
        text, wav_file, speaker_name, language_name = item
    elif len(item) == 3:
        text, wav_file, speaker_name = item
    else:
        raise ValueError(" [!] Dataset cannot parse the sample.")
    return text, wav_file, speaker_name, language_name, attn_file


def noise_augment_audio(wav):
    return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)


class TTSDataset(Dataset):
    def __init__(
        self,
        outputs_per_step: int = 1,
        compute_linear_spec: bool = False,
        ap: AudioProcessor = None,
        samples: List[Dict] = None,
        tokenizer: "TTSTokenizer" = None,
        compute_f0: bool = False,
        f0_cache_path: str = None,
        return_wav: bool = False,
        batch_group_size: int = 0,
        min_text_len: int = 0,
        max_text_len: int = float("inf"),
        min_audio_len: int = 0,
        max_audio_len: int = float("inf"),
        phoneme_cache_path: str = None,
        precompute_num_workers: int = 0,
        speaker_id_mapping: Dict = None,
        d_vector_mapping: Dict = None,
        language_id_mapping: Dict = None,
        use_noise_augment: bool = False,
        start_by_longest: bool = False,
        verbose: bool = False,
    ):
        """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.

        If you need something different, you can subclass and override.

        Args:
            outputs_per_step (int): Number of time frames predicted per step.

            compute_linear_spec (bool): compute linear spectrogram if True.

            ap (TTS.tts.utils.AudioProcessor): Audio processor object.

            samples (list): List of dataset samples.

            tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else
                use the given. Defaults to None.

            compute_f0 (bool): compute f0 if True. Defaults to False.

            f0_cache_path (str): Path to store f0 cache. Defaults to None.

            return_wav (bool): Return the waveform of the sample. Defaults to False.

            batch_group_size (int): Range of batch randomization after sorting
                sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
                batch. Set 0 to disable. Defaults to 0.

            min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored.
                Defaults to 0.

            max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored.
                Defaults to float("inf").

            min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored.
                Defaults to 0.

            max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored.
                The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to
                this value if you encounter an OOM error in training. Defaults to float("inf").

            phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
                separate file. Defaults to None.

            precompute_num_workers (int): Number of workers to precompute features. Defaults to 0.

            speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
                embedding layer. Defaults to None.

            d_vector_mapping (dict): Mapping of wav files to computed d-vectors. Defaults to None.

            use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.

            start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False.

            verbose (bool): Print diagnostic information. Defaults to false.
        """
        super().__init__()
        self.batch_group_size = batch_group_size
        self._samples = samples
        self.outputs_per_step = outputs_per_step
        self.compute_linear_spec = compute_linear_spec
        self.return_wav = return_wav
        self.compute_f0 = compute_f0
        self.f0_cache_path = f0_cache_path
        self.min_audio_len = min_audio_len
        self.max_audio_len = max_audio_len
        self.min_text_len = min_text_len
        self.max_text_len = max_text_len
        self.ap = ap
        self.phoneme_cache_path = phoneme_cache_path
        self.speaker_id_mapping = speaker_id_mapping
        self.d_vector_mapping = d_vector_mapping
        self.language_id_mapping = language_id_mapping
        self.use_noise_augment = use_noise_augment
        self.start_by_longest = start_by_longest

        self.verbose = verbose
        self.rescue_item_idx = 1
        self.pitch_computed = False
        self.tokenizer = tokenizer

        if self.tokenizer.use_phonemes:
            self.phoneme_dataset = PhonemeDataset(
                self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers
            )

        if compute_f0:
            self.f0_dataset = F0Dataset(
                self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
            )

        if self.verbose:
            self.print_logs()

    @property
    def lengths(self):
        lens = []
        for item in self.samples:
            _, wav_file, *_ = _parse_sample(item)
            audio_len = os.path.getsize(wav_file) / 16 * 8  # assuming 16bit audio
            lens.append(audio_len)
        return lens

    @property
    def samples(self):
        return self._samples

    @samples.setter
    def samples(self, new_samples):
        self._samples = new_samples
        if hasattr(self, "f0_dataset"):
            self.f0_dataset.samples = new_samples
        if hasattr(self, "phoneme_dataset"):
            self.phoneme_dataset.samples = new_samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.load_data(idx)

    def print_logs(self, level: int = 0) -> None:
        indent = "\t" * level
        print("\n")
        print(f"{indent}> DataLoader initialization")
        print(f"{indent}| > Tokenizer:")
        self.tokenizer.print_logs(level + 1)
        print(f"{indent}| > Number of instances : {len(self.samples)}")

    def load_wav(self, filename):
        waveform = self.ap.load_wav(filename)
        assert waveform.size > 0
        return waveform

    def get_phonemes(self, idx, text):
        out_dict = self.phoneme_dataset[idx]
        assert text == out_dict["text"], f"{text} != {out_dict['text']}"
        assert len(out_dict["token_ids"]) > 0
        return out_dict

    def get_f0(self, idx):
        out_dict = self.f0_dataset[idx]
        item = self.samples[idx]
        assert item["audio_file"] == out_dict["audio_file"]
        return out_dict

    @staticmethod
    def get_attn_mask(attn_file):
        return np.load(attn_file)

    def get_token_ids(self, idx, text):
        if self.tokenizer.use_phonemes:
            token_ids = self.get_phonemes(idx, text)["token_ids"]
        else:
            token_ids = self.tokenizer.text_to_ids(text)
        return np.array(token_ids, dtype=np.int32)

    def load_data(self, idx):
        item = self.samples[idx]

        raw_text = item["text"]

        wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32)

        # apply noise for augmentation
        if self.use_noise_augment:
            wav = noise_augment_audio(wav)

        # get token ids
        token_ids = self.get_token_ids(idx, item["text"])

        # get pre-computed attention maps
        attn = None
        if "alignment_file" in item:
            attn = self.get_attn_mask(item["alignment_file"])

        # after phonemization the text length may change
        # this is a shareful 🤭 hack to prevent longer phonemes
        # TODO: find a better fix
        if len(token_ids) > self.max_text_len or len(wav) < self.min_audio_len:
            self.rescue_item_idx += 1
            return self.load_data(self.rescue_item_idx)

        # get f0 values
        f0 = None
        if self.compute_f0:
            f0 = self.get_f0(idx)["f0"]

        sample = {
            "raw_text": raw_text,
            "token_ids": token_ids,
            "wav": wav,
            "pitch": f0,
            "attn": attn,
            "item_idx": item["audio_file"],
            "speaker_name": item["speaker_name"],
            "language_name": item["language"],
            "wav_file_name": os.path.basename(item["audio_file"]),
        }
        return sample

    @staticmethod
    def _compute_lengths(samples):
        new_samples = []
        for item in samples:
            audio_length = os.path.getsize(item["audio_file"]) / 16 * 8  # assuming 16bit audio
            text_lenght = len(item["text"])
            item["audio_length"] = audio_length
            item["text_length"] = text_lenght
            new_samples += [item]
        return new_samples

    @staticmethod
    def filter_by_length(lengths: List[int], min_len: int, max_len: int):
        idxs = np.argsort(lengths)  # ascending order
        ignore_idx = []
        keep_idx = []
        for idx in idxs:
            length = lengths[idx]
            if length < min_len or length > max_len:
                ignore_idx.append(idx)
            else:
                keep_idx.append(idx)
        return ignore_idx, keep_idx

    @staticmethod
    def sort_by_length(samples: List[List]):
        audio_lengths = [s["audio_length"] for s in samples]
        idxs = np.argsort(audio_lengths)  # ascending order
        return idxs

    @staticmethod
    def create_buckets(samples, batch_group_size: int):
        assert batch_group_size > 0
        for i in range(len(samples) // batch_group_size):
            offset = i * batch_group_size
            end_offset = offset + batch_group_size
            temp_items = samples[offset:end_offset]
            random.shuffle(temp_items)
            samples[offset:end_offset] = temp_items
        return samples

    @staticmethod
    def _select_samples_by_idx(idxs, samples):
        samples_new = []
        for idx in idxs:
            samples_new.append(samples[idx])
        return samples_new

    def preprocess_samples(self):
        r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
        range.
        """
        samples = self._compute_lengths(self.samples)

        # sort items based on the sequence length in ascending order
        text_lengths = [i["text_length"] for i in samples]
        audio_lengths = [i["audio_length"] for i in samples]
        text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len)
        audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len)
        keep_idx = list(set(audio_keep_idx) & set(text_keep_idx))
        ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx))

        samples = self._select_samples_by_idx(keep_idx, samples)

        sorted_idxs = self.sort_by_length(samples)

        if self.start_by_longest:
            longest_idxs = sorted_idxs[-1]
            sorted_idxs[-1] = sorted_idxs[0]
            sorted_idxs[0] = longest_idxs

        samples = self._select_samples_by_idx(sorted_idxs, samples)

        if len(samples) == 0:
            raise RuntimeError(" [!] No samples left")

        # shuffle batch groups
        # create batches with similar length items
        # the larger the `batch_group_size`, the higher the length variety in a batch.
        if self.batch_group_size > 0:
            samples = self.create_buckets(samples, self.batch_group_size)

        # update items to the new sorted items
        audio_lengths = [s["audio_length"] for s in samples]
        text_lengths = [s["text_length"] for s in samples]
        self.samples = samples

        if self.verbose:
            print(" | > Preprocessing samples")
            print(" | > Max text length: {}".format(np.max(text_lengths)))
            print(" | > Min text length: {}".format(np.min(text_lengths)))
            print(" | > Avg text length: {}".format(np.mean(text_lengths)))
            print(" | ")
            print(" | > Max audio length: {}".format(np.max(audio_lengths)))
            print(" | > Min audio length: {}".format(np.min(audio_lengths)))
            print(" | > Avg audio length: {}".format(np.mean(audio_lengths)))
            print(f" | > Num. instances discarded samples: {len(ignore_idx)}")
            print(" | > Batch group size: {}.".format(self.batch_group_size))

    @staticmethod
    def _sort_batch(batch, text_lengths):
        """Sort the batch by the input text length for RNN efficiency.

        Args:
            batch (Dict): Batch returned by `__getitem__`.
            text_lengths (List[int]): Lengths of the input character sequences.
        """
        text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True)
        batch = [batch[idx] for idx in ids_sorted_decreasing]
        return batch, text_lengths, ids_sorted_decreasing

    def collate_fn(self, batch):
        r"""
        Perform preprocessing and create a final data batch:
        1. Sort batch instances by text-length
        2. Convert Audio signal to features.
        3. PAD sequences wrt r.
        4. Load to Torch.
        """

        # Puts each data field into a tensor with outer dimension batch size
        if isinstance(batch[0], collections.abc.Mapping):

            token_ids_lengths = np.array([len(d["token_ids"]) for d in batch])

            # sort items with text input length for RNN efficiency
            batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths)

            # convert list of dicts to dict of lists
            batch = {k: [dic[k] for dic in batch] for k in batch[0]}

            # get language ids from language names
            if self.language_id_mapping is not None:
                language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]]
            else:
                language_ids = None
            # get pre-computed d-vectors
            if self.d_vector_mapping is not None:
                wav_files_names = list(batch["wav_file_name"])
                d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
            else:
                d_vectors = None

            # get numerical speaker ids from speaker names
            if self.speaker_id_mapping:
                speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]
            else:
                speaker_ids = None
            # compute features
            mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]]

            mel_lengths = [m.shape[1] for m in mel]

            # lengths adjusted by the reduction factor
            mel_lengths_adjusted = [
                m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step))
                if m.shape[1] % self.outputs_per_step
                else m.shape[1]
                for m in mel
            ]

            # compute 'stop token' targets
            stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths]

            # PAD stop targets
            stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)

            # PAD sequences with longest instance in the batch
            token_ids = prepare_data(batch["token_ids"]).astype(np.int32)

            # PAD features with longest instance
            mel = prepare_tensor(mel, self.outputs_per_step)

            # B x D x T --> B x T x D
            mel = mel.transpose(0, 2, 1)

            # convert things to pytorch
            token_ids_lengths = torch.LongTensor(token_ids_lengths)
            token_ids = torch.LongTensor(token_ids)
            mel = torch.FloatTensor(mel).contiguous()
            mel_lengths = torch.LongTensor(mel_lengths)
            stop_targets = torch.FloatTensor(stop_targets)

            # speaker vectors
            if d_vectors is not None:
                d_vectors = torch.FloatTensor(d_vectors)

            if speaker_ids is not None:
                speaker_ids = torch.LongTensor(speaker_ids)

            if language_ids is not None:
                language_ids = torch.LongTensor(language_ids)

            # compute linear spectrogram
            linear = None
            if self.compute_linear_spec:
                linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
                linear = prepare_tensor(linear, self.outputs_per_step)
                linear = linear.transpose(0, 2, 1)
                assert mel.shape[1] == linear.shape[1]
                linear = torch.FloatTensor(linear).contiguous()

            # format waveforms
            wav_padded = None
            if self.return_wav:
                wav_lengths = [w.shape[0] for w in batch["wav"]]
                max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
                wav_lengths = torch.LongTensor(wav_lengths)
                wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len)
                for i, w in enumerate(batch["wav"]):
                    mel_length = mel_lengths_adjusted[i]
                    w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
                    w = w[: mel_length * self.ap.hop_length]
                    wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
                wav_padded.transpose_(1, 2)

            # format F0
            if self.compute_f0:
                pitch = prepare_data(batch["pitch"])
                assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
                pitch = torch.FloatTensor(pitch)[:, None, :].contiguous()  # B x 1 xT
            else:
                pitch = None

            # format attention masks
            attns = None
            if batch["attn"][0] is not None:
                attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
                for idx, attn in enumerate(attns):
                    pad2 = mel.shape[1] - attn.shape[1]
                    pad1 = token_ids.shape[1] - attn.shape[0]
                    assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}"
                    attn = np.pad(attn, [[0, pad1], [0, pad2]])
                    attns[idx] = attn
                attns = prepare_tensor(attns, self.outputs_per_step)
                attns = torch.FloatTensor(attns).unsqueeze(1)

            return {
                "token_id": token_ids,
                "token_id_lengths": token_ids_lengths,
                "speaker_names": batch["speaker_name"],
                "linear": linear,
                "mel": mel,
                "mel_lengths": mel_lengths,
                "stop_targets": stop_targets,
                "item_idxs": batch["item_idx"],
                "d_vectors": d_vectors,
                "speaker_ids": speaker_ids,
                "attns": attns,
                "waveform": wav_padded,
                "raw_text": batch["raw_text"],
                "pitch": pitch,
                "language_ids": language_ids,
            }

        raise TypeError(
            (
                "batch must contain tensors, numbers, dicts or lists;\
                         found {}".format(
                    type(batch[0])
                )
            )
        )


class PhonemeDataset(Dataset):
    """Phoneme Dataset for converting input text to phonemes and then token IDs

    At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data
    loading latency. If `cache_path` is already present, it skips the pre-computation.

    Args:
        samples (Union[List[List], List[Dict]]):
            List of samples. Each sample is a list or a dict.

        tokenizer (TTSTokenizer):
            Tokenizer to convert input text to phonemes.

        cache_path (str):
            Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation.

        precompute_num_workers (int):
            Number of workers used for pre-computing the phonemes. Defaults to 0.
    """

    def __init__(
        self,
        samples: Union[List[Dict], List[List]],
        tokenizer: "TTSTokenizer",
        cache_path: str,
        precompute_num_workers=0,
    ):
        self.samples = samples
        self.tokenizer = tokenizer
        self.cache_path = cache_path
        if cache_path is not None and not os.path.exists(cache_path):
            os.makedirs(cache_path)
            self.precompute(precompute_num_workers)

    def __getitem__(self, index):
        item = self.samples[index]
        ids = self.compute_or_load(item["audio_file"], item["text"])
        ph_hat = self.tokenizer.ids_to_text(ids)
        return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}

    def __len__(self):
        return len(self.samples)

    def compute_or_load(self, wav_file, text):
        """Compute phonemes for the given text.

        If the phonemes are already cached, load them from cache.
        """
        file_name = os.path.splitext(os.path.basename(wav_file))[0]
        file_ext = "_phoneme.npy"
        cache_path = os.path.join(self.cache_path, file_name + file_ext)
        try:
            ids = np.load(cache_path)
        except FileNotFoundError:
            ids = self.tokenizer.text_to_ids(text)
            np.save(cache_path, ids)
        return ids

    def get_pad_id(self):
        """Get pad token ID for sequence padding"""
        return self.tokenizer.pad_id

    def precompute(self, num_workers=1):
        """Precompute phonemes for all samples.

        We use pytorch dataloader because we are lazy.
        """
        print("[*] Pre-computing phonemes...")
        with tqdm.tqdm(total=len(self)) as pbar:
            batch_size = num_workers if num_workers > 0 else 1
            dataloder = torch.utils.data.DataLoader(
                batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
            )
            for _ in dataloder:
                pbar.update(batch_size)

    def collate_fn(self, batch):
        ids = [item["token_ids"] for item in batch]
        ids_lens = [item["token_ids_len"] for item in batch]
        texts = [item["text"] for item in batch]
        texts_hat = [item["ph_hat"] for item in batch]
        ids_lens_max = max(ids_lens)
        ids_torch = torch.LongTensor(len(ids), ids_lens_max).fill_(self.get_pad_id())
        for i, ids_len in enumerate(ids_lens):
            ids_torch[i, :ids_len] = torch.LongTensor(ids[i])
        return {"text": texts, "ph_hat": texts_hat, "token_ids": ids_torch}

    def print_logs(self, level: int = 0) -> None:
        indent = "\t" * level
        print("\n")
        print(f"{indent}> PhonemeDataset ")
        print(f"{indent}| > Tokenizer:")
        self.tokenizer.print_logs(level + 1)
        print(f"{indent}| > Number of instances : {len(self.samples)}")


class F0Dataset:
    """F0 Dataset for computing F0 from wav files in CPU

    Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It
    also computes the mean and std of F0 values if `normalize_f0` is True.

    Args:
        samples (Union[List[List], List[Dict]]):
            List of samples. Each sample is a list or a dict.

        ap (AudioProcessor):
            AudioProcessor to compute F0 from wav files.

        cache_path (str):
            Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation.
            Defaults to None.

        precompute_num_workers (int):
            Number of workers used for pre-computing the F0 values. Defaults to 0.

        normalize_f0 (bool):
            Whether to normalize F0 values by mean and std. Defaults to True.
    """

    def __init__(
        self,
        samples: Union[List[List], List[Dict]],
        ap: "AudioProcessor",
        verbose=False,
        cache_path: str = None,
        precompute_num_workers=0,
        normalize_f0=True,
    ):
        self.samples = samples
        self.ap = ap
        self.verbose = verbose
        self.cache_path = cache_path
        self.normalize_f0 = normalize_f0
        self.pad_id = 0.0
        self.mean = None
        self.std = None
        if cache_path is not None and not os.path.exists(cache_path):
            os.makedirs(cache_path)
            self.precompute(precompute_num_workers)
        if normalize_f0:
            self.load_stats(cache_path)

    def __getitem__(self, idx):
        item = self.samples[idx]
        f0 = self.compute_or_load(item["audio_file"])
        if self.normalize_f0:
            assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
            f0 = self.normalize(f0)
        return {"audio_file": item["audio_file"], "f0": f0}

    def __len__(self):
        return len(self.samples)

    def precompute(self, num_workers=0):
        print("[*] Pre-computing F0s...")
        with tqdm.tqdm(total=len(self)) as pbar:
            batch_size = num_workers if num_workers > 0 else 1
            # we do not normalize at preproessing
            normalize_f0 = self.normalize_f0
            self.normalize_f0 = False
            dataloder = torch.utils.data.DataLoader(
                batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
            )
            computed_data = []
            for batch in dataloder:
                f0 = batch["f0"]
                computed_data.append(f for f in f0)
                pbar.update(batch_size)
            self.normalize_f0 = normalize_f0

        if self.normalize_f0:
            computed_data = [tensor for batch in computed_data for tensor in batch]  # flatten
            pitch_mean, pitch_std = self.compute_pitch_stats(computed_data)
            pitch_stats = {"mean": pitch_mean, "std": pitch_std}
            np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)

    def get_pad_id(self):
        return self.pad_id

    @staticmethod
    def create_pitch_file_path(wav_file, cache_path):
        file_name = os.path.splitext(os.path.basename(wav_file))[0]
        pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
        return pitch_file

    @staticmethod
    def _compute_and_save_pitch(ap, wav_file, pitch_file=None):
        wav = ap.load_wav(wav_file)
        pitch = ap.compute_f0(wav)
        if pitch_file:
            np.save(pitch_file, pitch)
        return pitch

    @staticmethod
    def compute_pitch_stats(pitch_vecs):
        nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs])
        mean, std = np.mean(nonzeros), np.std(nonzeros)
        return mean, std

    def load_stats(self, cache_path):
        stats_path = os.path.join(cache_path, "pitch_stats.npy")
        stats = np.load(stats_path, allow_pickle=True).item()
        self.mean = stats["mean"].astype(np.float32)
        self.std = stats["std"].astype(np.float32)

    def normalize(self, pitch):
        zero_idxs = np.where(pitch == 0.0)[0]
        pitch = pitch - self.mean
        pitch = pitch / self.std
        pitch[zero_idxs] = 0.0
        return pitch

    def denormalize(self, pitch):
        zero_idxs = np.where(pitch == 0.0)[0]
        pitch *= self.std
        pitch += self.mean
        pitch[zero_idxs] = 0.0
        return pitch

    def compute_or_load(self, wav_file):
        """
        compute pitch and return a numpy array of pitch values
        """
        pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
        if not os.path.exists(pitch_file):
            pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
        else:
            pitch = np.load(pitch_file)
        return pitch.astype(np.float32)

    def collate_fn(self, batch):
        audio_file = [item["audio_file"] for item in batch]
        f0s = [item["f0"] for item in batch]
        f0_lens = [len(item["f0"]) for item in batch]
        f0_lens_max = max(f0_lens)
        f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id())
        for i, f0_len in enumerate(f0_lens):
            f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i])
        return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens}

    def print_logs(self, level: int = 0) -> None:
        indent = "\t" * level
        print("\n")
        print(f"{indent}> F0Dataset ")
        print(f"{indent}| > Number of instances : {len(self.samples)}")