MichalMlodawski's picture
Update app.py
89b1df6 verified
raw
history blame contribute delete
No virus
2.37 kB
import os
import torch
from PIL import Image
from torchvision import transforms
from transformers import AutoProcessor, FocalNetForImageClassification
import gradio as gr
# Path to the model
MODEL_PATH = "MichalMlodawski/nsfw-image-detection-large"
# Load the model and feature extractor
feature_extractor = AutoProcessor.from_pretrained(MODEL_PATH)
model = FocalNetForImageClassification.from_pretrained(MODEL_PATH)
model.eval()
# Image transformations
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Mapping from model labels to NSFW categories
LABEL_TO_CATEGORY = {
"LABEL_0": "Safe",
"LABEL_1": "Questionable",
"LABEL_2": "Unsafe"
}
def classify_image(image):
if image is None:
return "No image uploaded"
# Convert to RGB (in case of PNG with alpha channel)
image = Image.fromarray(image).convert("RGB")
# Process image using feature_extractor
inputs = feature_extractor(images=image, return_tensors="pt")
# Prediction using the model
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
confidence, predicted = torch.max(probabilities, 1)
# Get the label from the model's configuration
label = model.config.id2label[predicted.item()]
category = LABEL_TO_CATEGORY.get(label, "Unknown")
confidence_value = confidence.item() * 100
# Prepare the result string
emoji = {"Safe": "✅", "Questionable": "⚠️", "Unsafe": "🔞"}.get(category, "❓")
confidence_bar = "🟩" * int(confidence_value // 10) + "⬜" * (10 - int(confidence_value // 10))
#result = f"{emoji} NSFW Category: {category}\n"
result = f"🏷️ Model Label: {label}\n"
result += f"🎯 Confidence: {confidence_value:.2f}% {confidence_bar}"
return result
# Define Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="numpy"),
outputs=gr.Textbox(label="Classification Result"),
title="🖼️ NSFW Image Classification 🔍",
description="Upload an image to classify its safety level!",
theme=gr.themes.Soft(primary_hue="purple"),
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch()