File size: 3,340 Bytes
ce91ea1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from torch import nn
import torch
import torch.nn.functional as F
import numpy as np


class Flatten(nn.Module):
    def forward(self, x):
        x = x.view(x.size()[0], -1)
        return x

class LSEPool2d(nn.Module):

    def __init__(self, r=3):
        super().__init__()
        self.r =r

    def forward(self, x):
        s = x.size()[3]  # x: bs*2048*7*7
        r = self.r
        x_max = F.adaptive_max_pool2d(x, 1) # x_max: bs*2048*1*1
        p = ((1/r) * torch.log((1 / (s*s)) * torch.exp(r*(x - x_max)).sum(3).sum(2)))
        x_max = x_max.view(x.size(0), -1) # bs*2048
        return x_max+p


class WeightedBCEWithLogitsLoss(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, input, target):
        w = self.get_weight(input, target)
        return F.binary_cross_entropy_with_logits(input, target, w, reduction='mean')

    def get_weight(self, input, target):
        y = target.cpu().data.numpy()
        y_hat = input.cpu().data.numpy()
        P = np.count_nonzero(y == 1)
        N = np.count_nonzero(y == 0)
        beta_p = (P + N) / (P + 1) # may not contain disease
        beta_n = (P + N) / N
        w = np.empty(y.shape)
        w[y==0] = beta_n
        w[y==1] = beta_p
        w = torch.FloatTensor(w).cuda()
        return w

class SaveFeature:
    features = None

    def __init__(self, m):
        self.hook = m.register_forward_hook(self.hook_fn)

    def hook_fn(self, module, input, output):
        self.features = output

    def remove(self):
        self.hook.remove()

# class FocalLoss(WeightedBCELoss):

#     def __init__(self, theta=2):
#         super().__init__()
#         self.theta = theta

#     def forward(self, input, target):
# #         pt = target*input + (1-target)*(1-input)
# #         target *= (1-pt)**self.theta
#         w = self.get_weight(input, target)
#         return F.binary_cross_entropy_with_logits(input, target, w)



# class FocalLoss(nn.Module):
#     def __init__(self, gamma=0, alpha=None, size_average=True):
#         super(FocalLoss, self).__init__()
#         self.gamma = gamma
#         self.alpha = alpha
#         if isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha])
#         if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
#         self.size_average = size_average

#     def forward(self, input, target):
#         if input.dim()>2:
#             input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
#             input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
#             input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
#         target = target.view(-1,1)

#         logpt = F.log_softmax(input)
#         logpt = logpt.gather(1,target)
#         logpt = logpt.view(-1)
#         pt = Variable(logpt.data.exp())

#         if self.alpha is not None:
#             if self.alpha.type()!=input.data.type():
#                 self.alpha = self.alpha.type_as(input.data)
#             at = self.alpha.gather(0,target.data.view(-1))
#             logpt = logpt * Variable(at)

#         loss = -1 * (1-pt)**self.gamma * logpt
#         if self.size_average: return loss.mean()
#         else: return loss.sum()