Spaces:
Runtime error
Runtime error
jungwoonshin
commited on
Commit
•
7199166
1
Parent(s):
a8ff7ce
132
Browse files- .gitattributes +1 -0
- app.py +57 -60
- classifier_DeepFakeClassifier_tf_efficientnet_b7_ns_1_best_dice +3 -0
- classifiers.py +172 -0
- predict/kernel_utils.py → kernel_utils.py +9 -9
- predict/app.py +0 -68
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
classifier_DeepFakeClassifier_tf_efficientnet_b7_ns_1_best_dice filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,74 +1,71 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
|
7 |
-
|
8 |
-
|
9 |
|
10 |
-
#
|
11 |
-
#
|
12 |
-
#
|
13 |
-
#
|
14 |
-
#
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
#
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
# num_workers=6, test_dir=args.test_dir)
|
37 |
-
# return predictions
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
|
48 |
-
|
49 |
-
#
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
|
62 |
def greet(name):
|
63 |
return "Hello " + name + "!!"
|
64 |
|
65 |
if __name__ == '__main__':
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
# stime = time.time()
|
70 |
# print("Elapsed:", time.time() - stime)
|
71 |
-
|
72 |
-
demo = gr.Interface(fn=
|
73 |
-
demo.launch()
|
74 |
-
|
|
|
1 |
import gradio as gr
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import time
|
6 |
|
7 |
+
import torch
|
8 |
+
import pandas as pd
|
9 |
|
10 |
+
# import os, sys
|
11 |
+
# root_folder = os.path.abspath(
|
12 |
+
# os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
13 |
+
# )
|
14 |
+
# sys.path.append(root_folder)
|
15 |
+
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set
|
16 |
+
from classifiers import DeepFakeClassifier
|
17 |
+
import gradio as gr
|
18 |
|
19 |
+
|
20 |
+
|
21 |
+
def predict(video):
|
22 |
|
23 |
+
frames_per_video = 32
|
24 |
+
video_reader = VideoReader()
|
25 |
+
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
|
26 |
+
face_extractor = FaceExtractor(video_read_fn)
|
27 |
+
input_size = 380
|
28 |
+
strategy = confident_strategy
|
29 |
+
|
30 |
+
# test_videos = sorted([x for x in os.listdir(args.test_dir) if x[-4:] == ".mp4"])[video_index]
|
31 |
+
# print(f"Predicting {video_index} videos")
|
32 |
+
predictions = predict_on_video_set(face_extractor=face_extractor, input_size=input_size, models=models,
|
33 |
+
strategy=strategy, frames_per_video=frames_per_video, videos=video,
|
34 |
+
num_workers=6, test_dir=args.test_dir)
|
35 |
+
return predictions
|
|
|
|
|
36 |
|
37 |
+
def get_args_models():
|
38 |
+
parser = argparse.ArgumentParser("Predict test videos")
|
39 |
+
arg = parser.add_argument
|
40 |
+
arg('--weights-dir', type=str, default=".", help="path to directory with checkpoints")
|
41 |
+
arg('--models', type=str, default='classifier_DeepFakeClassifier_tf_efficientnet_b7_ns_1_best_dice', help="checkpoint files") # nargs='+',
|
42 |
+
arg('--test-dir', type=str, default='test_dataset', help="path to directory with videos")
|
43 |
+
arg('--output', type=str, required=False, help="path to output csv", default="submission.csv")
|
44 |
+
args = parser.parse_args()
|
45 |
|
46 |
+
models = []
|
47 |
+
# model_paths = [os.path.join(args.weights_dir, model) for model in args.models]
|
48 |
+
model_paths = [os.path.join(args.weights_dir, args.models)]
|
49 |
+
for path in model_paths:
|
50 |
+
model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to("cpu")
|
51 |
+
print("loading state dict {}".format(path))
|
52 |
+
checkpoint = torch.load(path, map_location="cpu")
|
53 |
+
state_dict = checkpoint.get("state_dict", checkpoint)
|
54 |
+
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
|
55 |
+
model.eval()
|
56 |
+
del checkpoint
|
57 |
+
models.append(model)
|
58 |
+
return args, models
|
59 |
|
60 |
def greet(name):
|
61 |
return "Hello " + name + "!!"
|
62 |
|
63 |
if __name__ == '__main__':
|
64 |
+
global args, models
|
65 |
+
args, models = get_args_models()
|
66 |
+
|
67 |
# stime = time.time()
|
68 |
# print("Elapsed:", time.time() - stime)
|
69 |
+
|
70 |
+
demo = gr.Interface(fn=predict, inputs="video", outputs="text")
|
71 |
+
demo.launch()
|
|
classifier_DeepFakeClassifier_tf_efficientnet_b7_ns_1_best_dice
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a8ec8b8c200d260679069d022ca396ba55ded41c7a92c6b99f5ae52406a304ba
|
3 |
+
size 267100135
|
classifiers.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from timm.models.efficientnet import tf_efficientnet_b4_ns, tf_efficientnet_b3_ns, \
|
6 |
+
tf_efficientnet_b5_ns, tf_efficientnet_b2_ns, tf_efficientnet_b6_ns, tf_efficientnet_b7_ns
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn.modules.dropout import Dropout
|
9 |
+
from torch.nn.modules.linear import Linear
|
10 |
+
from torch.nn.modules.pooling import AdaptiveAvgPool2d
|
11 |
+
|
12 |
+
encoder_params = {
|
13 |
+
"tf_efficientnet_b3_ns": {
|
14 |
+
"features": 1536,
|
15 |
+
"init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
|
16 |
+
},
|
17 |
+
"tf_efficientnet_b2_ns": {
|
18 |
+
"features": 1408,
|
19 |
+
"init_op": partial(tf_efficientnet_b2_ns, pretrained=False, drop_path_rate=0.2)
|
20 |
+
},
|
21 |
+
"tf_efficientnet_b4_ns": {
|
22 |
+
"features": 1792,
|
23 |
+
"init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.5)
|
24 |
+
},
|
25 |
+
"tf_efficientnet_b5_ns": {
|
26 |
+
"features": 2048,
|
27 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
|
28 |
+
},
|
29 |
+
"tf_efficientnet_b4_ns_03d": {
|
30 |
+
"features": 1792,
|
31 |
+
"init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.3)
|
32 |
+
},
|
33 |
+
"tf_efficientnet_b5_ns_03d": {
|
34 |
+
"features": 2048,
|
35 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.3)
|
36 |
+
},
|
37 |
+
"tf_efficientnet_b5_ns_04d": {
|
38 |
+
"features": 2048,
|
39 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.4)
|
40 |
+
},
|
41 |
+
"tf_efficientnet_b6_ns": {
|
42 |
+
"features": 2304,
|
43 |
+
"init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.2)
|
44 |
+
},
|
45 |
+
"tf_efficientnet_b7_ns": {
|
46 |
+
"features": 2560,
|
47 |
+
"init_op": partial(tf_efficientnet_b7_ns, pretrained=True, drop_path_rate=0.2)
|
48 |
+
},
|
49 |
+
"tf_efficientnet_b6_ns_04d": {
|
50 |
+
"features": 2304,
|
51 |
+
"init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.4)
|
52 |
+
},
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
def setup_srm_weights(input_channels: int = 3) -> torch.Tensor:
|
57 |
+
"""Creates the SRM kernels for noise analysis."""
|
58 |
+
# note: values taken from Zhou et al., "Learning Rich Features for Image Manipulation Detection", CVPR2018
|
59 |
+
srm_kernel = torch.from_numpy(np.array([
|
60 |
+
[ # srm 1/2 horiz
|
61 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
62 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
63 |
+
[0., 1., -2., 1., 0.], # noqa: E241,E201
|
64 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
65 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
66 |
+
], [ # srm 1/4
|
67 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
68 |
+
[0., -1., 2., -1., 0.], # noqa: E241,E201
|
69 |
+
[0., 2., -4., 2., 0.], # noqa: E241,E201
|
70 |
+
[0., -1., 2., -1., 0.], # noqa: E241,E201
|
71 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
72 |
+
], [ # srm 1/12
|
73 |
+
[-1., 2., -2., 2., -1.], # noqa: E241,E201
|
74 |
+
[2., -6., 8., -6., 2.], # noqa: E241,E201
|
75 |
+
[-2., 8., -12., 8., -2.], # noqa: E241,E201
|
76 |
+
[2., -6., 8., -6., 2.], # noqa: E241,E201
|
77 |
+
[-1., 2., -2., 2., -1.], # noqa: E241,E201
|
78 |
+
]
|
79 |
+
])).float()
|
80 |
+
srm_kernel[0] /= 2
|
81 |
+
srm_kernel[1] /= 4
|
82 |
+
srm_kernel[2] /= 12
|
83 |
+
return srm_kernel.view(3, 1, 5, 5).repeat(1, input_channels, 1, 1)
|
84 |
+
|
85 |
+
|
86 |
+
def setup_srm_layer(input_channels: int = 3) -> torch.nn.Module:
|
87 |
+
"""Creates a SRM convolution layer for noise analysis."""
|
88 |
+
weights = setup_srm_weights(input_channels)
|
89 |
+
conv = torch.nn.Conv2d(input_channels, out_channels=3, kernel_size=5, stride=1, padding=2, bias=False)
|
90 |
+
with torch.no_grad():
|
91 |
+
conv.weight = torch.nn.Parameter(weights, requires_grad=False)
|
92 |
+
return conv
|
93 |
+
|
94 |
+
|
95 |
+
class DeepFakeClassifierSRM(nn.Module):
|
96 |
+
def __init__(self, encoder, dropout_rate=0.5) -> None:
|
97 |
+
super().__init__()
|
98 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
99 |
+
self.avg_pool = AdaptiveAvgPool2d((1, 1))
|
100 |
+
self.srm_conv = setup_srm_layer(3)
|
101 |
+
self.dropout = Dropout(dropout_rate)
|
102 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
noise = self.srm_conv(x)
|
106 |
+
x = self.encoder.forward_features(noise)
|
107 |
+
x = self.avg_pool(x).flatten(1)
|
108 |
+
x = self.dropout(x)
|
109 |
+
x = self.fc(x)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class GlobalWeightedAvgPool2d(nn.Module):
|
114 |
+
"""
|
115 |
+
Global Weighted Average Pooling from paper "Global Weighted Average
|
116 |
+
Pooling Bridges Pixel-level Localization and Image-level Classification"
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(self, features: int, flatten=False):
|
120 |
+
super().__init__()
|
121 |
+
self.conv = nn.Conv2d(features, 1, kernel_size=1, bias=True)
|
122 |
+
self.flatten = flatten
|
123 |
+
|
124 |
+
def fscore(self, x):
|
125 |
+
m = self.conv(x)
|
126 |
+
m = m.sigmoid().exp()
|
127 |
+
return m
|
128 |
+
|
129 |
+
def norm(self, x: torch.Tensor):
|
130 |
+
return x / x.sum(dim=[2, 3], keepdim=True)
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
input_x = x
|
134 |
+
x = self.fscore(x)
|
135 |
+
x = self.norm(x)
|
136 |
+
x = x * input_x
|
137 |
+
x = x.sum(dim=[2, 3], keepdim=not self.flatten)
|
138 |
+
return x
|
139 |
+
|
140 |
+
|
141 |
+
class DeepFakeClassifier(nn.Module):
|
142 |
+
def __init__(self, encoder, dropout_rate=0.0) -> None:
|
143 |
+
super().__init__()
|
144 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
145 |
+
self.avg_pool = AdaptiveAvgPool2d((1, 1))
|
146 |
+
self.dropout = Dropout(dropout_rate)
|
147 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
x = self.encoder.forward_features(x)
|
151 |
+
x = self.avg_pool(x).flatten(1)
|
152 |
+
x = self.dropout(x)
|
153 |
+
x = self.fc(x)
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
class DeepFakeClassifierGWAP(nn.Module):
|
160 |
+
def __init__(self, encoder, dropout_rate=0.5) -> None:
|
161 |
+
super().__init__()
|
162 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
163 |
+
self.avg_pool = GlobalWeightedAvgPool2d(encoder_params[encoder]["features"])
|
164 |
+
self.dropout = Dropout(dropout_rate)
|
165 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
x = self.encoder.forward_features(x)
|
169 |
+
x = self.avg_pool(x).flatten(1)
|
170 |
+
x = self.dropout(x)
|
171 |
+
x = self.fc(x)
|
172 |
+
return x
|
predict/kernel_utils.py → kernel_utils.py
RENAMED
@@ -50,7 +50,7 @@ class VideoReader:
|
|
50 |
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
51 |
if frame_count <= 0: return None
|
52 |
|
53 |
-
frame_idxs = np.linspace(0, frame_count - 1, num_frames, endpoint=True, dtype=np.
|
54 |
if jitter > 0:
|
55 |
np.random.seed(seed)
|
56 |
jitter_offsets = np.random.randint(-jitter, jitter, len(frame_idxs))
|
@@ -201,7 +201,7 @@ class FaceExtractor:
|
|
201 |
self.video_read_fn = video_read_fn
|
202 |
self.detector = MTCNN(margin=0, thresholds=[0.7, 0.8, 0.8], device="cuda")
|
203 |
|
204 |
-
def process_videos(self,
|
205 |
videos_read = []
|
206 |
frames_read = []
|
207 |
frames = []
|
@@ -211,7 +211,8 @@ class FaceExtractor:
|
|
211 |
# filename = filenames[video_idx]
|
212 |
# video_path = os.path.join(input_dir, filename)
|
213 |
# result = self.video_read_fn(video_path)
|
214 |
-
result =
|
|
|
215 |
# Error? Then skip this video.
|
216 |
|
217 |
# Keep track of the original frames (need them later).
|
@@ -241,7 +242,7 @@ class FaceExtractor:
|
|
241 |
faces.append(crop)
|
242 |
scores.append(score)
|
243 |
|
244 |
-
frame_dict = {"video_idx": video_idx,
|
245 |
"frame_idx": my_idxs[i],
|
246 |
"frame_w": w,
|
247 |
"frame_h": h,
|
@@ -255,7 +256,7 @@ class FaceExtractor:
|
|
255 |
"""Convenience method for doing face extraction on a single video."""
|
256 |
input_dir = os.path.dirname(video_path)
|
257 |
filenames = [os.path.basename(video_path)]
|
258 |
-
return self.process_videos(
|
259 |
|
260 |
|
261 |
|
@@ -320,7 +321,7 @@ def predict_on_video(face_extractor, video_path, videos, batch_size, input_size,
|
|
320 |
else:
|
321 |
pass
|
322 |
if n > 0:
|
323 |
-
x = torch.tensor(x, device="
|
324 |
# Preprocess the images.
|
325 |
x = x.permute((0, 3, 1, 2))
|
326 |
for i in range(len(x)):
|
@@ -329,7 +330,7 @@ def predict_on_video(face_extractor, video_path, videos, batch_size, input_size,
|
|
329 |
with torch.no_grad():
|
330 |
preds = []
|
331 |
for model in models:
|
332 |
-
y_pred = model(x[:n]
|
333 |
y_pred = torch.sigmoid(y_pred.squeeze())
|
334 |
bpred = y_pred[:n].cpu().numpy()
|
335 |
preds.append(strategy(bpred))
|
@@ -354,5 +355,4 @@ def predict_on_video_set(face_extractor, videos, input_size, num_workers, test_d
|
|
354 |
|
355 |
with ThreadPoolExecutor(max_workers=num_workers) as ex:
|
356 |
predictions = ex.map(process_file, [1])
|
357 |
-
return list(predictions)
|
358 |
-
|
|
|
50 |
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
51 |
if frame_count <= 0: return None
|
52 |
|
53 |
+
frame_idxs = np.linspace(0, frame_count - 1, num_frames, endpoint=True, dtype=np.int32)
|
54 |
if jitter > 0:
|
55 |
np.random.seed(seed)
|
56 |
jitter_offsets = np.random.randint(-jitter, jitter, len(frame_idxs))
|
|
|
201 |
self.video_read_fn = video_read_fn
|
202 |
self.detector = MTCNN(margin=0, thresholds=[0.7, 0.8, 0.8], device="cuda")
|
203 |
|
204 |
+
def process_videos(self, video_path):
|
205 |
videos_read = []
|
206 |
frames_read = []
|
207 |
frames = []
|
|
|
211 |
# filename = filenames[video_idx]
|
212 |
# video_path = os.path.join(input_dir, filename)
|
213 |
# result = self.video_read_fn(video_path)
|
214 |
+
result = self.video_read_fn(video_path)
|
215 |
+
# result = video
|
216 |
# Error? Then skip this video.
|
217 |
|
218 |
# Keep track of the original frames (need them later).
|
|
|
242 |
faces.append(crop)
|
243 |
scores.append(score)
|
244 |
|
245 |
+
frame_dict = { #"video_idx": video_idx,
|
246 |
"frame_idx": my_idxs[i],
|
247 |
"frame_w": w,
|
248 |
"frame_h": h,
|
|
|
256 |
"""Convenience method for doing face extraction on a single video."""
|
257 |
input_dir = os.path.dirname(video_path)
|
258 |
filenames = [os.path.basename(video_path)]
|
259 |
+
return self.process_videos(video_path)
|
260 |
|
261 |
|
262 |
|
|
|
321 |
else:
|
322 |
pass
|
323 |
if n > 0:
|
324 |
+
x = torch.tensor(x, device="cpu").float()
|
325 |
# Preprocess the images.
|
326 |
x = x.permute((0, 3, 1, 2))
|
327 |
for i in range(len(x)):
|
|
|
330 |
with torch.no_grad():
|
331 |
preds = []
|
332 |
for model in models:
|
333 |
+
y_pred = model(x[:n]) #
|
334 |
y_pred = torch.sigmoid(y_pred.squeeze())
|
335 |
bpred = y_pred[:n].cpu().numpy()
|
336 |
preds.append(strategy(bpred))
|
|
|
355 |
|
356 |
with ThreadPoolExecutor(max_workers=num_workers) as ex:
|
357 |
predictions = ex.map(process_file, [1])
|
358 |
+
return list(predictions)
|
|
predict/app.py
DELETED
@@ -1,68 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import argparse
|
3 |
-
import os
|
4 |
-
import re
|
5 |
-
import time
|
6 |
-
|
7 |
-
import torch
|
8 |
-
import pandas as pd
|
9 |
-
|
10 |
-
import os, sys
|
11 |
-
root_folder = os.path.abspath(
|
12 |
-
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
13 |
-
)
|
14 |
-
sys.path.append(root_folder)
|
15 |
-
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set
|
16 |
-
from training.zoo.classifiers import DeepFakeClassifier
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
def predict(video):
|
21 |
-
# video_index = int(video_index)
|
22 |
-
|
23 |
-
frames_per_video = 32
|
24 |
-
video_reader = VideoReader()
|
25 |
-
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
|
26 |
-
face_extractor = FaceExtractor(video_read_fn)
|
27 |
-
input_size = 380
|
28 |
-
strategy = confident_strategy
|
29 |
-
|
30 |
-
# test_videos = sorted([x for x in os.listdir(args.test_dir) if x[-4:] == ".mp4"])[video_index]
|
31 |
-
# print(f"Predicting {video_index} videos")
|
32 |
-
predictions = predict_on_video_set(face_extractor=face_extractor, input_size=input_size, models=models,
|
33 |
-
strategy=strategy, frames_per_video=frames_per_video, videos=video,
|
34 |
-
num_workers=6, test_dir=args.test_dir)
|
35 |
-
return predictions
|
36 |
-
|
37 |
-
def get_args_models():
|
38 |
-
parser = argparse.ArgumentParser("Predict test videos")
|
39 |
-
arg = parser.add_argument
|
40 |
-
arg('--weights-dir', type=str, default="weights", help="path to directory with checkpoints")
|
41 |
-
arg('--models', type=str, default='classifier_DeepFakeClassifier_tf_efficientnet_b7_ns_1_best_dice', help="checkpoint files") # nargs='+',
|
42 |
-
arg('--test-dir', type=str, default='test_dataset', help="path to directory with videos")
|
43 |
-
arg('--output', type=str, required=False, help="path to output csv", default="submission.csv")
|
44 |
-
args = parser.parse_args()
|
45 |
-
|
46 |
-
models = []
|
47 |
-
# model_paths = [os.path.join(args.weights_dir, model) for model in args.models]
|
48 |
-
model_paths = [os.path.join(args.weights_dir, args.models)]
|
49 |
-
for path in model_paths:
|
50 |
-
model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to("cpu")
|
51 |
-
print("loading state dict {}".format(path))
|
52 |
-
checkpoint = torch.load(path, map_location="cpu")
|
53 |
-
state_dict = checkpoint.get("state_dict", checkpoint)
|
54 |
-
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
|
55 |
-
model.eval()
|
56 |
-
del checkpoint
|
57 |
-
models.append(model.half())
|
58 |
-
return args, models
|
59 |
-
|
60 |
-
if __name__ == '__main__':
|
61 |
-
global models, args
|
62 |
-
stime = time.time()
|
63 |
-
print("Elapsed:", time.time() - stime)
|
64 |
-
args, models = get_args_models()
|
65 |
-
|
66 |
-
demo = gr.Interface(fn=predict, inputs="image", outputs="text")
|
67 |
-
demo.launch()
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|