gomoku / DI-engine /ding /torch_utils /loss /multi_logits_loss.py
zjowowen's picture
init space
079c32c
raw
history blame
6.2 kB
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ding.torch_utils.network import one_hot
class MultiLogitsLoss(nn.Module):
"""
Overview:
Base class for supervised learning on linklink, including basic processes.
Interfaces:
``__init__``, ``forward``.
"""
def __init__(self, criterion: str = None, smooth_ratio: float = 0.1) -> None:
"""
Overview:
Initialization method, use cross_entropy as default criterion.
Arguments:
- criterion (:obj:`str`): Criterion type, supports ['cross_entropy', 'label_smooth_ce'].
- smooth_ratio (:obs:`float`): Smoothing ratio for label smoothing.
"""
super(MultiLogitsLoss, self).__init__()
if criterion is None:
criterion = 'cross_entropy'
assert (criterion in ['cross_entropy', 'label_smooth_ce'])
self.criterion = criterion
if self.criterion == 'label_smooth_ce':
self.ratio = smooth_ratio
def _label_process(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.LongTensor:
"""
Overview:
Process the label according to the criterion.
Arguments:
- logits (:obj:`torch.Tensor`): Predicted logits.
- labels (:obj:`torch.LongTensor`): Ground truth.
Returns:
- ret (:obj:`torch.LongTensor`): Processed label.
"""
N = logits.shape[1]
if self.criterion == 'cross_entropy':
return one_hot(labels, num=N)
elif self.criterion == 'label_smooth_ce':
val = float(self.ratio) / (N - 1)
ret = torch.full_like(logits, val)
ret.scatter_(1, labels.unsqueeze(1), 1 - val)
return ret
def _nll_loss(self, nlls: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor:
"""
Overview:
Calculate the negative log likelihood loss.
Arguments:
- nlls (:obj:`torch.Tensor`): Negative log likelihood loss.
- labels (:obj:`torch.LongTensor`): Ground truth.
Returns:
- ret (:obj:`torch.Tensor`): Calculated loss.
"""
ret = (-nlls * (labels.detach()))
return ret.sum(dim=1)
def _get_metric_matrix(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor:
"""
Overview:
Calculate the metric matrix.
Arguments:
- logits (:obj:`torch.Tensor`): Predicted logits.
- labels (:obj:`torch.LongTensor`): Ground truth.
Returns:
- metric (:obj:`torch.Tensor`): Calculated metric matrix.
"""
M, N = logits.shape
labels = self._label_process(logits, labels)
logits = F.log_softmax(logits, dim=1)
metric = []
for i in range(M):
logit = logits[i]
logit = logit.repeat(M).reshape(M, N)
metric.append(self._nll_loss(logit, labels))
return torch.stack(metric, dim=0)
def _match(self, matrix: torch.Tensor):
"""
Overview:
Match the metric matrix.
Arguments:
- matrix (:obj:`torch.Tensor`): Metric matrix.
Returns:
- index (:obj:`np.ndarray`): Matched index.
"""
mat = matrix.clone().detach().to('cpu').numpy()
mat = -mat # maximize
M = mat.shape[0]
index = np.full(M, -1, dtype=np.int32) # -1 note not find link
lx = mat.max(axis=1)
ly = np.zeros(M, dtype=np.float32)
visx = np.zeros(M, dtype=np.bool_)
visy = np.zeros(M, dtype=np.bool_)
def has_augmented_path(t, binary_distance_matrix):
# What is changed? visx, visy, distance_matrix, index
# What is changed within this function? visx, visy, index
visx[t] = True
for i in range(M):
if not visy[i] and binary_distance_matrix[t, i]:
visy[i] = True
if index[i] == -1 or has_augmented_path(index[i], binary_distance_matrix):
index[i] = t
return True
return False
for i in range(M):
while True:
visx.fill(False)
visy.fill(False)
distance_matrix = self._get_distance_matrix(lx, ly, mat, M)
binary_distance_matrix = np.abs(distance_matrix) < 1e-4
if has_augmented_path(i, binary_distance_matrix):
break
masked_distance_matrix = distance_matrix[:, ~visy][visx]
if 0 in masked_distance_matrix.shape: # empty matrix
raise RuntimeError("match error, matrix: {}".format(matrix))
else:
d = masked_distance_matrix.min()
lx[visx] -= d
ly[visy] += d
return index
@staticmethod
def _get_distance_matrix(lx: np.ndarray, ly: np.ndarray, mat: np.ndarray, M: int) -> np.ndarray:
"""
Overview:
Get distance matrix.
Arguments:
- lx (:obj:`np.ndarray`): lx.
- ly (:obj:`np.ndarray`): ly.
- mat (:obj:`np.ndarray`): mat.
- M (:obj:`int`): M.
"""
nlx = np.broadcast_to(lx, [M, M]).T
nly = np.broadcast_to(ly, [M, M])
nret = nlx + nly - mat
return nret
def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor:
"""
Overview:
Calculate multiple logits loss.
Arguments:
- logits (:obj:`torch.Tensor`): Predicted logits, whose shape must be 2-dim, like (B, N).
- labels (:obj:`torch.LongTensor`): Ground truth.
Returns:
- loss (:obj:`torch.Tensor`): Calculated loss.
"""
assert (len(logits.shape) == 2)
metric_matrix = self._get_metric_matrix(logits, labels)
index = self._match(metric_matrix)
loss = []
for i in range(metric_matrix.shape[0]):
loss.append(metric_matrix[index[i], i])
return sum(loss) / len(loss)