molher commited on
Commit
7695113
1 Parent(s): 3876f09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -4
app.py CHANGED
@@ -1,7 +1,99 @@
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- def saludar(nombre):
4
- return f'Hola {nombre}'
5
- inf=gr.Interface(saludar,inputs='text',outputs='text')
6
 
7
- inf.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from facenet_pytorch import MTCNN, InceptionResnetV1
5
+ import os
6
+ import numpy as np
7
+ from PIL import Image
8
+ import matplotlib.pyplot as plt
9
 
10
+ DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
11
+ print(f'Running on device: {DEVICE.upper()}')
 
12
 
13
+ torch.load('./models/resnetinceptionv1_final.pth',map_location='cpu')
14
+
15
+ mtcnn = MTCNN(
16
+ select_largest=False,
17
+ post_process=False,
18
+ device=DEVICE
19
+ ).to(DEVICE).eval()
20
+
21
+ model = InceptionResnetV1(
22
+ pretrained="vggface2",
23
+ classify=True,
24
+ num_classes=1,
25
+ device=DEVICE
26
+ )
27
+ model.load_state_dict(torch.load('./models/resnetinceptionv1_final.pth',map_location='cpu'))
28
+ model.to(DEVICE)
29
+ model.eval()
30
+ print("MTCNN & Classfier models loaded")
31
+
32
+
33
+ EXAMPLES_FOLDER = 'examples'
34
+ examples_names = os.listdir(EXAMPLES_FOLDER)
35
+ examples = []
36
+ for example_name in examples_names:
37
+ example_path = os.path.join(EXAMPLES_FOLDER, example_name)
38
+ label = example_name.split('_')[0]
39
+ example = {
40
+ 'path': example_path,
41
+ 'label': label
42
+ }
43
+ examples.append(example)
44
+
45
+
46
+
47
+ def predict(input_image:Image.Image):
48
+ """Predict the label of the input_image"""
49
+ face = mtcnn(input_image)
50
+ if face is None:
51
+ raise Exception('No face detected')
52
+ face = face.unsqueeze(0) # add the batch dimension
53
+ face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
54
+
55
+ # convert the face into a numpy array to be able to plot it
56
+ face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
57
+
58
+ face = face.to(DEVICE)
59
+ face = face.to(torch.float32)
60
+ face = face / 255.0
61
+ with torch.no_grad():
62
+ output = torch.sigmoid(model(face).squeeze(0))
63
+ prediction = "real" if output.item() < 0.5 else "fake"
64
+
65
+ real_prediction = 1 - output.item()
66
+ fake_prediction = output.item()
67
+
68
+ confidences = {
69
+ 'real': real_prediction,
70
+ 'fake': fake_prediction
71
+ }
72
+ return confidences, face_image_to_plot
73
+
74
+ for i in range(10):
75
+ example = examples[8]
76
+ example_img = example['path']
77
+ example_label = example['label']
78
+
79
+ print(f"True label: {example_label}")
80
+
81
+ example_img = Image.open(example_img)
82
+ confidences, _ = predict(example_img)
83
+ if confidences['real'] > 0.5:
84
+ print("Predicted label: real")
85
+ else:
86
+ print("Predicted label: fake")
87
+
88
+ print()
89
+
90
+
91
+ interface = gr.Interface(
92
+ fn=predict,
93
+ inputs=gr.inputs.Image(label="Input Image", type="pil"),
94
+ outputs=[
95
+ gr.outputs.Label(label="Class"),
96
+ gr.outputs.Image(label="Face")
97
+ ],
98
+ examples=[examples[i]["path"] for i in range(8)] # fake examples
99
+ ).launch()