Muinez commited on
Commit
3d42bd2
1 Parent(s): ad8d862

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -56
app.py CHANGED
@@ -1,57 +1,57 @@
1
- import gradio as gr
2
- import torch
3
- from torch import nn
4
- from transformers import SiglipImageProcessor,SiglipModel
5
- import dbimutils as utils
6
-
7
- class ScoreClassifier(nn.Module):
8
- def __init__(self):
9
- super(ScoreClassifier, self).__init__()
10
-
11
- self.classifier = nn.Sequential(
12
- nn.Linear(256, 1),
13
- nn.Sigmoid()
14
- )
15
-
16
- self.extractor = nn.Sequential(
17
- nn.Linear(768, 512),
18
- nn.BatchNorm1d(512),
19
- nn.ReLU(),
20
- nn.Linear(512, 256),
21
- nn.BatchNorm1d(256),
22
- nn.ReLU(),
23
- nn.Linear(256, 256),
24
- nn.ReLU(),
25
- )
26
-
27
- def forward(self, img):
28
- return self.classifier(self.extractor(img))
29
-
30
- from huggingface_hub import hf_hub_download
31
- model_file = hf_hub_download(repo_id="Muinez/Image-scorer", filename="scorer.pth")
32
-
33
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
34
- model = ScoreClassifier().to(DEVICE)
35
- model.load_state_dict(torch.load(model_file))
36
- model.eval()
37
-
38
- processor = SiglipImageProcessor.from_pretrained('google/siglip-base-patch16-512')
39
- siglip = SiglipModel.from_pretrained('google/siglip-base-patch16-512').to(DEVICE)
40
-
41
- def predict(img):
42
- img = utils.preprocess_image(img)
43
- encoded = processor(img, return_tensors="pt").pixel_values.to(DEVICE)
44
-
45
- with torch.no_grad():
46
- score = model(siglip.get_image_features(encoded))
47
-
48
- return score.item()
49
-
50
- gr.Interface(
51
- title="Artwork scorer",
52
- description="Predicts score (0-1) for artwork.\nCould be wrong!!!\nDoes not work very well with nsfw i.e. it was not trained on it",
53
- fn=predict,
54
- allow_flagging="never",
55
- inputs=gr.Image(type="pil"),
56
- outputs=[gr.Number(label="Score")]
57
  ).launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch import nn
4
+ from transformers import SiglipImageProcessor,SiglipModel
5
+ import dbimutils as utils
6
+
7
+ class ScoreClassifier(nn.Module):
8
+ def __init__(self):
9
+ super(ScoreClassifier, self).__init__()
10
+
11
+ self.classifier = nn.Sequential(
12
+ nn.Linear(256, 1),
13
+ nn.Sigmoid()
14
+ )
15
+
16
+ self.extractor = nn.Sequential(
17
+ nn.Linear(768, 512),
18
+ nn.BatchNorm1d(512),
19
+ nn.ReLU(),
20
+ nn.Linear(512, 256),
21
+ nn.BatchNorm1d(256),
22
+ nn.ReLU(),
23
+ nn.Linear(256, 256),
24
+ nn.ReLU(),
25
+ )
26
+
27
+ def forward(self, img):
28
+ return self.classifier(self.extractor(img))
29
+
30
+ from huggingface_hub import hf_hub_download
31
+ model_file = hf_hub_download(repo_id="Muinez/Image-scorer", filename="scorer.pth")
32
+
33
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
34
+ model = ScoreClassifier().to(DEVICE)
35
+ model.load_state_dict(torch.load(model_file, map_location=torch.device('cpu')))
36
+ model.eval()
37
+
38
+ processor = SiglipImageProcessor.from_pretrained('google/siglip-base-patch16-512')
39
+ siglip = SiglipModel.from_pretrained('google/siglip-base-patch16-512').to(DEVICE)
40
+
41
+ def predict(img):
42
+ img = utils.preprocess_image(img)
43
+ encoded = processor(img, return_tensors="pt").pixel_values.to(DEVICE)
44
+
45
+ with torch.no_grad():
46
+ score = model(siglip.get_image_features(encoded))
47
+
48
+ return score.item()
49
+
50
+ gr.Interface(
51
+ title="Artwork scorer",
52
+ description="Predicts score (0-1) for artwork.\nCould be wrong!!!\nDoes not work very well with nsfw i.e. it was not trained on it",
53
+ fn=predict,
54
+ allow_flagging="never",
55
+ inputs=gr.Image(type="pil"),
56
+ outputs=[gr.Number(label="Score")]
57
  ).launch()