ner-demo-evaluate / metrics.py
elshehawy's picture
update code to work with sentence-transformers instead of simcse
ed49033
raw
history blame
2.41 kB
from sentence_transformers import util
def calc_recall(true_pos, false_neg, eps=1e-8):
return true_pos / (true_pos + false_neg + eps)
def calc_precision(true_pos, false_pos, eps=1e-8):
return true_pos / (true_pos + false_pos + eps)
def calc_f1_score(precision, recall, eps=1e-8):
return (2*precision*recall) / (precision + recall + eps)
def calc_metrics(true, predicted, model, threshold=0.95, eps=1e-8):
true_pos = 0
false_pos = 0
false_neg = 0
false_pos_ids = []
false_neg_ids = []
i = 0
total = len(true)
for j, (true_ents, pred_ents) in enumerate(zip(true, predicted)):
i += 1
# print(f'{i}/{total}')
# print('----------------------------')
if len(true_ents) == 0:
false_pos += len(pred_ents)
if len(pred_ents) > 0:
false_pos_ids.append(j)
continue
if len(pred_ents) == 0:
false_neg += len(true_ents)
if len(true_ents) > 0:
# print('False Negative')
false_neg_ids.append(j)
continue
embed_true = model.encode(true_ents, convert_to_tensor=True)
embed_pred = model.encode(pred_ents, convert_to_tensor=True)
similarities = util.pytorch_cos_sim(embed_true, embed_pred)
# similarities = model.similarity(true_ents, pred_ents, device='cuda')
for row in similarities:
if (row >= threshold).any():
true_pos += 1
else:
false_neg += 1
# print('False Negative 2222222')
false_neg_ids.append(j)
for row in similarities.T:
if (row >= threshold).any():
continue
else:
false_pos += 1
false_pos_ids.append(j)
recall = calc_recall(true_pos, false_neg)
precision = calc_precision(true_pos, false_pos)
f1_score = calc_f1_score(precision, recall, eps=eps)
return {
# 'true_pos': true_pos,
# 'false_pos': false_pos,
# 'false_neg': false_neg,
'recall': recall,
'precision': precision,
'f1': f1_score,
# 'false_pos_ids': list(set(false_pos_ids)),
# 'false_neg_ids': list(set(false_neg_ids))
}