Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import json | |
import os | |
import re | |
import subprocess | |
import warnings | |
from itertools import groupby | |
from logging import getLogger | |
from pathlib import Path | |
from typing import Any, Literal, Sequence | |
import matplotlib | |
import matplotlib.pylab as plt | |
import numpy as np | |
import requests | |
import torch | |
import torch.backends.mps | |
import torch.nn as nn | |
import torchaudio | |
from cm_time import timer | |
from numpy import ndarray | |
from tqdm import tqdm | |
from transformers import HubertModel | |
from so_vits_svc_fork.hparams import HParams | |
LOG = getLogger(__name__) | |
HUBERT_SAMPLING_RATE = 16000 | |
IS_COLAB = os.getenv("COLAB_RELEASE_TAG", False) | |
def get_optimal_device(index: int = 0) -> torch.device: | |
if torch.cuda.is_available(): | |
return torch.device(f"cuda:{index % torch.cuda.device_count()}") | |
elif torch.backends.mps.is_available(): | |
return torch.device("mps") | |
else: | |
try: | |
import torch_xla.core.xla_model as xm # noqa | |
if xm.xrt_world_size() > 0: | |
return torch.device("xla") | |
# return xm.xla_device() | |
except ImportError: | |
pass | |
return torch.device("cpu") | |
def download_file( | |
url: str, | |
filepath: Path | str, | |
chunk_size: int = 64 * 1024, | |
tqdm_cls: type = tqdm, | |
skip_if_exists: bool = False, | |
overwrite: bool = False, | |
**tqdm_kwargs: Any, | |
): | |
if skip_if_exists is True and overwrite is True: | |
raise ValueError("skip_if_exists and overwrite cannot be both True") | |
filepath = Path(filepath) | |
filepath.parent.mkdir(parents=True, exist_ok=True) | |
temppath = filepath.parent / f"{filepath.name}.download" | |
if filepath.exists(): | |
if skip_if_exists: | |
return | |
elif not overwrite: | |
filepath.unlink() | |
else: | |
raise FileExistsError(f"{filepath} already exists") | |
temppath.unlink(missing_ok=True) | |
resp = requests.get(url, stream=True) | |
total = int(resp.headers.get("content-length", 0)) | |
kwargs = dict( | |
total=total, | |
unit="iB", | |
unit_scale=True, | |
unit_divisor=1024, | |
desc=f"Downloading {filepath.name}", | |
) | |
kwargs.update(tqdm_kwargs) | |
with temppath.open("wb") as f, tqdm_cls(**kwargs) as pbar: | |
for data in resp.iter_content(chunk_size=chunk_size): | |
size = f.write(data) | |
pbar.update(size) | |
temppath.rename(filepath) | |
PRETRAINED_MODEL_URLS = { | |
"hifi-gan": [ | |
[ | |
"https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/D_0.pth", | |
"https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/G_0.pth", | |
], | |
[ | |
"https://huggingface.co/Himawari00/so-vits-svc4.0-pretrain-models/resolve/main/D_0.pth", | |
"https://huggingface.co/Himawari00/so-vits-svc4.0-pretrain-models/resolve/main/G_0.pth", | |
], | |
], | |
"contentvec": [ | |
[ | |
"https://huggingface.co/therealvul/so-vits-svc-4.0-init/resolve/main/checkpoint_best_legacy_500.pt" | |
], | |
[ | |
"https://huggingface.co/Himawari00/so-vits-svc4.0-pretrain-models/resolve/main/checkpoint_best_legacy_500.pt" | |
], | |
[ | |
"http://obs.cstcloud.cn/share/obs/sankagenkeshi/checkpoint_best_legacy_500.pt" | |
], | |
], | |
} | |
from joblib import Parallel, delayed | |
def ensure_pretrained_model( | |
folder_path: Path | str, type_: str | dict[str, str], **tqdm_kwargs: Any | |
) -> tuple[Path, ...] | None: | |
folder_path = Path(folder_path) | |
# new code | |
if not isinstance(type_, str): | |
try: | |
Parallel(n_jobs=len(type_))( | |
[ | |
delayed(download_file)( | |
url, | |
folder_path / filename, | |
position=i, | |
skip_if_exists=True, | |
**tqdm_kwargs, | |
) | |
for i, (filename, url) in enumerate(type_.items()) | |
] | |
) | |
return tuple(folder_path / filename for filename in type_.values()) | |
except Exception as e: | |
LOG.error(f"Failed to download {type_}") | |
LOG.exception(e) | |
# old code | |
models_candidates = PRETRAINED_MODEL_URLS.get(type_, None) | |
if models_candidates is None: | |
LOG.warning(f"Unknown pretrained model type: {type_}") | |
return | |
for model_urls in models_candidates: | |
paths = [folder_path / model_url.split("/")[-1] for model_url in model_urls] | |
try: | |
Parallel(n_jobs=len(paths))( | |
[ | |
delayed(download_file)( | |
url, path, position=i, skip_if_exists=True, **tqdm_kwargs | |
) | |
for i, (url, path) in enumerate(zip(model_urls, paths)) | |
] | |
) | |
return tuple(paths) | |
except Exception as e: | |
LOG.error(f"Failed to download {model_urls}") | |
LOG.exception(e) | |
class HubertModelWithFinalProj(HubertModel): | |
def __init__(self, config): | |
super().__init__(config) | |
# The final projection layer is only used for backward compatibility. | |
# Following https://github.com/auspicious3000/contentvec/issues/6 | |
# Remove this layer is necessary to achieve the desired outcome. | |
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) | |
def remove_weight_norm_if_exists(module, name: str = "weight"): | |
r"""Removes the weight normalization reparameterization from a module. | |
Args: | |
module (Module): containing module | |
name (str, optional): name of weight parameter | |
Example: | |
>>> m = weight_norm(nn.Linear(20, 40)) | |
>>> remove_weight_norm(m) | |
""" | |
from torch.nn.utils.weight_norm import WeightNorm | |
for k, hook in module._forward_pre_hooks.items(): | |
if isinstance(hook, WeightNorm) and hook.name == name: | |
hook.remove(module) | |
del module._forward_pre_hooks[k] | |
return module | |
def get_hubert_model( | |
device: str | torch.device, final_proj: bool = True | |
) -> HubertModel: | |
if final_proj: | |
model = HubertModelWithFinalProj.from_pretrained("lengyue233/content-vec-best") | |
else: | |
model = HubertModel.from_pretrained("lengyue233/content-vec-best") | |
# Hubert is always used in inference mode, we can safely remove weight-norms | |
for m in model.modules(): | |
if isinstance(m, (nn.Conv2d, nn.Conv1d)): | |
remove_weight_norm_if_exists(m) | |
return model.to(device) | |
def get_content( | |
cmodel: HubertModel, | |
audio: torch.Tensor | ndarray[Any, Any], | |
device: torch.device | str, | |
sr: int, | |
legacy_final_proj: bool = False, | |
) -> torch.Tensor: | |
audio = torch.as_tensor(audio) | |
if sr != HUBERT_SAMPLING_RATE: | |
audio = ( | |
torchaudio.transforms.Resample(sr, HUBERT_SAMPLING_RATE) | |
.to(audio.device)(audio) | |
.to(device) | |
) | |
if audio.ndim == 1: | |
audio = audio.unsqueeze(0) | |
with torch.no_grad(), timer() as t: | |
if legacy_final_proj: | |
warnings.warn("legacy_final_proj is deprecated") | |
if not hasattr(cmodel, "final_proj"): | |
raise ValueError("HubertModel does not have final_proj") | |
c = cmodel(audio, output_hidden_states=True)["hidden_states"][9] | |
c = cmodel.final_proj(c) | |
else: | |
c = cmodel(audio)["last_hidden_state"] | |
c = c.transpose(1, 2) | |
wav_len = audio.shape[-1] / HUBERT_SAMPLING_RATE | |
LOG.info( | |
f"HuBERT inference time : {t.elapsed:.3f}s, RTF: {t.elapsed / wav_len:.3f}" | |
) | |
return c | |
def _substitute_if_same_shape(to_: dict[str, Any], from_: dict[str, Any]) -> None: | |
not_in_to = list(filter(lambda x: x not in to_, from_.keys())) | |
not_in_from = list(filter(lambda x: x not in from_, to_.keys())) | |
if not_in_to: | |
warnings.warn(f"Keys not found in model state dict:" f"{not_in_to}") | |
if not_in_from: | |
warnings.warn(f"Keys not found in checkpoint state dict:" f"{not_in_from}") | |
shape_missmatch = [] | |
for k, v in from_.items(): | |
if k not in to_: | |
pass | |
elif hasattr(v, "shape"): | |
if not hasattr(to_[k], "shape"): | |
raise ValueError(f"Key {k} is not a tensor") | |
if to_[k].shape == v.shape: | |
to_[k] = v | |
else: | |
shape_missmatch.append((k, to_[k].shape, v.shape)) | |
elif isinstance(v, dict): | |
assert isinstance(to_[k], dict) | |
_substitute_if_same_shape(to_[k], v) | |
else: | |
to_[k] = v | |
if shape_missmatch: | |
warnings.warn( | |
f"Shape mismatch: {[f'{k}: {v1} -> {v2}' for k, v1, v2 in shape_missmatch]}" | |
) | |
def safe_load(model: torch.nn.Module, state_dict: dict[str, Any]) -> None: | |
model_state_dict = model.state_dict() | |
_substitute_if_same_shape(model_state_dict, state_dict) | |
model.load_state_dict(model_state_dict) | |
def load_checkpoint( | |
checkpoint_path: Path | str, | |
model: torch.nn.Module, | |
optimizer: torch.optim.Optimizer | None = None, | |
skip_optimizer: bool = False, | |
) -> tuple[torch.nn.Module, torch.optim.Optimizer | None, float, int]: | |
if not Path(checkpoint_path).is_file(): | |
raise FileNotFoundError(f"File {checkpoint_path} not found") | |
with Path(checkpoint_path).open("rb") as f: | |
with warnings.catch_warnings(): | |
warnings.filterwarnings( | |
"ignore", category=UserWarning, message="TypedStorage is deprecated" | |
) | |
checkpoint_dict = torch.load(f, map_location="cpu", weights_only=True) | |
iteration = checkpoint_dict["iteration"] | |
learning_rate = checkpoint_dict["learning_rate"] | |
# safe load module | |
if hasattr(model, "module"): | |
safe_load(model.module, checkpoint_dict["model"]) | |
else: | |
safe_load(model, checkpoint_dict["model"]) | |
# safe load optim | |
if ( | |
optimizer is not None | |
and not skip_optimizer | |
and checkpoint_dict["optimizer"] is not None | |
): | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
safe_load(optimizer, checkpoint_dict["optimizer"]) | |
LOG.info(f"Loaded checkpoint '{checkpoint_path}' (epoch {iteration})") | |
return model, optimizer, learning_rate, iteration | |
def save_checkpoint( | |
model: torch.nn.Module, | |
optimizer: torch.optim.Optimizer, | |
learning_rate: float, | |
iteration: int, | |
checkpoint_path: Path | str, | |
) -> None: | |
LOG.info( | |
"Saving model and optimizer state at epoch {} to {}".format( | |
iteration, checkpoint_path | |
) | |
) | |
if hasattr(model, "module"): | |
state_dict = model.module.state_dict() | |
else: | |
state_dict = model.state_dict() | |
with Path(checkpoint_path).open("wb") as f: | |
torch.save( | |
{ | |
"model": state_dict, | |
"iteration": iteration, | |
"optimizer": optimizer.state_dict(), | |
"learning_rate": learning_rate, | |
}, | |
f, | |
) | |
def clean_checkpoints( | |
path_to_models: Path | str, n_ckpts_to_keep: int = 2, sort_by_time: bool = True | |
) -> None: | |
"""Freeing up space by deleting saved ckpts | |
Arguments: | |
path_to_models -- Path to the model directory | |
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth | |
sort_by_time -- True -> chronologically delete ckpts | |
False -> lexicographically delete ckpts | |
""" | |
LOG.info("Cleaning old checkpoints...") | |
path_to_models = Path(path_to_models) | |
# Define sort key functions | |
name_key = lambda p: int(re.match(r"[GD]_(\d+)", p.stem).group(1)) | |
time_key = lambda p: p.stat().st_mtime | |
path_key = lambda p: (p.stem[0], time_key(p) if sort_by_time else name_key(p)) | |
models = list( | |
filter( | |
lambda p: ( | |
p.is_file() | |
and re.match(r"[GD]_\d+", p.stem) | |
and not p.stem.endswith("_0") | |
), | |
path_to_models.glob("*.pth"), | |
) | |
) | |
models_sorted = sorted(models, key=path_key) | |
models_sorted_grouped = groupby(models_sorted, lambda p: p.stem[0]) | |
for group_name, group_items in models_sorted_grouped: | |
to_delete_list = list(group_items)[:-n_ckpts_to_keep] | |
for to_delete in to_delete_list: | |
if to_delete.exists(): | |
LOG.info(f"Removing {to_delete}") | |
if IS_COLAB: | |
to_delete.write_text("") | |
to_delete.unlink() | |
def latest_checkpoint_path(dir_path: Path | str, regex: str = "G_*.pth") -> Path | None: | |
dir_path = Path(dir_path) | |
name_key = lambda p: int(re.match(r"._(\d+)\.pth", p.name).group(1)) | |
paths = list(sorted(dir_path.glob(regex), key=name_key)) | |
if len(paths) == 0: | |
return None | |
return paths[-1] | |
def plot_spectrogram_to_numpy(spectrogram: ndarray) -> ndarray: | |
matplotlib.use("Agg") | |
fig, ax = plt.subplots(figsize=(10, 2)) | |
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") | |
plt.colorbar(im, ax=ax) | |
plt.xlabel("Frames") | |
plt.ylabel("Channels") | |
plt.tight_layout() | |
fig.canvas.draw() | |
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
plt.close() | |
return data | |
def get_backup_hparams( | |
config_path: Path, model_path: Path, init: bool = True | |
) -> HParams: | |
model_path.mkdir(parents=True, exist_ok=True) | |
config_save_path = model_path / "config.json" | |
if init: | |
with config_path.open() as f: | |
data = f.read() | |
with config_save_path.open("w") as f: | |
f.write(data) | |
else: | |
with config_save_path.open() as f: | |
data = f.read() | |
config = json.loads(data) | |
hparams = HParams(**config) | |
hparams.model_dir = model_path.as_posix() | |
return hparams | |
def get_hparams(config_path: Path | str) -> HParams: | |
config = json.loads(Path(config_path).read_text("utf-8")) | |
hparams = HParams(**config) | |
return hparams | |
def repeat_expand_2d(content: torch.Tensor, target_len: int) -> torch.Tensor: | |
# content : [h, t] | |
src_len = content.shape[-1] | |
if target_len < src_len: | |
return content[:, :target_len] | |
else: | |
return torch.nn.functional.interpolate( | |
content.unsqueeze(0), size=target_len, mode="nearest" | |
).squeeze(0) | |
def plot_data_to_numpy(x: ndarray, y: ndarray) -> ndarray: | |
matplotlib.use("Agg") | |
fig, ax = plt.subplots(figsize=(10, 2)) | |
plt.plot(x) | |
plt.plot(y) | |
plt.tight_layout() | |
fig.canvas.draw() | |
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
plt.close() | |
return data | |
def get_gpu_memory(type_: Literal["total", "free", "used"]) -> Sequence[int] | None: | |
command = f"nvidia-smi --query-gpu=memory.{type_} --format=csv" | |
try: | |
memory_free_info = ( | |
subprocess.check_output(command.split()) | |
.decode("ascii") | |
.split("\n")[:-1][1:] | |
) | |
memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)] | |
return memory_free_values | |
except Exception: | |
return | |
def get_total_gpu_memory(type_: Literal["total", "free", "used"]) -> int | None: | |
memories = get_gpu_memory(type_) | |
if memories is None: | |
return | |
return sum(memories) | |