import copy from typing import Optional, Tuple import random from sklearn.cluster import KMeans import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present URLS = { "hubert-discrete": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-discrete-e9416457.pt", "hubert-soft": "https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt", "kmeans100": "https://github.com/bshall/hubert/releases/download/v0.1/kmeans100-50f36a95.pt", } class Hubert(nn.Module): def __init__(self, num_label_embeddings: int = 100, mask: bool = True): super().__init__() self._mask = mask self.feature_extractor = FeatureExtractor() self.feature_projection = FeatureProjection() self.positional_embedding = PositionalConvEmbedding() self.norm = nn.LayerNorm(768) self.dropout = nn.Dropout(0.1) self.encoder = TransformerEncoder( nn.TransformerEncoderLayer( 768, 12, 3072, activation="gelu", batch_first=True ), 12, ) self.proj = nn.Linear(768, 256) self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_()) self.label_embedding = nn.Embedding(num_label_embeddings, 256) def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: mask = None if self.training and self._mask: mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2) x[mask] = self.masked_spec_embed.to(x.dtype) return x, mask def encode( self, x: torch.Tensor, layer: Optional[int] = None ) -> Tuple[torch.Tensor, torch.Tensor]: x = self.feature_extractor(x) x = self.feature_projection(x.transpose(1, 2)) x, mask = self.mask(x) x = x + self.positional_embedding(x) x = self.dropout(self.norm(x)) x = self.encoder(x, output_layer=layer) return x, mask def logits(self, x: torch.Tensor) -> torch.Tensor: logits = torch.cosine_similarity( x.unsqueeze(2), self.label_embedding.weight.unsqueeze(0).unsqueeze(0), dim=-1, ) return logits / 0.1 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x, mask = self.encode(x) x = self.proj(x) logits = self.logits(x) return logits, mask class HubertSoft(Hubert): def __init__(self): super().__init__() @torch.inference_mode() def units(self, wav: torch.Tensor) -> torch.Tensor: wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) x, _ = self.encode(wav) return self.proj(x) class HubertDiscrete(Hubert): def __init__(self, kmeans): super().__init__(504) self.kmeans = kmeans @torch.inference_mode() def units(self, wav: torch.Tensor) -> torch.LongTensor: wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) x, _ = self.encode(wav, layer=7) x = self.kmeans.predict(x.squeeze().cpu().numpy()) return torch.tensor(x, dtype=torch.long, device=wav.device) class FeatureExtractor(nn.Module): def __init__(self): super().__init__() self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False) self.norm0 = nn.GroupNorm(512, 512) self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False) self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False) self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False) self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False) self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False) self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.gelu(self.norm0(self.conv0(x))) x = F.gelu(self.conv1(x)) x = F.gelu(self.conv2(x)) x = F.gelu(self.conv3(x)) x = F.gelu(self.conv4(x)) x = F.gelu(self.conv5(x)) x = F.gelu(self.conv6(x)) return x class FeatureProjection(nn.Module): def __init__(self): super().__init__() self.norm = nn.LayerNorm(512) self.projection = nn.Linear(512, 768) self.dropout = nn.Dropout(0.1) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.norm(x) x = self.projection(x) x = self.dropout(x) return x class PositionalConvEmbedding(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv1d( 768, 768, kernel_size=128, padding=128 // 2, groups=16, ) self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x.transpose(1, 2)) x = F.gelu(x[:, :, :-1]) return x.transpose(1, 2) class TransformerEncoder(nn.Module): def __init__( self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int ) -> None: super(TransformerEncoder, self).__init__() self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for _ in range(num_layers)] ) self.num_layers = num_layers def forward( self, src: torch.Tensor, mask: torch.Tensor = None, src_key_padding_mask: torch.Tensor = None, output_layer: Optional[int] = None, ) -> torch.Tensor: output = src for layer in self.layers[:output_layer]: output = layer( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask ) return output def _compute_mask( shape: Tuple[int, int], mask_prob: float, mask_length: int, device: torch.device, min_masks: int = 0, ) -> torch.Tensor: batch_size, sequence_length = shape if mask_length < 1: raise ValueError("`mask_length` has to be bigger than 0.") if mask_length > sequence_length: raise ValueError( f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" ) # compute number of masked spans in batch num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random()) num_masked_spans = max(num_masked_spans, min_masks) # make sure num masked indices <= sequence_length if num_masked_spans * mask_length > sequence_length: num_masked_spans = sequence_length // mask_length # SpecAugment mask to fill mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool) # uniform distribution to sample from, make sure that offset samples are < sequence_length uniform_dist = torch.ones( (batch_size, sequence_length - (mask_length - 1)), device=device ) # get random indices to mask mask_indices = torch.multinomial(uniform_dist, num_masked_spans) # expand masked indices to masked spans mask_indices = ( mask_indices.unsqueeze(dim=-1) .expand((batch_size, num_masked_spans, mask_length)) .reshape(batch_size, num_masked_spans * mask_length) ) offsets = ( torch.arange(mask_length, device=device)[None, None, :] .expand((batch_size, num_masked_spans, mask_length)) .reshape(batch_size, num_masked_spans * mask_length) ) mask_idxs = mask_indices + offsets # scatter indices to mask mask = mask.scatter(1, mask_idxs, True) return mask def hubert_discrete( pretrained: bool = True, progress: bool = True, ) -> HubertDiscrete: r"""HuBERT-Discrete from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. Args: pretrained (bool): load pretrained weights into the model progress (bool): show progress bar when downloading model """ kmeans = kmeans100(pretrained=pretrained, progress=progress) hubert = HubertDiscrete(kmeans) if pretrained: checkpoint = torch.hub.load_state_dict_from_url( URLS["hubert-discrete"], progress=progress ) consume_prefix_in_state_dict_if_present(checkpoint, "module.") hubert.load_state_dict(checkpoint) hubert.eval() return hubert def hubert_soft( pretrained: bool = True, progress: bool = True, ) -> HubertSoft: r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. Args: pretrained (bool): load pretrained weights into the model progress (bool): show progress bar when downloading model """ hubert = HubertSoft() if pretrained: checkpoint = torch.hub.load_state_dict_from_url( URLS["hubert-soft"], progress=progress ) consume_prefix_in_state_dict_if_present(checkpoint, "module.") hubert.load_state_dict(checkpoint) hubert.eval() return hubert def _kmeans( num_clusters: int, pretrained: bool = True, progress: bool = True ) -> KMeans: kmeans = KMeans(num_clusters) if pretrained: checkpoint = torch.hub.load_state_dict_from_url( URLS[f"kmeans{num_clusters}"], progress=progress ) kmeans.__dict__["n_features_in_"] = checkpoint["n_features_in_"] kmeans.__dict__["_n_threads"] = checkpoint["_n_threads"] kmeans.__dict__["cluster_centers_"] = checkpoint["cluster_centers_"].numpy() return kmeans def kmeans100(pretrained: bool = True, progress: bool = True) -> KMeans: r""" k-means checkpoint for HuBERT-Discrete with 100 clusters. Args: pretrained (bool): load pretrained weights into the model progress (bool): show progress bar when downloading model """ return _kmeans(100, pretrained, progress)