File size: 2,382 Bytes
88b0dcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
@date: 2021/7/19
@description:
"""
import torch
import loss

from utils.misc import tensor2np


def build_criterion(config, logger):
    criterion = {}
    device = config.TRAIN.DEVICE

    for k in config.TRAIN.CRITERION.keys():
        sc = config.TRAIN.CRITERION[k]
        if sc.WEIGHT is None or float(sc.WEIGHT) == 0:
            continue
        criterion[sc.NAME] = {
            'loss': getattr(loss, sc.LOSS)(),
            'weight': float(sc.WEIGHT),
            'sub_weights': sc.WEIGHTS,
            'need_all': sc.NEED_ALL
        }

        criterion[sc.NAME]['loss'] = criterion[sc.NAME]['loss'].to(device)
        if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device:
            criterion[sc.NAME]['loss'] = criterion[sc.NAME]['loss'].type(torch.float16)

        # logger.info(f"Build criterion:{sc.WEIGHT}_{sc.NAME}_{sc.LOSS}_{sc.WEIGHTS}")
    return criterion


def calc_criterion(criterion, gt, dt, epoch_loss_d):
    loss = None
    postfix_d = {}
    for k in criterion.keys():
        if criterion[k]['need_all']:
            single_loss = criterion[k]['loss'](gt, dt)
            ws_loss = None
            for i, sub_weight in enumerate(criterion[k]['sub_weights']):
                if sub_weight == 0:
                    continue
                if ws_loss is None:
                    ws_loss = single_loss[i] * sub_weight
                else:
                    ws_loss = ws_loss + single_loss[i] * sub_weight
            single_loss = ws_loss if ws_loss is not None else single_loss
        else:
            assert k in gt.keys(), "ground label is None:" + k
            assert k in dt.keys(), "detection key is None:" + k
            if k == 'ratio' and gt[k].shape[-1] != dt[k].shape[-1]:
                gt[k] = gt[k].repeat(1, dt[k].shape[-1])
            single_loss = criterion[k]['loss'](gt[k], dt[k])

        postfix_d[k] = tensor2np(single_loss)
        if k not in epoch_loss_d.keys():
            epoch_loss_d[k] = []
        epoch_loss_d[k].append(postfix_d[k])

        single_loss = single_loss * criterion[k]['weight']
        if loss is None:
            loss = single_loss
        else:
            loss = loss + single_loss

    k = 'loss'
    postfix_d[k] = tensor2np(loss)
    if k not in epoch_loss_d.keys():
        epoch_loss_d[k] = []
    epoch_loss_d[k].append(postfix_d[k])
    return loss, postfix_d, epoch_loss_d