TorchFakes / app.py
molher's picture
Update app.py
05b6126
raw
history blame
3.13 kB
import gradio as gr
import torch
import torch.nn.functional as F
from facenet_pytorch import MTCNN, InceptionResnetV1
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import pickle
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'Running on device: {DEVICE.upper()}')
torch.load('resnetinceptionv1_final.pth',map_location='cpu')
mtcnn = MTCNN(
select_largest=False,
post_process=False,
device=DEVICE
).to(DEVICE).eval()
model = InceptionResnetV1(
pretrained="vggface2",
classify=True,
num_classes=1,
device=DEVICE
)
model.load_state_dict(torch.load('resnetinceptionv1_final.pth',map_location='cpu'))
model.to(DEVICE)
model.eval()
print("MTCNN & Classfier models loaded")
# Abrimos el fichero pickle de ejemplos de imagenes
with open('file_examples.pkl','rb') as file:
examples=pickle.load(file)
#EXAMPLES_FOLDER = 'examples'
#examples_names = os.listdir(EXAMPLES_FOLDER)
#examples = []
#for example_name in examples_names:
# example_path = os.path.join(EXAMPLES_FOLDER, example_name)
# label = example_name.split('_')[0]
# example = {
# 'path': example_path,
# 'label': label
# }
# examples.append(example)
def predict(input_image:Image.Image):
"""Predict the label of the input_image"""
face = mtcnn(input_image)
if face is None:
raise Exception('No face detected')
face = face.unsqueeze(0) # add the batch dimension
face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
# convert the face into a numpy array to be able to plot it
face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
face = face.to(DEVICE)
face = face.to(torch.float32)
face = face / 255.0
with torch.no_grad():
output = torch.sigmoid(model(face).squeeze(0))
prediction = "real" if output.item() < 0.5 else "fake"
real_prediction = 1 - output.item()
fake_prediction = output.item()
confidences = {
'real': real_prediction,
'fake': fake_prediction
}
return confidences, face_image_to_plot
for i in range(10):
example = examples[8]
#example_img = example['path']
example_img='fake_frame_0.jpg'
example_label = example['label']
print(f"True label: {example_label}")
example_img = Image.open(example_img)
confidences, _ = predict(example_img)
if confidences['real'] > 0.5:
print("Predicted label: real")
else:
print("Predicted label: fake")
print()
interface = gr.Interface(
fn=predict,
inputs=gr.inputs.Image(label="Input Image", type="pil"),
outputs=[
gr.outputs.Label(label="Class"),
gr.outputs.Image(label="Face")
],
#examples=[examples[i]["path"] for i in range(8)] # fake examples
examples=['fake_frame_0.jpg','fake_frame_1.jpg','fake_frame_2.jpg','fake_frame_3.jpg','fake_frame_4.jpg','real_frame_0.jpg','real_frame_1.jpg','real_frame_2.jpg','real_frame_3.jpg','real_frame_0jpg']
).launch()