PKaushik commited on
Commit
d570deb
1 Parent(s): b9ee486
Files changed (1) hide show
  1. yolov6/models/loss.py +411 -0
yolov6/models/loss.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+
4
+ # The code is based on
5
+ # https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/models/yolo_head.py
6
+ # Copyright (c) Megvii, Inc. and its affiliates.
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import numpy as np
11
+ import torch.nn.functional as F
12
+ from yolov6.utils.figure_iou import IOUloss, pairwise_bbox_iou
13
+
14
+
15
+ class ComputeLoss:
16
+ '''Loss computation func.
17
+ This func contains SimOTA and siou loss.
18
+ '''
19
+ def __init__(self,
20
+ reg_weight=5.0,
21
+ iou_weight=3.0,
22
+ cls_weight=1.0,
23
+ center_radius=2.5,
24
+ eps=1e-7,
25
+ in_channels=[256, 512, 1024],
26
+ strides=[8, 16, 32],
27
+ n_anchors=1,
28
+ iou_type='ciou'
29
+ ):
30
+
31
+ self.reg_weight = reg_weight
32
+ self.iou_weight = iou_weight
33
+ self.cls_weight = cls_weight
34
+
35
+ self.center_radius = center_radius
36
+ self.eps = eps
37
+ self.n_anchors = n_anchors
38
+ self.strides = strides
39
+ self.grids = [torch.zeros(1)] * len(in_channels)
40
+
41
+ # Define criteria
42
+ self.l1_loss = nn.L1Loss(reduction="none")
43
+ self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
44
+ self.iou_loss = IOUloss(iou_type=iou_type, reduction="none")
45
+
46
+ def __call__(
47
+ self,
48
+ outputs,
49
+ targets
50
+ ):
51
+ dtype = outputs[0].type()
52
+ device = targets.device
53
+ loss_cls, loss_obj, loss_iou, loss_l1 = torch.zeros(1, device=device), torch.zeros(1, device=device), \
54
+ torch.zeros(1, device=device), torch.zeros(1, device=device)
55
+ num_classes = outputs[0].shape[-1] - 5
56
+
57
+ outputs, outputs_origin, gt_bboxes_scale, xy_shifts, expanded_strides = self.get_outputs_and_grids(
58
+ outputs, self.strides, dtype, device)
59
+
60
+ total_num_anchors = outputs.shape[1]
61
+ bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
62
+ bbox_preds_org = outputs_origin[:, :, :4] # [batch, n_anchors_all, 4]
63
+ obj_preds = outputs[:, :, 4].unsqueeze(-1) # [batch, n_anchors_all, 1]
64
+ cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]
65
+
66
+ # targets
67
+ batch_size = bbox_preds.shape[0]
68
+ targets_list = np.zeros((batch_size, 1, 5)).tolist()
69
+ for i, item in enumerate(targets.cpu().numpy().tolist()):
70
+ targets_list[int(item[0])].append(item[1:])
71
+ max_len = max((len(l) for l in targets_list))
72
+
73
+ targets = torch.from_numpy(np.array(list(map(lambda l:l + [[-1,0,0,0,0]]*(max_len - len(l)), targets_list)))[:,1:,:]).to(targets.device)
74
+ num_targets_list = (targets.sum(dim=2) > 0).sum(dim=1) # number of objects
75
+
76
+ num_fg, num_gts = 0, 0
77
+ cls_targets, reg_targets, l1_targets, obj_targets, fg_masks = [], [], [], [], []
78
+
79
+ for batch_idx in range(batch_size):
80
+ num_gt = int(num_targets_list[batch_idx])
81
+ num_gts += num_gt
82
+ if num_gt == 0:
83
+ cls_target = outputs.new_zeros((0, num_classes))
84
+ reg_target = outputs.new_zeros((0, 4))
85
+ l1_target = outputs.new_zeros((0, 4))
86
+ obj_target = outputs.new_zeros((total_num_anchors, 1))
87
+ fg_mask = outputs.new_zeros(total_num_anchors).bool()
88
+ else:
89
+
90
+ gt_bboxes_per_image = targets[batch_idx, :num_gt, 1:5].mul_(gt_bboxes_scale)
91
+ gt_classes = targets[batch_idx, :num_gt, 0]
92
+ bboxes_preds_per_image = bbox_preds[batch_idx]
93
+ cls_preds_per_image = cls_preds[batch_idx]
94
+ obj_preds_per_image = obj_preds[batch_idx]
95
+
96
+ try:
97
+ (
98
+ gt_matched_classes,
99
+ fg_mask,
100
+ pred_ious_this_matching,
101
+ matched_gt_inds,
102
+ num_fg_img,
103
+ ) = self.get_assignments(
104
+ batch_idx,
105
+ num_gt,
106
+ total_num_anchors,
107
+ gt_bboxes_per_image,
108
+ gt_classes,
109
+ bboxes_preds_per_image,
110
+ cls_preds_per_image,
111
+ obj_preds_per_image,
112
+ expanded_strides,
113
+ xy_shifts,
114
+ num_classes
115
+ )
116
+
117
+ except RuntimeError:
118
+ print(
119
+ "OOM RuntimeError is raised due to the huge memory cost during label assignment. \
120
+ CPU mode is applied in this batch. If you want to avoid this issue, \
121
+ try to reduce the batch size or image size."
122
+ )
123
+ torch.cuda.empty_cache()
124
+ print("------------CPU Mode for This Batch-------------")
125
+
126
+ _gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
127
+ _gt_classes = gt_classes.cpu().float()
128
+ _bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
129
+ _cls_preds_per_image = cls_preds_per_image.cpu().float()
130
+ _obj_preds_per_image = obj_preds_per_image.cpu().float()
131
+
132
+ _expanded_strides = expanded_strides.cpu().float()
133
+ _xy_shifts = xy_shifts.cpu()
134
+
135
+ (
136
+ gt_matched_classes,
137
+ fg_mask,
138
+ pred_ious_this_matching,
139
+ matched_gt_inds,
140
+ num_fg_img,
141
+ ) = self.get_assignments(
142
+ batch_idx,
143
+ num_gt,
144
+ total_num_anchors,
145
+ _gt_bboxes_per_image,
146
+ _gt_classes,
147
+ _bboxes_preds_per_image,
148
+ _cls_preds_per_image,
149
+ _obj_preds_per_image,
150
+ _expanded_strides,
151
+ _xy_shifts,
152
+ num_classes
153
+ )
154
+
155
+ gt_matched_classes = gt_matched_classes.cuda()
156
+ fg_mask = fg_mask.cuda()
157
+ pred_ious_this_matching = pred_ious_this_matching.cuda()
158
+ matched_gt_inds = matched_gt_inds.cuda()
159
+
160
+ torch.cuda.empty_cache()
161
+ num_fg += num_fg_img
162
+ if num_fg_img > 0:
163
+ cls_target = F.one_hot(
164
+ gt_matched_classes.to(torch.int64), num_classes
165
+ ) * pred_ious_this_matching.unsqueeze(-1)
166
+ obj_target = fg_mask.unsqueeze(-1)
167
+ reg_target = gt_bboxes_per_image[matched_gt_inds]
168
+
169
+ l1_target = self.get_l1_target(
170
+ outputs.new_zeros((num_fg_img, 4)),
171
+ gt_bboxes_per_image[matched_gt_inds],
172
+ expanded_strides[0][fg_mask],
173
+ xy_shifts=xy_shifts[0][fg_mask],
174
+ )
175
+
176
+ cls_targets.append(cls_target)
177
+ reg_targets.append(reg_target)
178
+ obj_targets.append(obj_target)
179
+ l1_targets.append(l1_target)
180
+ fg_masks.append(fg_mask)
181
+
182
+ cls_targets = torch.cat(cls_targets, 0)
183
+ reg_targets = torch.cat(reg_targets, 0)
184
+ obj_targets = torch.cat(obj_targets, 0)
185
+ l1_targets = torch.cat(l1_targets, 0)
186
+ fg_masks = torch.cat(fg_masks, 0)
187
+
188
+ num_fg = max(num_fg, 1)
189
+ # loss
190
+ loss_iou += (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks].T, reg_targets)).sum() / num_fg
191
+ loss_l1 += (self.l1_loss(bbox_preds_org.view(-1, 4)[fg_masks], l1_targets)).sum() / num_fg
192
+
193
+ loss_obj += (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets*1.0)).sum() / num_fg
194
+ loss_cls += (self.bcewithlog_loss(cls_preds.view(-1, num_classes)[fg_masks], cls_targets)).sum() / num_fg
195
+
196
+ total_losses = self.reg_weight * loss_iou + loss_l1 + loss_obj + loss_cls
197
+ return total_losses, torch.cat((self.reg_weight * loss_iou, loss_l1, loss_obj, loss_cls)).detach()
198
+
199
+ def decode_output(self, output, k, stride, dtype, device):
200
+ grid = self.grids[k].to(device)
201
+ batch_size = output.shape[0]
202
+ hsize, wsize = output.shape[2:4]
203
+ if grid.shape[2:4] != output.shape[2:4]:
204
+ yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
205
+ grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype).to(device)
206
+ self.grids[k] = grid
207
+
208
+ output = output.reshape(batch_size, self.n_anchors * hsize * wsize, -1)
209
+ output_origin = output.clone()
210
+ grid = grid.view(1, -1, 2)
211
+
212
+ output[..., :2] = (output[..., :2] + grid) * stride
213
+ output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
214
+
215
+ return output, output_origin, grid, hsize, wsize
216
+
217
+ def get_outputs_and_grids(self, outputs, strides, dtype, device):
218
+ xy_shifts = []
219
+ expanded_strides = []
220
+ outputs_new = []
221
+ outputs_origin = []
222
+
223
+ for k, output in enumerate(outputs):
224
+ output, output_origin, grid, feat_h, feat_w = self.decode_output(
225
+ output, k, strides[k], dtype, device)
226
+
227
+ xy_shift = grid
228
+ expanded_stride = torch.full((1, grid.shape[1], 1), strides[k], dtype=grid.dtype, device=grid.device)
229
+
230
+ xy_shifts.append(xy_shift)
231
+ expanded_strides.append(expanded_stride)
232
+ outputs_new.append(output)
233
+ outputs_origin.append(output_origin)
234
+
235
+ xy_shifts = torch.cat(xy_shifts, 1) # [1, n_anchors_all, 2]
236
+ expanded_strides = torch.cat(expanded_strides, 1) # [1, n_anchors_all, 1]
237
+ outputs_origin = torch.cat(outputs_origin, 1)
238
+ outputs = torch.cat(outputs_new, 1)
239
+
240
+ feat_h *= strides[-1]
241
+ feat_w *= strides[-1]
242
+ gt_bboxes_scale = torch.Tensor([[feat_w, feat_h, feat_w, feat_h]]).type_as(outputs)
243
+
244
+ return outputs, outputs_origin, gt_bboxes_scale, xy_shifts, expanded_strides
245
+
246
+ def get_l1_target(self, l1_target, gt, stride, xy_shifts, eps=1e-8):
247
+
248
+ l1_target[:, 0:2] = gt[:, 0:2] / stride - xy_shifts
249
+ l1_target[:, 2:4] = torch.log(gt[:, 2:4] / stride + eps)
250
+ return l1_target
251
+
252
+ @torch.no_grad()
253
+ def get_assignments(
254
+ self,
255
+ batch_idx,
256
+ num_gt,
257
+ total_num_anchors,
258
+ gt_bboxes_per_image,
259
+ gt_classes,
260
+ bboxes_preds_per_image,
261
+ cls_preds_per_image,
262
+ obj_preds_per_image,
263
+ expanded_strides,
264
+ xy_shifts,
265
+ num_classes
266
+ ):
267
+
268
+ fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
269
+ gt_bboxes_per_image,
270
+ expanded_strides,
271
+ xy_shifts,
272
+ total_num_anchors,
273
+ num_gt,
274
+ )
275
+
276
+ bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
277
+ cls_preds_ = cls_preds_per_image[fg_mask]
278
+ obj_preds_ = obj_preds_per_image[fg_mask]
279
+ num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
280
+
281
+ # cost
282
+ pair_wise_ious = pairwise_bbox_iou(gt_bboxes_per_image, bboxes_preds_per_image, box_format='xywh')
283
+ pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
284
+
285
+ gt_cls_per_image = (
286
+ F.one_hot(gt_classes.to(torch.int64), num_classes)
287
+ .float()
288
+ .unsqueeze(1)
289
+ .repeat(1, num_in_boxes_anchor, 1)
290
+ )
291
+
292
+ with torch.cuda.amp.autocast(enabled=False):
293
+ cls_preds_ = (
294
+ cls_preds_.float().sigmoid_().unsqueeze(0).repeat(num_gt, 1, 1)
295
+ * obj_preds_.float().sigmoid_().unsqueeze(0).repeat(num_gt, 1, 1)
296
+ )
297
+ pair_wise_cls_loss = F.binary_cross_entropy(
298
+ cls_preds_.sqrt_(), gt_cls_per_image, reduction="none"
299
+ ).sum(-1)
300
+ del cls_preds_, obj_preds_
301
+
302
+ cost = (
303
+ self.cls_weight * pair_wise_cls_loss
304
+ + self.iou_weight * pair_wise_ious_loss
305
+ + 100000.0 * (~is_in_boxes_and_center)
306
+ )
307
+
308
+ (
309
+ num_fg,
310
+ gt_matched_classes,
311
+ pred_ious_this_matching,
312
+ matched_gt_inds,
313
+ ) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
314
+
315
+ del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
316
+
317
+ return (
318
+ gt_matched_classes,
319
+ fg_mask,
320
+ pred_ious_this_matching,
321
+ matched_gt_inds,
322
+ num_fg,
323
+ )
324
+
325
+ def get_in_boxes_info(
326
+ self,
327
+ gt_bboxes_per_image,
328
+ expanded_strides,
329
+ xy_shifts,
330
+ total_num_anchors,
331
+ num_gt,
332
+ ):
333
+ expanded_strides_per_image = expanded_strides[0]
334
+ xy_shifts_per_image = xy_shifts[0] * expanded_strides_per_image
335
+ xy_centers_per_image = (
336
+ (xy_shifts_per_image + 0.5 * expanded_strides_per_image)
337
+ .unsqueeze(0)
338
+ .repeat(num_gt, 1, 1)
339
+ ) # [n_anchor, 2] -> [n_gt, n_anchor, 2]
340
+
341
+ gt_bboxes_per_image_lt = (
342
+ (gt_bboxes_per_image[:, 0:2] - 0.5 * gt_bboxes_per_image[:, 2:4])
343
+ .unsqueeze(1)
344
+ .repeat(1, total_num_anchors, 1)
345
+ )
346
+ gt_bboxes_per_image_rb = (
347
+ (gt_bboxes_per_image[:, 0:2] + 0.5 * gt_bboxes_per_image[:, 2:4])
348
+ .unsqueeze(1)
349
+ .repeat(1, total_num_anchors, 1)
350
+ ) # [n_gt, 2] -> [n_gt, n_anchor, 2]
351
+
352
+ b_lt = xy_centers_per_image - gt_bboxes_per_image_lt
353
+ b_rb = gt_bboxes_per_image_rb - xy_centers_per_image
354
+ bbox_deltas = torch.cat([b_lt, b_rb], 2)
355
+
356
+ is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
357
+ is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
358
+
359
+ # in fixed center
360
+ gt_bboxes_per_image_lt = (gt_bboxes_per_image[:, 0:2]).unsqueeze(1).repeat(
361
+ 1, total_num_anchors, 1
362
+ ) - self.center_radius * expanded_strides_per_image.unsqueeze(0)
363
+ gt_bboxes_per_image_rb = (gt_bboxes_per_image[:, 0:2]).unsqueeze(1).repeat(
364
+ 1, total_num_anchors, 1
365
+ ) + self.center_radius * expanded_strides_per_image.unsqueeze(0)
366
+
367
+ c_lt = xy_centers_per_image - gt_bboxes_per_image_lt
368
+ c_rb = gt_bboxes_per_image_rb - xy_centers_per_image
369
+ center_deltas = torch.cat([c_lt, c_rb], 2)
370
+ is_in_centers = center_deltas.min(dim=-1).values > 0.0
371
+ is_in_centers_all = is_in_centers.sum(dim=0) > 0
372
+
373
+ # in boxes and in centers
374
+ is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
375
+
376
+ is_in_boxes_and_center = (
377
+ is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
378
+ )
379
+ return is_in_boxes_anchor, is_in_boxes_and_center
380
+
381
+ def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
382
+ matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
383
+ ious_in_boxes_matrix = pair_wise_ious
384
+ n_candidate_k = min(10, ious_in_boxes_matrix.size(1))
385
+ topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
386
+ dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
387
+ dynamic_ks = dynamic_ks.tolist()
388
+
389
+ for gt_idx in range(num_gt):
390
+ _, pos_idx = torch.topk(
391
+ cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
392
+ )
393
+ matching_matrix[gt_idx][pos_idx] = 1
394
+ del topk_ious, dynamic_ks, pos_idx
395
+
396
+ anchor_matching_gt = matching_matrix.sum(0)
397
+ if (anchor_matching_gt > 1).sum() > 0:
398
+ _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
399
+ matching_matrix[:, anchor_matching_gt > 1] *= 0
400
+ matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
401
+ fg_mask_inboxes = matching_matrix.sum(0) > 0
402
+ num_fg = fg_mask_inboxes.sum().item()
403
+ fg_mask[fg_mask.clone()] = fg_mask_inboxes
404
+ matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
405
+ gt_matched_classes = gt_classes[matched_gt_inds]
406
+
407
+ pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
408
+ fg_mask_inboxes
409
+ ]
410
+
411
+ return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds