yotamsapi commited on
Commit
3fc64a0
β€’
1 Parent(s): 86de714

Create models.py

Browse files
Files changed (1) hide show
  1. retinaface/models.py +301 -0
retinaface/models.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import Model
3
+ from tensorflow.keras.applications import MobileNetV2, ResNet50
4
+ from tensorflow.keras.layers import Input, Conv2D, ReLU, LeakyReLU
5
+ from retinaface.anchor import decode_tf, prior_box_tf
6
+
7
+
8
+ def _regularizer(weights_decay):
9
+ """l2 regularizer"""
10
+ return tf.keras.regularizers.l2(weights_decay)
11
+
12
+
13
+ def _kernel_init(scale=1.0, seed=None):
14
+ """He normal initializer"""
15
+ return tf.keras.initializers.he_normal()
16
+
17
+
18
+ class BatchNormalization(tf.keras.layers.BatchNormalization):
19
+ """Make trainable=False freeze BN for real (the og version is sad).
20
+ ref: https://github.com/zzh8829/yolov3-tf2
21
+ """
22
+ def __init__(self, axis=-1, momentum=0.9, epsilon=1e-5, center=True,
23
+ scale=True, name=None, **kwargs):
24
+ super(BatchNormalization, self).__init__(
25
+ axis=axis, momentum=momentum, epsilon=epsilon, center=center,
26
+ scale=scale, name=name, **kwargs)
27
+
28
+ def call(self, x, training=False):
29
+ if training is None:
30
+ training = tf.constant(False)
31
+ training = tf.logical_and(training, self.trainable)
32
+
33
+ return super().call(x, training)
34
+
35
+
36
+ def Backbone(backbone_type='ResNet50', use_pretrain=True):
37
+ """Backbone Model"""
38
+ weights = None
39
+ if use_pretrain:
40
+ weights = 'imagenet'
41
+
42
+ def backbone(x):
43
+ if backbone_type == 'ResNet50':
44
+ extractor = ResNet50(
45
+ input_shape=x.shape[1:], include_top=False, weights=weights)
46
+ pick_layer1 = 80 # [80, 80, 512]
47
+ pick_layer2 = 142 # [40, 40, 1024]
48
+ pick_layer3 = 174 # [20, 20, 2048]
49
+ preprocess = tf.keras.applications.resnet.preprocess_input
50
+ elif backbone_type == 'MobileNetV2':
51
+ extractor = MobileNetV2(
52
+ input_shape=x.shape[1:], include_top=False, weights=weights)
53
+ pick_layer1 = 54 # [80, 80, 32]
54
+ pick_layer2 = 116 # [40, 40, 96]
55
+ pick_layer3 = 143 # [20, 20, 160]
56
+ preprocess = tf.keras.applications.mobilenet_v2.preprocess_input
57
+ else:
58
+ raise NotImplementedError(
59
+ 'Backbone type {} is not recognized.'.format(backbone_type))
60
+
61
+ return Model(extractor.input,
62
+ (extractor.layers[pick_layer1].output,
63
+ extractor.layers[pick_layer2].output,
64
+ extractor.layers[pick_layer3].output),
65
+ name=backbone_type + '_extrator')(preprocess(x))
66
+
67
+ return backbone
68
+
69
+
70
+ class ConvUnit(tf.keras.layers.Layer):
71
+ """Conv + BN + Act"""
72
+ def __init__(self, f, k, s, wd, act=None, **kwargs):
73
+ super(ConvUnit, self).__init__(**kwargs)
74
+ self.conv = Conv2D(filters=f, kernel_size=k, strides=s, padding='same',
75
+ kernel_initializer=_kernel_init(),
76
+ kernel_regularizer=_regularizer(wd),
77
+ use_bias=False)
78
+ self.bn = BatchNormalization()
79
+
80
+ if act is None:
81
+ self.act_fn = tf.identity
82
+ elif act == 'relu':
83
+ self.act_fn = ReLU()
84
+ elif act == 'lrelu':
85
+ self.act_fn = LeakyReLU(0.1)
86
+ else:
87
+ raise NotImplementedError(
88
+ 'Activation function type {} is not recognized.'.format(act))
89
+
90
+ def call(self, x):
91
+ return self.act_fn(self.bn(self.conv(x)))
92
+
93
+
94
+ class FPN(tf.keras.layers.Layer):
95
+ """Feature Pyramid Network"""
96
+ def __init__(self, out_ch, wd, **kwargs):
97
+ super(FPN, self).__init__(**kwargs)
98
+ act = 'relu'
99
+ self.out_ch = out_ch
100
+ self.wd = wd
101
+ if (out_ch <= 64):
102
+ act = 'lrelu'
103
+
104
+ self.output1 = ConvUnit(f=out_ch, k=1, s=1, wd=wd, act=act)
105
+ self.output2 = ConvUnit(f=out_ch, k=1, s=1, wd=wd, act=act)
106
+ self.output3 = ConvUnit(f=out_ch, k=1, s=1, wd=wd, act=act)
107
+ self.merge1 = ConvUnit(f=out_ch, k=3, s=1, wd=wd, act=act)
108
+ self.merge2 = ConvUnit(f=out_ch, k=3, s=1, wd=wd, act=act)
109
+
110
+ def call(self, x):
111
+ output1 = self.output1(x[0]) # [80, 80, out_ch]
112
+ output2 = self.output2(x[1]) # [40, 40, out_ch]
113
+ output3 = self.output3(x[2]) # [20, 20, out_ch]
114
+
115
+ up_h, up_w = tf.shape(output2)[1], tf.shape(output2)[2]
116
+ up3 = tf.image.resize(output3, [up_h, up_w], method='nearest')
117
+ output2 = output2 + up3
118
+ output2 = self.merge2(output2)
119
+
120
+ up_h, up_w = tf.shape(output1)[1], tf.shape(output1)[2]
121
+ up2 = tf.image.resize(output2, [up_h, up_w], method='nearest')
122
+ output1 = output1 + up2
123
+ output1 = self.merge1(output1)
124
+
125
+ return output1, output2, output3
126
+
127
+ def get_config(self):
128
+ config = {
129
+ 'out_ch': self.out_ch,
130
+ 'wd': self.wd,
131
+ }
132
+ base_config = super(FPN, self).get_config()
133
+ return dict(list(base_config.items()) + list(config.items()))
134
+
135
+
136
+ class SSH(tf.keras.layers.Layer):
137
+ """Single Stage Headless Layer"""
138
+ def __init__(self, out_ch, wd, **kwargs):
139
+ super(SSH, self).__init__(**kwargs)
140
+ assert out_ch % 4 == 0
141
+ self.out_ch = out_ch
142
+ self.wd = wd
143
+ act = 'relu'
144
+ if (out_ch <= 64):
145
+ act = 'lrelu'
146
+
147
+ self.conv_3x3 = ConvUnit(f=out_ch // 2, k=3, s=1, wd=wd, act=None)
148
+
149
+ self.conv_5x5_1 = ConvUnit(f=out_ch // 4, k=3, s=1, wd=wd, act=act)
150
+ self.conv_5x5_2 = ConvUnit(f=out_ch // 4, k=3, s=1, wd=wd, act=None)
151
+
152
+ self.conv_7x7_2 = ConvUnit(f=out_ch // 4, k=3, s=1, wd=wd, act=act)
153
+ self.conv_7x7_3 = ConvUnit(f=out_ch // 4, k=3, s=1, wd=wd, act=None)
154
+
155
+ self.relu = ReLU()
156
+
157
+ def call(self, x):
158
+ conv_3x3 = self.conv_3x3(x)
159
+
160
+ conv_5x5_1 = self.conv_5x5_1(x)
161
+ conv_5x5 = self.conv_5x5_2(conv_5x5_1)
162
+
163
+ conv_7x7_2 = self.conv_7x7_2(conv_5x5_1)
164
+ conv_7x7 = self.conv_7x7_3(conv_7x7_2)
165
+
166
+ output = tf.concat([conv_3x3, conv_5x5, conv_7x7], axis=3)
167
+ output = self.relu(output)
168
+
169
+ return output
170
+
171
+ def get_config(self):
172
+ config = {
173
+ 'out_ch': self.out_ch,
174
+ 'wd': self.wd,
175
+ }
176
+ base_config = super(SSH, self).get_config()
177
+ return dict(list(base_config.items()) + list(config.items()))
178
+
179
+
180
+ class BboxHead(tf.keras.layers.Layer):
181
+ """Bbox Head Layer"""
182
+ def __init__(self, num_anchor, wd, **kwargs):
183
+ super(BboxHead, self).__init__(**kwargs)
184
+ self.num_anchor = num_anchor
185
+ self.wd = wd
186
+ self.conv = Conv2D(filters=num_anchor * 4, kernel_size=1, strides=1)
187
+
188
+ def call(self, x):
189
+ h, w = tf.shape(x)[1], tf.shape(x)[2]
190
+ x = self.conv(x)
191
+
192
+ return tf.reshape(x, [-1, h * w * self.num_anchor, 4])
193
+
194
+ def get_config(self):
195
+ config = {
196
+ 'num_anchor': self.num_anchor,
197
+ 'wd': self.wd,
198
+ }
199
+ base_config = super(BboxHead, self).get_config()
200
+ return dict(list(base_config.items()) + list(config.items()))
201
+
202
+
203
+ class LandmarkHead(tf.keras.layers.Layer):
204
+ """Landmark Head Layer"""
205
+ def __init__(self, num_anchor, wd, name='LandmarkHead', **kwargs):
206
+ super(LandmarkHead, self).__init__(name=name, **kwargs)
207
+ self.num_anchor = num_anchor
208
+ self.wd = wd
209
+ self.conv = Conv2D(filters=num_anchor * 10, kernel_size=1, strides=1)
210
+
211
+ def call(self, x):
212
+ h, w = tf.shape(x)[1], tf.shape(x)[2]
213
+ x = self.conv(x)
214
+
215
+ return tf.reshape(x, [-1, h * w * self.num_anchor, 10])
216
+
217
+ def get_config(self):
218
+ config = {
219
+ 'num_anchor': self.num_anchor,
220
+ 'wd': self.wd,
221
+ }
222
+ base_config = super(LandmarkHead, self).get_config()
223
+ return dict(list(base_config.items()) + list(config.items()))
224
+
225
+
226
+ class ClassHead(tf.keras.layers.Layer):
227
+ """Class Head Layer"""
228
+ def __init__(self, num_anchor, wd, name='ClassHead', **kwargs):
229
+ super(ClassHead, self).__init__(name=name, **kwargs)
230
+ self.num_anchor = num_anchor
231
+ self.wd = wd
232
+ self.conv = Conv2D(filters=num_anchor * 2, kernel_size=1, strides=1)
233
+
234
+ def call(self, x):
235
+ h, w = tf.shape(x)[1], tf.shape(x)[2]
236
+ x = self.conv(x)
237
+
238
+ return tf.reshape(x, [-1, h * w * self.num_anchor, 2])
239
+
240
+ def get_config(self):
241
+ config = {
242
+ 'num_anchor': self.num_anchor,
243
+ 'wd': self.wd,
244
+ }
245
+ base_config = super(ClassHead, self).get_config()
246
+ return dict(list(base_config.items()) + list(config.items()))
247
+
248
+
249
+ def RetinaFaceModel(cfg, training=False, iou_th=0.4, score_th=0.02,
250
+ name='RetinaFaceModel'):
251
+ """Retina Face Model"""
252
+ input_size = cfg['input_size'] if training else None
253
+ wd = cfg['weights_decay']
254
+ out_ch = cfg['out_channel']
255
+ num_anchor = len(cfg['min_sizes'][0])
256
+ backbone_type = cfg['backbone_type']
257
+
258
+ # define model
259
+ x = inputs = Input([input_size, input_size, 3], name='input_image')
260
+
261
+ x = Backbone(backbone_type=backbone_type)(x)
262
+
263
+ fpn = FPN(out_ch=out_ch, wd=wd)(x)
264
+
265
+ features = [SSH(out_ch=out_ch, wd=wd)(f)
266
+ for i, f in enumerate(fpn)]
267
+
268
+ bbox_regressions = tf.concat(
269
+ [BboxHead(num_anchor, wd=wd)(f)
270
+ for i, f in enumerate(features)], axis=1)
271
+ landm_regressions = tf.concat(
272
+ [LandmarkHead(num_anchor, wd=wd, name=f'LandmarkHead_{i}')(f)
273
+ for i, f in enumerate(features)], axis=1)
274
+ classifications = tf.concat(
275
+ [ClassHead(num_anchor, wd=wd, name=f'ClassHead_{i}')(f)
276
+ for i, f in enumerate(features)], axis=1)
277
+
278
+ classifications = tf.keras.layers.Softmax(axis=-1)(classifications)
279
+
280
+ if training:
281
+ out = (bbox_regressions, landm_regressions, classifications)
282
+ else:
283
+ # only for batch size 1
284
+ preds = tf.concat( # [bboxes, landms, landms_valid, conf]
285
+ [bbox_regressions[0],
286
+ landm_regressions[0],
287
+ tf.ones_like(classifications[0, :, 0][..., tf.newaxis]),
288
+ classifications[0, :, 1][..., tf.newaxis]], 1)
289
+ priors = prior_box_tf((tf.shape(inputs)[1], tf.shape(inputs)[2]), cfg['min_sizes'], cfg['steps'], cfg['clip'])
290
+ decode_preds = decode_tf(preds, priors, cfg['variances'])
291
+
292
+ selected_indices = tf.image.non_max_suppression(
293
+ boxes=decode_preds[:, :4],
294
+ scores=decode_preds[:, -1],
295
+ max_output_size=tf.shape(decode_preds)[0],
296
+ iou_threshold=iou_th,
297
+ score_threshold=score_th)
298
+
299
+ out = tf.gather(decode_preds, selected_indices)
300
+
301
+ return Model(inputs, out, name=name), Model(inputs, [bbox_regressions, landm_regressions, classifications], name=name + '_bb_only')