|
import json |
|
import os |
|
from functools import lru_cache |
|
from typing import Mapping, List |
|
|
|
from huggingface_hub import HfFileSystem |
|
from huggingface_hub import hf_hub_download |
|
from imgutils.data import ImageTyping, load_image |
|
from natsort import natsorted |
|
|
|
from onnx_ import _open_onnx_model |
|
from preprocess import _img_encode |
|
|
|
hfs = HfFileSystem() |
|
|
|
_REPO = 'deepghs/anime_rating' |
|
_RATING_MODELS = natsorted([ |
|
os.path.dirname(os.path.relpath(file, _REPO)) |
|
for file in hfs.glob(f'{_REPO}/*/model.onnx') |
|
]) |
|
_DEFAULT_RATING_MODEL = 'mobilenetv3_sce_dist' |
|
|
|
|
|
@lru_cache() |
|
def _open_anime_rating_model(model_name): |
|
return _open_onnx_model(hf_hub_download(_REPO, f'{model_name}/model.onnx')) |
|
|
|
|
|
@lru_cache() |
|
def _get_tags(model_name) -> List[str]: |
|
with open(hf_hub_download(_REPO, f'{model_name}/meta.json'), 'r') as f: |
|
return json.load(f)['labels'] |
|
|
|
|
|
def _gr_rating(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]: |
|
image = load_image(image, mode='RGB') |
|
input_ = _img_encode(image, size=(size, size))[None, ...] |
|
output, = _open_anime_rating_model(model_name).run(['output'], {'input': input_}) |
|
|
|
labels = _get_tags(model_name) |
|
values = dict(zip(labels, map(lambda x: x.item(), output[0]))) |
|
return values |
|
|