Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
from logging import getLogger | |
from pathlib import Path | |
from typing import Literal, Sequence | |
import librosa | |
import numpy as np | |
import soundfile | |
import torch | |
from cm_time import timer | |
from tqdm import tqdm | |
from so_vits_svc_fork.inference.core import RealtimeVC, RealtimeVC2, Svc | |
from so_vits_svc_fork.utils import get_optimal_device | |
LOG = getLogger(__name__) | |
def infer( | |
*, | |
# paths | |
input_path: Path | str | Sequence[Path | str], | |
output_path: Path | str | Sequence[Path | str], | |
model_path: Path | str, | |
config_path: Path | str, | |
recursive: bool = False, | |
# svc config | |
speaker: int | str, | |
cluster_model_path: Path | str | None = None, | |
transpose: int = 0, | |
auto_predict_f0: bool = False, | |
cluster_infer_ratio: float = 0, | |
noise_scale: float = 0.4, | |
f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio", | |
# slice config | |
db_thresh: int = -40, | |
pad_seconds: float = 0.5, | |
chunk_seconds: float = 0.5, | |
absolute_thresh: bool = False, | |
max_chunk_seconds: float = 40, | |
device: str | torch.device = get_optimal_device(), | |
): | |
if isinstance(input_path, (str, Path)): | |
input_path = [input_path] | |
if isinstance(output_path, (str, Path)): | |
output_path = [output_path] | |
if len(input_path) != len(output_path): | |
raise ValueError( | |
f"input_path and output_path must have same length, but got {len(input_path)} and {len(output_path)}" | |
) | |
model_path = Path(model_path) | |
config_path = Path(config_path) | |
output_path = [Path(p) for p in output_path] | |
input_path = [Path(p) for p in input_path] | |
output_paths = [] | |
input_paths = [] | |
for input_path, output_path in zip(input_path, output_path): | |
if input_path.is_dir(): | |
if not recursive: | |
raise ValueError( | |
f"input_path is a directory, but recursive is False: {input_path}" | |
) | |
input_paths.extend(list(input_path.rglob("*.*"))) | |
output_paths.extend( | |
[output_path / p.relative_to(input_path) for p in input_paths] | |
) | |
continue | |
input_paths.append(input_path) | |
output_paths.append(output_path) | |
cluster_model_path = Path(cluster_model_path) if cluster_model_path else None | |
svc_model = Svc( | |
net_g_path=model_path.as_posix(), | |
config_path=config_path.as_posix(), | |
cluster_model_path=cluster_model_path.as_posix() | |
if cluster_model_path | |
else None, | |
device=device, | |
) | |
try: | |
pbar = tqdm(list(zip(input_paths, output_paths)), disable=len(input_paths) == 1) | |
for input_path, output_path in pbar: | |
pbar.set_description(f"{input_path}") | |
try: | |
audio, _ = librosa.load(str(input_path), sr=svc_model.target_sample) | |
except Exception as e: | |
LOG.error(f"Failed to load {input_path}") | |
LOG.exception(e) | |
continue | |
output_path.parent.mkdir(parents=True, exist_ok=True) | |
audio = svc_model.infer_silence( | |
audio.astype(np.float32), | |
speaker=speaker, | |
transpose=transpose, | |
auto_predict_f0=auto_predict_f0, | |
cluster_infer_ratio=cluster_infer_ratio, | |
noise_scale=noise_scale, | |
f0_method=f0_method, | |
db_thresh=db_thresh, | |
pad_seconds=pad_seconds, | |
chunk_seconds=chunk_seconds, | |
absolute_thresh=absolute_thresh, | |
max_chunk_seconds=max_chunk_seconds, | |
) | |
soundfile.write(str(output_path), audio, svc_model.target_sample) | |
finally: | |
del svc_model | |
torch.cuda.empty_cache() | |
def realtime( | |
*, | |
# paths | |
model_path: Path | str, | |
config_path: Path | str, | |
# svc config | |
speaker: str, | |
cluster_model_path: Path | str | None = None, | |
transpose: int = 0, | |
auto_predict_f0: bool = False, | |
cluster_infer_ratio: float = 0, | |
noise_scale: float = 0.4, | |
f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio", | |
# slice config | |
db_thresh: int = -40, | |
pad_seconds: float = 0.5, | |
chunk_seconds: float = 0.5, | |
# realtime config | |
crossfade_seconds: float = 0.05, | |
additional_infer_before_seconds: float = 0.2, | |
additional_infer_after_seconds: float = 0.1, | |
block_seconds: float = 0.5, | |
version: int = 2, | |
input_device: int | str | None = None, | |
output_device: int | str | None = None, | |
device: str | torch.device = get_optimal_device(), | |
passthrough_original: bool = False, | |
): | |
import sounddevice as sd | |
model_path = Path(model_path) | |
config_path = Path(config_path) | |
cluster_model_path = Path(cluster_model_path) if cluster_model_path else None | |
svc_model = Svc( | |
net_g_path=model_path.as_posix(), | |
config_path=config_path.as_posix(), | |
cluster_model_path=cluster_model_path.as_posix() | |
if cluster_model_path | |
else None, | |
device=device, | |
) | |
LOG.info("Creating realtime model...") | |
if version == 1: | |
model = RealtimeVC( | |
svc_model=svc_model, | |
crossfade_len=int(crossfade_seconds * svc_model.target_sample), | |
additional_infer_before_len=int( | |
additional_infer_before_seconds * svc_model.target_sample | |
), | |
additional_infer_after_len=int( | |
additional_infer_after_seconds * svc_model.target_sample | |
), | |
) | |
else: | |
model = RealtimeVC2( | |
svc_model=svc_model, | |
) | |
# LOG all device info | |
devices = sd.query_devices() | |
LOG.info(f"Device: {devices}") | |
if isinstance(input_device, str): | |
input_device_candidates = [ | |
i for i, d in enumerate(devices) if d["name"] == input_device | |
] | |
if len(input_device_candidates) == 0: | |
LOG.warning(f"Input device {input_device} not found, using default") | |
input_device = None | |
else: | |
input_device = input_device_candidates[0] | |
if isinstance(output_device, str): | |
output_device_candidates = [ | |
i for i, d in enumerate(devices) if d["name"] == output_device | |
] | |
if len(output_device_candidates) == 0: | |
LOG.warning(f"Output device {output_device} not found, using default") | |
output_device = None | |
else: | |
output_device = output_device_candidates[0] | |
if input_device is None or input_device >= len(devices): | |
input_device = sd.default.device[0] | |
if output_device is None or output_device >= len(devices): | |
output_device = sd.default.device[1] | |
LOG.info( | |
f"Input Device: {devices[input_device]['name']}, Output Device: {devices[output_device]['name']}" | |
) | |
# the model RTL is somewhat significantly high only in the first inference | |
# there could be no better way to warm up the model than to do a dummy inference | |
# (there are not differences in the behavior of the model between the first and the later inferences) | |
# so we do a dummy inference to warm up the model (1 second of audio) | |
LOG.info("Warming up the model...") | |
svc_model.infer( | |
speaker=speaker, | |
transpose=transpose, | |
auto_predict_f0=auto_predict_f0, | |
cluster_infer_ratio=cluster_infer_ratio, | |
noise_scale=noise_scale, | |
f0_method=f0_method, | |
audio=np.zeros(svc_model.target_sample, dtype=np.float32), | |
) | |
def callback( | |
indata: np.ndarray, | |
outdata: np.ndarray, | |
frames: int, | |
time: int, | |
status: sd.CallbackFlags, | |
) -> None: | |
LOG.debug( | |
f"Frames: {frames}, Status: {status}, Shape: {indata.shape}, Time: {time}" | |
) | |
kwargs = dict( | |
input_audio=indata.mean(axis=1).astype(np.float32), | |
# svc config | |
speaker=speaker, | |
transpose=transpose, | |
auto_predict_f0=auto_predict_f0, | |
cluster_infer_ratio=cluster_infer_ratio, | |
noise_scale=noise_scale, | |
f0_method=f0_method, | |
# slice config | |
db_thresh=db_thresh, | |
# pad_seconds=pad_seconds, | |
chunk_seconds=chunk_seconds, | |
) | |
if version == 1: | |
kwargs["pad_seconds"] = pad_seconds | |
with timer() as t: | |
inference = model.process( | |
**kwargs, | |
).reshape(-1, 1) | |
if passthrough_original: | |
outdata[:] = (indata + inference) / 2 | |
else: | |
outdata[:] = inference | |
rtf = t.elapsed / block_seconds | |
LOG.info(f"Realtime inference time: {t.elapsed:.3f}s, RTF: {rtf:.3f}") | |
if rtf > 1: | |
LOG.warning("RTF is too high, consider increasing block_seconds") | |
try: | |
with sd.Stream( | |
device=(input_device, output_device), | |
channels=1, | |
callback=callback, | |
samplerate=svc_model.target_sample, | |
blocksize=int(block_seconds * svc_model.target_sample), | |
latency="low", | |
) as stream: | |
LOG.info(f"Latency: {stream.latency}") | |
while True: | |
sd.sleep(1000) | |
finally: | |
# del model, svc_model | |
torch.cuda.empty_cache() | |