elshehawy commited on
Commit
ed49033
β€’
1 Parent(s): 3611192

update code to work with sentence-transformers instead of simcse

Browse files
Files changed (3) hide show
  1. app.py +4 -2
  2. metrics.py +8 -2
  3. requirements.txt +1 -1
app.py CHANGED
@@ -9,6 +9,7 @@ import huggingface_hub
9
  import json
10
  from simcse import SimCSE # use for gpt
11
  from evaluate_data import store_sample_data, get_metrics_trf
 
12
 
13
  # store_sample_data()
14
 
@@ -108,8 +109,9 @@ def find_orgs(uploaded_file):
108
  true_orgs.append(sent['orgs'])
109
 
110
 
111
- sim_model = SimCSE('sentence-transformers/all-MiniLM-L6-v2')
112
- all_metrics['gpt'] = calc_metrics(true_orgs, gpt_orgs, sim_model)
 
113
 
114
  return all_metrics
115
  # radio_btn = gr.Radio(choices=['GPT', 'iSemantics'], value='iSemantics', label='Available models', show_label=True)
 
9
  import json
10
  from simcse import SimCSE # use for gpt
11
  from evaluate_data import store_sample_data, get_metrics_trf
12
+ from sentence_transformers import SentenceTransformer
13
 
14
  # store_sample_data()
15
 
 
109
  true_orgs.append(sent['orgs'])
110
 
111
 
112
+ # sim_model = SimCSE('sentence-transformers/all-MiniLM-L6-v2')
113
+ sim_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
114
+ all_metrics['gpt'] = calc_metrics(true_orgs, gpt_orgs, sim_model, threshold=0.85)
115
 
116
  return all_metrics
117
  # radio_btn = gr.Radio(choices=['GPT', 'iSemantics'], value='iSemantics', label='Available models', show_label=True)
metrics.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  def calc_recall(true_pos, false_neg, eps=1e-8):
2
  return true_pos / (true_pos + false_neg + eps)
3
 
@@ -44,8 +46,12 @@ def calc_metrics(true, predicted, model, threshold=0.95, eps=1e-8):
44
  false_neg_ids.append(j)
45
 
46
  continue
47
-
48
- similarities = model.similarity(true_ents, pred_ents, device='cuda')
 
 
 
 
49
 
50
  for row in similarities:
51
  if (row >= threshold).any():
 
1
+ from sentence_transformers import util
2
+
3
  def calc_recall(true_pos, false_neg, eps=1e-8):
4
  return true_pos / (true_pos + false_neg + eps)
5
 
 
46
  false_neg_ids.append(j)
47
 
48
  continue
49
+
50
+ embed_true = model.encode(true_ents, convert_to_tensor=True)
51
+ embed_pred = model.encode(pred_ents, convert_to_tensor=True)
52
+
53
+ similarities = util.pytorch_cos_sim(embed_true, embed_pred)
54
+ # similarities = model.similarity(true_ents, pred_ents, device='cuda')
55
 
56
  for row in similarities:
57
  if (row >= threshold).any():
requirements.txt CHANGED
@@ -5,4 +5,4 @@ datasets==2.18.0
5
  evaluate
6
  seqeval
7
  rich
8
- simcse
 
5
  evaluate
6
  seqeval
7
  rich
8
+ sentence-transformers