Image-scorer / app.py
Muinez's picture
Upload 2 files
1d56378 verified
raw
history blame
No virus
1.68 kB
import gradio as gr
import torch
from torch import nn
from transformers import SiglipImageProcessor,SiglipModel
import dbimutils as utils
class ScoreClassifier(nn.Module):
def __init__(self):
super(ScoreClassifier, self).__init__()
self.classifier = nn.Sequential(
nn.Linear(256, 1),
nn.Sigmoid()
)
self.extractor = nn.Sequential(
nn.Linear(768, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
)
def forward(self, img):
return self.classifier(self.extractor(img))
from huggingface_hub import hf_hub_download
model_file = hf_hub_download(repo_id="Muinez/Image-scorer", filename="scorer.pth")
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ScoreClassifier().to(DEVICE)
model.load_state_dict(torch.load("scorer.pth"))
model.eval()
processor = SiglipImageProcessor.from_pretrained('google/siglip-base-patch16-512')
siglip = SiglipModel.from_pretrained('google/siglip-base-patch16-512').to(DEVICE)
def predict(img):
img = utils.preprocess_image(img)
encoded = processor(img, return_tensors="pt").pixel_values.to(DEVICE)
with torch.no_grad():
score = model(siglip.get_image_features(encoded))
return score.item()
gr.Interface(
title="Artwork scorer",
description="Predicts score (0-1) for artwork.\nCould be wrong!!!\nDoes not work very well with nsfw i.e. it was not trained on it",
fn=predict,
allow_flagging="never",
inputs=gr.Image(type="pil"),
outputs=[gr.Number(label="Score")]
).launch()