import gradio as gr import torch import torch.nn.functional as F from facenet_pytorch import MTCNN, InceptionResnetV1 import os import numpy as np from PIL import Image import matplotlib.pyplot as plt DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' print(f'Running on device: {DEVICE.upper()}') torch.load('resnetinceptionv1_final.pth',map_location='cpu') mtcnn = MTCNN( select_largest=False, post_process=False, device=DEVICE ).to(DEVICE).eval() model = InceptionResnetV1( pretrained="vggface2", classify=True, num_classes=1, device=DEVICE ) model.load_state_dict(torch.load('resnetinceptionv1_final.pth',map_location='cpu')) model.to(DEVICE) model.eval() print("MTCNN & Classfier models loaded") EXAMPLES_FOLDER = 'examples' examples_names = os.listdir(EXAMPLES_FOLDER) examples = [] for example_name in examples_names: example_path = os.path.join(EXAMPLES_FOLDER, example_name) label = example_name.split('_')[0] example = { 'path': example_path, 'label': label } examples.append(example) def predict(input_image:Image.Image): """Predict the label of the input_image""" face = mtcnn(input_image) if face is None: raise Exception('No face detected') face = face.unsqueeze(0) # add the batch dimension face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False) # convert the face into a numpy array to be able to plot it face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy() face = face.to(DEVICE) face = face.to(torch.float32) face = face / 255.0 with torch.no_grad(): output = torch.sigmoid(model(face).squeeze(0)) prediction = "real" if output.item() < 0.5 else "fake" real_prediction = 1 - output.item() fake_prediction = output.item() confidences = { 'real': real_prediction, 'fake': fake_prediction } return confidences, face_image_to_plot for i in range(10): example = examples[8] example_img = example['path'] example_label = example['label'] print(f"True label: {example_label}") example_img = Image.open(example_img) confidences, _ = predict(example_img) if confidences['real'] > 0.5: print("Predicted label: real") else: print("Predicted label: fake") print() interface = gr.Interface( fn=predict, inputs=gr.inputs.Image(label="Input Image", type="pil"), outputs=[ gr.outputs.Label(label="Class"), gr.outputs.Image(label="Face") ], examples=[examples[i]["path"] for i in range(8)] # fake examples ).launch()