Spaces:
Running
on
A10G
Running
on
A10G
# Copyright 2023 (authors: Feiteng Li) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import random | |
from typing import Dict, Iterator, List, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# from icefall.utils import make_pad_mask | |
# from torchmetrics.classification import MulticlassAccuracy | |
from modules.embedding import SinePositionalEmbedding, TokenEmbedding | |
from modules.transformer import ( | |
AdaptiveLayerNorm, | |
LayerNorm, | |
TransformerDecoderLayer, | |
TransformerEncoder, | |
TransformerEncoderLayer, | |
) | |
from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS | |
import psutil | |
def get_memory_usage(): | |
process = psutil.Process() | |
memory_info = process.memory_info() | |
memory_used = memory_info.rss | |
memory_used_mb = memory_used / (1024 * 1024) | |
return memory_used_mb | |
class Transpose(nn.Identity): | |
"""(N, T, D) -> (N, D, T)""" | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
return input.transpose(1, 2) | |
# NOTE: There are two ways to implement the model | |
# 1) [VALL-F] standard TransformerDecoder, use x as memory | |
# 2) [VALL-E] modified TransformerDecoder like GPT-x(e.g. causal TransformerEncoder), | |
# use x as the prefix of decoder inputs | |
class VALLF(nn.Module): | |
"""It implements https://arxiv.org/abs/2301.02111 | |
"Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers" | |
""" | |
def __init__( | |
self, | |
d_model: int, | |
nhead: int, | |
num_layers: int, | |
norm_first: bool = True, | |
add_prenet: bool = False, | |
decoder_cls: Union[ | |
nn.TransformerDecoder, nn.TransformerEncoder | |
] = nn.TransformerDecoder, | |
decoder_layer_cls: Union[ | |
TransformerDecoderLayer, TransformerEncoderLayer | |
] = TransformerDecoderLayer, | |
prefix_mode: int = 0, | |
share_embedding: bool = True, | |
nar_scale_factor: float = 1.0, | |
prepend_bos: bool = True, | |
num_quantizers: int = 8, | |
): | |
""" | |
Args: | |
d_model: | |
The number of expected features in the input (required). | |
nhead: | |
The number of heads in the multiheadattention models (required). | |
num_layers: | |
The number of sub-decoder-layers in the decoder (required). | |
""" | |
super().__init__() | |
nar_d_model = int(d_model * nar_scale_factor) | |
self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x | |
self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS) | |
# ID NUM_AUDIO_TOKENS -> PAD | |
# ID NUM_AUDIO_TOKENS + 1 -> BOS | |
self.ar_audio_prepend_bos = prepend_bos | |
self.ar_audio_embedding = TokenEmbedding( | |
d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos) | |
) | |
# PreNet | |
if add_prenet: | |
self.ar_text_prenet = nn.Sequential( | |
Transpose(), | |
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), | |
nn.BatchNorm1d(d_model), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), | |
nn.BatchNorm1d(d_model), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), | |
nn.BatchNorm1d(d_model), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
Transpose(), | |
nn.Linear(d_model, d_model), | |
) | |
self.ar_audio_prenet = nn.Sequential( | |
nn.Linear(d_model, 256), | |
nn.ReLU(), | |
nn.Dropout(0.25), | |
nn.Linear(256, 256), | |
nn.ReLU(), | |
nn.Dropout(0.25), | |
nn.Linear(256, d_model), | |
) | |
else: | |
self.ar_text_prenet = nn.Identity() | |
self.ar_audio_prenet = nn.Identity() | |
self.ar_text_position = SinePositionalEmbedding( | |
d_model, | |
dropout=0.1, | |
scale=False, | |
alpha=True, | |
) | |
self.ar_audio_position = SinePositionalEmbedding( | |
d_model, | |
dropout=0.1, | |
scale=False, | |
alpha=True, | |
) | |
self.ar_decoder = decoder_cls( | |
decoder_layer_cls( | |
d_model, | |
nhead, | |
dim_feedforward=d_model * 4, | |
dropout=0.1, | |
batch_first=True, | |
norm_first=norm_first, | |
), | |
num_layers=num_layers, | |
norm=LayerNorm(d_model) if norm_first else None, | |
) | |
self.ar_predict_layer = nn.Linear( | |
d_model, NUM_AUDIO_TOKENS + 1, bias=False | |
) | |
self.rng = random.Random(0) | |
self.num_heads = nhead | |
self.prefix_mode = prefix_mode | |
self.num_quantizers = num_quantizers | |
assert num_quantizers >= 1 | |
if num_quantizers > 1: | |
self.nar_audio_embeddings = nn.ModuleList( | |
[TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)] | |
+ [ | |
TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS) | |
for i in range(num_quantizers - 1) | |
] | |
) # W_a | |
# PreNet | |
if add_prenet: | |
self.nar_text_prenet = nn.Sequential( | |
Transpose(), | |
nn.Conv1d( | |
nar_d_model, nar_d_model, kernel_size=5, padding="same" | |
), | |
nn.BatchNorm1d(nar_d_model), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
nn.Conv1d( | |
nar_d_model, nar_d_model, kernel_size=5, padding="same" | |
), | |
nn.BatchNorm1d(nar_d_model), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
nn.Conv1d( | |
nar_d_model, nar_d_model, kernel_size=5, padding="same" | |
), | |
nn.BatchNorm1d(nar_d_model), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
Transpose(), | |
nn.Linear(nar_d_model, nar_d_model), | |
) | |
self.nar_audio_prenet = nn.Sequential( | |
nn.Linear(nar_d_model, 256), | |
nn.ReLU(), | |
nn.Dropout(0.25), | |
nn.Linear(256, 256), | |
nn.ReLU(), | |
nn.Dropout(0.25), | |
nn.Linear(256, nar_d_model), | |
) | |
else: | |
self.nar_text_prenet = nn.Identity() | |
self.nar_audio_prenet = nn.Identity() | |
self.nar_text_position = SinePositionalEmbedding( | |
nar_d_model, | |
dropout=0.0, | |
scale=False, | |
alpha=False, | |
) | |
self.nar_audio_position = SinePositionalEmbedding( | |
nar_d_model, | |
dropout=0.1, | |
scale=False, | |
alpha=False, | |
) | |
self.nar_decoder = decoder_cls( | |
decoder_layer_cls( | |
nar_d_model, | |
int(nhead * nar_scale_factor), | |
dim_feedforward=nar_d_model * 4, | |
dropout=0.1, | |
batch_first=True, | |
norm_first=norm_first, | |
adaptive_layer_norm=True, | |
), | |
num_layers=int(num_layers * nar_scale_factor), | |
norm=AdaptiveLayerNorm( | |
nar_d_model, norm=nn.LayerNorm(nar_d_model) | |
) | |
if norm_first | |
else None, | |
) | |
self.nar_predict_layers = nn.ModuleList( | |
[ | |
nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False) | |
for i in range(num_quantizers - 1) | |
] | |
) | |
self.nar_stage_embeddings = nn.ModuleList( | |
[ | |
TokenEmbedding(nar_d_model, 1) | |
for i in range(num_quantizers - 1) | |
] | |
) | |
if share_embedding: | |
# We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa | |
# NOTE(Feiteng): In the experiment, this undermines accuracy | |
# self.ar_predict_layer.weight = self.ar_audio_embedding.weight | |
# We also share the parameters of the acoustic embedding layer and the output prediction layer, | |
# which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer. | |
for j in range(0, num_quantizers - 2): | |
self.nar_predict_layers[ | |
j | |
].weight = self.nar_audio_embeddings[j + 2].weight | |
def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]: | |
assert stage > 0 | |
if stage == 1: | |
for name, param in self.named_parameters(): | |
if name.startswith("ar_"): | |
print(f" AR parameter: {name}") | |
yield param | |
if stage == 2: | |
for name, param in self.named_parameters(): | |
if name.startswith("nar_"): | |
print(f"NAR parameter: {name}") | |
yield param | |
def stage_named_parameters( | |
self, stage: int = 1 | |
) -> Iterator[Tuple[str, nn.Parameter]]: | |
assert stage > 0 | |
if stage == 1: | |
for pair in self.named_parameters(): | |
if pair[0].startswith("ar_"): | |
yield pair | |
if stage == 2: | |
for pair in self.named_parameters(): | |
if pair[0].startswith("nar_"): | |
yield pair | |
def pad_y_eos(self, y, y_mask_int, eos_id): | |
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( | |
y_mask_int, (0, 1), value=1 | |
) | |
# inputs, targets | |
if self.ar_audio_prepend_bos: | |
return ( | |
F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1), | |
targets, | |
) | |
return targets[:, :-1], targets[:, 1:] | |
def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes, prefix_mode): | |
# 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds | |
# from the same utterance. | |
# We implement this differently. | |
if prefix_mode == 0: | |
# no prefix | |
prefix_len = 0 | |
y_emb = self.nar_audio_embeddings[0](y) | |
for j in range(1, nar_stage): | |
# Formula (4) (5) | |
y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j]) | |
elif prefix_mode == 1: | |
# prefix at begining | |
int_low = (0.25 * y_lens.min()).type(torch.int64).item() | |
prefix_len = torch.randint(0, int_low * 2, size=()).item() | |
prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames | |
y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len]) | |
y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:]) | |
for j in range(1, self.num_quantizers): | |
y_prompts += self.nar_audio_embeddings[j]( | |
codes[:, :prefix_len, j] | |
) | |
if j < nar_stage: | |
y_emb += self.nar_audio_embeddings[j]( | |
codes[:, prefix_len:, j] | |
) | |
y_emb = torch.concat([y_prompts, y_emb], axis=1) | |
elif prefix_mode in [2, 4]: | |
if prefix_mode == 2: | |
# random prefix | |
prefix_len = min(225, int(0.25 * y_lens.min().item())) | |
y_prompts_codes = [] | |
for b in range(codes.shape[0]): | |
start = self.rng.randint(0, y_lens[b].item() - prefix_len) | |
y_prompts_codes.append( | |
torch.clone(codes[b, start : start + prefix_len]) | |
) | |
codes[ | |
b, start : start + prefix_len, nar_stage | |
] = NUM_AUDIO_TOKENS | |
y_prompts_codes = torch.stack(y_prompts_codes, dim=0) | |
else: | |
prefix_len = y_prompts_codes.shape[1] | |
y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0]) | |
y_emb = self.nar_audio_embeddings[0](y) | |
for j in range(1, self.num_quantizers): | |
y_prompts += self.nar_audio_embeddings[j]( | |
y_prompts_codes[..., j] | |
) | |
if j < nar_stage: | |
y_emb += self.nar_audio_embeddings[j](codes[..., j]) | |
y_emb = torch.concat([y_prompts, y_emb], axis=1) | |
else: | |
raise ValueError | |
return y_emb, prefix_len | |
def forward( | |
self, | |
x: torch.Tensor, | |
x_lens: torch.Tensor, | |
y: Union[torch.Tensor], | |
y_lens: Union[torch.Tensor], | |
reduction: str = "sum", | |
train_stage: int = 0, | |
**kwargs, | |
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: | |
raise NotImplementedError | |
def inference( | |
self, | |
x: torch.Tensor, | |
x_lens: torch.Tensor, | |
y: torch.Tensor, | |
enroll_x_lens: Union[torch.Tensor, None] = None, | |
top_k: int = -100, | |
temperature: float = 1.0, | |
) -> torch.Tensor: | |
raise NotImplementedError | |
def visualize( | |
self, | |
predicts: Tuple[torch.Tensor], | |
batch: Dict[str, Union[List, torch.Tensor]], | |
output_dir: str, | |
limit: int = 4, | |
) -> None: | |
raise NotImplementedError | |
class VALLE(VALLF): | |
"""It implements https://arxiv.org/abs/2301.02111 | |
"Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers" | |
""" | |
def __init__( | |
self, | |
d_model: int, | |
nhead: int, | |
num_layers: int, | |
norm_first: bool = True, | |
add_prenet: bool = False, | |
prefix_mode: int = 0, | |
share_embedding: bool = True, | |
nar_scale_factor: float = 1.0, | |
**kwargs, | |
): | |
""" | |
Args: | |
d_model: | |
The number of expected features in the input (required). | |
nhead: | |
The number of heads in the multiheadattention models (required). | |
num_layers: | |
The number of sub-decoder-layers in the decoder (required). | |
""" | |
super(VALLE, self).__init__( | |
d_model, | |
nhead, | |
num_layers, | |
norm_first=norm_first, | |
add_prenet=add_prenet, | |
decoder_cls=TransformerEncoder, | |
decoder_layer_cls=TransformerEncoderLayer, | |
prefix_mode=prefix_mode, | |
share_embedding=share_embedding, | |
nar_scale_factor=nar_scale_factor, | |
**kwargs, | |
) | |
self.language_ID = { | |
'en': 0, | |
'zh': 1, | |
'ja': 2, | |
} | |
self.ar_language_embedding = TokenEmbedding(d_model, len(self.language_ID)) | |
self.nar_language_embedding = TokenEmbedding(d_model, len(self.language_ID)) | |
def forward( | |
self, | |
x: torch.Tensor, | |
x_lens: torch.Tensor, | |
y: Union[torch.Tensor], | |
y_lens: Union[torch.Tensor], | |
reduction: str = "sum", | |
train_stage: int = 0, | |
**kwargs, | |
): | |
raise NotImplementedError | |
def inference( | |
self, | |
x: torch.Tensor, | |
x_lens: torch.Tensor, | |
y: torch.Tensor, | |
enroll_x_lens: torch.Tensor, | |
top_k: int = -100, | |
temperature: float = 1.0, | |
prompt_language: str = None, | |
text_language: str = None, | |
) -> torch.Tensor: | |
""" | |
Args: | |
x: | |
A 2-D tensor of shape (1, S). | |
x_lens: | |
A 1-D tensor of shape (1,). It contains the number of tokens in `x` | |
before padding. | |
y: | |
A 3-D tensor of shape (1, T, 8). | |
top_k: (`optional`) int | |
The number of highest probability tokens to keep for top-k-filtering. Default to -100. | |
temperature: (`optional`) float | |
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. | |
Returns: | |
Return the predicted audio code matrix. | |
""" | |
assert x.ndim == 2, x.shape | |
assert x_lens.ndim == 1, x_lens.shape | |
assert y.ndim == 3, y.shape | |
assert y.shape[0] == 1, y.shape | |
assert torch.all(x_lens > 0) | |
# NOTE: x has been padded in TextTokenCollater | |
text = x | |
x = self.ar_text_embedding(text) | |
# Add language embedding | |
prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device) | |
if isinstance(text_language, str): | |
text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device) | |
elif isinstance(text_language, List): | |
text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device) | |
x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id) | |
x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id) | |
x = self.ar_text_prenet(x) | |
x = self.ar_text_position(x) | |
text_len = x_lens.max() | |
prompts = y | |
prefix_len = y.shape[1] | |
# AR Decoder | |
# TODO: Managing decoder steps avoid repetitive computation | |
y = prompts[..., 0] | |
if self.ar_audio_prepend_bos: | |
y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1) | |
x_len = x_lens.max() | |
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) | |
kv_cache = None | |
use_kv_caching = True | |
while True: | |
y_emb = self.ar_audio_embedding(y) | |
y_emb = self.ar_audio_prenet(y_emb) | |
y_pos = self.ar_audio_position(y_emb) | |
xy_pos = torch.concat([x, y_pos], dim=1) | |
y_len = y.shape[1] | |
x_attn_mask_pad = F.pad( | |
x_attn_mask, | |
(0, y_len), | |
value=True, | |
) | |
y_attn_mask = F.pad( | |
torch.triu( | |
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1 | |
), | |
(x_len, 0), | |
value=False, | |
) | |
xy_attn_mask = torch.concat( | |
[x_attn_mask_pad, y_attn_mask], dim=0 | |
).to(y.device) | |
if use_kv_caching and kv_cache is not None: | |
xy_pos = xy_pos[:, [-1]] | |
else: | |
pass | |
xy_dec, kv_cache = self.ar_decoder.infer( | |
xy_pos, | |
mask=xy_attn_mask, | |
past_kv=kv_cache, | |
use_cache=use_kv_caching, | |
) | |
# xy_dec, _ = self.ar_decoder( | |
# (xy_pos, None), | |
# mask=xy_attn_mask, | |
# ) | |
logits = self.ar_predict_layer(xy_dec[:, -1]) | |
samples = topk_sampling( | |
logits, top_k=top_k, top_p=1, temperature=temperature | |
) | |
if ( | |
torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS | |
or samples[0, 0] == NUM_AUDIO_TOKENS | |
or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16 | |
): | |
if prompts.shape[1] == y.shape[1]: | |
raise SyntaxError( | |
"well trained model shouldn't reach here." | |
) | |
print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]") | |
memory_used = get_memory_usage() | |
print(f"Current memory used: {memory_used:.2f} MB") | |
break | |
y = torch.concat([y, samples], dim=1) | |
codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]] | |
if self.num_quantizers == 1: | |
return torch.stack(codes, dim=-1) | |
# Non-AR Decoders | |
y_emb = self.nar_audio_embeddings[0]( | |
y[:, int(self.ar_audio_prepend_bos) :] | |
) | |
if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes | |
enrolled_len = enroll_x_lens.max().item() | |
# SOS + Synthesis Text + EOS | |
text = torch.concat( | |
[ | |
text[:, :1], | |
text[:, enrolled_len - 1 :], | |
], | |
dim=1, | |
) | |
text_len = text_len - (enrolled_len - 2) | |
assert text.shape[0] == 1 | |
x = self.nar_text_embedding(text) | |
# Add language embedding | |
prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device) | |
if isinstance(text_language, str): | |
text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device) | |
elif isinstance(text_language, List): | |
text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device) | |
x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id) | |
x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id) | |
x = self.nar_text_prenet(x) | |
x = self.nar_text_position(x) | |
if self.prefix_mode == 0: | |
for i, (predict_layer, embedding_layer) in enumerate( | |
zip( | |
self.nar_predict_layers, | |
self.nar_audio_embeddings[1:], | |
) | |
): | |
y_pos = self.nar_audio_prenet(y_emb) | |
y_pos = self.nar_audio_position(y_pos) | |
xy_pos = torch.concat([x, y_pos], dim=1) | |
xy_dec, _ = self.nar_decoder( | |
(xy_pos, self.nar_stage_embeddings[i].weight) | |
) | |
logits = predict_layer(xy_dec[:, text_len + prefix_len :]) | |
samples = torch.argmax(logits, dim=-1) | |
codes.append(samples) | |
if i < self.num_quantizers - 2: | |
y_emb[:, :prefix_len] += embedding_layer( | |
prompts[..., i + 1] | |
) | |
y_emb[:, prefix_len:] += embedding_layer(samples) | |
else: | |
for j in range(1, self.num_quantizers): | |
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j]( | |
prompts[..., j] | |
) | |
for i, (predict_layer, embedding_layer) in enumerate( | |
zip( | |
self.nar_predict_layers, | |
self.nar_audio_embeddings[1:], | |
) | |
): | |
y_pos = self.nar_audio_prenet(y_emb) | |
y_pos = self.nar_audio_position(y_pos) | |
xy_pos = torch.concat([x, y_pos], dim=1) | |
xy_dec, _ = self.nar_decoder( | |
(xy_pos, self.nar_stage_embeddings[i].weight) | |
) | |
logits = predict_layer(xy_dec[:, text_len + prefix_len :]) | |
samples = torch.argmax(logits, dim=-1) | |
codes.append(samples) | |
if i < self.num_quantizers - 2: | |
y_emb[:, prefix_len:] += embedding_layer(samples) | |
assert len(codes) == self.num_quantizers | |
return torch.stack(codes, dim=-1) | |
def continual( | |
self, | |
x: torch.Tensor, | |
x_lens: torch.Tensor, | |
y: torch.Tensor, | |
) -> torch.Tensor: | |
""" | |
Args: | |
x: | |
A 2-D tensor of shape (1, S). | |
x_lens: | |
A 1-D tensor of shape (1,). It contains the number of tokens in `x` | |
before padding. | |
y: | |
A 3-D tensor of shape (1, T, 8). | |
Returns: | |
Return the predicted audio code matrix. | |
""" | |
assert x.ndim == 2, x.shape | |
assert x_lens.ndim == 1, x_lens.shape | |
assert y.ndim == 3, y.shape | |
assert y.shape[0] == 1, y.shape | |
assert torch.all(x_lens > 0) | |
assert self.num_quantizers == 8 | |
# NOTE: x has been padded in TextTokenCollater | |
text = x | |
x = self.ar_text_embedding(text) | |
x = self.ar_text_prenet(x) | |
x = self.ar_text_position(x) | |
text_len = x_lens.max() | |
prefix_len = min(int(y.shape[1] * 0.5), 3 * 75) | |
# AR Decoder | |
prompts = y[:, :prefix_len] | |
codes = [y[:, prefix_len:, 0]] | |
# Non-AR Decoders | |
x = self.nar_text_embedding(text) | |
x = self.nar_text_prenet(x) | |
x = self.nar_text_position(x) | |
y_emb = self.nar_audio_embeddings[0](y[..., 0]) | |
if self.prefix_mode == 0: | |
for i, (predict_layer, embedding_layer) in enumerate( | |
zip( | |
self.nar_predict_layers, | |
self.nar_audio_embeddings[1:], | |
) | |
): | |
y_pos = self.nar_audio_position(y_emb) | |
y_pos = self.nar_audio_prenet(y_pos) | |
xy_pos = torch.concat([x, y_pos], dim=1) | |
xy_dec, _ = self.nar_decoder( | |
(xy_pos, self.nar_stage_embeddings[i].weight) | |
) | |
logits = predict_layer(xy_dec[:, text_len + prefix_len :]) | |
samples = torch.argmax(logits, dim=-1) | |
codes.append(samples) | |
if i < 6: | |
y_emb[:, :prefix_len] += embedding_layer( | |
prompts[..., i + 1] | |
) | |
y_emb[:, prefix_len:] += embedding_layer(samples) | |
else: | |
for j in range(1, 8): | |
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j]( | |
prompts[..., j] | |
) | |
for i, (predict_layer, embedding_layer) in enumerate( | |
zip( | |
self.nar_predict_layers, | |
self.nar_audio_embeddings[1:], | |
) | |
): | |
y_pos = self.nar_audio_prenet(y_emb) | |
y_pos = self.nar_audio_position(y_pos) | |
xy_pos = torch.concat([x, y_pos], dim=1) | |
xy_dec, _ = self.nar_decoder( | |
(xy_pos, self.nar_stage_embeddings[i].weight) | |
) | |
logits = predict_layer(xy_dec[:, text_len + prefix_len :]) | |
samples = torch.argmax(logits, dim=-1) | |
codes.append(samples) | |
if i < 6: | |
y_emb[:, prefix_len:] += embedding_layer(samples) | |
assert len(codes) == 8 | |
return torch.stack(codes, dim=-1) | |
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py | |
def top_k_top_p_filtering( | |
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 | |
): | |
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering | |
Args: | |
logits: logits distribution shape (batch size, vocabulary size) | |
if top_k > 0: keep only top k tokens with highest probability (top-k filtering). | |
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). | |
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) | |
Make sure we keep at least min_tokens_to_keep per batch example in the output | |
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 | |
""" | |
if top_k > 0: | |
top_k = min( | |
max(top_k, min_tokens_to_keep), logits.size(-1) | |
) # Safety check | |
# Remove all tokens with a probability less than the last token of the top-k | |
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
logits[indices_to_remove] = filter_value | |
if top_p < 1.0: | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
cumulative_probs = torch.cumsum( | |
F.softmax(sorted_logits, dim=-1), dim=-1 | |
) | |
# Remove tokens with cumulative probability above the threshold (token with 0 are kept) | |
sorted_indices_to_remove = cumulative_probs > top_p | |
if min_tokens_to_keep > 1: | |
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) | |
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 | |
# Shift the indices to the right to keep also the first token above the threshold | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ | |
..., :-1 | |
].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
# scatter sorted tensors to original indexing | |
indices_to_remove = sorted_indices_to_remove.scatter( | |
1, sorted_indices, sorted_indices_to_remove | |
) | |
logits[indices_to_remove] = filter_value | |
return logits | |
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): | |
# temperature: (`optional`) float | |
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. | |
# top_k: (`optional`) int | |
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. | |
# top_p: (`optional`) float | |
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. | |
# Temperature (higher temperature => more likely to sample low probability tokens) | |
if temperature != 1.0: | |
logits = logits / temperature | |
# Top-p/top-k filtering | |
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) | |
# Sample | |
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) | |
return token | |