vampnet-choir / vampnet /__init__.py
hugo flores garcia
recovering from a gittastrophe
41b9d24
raw
history blame
2.59 kB
from . import modules
from pathlib import Path
from . import scheduler
from .interface import Interface
from .modules.transformer import VampNet
__version__ = "0.0.1"
ROOT = Path(__file__).parent.parent
MODELS_DIR = ROOT / "models" / "vampnet"
from huggingface_hub import hf_hub_download, HfFileSystem
DEFAULT_HF_MODEL_REPO = "hugggof/vampnet"
FS = HfFileSystem()
def download_codec():
# from dac.model.dac import DAC
from lac.model.lac import LAC as DAC
repo_id = DEFAULT_HF_MODEL_REPO
filename = "codec.pth"
codec_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=None,
local_dir=MODELS_DIR
)
return codec_path
def download_default():
filenames = ["coarse.pth", "c2f.pth"]
repo_id = DEFAULT_HF_MODEL_REPO
paths = []
for filename in filenames:
path = f"{MODELS_DIR}/{filename}"
if not Path(path).exists():
path = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=None,
local_dir=MODELS_DIR,
local_dir_use_symlinks=False,
local_files_only=False
)
paths.append(path)
# load the models
return paths[0], paths[1]
def download_finetuned(name):
repo_id = f"{DEFAULT_HF_MODEL_REPO}"
filenames = ["coarse.pth", "c2f.pth"]
paths = []
for filename in filenames:
path = f"{MODELS_DIR}/{name}/loras/{filename}"
if not Path(path).exists():
path = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=f"loras/{name}",
local_dir=MODELS_DIR,
local_dir_use_symlinks=False,
local_files_only=False
)
paths.append(path)
# load the models
return paths[0], paths[1]
def list_finetuned():
diritems = FS.listdir(f"{DEFAULT_HF_MODEL_REPO}/loras")
# iterate through all the names
valid_diritems = []
for item in diritems:
model_file_items = FS.listdir(item["name"])
item_names = [item["name"].split("/")[-1] for item in model_file_items]
# check that theres a "c2f.pth" and "coarse.pth" in the items
c2f_exists = "c2f.pth" in item_names
coarse_exists = "coarse.pth" in item_names
if c2f_exists and coarse_exists:
valid_diritems.append(item)
# get the names of the valid items
names = [item["name"].split("/")[-1] for item in valid_diritems]
return names