|
|
|
import pdb |
|
import sys |
|
|
|
|
|
import traceback, os |
|
from typing import Dict |
|
from typing import List |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch, json |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data import Dataset |
|
from transformers import AutoTokenizer |
|
|
|
from text import cleaned_text_to_sequence |
|
|
|
|
|
|
|
|
|
def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0): |
|
seq = sequences[0] |
|
ndim = seq.ndim |
|
if axis < 0: |
|
axis += ndim |
|
dtype = seq.dtype |
|
pad_value = dtype.type(pad_value) |
|
seq_lengths = [seq.shape[axis] for seq in sequences] |
|
max_length = np.max(seq_lengths) |
|
|
|
padded_sequences = [] |
|
for seq, length in zip(sequences, seq_lengths): |
|
padding = ( |
|
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1) |
|
) |
|
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value) |
|
padded_sequences.append(padded_seq) |
|
batch = np.stack(padded_sequences) |
|
return batch |
|
|
|
|
|
class Text2SemanticDataset(Dataset): |
|
"""dataset class for text tokens to semantic model training.""" |
|
|
|
def __init__( |
|
self, |
|
phoneme_path: str, |
|
semantic_path: str, |
|
max_sample: int = None, |
|
max_sec: int = 100, |
|
pad_val: int = 1024, |
|
|
|
min_ps_ratio: int = 3, |
|
|
|
max_ps_ratio: int = 25, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.semantic_data = pd.read_csv( |
|
semantic_path, delimiter="\t", encoding="utf-8" |
|
) |
|
|
|
self.path2 = phoneme_path |
|
self.path3 = "%s/3-bert" % ( |
|
os.path.basename(phoneme_path) |
|
) |
|
self.path6 = semantic_path |
|
assert os.path.exists(self.path2) |
|
assert os.path.exists(self.path6) |
|
self.phoneme_data = {} |
|
with open(self.path2, "r", encoding="utf8") as f: |
|
lines = f.read().strip("\n").split("\n") |
|
|
|
for line in lines: |
|
tmp = line.split("\t") |
|
if len(tmp) != 4: |
|
continue |
|
self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]] |
|
|
|
|
|
|
|
self.PAD: int = pad_val |
|
|
|
|
|
|
|
|
|
self.hz = int(os.environ.get("hz", "25hz")[:-2]) |
|
|
|
|
|
self.max_sec = max_sec |
|
self.min_ps_ratio = min_ps_ratio |
|
self.max_ps_ratio = max_ps_ratio |
|
|
|
if max_sample is not None: |
|
self.semantic_data = self.semantic_data[:max_sample] |
|
|
|
|
|
|
|
self.semantic_phoneme = [] |
|
self.item_names = [] |
|
|
|
self.inited = False |
|
|
|
if not self.inited: |
|
|
|
self.init_batch() |
|
self.inited = True |
|
del self.semantic_data |
|
del self.phoneme_data |
|
|
|
|
|
|
|
def init_batch(self): |
|
semantic_data_len = len(self.semantic_data) |
|
phoneme_data_len = len(self.phoneme_data.keys()) |
|
print("semantic_data_len:", semantic_data_len) |
|
print("phoneme_data_len:", phoneme_data_len) |
|
print(self.semantic_data) |
|
idx = 0 |
|
num_not_in = 0 |
|
num_deleted_bigger = 0 |
|
num_deleted_ps = 0 |
|
for i in range(semantic_data_len): |
|
|
|
|
|
item_name = self.semantic_data.iloc[i,0] |
|
|
|
try: |
|
phoneme, word2ph, text = self.phoneme_data[item_name] |
|
except Exception: |
|
traceback.print_exc() |
|
|
|
num_not_in += 1 |
|
continue |
|
|
|
semantic_str = self.semantic_data.iloc[i,1] |
|
|
|
semantic_ids = [int(idx) for idx in semantic_str.split(" ")] |
|
|
|
|
|
if ( |
|
len(semantic_ids) > self.max_sec * self.hz |
|
): |
|
num_deleted_bigger += 1 |
|
continue |
|
|
|
phoneme = phoneme.split(" ") |
|
|
|
try: |
|
phoneme_ids = cleaned_text_to_sequence(phoneme) |
|
except: |
|
traceback.print_exc() |
|
|
|
num_not_in += 1 |
|
continue |
|
|
|
if ( |
|
len(phoneme_ids) > self.max_sec * self.hz / 2.5 |
|
): |
|
num_deleted_ps += 1 |
|
continue |
|
|
|
|
|
|
|
|
|
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz) |
|
|
|
if ( |
|
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio |
|
): |
|
num_deleted_ps += 1 |
|
|
|
continue |
|
|
|
self.semantic_phoneme.append((semantic_ids, phoneme_ids)) |
|
idx += 1 |
|
self.item_names.append(item_name) |
|
|
|
min_num = 100 |
|
leng = len(self.semantic_phoneme) |
|
if leng < min_num: |
|
tmp1 = self.semantic_phoneme |
|
tmp2 = self.item_names |
|
self.semantic_phoneme = [] |
|
self.item_names = [] |
|
for _ in range(max(2, int(min_num / leng))): |
|
self.semantic_phoneme += tmp1 |
|
self.item_names += tmp2 |
|
if num_not_in > 0: |
|
print(f"there are {num_not_in} semantic datas not in phoneme datas") |
|
if num_deleted_bigger > 0: |
|
print( |
|
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds" |
|
) |
|
if num_deleted_ps > 0: |
|
|
|
print( |
|
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}" |
|
) |
|
""" |
|
there are 31 semantic datas not in phoneme datas |
|
deleted 34 audios who's duration are bigger than 54 seconds |
|
deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3 |
|
dataset.__len__(): 366463 |
|
|
|
""" |
|
|
|
print("dataset.__len__():", self.__len__()) |
|
|
|
def __get_item_names__(self) -> List[str]: |
|
return self.item_names |
|
|
|
def __len__(self) -> int: |
|
return len(self.semantic_phoneme) |
|
|
|
def __getitem__(self, idx: int) -> Dict: |
|
semantic_ids, phoneme_ids = self.semantic_phoneme[idx] |
|
item_name = self.item_names[idx] |
|
phoneme_ids_len = len(phoneme_ids) |
|
|
|
semantic_ids_len = len(semantic_ids) |
|
|
|
flag = 0 |
|
path_bert = "%s/%s.pt" % (self.path3, item_name) |
|
if os.path.exists(path_bert) == True: |
|
bert_feature = torch.load(path_bert, map_location="cpu") |
|
else: |
|
flag = 1 |
|
if flag == 1: |
|
|
|
bert_feature = None |
|
else: |
|
assert bert_feature.shape[-1] == len(phoneme_ids) |
|
return { |
|
"idx": idx, |
|
"phoneme_ids": phoneme_ids, |
|
"phoneme_ids_len": phoneme_ids_len, |
|
"semantic_ids": semantic_ids, |
|
"semantic_ids_len": semantic_ids_len, |
|
"bert_feature": bert_feature, |
|
} |
|
|
|
def get_sample_length(self, idx: int): |
|
semantic_ids = self.semantic_phoneme[idx][0] |
|
sec = 1.0 * len(semantic_ids) / self.hz |
|
return sec |
|
|
|
def collate(self, examples: List[Dict]) -> Dict: |
|
sample_index: List[int] = [] |
|
phoneme_ids: List[torch.Tensor] = [] |
|
phoneme_ids_lens: List[int] = [] |
|
semantic_ids: List[torch.Tensor] = [] |
|
semantic_ids_lens: List[int] = [] |
|
|
|
|
|
for item in examples: |
|
sample_index.append(item["idx"]) |
|
phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64)) |
|
semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64)) |
|
phoneme_ids_lens.append(item["phoneme_ids_len"]) |
|
semantic_ids_lens.append(item["semantic_ids_len"]) |
|
|
|
|
|
phoneme_ids = batch_sequences(phoneme_ids) |
|
semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD) |
|
|
|
|
|
phoneme_ids = torch.tensor(phoneme_ids) |
|
semantic_ids = torch.tensor(semantic_ids) |
|
phoneme_ids_lens = torch.tensor(phoneme_ids_lens) |
|
semantic_ids_lens = torch.tensor(semantic_ids_lens) |
|
bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens)) |
|
bert_padded.zero_() |
|
|
|
for idx, item in enumerate(examples): |
|
bert = item["bert_feature"] |
|
if bert != None: |
|
bert_padded[idx, :, : bert.shape[-1]] = bert |
|
|
|
return { |
|
|
|
"ids": sample_index, |
|
|
|
"phoneme_ids": phoneme_ids, |
|
|
|
"phoneme_ids_len": phoneme_ids_lens, |
|
|
|
"semantic_ids": semantic_ids, |
|
|
|
"semantic_ids_len": semantic_ids_lens, |
|
|
|
"bert_feature": bert_padded, |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/" |
|
dataset = Text2SemanticDataset( |
|
phoneme_path=root_dir + "phoneme_train.npy", |
|
semantic_path=root_dir + "semantic_train.tsv", |
|
) |
|
|
|
batch_size = 12 |
|
dataloader = DataLoader( |
|
dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False |
|
) |
|
for i, batch in enumerate(dataloader): |
|
if i % 1000 == 0: |
|
print(i) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|