File size: 3,357 Bytes
0673028
7695113
 
 
 
 
 
 
ca301d8
0673028
7695113
 
0673028
878334f
7695113
 
 
 
 
 
 
 
 
 
 
 
 
e0006c0
7695113
 
 
 
 
84823eb
 
f95cba1
 
84823eb
 
 
 
 
 
 
 
 
 
14f526a
84823eb
7695113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c3e34b
7695113
523fb6e
 
7695113
 
 
 
 
 
 
 
 
 
 
 
 
671719d
 
859b347
671719d
 
8c3e34b
7695113
 
 
 
 
 
671719d
859b347
1e67343
a9f1065
7695113
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import torch
import torch.nn.functional as F
from facenet_pytorch import MTCNN, InceptionResnetV1
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import pickle

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'Running on device: {DEVICE.upper()}')

torch.load('resnetinceptionv1_final.pth',map_location='cpu')

mtcnn = MTCNN(
    select_largest=False,
    post_process=False,
    device=DEVICE
).to(DEVICE).eval()

model = InceptionResnetV1(
    pretrained="vggface2",
    classify=True,
    num_classes=1,
    device=DEVICE
)
model.load_state_dict(torch.load('resnetinceptionv1_final.pth',map_location='cpu'))
model.to(DEVICE)
model.eval()
print("MTCNN & Classfier models loaded")


# Abrimos el fichero pickle de ejemplos de imagenes

with open('file_examples.pkl','rb') as file:
    examples=pickle.load(file)

#EXAMPLES_FOLDER = 'examples'
#examples_names = os.listdir(EXAMPLES_FOLDER)
#examples = []
#for example_name in examples_names:
#    example_path = os.path.join(EXAMPLES_FOLDER, example_name)
#    label = example_name.split('_')[0]
#    example = {
#        'path': example_path,
#        'label': label
#    }
#    examples.append(example)
 
   
       
def predict(input_image:Image.Image):
    """Predict the label of the input_image"""
    face = mtcnn(input_image)
    if face is None:
        raise Exception('No face detected')
    face = face.unsqueeze(0) # add the batch dimension
    face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
    
    # convert the face into a numpy array to be able to plot it
    face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()

    face = face.to(DEVICE)
    face = face.to(torch.float32)
    face = face / 255.0
    with torch.no_grad():
        output = torch.sigmoid(model(face).squeeze(0))
        prediction = "real" if output.item() < 0.5 else "fake"
        
        real_prediction = 1 - output.item()
        fake_prediction = output.item()
        
        confidences = {
            'real': real_prediction,
            'fake': fake_prediction
        }
    return confidences, face_image_to_plot 
    
for i in range(10):
    example = examples[8]
    #example_img = example['path']
    example_img='fake_frame_0.jpg'
    example_label = example['label']

    print(f"True label: {example_label}")

    example_img = Image.open(example_img)
    confidences, _ = predict(example_img)
    if confidences['real'] > 0.5:
        print("Predicted label: real")
    else:
        print("Predicted label: fake")

    print()     
    
title='Fake or not Fake? that is the question'
description='Modelo de deeplearning para clasificar las imagenes en reales o falsas'
article='Proyecto Saturdays.AI DemoDay 11/06/2022'
   
         
interface = gr.Interface(
    fn=predict,
    inputs=gr.inputs.Image(label="Input Image", type="pil"),
    outputs=[
        gr.outputs.Label(label="Class"),
        gr.outputs.Image(label="Face")
    ],
    title=title,description=description, article=article,
    theme='peach',
    #examples=[examples[i]["path"] for i in range(8)] # fake examples
    examples=['fake_frame_0.jpg','fake_frame_1.jpg','fake_frame_2.jpg','fake_frame_3.jpg','real_frame_0.jpg','real_frame_1.jpg','real_frame_2.jpg','real_frame_3.jpg']
).launch()