File size: 1,677 Bytes
1d56378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69791ae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
import torch
from torch import nn
from transformers import SiglipImageProcessor,SiglipModel
import dbimutils as utils

class ScoreClassifier(nn.Module):
	def __init__(self):
		super(ScoreClassifier, self).__init__()

		self.classifier = nn.Sequential(
			nn.Linear(256, 1),
			nn.Sigmoid()
		)

		self.extractor = nn.Sequential(
            nn.Linear(768, 512),
            nn.BatchNorm1d(512),
			nn.ReLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
			nn.ReLU(),
            nn.Linear(256, 256),
			nn.ReLU(),
            )

	def forward(self, img):
		return self.classifier(self.extractor(img))

from huggingface_hub import hf_hub_download
model_file = hf_hub_download(repo_id="Muinez/Image-scorer", filename="scorer.pth")

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ScoreClassifier().to(DEVICE)
model.load_state_dict(torch.load("scorer.pth"))
model.eval()

processor = SiglipImageProcessor.from_pretrained('google/siglip-base-patch16-512')
siglip = SiglipModel.from_pretrained('google/siglip-base-patch16-512').to(DEVICE)

def predict(img):
	img = utils.preprocess_image(img)
	encoded = processor(img, return_tensors="pt").pixel_values.to(DEVICE)

	with torch.no_grad():
		score = model(siglip.get_image_features(encoded))
    
	return score.item()

gr.Interface(
    title="Artwork scorer",
    description="Predicts score (0-1) for artwork.\nCould be wrong!!!\nDoes not work very well with nsfw i.e. it was not trained on it",
    fn=predict,
	allow_flagging="never",
    inputs=gr.Image(type="pil"),
    outputs=[gr.Number(label="Score")]
).launch()