Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,611 Bytes
46ff99b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
# import torch.distributed as dist
logger = logging.getLogger("dinov2")
class KoLeoLoss(nn.Module):
"""Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search"""
def __init__(self):
super().__init__()
self.pdist = nn.PairwiseDistance(2, eps=1e-8)
def pairwise_NNs_inner(self, x):
"""
Pairwise nearest neighbors for L2-normalized vectors.
Uses Torch rather than Faiss to remain on GPU.
"""
# parwise dot products (= inverse distance)
dots = torch.mm(x, x.t())
n = x.shape[0]
dots.view(-1)[:: (n + 1)].fill_(-1) # Trick to fill diagonal with -1
# max inner prod -> min distance
_, I = torch.max(dots, dim=1) # noqa: E741
return I
def forward(self, student_output, eps=1e-8):
"""
Args:
student_output (BxD): backbone output of student
"""
with torch.cuda.amp.autocast(enabled=False):
student_output = F.normalize(student_output, eps=eps, p=2, dim=-1)
I = self.pairwise_NNs_inner(student_output) # noqa: E741
distances = self.pdist(student_output, student_output[I]) # BxD, BxD -> B
loss = -torch.log(distances + eps).mean()
return loss
|