Spaces:
Running
Running
File size: 4,914 Bytes
d5d7329 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from __future__ import annotations
import math
from logging import getLogger
from pathlib import Path
from typing import Any
import numpy as np
import torch
from cm_time import timer
from joblib import Parallel, delayed
from sklearn.cluster import KMeans, MiniBatchKMeans
from tqdm_joblib import tqdm_joblib
LOG = getLogger(__name__)
def train_cluster(
input_dir: Path | str,
n_clusters: int,
use_minibatch: bool = True,
batch_size: int = 4096,
partial_fit: bool = False,
verbose: bool = False,
) -> dict:
input_dir = Path(input_dir)
if not partial_fit:
LOG.info(f"Loading features from {input_dir}")
features = []
for path in input_dir.rglob("*.data.pt"):
with path.open("rb") as f:
features.append(
torch.load(f, weights_only=True)["content"].squeeze(0).numpy().T
)
if not features:
raise ValueError(f"No features found in {input_dir}")
features = np.concatenate(features, axis=0).astype(np.float32)
if features.shape[0] < n_clusters:
raise ValueError(
"Too few HuBERT features to cluster. Consider using a smaller number of clusters."
)
LOG.info(
f"shape: {features.shape}, size: {features.nbytes/1024**2:.2f} MB, dtype: {features.dtype}"
)
with timer() as t:
if use_minibatch:
kmeans = MiniBatchKMeans(
n_clusters=n_clusters,
verbose=verbose,
batch_size=batch_size,
max_iter=80,
n_init="auto",
).fit(features)
else:
kmeans = KMeans(
n_clusters=n_clusters, verbose=verbose, n_init="auto"
).fit(features)
LOG.info(f"Clustering took {t.elapsed:.2f} seconds")
x = {
"n_features_in_": kmeans.n_features_in_,
"_n_threads": kmeans._n_threads,
"cluster_centers_": kmeans.cluster_centers_,
}
return x
else:
# minibatch partial fit
paths = list(input_dir.rglob("*.data.pt"))
if len(paths) == 0:
raise ValueError(f"No features found in {input_dir}")
LOG.info(f"Found {len(paths)} features in {input_dir}")
n_batches = math.ceil(len(paths) / batch_size)
LOG.info(f"Splitting into {n_batches} batches")
with timer() as t:
kmeans = MiniBatchKMeans(
n_clusters=n_clusters,
verbose=verbose,
batch_size=batch_size,
max_iter=80,
n_init="auto",
)
for i in range(0, len(paths), batch_size):
LOG.info(
f"Processing batch {i//batch_size+1}/{n_batches} for speaker {input_dir.stem}"
)
features = []
for path in paths[i : i + batch_size]:
with path.open("rb") as f:
features.append(
torch.load(f, weights_only=True)["content"]
.squeeze(0)
.numpy()
.T
)
features = np.concatenate(features, axis=0).astype(np.float32)
kmeans.partial_fit(features)
LOG.info(f"Clustering took {t.elapsed:.2f} seconds")
x = {
"n_features_in_": kmeans.n_features_in_,
"_n_threads": kmeans._n_threads,
"cluster_centers_": kmeans.cluster_centers_,
}
return x
def main(
input_dir: Path | str,
output_path: Path | str,
n_clusters: int = 10000,
use_minibatch: bool = True,
batch_size: int = 4096,
partial_fit: bool = False,
verbose: bool = False,
) -> None:
input_dir = Path(input_dir)
output_path = Path(output_path)
if not (use_minibatch or not partial_fit):
raise ValueError("partial_fit requires use_minibatch")
def train_cluster_(input_path: Path, **kwargs: Any) -> tuple[str, dict]:
return input_path.stem, train_cluster(input_path, **kwargs)
with tqdm_joblib(desc="Training clusters", total=len(list(input_dir.iterdir()))):
parallel_result = Parallel(n_jobs=-1)(
delayed(train_cluster_)(
speaker_name,
n_clusters=n_clusters,
use_minibatch=use_minibatch,
batch_size=batch_size,
partial_fit=partial_fit,
verbose=verbose,
)
for speaker_name in input_dir.iterdir()
)
assert parallel_result is not None
checkpoint = dict(parallel_result)
output_path.parent.mkdir(exist_ok=True, parents=True)
with output_path.open("wb") as f:
torch.save(checkpoint, f)
|