Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
from pathlib import Path | |
from typing import Any | |
import torch | |
from sklearn.cluster import KMeans | |
def get_cluster_model(ckpt_path: Path | str): | |
with Path(ckpt_path).open("rb") as f: | |
checkpoint = torch.load( | |
f, map_location="cpu" | |
) # Danger of arbitrary code execution | |
kmeans_dict = {} | |
for spk, ckpt in checkpoint.items(): | |
km = KMeans(ckpt["n_features_in_"]) | |
km.__dict__["n_features_in_"] = ckpt["n_features_in_"] | |
km.__dict__["_n_threads"] = ckpt["_n_threads"] | |
km.__dict__["cluster_centers_"] = ckpt["cluster_centers_"] | |
kmeans_dict[spk] = km | |
return kmeans_dict | |
def check_speaker(model: Any, speaker: Any): | |
if speaker not in model: | |
raise ValueError(f"Speaker {speaker} not in {list(model.keys())}") | |
def get_cluster_result(model: Any, x: Any, speaker: Any): | |
""" | |
x: np.array [t, 256] | |
return cluster class result | |
""" | |
check_speaker(model, speaker) | |
return model[speaker].predict(x) | |
def get_cluster_center_result(model: Any, x: Any, speaker: Any): | |
"""x: np.array [t, 256]""" | |
check_speaker(model, speaker) | |
predict = model[speaker].predict(x) | |
return model[speaker].cluster_centers_[predict] | |
def get_center(model: Any, x: Any, speaker: Any): | |
check_speaker(model, speaker) | |
return model[speaker].cluster_centers_[x] | |