Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,753 Bytes
8fd2f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import importlib
from functools import partialmethod
from pathlib import Path
from torchvision.datasets.utils import download_url
import gdown
from utils.aux import ensure_annotation_class
def get_class(cls_path: str, *args, **kwargs):
module_name = ".".join(cls_path.split(".")[:-1])
module = importlib.import_module(module_name)
class_ = getattr(module, cls_path.split(".")[-1])
class_.__init__ = partialmethod(class_.__init__, *args, **kwargs)
return class_
@ensure_annotation_class
def download_ckpt(local_path: Path, global_path: str) -> str:
if local_path.exists():
return local_path.as_posix()
else:
if not local_path.parent.exists():
local_path.parent.mkdir(parents=True)
if "drive.google.com" in global_path and "file" in global_path:
url = global_path
dest = local_path.as_posix()
gdown.download(url=url, output=dest, fuzzy=True)
elif "drive.google.com" in global_path and "folder" in global_path:
url = global_path
dest = local_path.parent.as_posix()
gdown.download_folder(url=url, output=dest)
elif local_path.suffix == ".safetensors" or "." not in local_path.as_posix():
ckpt_url = f"https://huggingface.co/{global_path}"
try:
download_url(ckpt_url, local_path.parent.as_posix(),
local_path.name)
except Exception as e:
print(
f"Error: Failed to download model from {ckpt_url} to {local_path}")
raise e
else:
raise NotImplementedError(
f"Download model file {global_path} not supported")
assert local_path.exists(), f"Missing checkpoint {local_path}"
return local_path.as_posix()
|