Spaces:
Sleeping
Sleeping
File size: 3,357 Bytes
0673028 7695113 ca301d8 0673028 7695113 0673028 878334f 7695113 e0006c0 7695113 84823eb f95cba1 84823eb 14f526a 84823eb 7695113 8c3e34b 7695113 523fb6e 7695113 671719d 859b347 671719d 8c3e34b 7695113 671719d 859b347 1e67343 a9f1065 7695113 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
import torch
import torch.nn.functional as F
from facenet_pytorch import MTCNN, InceptionResnetV1
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import pickle
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'Running on device: {DEVICE.upper()}')
torch.load('resnetinceptionv1_final.pth',map_location='cpu')
mtcnn = MTCNN(
select_largest=False,
post_process=False,
device=DEVICE
).to(DEVICE).eval()
model = InceptionResnetV1(
pretrained="vggface2",
classify=True,
num_classes=1,
device=DEVICE
)
model.load_state_dict(torch.load('resnetinceptionv1_final.pth',map_location='cpu'))
model.to(DEVICE)
model.eval()
print("MTCNN & Classfier models loaded")
# Abrimos el fichero pickle de ejemplos de imagenes
with open('file_examples.pkl','rb') as file:
examples=pickle.load(file)
#EXAMPLES_FOLDER = 'examples'
#examples_names = os.listdir(EXAMPLES_FOLDER)
#examples = []
#for example_name in examples_names:
# example_path = os.path.join(EXAMPLES_FOLDER, example_name)
# label = example_name.split('_')[0]
# example = {
# 'path': example_path,
# 'label': label
# }
# examples.append(example)
def predict(input_image:Image.Image):
"""Predict the label of the input_image"""
face = mtcnn(input_image)
if face is None:
raise Exception('No face detected')
face = face.unsqueeze(0) # add the batch dimension
face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
# convert the face into a numpy array to be able to plot it
face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
face = face.to(DEVICE)
face = face.to(torch.float32)
face = face / 255.0
with torch.no_grad():
output = torch.sigmoid(model(face).squeeze(0))
prediction = "real" if output.item() < 0.5 else "fake"
real_prediction = 1 - output.item()
fake_prediction = output.item()
confidences = {
'real': real_prediction,
'fake': fake_prediction
}
return confidences, face_image_to_plot
for i in range(10):
example = examples[8]
#example_img = example['path']
example_img='fake_frame_0.jpg'
example_label = example['label']
print(f"True label: {example_label}")
example_img = Image.open(example_img)
confidences, _ = predict(example_img)
if confidences['real'] > 0.5:
print("Predicted label: real")
else:
print("Predicted label: fake")
print()
title='Fake or not Fake? that is the question'
description='Modelo de deeplearning para clasificar las imagenes en reales o falsas'
article='Proyecto Saturdays.AI DemoDay 11/06/2022'
interface = gr.Interface(
fn=predict,
inputs=gr.inputs.Image(label="Input Image", type="pil"),
outputs=[
gr.outputs.Label(label="Class"),
gr.outputs.Image(label="Face")
],
title=title,description=description, article=article,
theme='peach',
#examples=[examples[i]["path"] for i in range(8)] # fake examples
examples=['fake_frame_0.jpg','fake_frame_1.jpg','fake_frame_2.jpg','fake_frame_3.jpg','real_frame_0.jpg','real_frame_1.jpg','real_frame_2.jpg','real_frame_3.jpg']
).launch() |