ZappY-AI commited on
Commit
79958cf
1 Parent(s): bf636c8

Added model and application file

Browse files
Files changed (3) hide show
  1. FaKe-ViT-B16.pth +3 -0
  2. app.py +59 -0
  3. requirements.txt +5 -0
FaKe-ViT-B16.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a2d9f5edce776c627c3797b1f1a6be5d243a188ce39b9546da2ee031b363c30
3
+ size 343286022
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ import torchvision.models as models
5
+ from torchvision.transforms import v2 as transforms
6
+ import os
7
+
8
+ class_names = ['AI-Generated Image', "Real/Non-AI-Generated Image"]
9
+
10
+ # Downloading the model
11
+ # model = models.vit_b_16()
12
+ weights_path = "FaKe-ViT-B16.pth"
13
+ model = torch.load(weights_path).to("cpu")
14
+ model.eval()
15
+ # Preprocessing the image
16
+ preprocess = transforms.Compose([
17
+ transforms.Resize(256),
18
+ transforms.CenterCrop(224),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
21
+ ])
22
+
23
+ # Define the prediction function
24
+ def predict_image(image):
25
+ # inp = Image.fromarray(inp.astype('uint8'), 'RGB')
26
+ image = preprocess(image)
27
+ if image.shape[0] != 3:
28
+ image = image[:3, :, :]
29
+ image = image.unsqueeze(0)
30
+ with torch.inference_mode():
31
+ output = model(image)
32
+ output1 = torch.argmax(torch.softmax(output,dim=1),dim=1).item()
33
+ return class_names[output1]
34
+
35
+ # def image_mod(image):
36
+ # return image.rotate(45)
37
+
38
+
39
+ demo = gr.Interface(
40
+ predict_image,
41
+ gr.Image(image_mode="RGB",type="pil"),
42
+ "text",
43
+ flagging_options=["incorrect prediction"],
44
+ # examples=[
45
+ # os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
46
+ # os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
47
+ # os.path.join(os.path.dirname(__file__), "images/logo.png"),
48
+ # os.path.join(os.path.dirname(__file__), "images/tower.jpg"),
49
+ # ],
50
+ title="FaKe-ViT-B/16: AI-Generated Image Detection using Vision Transformer(ViT-B/16)",
51
+ description="This is a demo to detect AI-Generated images using Vision Transformer(ViT-B/16). Upload an image and the model will predict whether the image is AI-Generated or Real",
52
+ css=""".gr-header, .gr-text {
53
+ font-size: 20px;
54
+ }""",
55
+ article=" \nBased on the paper:'An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale', Alexey et al.\nDataset: 'Fake or Real competition dataset' at https://huggingface.co/datasets/mncai/Fake_or_Real_Competition_Dataset"
56
+ )
57
+
58
+ if __name__ == "__main__":
59
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ uvicorn
3
+ starlette
4
+ torch
5
+ torchvision