Spaces:
Runtime error
Runtime error
File size: 2,405 Bytes
ed49033 894b24d ed49033 894b24d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from sentence_transformers import util
def calc_recall(true_pos, false_neg, eps=1e-8):
return true_pos / (true_pos + false_neg + eps)
def calc_precision(true_pos, false_pos, eps=1e-8):
return true_pos / (true_pos + false_pos + eps)
def calc_f1_score(precision, recall, eps=1e-8):
return (2*precision*recall) / (precision + recall + eps)
def calc_metrics(true, predicted, model, threshold=0.95, eps=1e-8):
true_pos = 0
false_pos = 0
false_neg = 0
false_pos_ids = []
false_neg_ids = []
i = 0
total = len(true)
for j, (true_ents, pred_ents) in enumerate(zip(true, predicted)):
i += 1
# print(f'{i}/{total}')
# print('----------------------------')
if len(true_ents) == 0:
false_pos += len(pred_ents)
if len(pred_ents) > 0:
false_pos_ids.append(j)
continue
if len(pred_ents) == 0:
false_neg += len(true_ents)
if len(true_ents) > 0:
# print('False Negative')
false_neg_ids.append(j)
continue
embed_true = model.encode(true_ents, convert_to_tensor=True)
embed_pred = model.encode(pred_ents, convert_to_tensor=True)
similarities = util.pytorch_cos_sim(embed_true, embed_pred)
# similarities = model.similarity(true_ents, pred_ents, device='cuda')
for row in similarities:
if (row >= threshold).any():
true_pos += 1
else:
false_neg += 1
# print('False Negative 2222222')
false_neg_ids.append(j)
for row in similarities.T:
if (row >= threshold).any():
continue
else:
false_pos += 1
false_pos_ids.append(j)
recall = calc_recall(true_pos, false_neg)
precision = calc_precision(true_pos, false_pos)
f1_score = calc_f1_score(precision, recall, eps=eps)
return {
# 'true_pos': true_pos,
# 'false_pos': false_pos,
# 'false_neg': false_neg,
'recall': recall,
'precision': precision,
'f1': f1_score,
# 'false_pos_ids': list(set(false_pos_ids)),
# 'false_neg_ids': list(set(false_neg_ids))
} |