Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from omegaconf import OmegaConf | |
def weight_loss(log_assignment, weights, gamma=0.0): | |
b, m, n = log_assignment.shape | |
m -= 1 | |
n -= 1 | |
loss_sc = log_assignment * weights | |
num_neg0 = weights[:, :m, -1].sum(-1).clamp(min=1.0) | |
num_neg1 = weights[:, -1, :n].sum(-1).clamp(min=1.0) | |
num_pos = weights[:, :m, :n].sum((-1, -2)).clamp(min=1.0) | |
nll_pos = -loss_sc[:, :m, :n].sum((-1, -2)) | |
nll_pos /= num_pos.clamp(min=1.0) | |
nll_neg0 = -loss_sc[:, :m, -1].sum(-1) | |
nll_neg1 = -loss_sc[:, -1, :n].sum(-1) | |
nll_neg = (nll_neg0 + nll_neg1) / (num_neg0 + num_neg1) | |
return nll_pos, nll_neg, num_pos, (num_neg0 + num_neg1) / 2.0 | |
class NLLLoss(nn.Module): | |
default_conf = { | |
"nll_balancing": 0.5, | |
"gamma_f": 0.0, # focal loss | |
} | |
def __init__(self, conf): | |
super().__init__() | |
self.conf = OmegaConf.merge(self.default_conf, conf) | |
self.loss_fn = self.nll_loss | |
def forward(self, pred, data, weights=None): | |
log_assignment = pred["log_assignment"] | |
if weights is None: | |
weights = self.loss_fn(log_assignment, data) | |
nll_pos, nll_neg, num_pos, num_neg = weight_loss( | |
log_assignment, weights, gamma=self.conf.gamma_f | |
) | |
nll = ( | |
self.conf.nll_balancing * nll_pos + (1 - self.conf.nll_balancing) * nll_neg | |
) | |
return ( | |
nll, | |
weights, | |
{ | |
"assignment_nll": nll, | |
"nll_pos": nll_pos, | |
"nll_neg": nll_neg, | |
"num_matchable": num_pos, | |
"num_unmatchable": num_neg, | |
}, | |
) | |
def nll_loss(self, log_assignment, data): | |
m, n = data["gt_matches0"].size(-1), data["gt_matches1"].size(-1) | |
positive = data["gt_assignment"].float() | |
neg0 = (data["gt_matches0"] == -1).float() | |
neg1 = (data["gt_matches1"] == -1).float() | |
weights = torch.zeros_like(log_assignment) | |
weights[:, :m, :n] = positive | |
weights[:, :m, -1] = neg0 | |
weights[:, -1, :n] = neg1 | |
return weights | |