|
import datasets |
|
import evaluate |
|
from typing import List |
|
import torch |
|
|
|
|
|
_DESCRIPTION = """ |
|
Quantifying encoder feature distribution properties, Alignment and Uniformity on the Hypersphere. |
|
(https://github.com/ssnl/align_uniform) |
|
""" |
|
|
|
_KWARGS_DESCRIPTION = """ |
|
Args: |
|
xs (`list` of a list of `int`): a group of embeddings |
|
ys (`list` of `int`): the other group of embeddings paired with the ys |
|
|
|
Returns: |
|
"align_loss": float(align_loss_val), |
|
"x_unif_loss": float(x_unif_loss_v), |
|
"y_unif_loss": float(y_unif_loss_v), |
|
"unif_loss": float(unif_loss) |
|
|
|
Examples: |
|
|
|
Example 1-A simple example |
|
>>> metrics = evaluate.load("ahnyeonchan/Alignment-and-Uniformity") |
|
>>> results = metrics.compute(xs=[[1.0, 1.0], [0.0, 1.0]], ys=[[1.0, 1.0], [0.0, 1.0]]) |
|
>>> print(results) |
|
{'align_loss': 0.0, 'x_unif_loss': -2.0, 'y_unif_loss': -2.0, 'unif_loss': -2.0} |
|
""" |
|
|
|
_CITATION = """""" |
|
|
|
|
|
def align_loss(x, y, alpha=2): |
|
return (x - y).norm(p=2, dim=1).pow(alpha).mean() |
|
|
|
|
|
def uniform_loss(x, t=2): |
|
return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log() |
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
class AlignUniform(evaluate.Metric): |
|
def __init__(self, align_alpha: float = 2.0, unif_t: float = 2.0, *args, **kwargs): |
|
super(AlignUniform, self).__init__(*args, **kwargs) |
|
self.align_alpha = align_alpha |
|
self.unif_t = unif_t |
|
|
|
def _info(self): |
|
return evaluate.MetricInfo( |
|
description=_DESCRIPTION, |
|
citation=_CITATION, |
|
inputs_description=_KWARGS_DESCRIPTION, |
|
features=datasets.Features( |
|
{ |
|
"xs": datasets.Sequence(datasets.Value("float32")), |
|
"ys": datasets.Sequence(datasets.Value("float32")), |
|
} |
|
), |
|
reference_urls=[], |
|
) |
|
|
|
def _compute(self, xs: List[List], ys: List[List]): |
|
|
|
if isinstance(xs, torch.Tensor): |
|
xs = torch.Tensor(xs) |
|
elif isinstance(ys, list): |
|
xs = torch.Tensor(xs) |
|
else: |
|
raise NotImplementedError() |
|
|
|
if isinstance(ys, torch.Tensor): |
|
ys = torch.Tensor(ys) |
|
elif isinstance(ys, list): |
|
ys = torch.Tensor(ys) |
|
else: |
|
raise NotImplementedError() |
|
|
|
align_loss_val = align_loss(xs, ys, self.align_alpha) |
|
x_unif_loss_v = uniform_loss(xs, t=self.unif_t) |
|
y_unif_loss_v = uniform_loss(ys, t=self.unif_t) |
|
unif_loss = (x_unif_loss_v + y_unif_loss_v) / 2 |
|
|
|
return { |
|
"align_loss": float(align_loss_val), |
|
"x_unif_loss": float(x_unif_loss_v), |
|
"y_unif_loss": float(y_unif_loss_v), |
|
"unif_loss": float(unif_loss) |
|
} |
|
|