narugo1992
dev(narugo): add new models
32ef351
raw
history blame
1.26 kB
import json
import os
from functools import lru_cache
from typing import Mapping, List
from huggingface_hub import HfFileSystem
from huggingface_hub import hf_hub_download
from imgutils.data import ImageTyping, load_image
from natsort import natsorted
from onnx_ import _open_onnx_model
from preprocess import _img_encode
hfs = HfFileSystem()
_REPO = 'deepghs/anime_rating'
_RATING_MODELS = natsorted([
os.path.dirname(os.path.relpath(file, _REPO))
for file in hfs.glob(f'{_REPO}/*/model.onnx')
])
_DEFAULT_RATING_MODEL = 'mobilenetv3_sce_dist'
@lru_cache()
def _open_anime_rating_model(model_name):
return _open_onnx_model(hf_hub_download(_REPO, f'{model_name}/model.onnx'))
@lru_cache()
def _get_tags(model_name) -> List[str]:
with open(hf_hub_download(_REPO, f'{model_name}/meta.json'), 'r') as f:
return json.load(f)['labels']
def _gr_rating(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]:
image = load_image(image, mode='RGB')
input_ = _img_encode(image, size=(size, size))[None, ...]
output, = _open_anime_rating_model(model_name).run(['output'], {'input': input_})
labels = _get_tags(model_name)
values = dict(zip(labels, map(lambda x: x.item(), output[0])))
return values