|
import spaces |
|
from flask import Flask, request, jsonify |
|
import os |
|
from werkzeug.utils import secure_filename |
|
import cv2 |
|
import torch |
|
import torch.nn.functional as F |
|
from facenet_pytorch import MTCNN, InceptionResnetV1 |
|
import numpy as np |
|
from pytorch_grad_cam import GradCAM |
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
import base64 |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
UPLOAD_FOLDER = 'uploads' |
|
ALLOWED_EXTENSIONS = {'mp4', 'avi', 'mov', 'webm'} |
|
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER |
|
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 |
|
|
|
os.makedirs(UPLOAD_FOLDER, exist_ok=True) |
|
|
|
|
|
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
|
|
|
mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).to(DEVICE).eval() |
|
|
|
model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1, device=DEVICE) |
|
checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu')) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
model.to(DEVICE) |
|
model.eval() |
|
|
|
|
|
target_layers = [model.block8.branch1[-1]] |
|
cam = GradCAM(model=model, target_layers=target_layers) |
|
targets = [ClassifierOutputTarget(0)] |
|
|
|
def allowed_file(filename): |
|
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS |
|
|
|
@spaces.GPU |
|
def process_frame(frame): |
|
face = mtcnn(frame) |
|
if face is None: |
|
return None, None, None |
|
|
|
face = face.unsqueeze(0) |
|
face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False) |
|
|
|
face = face.to(DEVICE) |
|
face = face.to(torch.float32) |
|
face = face / 255.0 |
|
|
|
with torch.no_grad(): |
|
output = torch.sigmoid(model(face).squeeze(0)) |
|
prediction = "fake" if output.item() >= 0.5 else "real" |
|
|
|
|
|
grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True) |
|
grayscale_cam = grayscale_cam[0, :] |
|
|
|
face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().numpy() |
|
visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True) |
|
|
|
return prediction, output.item(), visualization |
|
|
|
@spaces.GPU |
|
def analyze_video(video_path, sample_rate=30, top_n=5): |
|
cap = cv2.VideoCapture(video_path) |
|
frame_count = 0 |
|
fake_count = 0 |
|
total_processed = 0 |
|
frames_info = [] |
|
|
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
if frame_count % sample_rate == 0: |
|
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
prediction, confidence, visualization = process_frame(rgb_frame) |
|
|
|
if prediction is not None: |
|
total_processed += 1 |
|
if prediction == "fake": |
|
fake_count += 1 |
|
|
|
frames_info.append({ |
|
'frame_number': frame_count, |
|
'prediction': prediction, |
|
'confidence': confidence, |
|
'visualization': visualization |
|
}) |
|
|
|
frame_count += 1 |
|
|
|
cap.release() |
|
|
|
if total_processed > 0: |
|
fake_percentage = (fake_count / total_processed) * 100 |
|
frames_info.sort(key=lambda x: x['confidence'], reverse=True) |
|
top_frames = frames_info[:top_n] |
|
|
|
return fake_percentage, top_frames |
|
else: |
|
return 0, [] |
|
|
|
@app.route('/analyze', methods=['POST']) |
|
def analyze_video_api(): |
|
if 'video' not in request.files: |
|
return jsonify({'error': 'No video file provided'}), 400 |
|
|
|
file = request.files['video'] |
|
|
|
if file.filename == '': |
|
return jsonify({'error': 'No selected file'}), 400 |
|
|
|
if file and allowed_file(file.filename): |
|
filename = secure_filename(file.filename) |
|
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) |
|
file.save(filepath) |
|
|
|
try: |
|
fake_percentage, top_frames = analyze_video(filepath) |
|
os.remove(filepath) |
|
|
|
|
|
for frame in top_frames: |
|
frame['visualization'] = base64.b64encode(cv2.imencode('.png', frame['visualization'])[1]).decode('utf-8') |
|
|
|
result = { |
|
'fake_percentage': round(fake_percentage, 2), |
|
'is_likely_deepfake': fake_percentage >= 60, |
|
'top_frames': top_frames |
|
} |
|
|
|
return jsonify(result), 200 |
|
except Exception as e: |
|
os.remove(filepath) |
|
return jsonify({'error': str(e)}), 500 |
|
else: |
|
return jsonify({'error': f'Invalid file type: {file.filename}'}), 400 |
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=7860) |