|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from ding.torch_utils.network import one_hot |
|
|
|
|
|
class MultiLogitsLoss(nn.Module): |
|
""" |
|
Overview: |
|
Base class for supervised learning on linklink, including basic processes. |
|
Interfaces: |
|
``__init__``, ``forward``. |
|
""" |
|
|
|
def __init__(self, criterion: str = None, smooth_ratio: float = 0.1) -> None: |
|
""" |
|
Overview: |
|
Initialization method, use cross_entropy as default criterion. |
|
Arguments: |
|
- criterion (:obj:`str`): Criterion type, supports ['cross_entropy', 'label_smooth_ce']. |
|
- smooth_ratio (:obs:`float`): Smoothing ratio for label smoothing. |
|
""" |
|
super(MultiLogitsLoss, self).__init__() |
|
if criterion is None: |
|
criterion = 'cross_entropy' |
|
assert (criterion in ['cross_entropy', 'label_smooth_ce']) |
|
self.criterion = criterion |
|
if self.criterion == 'label_smooth_ce': |
|
self.ratio = smooth_ratio |
|
|
|
def _label_process(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.LongTensor: |
|
""" |
|
Overview: |
|
Process the label according to the criterion. |
|
Arguments: |
|
- logits (:obj:`torch.Tensor`): Predicted logits. |
|
- labels (:obj:`torch.LongTensor`): Ground truth. |
|
Returns: |
|
- ret (:obj:`torch.LongTensor`): Processed label. |
|
""" |
|
N = logits.shape[1] |
|
if self.criterion == 'cross_entropy': |
|
return one_hot(labels, num=N) |
|
elif self.criterion == 'label_smooth_ce': |
|
val = float(self.ratio) / (N - 1) |
|
ret = torch.full_like(logits, val) |
|
ret.scatter_(1, labels.unsqueeze(1), 1 - val) |
|
return ret |
|
|
|
def _nll_loss(self, nlls: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Calculate the negative log likelihood loss. |
|
Arguments: |
|
- nlls (:obj:`torch.Tensor`): Negative log likelihood loss. |
|
- labels (:obj:`torch.LongTensor`): Ground truth. |
|
Returns: |
|
- ret (:obj:`torch.Tensor`): Calculated loss. |
|
""" |
|
ret = (-nlls * (labels.detach())) |
|
return ret.sum(dim=1) |
|
|
|
def _get_metric_matrix(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Calculate the metric matrix. |
|
Arguments: |
|
- logits (:obj:`torch.Tensor`): Predicted logits. |
|
- labels (:obj:`torch.LongTensor`): Ground truth. |
|
Returns: |
|
- metric (:obj:`torch.Tensor`): Calculated metric matrix. |
|
""" |
|
M, N = logits.shape |
|
labels = self._label_process(logits, labels) |
|
logits = F.log_softmax(logits, dim=1) |
|
metric = [] |
|
for i in range(M): |
|
logit = logits[i] |
|
logit = logit.repeat(M).reshape(M, N) |
|
metric.append(self._nll_loss(logit, labels)) |
|
return torch.stack(metric, dim=0) |
|
|
|
def _match(self, matrix: torch.Tensor): |
|
""" |
|
Overview: |
|
Match the metric matrix. |
|
Arguments: |
|
- matrix (:obj:`torch.Tensor`): Metric matrix. |
|
Returns: |
|
- index (:obj:`np.ndarray`): Matched index. |
|
""" |
|
mat = matrix.clone().detach().to('cpu').numpy() |
|
mat = -mat |
|
M = mat.shape[0] |
|
index = np.full(M, -1, dtype=np.int32) |
|
lx = mat.max(axis=1) |
|
ly = np.zeros(M, dtype=np.float32) |
|
visx = np.zeros(M, dtype=np.bool_) |
|
visy = np.zeros(M, dtype=np.bool_) |
|
|
|
def has_augmented_path(t, binary_distance_matrix): |
|
|
|
|
|
visx[t] = True |
|
for i in range(M): |
|
if not visy[i] and binary_distance_matrix[t, i]: |
|
visy[i] = True |
|
if index[i] == -1 or has_augmented_path(index[i], binary_distance_matrix): |
|
index[i] = t |
|
return True |
|
return False |
|
|
|
for i in range(M): |
|
while True: |
|
visx.fill(False) |
|
visy.fill(False) |
|
distance_matrix = self._get_distance_matrix(lx, ly, mat, M) |
|
binary_distance_matrix = np.abs(distance_matrix) < 1e-4 |
|
if has_augmented_path(i, binary_distance_matrix): |
|
break |
|
masked_distance_matrix = distance_matrix[:, ~visy][visx] |
|
if 0 in masked_distance_matrix.shape: |
|
raise RuntimeError("match error, matrix: {}".format(matrix)) |
|
else: |
|
d = masked_distance_matrix.min() |
|
lx[visx] -= d |
|
ly[visy] += d |
|
return index |
|
|
|
@staticmethod |
|
def _get_distance_matrix(lx: np.ndarray, ly: np.ndarray, mat: np.ndarray, M: int) -> np.ndarray: |
|
""" |
|
Overview: |
|
Get distance matrix. |
|
Arguments: |
|
- lx (:obj:`np.ndarray`): lx. |
|
- ly (:obj:`np.ndarray`): ly. |
|
- mat (:obj:`np.ndarray`): mat. |
|
- M (:obj:`int`): M. |
|
""" |
|
nlx = np.broadcast_to(lx, [M, M]).T |
|
nly = np.broadcast_to(ly, [M, M]) |
|
nret = nlx + nly - mat |
|
return nret |
|
|
|
def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Calculate multiple logits loss. |
|
Arguments: |
|
- logits (:obj:`torch.Tensor`): Predicted logits, whose shape must be 2-dim, like (B, N). |
|
- labels (:obj:`torch.LongTensor`): Ground truth. |
|
Returns: |
|
- loss (:obj:`torch.Tensor`): Calculated loss. |
|
""" |
|
assert (len(logits.shape) == 2) |
|
metric_matrix = self._get_metric_matrix(logits, labels) |
|
index = self._match(metric_matrix) |
|
loss = [] |
|
for i in range(metric_matrix.shape[0]): |
|
loss.append(metric_matrix[index[i], i]) |
|
return sum(loss) / len(loss) |
|
|