zhigangjiang's picture
no message
88b0dcb
raw
history blame
1.23 kB
"""
@Date: 2021/08/12
@description:
"""
import torch
import torch.nn as nn
from loss.grad_loss import GradLoss
class ObjectLoss(nn.Module):
def __init__(self):
super().__init__()
self.heat_map_loss = HeatmapLoss(reduction='mean') # FocalLoss(reduction='mean')
self.l1_loss = nn.SmoothL1Loss()
def forward(self, gt, dt):
# TODO::
return 0
class HeatmapLoss(nn.Module):
def __init__(self, weight=None, alpha=2, beta=4, reduction='mean'):
super(HeatmapLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.reduction = reduction
def forward(self, targets, inputs):
center_id = (targets == 1.0).float()
other_id = (targets != 1.0).float()
center_loss = -center_id * (1.0 - inputs) ** self.alpha * torch.log(inputs + 1e-14)
other_loss = -other_id * (1 - targets) ** self.beta * inputs ** self.alpha * torch.log(1.0 - inputs + 1e-14)
loss = center_loss + other_loss
batch_size = loss.size(0)
if self.reduction == 'mean':
loss = torch.sum(loss) / batch_size
if self.reduction == 'sum':
loss = torch.sum(loss) / batch_size
return loss