from __future__ import annotations import math from logging import getLogger from pathlib import Path from typing import Any import numpy as np import torch from cm_time import timer from joblib import Parallel, delayed from sklearn.cluster import KMeans, MiniBatchKMeans from tqdm_joblib import tqdm_joblib LOG = getLogger(__name__) def train_cluster( input_dir: Path | str, n_clusters: int, use_minibatch: bool = True, batch_size: int = 4096, partial_fit: bool = False, verbose: bool = False, ) -> dict: input_dir = Path(input_dir) if not partial_fit: LOG.info(f"Loading features from {input_dir}") features = [] for path in input_dir.rglob("*.data.pt"): with path.open("rb") as f: features.append( torch.load(f, weights_only=True)["content"].squeeze(0).numpy().T ) if not features: raise ValueError(f"No features found in {input_dir}") features = np.concatenate(features, axis=0).astype(np.float32) if features.shape[0] < n_clusters: raise ValueError( "Too few HuBERT features to cluster. Consider using a smaller number of clusters." ) LOG.info( f"shape: {features.shape}, size: {features.nbytes/1024**2:.2f} MB, dtype: {features.dtype}" ) with timer() as t: if use_minibatch: kmeans = MiniBatchKMeans( n_clusters=n_clusters, verbose=verbose, batch_size=batch_size, max_iter=80, n_init="auto", ).fit(features) else: kmeans = KMeans( n_clusters=n_clusters, verbose=verbose, n_init="auto" ).fit(features) LOG.info(f"Clustering took {t.elapsed:.2f} seconds") x = { "n_features_in_": kmeans.n_features_in_, "_n_threads": kmeans._n_threads, "cluster_centers_": kmeans.cluster_centers_, } return x else: # minibatch partial fit paths = list(input_dir.rglob("*.data.pt")) if len(paths) == 0: raise ValueError(f"No features found in {input_dir}") LOG.info(f"Found {len(paths)} features in {input_dir}") n_batches = math.ceil(len(paths) / batch_size) LOG.info(f"Splitting into {n_batches} batches") with timer() as t: kmeans = MiniBatchKMeans( n_clusters=n_clusters, verbose=verbose, batch_size=batch_size, max_iter=80, n_init="auto", ) for i in range(0, len(paths), batch_size): LOG.info( f"Processing batch {i//batch_size+1}/{n_batches} for speaker {input_dir.stem}" ) features = [] for path in paths[i : i + batch_size]: with path.open("rb") as f: features.append( torch.load(f, weights_only=True)["content"] .squeeze(0) .numpy() .T ) features = np.concatenate(features, axis=0).astype(np.float32) kmeans.partial_fit(features) LOG.info(f"Clustering took {t.elapsed:.2f} seconds") x = { "n_features_in_": kmeans.n_features_in_, "_n_threads": kmeans._n_threads, "cluster_centers_": kmeans.cluster_centers_, } return x def main( input_dir: Path | str, output_path: Path | str, n_clusters: int = 10000, use_minibatch: bool = True, batch_size: int = 4096, partial_fit: bool = False, verbose: bool = False, ) -> None: input_dir = Path(input_dir) output_path = Path(output_path) if not (use_minibatch or not partial_fit): raise ValueError("partial_fit requires use_minibatch") def train_cluster_(input_path: Path, **kwargs: Any) -> tuple[str, dict]: return input_path.stem, train_cluster(input_path, **kwargs) with tqdm_joblib(desc="Training clusters", total=len(list(input_dir.iterdir()))): parallel_result = Parallel(n_jobs=-1)( delayed(train_cluster_)( speaker_name, n_clusters=n_clusters, use_minibatch=use_minibatch, batch_size=batch_size, partial_fit=partial_fit, verbose=verbose, ) for speaker_name in input_dir.iterdir() ) assert parallel_result is not None checkpoint = dict(parallel_result) output_path.parent.mkdir(exist_ok=True, parents=True) with output_path.open("wb") as f: torch.save(checkpoint, f)