MichalMlodawski commited on
Commit
43f243a
1 Parent(s): c62aa71

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ from transformers import AutoProcessor, FocalNetForImageClassification
6
+ import gradio as gr
7
+
8
+ # Path to the model
9
+ MODEL_PATH = "MichalMlodawski/nsfw-image-detection-large"
10
+
11
+ # Load the model and feature extractor
12
+ feature_extractor = AutoProcessor.from_pretrained(MODEL_PATH)
13
+ model = FocalNetForImageClassification.from_pretrained(MODEL_PATH)
14
+ model.eval()
15
+
16
+ # Image transformations
17
+ transform = transforms.Compose([
18
+ transforms.Resize((512, 512)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
21
+ ])
22
+
23
+ # Mapping from model labels to NSFW categories
24
+ LABEL_TO_CATEGORY = {
25
+ "LABEL_0": "Safe",
26
+ "LABEL_1": "Questionable",
27
+ "LABEL_2": "Unsafe"
28
+ }
29
+
30
+ def classify_image(image):
31
+ if image is None:
32
+ return "No image uploaded"
33
+
34
+ # Convert to RGB (in case of PNG with alpha channel)
35
+ image = Image.fromarray(image).convert("RGB")
36
+
37
+ # Process image using feature_extractor
38
+ inputs = feature_extractor(images=image, return_tensors="pt")
39
+
40
+ # Prediction using the model
41
+ with torch.no_grad():
42
+ outputs = model(**inputs)
43
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
44
+ confidence, predicted = torch.max(probabilities, 1)
45
+
46
+ # Get the label from the model's configuration
47
+ label = model.config.id2label[predicted.item()]
48
+ category = LABEL_TO_CATEGORY.get(label, "Unknown")
49
+ confidence_value = confidence.item() * 100
50
+
51
+ # Prepare the result string
52
+ emoji = {"Safe": "✅", "Questionable": "⚠️", "Unsafe": "🔞"}.get(category, "❓")
53
+ confidence_bar = "🟩" * int(confidence_value // 10) + "⬜" * (10 - int(confidence_value // 10))
54
+
55
+ result = f"{emoji} NSFW Category: {category}\n"
56
+ result += f"🏷️ Model Label: {label}\n"
57
+ result += f"🎯 Confidence: {confidence_value:.2f}% {confidence_bar}"
58
+
59
+ return result
60
+
61
+ # Define Gradio interface
62
+ iface = gr.Interface(
63
+ fn=classify_image,
64
+ inputs=gr.Image(type="numpy"),
65
+ outputs=gr.Textbox(label="Classification Result"),
66
+ title="🖼️ NSFW Image Classification 🔍",
67
+ description="Upload an image to classify its safety level!",
68
+ theme=gr.themes.Soft(primary_hue="purple"),
69
+ allow_flagging="never"
70
+ )
71
+
72
+ if __name__ == "__main__":
73
+ iface.launch()