File size: 2,368 Bytes
43f243a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89b1df6
 
43f243a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import torch
from PIL import Image
from torchvision import transforms
from transformers import AutoProcessor, FocalNetForImageClassification
import gradio as gr

# Path to the model
MODEL_PATH = "MichalMlodawski/nsfw-image-detection-large"

# Load the model and feature extractor
feature_extractor = AutoProcessor.from_pretrained(MODEL_PATH)
model = FocalNetForImageClassification.from_pretrained(MODEL_PATH)
model.eval()

# Image transformations
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Mapping from model labels to NSFW categories
LABEL_TO_CATEGORY = {
    "LABEL_0": "Safe",
    "LABEL_1": "Questionable",
    "LABEL_2": "Unsafe"
}

def classify_image(image):
    if image is None:
        return "No image uploaded"

    # Convert to RGB (in case of PNG with alpha channel)
    image = Image.fromarray(image).convert("RGB")
    
    # Process image using feature_extractor
    inputs = feature_extractor(images=image, return_tensors="pt")
    
    # Prediction using the model
    with torch.no_grad():
        outputs = model(**inputs)
        probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
        confidence, predicted = torch.max(probabilities, 1)
    
    # Get the label from the model's configuration
    label = model.config.id2label[predicted.item()]
    category = LABEL_TO_CATEGORY.get(label, "Unknown")
    confidence_value = confidence.item() * 100

    # Prepare the result string
    emoji = {"Safe": "✅", "Questionable": "⚠️", "Unsafe": "🔞"}.get(category, "❓")
    confidence_bar = "🟩" * int(confidence_value // 10) + "⬜" * (10 - int(confidence_value // 10))
    
    #result = f"{emoji} NSFW Category: {category}\n"
    result = f"🏷️ Model Label: {label}\n"
    result += f"🎯 Confidence: {confidence_value:.2f}% {confidence_bar}"

    return result

# Define Gradio interface
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="numpy"),
    outputs=gr.Textbox(label="Classification Result"),
    title="🖼️ NSFW Image Classification 🔍",
    description="Upload an image to classify its safety level!",
    theme=gr.themes.Soft(primary_hue="purple"),
    allow_flagging="never"
)

if __name__ == "__main__":
    iface.launch()