StevenChen16's picture
first commit
31ba7c5
raw
history blame
37 kB
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import json
import tarfile
import json
import io
import pyarrow.parquet as pq
from io import BytesIO
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import tarfile
import json
import io
import wave
import numpy as np
import torchaudio
import os
import sys
import json
import random
import pickle
import argparse
import itertools
import mmap
import struct
import collections
import shutil
import multiprocessing as mp
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict
from copy import deepcopy
from datetime import datetime
import pickle
from wids import wids
import math
torchaudio.set_audio_backend('soundfile')
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
try:
MAIN_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/mean_embedding.pt")
GPT_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/spk_mean_embeddings.pt")
except:
MAIN_SPK_EMBEDDING=torch.zeros(1,192)
GPT_SPK_EMBEDDING=torch.zeros(1,192)
def parquet_opener(data, mode='train', tts_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
url = sample['src']
try:
df = pq.read_table(url).to_pandas()
for i in range(len(df)):
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
continue
sample.update(dict(df.loc[i]))
if mode == 'train':
# NOTE do not return sample directly, must initialize a new dict
yield {**sample}
else:
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
yield {**sample, 'tts_index': index, 'tts_text': text}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
def parse_tar_header(header_bytes):
header = struct.unpack("!100s8s8s8s12s12s8s1s100s6s2s32s32s8s8s155s", header_bytes)
return TarHeader(*header)
TarHeader = collections.namedtuple(
"TarHeader",
[
"name",
"mode",
"uid",
"gid",
"size",
"mtime",
"chksum",
"typeflag",
"linkname",
"magic",
"version",
"uname",
"gname",
"devmajor",
"devminor",
"prefix",
],
)
class MMTar:
def __init__(self, file_path: Path | str):
self.stream = open(file_path, "rb")
self.mmap = mmap.mmap(self.stream.fileno(), 0, access=mmap.ACCESS_READ)
def __del__(self):
try:
self.mmap.close()
self.stream.close()
except: # noqa
pass
def get_at_offset(self, offset) -> tuple[str, bytes]:
header = parse_tar_header(self.mmap[offset : offset + 500])
name = header.name.decode("utf-8").strip("\x00")
start = offset + 512
end = start + int(header.size.decode("utf-8")[:-1], 8)
return name, self.mmap[start:end]
class Tar:
def __init__(self, path: Path):
self.tar = MMTar(path)
indices_path = path.with_suffix(".index")
self.index = pickle.loads(indices_path.read_bytes())
self.name_mapping = {}
for name, offset, _ in self.index:
self.name_mapping[name] = offset
def read(self, name: str) -> bytes:
return self.tar.get_at_offset(self.name_mapping[name])[1]
def cosy_jsonl_opener(data, mode='train', tts_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
cosy_jsonl_path = sample['src']
tar_file_path=cosy_jsonl_path.replace(".vq0907.jsonl",".tar")
try:
tar_data=Tar(Path(tar_file_path))
with open(cosy_jsonl_path, 'r') as f:
for line in f:
item=json.loads(line)
cosy_token = item['cosy_token']
sample['speech_token']=torch.tensor(cosy_token)
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
# print(item['filename'])
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
def cosy_jsonl_opener_vq0918_nopool(data, mode='train', tts_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
cosy_jsonl_path = sample['src']
tar_file_path=cosy_jsonl_path.replace(".vq0918-nopool.jsonl",".tar")
try:
tar_data=Tar(Path(tar_file_path))
with open(cosy_jsonl_path, 'r') as f:
# cosy_data = [json.loads(line) for line in f]
for line in f:
item=json.loads(line)
cosy_token = item['cosy_token']
sample['speech_token']=torch.tensor(cosy_token)
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
# print(item['filename'])
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
def cosy_jsonl_opener_vq0918_pool2(data, mode='train', tts_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
cosy_jsonl_path = sample['src']
tar_file_path=cosy_jsonl_path.replace(".vq0918-pool2.jsonl",".tar")
try:
tar_data=Tar(Path(tar_file_path))
with open(cosy_jsonl_path, 'r') as f:
for line in f:
item=json.loads(line)
cosy_token = item['cosy_token']
sample['speech_token']=torch.tensor(cosy_token)
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
def cosy_jsonl_opener_vq0918_pool4(data, mode='train', tts_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
cosy_jsonl_path = sample['src']
tar_file_path=cosy_jsonl_path.replace(".vq0918-pool4.jsonl",".tar")
try:
tar_data=Tar(Path(tar_file_path))
with open(cosy_jsonl_path, 'r') as f:
# cosy_data = [json.loads(line) for line in f]
for line in f:
item=json.loads(line)
cosy_token = item['cosy_token']
sample['speech_token']=torch.tensor(cosy_token)
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
# print(item['filename'])
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
def cosy_jsonl_opener_vq0918_pool8(data, mode='train', tts_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
cosy_jsonl_path = sample['src']
tar_file_path=cosy_jsonl_path.replace(".vq0918-pool8.jsonl",".tar")
try:
tar_data=Tar(Path(tar_file_path))
with open(cosy_jsonl_path, 'r') as f:
# cosy_data = [json.loads(line) for line in f]
for line in f:
item=json.loads(line)
cosy_token = item['cosy_token']
sample['speech_token']=torch.tensor(cosy_token)
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
# print(item['filename'])
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
def process_sft_vq0918_pool4(data, mode='train', tts_data={}):
for sample in data:
assert 'src' in sample
token_npy_path = sample['src']
wav_path=token_npy_path.replace(".vq0918-pool4.npy","")
# wav_path,token_npy_path=sample['src'].split(' ')
try:
sample['speech_token']=torch.tensor(np.load(token_npy_path))
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
if sample['speech'].shape[0] > 1:
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
logging.warning('Failed to open {}'.format(wav_path))
def process_sft_vq0918_pool4_split(data, mode='train',split_token=25, tts_data={}):
for sample in data:
assert 'src' in sample
token_npy_path = sample['src']
wav_path=token_npy_path.replace(".vq0918-pool4.npy","")
# wav_path,token_npy_path=sample['src'].split(' ')
try:
# sample['speech_token']=torch.tensor(np.load(token_npy_path))
# sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
# if sample['speech'].shape[0] > 1:
# sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
# sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
speech_token=torch.tensor(np.load(token_npy_path))
speech,sample_rate= torchaudio.load(wav_path)
# split_speech=int(split_token / 12.5 * sample_rate)
if speech.shape[0] > 1:
speech = speech.mean(dim=0, keepdim=True)
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
sample['sample_rate']=sample_rate
num_splits = (speech_token.size(0) + split_token - 1) // split_token
for split_id in range(num_splits):
end_token_idx = min((split_id + 1) * split_token, speech_token.size(0))
end_speech_idx=int(np.ceil(end_token_idx / 12.5 * sample_rate))
sample['speech_token']=speech_token[:end_token_idx]
sample['speech']=speech[:,:end_speech_idx]
print(sample['speech_token'].size(),sample['speech'].size())
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
logging.warning('Failed to open {}'.format(wav_path))
def process_sft_vq0918_pool2(data, mode='train', tts_data={}):
for sample in data:
assert 'src' in sample
token_npy_path = sample['src'].replace(".vq0918-pool4.npy",".vq0918-pool2.npy")
wav_path=token_npy_path.replace(".vq0918-pool2.npy","")
# wav_path,token_npy_path=sample['src'].split(' ')
try:
sample['speech_token']=torch.tensor(np.load(token_npy_path))
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
if sample['speech'].shape[0] > 1:
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
logging.warning('Failed to open {}'.format(wav_path))
def process_sft_vq0918_pool2_split(data, mode='train',split_token=50, tts_data={}):
for sample in data:
assert 'src' in sample
token_npy_path = sample['src']
wav_path=token_npy_path.replace(".vq0918-pool2.npy","")
# wav_path,token_npy_path=sample['src'].split(' ')
try:
# sample['speech_token']=torch.tensor(np.load(token_npy_path))
# sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
# if sample['speech'].shape[0] > 1:
# sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
# sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
speech_token=torch.tensor(np.load(token_npy_path))
speech,sample_rate= torchaudio.load(wav_path)
# split_speech=int(split_token / 12.5 * sample_rate)
if speech.shape[0] > 1:
speech = speech.mean(dim=0, keepdim=True)
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
sample['sample_rate']=sample_rate
num_splits = (speech_token.size(0) + split_token - 1) // split_token
for split_id in range(num_splits):
end_token_idx = min((split_id + 1) * split_token, speech_token.size(0))
end_speech_idx=int(np.ceil(end_token_idx / 25 * sample_rate))
sample['speech_token']=speech_token[:end_token_idx]
sample['speech']=speech[:,:end_speech_idx]
print(sample['speech_token'].size(),sample['speech'].size())
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
logging.warning('Failed to open {}'.format(wav_path))
def process_sft_vq0918_pool4_gpt(data, mode='train', tts_data={}):
for sample in data:
assert 'src' in sample
try:
entry=json.loads(sample['src'])
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
for conv in entry["conversations"]:
if "response_wav" in conv:
wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
sample['speech_token']=torch.tensor(np.load(token_npy_path))
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
if sample['speech'].shape[0] > 1:
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
sample['spk_embedding']=spk_embedding
yield {**sample}
except Exception as ex:
# logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
logging.warning('Failed to open {}'.format(wav_path))
def process_sft_vq0918_pool4_gpt_1010(data, mode='train', tts_data={}):
for sample in data:
assert 'src' in sample
try:
entry=json.loads(sample['src'])
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
for conv in entry["conversations"]:
if "response_wav" in conv:
wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
sample['speech_token']=torch.tensor(np.load(token_npy_path))
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
if sample['speech'].shape[0] > 1:
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
sample['spk_embedding']=spk_embedding
yield {**sample}
if "prompt_wav" in conv:
wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
sample['speech_token']=torch.tensor(np.load(token_npy_path))
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
if sample['speech'].shape[0] > 1:
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
sample['spk_embedding']=spk_embedding
yield {**sample}
except Exception as ex:
# logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
logging.warning('Failed to open {}'.format(wav_path))
def filter(data,
max_length=10240,
min_length=10,
token_max_length=200,
token_min_length=1,
min_output_input_ratio=0.0005,
max_output_input_ratio=1,
mode='train'):
""" Filter sample according to feature and label length
Inplace operation.
Args::
data: Iterable[{key, wav, label, sample_rate}]
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
# sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
# del sample['audio_data']
# sample['wav'] is torch.Tensor, we have 100 frames every second
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
if num_frames < min_length:
continue
if num_frames > max_length:
continue
if len(sample['text_token']) < token_min_length:
continue
if len(sample['text_token']) > token_max_length:
continue
if len(sample['speech_token']) == 0:
continue
if num_frames != 0:
if len(sample['text_token']) / num_frames < min_output_input_ratio:
continue
if len(sample['text_token']) / num_frames > max_output_input_ratio:
continue
yield sample
def filter_speech_token(data,
max_length=10240,
min_length=10,
token_max_length=5000,
token_min_length=1,
min_output_input_ratio=0.0005,
max_output_input_ratio=30,
mode='train'):
""" Filter sample according to feature and label length
Inplace operation.
Args::
data: Iterable[{key, wav, label, sample_rate}]
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
# sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
# del sample['audio_data']
# sample['wav'] is torch.Tensor, we have 100 frames every second
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
if num_frames < min_length:
continue
if num_frames > max_length:
continue
if len(sample['speech_token']) < token_min_length:
continue
if len(sample['speech_token']) > token_max_length:
continue
if len(sample['speech_token']) == 0:
continue
if num_frames != 0:
if len(sample['speech_token']) / num_frames < min_output_input_ratio:
continue
if len(sample['speech_token']) / num_frames > max_output_input_ratio:
continue
yield sample
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
sample_rate = sample['sample_rate']
waveform = sample['speech']
if sample_rate != resample_rate:
if sample_rate < min_sample_rate:
continue
sample['sample_rate'] = resample_rate
sample['speech'] = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
max_val = sample['speech'].abs().max()
if max_val > 1:
sample['speech'] /= max_val
yield sample
def compute_fbank(data,
feat_extractor,
mode='train'):
""" Extract fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
# assert 'utt' in sample
# assert 'text_token' in sample
waveform = sample['speech']
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
sample['speech_feat'] = mat
del sample['speech']
yield sample
def parse_embedding(data, normalize, mode='train'):
""" Parse utt_embedding/spk_embedding
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
if normalize:
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
yield sample
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
""" Decode text to chars or BPE
Inplace operation
Args:
data: Iterable[{key, wav, txt, sample_rate}]
Returns:
Iterable[{key, wav, txt, tokens, label, sample_rate}]
"""
tokenizer = get_tokenizer()
for sample in data:
assert 'text' in sample
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
if mode == 'inference':
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
yield sample
def shuffle(data, shuffle_size=10000, mode='train'):
""" Local shuffle the data
Args:
data: Iterable[{key, feat, label}]
shuffle_size: buffer size for shuffle
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
yield x
buf = []
# The sample left over
random.shuffle(buf)
for x in buf:
yield x
def sort(data, sort_size=500, mode='train'):
""" Sort the data by feature length.
Sort is used after shuffle and before batch, so we can group
utts with similar lengths into a batch, and `sort_size` should
be less than `shuffle_size`
Args:
data: Iterable[{key, feat, label}]
sort_size: buffer size for sort
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= sort_size:
buf.sort(key=lambda x: x['speech_feat'].size(0))
for x in buf:
yield x
buf = []
# The sample left over
buf.sort(key=lambda x: x['speech_feat'].size(0))
for x in buf:
yield x
def static_batch(data, batch_size=16):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{key, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= batch_size:
yield buf
buf = []
if len(buf) > 0:
yield buf
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
Args:
data: Iterable[{key, feat, label}]
max_frames_in_batch: max_frames in one batch
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
longest_frames = 0
for sample in data:
assert 'speech_feat' in sample
assert isinstance(sample['speech_feat'], torch.Tensor)
new_sample_frames = sample['speech_feat'].size(0)
longest_frames = max(longest_frames, new_sample_frames)
frames_after_padding = longest_frames * (len(buf) + 1)
if frames_after_padding > max_frames_in_batch:
yield buf
buf = [sample]
longest_frames = new_sample_frames
else:
buf.append(sample)
if len(buf) > 0:
yield buf
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
""" Wrapper for static/dynamic batch
"""
if mode == 'inference':
return static_batch(data, 1)
else:
if batch_type == 'static':
return static_batch(data, batch_size)
elif batch_type == 'dynamic':
return dynamic_batch(data, max_frames_in_batch)
else:
logging.fatal('Unsupported batch type {}'.format(batch_type))
def padding(data, use_spk_embedding, mode='train'):
""" Padding the data into training data
Args:
data: Iterable[List[{key, feat, label}]]
Returns:
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
"""
for sample in data:
assert isinstance(sample, list)
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
dtype=torch.int32)
order = torch.argsort(speech_feat_len, descending=True)
utts = [sample[i]['utt'] for i in order]
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
speech_token = pad_sequence(speech_token,
batch_first=True,
padding_value=0)
speech_feat = [sample[i]['speech_feat'] for i in order]
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
speech_feat = pad_sequence(speech_feat,
batch_first=True,
padding_value=0)
text = [sample[i]['text'] for i in order]
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
batch = {
"utts": utts,
"speech_token": speech_token,
"speech_token_len": speech_token_len,
"speech_feat": speech_feat,
"speech_feat_len": speech_feat_len,
"text": text,
"text_token": text_token,
"text_token_len": text_token_len,
"utt_embedding": utt_embedding,
"spk_embedding": spk_embedding,
}
if mode == 'inference':
tts_text = [sample[i]['tts_text'] for i in order]
tts_index = [sample[i]['tts_index'] for i in order]
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
batch.update({'tts_text': tts_text,
'tts_index': tts_index,
'tts_text_token': tts_text_token,
'tts_text_token_len': tts_text_token_len})
if use_spk_embedding is True:
batch["embedding"] = batch["spk_embedding"]
else:
batch["embedding"] = batch["utt_embedding"]
yield batch
def padding_speech_token(data, use_spk_embedding, mode='train'):
""" Padding the data into training data
Args:
data: Iterable[List[{key, feat, label}]]
Returns:
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
"""
for sample in data:
assert isinstance(sample, list)
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
dtype=torch.int32)
order = torch.argsort(speech_feat_len, descending=True)
# utts = [sample[i]['utt'] for i in order]
# speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
try:
speech_token = [sample[i]['speech_token'].clone().detach() for i in order]
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
speech_token = pad_sequence(speech_token,
batch_first=True,
padding_value=0)
speech_feat = [sample[i]['speech_feat'] for i in order]
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
speech_feat = pad_sequence(speech_feat,
batch_first=True,
padding_value=0)
batch = {
"speech_token": speech_token,
"speech_token_len": speech_token_len,
"speech_feat": speech_feat,
"speech_feat_len": speech_feat_len,
}
if mode == 'inference':
tts_text = [sample[i]['tts_text'] for i in order]
tts_index = [sample[i]['tts_index'] for i in order]
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
batch.update({'tts_text': tts_text,
'tts_index': tts_index,
'tts_text_token': tts_text_token,
'tts_text_token_len': tts_text_token_len})
# if use_spk_embedding is True:
# batch["embedding"] = batch["spk_embedding"]
# else:
# batch["embedding"] = batch["utt_embedding"]
batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device)
yield batch
except Exception as ex:
logging.warning(' ex info {}'.format(ex))
# assert False
def padding_speech_token_spk(data, use_spk_embedding, mode='train'):
""" Padding the data into training data
Args:
data: Iterable[List[{key, feat, label}]]
Returns:
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
"""
for sample in data:
assert isinstance(sample, list)
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
dtype=torch.int32)
order = torch.argsort(speech_feat_len, descending=True)
# utts = [sample[i]['utt'] for i in order]
# speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
try:
speech_token = [sample[i]['speech_token'].clone().detach() for i in order]
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
speech_token = pad_sequence(speech_token,
batch_first=True,
padding_value=0)
speech_feat = [sample[i]['speech_feat'] for i in order]
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
speech_feat = pad_sequence(speech_feat,
batch_first=True,
padding_value=0)
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
batch = {
"speech_token": speech_token,
"speech_token_len": speech_token_len,
"speech_feat": speech_feat,
"speech_feat_len": speech_feat_len,
"spk_embedding": spk_embedding,
}
if mode == 'inference':
tts_text = [sample[i]['tts_text'] for i in order]
tts_index = [sample[i]['tts_index'] for i in order]
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
batch.update({'tts_text': tts_text,
'tts_index': tts_index,
'tts_text_token': tts_text_token,
'tts_text_token_len': tts_text_token_len})
# if use_spk_embedding is True:
# batch["embedding"] = batch["spk_embedding"]
# else:
# batch["embedding"] = batch["utt_embedding"]
# batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device)
batch["embedding"] = batch["spk_embedding"]
yield batch
except Exception as ex:
logging.warning(' ex info {}'.format(ex))
# assert False