Vladimir Alabov
Refactor #3
46b0a70
raw
history blame
9.42 kB
from __future__ import annotations
from logging import getLogger
from pathlib import Path
from typing import Literal, Sequence
import librosa
import numpy as np
import soundfile
import torch
from cm_time import timer
from tqdm import tqdm
from so_vits_svc_fork.inference.core import RealtimeVC, RealtimeVC2, Svc
from so_vits_svc_fork.utils import get_optimal_device
LOG = getLogger(__name__)
def infer(
*,
# paths
input_path: Path | str | Sequence[Path | str],
output_path: Path | str | Sequence[Path | str],
model_path: Path | str,
config_path: Path | str,
recursive: bool = False,
# svc config
speaker: int | str,
cluster_model_path: Path | str | None = None,
transpose: int = 0,
auto_predict_f0: bool = False,
cluster_infer_ratio: float = 0,
noise_scale: float = 0.4,
f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
# slice config
db_thresh: int = -40,
pad_seconds: float = 0.5,
chunk_seconds: float = 0.5,
absolute_thresh: bool = False,
max_chunk_seconds: float = 40,
device: str | torch.device = get_optimal_device(),
):
if isinstance(input_path, (str, Path)):
input_path = [input_path]
if isinstance(output_path, (str, Path)):
output_path = [output_path]
if len(input_path) != len(output_path):
raise ValueError(
f"input_path and output_path must have same length, but got {len(input_path)} and {len(output_path)}"
)
model_path = Path(model_path)
config_path = Path(config_path)
output_path = [Path(p) for p in output_path]
input_path = [Path(p) for p in input_path]
output_paths = []
input_paths = []
for input_path, output_path in zip(input_path, output_path):
if input_path.is_dir():
if not recursive:
raise ValueError(
f"input_path is a directory, but recursive is False: {input_path}"
)
input_paths.extend(list(input_path.rglob("*.*")))
output_paths.extend(
[output_path / p.relative_to(input_path) for p in input_paths]
)
continue
input_paths.append(input_path)
output_paths.append(output_path)
cluster_model_path = Path(cluster_model_path) if cluster_model_path else None
svc_model = Svc(
net_g_path=model_path.as_posix(),
config_path=config_path.as_posix(),
cluster_model_path=cluster_model_path.as_posix()
if cluster_model_path
else None,
device=device,
)
try:
pbar = tqdm(list(zip(input_paths, output_paths)), disable=len(input_paths) == 1)
for input_path, output_path in pbar:
pbar.set_description(f"{input_path}")
try:
audio, _ = librosa.load(str(input_path), sr=svc_model.target_sample)
except Exception as e:
LOG.error(f"Failed to load {input_path}")
LOG.exception(e)
continue
output_path.parent.mkdir(parents=True, exist_ok=True)
audio = svc_model.infer_silence(
audio.astype(np.float32),
speaker=speaker,
transpose=transpose,
auto_predict_f0=auto_predict_f0,
cluster_infer_ratio=cluster_infer_ratio,
noise_scale=noise_scale,
f0_method=f0_method,
db_thresh=db_thresh,
pad_seconds=pad_seconds,
chunk_seconds=chunk_seconds,
absolute_thresh=absolute_thresh,
max_chunk_seconds=max_chunk_seconds,
)
soundfile.write(str(output_path), audio, svc_model.target_sample)
finally:
del svc_model
torch.cuda.empty_cache()
def realtime(
*,
# paths
model_path: Path | str,
config_path: Path | str,
# svc config
speaker: str,
cluster_model_path: Path | str | None = None,
transpose: int = 0,
auto_predict_f0: bool = False,
cluster_infer_ratio: float = 0,
noise_scale: float = 0.4,
f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
# slice config
db_thresh: int = -40,
pad_seconds: float = 0.5,
chunk_seconds: float = 0.5,
# realtime config
crossfade_seconds: float = 0.05,
additional_infer_before_seconds: float = 0.2,
additional_infer_after_seconds: float = 0.1,
block_seconds: float = 0.5,
version: int = 2,
input_device: int | str | None = None,
output_device: int | str | None = None,
device: str | torch.device = get_optimal_device(),
passthrough_original: bool = False,
):
import sounddevice as sd
model_path = Path(model_path)
config_path = Path(config_path)
cluster_model_path = Path(cluster_model_path) if cluster_model_path else None
svc_model = Svc(
net_g_path=model_path.as_posix(),
config_path=config_path.as_posix(),
cluster_model_path=cluster_model_path.as_posix()
if cluster_model_path
else None,
device=device,
)
LOG.info("Creating realtime model...")
if version == 1:
model = RealtimeVC(
svc_model=svc_model,
crossfade_len=int(crossfade_seconds * svc_model.target_sample),
additional_infer_before_len=int(
additional_infer_before_seconds * svc_model.target_sample
),
additional_infer_after_len=int(
additional_infer_after_seconds * svc_model.target_sample
),
)
else:
model = RealtimeVC2(
svc_model=svc_model,
)
# LOG all device info
devices = sd.query_devices()
LOG.info(f"Device: {devices}")
if isinstance(input_device, str):
input_device_candidates = [
i for i, d in enumerate(devices) if d["name"] == input_device
]
if len(input_device_candidates) == 0:
LOG.warning(f"Input device {input_device} not found, using default")
input_device = None
else:
input_device = input_device_candidates[0]
if isinstance(output_device, str):
output_device_candidates = [
i for i, d in enumerate(devices) if d["name"] == output_device
]
if len(output_device_candidates) == 0:
LOG.warning(f"Output device {output_device} not found, using default")
output_device = None
else:
output_device = output_device_candidates[0]
if input_device is None or input_device >= len(devices):
input_device = sd.default.device[0]
if output_device is None or output_device >= len(devices):
output_device = sd.default.device[1]
LOG.info(
f"Input Device: {devices[input_device]['name']}, Output Device: {devices[output_device]['name']}"
)
# the model RTL is somewhat significantly high only in the first inference
# there could be no better way to warm up the model than to do a dummy inference
# (there are not differences in the behavior of the model between the first and the later inferences)
# so we do a dummy inference to warm up the model (1 second of audio)
LOG.info("Warming up the model...")
svc_model.infer(
speaker=speaker,
transpose=transpose,
auto_predict_f0=auto_predict_f0,
cluster_infer_ratio=cluster_infer_ratio,
noise_scale=noise_scale,
f0_method=f0_method,
audio=np.zeros(svc_model.target_sample, dtype=np.float32),
)
def callback(
indata: np.ndarray,
outdata: np.ndarray,
frames: int,
time: int,
status: sd.CallbackFlags,
) -> None:
LOG.debug(
f"Frames: {frames}, Status: {status}, Shape: {indata.shape}, Time: {time}"
)
kwargs = dict(
input_audio=indata.mean(axis=1).astype(np.float32),
# svc config
speaker=speaker,
transpose=transpose,
auto_predict_f0=auto_predict_f0,
cluster_infer_ratio=cluster_infer_ratio,
noise_scale=noise_scale,
f0_method=f0_method,
# slice config
db_thresh=db_thresh,
# pad_seconds=pad_seconds,
chunk_seconds=chunk_seconds,
)
if version == 1:
kwargs["pad_seconds"] = pad_seconds
with timer() as t:
inference = model.process(
**kwargs,
).reshape(-1, 1)
if passthrough_original:
outdata[:] = (indata + inference) / 2
else:
outdata[:] = inference
rtf = t.elapsed / block_seconds
LOG.info(f"Realtime inference time: {t.elapsed:.3f}s, RTF: {rtf:.3f}")
if rtf > 1:
LOG.warning("RTF is too high, consider increasing block_seconds")
try:
with sd.Stream(
device=(input_device, output_device),
channels=1,
callback=callback,
samplerate=svc_model.target_sample,
blocksize=int(block_seconds * svc_model.target_sample),
latency="low",
) as stream:
LOG.info(f"Latency: {stream.latency}")
while True:
sd.sleep(1000)
finally:
# del model, svc_model
torch.cuda.empty_cache()