import spaces from flask import Flask, request, jsonify import os from werkzeug.utils import secure_filename import cv2 import torch import torch.nn.functional as F from facenet_pytorch import MTCNN, InceptionResnetV1 import numpy as np from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image import base64 app = Flask(__name__) # Configuration UPLOAD_FOLDER = 'uploads' ALLOWED_EXTENSIONS = {'mp4', 'avi', 'mov', 'webm'} app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 os.makedirs(UPLOAD_FOLDER, exist_ok=True) # Device configuration DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).to(DEVICE).eval() model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1, device=DEVICE) checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu')) model.load_state_dict(checkpoint['model_state_dict']) model.to(DEVICE) model.eval() # GradCAM setup target_layers = [model.block8.branch1[-1]] cam = GradCAM(model=model, target_layers=target_layers) targets = [ClassifierOutputTarget(0)] def allowed_file(filename): return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS @spaces.GPU def process_frame(frame): face = mtcnn(frame) if face is None: return None, None, None face = face.unsqueeze(0) face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False) face = face.to(DEVICE) face = face.to(torch.float32) face = face / 255.0 with torch.no_grad(): output = torch.sigmoid(model(face).squeeze(0)) prediction = "fake" if output.item() >= 0.5 else "real" # Generate GradCAM grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True) grayscale_cam = grayscale_cam[0, :] face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().numpy() visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True) return prediction, output.item(), visualization @spaces.GPU def analyze_video(video_path, sample_rate=30, top_n=5, detection_threshold=0.5): cap = cv2.VideoCapture(video_path) frame_count = 0 fake_count = 0 total_processed = 0 frames_info = [] confidence_scores = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break if frame_count % sample_rate == 0: rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) prediction, confidence, visualization = process_frame(rgb_frame) if prediction is not None: total_processed += 1 confidence_scores.append(confidence) if prediction == "fake": fake_count += 1 frames_info.append({ 'frame_number': frame_count, 'prediction': prediction, 'confidence': confidence, 'visualization': visualization }) frame_count += 1 cap.release() if total_processed > 0: fake_percentage = (fake_count / total_processed) * 100 average_confidence = sum(confidence_scores) / len(confidence_scores) model_confidence = 1 - (sum((score - average_confidence) ** 2 for score in confidence_scores) / len(confidence_scores)) frames_info.sort(key=lambda x: x['confidence'], reverse=True) top_frames = frames_info[:top_n] return { 'fake_percentage': fake_percentage, 'is_likely_deepfake': fake_percentage >= 60, 'top_frames': top_frames, 'model_confidence': model_confidence, 'total_frames_analyzed': total_processed, 'average_confidence_score': average_confidence, 'detection_threshold': detection_threshold } else: return None @app.route('/analyze', methods=['POST']) def analyze_video_api(): if 'video' not in request.files: return jsonify({'error': 'No video file provided'}), 400 file = request.files['video'] if file.filename == '': return jsonify({'error': 'No selected file'}), 400 if file and allowed_file(file.filename): filename = secure_filename(file.filename) filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) file.save(filepath) try: result = analyze_video(filepath) os.remove(filepath) # Remove the file after analysis if result: # Convert numpy arrays to base64 encoded strings for frame in result['top_frames']: frame['visualization'] = base64.b64encode(cv2.imencode('.png', frame['visualization'])[1]).decode('utf-8') return jsonify(result), 200 else: return jsonify({'error': 'No frames could be processed'}), 400 except Exception as e: os.remove(filepath) # Remove the file if an error occurs return jsonify({'error': str(e)}), 500 else: return jsonify({'error': f'Invalid file type: {file.filename}'}), 400 if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)