File size: 2,310 Bytes
033bd8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch


def loss_generator(ignore: list = None):
    loss_fn = {'mse': mse,
               'lut_mse': lut_mse,
               'masked_mse': masked_mse,
               'sample_weighted_mse': sample_weighted_mse,
               'regularize_LUT': regularize_LUT,
               'MaskWeightedMSE': MaskWeightedMSE}

    if ignore:
        for fn in ignore:
            ignore.pop(fn)

    return loss_fn


def mse(pred, gt):
    return torch.mean((pred - gt) ** 2)


def masked_mse(pred, gt, mask):
    delimin = torch.clamp_min(torch.sum(mask, dim=([x for x in range(1, len(mask.shape))])), 100).cuda()
    # total = torch.sum(torch.ones_like(mask), dim=([x for x in range(1, len(mask.shape))]))
    out = torch.sum((mask > 100 / 255.) * (pred - gt) ** 2, dim=([x for x in range(1, len(mask.shape))]))
    out = out / delimin
    return torch.mean(out)


def sample_weighted_mse(pred, gt, mask):
    multi_factor = torch.clamp_min(torch.sum(mask, dim=([x for x in range(1, len(mask.shape))])), 100).cuda()
    multi_factor = multi_factor / (multi_factor.sum())
    # total = torch.sum(torch.ones_like(mask), dim=([x for x in range(1, len(mask.shape))]))
    out = torch.mean((pred - gt) ** 2, dim=([x for x in range(1, len(mask.shape))]))
    out = out * multi_factor
    return torch.sum(out)


def regularize_LUT(lut):
    st = lut[lut < 0.]
    reg_st = (st ** 2).mean() if min(st.shape) != 0 else 0

    lt = lut[lut > 1.]
    reg_lt = ((lt - 1.) ** 2).mean() if min(lt.shape) != 0 else 0

    return reg_lt + reg_st


def lut_mse(feat, lut_batch):
    loss = 0
    for id in range(feat.shape[0] // lut_batch):
        for i in feat[id * lut_batch: id * lut_batch + lut_batch]:
            for j in feat[id * lut_batch: id * lut_batch + lut_batch]:
                loss += mse(i, j)

    return loss / lut_batch


def MaskWeightedMSE(pred, label, mask):
    label = label.view(pred.size())
    reduce_dims = get_dims_with_exclusion(label.dim(), 0)

    loss = (pred - label) ** 2
    delimeter = pred.size(1) * torch.clamp_min(torch.sum(mask, dim=reduce_dims), 100)
    loss = torch.sum(loss, dim=reduce_dims) / delimeter

    return torch.mean(loss)


def get_dims_with_exclusion(dim, exclude=None):
    dims = list(range(dim))
    if exclude is not None:
        dims.remove(exclude)

    return dims