Fake-ViT / app.py
ZappY-AI's picture
some font changes
f88be9d
import gradio as gr
from PIL import Image
import torch
import torchvision.models as models
from torchvision.transforms import v2 as transforms
import os
# Define the class names
class_names = ['Fake/AI-Generated Image', "Real/Not an AI-Generated Image"]
# Load the model
weights_path = "FaKe-ViT-B16.pth"
model = torch.load(weights_path, map_location=torch.device('cpu'))
model.eval()
# Preprocessing the image
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Define the prediction function
def predict_image(image):
image = preprocess(image)
if image.shape[0] != 3:
# image = image[:3, :, :]
return "Invalid Image: Image should be in RGB format. Please upload a valid image."
image = image.unsqueeze(0)
with torch.inference_mode():
output = model(image)
output1 = torch.argmax(torch.softmax(output,dim=1),dim=1).item()
return class_names[output1]
demo = gr.Interface(
predict_image,
gr.Image(image_mode="RGB",type="pil"),
"text",
flagging_options=["incorrect prediction"],
examples=[
("images/cheetah.jpg"),
( "images/cat.jpg"),
("images/astronaut.jpg"),
("images/mountain.jpg"),
("images/unicorn.jpg")
],
title="<u>FaKe-ViT-B/16: Robust and Fast AI-Generated Image Detection using Vision Transformer(ViT-B/16):</u>",
description="<p style='font-size: 20px;'>This is a demo to detect AI-Generated images using a fine-tuned Vision Transformer(ViT-B/16). Upload an image and the model will predict whether the image is AI-Generated or Real",
article="<p style='font-size: 20px;'><b>Paper</b>: 'An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale', Alexey et al.<br/><b>Dataset</b>: 'Fake or Real competition dataset' at <a href='https://huggingface.co/datasets/mncai/Fake_or_Real_Competition_Dataset'>Fake or Real competition dataset</a>"
)
if __name__ == "__main__":
demo.launch()