Spaces:
Running
Running
import torch | |
def matcher_metrics(pred, data, prefix="", prefix_gt=None): | |
def recall(m, gt_m): | |
mask = (gt_m > -1).float() | |
return ((m == gt_m) * mask).sum(1) / (1e-8 + mask.sum(1)) | |
def accuracy(m, gt_m): | |
mask = (gt_m >= -1).float() | |
return ((m == gt_m) * mask).sum(1) / (1e-8 + mask.sum(1)) | |
def precision(m, gt_m): | |
mask = ((m > -1) & (gt_m >= -1)).float() | |
return ((m == gt_m) * mask).sum(1) / (1e-8 + mask.sum(1)) | |
def ranking_ap(m, gt_m, scores): | |
p_mask = ((m > -1) & (gt_m >= -1)).float() | |
r_mask = (gt_m > -1).float() | |
sort_ind = torch.argsort(-scores) | |
sorted_p_mask = torch.gather(p_mask, -1, sort_ind) | |
sorted_r_mask = torch.gather(r_mask, -1, sort_ind) | |
sorted_tp = torch.gather(m == gt_m, -1, sort_ind) | |
p_pts = torch.cumsum(sorted_tp * sorted_p_mask, -1) / ( | |
1e-8 + torch.cumsum(sorted_p_mask, -1) | |
) | |
r_pts = torch.cumsum(sorted_tp * sorted_r_mask, -1) / ( | |
1e-8 + sorted_r_mask.sum(-1)[:, None] | |
) | |
r_pts_diff = r_pts[..., 1:] - r_pts[..., :-1] | |
return torch.sum(r_pts_diff * p_pts[:, None, -1], dim=-1) | |
if prefix_gt is None: | |
prefix_gt = prefix | |
rec = recall(pred[f"{prefix}matches0"], data[f"gt_{prefix_gt}matches0"]) | |
prec = precision(pred[f"{prefix}matches0"], data[f"gt_{prefix_gt}matches0"]) | |
acc = accuracy(pred[f"{prefix}matches0"], data[f"gt_{prefix_gt}matches0"]) | |
ap = ranking_ap( | |
pred[f"{prefix}matches0"], | |
data[f"gt_{prefix_gt}matches0"], | |
pred[f"{prefix}matching_scores0"], | |
) | |
metrics = { | |
f"{prefix}match_recall": rec, | |
f"{prefix}match_precision": prec, | |
f"{prefix}accuracy": acc, | |
f"{prefix}average_precision": ap, | |
} | |
return metrics | |