Spaces:
Runtime error
Runtime error
from sentence_transformers import util | |
from tqdm import tqdm | |
def calc_recall(true_pos, false_neg, eps=1e-8): | |
return true_pos / (true_pos + false_neg + eps) | |
def calc_precision(true_pos, false_pos, eps=1e-8): | |
return true_pos / (true_pos + false_pos + eps) | |
def calc_f1_score(precision, recall, eps=1e-8): | |
return (2*precision*recall) / (precision + recall + eps) | |
def calc_metrics(true, predicted, model, threshold=0.95, eps=1e-8): | |
true_pos = 0 | |
false_pos = 0 | |
false_neg = 0 | |
false_pos_ids = [] | |
false_neg_ids = [] | |
i = 0 | |
total = len(true) | |
for j, (true_ents, pred_ents) in tqdm(enumerate(zip(true, predicted))): | |
i += 1 | |
# print(f'{i}/{total}') | |
# print('----------------------------') | |
if len(true_ents) == 0: | |
false_pos += len(pred_ents) | |
if len(pred_ents) > 0: | |
false_pos_ids.append(j) | |
continue | |
if len(pred_ents) == 0: | |
false_neg += len(true_ents) | |
if len(true_ents) > 0: | |
# print('False Negative') | |
false_neg_ids.append(j) | |
continue | |
embed_true = model.encode(true_ents, convert_to_tensor=True) | |
embed_pred = model.encode(pred_ents, convert_to_tensor=True) | |
similarities = util.pytorch_cos_sim(embed_true, embed_pred) | |
# similarities = model.similarity(true_ents, pred_ents, device='cuda') | |
for row in similarities: | |
if (row >= threshold).any(): | |
true_pos += 1 | |
else: | |
false_neg += 1 | |
# print('False Negative 2222222') | |
false_neg_ids.append(j) | |
for row in similarities.T: | |
if (row >= threshold).any(): | |
continue | |
else: | |
false_pos += 1 | |
false_pos_ids.append(j) | |
recall = calc_recall(true_pos, false_neg) | |
precision = calc_precision(true_pos, false_pos) | |
f1_score = calc_f1_score(precision, recall, eps=eps) | |
return { | |
# 'true_pos': true_pos, | |
# 'false_pos': false_pos, | |
# 'false_neg': false_neg, | |
'recall': recall, | |
'precision': precision, | |
'f1': f1_score, | |
# 'false_pos_ids': list(set(false_pos_ids)), | |
# 'false_neg_ids': list(set(false_neg_ids)) | |
} |