Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2022 rinna Co., Ltd. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Union | |
import json | |
import torch | |
from torchvision import transforms as T | |
from huggingface_hub import hf_hub_url, cached_download | |
import os | |
from .clip import CLIPModel | |
from .cloob import CLOOBModel | |
# TODO: Fill in repo_ids | |
MODELS = { | |
'rinna/japanese-clip-vit-b-16': { | |
'repo_id': 'rinna/japanese-clip-vit-b-16', | |
'model_class': CLIPModel, | |
}, | |
'rinna/japanese-cloob-vit-b-16': { | |
'repo_id': 'rinna/japanese-cloob-vit-b-16', | |
'model_class': CLOOBModel, | |
} | |
} | |
MODEL_CLASSES = { | |
"cloob": CLOOBModel, | |
"clip": CLIPModel, | |
} | |
MODEL_FILE = "pytorch_model.bin" | |
CONFIG_FILE = "config.json" | |
def available_models(): | |
return list(MODELS.keys()) | |
def _convert_to_rgb(image): | |
return image.convert('RGB') | |
def _transform(image_size): | |
return T.Compose([ | |
T.Resize(image_size, interpolation=T.InterpolationMode.BILINEAR), | |
T.CenterCrop(image_size), | |
_convert_to_rgb, | |
T.ToTensor(), | |
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711),) | |
]) | |
def _download(repo_id: str, cache_dir: str): | |
config_file_url = hf_hub_url(repo_id=repo_id, filename=CONFIG_FILE) | |
cached_download(config_file_url, cache_dir=cache_dir) | |
model_file_url = hf_hub_url(repo_id=repo_id, filename=MODEL_FILE) | |
cached_download(model_file_url, cache_dir=cache_dir) | |
def load( | |
model_name: str, | |
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", | |
**kwargs | |
): | |
""" | |
Args: | |
model_name: model unique name or path to pre-downloaded model | |
device: device to put the loaded model | |
kwargs: kwargs for huggingface pretrained model class | |
Return: | |
(torch.nn.Module, A torchvision transform) | |
""" | |
if model_name in MODELS.keys(): | |
ModelClass = CLIPModel if 'clip' in model_name else CLOOBModel | |
elif os.path.exists(model_name): | |
assert os.path.exists(os.path.join(model_name, CONFIG_FILE)) | |
with open(os.path.join(model_name, CONFIG_FILE), "r", encoding="utf-8") as f: | |
j = json.load(f) | |
ModelClass = MODEL_CLASSES[j["model_type"]] | |
else: | |
RuntimeError(f"Model {model_name} not found; available models = {available_models()}") | |
model = ModelClass.from_pretrained(model_name, **kwargs) | |
model = model.eval().requires_grad_(False).to(device) | |
return model, _transform(model.config.vision_config.image_size) | |