|
import copy |
|
from typing import Optional, Tuple |
|
import random |
|
|
|
from sklearn.cluster import KMeans |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present |
|
|
|
URLS = { |
|
"hubert-discrete": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-discrete-e9416457.pt", |
|
"hubert-soft": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt", |
|
"kmeans100": "https://github.com/bshall/hubert/releases/download/v0.1/kmeans100-50f36a95.pt", |
|
} |
|
|
|
|
|
class Hubert(nn.Module): |
|
def __init__(self, num_label_embeddings: int = 100, mask: bool = True): |
|
super().__init__() |
|
self._mask = mask |
|
self.feature_extractor = FeatureExtractor() |
|
self.feature_projection = FeatureProjection() |
|
self.positional_embedding = PositionalConvEmbedding() |
|
self.norm = nn.LayerNorm(768) |
|
self.dropout = nn.Dropout(0.1) |
|
self.encoder = TransformerEncoder( |
|
nn.TransformerEncoderLayer( |
|
768, 12, 3072, activation="gelu", batch_first=True |
|
), |
|
12, |
|
) |
|
self.proj = nn.Linear(768, 256) |
|
|
|
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_()) |
|
self.label_embedding = nn.Embedding(num_label_embeddings, 256) |
|
|
|
def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
mask = None |
|
if self.training and self._mask: |
|
mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2) |
|
x[mask] = self.masked_spec_embed.to(x.dtype) |
|
return x, mask |
|
|
|
def encode( |
|
self, x: torch.Tensor, layer: Optional[int] = None |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
x = self.feature_extractor(x) |
|
x = self.feature_projection(x.transpose(1, 2)) |
|
x, mask = self.mask(x) |
|
x = x + self.positional_embedding(x) |
|
x = self.dropout(self.norm(x)) |
|
x = self.encoder(x, output_layer=layer) |
|
return x, mask |
|
|
|
def logits(self, x: torch.Tensor) -> torch.Tensor: |
|
logits = torch.cosine_similarity( |
|
x.unsqueeze(2), |
|
self.label_embedding.weight.unsqueeze(0).unsqueeze(0), |
|
dim=-1, |
|
) |
|
return logits / 0.1 |
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
x, mask = self.encode(x) |
|
x = self.proj(x) |
|
logits = self.logits(x) |
|
return logits, mask |
|
|
|
|
|
class HubertSoft(Hubert): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
@torch.inference_mode() |
|
def units(self, wav: torch.Tensor) -> torch.Tensor: |
|
wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) |
|
x, _ = self.encode(wav) |
|
return self.proj(x) |
|
|
|
|
|
class HubertDiscrete(Hubert): |
|
def __init__(self, kmeans): |
|
super().__init__(504) |
|
self.kmeans = kmeans |
|
|
|
@torch.inference_mode() |
|
def units(self, wav: torch.Tensor) -> torch.LongTensor: |
|
wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) |
|
x, _ = self.encode(wav, layer=7) |
|
x = self.kmeans.predict(x.squeeze().cpu().numpy()) |
|
return torch.tensor(x, dtype=torch.long, device=wav.device) |
|
|
|
|
|
class FeatureExtractor(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False) |
|
self.norm0 = nn.GroupNorm(512, 512) |
|
self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False) |
|
self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False) |
|
self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False) |
|
self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False) |
|
self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False) |
|
self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = F.gelu(self.norm0(self.conv0(x))) |
|
x = F.gelu(self.conv1(x)) |
|
x = F.gelu(self.conv2(x)) |
|
x = F.gelu(self.conv3(x)) |
|
x = F.gelu(self.conv4(x)) |
|
x = F.gelu(self.conv5(x)) |
|
x = F.gelu(self.conv6(x)) |
|
return x |
|
|
|
|
|
class FeatureProjection(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.norm = nn.LayerNorm(512) |
|
self.projection = nn.Linear(512, 768) |
|
self.dropout = nn.Dropout(0.1) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.norm(x) |
|
x = self.projection(x) |
|
x = self.dropout(x) |
|
return x |
|
|
|
|
|
class PositionalConvEmbedding(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.conv = nn.Conv1d( |
|
768, |
|
768, |
|
kernel_size=128, |
|
padding=128 // 2, |
|
groups=16, |
|
) |
|
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.conv(x.transpose(1, 2)) |
|
x = F.gelu(x[:, :, :-1]) |
|
return x.transpose(1, 2) |
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
def __init__( |
|
self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int |
|
) -> None: |
|
super(TransformerEncoder, self).__init__() |
|
self.layers = nn.ModuleList( |
|
[copy.deepcopy(encoder_layer) for _ in range(num_layers)] |
|
) |
|
self.num_layers = num_layers |
|
|
|
def forward( |
|
self, |
|
src: torch.Tensor, |
|
mask: torch.Tensor = None, |
|
src_key_padding_mask: torch.Tensor = None, |
|
output_layer: Optional[int] = None, |
|
) -> torch.Tensor: |
|
output = src |
|
for layer in self.layers[:output_layer]: |
|
output = layer( |
|
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask |
|
) |
|
return output |
|
|
|
|
|
def _compute_mask( |
|
shape: Tuple[int, int], |
|
mask_prob: float, |
|
mask_length: int, |
|
device: torch.device, |
|
min_masks: int = 0, |
|
) -> torch.Tensor: |
|
batch_size, sequence_length = shape |
|
|
|
if mask_length < 1: |
|
raise ValueError("`mask_length` has to be bigger than 0.") |
|
|
|
if mask_length > sequence_length: |
|
raise ValueError( |
|
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" |
|
) |
|
|
|
|
|
num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random()) |
|
num_masked_spans = max(num_masked_spans, min_masks) |
|
|
|
|
|
if num_masked_spans * mask_length > sequence_length: |
|
num_masked_spans = sequence_length // mask_length |
|
|
|
|
|
mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool) |
|
|
|
|
|
uniform_dist = torch.ones( |
|
(batch_size, sequence_length - (mask_length - 1)), device=device |
|
) |
|
|
|
|
|
mask_indices = torch.multinomial(uniform_dist, num_masked_spans) |
|
|
|
|
|
mask_indices = ( |
|
mask_indices.unsqueeze(dim=-1) |
|
.expand((batch_size, num_masked_spans, mask_length)) |
|
.reshape(batch_size, num_masked_spans * mask_length) |
|
) |
|
offsets = ( |
|
torch.arange(mask_length, device=device)[None, None, :] |
|
.expand((batch_size, num_masked_spans, mask_length)) |
|
.reshape(batch_size, num_masked_spans * mask_length) |
|
) |
|
mask_idxs = mask_indices + offsets |
|
|
|
|
|
mask = mask.scatter(1, mask_idxs, True) |
|
|
|
return mask |
|
|
|
|
|
def hubert_discrete( |
|
pretrained: bool = True, |
|
progress: bool = True, |
|
) -> HubertDiscrete: |
|
r"""HuBERT-Discrete from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. |
|
Args: |
|
pretrained (bool): load pretrained weights into the model |
|
progress (bool): show progress bar when downloading model |
|
""" |
|
kmeans = kmeans100(pretrained=pretrained, progress=progress) |
|
hubert = HubertDiscrete(kmeans) |
|
if pretrained: |
|
checkpoint = torch.hub.load_state_dict_from_url( |
|
URLS["hubert-discrete"], progress=progress |
|
) |
|
consume_prefix_in_state_dict_if_present(checkpoint, "module.") |
|
hubert.load_state_dict(checkpoint) |
|
hubert.eval() |
|
return hubert |
|
|
|
|
|
def hubert_soft( |
|
pretrained: bool = True, |
|
progress: bool = True, |
|
) -> HubertSoft: |
|
r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. |
|
Args: |
|
pretrained (bool): load pretrained weights into the model |
|
progress (bool): show progress bar when downloading model |
|
""" |
|
hubert = HubertSoft() |
|
if pretrained: |
|
checkpoint = torch.hub.load_state_dict_from_url( |
|
URLS["hubert-soft"], progress=progress |
|
) |
|
consume_prefix_in_state_dict_if_present(checkpoint, "module.") |
|
hubert.load_state_dict(checkpoint) |
|
hubert.eval() |
|
return hubert |
|
|
|
|
|
def _kmeans( |
|
num_clusters: int, pretrained: bool = True, progress: bool = True |
|
) -> KMeans: |
|
kmeans = KMeans(num_clusters) |
|
if pretrained: |
|
checkpoint = torch.hub.load_state_dict_from_url( |
|
URLS[f"kmeans{num_clusters}"], progress=progress |
|
) |
|
kmeans.__dict__["n_features_in_"] = checkpoint["n_features_in_"] |
|
kmeans.__dict__["_n_threads"] = checkpoint["_n_threads"] |
|
kmeans.__dict__["cluster_centers_"] = checkpoint["cluster_centers_"].numpy() |
|
return kmeans |
|
|
|
|
|
def kmeans100(pretrained: bool = True, progress: bool = True) -> KMeans: |
|
r""" |
|
k-means checkpoint for HuBERT-Discrete with 100 clusters. |
|
Args: |
|
pretrained (bool): load pretrained weights into the model |
|
progress (bool): show progress bar when downloading model |
|
""" |
|
return _kmeans(100, pretrained, progress) |