bark-voice-cloning / hubert /pre_kmeans_hubert.py
Mylo
Initial commit
9449f27
raw
history blame
2.2 kB
from pathlib import Path
import torch
from torch import nn
from einops import pack, unpack
import fairseq
from torchaudio.functional import resample
import logging
logging.root.setLevel(logging.ERROR)
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
class CustomHubert(nn.Module):
"""
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
or you can train your own
"""
def __init__(
self,
checkpoint_path,
target_sample_hz=16000,
seq_len_multiple_of=None,
output_layer=9
):
super().__init__()
self.target_sample_hz = target_sample_hz
self.seq_len_multiple_of = seq_len_multiple_of
self.output_layer = output_layer
model_path = Path(checkpoint_path)
assert model_path.exists(), f'path {checkpoint_path} does not exist'
checkpoint = torch.load(checkpoint_path)
load_model_input = {checkpoint_path: checkpoint}
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
self.model = model[0]
self.model.eval()
@property
def groups(self):
return 1
@torch.no_grad()
def forward(
self,
wav_input,
flatten=True,
input_sample_hz=None
):
device = wav_input.device
if exists(input_sample_hz):
wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)
embed = self.model(
wav_input,
features_only=True,
mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code
output_layer=self.output_layer
)
embed, packed_shape = pack([embed['x']], '* d')
# codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())
codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long()
if flatten:
return codebook_indices
codebook_indices, = unpack(codebook_indices, packed_shape, '*')
return codebook_indices