|
import gradio as gr |
|
from PIL import Image |
|
import torch |
|
import torchvision.models as models |
|
from torchvision.transforms import v2 as transforms |
|
import os |
|
|
|
|
|
class_names = ['Fake/AI-Generated Image', "Real/Not an AI-Generated Image"] |
|
|
|
|
|
weights_path = "FaKe-ViT-B16.pth" |
|
model = torch.load(weights_path, map_location=torch.device('cpu')) |
|
model.eval() |
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
def predict_image(image): |
|
image = preprocess(image) |
|
if image.shape[0] != 3: |
|
|
|
return "Invalid Image: Image should be in RGB format. Please upload a valid image." |
|
image = image.unsqueeze(0) |
|
with torch.inference_mode(): |
|
output = model(image) |
|
output1 = torch.argmax(torch.softmax(output,dim=1),dim=1).item() |
|
return class_names[output1] |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
predict_image, |
|
gr.Image(image_mode="RGB",type="pil"), |
|
"text", |
|
flagging_options=["incorrect prediction"], |
|
examples=[ |
|
("images/cheetah.jpg"), |
|
( "images/cat.jpg"), |
|
("images/astronaut.jpg"), |
|
("images/mountain.jpg"), |
|
("images/unicorn.jpg") |
|
], |
|
title="<u>FaKe-ViT-B/16: Robust and Fast AI-Generated Image Detection using Vision Transformer(ViT-B/16):</u>", |
|
description="<p style='font-size: 20px;'>This is a demo to detect AI-Generated images using a fine-tuned Vision Transformer(ViT-B/16). Upload an image and the model will predict whether the image is AI-Generated or Real", |
|
article="<p style='font-size: 20px;'><b>Paper</b>: 'An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale', Alexey et al.<br/><b>Dataset</b>: 'Fake or Real competition dataset' at <a href='https://huggingface.co/datasets/mncai/Fake_or_Real_Competition_Dataset'>Fake or Real competition dataset</a>" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |