Spaces:
Running
on
Zero
Running
on
Zero
import importlib | |
from functools import partialmethod | |
from pathlib import Path | |
from torchvision.datasets.utils import download_url | |
import gdown | |
from utils.aux import ensure_annotation_class | |
def get_class(cls_path: str, *args, **kwargs): | |
module_name = ".".join(cls_path.split(".")[:-1]) | |
module = importlib.import_module(module_name) | |
class_ = getattr(module, cls_path.split(".")[-1]) | |
class_.__init__ = partialmethod(class_.__init__, *args, **kwargs) | |
return class_ | |
def download_ckpt(local_path: Path, global_path: str) -> str: | |
if local_path.exists(): | |
return local_path.as_posix() | |
else: | |
if not local_path.parent.exists(): | |
local_path.parent.mkdir(parents=True) | |
if "drive.google.com" in global_path and "file" in global_path: | |
url = global_path | |
dest = local_path.as_posix() | |
gdown.download(url=url, output=dest, fuzzy=True) | |
elif "drive.google.com" in global_path and "folder" in global_path: | |
url = global_path | |
dest = local_path.parent.as_posix() | |
gdown.download_folder(url=url, output=dest) | |
elif local_path.suffix == ".safetensors" or "." not in local_path.as_posix(): | |
ckpt_url = f"https://huggingface.co/{global_path}" | |
try: | |
download_url(ckpt_url, local_path.parent.as_posix(), | |
local_path.name) | |
except Exception as e: | |
print( | |
f"Error: Failed to download model from {ckpt_url} to {local_path}") | |
raise e | |
else: | |
raise NotImplementedError( | |
f"Download model file {global_path} not supported") | |
assert local_path.exists(), f"Missing checkpoint {local_path}" | |
return local_path.as_posix() | |