import gradio as gr from PIL import Image import torch import torchvision.models as models from torchvision.transforms import v2 as transforms import os # Define the class names class_names = ['Fake/AI-Generated Image', "Real/Not an AI-Generated Image"] # Load the model weights_path = "FaKe-ViT-B16.pth" model = torch.load(weights_path, map_location=torch.device('cpu')) model.eval() # Preprocessing the image preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Define the prediction function def predict_image(image): image = preprocess(image) if image.shape[0] != 3: # image = image[:3, :, :] return "Invalid Image: Image should be in RGB format. Please upload a valid image." image = image.unsqueeze(0) with torch.inference_mode(): output = model(image) output1 = torch.argmax(torch.softmax(output,dim=1),dim=1).item() return class_names[output1] demo = gr.Interface( predict_image, gr.Image(image_mode="RGB",type="pil"), "text", flagging_options=["incorrect prediction"], examples=[ ("images/cheetah.jpg"), ( "images/cat.jpg"), ("images/astronaut.jpg"), ("images/mountain.jpg"), ("images/unicorn.jpg") ], title="FaKe-ViT-B/16: Robust and Fast AI-Generated Image Detection using Vision Transformer(ViT-B/16):", description="
This is a demo to detect AI-Generated images using a fine-tuned Vision Transformer(ViT-B/16). Upload an image and the model will predict whether the image is AI-Generated or Real", article="
Paper: 'An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale', Alexey et al.
Dataset: 'Fake or Real competition dataset' at Fake or Real competition dataset"
)
if __name__ == "__main__":
demo.launch()