Image-scorer / app.py
Muinez's picture
Update app.py
c37bca7 verified
raw
history blame
No virus
1.58 kB
import gradio as gr
import torch
from torch import nn
from transformers import SiglipImageProcessor,SiglipModel
import dbimutils as utils
class ScoreClassifier(nn.Module):
def __init__(self):
super(ScoreClassifier, self).__init__()
self.classifier = nn.Sequential(
nn.Linear(256, 1),
nn.Sigmoid()
)
self.extractor = nn.Sequential(
nn.Linear(768, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
)
def forward(self, img):
return self.classifier(self.extractor(img))
from huggingface_hub import hf_hub_download
model_file = hf_hub_download(repo_id="Muinez/Image-scorer", filename="scorer.pth")
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ScoreClassifier().to(DEVICE)
model.load_state_dict(torch.load(model_file, map_location=torch.device('cpu')))
model.eval()
processor = SiglipImageProcessor.from_pretrained('google/siglip-base-patch16-512')
siglip = SiglipModel.from_pretrained('google/siglip-base-patch16-512').to(DEVICE)
def predict(img):
img = utils.preprocess_image(img)
encoded = processor(img, return_tensors="pt").pixel_values.to(DEVICE)
with torch.no_grad():
score = model(siglip.get_image_features(encoded))
return score.item()
gr.Interface(
title="Image scorer",
description="Predicts score (0-1) for image.\nCould be wrong",
fn=predict,
allow_flagging="never",
inputs=gr.Image(type="pil"),
outputs=[gr.Number(label="Score")]
).launch()