dhairyashah commited on
Commit
096a01b
1 Parent(s): 5c4efa2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -10
app.py CHANGED
@@ -25,7 +25,10 @@ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
25
  # Device configuration
26
  DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
27
 
28
- mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).to(DEVICE).eval()
 
 
 
29
 
30
  model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1, device=DEVICE)
31
  checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
@@ -41,18 +44,39 @@ targets = [ClassifierOutputTarget(0)]
41
  def allowed_file(filename):
42
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
43
 
 
 
 
 
 
 
 
 
 
 
 
44
  @spaces.GPU
45
  def process_frame(frame):
46
- face = mtcnn(frame)
47
- if face is None:
 
 
 
48
  return None, None, None
49
 
50
- face = face.unsqueeze(0)
51
- face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
 
 
 
 
 
 
 
52
 
53
- face = face.to(DEVICE)
54
- face = face.to(torch.float32)
55
- face = face / 255.0
56
 
57
  with torch.no_grad():
58
  output = torch.sigmoid(model(face).squeeze(0))
@@ -82,8 +106,7 @@ def analyze_video(video_path, sample_rate=30, top_n=5, detection_threshold=0.5):
82
  break
83
 
84
  if frame_count % sample_rate == 0:
85
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
86
- prediction, confidence, visualization = process_frame(rgb_frame)
87
 
88
  if prediction is not None:
89
  total_processed += 1
 
25
  # Device configuration
26
  DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
27
 
28
+ # Configure MTCNN with adjusted thresholds
29
+ mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE,
30
+ thresholds=[0.7, 0.8, 0.8], # Adjust these thresholds for P-Net, R-Net, O-Net
31
+ margin=20, min_face_size=50).to(DEVICE).eval()
32
 
33
  model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1, device=DEVICE)
34
  checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
 
44
  def allowed_file(filename):
45
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
46
 
47
+ def filter_low_quality_detections(detection, min_size=(50, 50)):
48
+ if detection is None or detection[0] is None:
49
+ return None
50
+ for i, (box, prob) in enumerate(zip(detection[0], detection[1])):
51
+ if prob < 0.9: # Filter out detections with low confidence
52
+ continue
53
+ if (box[2] - box[0] < min_size[0]) or (box[3] - box[1] < min_size[1]): # Check size
54
+ continue
55
+ return box # Return the first valid detection
56
+ return None
57
+
58
  @spaces.GPU
59
  def process_frame(frame):
60
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
61
+ detection = mtcnn.detect(rgb_frame)
62
+ face_box = filter_low_quality_detections(detection)
63
+
64
+ if face_box is None:
65
  return None, None, None
66
 
67
+ x1, y1, x2, y2 = map(int, face_box)
68
+ h, w, _ = rgb_frame.shape
69
+
70
+ if x1 < 0 or y1 < 0 or x2 > w or y2 > h:
71
+ return None, None, None
72
+
73
+ face = rgb_frame[y1:y2, x1:x2]
74
+ if face.size == 0:
75
+ return None, None, None
76
 
77
+ face = cv2.resize(face, (256, 256))
78
+ face = torch.from_numpy(face).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
79
+ face = face.to(torch.float32) / 255.0
80
 
81
  with torch.no_grad():
82
  output = torch.sigmoid(model(face).squeeze(0))
 
106
  break
107
 
108
  if frame_count % sample_rate == 0:
109
+ prediction, confidence, visualization = process_frame(frame)
 
110
 
111
  if prediction is not None:
112
  total_processed += 1