|
""" |
|
@Date: 2021/08/12 |
|
@description: |
|
""" |
|
import torch |
|
import torch.nn as nn |
|
from loss.grad_loss import GradLoss |
|
|
|
|
|
class ObjectLoss(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.heat_map_loss = HeatmapLoss(reduction='mean') |
|
self.l1_loss = nn.SmoothL1Loss() |
|
|
|
def forward(self, gt, dt): |
|
|
|
return 0 |
|
|
|
|
|
class HeatmapLoss(nn.Module): |
|
def __init__(self, weight=None, alpha=2, beta=4, reduction='mean'): |
|
super(HeatmapLoss, self).__init__() |
|
self.alpha = alpha |
|
self.beta = beta |
|
self.reduction = reduction |
|
|
|
def forward(self, targets, inputs): |
|
center_id = (targets == 1.0).float() |
|
other_id = (targets != 1.0).float() |
|
center_loss = -center_id * (1.0 - inputs) ** self.alpha * torch.log(inputs + 1e-14) |
|
other_loss = -other_id * (1 - targets) ** self.beta * inputs ** self.alpha * torch.log(1.0 - inputs + 1e-14) |
|
loss = center_loss + other_loss |
|
|
|
batch_size = loss.size(0) |
|
if self.reduction == 'mean': |
|
loss = torch.sum(loss) / batch_size |
|
|
|
if self.reduction == 'sum': |
|
loss = torch.sum(loss) / batch_size |
|
|
|
return loss |
|
|