Spaces:
Runtime error
Runtime error
File size: 2,912 Bytes
51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from .utils import weighted_loss
@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def knowledge_distillation_kl_div_loss(pred,
soft_label,
T,
detach_target=True):
r"""Loss function for knowledge distilling using KL divergence.
Args:
pred (Tensor): Predicted logits with shape (N, n + 1).
soft_label (Tensor): Target logits with shape (N, N + 1).
T (int): Temperature for distillation.
detach_target (bool): Remove soft_label from automatic differentiation
Returns:
torch.Tensor: Loss tensor with shape (N,).
"""
assert pred.size() == soft_label.size()
target = F.softmax(soft_label / T, dim=1)
if detach_target:
target = target.detach()
kd_loss = F.kl_div(
F.log_softmax(pred / T, dim=1), target, reduction='none').mean(1) * (
T * T)
return kd_loss
@LOSSES.register_module()
class KnowledgeDistillationKLDivLoss(nn.Module):
"""Loss function for knowledge distilling using KL divergence.
Args:
reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
loss_weight (float): Loss weight of current loss.
T (int): Temperature for distillation.
"""
def __init__(self, reduction='mean', loss_weight=1.0, T=10):
super(KnowledgeDistillationKLDivLoss, self).__init__()
assert T >= 1
self.reduction = reduction
self.loss_weight = loss_weight
self.T = T
def forward(self,
pred,
soft_label,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (Tensor): Predicted logits with shape (N, n + 1).
soft_label (Tensor): Target logits with shape (N, N + 1).
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_kd = self.loss_weight * knowledge_distillation_kl_div_loss(
pred,
soft_label,
weight,
reduction=reduction,
avg_factor=avg_factor,
T=self.T)
return loss_kd
|