gabrielandrade2
commited on
Commit
•
576d564
1
Parent(s):
13dbcb6
Update model with additional negative examples, improve support scripts
Browse files- EntityNormalizer.py +17 -9
- NER_medNLP.py +77 -79
- README.md +86 -16
- config.json +13 -1
- id_to_tags.pkl +2 -2
- model.safetensors +0 -3
- predict.py +79 -42
- pytorch_model.bin +2 -2
- requirements.txt +30 -33
- tokenizer_config.json +1 -1
- utils.py +15 -0
EntityNormalizer.py
CHANGED
@@ -5,29 +5,38 @@ from rapidfuzz import fuzz, process
|
|
5 |
|
6 |
class EntityDictionary:
|
7 |
|
8 |
-
def __init__(self, path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
self.df = pd.read_csv(path)
|
|
|
|
|
10 |
|
11 |
def get_candidates_list(self):
|
12 |
-
return self.df.iloc[:,
|
13 |
|
14 |
def get_normalization_list(self):
|
15 |
-
return self.df.iloc[:,
|
16 |
|
17 |
def get_normalized_term(self, term):
|
18 |
-
return self.df[self.df.iloc[:,
|
19 |
|
20 |
|
21 |
-
class
|
22 |
|
23 |
def __init__(self):
|
24 |
-
super().__init__('dictionaries/disease_dict.csv')
|
25 |
|
26 |
|
27 |
-
class
|
28 |
|
29 |
def __init__(self):
|
30 |
-
super().__init__('dictionaries/drug_dict.csv')
|
31 |
|
32 |
|
33 |
class EntityNormalizer:
|
@@ -48,4 +57,3 @@ class EntityNormalizer:
|
|
48 |
return ('' if pd.isna(ret) else ret), score
|
49 |
else:
|
50 |
return '', score
|
51 |
-
|
|
|
5 |
|
6 |
class EntityDictionary:
|
7 |
|
8 |
+
def __init__(self, path, candidate_column, normalization_column):
|
9 |
+
if path is None:
|
10 |
+
raise ValueError('Path to dictionary file is not specified.')
|
11 |
+
if candidate_column is None:
|
12 |
+
raise ValueError('Candidate column is not specified.')
|
13 |
+
if normalization_column is None:
|
14 |
+
raise ValueError('Normalization column is not specified.')
|
15 |
+
|
16 |
self.df = pd.read_csv(path)
|
17 |
+
self.candidate_column = candidate_column
|
18 |
+
self.normalization_column = normalization_column
|
19 |
|
20 |
def get_candidates_list(self):
|
21 |
+
return self.df.iloc[:, self.candidate_column].to_list()
|
22 |
|
23 |
def get_normalization_list(self):
|
24 |
+
return self.df.iloc[:, self.normalization_column].to_list()
|
25 |
|
26 |
def get_normalized_term(self, term):
|
27 |
+
return self.df[self.df.iloc[:, self.candidate_column] == term].iloc[:, self.normalization_column].item()
|
28 |
|
29 |
|
30 |
+
class DefaultDiseaseDict(EntityDictionary):
|
31 |
|
32 |
def __init__(self):
|
33 |
+
super().__init__('dictionaries/disease_dict.csv', 0, 2)
|
34 |
|
35 |
|
36 |
+
class DefaultDrugDict(EntityDictionary):
|
37 |
|
38 |
def __init__(self):
|
39 |
+
super().__init__('dictionaries/drug_dict.csv', 0, 2)
|
40 |
|
41 |
|
42 |
class EntityNormalizer:
|
|
|
57 |
return ('' if pd.isna(ret) else ret), score
|
58 |
else:
|
59 |
return '', score
|
|
NER_medNLP.py
CHANGED
@@ -1,46 +1,47 @@
|
|
1 |
# %%
|
2 |
|
3 |
import itertools
|
4 |
-
|
5 |
import numpy as np
|
6 |
-
import torch
|
7 |
-
from transformers import BertJapaneseTokenizer, BertForTokenClassification
|
8 |
import pytorch_lightning as pl
|
|
|
|
|
9 |
|
10 |
-
# from torch.utils.data import DataLoader
|
11 |
-
# import from_XML_to_json as XtC
|
12 |
-
# import random
|
13 |
-
# import json
|
14 |
-
# import unicodedata
|
15 |
-
# import pandas as pd
|
16 |
|
17 |
# %%
|
18 |
-
# 8-16
|
19 |
# PyTorch Lightningのモデル
|
20 |
class BertForTokenClassification_pl(pl.LightningModule):
|
21 |
-
|
22 |
-
def __init__(self,
|
23 |
super().__init__()
|
|
|
24 |
self.save_hyperparameters()
|
25 |
self.bert_tc = BertForTokenClassification.from_pretrained(
|
26 |
-
|
27 |
-
num_labels=num_labels
|
|
|
28 |
)
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
def training_step(self, batch, batch_idx):
|
31 |
output = self.bert_tc(**batch)
|
32 |
loss = output.loss
|
33 |
self.log('train_loss', loss)
|
34 |
return loss
|
35 |
-
|
36 |
def validation_step(self, batch, batch_idx):
|
37 |
output = self.bert_tc(**batch)
|
38 |
val_loss = output.loss
|
39 |
self.log('val_loss', val_loss)
|
40 |
-
|
41 |
-
def configure_optimizers(self):
|
42 |
-
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
43 |
|
|
|
|
|
44 |
|
45 |
|
46 |
# %%
|
@@ -58,58 +59,58 @@ class NER_tokenizer_BIO(BertJapaneseTokenizer):
|
|
58 |
符号化とラベル列の作成を行う。
|
59 |
"""
|
60 |
# 固有表現の前後でtextを分割し、それぞれのラベルをつけておく。
|
61 |
-
splitted = []
|
62 |
position = 0
|
63 |
-
|
64 |
for entity in entities:
|
65 |
start = entity['span'][0]
|
66 |
end = entity['span'][1]
|
67 |
label = entity['type_id']
|
68 |
-
splitted.append({'text':text[position:start], 'label':0})
|
69 |
-
splitted.append({'text':text[start:end], 'label':label})
|
70 |
position = end
|
71 |
-
splitted.append({'text': text[position:], 'label':0})
|
72 |
-
splitted = [
|
73 |
|
74 |
# 分割されたそれぞれの文章をトークン化し、ラベルをつける。
|
75 |
-
tokens = []
|
76 |
-
labels = []
|
77 |
for s in splitted:
|
78 |
tokens_splitted = self.tokenize(s['text'])
|
79 |
label = s['label']
|
80 |
-
if label > 0:
|
81 |
# まずトークン全てにI-タグを付与
|
82 |
# 番号順O-tag:0, B-tag:1~タグの数,I-tag:タグの数〜
|
83 |
-
labels_splitted =
|
84 |
-
[
|
85 |
# 先頭のトークンをB-タグにする
|
86 |
labels_splitted[0] = label
|
87 |
-
else:
|
88 |
-
labels_splitted =
|
89 |
-
|
90 |
tokens.extend(tokens_splitted)
|
91 |
labels.extend(labels_splitted)
|
92 |
|
93 |
# 符号化を行いBERTに入力できる形式にする。
|
94 |
input_ids = self.convert_tokens_to_ids(tokens)
|
95 |
encoding = self.prepare_for_model(
|
96 |
-
input_ids,
|
97 |
-
max_length=max_length,
|
98 |
padding='max_length',
|
99 |
truncation=True
|
100 |
-
)
|
101 |
|
102 |
# ラベルに特殊トークンを追加
|
103 |
# max_lengthで切り取って,その前後に[CLS]と[SEP]を追加するためのラベルを入れる
|
104 |
-
labels = [0] + labels[:max_length-2] + [0]
|
105 |
# max_lengthに満たない場合は,満たない分を後ろ側に追加する
|
106 |
-
labels = labels + [0]*(
|
107 |
encoding['labels'] = labels
|
108 |
|
109 |
return encoding
|
110 |
|
111 |
def encode_plus_untagged(
|
112 |
-
|
113 |
):
|
114 |
"""
|
115 |
文章をトークン化し、それぞれのトークンの文章中の位置も特定しておく。
|
@@ -117,50 +118,50 @@ class NER_tokenizer_BIO(BertJapaneseTokenizer):
|
|
117 |
"""
|
118 |
# 文章のトークン化を行い、
|
119 |
# それぞれのトークンと文章中の文字列を対応づける。
|
120 |
-
tokens = []
|
121 |
-
tokens_original = []
|
122 |
-
words = self.word_tokenizer.tokenize(text)
|
123 |
for word in words:
|
124 |
# 単語をサブワードに分割
|
125 |
-
tokens_word = self.subword_tokenizer.tokenize(word)
|
126 |
tokens.extend(tokens_word)
|
127 |
-
if tokens_word[0] == '[UNK]':
|
128 |
tokens_original.append(word)
|
129 |
else:
|
130 |
tokens_original.extend([
|
131 |
-
token.replace('##','') for token in tokens_word
|
132 |
])
|
133 |
|
134 |
# 各トークンの文章中での位置を調べる。(空白の位置を考慮する)
|
135 |
position = 0
|
136 |
-
spans = []
|
137 |
for token in tokens_original:
|
138 |
l = len(token)
|
139 |
while 1:
|
140 |
-
if token != text[position:position+l]:
|
141 |
position += 1
|
142 |
else:
|
143 |
-
spans.append([position, position+l])
|
144 |
position += l
|
145 |
break
|
146 |
|
147 |
# 符号化を行いBERTに入力できる形式にする。
|
148 |
-
input_ids = self.convert_tokens_to_ids(tokens)
|
149 |
encoding = self.prepare_for_model(
|
150 |
-
input_ids,
|
151 |
-
max_length=max_length,
|
152 |
-
padding='max_length' if max_length else False,
|
153 |
truncation=True if max_length else False
|
154 |
)
|
155 |
sequence_length = len(encoding['input_ids'])
|
156 |
# 特殊トークン[CLS]に対するダミーのspanを追加。
|
157 |
-
spans = [[-1, -1]] + spans[:sequence_length-2]
|
158 |
# 特殊トークン[SEP]、[PAD]に対するダミーのspanを追加。
|
159 |
-
spans = spans + [[-1, -1]] * (
|
160 |
|
161 |
# 必要に応じてtorch.Tensorにする。
|
162 |
if return_tensors == 'pt':
|
163 |
-
encoding = {
|
164 |
|
165 |
return encoding, spans
|
166 |
|
@@ -169,28 +170,26 @@ class NER_tokenizer_BIO(BertJapaneseTokenizer):
|
|
169 |
"""
|
170 |
Viterbiアルゴリズムで最適解を求める。
|
171 |
"""
|
172 |
-
m = 2*num_entity_type + 1
|
173 |
penalty_matrix = np.zeros([m, m])
|
174 |
for i in range(m):
|
175 |
-
for j in range(1+num_entity_type, m):
|
176 |
-
if not (
|
177 |
-
penalty_matrix[i,j] = penalty
|
178 |
-
path = [
|
179 |
-
scores_path = scores_bert[0] - penalty_matrix[0
|
180 |
scores_bert = scores_bert[1:]
|
181 |
|
182 |
-
|
183 |
-
|
184 |
for scores in scores_bert:
|
185 |
-
assert len(scores) == 2*num_entity_type + 1
|
186 |
-
score_matrix = np.array(scores_path).reshape(-1,1) \
|
187 |
-
|
188 |
-
|
189 |
scores_path = score_matrix.max(axis=0)
|
190 |
argmax = score_matrix.argmax(axis=0)
|
191 |
path_new = []
|
192 |
for i, idx in enumerate(argmax):
|
193 |
-
path_new.append(
|
194 |
path = path_new
|
195 |
|
196 |
labels_optimal = path[np.argmax(scores_path)]
|
@@ -203,26 +202,26 @@ class NER_tokenizer_BIO(BertJapaneseTokenizer):
|
|
203 |
"""
|
204 |
assert len(spans) == len(scores)
|
205 |
num_entity_type = self.num_entity_type
|
206 |
-
|
207 |
# 特殊トークンに対応する部分を取り除く
|
208 |
-
scores = [score for score, span in zip(scores, spans) if span[0]
|
209 |
-
spans = [span for span in spans if span[0]
|
210 |
-
|
211 |
# Viterbiアルゴリズムでラベルの予測値を決める。
|
212 |
labels = self.Viterbi(scores, num_entity_type)
|
213 |
|
214 |
# 同じラベルが連続するトークンをまとめて、固有表現を抽出する。
|
215 |
entities = []
|
216 |
for label, group \
|
217 |
-
|
218 |
-
|
219 |
group = list(group)
|
220 |
start = spans[group[0][0]][0]
|
221 |
end = spans[group[-1][0]][1]
|
222 |
|
223 |
-
if label != 0:
|
224 |
if 1 <= label <= num_entity_type:
|
225 |
-
|
226 |
entity = {
|
227 |
"name": text[start:end],
|
228 |
"span": [start, end],
|
@@ -231,8 +230,7 @@ class NER_tokenizer_BIO(BertJapaneseTokenizer):
|
|
231 |
entities.append(entity)
|
232 |
else:
|
233 |
# ラベルが`I-`ならば、直近のentityを更新
|
234 |
-
entity['span'][1] = end
|
235 |
entity['name'] = text[entity['span'][0]:entity['span'][1]]
|
236 |
-
|
237 |
-
return entities
|
238 |
|
|
|
|
1 |
# %%
|
2 |
|
3 |
import itertools
|
4 |
+
|
5 |
import numpy as np
|
|
|
|
|
6 |
import pytorch_lightning as pl
|
7 |
+
import torch
|
8 |
+
from transformers import BertForTokenClassification, BertJapaneseTokenizer
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# %%
|
12 |
+
# 8-16
|
13 |
# PyTorch Lightningのモデル
|
14 |
class BertForTokenClassification_pl(pl.LightningModule):
|
15 |
+
|
16 |
+
def __init__(self, num_labels, model='sociocom/MedNERN-CR-JA', lr=1e-4):
|
17 |
super().__init__()
|
18 |
+
self.lr = lr
|
19 |
self.save_hyperparameters()
|
20 |
self.bert_tc = BertForTokenClassification.from_pretrained(
|
21 |
+
model,
|
22 |
+
num_labels=num_labels,
|
23 |
+
ignore_mismatched_sizes=True
|
24 |
)
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def from_pretrained_bin(cls, model_path, num_labels):
|
28 |
+
model = cls(num_labels)
|
29 |
+
model.load_state_dict(torch.load(model_path))
|
30 |
+
return model
|
31 |
+
|
32 |
def training_step(self, batch, batch_idx):
|
33 |
output = self.bert_tc(**batch)
|
34 |
loss = output.loss
|
35 |
self.log('train_loss', loss)
|
36 |
return loss
|
37 |
+
|
38 |
def validation_step(self, batch, batch_idx):
|
39 |
output = self.bert_tc(**batch)
|
40 |
val_loss = output.loss
|
41 |
self.log('val_loss', val_loss)
|
|
|
|
|
|
|
42 |
|
43 |
+
def configure_optimizers(self):
|
44 |
+
return torch.optim.Adam(self.parameters(), lr=self.lr)
|
45 |
|
46 |
|
47 |
# %%
|
|
|
59 |
符号化とラベル列の作成を行う。
|
60 |
"""
|
61 |
# 固有表現の前後でtextを分割し、それぞれのラベルをつけておく。
|
62 |
+
splitted = [] # 分割後の文字列を追加していく
|
63 |
position = 0
|
64 |
+
|
65 |
for entity in entities:
|
66 |
start = entity['span'][0]
|
67 |
end = entity['span'][1]
|
68 |
label = entity['type_id']
|
69 |
+
splitted.append({'text': text[position:start], 'label': 0})
|
70 |
+
splitted.append({'text': text[start:end], 'label': label})
|
71 |
position = end
|
72 |
+
splitted.append({'text': text[position:], 'label': 0})
|
73 |
+
splitted = [s for s in splitted if s['text']]
|
74 |
|
75 |
# 分割されたそれぞれの文章をトークン化し、ラベルをつける。
|
76 |
+
tokens = [] # トークンを追加していく
|
77 |
+
labels = [] # ラベルを追加していく
|
78 |
for s in splitted:
|
79 |
tokens_splitted = self.tokenize(s['text'])
|
80 |
label = s['label']
|
81 |
+
if label > 0: # 固有表現
|
82 |
# まずトークン全てにI-タグを付与
|
83 |
# 番号順O-tag:0, B-tag:1~タグの数,I-tag:タグの数〜
|
84 |
+
labels_splitted = \
|
85 |
+
[label + self.num_entity_type] * len(tokens_splitted)
|
86 |
# 先頭のトークンをB-タグにする
|
87 |
labels_splitted[0] = label
|
88 |
+
else: # それ以外
|
89 |
+
labels_splitted = [0] * len(tokens_splitted)
|
90 |
+
|
91 |
tokens.extend(tokens_splitted)
|
92 |
labels.extend(labels_splitted)
|
93 |
|
94 |
# 符号化を行いBERTに入力できる形式にする。
|
95 |
input_ids = self.convert_tokens_to_ids(tokens)
|
96 |
encoding = self.prepare_for_model(
|
97 |
+
input_ids,
|
98 |
+
max_length=max_length,
|
99 |
padding='max_length',
|
100 |
truncation=True
|
101 |
+
)
|
102 |
|
103 |
# ラベルに特殊トークンを追加
|
104 |
# max_lengthで切り取って,その前後に[CLS]と[SEP]を追加するためのラベルを入れる
|
105 |
+
labels = [0] + labels[:max_length - 2] + [0]
|
106 |
# max_lengthに満たない場合は,満たない分を後ろ側に追加する
|
107 |
+
labels = labels + [0] * (max_length - len(labels))
|
108 |
encoding['labels'] = labels
|
109 |
|
110 |
return encoding
|
111 |
|
112 |
def encode_plus_untagged(
|
113 |
+
self, text, max_length=None, return_tensors=None
|
114 |
):
|
115 |
"""
|
116 |
文章をトークン化し、それぞれのトークンの文章中の位置も特定しておく。
|
|
|
118 |
"""
|
119 |
# 文章のトークン化を行い、
|
120 |
# それぞれのトークンと文章中の文字列を対応づける。
|
121 |
+
tokens = [] # トークンを追加していく。
|
122 |
+
tokens_original = [] # トークンに対応する文章中の文字列を追加していく。
|
123 |
+
words = self.word_tokenizer.tokenize(text) # MeCabで単語に分割
|
124 |
for word in words:
|
125 |
# 単語をサブワードに分割
|
126 |
+
tokens_word = self.subword_tokenizer.tokenize(word)
|
127 |
tokens.extend(tokens_word)
|
128 |
+
if tokens_word[0] == '[UNK]': # 未知語への対応
|
129 |
tokens_original.append(word)
|
130 |
else:
|
131 |
tokens_original.extend([
|
132 |
+
token.replace('##', '') for token in tokens_word
|
133 |
])
|
134 |
|
135 |
# 各トークンの文章中での位置を調べる。(空白の位置を考慮する)
|
136 |
position = 0
|
137 |
+
spans = [] # トークンの位置を追加していく。
|
138 |
for token in tokens_original:
|
139 |
l = len(token)
|
140 |
while 1:
|
141 |
+
if token != text[position:position + l]:
|
142 |
position += 1
|
143 |
else:
|
144 |
+
spans.append([position, position + l])
|
145 |
position += l
|
146 |
break
|
147 |
|
148 |
# 符号化を行いBERTに入力できる形式にする。
|
149 |
+
input_ids = self.convert_tokens_to_ids(tokens)
|
150 |
encoding = self.prepare_for_model(
|
151 |
+
input_ids,
|
152 |
+
max_length=max_length,
|
153 |
+
padding='max_length' if max_length else False,
|
154 |
truncation=True if max_length else False
|
155 |
)
|
156 |
sequence_length = len(encoding['input_ids'])
|
157 |
# 特殊トークン[CLS]に対するダミーのspanを追加。
|
158 |
+
spans = [[-1, -1]] + spans[:sequence_length - 2]
|
159 |
# 特殊トークン[SEP]、[PAD]に対するダミーのspanを追加。
|
160 |
+
spans = spans + [[-1, -1]] * (sequence_length - len(spans))
|
161 |
|
162 |
# 必要に応じてtorch.Tensorにする。
|
163 |
if return_tensors == 'pt':
|
164 |
+
encoding = {k: torch.tensor([v]) for k, v in encoding.items()}
|
165 |
|
166 |
return encoding, spans
|
167 |
|
|
|
170 |
"""
|
171 |
Viterbiアルゴリズムで最適解を求める。
|
172 |
"""
|
173 |
+
m = 2 * num_entity_type + 1
|
174 |
penalty_matrix = np.zeros([m, m])
|
175 |
for i in range(m):
|
176 |
+
for j in range(1 + num_entity_type, m):
|
177 |
+
if not ((i == j) or (i + num_entity_type == j)):
|
178 |
+
penalty_matrix[i, j] = penalty
|
179 |
+
path = [[i] for i in range(m)]
|
180 |
+
scores_path = scores_bert[0] - penalty_matrix[0, :]
|
181 |
scores_bert = scores_bert[1:]
|
182 |
|
|
|
|
|
183 |
for scores in scores_bert:
|
184 |
+
assert len(scores) == 2 * num_entity_type + 1
|
185 |
+
score_matrix = np.array(scores_path).reshape(-1, 1) \
|
186 |
+
+ np.array(scores).reshape(1, -1) \
|
187 |
+
- penalty_matrix
|
188 |
scores_path = score_matrix.max(axis=0)
|
189 |
argmax = score_matrix.argmax(axis=0)
|
190 |
path_new = []
|
191 |
for i, idx in enumerate(argmax):
|
192 |
+
path_new.append(path[idx] + [i])
|
193 |
path = path_new
|
194 |
|
195 |
labels_optimal = path[np.argmax(scores_path)]
|
|
|
202 |
"""
|
203 |
assert len(spans) == len(scores)
|
204 |
num_entity_type = self.num_entity_type
|
205 |
+
|
206 |
# 特殊トークンに対応する部分を取り除く
|
207 |
+
scores = [score for score, span in zip(scores, spans) if span[0] != -1]
|
208 |
+
spans = [span for span in spans if span[0] != -1]
|
209 |
+
|
210 |
# Viterbiアルゴリズムでラベルの予測値を決める。
|
211 |
labels = self.Viterbi(scores, num_entity_type)
|
212 |
|
213 |
# 同じラベルが連続するトークンをまとめて、固有表現を抽出する。
|
214 |
entities = []
|
215 |
for label, group \
|
216 |
+
in itertools.groupby(enumerate(labels), key=lambda x: x[1]):
|
217 |
+
|
218 |
group = list(group)
|
219 |
start = spans[group[0][0]][0]
|
220 |
end = spans[group[-1][0]][1]
|
221 |
|
222 |
+
if label != 0: # 固有表現であれば
|
223 |
if 1 <= label <= num_entity_type:
|
224 |
+
# ラベルが`B-`ならば、新しいentityを追加
|
225 |
entity = {
|
226 |
"name": text[start:end],
|
227 |
"span": [start, end],
|
|
|
230 |
entities.append(entity)
|
231 |
else:
|
232 |
# ラベルが`I-`ならば、直近のentityを更新
|
233 |
+
entity['span'][1] = end
|
234 |
entity['name'] = text[entity['span'][0]:entity['span'][1]]
|
|
|
|
|
235 |
|
236 |
+
return entities
|
README.md
CHANGED
@@ -12,26 +12,65 @@ metrics:
|
|
12 |
- NTCIR-16 Real-MedNLP subtask 1
|
13 |
---
|
14 |
|
15 |
-
|
16 |
This is a model for named entity recognition of Japanese medical documents.
|
17 |
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
|
21 |
|
22 |
-
|
23 |
-
- key_attr.pkl
|
24 |
-
- NER_medNLP.py
|
25 |
-
- predict.py
|
26 |
-
- text.txt (This is an input file which should be predicted, which could be changed.)
|
27 |
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
```
|
31 |
python3 predict.py
|
32 |
```
|
33 |
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
This model supports entity normalization via dictionary matching. The dictionary is a list of medical terms or
|
37 |
drugs and their standard forms.
|
@@ -39,23 +78,42 @@ drugs and their standard forms.
|
|
39 |
Two different dictionaries are used for drug and disease normalization, stored in the `dictionaries` folder as
|
40 |
`drug_dict.csv` and `disease_dict.csv`, respectively.
|
41 |
|
42 |
-
To enable normalization you can add the `--normalize` flag to the `predict.py` command.
|
43 |
|
44 |
```
|
45 |
-
python3 predict.py --normalize
|
46 |
```
|
47 |
|
48 |
Normalization will add the `norm` attribute to the output XML tags. This attribute can be empty if a normalized form of
|
49 |
the term is not found.
|
50 |
|
51 |
-
The provided disease normalization dictionary (`dictionaties/disease_dict.csv`) is based on
|
|
|
|
|
52 |
|
53 |
-
The default drug dictionary (`dictionaties/drug_dict.csv`) is based on
|
|
|
54 |
|
55 |
The dictionary is a CSV file with three columns: the first column is the surface form term and the third column contain
|
56 |
its standard form. The second column is not used.
|
57 |
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
### Input Example
|
61 |
|
@@ -71,5 +129,17 @@ User can freely change the dictionary to fit their needs, as long as the format
|
|
71 |
<timex3 type="med">治療経過中</timex3>に<d certainty="positive" norm="I472">非持続性心室頻拍</d>が認められたため<m-key state="executed" norm="アミオダロン塩酸塩">アミオダロン</m-key>が併用となった。
|
72 |
```
|
73 |
|
74 |
-
|
|
|
|
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
- NTCIR-16 Real-MedNLP subtask 1
|
13 |
---
|
14 |
|
|
|
15 |
This is a model for named entity recognition of Japanese medical documents.
|
16 |
|
17 |
+
# Introduction
|
18 |
+
|
19 |
+
This repository contains the base model and a support predict script for using the model and providing a XML tagged text output.
|
20 |
+
|
21 |
+
The original model was trained on the [MedTxt-CR-JA](https://sociocom.naist.jp/medtxt/cr) dataset, so the provided prediction code outputs XML tags in the same format.
|
22 |
+
|
23 |
+
The script also provide the normalization method for the output entities, which is not embedded in the model.
|
24 |
+
|
25 |
+
If you want to re-train or update the model, we provide additional support scripts in [this GitHub repository](https://github.com/sociocom/MedNERN-CR-JA).
|
26 |
+
Issues and suggestions can also be submitted there.
|
27 |
+
|
28 |
+
### A note about loading the model using standard HuggingFace methods
|
29 |
+
This model should also be loadable using standard HuggingFace `from_pretrained` methods. However, the model by itself only outputs labels in the format "LABEL_0", "LABEL1", etc.
|
30 |
+
|
31 |
+
The conversion of model outputs to the actual labels ("<m-key>, "<m-val>", "<timex-3>" etc.) is not yet embedded into the model, so the extra `id_to_tags.pkl` file is necessary
|
32 |
+
to make the conversion. It contains a mapping from the model output ids to the actual labels.
|
33 |
+
|
34 |
+
Such process can be done manually if needed, but the `predict.py` script already does that.
|
35 |
|
36 |
+
We are currently working to better standardize the model to HuggingFace's standards.
|
37 |
|
38 |
+
## How to use
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
Clone the repository and install the requirements:
|
41 |
+
|
42 |
+
```
|
43 |
+
pip install -r requirements.txt
|
44 |
+
```
|
45 |
+
|
46 |
+
The code has been developed tested with Python 3.9 in MacOS 14.1 (M1 MacBook Pro).
|
47 |
+
|
48 |
+
### Prediction
|
49 |
+
|
50 |
+
The prediction script will output the results in the same XML format as the input file. It can be run with the following
|
51 |
+
command:
|
52 |
|
53 |
```
|
54 |
python3 predict.py
|
55 |
```
|
56 |
|
57 |
+
The default parameters will take the model located in `pytorch_model.bin` and the input file `text.txt`.
|
58 |
+
The resulting predictions will be output to the screen.
|
59 |
+
|
60 |
+
To select a different model or input file, use the `-m` and `-i` parameters, respectively:
|
61 |
+
|
62 |
+
```
|
63 |
+
python3 predict.py -m <model_path> -i <your_input_file>.txt
|
64 |
+
```
|
65 |
+
|
66 |
+
The input file can be a single text file or a folder containing multiple `.txt` files, for batch processing. For example:
|
67 |
+
|
68 |
+
```
|
69 |
+
python3 predict.py -m <model_path> -i <your_input_folder>
|
70 |
+
```
|
71 |
+
|
72 |
+
|
73 |
+
### Entity normalization
|
74 |
|
75 |
This model supports entity normalization via dictionary matching. The dictionary is a list of medical terms or
|
76 |
drugs and their standard forms.
|
|
|
78 |
Two different dictionaries are used for drug and disease normalization, stored in the `dictionaries` folder as
|
79 |
`drug_dict.csv` and `disease_dict.csv`, respectively.
|
80 |
|
81 |
+
To enable normalization you can add the `--normalize` flag to the `predict.py` command.
|
82 |
|
83 |
```
|
84 |
+
python3 predict.py -m <model_path> --normalize
|
85 |
```
|
86 |
|
87 |
Normalization will add the `norm` attribute to the output XML tags. This attribute can be empty if a normalized form of
|
88 |
the term is not found.
|
89 |
|
90 |
+
The provided disease normalization dictionary (`dictionaties/disease_dict.csv`) is based on
|
91 |
+
the [Manbyo Dictionary](https://sociocom.naist.jp/manbyo-dic-en/) and provides normalization to the standard ICD code
|
92 |
+
for the diseases.
|
93 |
|
94 |
+
The default drug dictionary (`dictionaties/drug_dict.csv`) is based on
|
95 |
+
the [Hyakuyaku Dictionary](https://sociocom.naist.jp/hyakuyaku-dic-en/).
|
96 |
|
97 |
The dictionary is a CSV file with three columns: the first column is the surface form term and the third column contain
|
98 |
its standard form. The second column is not used.
|
99 |
|
100 |
+
### Replacing the default dictionaries
|
101 |
+
|
102 |
+
User can freely change the dictionary to fit their needs by passing the path to a custom dictionary file.
|
103 |
+
The dictionary file must have at least a column containing the list of surface forms and a column containing the list of
|
104 |
+
normalized forms.
|
105 |
+
|
106 |
+
The parameters `--drug_dict` and `--disease_dict` can be used to specify the path to the drug and disease dictionaries,
|
107 |
+
respectively.
|
108 |
+
When doing so, the respective parameters informing the column index of the surface form and normalized form must also be
|
109 |
+
provided.
|
110 |
+
You don't need to replace both dictionaries at the same time, you can replace only one of them.
|
111 |
+
|
112 |
+
E.g.:
|
113 |
+
|
114 |
+
```
|
115 |
+
python3 predict.py --normalize --drug_dict dictionaries/drug_dict.csv --drug_surface_form 0 --drug_norm_form 2 --disease_dict dictionaries/disease_dict.csv --disease_surface_form 0 --disease_norm_form 2
|
116 |
+
```
|
117 |
|
118 |
### Input Example
|
119 |
|
|
|
129 |
<timex3 type="med">治療経過中</timex3>に<d certainty="positive" norm="I472">非持続性心室頻拍</d>が認められたため<m-key state="executed" norm="アミオダロン塩酸塩">アミオダロン</m-key>が併用となった。
|
130 |
```
|
131 |
|
132 |
+
## Publication
|
133 |
+
|
134 |
+
This model can be cited as:
|
135 |
|
136 |
+
```
|
137 |
+
@misc {social_computing_lab_2023,
|
138 |
+
author = { {Social Computing Lab} },
|
139 |
+
title = { MedNERN-CR-JA (Revision 13dbcb6) },
|
140 |
+
year = 2023,
|
141 |
+
url = { https://huggingface.co/sociocom/MedNERN-CR-JA },
|
142 |
+
doi = { 10.57967/hf/0620 },
|
143 |
+
publisher = { Hugging Face }
|
144 |
+
}
|
145 |
+
```
|
config.json
CHANGED
@@ -89,7 +89,13 @@
|
|
89 |
"77": "LABEL_77",
|
90 |
"78": "LABEL_78",
|
91 |
"79": "LABEL_79",
|
92 |
-
"80": "LABEL_80"
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
},
|
94 |
"initializer_range": 0.02,
|
95 |
"intermediate_size": 3072,
|
@@ -174,6 +180,12 @@
|
|
174 |
"LABEL_79": 79,
|
175 |
"LABEL_8": 8,
|
176 |
"LABEL_80": 80,
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
"LABEL_9": 9
|
178 |
},
|
179 |
"layer_norm_eps": 1e-12,
|
|
|
89 |
"77": "LABEL_77",
|
90 |
"78": "LABEL_78",
|
91 |
"79": "LABEL_79",
|
92 |
+
"80": "LABEL_80",
|
93 |
+
"81": "LABEL_81",
|
94 |
+
"82": "LABEL_82",
|
95 |
+
"83": "LABEL_83",
|
96 |
+
"84": "LABEL_84",
|
97 |
+
"85": "LABEL_85",
|
98 |
+
"86": "LABEL_86"
|
99 |
},
|
100 |
"initializer_range": 0.02,
|
101 |
"intermediate_size": 3072,
|
|
|
180 |
"LABEL_79": 79,
|
181 |
"LABEL_8": 8,
|
182 |
"LABEL_80": 80,
|
183 |
+
"LABEL_81": 81,
|
184 |
+
"LABEL_82": 82,
|
185 |
+
"LABEL_83": 83,
|
186 |
+
"LABEL_84": 84,
|
187 |
+
"LABEL_85": 85,
|
188 |
+
"LABEL_86": 86,
|
189 |
"LABEL_9": 9
|
190 |
},
|
191 |
"layer_norm_eps": 1e-12,
|
id_to_tags.pkl
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:26cbbc0594cf7a1c4439a1010c5e2c55c1f0fb0a9664d93248b7b7d7de0cc434
|
3 |
+
size 671
|
model.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:802f29afc4eae3cbf49f3957f9b2d27ae247e6e079ad6d947ccf181dff7c754c
|
3 |
-
size 440383704
|
|
|
|
|
|
|
|
predict.py
CHANGED
@@ -1,23 +1,24 @@
|
|
1 |
# %%
|
2 |
import argparse
|
3 |
-
|
4 |
-
from tqdm import tqdm
|
5 |
-
import unicodedata
|
6 |
-
import re
|
7 |
import pickle
|
|
|
|
|
8 |
import torch
|
9 |
-
|
10 |
|
11 |
-
|
|
|
|
|
12 |
|
13 |
-
device = torch.device('cuda
|
14 |
|
15 |
# %% global変数として使う
|
16 |
dict_key = {}
|
17 |
|
18 |
|
19 |
# %%
|
20 |
-
def to_xml(data):
|
21 |
with open("key_attr.pkl", "rb") as tf:
|
22 |
key_attr = pickle.load(tf)
|
23 |
|
@@ -27,7 +28,11 @@ def to_xml(data):
|
|
27 |
if entities == "":
|
28 |
return
|
29 |
span = entities['span']
|
30 |
-
|
|
|
|
|
|
|
|
|
31 |
tag = type_id[0]
|
32 |
|
33 |
if not type_id[1] == "":
|
@@ -49,17 +54,11 @@ def to_xml(data):
|
|
49 |
|
50 |
|
51 |
def predict_entities(modelpath, sentences_list, len_num_entity_type):
|
52 |
-
|
53 |
-
# checkpoint_path = modelpath + ".ckpt"
|
54 |
-
# )
|
55 |
-
# bert_tc = model.bert_tc.cuda()
|
56 |
-
|
57 |
-
model = ner.BertForTokenClassification_pl(modelpath, num_labels=81, lr=1e-5)
|
58 |
bert_tc = model.bert_tc.to(device)
|
59 |
|
60 |
-
MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'
|
61 |
tokenizer = ner.NER_tokenizer_BIO.from_pretrained(
|
62 |
-
|
63 |
num_entity_type=len_num_entity_type # Entityの数を変え忘れないように!
|
64 |
)
|
65 |
|
@@ -69,7 +68,7 @@ def predict_entities(modelpath, sentences_list, len_num_entity_type):
|
|
69 |
text_entities_set = []
|
70 |
for dataset in sentences_list:
|
71 |
text_entities = []
|
72 |
-
for sample in tqdm(dataset):
|
73 |
text = sample
|
74 |
encoding, spans = tokenizer.encode_plus_untagged(
|
75 |
text, return_tensors='pt'
|
@@ -93,12 +92,12 @@ def predict_entities(modelpath, sentences_list, len_num_entity_type):
|
|
93 |
return text_entities_set
|
94 |
|
95 |
|
96 |
-
def combine_sentences(text_entities_set, insert: str):
|
97 |
documents = []
|
98 |
for text_entities in tqdm(text_entities_set):
|
99 |
document = []
|
100 |
for t in text_entities:
|
101 |
-
document.append(to_xml(t))
|
102 |
documents.append('\n'.join(document))
|
103 |
return documents
|
104 |
|
@@ -115,9 +114,19 @@ def value_to_key(value, key_attr): # attributeから属性名を取得
|
|
115 |
|
116 |
|
117 |
# %%
|
118 |
-
def normalize_entities(text_entities_set
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
for entry in text_entities_set:
|
123 |
for text_entities in entry:
|
@@ -136,31 +145,59 @@ def normalize_entities(text_entities_set):
|
|
136 |
entity['norm'] = str(normalization)
|
137 |
|
138 |
|
139 |
-
|
140 |
-
parser = argparse.ArgumentParser(description='Predict entities from text')
|
141 |
-
parser.add_argument('--normalize', action=argparse.BooleanOptionalAction, help='Enable entity normalization')
|
142 |
-
args = parser.parse_args()
|
143 |
-
|
144 |
with open("id_to_tags.pkl", "rb") as tf:
|
145 |
id_to_tags = pickle.load(tf)
|
146 |
-
with open("key_attr.pkl", "rb") as tf:
|
147 |
-
key_attr = pickle.load(tf)
|
148 |
-
with open('text.txt') as f:
|
149 |
-
articles_raw = f.read()
|
150 |
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
-
|
154 |
-
sentences_norm = [s for s in re.split(r'\n', article_norm) if s != '']
|
155 |
|
156 |
-
|
|
|
157 |
|
158 |
-
|
159 |
-
|
160 |
|
161 |
-
|
162 |
-
normalize_entities(text_entities_set)
|
163 |
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
-
|
|
|
|
1 |
# %%
|
2 |
import argparse
|
3 |
+
import os.path
|
|
|
|
|
|
|
4 |
import pickle
|
5 |
+
import unicodedata
|
6 |
+
|
7 |
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
|
10 |
+
import NER_medNLP as ner
|
11 |
+
import utils
|
12 |
+
from EntityNormalizer import EntityNormalizer, EntityDictionary, DefaultDiseaseDict, DefaultDrugDict
|
13 |
|
14 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
|
16 |
# %% global変数として使う
|
17 |
dict_key = {}
|
18 |
|
19 |
|
20 |
# %%
|
21 |
+
def to_xml(data, id_to_tags):
|
22 |
with open("key_attr.pkl", "rb") as tf:
|
23 |
key_attr = pickle.load(tf)
|
24 |
|
|
|
28 |
if entities == "":
|
29 |
return
|
30 |
span = entities['span']
|
31 |
+
try:
|
32 |
+
type_id = id_to_tags[entities['type_id']].split('_')
|
33 |
+
except:
|
34 |
+
print("out of rage type_id", entities)
|
35 |
+
continue
|
36 |
tag = type_id[0]
|
37 |
|
38 |
if not type_id[1] == "":
|
|
|
54 |
|
55 |
|
56 |
def predict_entities(modelpath, sentences_list, len_num_entity_type):
|
57 |
+
model = ner.BertForTokenClassification_pl.from_pretrained_bin(model_path=modelpath, num_labels=2 * len_num_entity_type + 1)
|
|
|
|
|
|
|
|
|
|
|
58 |
bert_tc = model.bert_tc.to(device)
|
59 |
|
|
|
60 |
tokenizer = ner.NER_tokenizer_BIO.from_pretrained(
|
61 |
+
'cl-tohoku/bert-base-japanese-whole-word-masking',
|
62 |
num_entity_type=len_num_entity_type # Entityの数を変え忘れないように!
|
63 |
)
|
64 |
|
|
|
68 |
text_entities_set = []
|
69 |
for dataset in sentences_list:
|
70 |
text_entities = []
|
71 |
+
for sample in tqdm(dataset, desc='Predict'):
|
72 |
text = sample
|
73 |
encoding, spans = tokenizer.encode_plus_untagged(
|
74 |
text, return_tensors='pt'
|
|
|
92 |
return text_entities_set
|
93 |
|
94 |
|
95 |
+
def combine_sentences(text_entities_set, id_to_tags, insert: str):
|
96 |
documents = []
|
97 |
for text_entities in tqdm(text_entities_set):
|
98 |
document = []
|
99 |
for t in text_entities:
|
100 |
+
document.append(to_xml(t, id_to_tags))
|
101 |
documents.append('\n'.join(document))
|
102 |
return documents
|
103 |
|
|
|
114 |
|
115 |
|
116 |
# %%
|
117 |
+
def normalize_entities(text_entities_set, id_to_tags, disease_dict=None, disease_candidate_col=None, disease_normalization_col=None, disease_matching_threshold=None, drug_dict=None,
|
118 |
+
drug_candidate_col=None, drug_normalization_col=None, drug_matching_threshold=None):
|
119 |
+
if disease_dict:
|
120 |
+
disease_dict = EntityDictionary(disease_dict, disease_candidate_col, disease_normalization_col)
|
121 |
+
else:
|
122 |
+
disease_dict = DefaultDiseaseDict()
|
123 |
+
disease_normalizer = EntityNormalizer(disease_dict, matching_threshold=disease_matching_threshold)
|
124 |
+
|
125 |
+
if drug_dict:
|
126 |
+
drug_dict = EntityDictionary(drug_dict, drug_candidate_col, drug_normalization_col)
|
127 |
+
else:
|
128 |
+
drug_dict = DefaultDrugDict()
|
129 |
+
drug_normalizer = EntityNormalizer(drug_dict, matching_threshold=drug_matching_threshold)
|
130 |
|
131 |
for entry in text_entities_set:
|
132 |
for text_entities in entry:
|
|
|
145 |
entity['norm'] = str(normalization)
|
146 |
|
147 |
|
148 |
+
def run(model, input, output=None, normalize=False, **kwargs):
|
|
|
|
|
|
|
|
|
149 |
with open("id_to_tags.pkl", "rb") as tf:
|
150 |
id_to_tags = pickle.load(tf)
|
|
|
|
|
|
|
|
|
151 |
|
152 |
+
if (os.path.isdir(input)):
|
153 |
+
files = [f for f in os.listdir(input) if os.path.isfile(os.path.join(input, f))]
|
154 |
+
else:
|
155 |
+
files = [input]
|
156 |
+
|
157 |
+
for file in tqdm(files, desc="Input file"):
|
158 |
+
with open(file) as f:
|
159 |
+
articles_raw = f.read()
|
160 |
+
|
161 |
+
article_norm = unicodedata.normalize('NFKC', articles_raw)
|
162 |
+
|
163 |
+
sentences_raw = utils.split_sentences(articles_raw)
|
164 |
+
sentences_norm = utils.split_sentences(article_norm)
|
165 |
|
166 |
+
text_entities_set = predict_entities(model, [sentences_norm], len(id_to_tags))
|
|
|
167 |
|
168 |
+
for i, texts_ent in enumerate(text_entities_set[0]):
|
169 |
+
texts_ent['text'] = sentences_raw[i]
|
170 |
|
171 |
+
if normalize:
|
172 |
+
normalize_entities(text_entities_set, id_to_tags, **kwargs)
|
173 |
|
174 |
+
documents = combine_sentences(text_entities_set, id_to_tags, '\n')
|
|
|
175 |
|
176 |
+
print(documents[0])
|
177 |
+
|
178 |
+
if output:
|
179 |
+
with open(file.replace(input, output), 'w') as f:
|
180 |
+
f.write(documents[0])
|
181 |
+
|
182 |
+
|
183 |
+
if __name__ == '__main__':
|
184 |
+
parser = argparse.ArgumentParser(description='Predict entities from text')
|
185 |
+
parser.add_argument('-m', '--model', type=str, default='pytorch_model.bin', help='Path to model checkpoint')
|
186 |
+
parser.add_argument('-i', '--input', type=str, default='text.txt', help='Path to text file or directory')
|
187 |
+
parser.add_argument('-o', '--output', type=str, default=None, help='Path to output file or directory')
|
188 |
+
parser.add_argument('-n', '--normalize', action=argparse.BooleanOptionalAction, help='Enable entity normalization', default=False)
|
189 |
+
|
190 |
+
# Dictionary override arguments
|
191 |
+
parser.add_argument("--drug-dict", help="File path for overriding the default drug dictionary")
|
192 |
+
parser.add_argument("--drug-candidate-col", type=int, help="Column name for drug candidates in the CSV file (required if --drug-dict is specified)")
|
193 |
+
parser.add_argument("--drug-normalization-col", type=int, help="Column name for drug normalization in the CSV file (required if --drug-dict is specified")
|
194 |
+
parser.add_argument('--disease-matching-threshold', type=int, default=50, help='Matching threshold for disease dictionary')
|
195 |
+
|
196 |
+
parser.add_argument("--disease-dict", help="File path for overriding the default disease dictionary")
|
197 |
+
parser.add_argument("--disease-candidate-col", type=int, help="Column name for disease candidates in the CSV file (required if --disease-dict is specified)")
|
198 |
+
parser.add_argument("--disease-normalization-col", type=int, help="Column name for disease normalization in the CSV file (required if --disease-dict is specified)")
|
199 |
+
parser.add_argument('--drug-matching-threshold', type=int, default=50, help='Matching threshold for drug dictionary')
|
200 |
+
args = parser.parse_args()
|
201 |
|
202 |
+
argument_dict = vars(args)
|
203 |
+
run(**argument_dict)
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a92d7a2fc876d379593b3425f14aa001f439e2c4a3ca882768fd8a7a35be363d
|
3 |
+
size 440466875
|
requirements.txt
CHANGED
@@ -1,45 +1,42 @@
|
|
1 |
-
aiohttp==3.8.
|
2 |
aiosignal==1.3.1
|
3 |
-
async-timeout==4.0.
|
4 |
-
attrs==
|
5 |
-
certifi==
|
6 |
-
charset-normalizer==3.
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
huggingface-hub==0.13.4
|
13 |
idna==3.4
|
14 |
ipadic==1.0.0
|
15 |
Jinja2==3.1.2
|
16 |
-
|
17 |
-
|
18 |
-
MarkupSafe==2.1.2
|
19 |
mojimoji==0.0.12
|
20 |
mpmath==1.3.0
|
21 |
multidict==6.0.4
|
22 |
-
networkx==3.1
|
23 |
-
numpy==1.
|
24 |
-
|
25 |
-
|
26 |
-
pandas==2.0.0
|
27 |
python-dateutil==2.8.2
|
28 |
-
pytorch-lightning==2.
|
29 |
-
pytz==2023.3
|
30 |
-
PyYAML==6.0
|
31 |
-
rapidfuzz==
|
32 |
-
regex==2023.3
|
33 |
-
requests==2.
|
|
|
34 |
six==1.16.0
|
35 |
-
|
36 |
-
sympy==1.11.1
|
37 |
tokenizers==0.13.3
|
38 |
-
torch==2.
|
39 |
-
torchmetrics==
|
40 |
-
tqdm==4.
|
41 |
transformers==4.27.4
|
42 |
-
typing_extensions==4.
|
43 |
tzdata==2023.3
|
44 |
-
urllib3==1.
|
45 |
-
yarl==1.
|
|
|
1 |
+
aiohttp==3.8.6
|
2 |
aiosignal==1.3.1
|
3 |
+
async-timeout==4.0.3
|
4 |
+
attrs==23.1.0
|
5 |
+
certifi==2023.7.22
|
6 |
+
charset-normalizer==3.3.2
|
7 |
+
filelock==3.13.1
|
8 |
+
frozenlist==1.4.0
|
9 |
+
fsspec==2023.10.0
|
10 |
+
fugashi==1.3.0
|
11 |
+
huggingface-hub==0.17.3
|
|
|
12 |
idna==3.4
|
13 |
ipadic==1.0.0
|
14 |
Jinja2==3.1.2
|
15 |
+
lightning-utilities==0.9.0
|
16 |
+
MarkupSafe==2.1.3
|
|
|
17 |
mojimoji==0.0.12
|
18 |
mpmath==1.3.0
|
19 |
multidict==6.0.4
|
20 |
+
networkx==3.2.1
|
21 |
+
numpy==1.26.2
|
22 |
+
packaging==23.2
|
23 |
+
pandas==2.1.3
|
|
|
24 |
python-dateutil==2.8.2
|
25 |
+
pytorch-lightning==2.1.1
|
26 |
+
pytz==2023.3.post1
|
27 |
+
PyYAML==6.0.1
|
28 |
+
rapidfuzz==3.5.2
|
29 |
+
regex==2023.10.3
|
30 |
+
requests==2.31.0
|
31 |
+
safetensors==0.4.0
|
32 |
six==1.16.0
|
33 |
+
sympy==1.12
|
|
|
34 |
tokenizers==0.13.3
|
35 |
+
torch==2.1.0
|
36 |
+
torchmetrics==1.2.0
|
37 |
+
tqdm==4.66.1
|
38 |
transformers==4.27.4
|
39 |
+
typing_extensions==4.8.0
|
40 |
tzdata==2023.3
|
41 |
+
urllib3==2.1.0
|
42 |
+
yarl==1.9.2
|
tokenizer_config.json
CHANGED
@@ -20,5 +20,5 @@
|
|
20 |
"NER_medNLP.NER_tokenizer_BIO",
|
21 |
null
|
22 |
],
|
23 |
-
"num_entity_type": "
|
24 |
}
|
|
|
20 |
"NER_medNLP.NER_tokenizer_BIO",
|
21 |
null
|
22 |
],
|
23 |
+
"num_entity_type": "43"
|
24 |
}
|
utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
|
4 |
+
def split_sentences(text):
|
5 |
+
"""Given a string, split it into sentences.
|
6 |
+
|
7 |
+
:param text: The string to be processed.
|
8 |
+
:return: The list of split sentences.
|
9 |
+
"""
|
10 |
+
processed_text = re.split(
|
11 |
+
"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=[.?!])\s\n*|(?<=[^A-zA-z0-90-9 ].)(?<=[。..??!!])(?![\.」])\n*", text)
|
12 |
+
# processed_text = re.split("(? <=[。??!!])") # In case only a simple regex is necessary
|
13 |
+
processed_text = [x.strip() for x in processed_text]
|
14 |
+
processed_text = [x for x in processed_text if x != '']
|
15 |
+
return processed_text
|