Spaces:
Running
Running
danielhajialigol
commited on
Commit
•
bc31c45
1
Parent(s):
542f530
removing redundant examples
Browse files
app.py
CHANGED
@@ -8,7 +8,7 @@ from utils import load_rule, get_attribution, get_diseases, get_drg_link, get_ic
|
|
8 |
from transformers import AutoTokenizer, AutoModel, set_seed, pipeline
|
9 |
set_seed(42)
|
10 |
model_path = 'checkpoint_0_9113.bin'
|
11 |
-
related_tensor = torch.
|
12 |
all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES'].to_list()
|
13 |
|
14 |
similarity_tokenizer = AutoTokenizer.from_pretrained('kamalkraj/BioSimCSE-BioLinkBERT-BASE')
|
@@ -72,13 +72,20 @@ def find_related_summaries(text):
|
|
72 |
embedding = mean_pooling(outputs, attention_mask=inputs.attention_mask)
|
73 |
embedding = torch.nn.functional.normalize(embedding)
|
74 |
scores = torch.mm(related_tensor, embedding.transpose(1,0))
|
75 |
-
scores_indices = scores.topk(k=
|
76 |
indices, scores = scores_indices[-1], torch.round(100 * scores_indices[0], decimals=2)
|
77 |
summaries = []
|
|
|
78 |
for summary_idx, score in zip(indices, scores):
|
|
|
|
|
|
|
79 |
corresp_summary = all_summaries[summary_idx]
|
80 |
-
|
|
|
|
|
81 |
summaries.append([summary])
|
|
|
82 |
return summaries
|
83 |
|
84 |
|
@@ -208,11 +215,10 @@ def main():
|
|
208 |
|
209 |
# input to related summaries
|
210 |
with gr.Row() as row:
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
sbm_btn = gr.Button(value="Submit Related Summaries", components=[input_related], visible=False)
|
216 |
|
217 |
with gr.Row() as row:
|
218 |
related = gr.Dataset(samples=[], components=[input_related], visible=False, type='index')
|
|
|
8 |
from transformers import AutoTokenizer, AutoModel, set_seed, pipeline
|
9 |
set_seed(42)
|
10 |
model_path = 'checkpoint_0_9113.bin'
|
11 |
+
related_tensor = torch.load('discharge_embeddings.pt')
|
12 |
all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES'].to_list()
|
13 |
|
14 |
similarity_tokenizer = AutoTokenizer.from_pretrained('kamalkraj/BioSimCSE-BioLinkBERT-BASE')
|
|
|
72 |
embedding = mean_pooling(outputs, attention_mask=inputs.attention_mask)
|
73 |
embedding = torch.nn.functional.normalize(embedding)
|
74 |
scores = torch.mm(related_tensor, embedding.transpose(1,0))
|
75 |
+
scores_indices = scores.topk(k=50, dim=0)
|
76 |
indices, scores = scores_indices[-1], torch.round(100 * scores_indices[0], decimals=2)
|
77 |
summaries = []
|
78 |
+
score_set = set()
|
79 |
for summary_idx, score in zip(indices, scores):
|
80 |
+
score = score.item()
|
81 |
+
if len(summaries) == 5:
|
82 |
+
break
|
83 |
corresp_summary = all_summaries[summary_idx]
|
84 |
+
if score in score_set:
|
85 |
+
continue
|
86 |
+
summary = f'{round(score,2)}% Similarity Rate for the following Discharge Summary:\n\n{corresp_summary}'
|
87 |
summaries.append([summary])
|
88 |
+
score_set.add(score)
|
89 |
return summaries
|
90 |
|
91 |
|
|
|
215 |
|
216 |
# input to related summaries
|
217 |
with gr.Row() as row:
|
218 |
+
input_related = gr.TextArea(label="Input up to 3 Related Discharge Summary/Summaries Here", visible=False)
|
219 |
+
with gr.Row() as row:
|
220 |
+
rmv_related_btn = gr.Button(value='Remove Related Summary', visible=False)
|
221 |
+
sbm_btn = gr.Button(value="Submit Related Summaries", components=[input_related], visible=False)
|
|
|
222 |
|
223 |
with gr.Row() as row:
|
224 |
related = gr.Dataset(samples=[], components=[input_related], visible=False, type='index')
|
utils.py
CHANGED
@@ -317,7 +317,7 @@ def visualize_text(datarecord, drg_link, icd_annotations, diseases):
|
|
317 |
"<th style='text-align: left'>Predicted DRG</th>"
|
318 |
"<th style='text-align: left'>Word Importance</th>"
|
319 |
"<th style='text-align: left'>Diseases</th>"
|
320 |
-
"<th style='text-align: left'>ICD
|
321 |
]
|
322 |
pred_class_html = visualization.format_classname(datarecord.pred_class)
|
323 |
icd_class_html = get_icd_html(icd_annotations)
|
|
|
317 |
"<th style='text-align: left'>Predicted DRG</th>"
|
318 |
"<th style='text-align: left'>Word Importance</th>"
|
319 |
"<th style='text-align: left'>Diseases</th>"
|
320 |
+
"<th style='text-align: left'>ICD Concepts</th>"
|
321 |
]
|
322 |
pred_class_html = visualization.format_classname(datarecord.pred_class)
|
323 |
icd_class_html = get_icd_html(icd_annotations)
|