gabrielandrade2 commited on
Commit
576d564
1 Parent(s): 13dbcb6

Update model with additional negative examples, improve support scripts

Browse files
Files changed (11) hide show
  1. EntityNormalizer.py +17 -9
  2. NER_medNLP.py +77 -79
  3. README.md +86 -16
  4. config.json +13 -1
  5. id_to_tags.pkl +2 -2
  6. model.safetensors +0 -3
  7. predict.py +79 -42
  8. pytorch_model.bin +2 -2
  9. requirements.txt +30 -33
  10. tokenizer_config.json +1 -1
  11. utils.py +15 -0
EntityNormalizer.py CHANGED
@@ -5,29 +5,38 @@ from rapidfuzz import fuzz, process
5
 
6
  class EntityDictionary:
7
 
8
- def __init__(self, path):
 
 
 
 
 
 
 
9
  self.df = pd.read_csv(path)
 
 
10
 
11
  def get_candidates_list(self):
12
- return self.df.iloc[:, 0].to_list()
13
 
14
  def get_normalization_list(self):
15
- return self.df.iloc[:, 2].to_list()
16
 
17
  def get_normalized_term(self, term):
18
- return self.df[self.df.iloc[:, 0] == term].iloc[:, 2].item()
19
 
20
 
21
- class DiseaseDict(EntityDictionary):
22
 
23
  def __init__(self):
24
- super().__init__('dictionaries/disease_dict.csv')
25
 
26
 
27
- class DrugDict(EntityDictionary):
28
 
29
  def __init__(self):
30
- super().__init__('dictionaries/drug_dict.csv')
31
 
32
 
33
  class EntityNormalizer:
@@ -48,4 +57,3 @@ class EntityNormalizer:
48
  return ('' if pd.isna(ret) else ret), score
49
  else:
50
  return '', score
51
-
 
5
 
6
  class EntityDictionary:
7
 
8
+ def __init__(self, path, candidate_column, normalization_column):
9
+ if path is None:
10
+ raise ValueError('Path to dictionary file is not specified.')
11
+ if candidate_column is None:
12
+ raise ValueError('Candidate column is not specified.')
13
+ if normalization_column is None:
14
+ raise ValueError('Normalization column is not specified.')
15
+
16
  self.df = pd.read_csv(path)
17
+ self.candidate_column = candidate_column
18
+ self.normalization_column = normalization_column
19
 
20
  def get_candidates_list(self):
21
+ return self.df.iloc[:, self.candidate_column].to_list()
22
 
23
  def get_normalization_list(self):
24
+ return self.df.iloc[:, self.normalization_column].to_list()
25
 
26
  def get_normalized_term(self, term):
27
+ return self.df[self.df.iloc[:, self.candidate_column] == term].iloc[:, self.normalization_column].item()
28
 
29
 
30
+ class DefaultDiseaseDict(EntityDictionary):
31
 
32
  def __init__(self):
33
+ super().__init__('dictionaries/disease_dict.csv', 0, 2)
34
 
35
 
36
+ class DefaultDrugDict(EntityDictionary):
37
 
38
  def __init__(self):
39
+ super().__init__('dictionaries/drug_dict.csv', 0, 2)
40
 
41
 
42
  class EntityNormalizer:
 
57
  return ('' if pd.isna(ret) else ret), score
58
  else:
59
  return '', score
 
NER_medNLP.py CHANGED
@@ -1,46 +1,47 @@
1
  # %%
2
 
3
  import itertools
4
- from tqdm import tqdm
5
  import numpy as np
6
- import torch
7
- from transformers import BertJapaneseTokenizer, BertForTokenClassification
8
  import pytorch_lightning as pl
 
 
9
 
10
- # from torch.utils.data import DataLoader
11
- # import from_XML_to_json as XtC
12
- # import random
13
- # import json
14
- # import unicodedata
15
- # import pandas as pd
16
 
17
  # %%
18
- # 8-16
19
  # PyTorch Lightningのモデル
20
  class BertForTokenClassification_pl(pl.LightningModule):
21
-
22
- def __init__(self, model_name, num_labels, lr):
23
  super().__init__()
 
24
  self.save_hyperparameters()
25
  self.bert_tc = BertForTokenClassification.from_pretrained(
26
- model_name,
27
- num_labels=num_labels
 
28
  )
29
-
 
 
 
 
 
 
30
  def training_step(self, batch, batch_idx):
31
  output = self.bert_tc(**batch)
32
  loss = output.loss
33
  self.log('train_loss', loss)
34
  return loss
35
-
36
  def validation_step(self, batch, batch_idx):
37
  output = self.bert_tc(**batch)
38
  val_loss = output.loss
39
  self.log('val_loss', val_loss)
40
-
41
- def configure_optimizers(self):
42
- return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
43
 
 
 
44
 
45
 
46
  # %%
@@ -58,58 +59,58 @@ class NER_tokenizer_BIO(BertJapaneseTokenizer):
58
  符号化とラベル列の作成を行う。
59
  """
60
  # 固有表現の前後でtextを分割し、それぞれのラベルをつけておく。
61
- splitted = [] # 分割後の文字列を追加していく
62
  position = 0
63
-
64
  for entity in entities:
65
  start = entity['span'][0]
66
  end = entity['span'][1]
67
  label = entity['type_id']
68
- splitted.append({'text':text[position:start], 'label':0})
69
- splitted.append({'text':text[start:end], 'label':label})
70
  position = end
71
- splitted.append({'text': text[position:], 'label':0})
72
- splitted = [ s for s in splitted if s['text'] ]
73
 
74
  # 分割されたそれぞれの文章をトークン化し、ラベルをつける。
75
- tokens = [] # トークンを追加していく
76
- labels = [] # ラベルを追加していく
77
  for s in splitted:
78
  tokens_splitted = self.tokenize(s['text'])
79
  label = s['label']
80
- if label > 0: # 固有表現
81
  # まずトークン全てにI-タグを付与
82
  # 番号順O-tag:0, B-tag:1~タグの数,I-tag:タグの数〜
83
- labels_splitted = \
84
- [ label + self.num_entity_type ] * len(tokens_splitted)
85
  # 先頭のトークンをB-タグにする
86
  labels_splitted[0] = label
87
- else: # それ以外
88
- labels_splitted = [0] * len(tokens_splitted)
89
-
90
  tokens.extend(tokens_splitted)
91
  labels.extend(labels_splitted)
92
 
93
  # 符号化を行いBERTに入力できる形式にする。
94
  input_ids = self.convert_tokens_to_ids(tokens)
95
  encoding = self.prepare_for_model(
96
- input_ids,
97
- max_length=max_length,
98
  padding='max_length',
99
  truncation=True
100
- )
101
 
102
  # ラベルに特殊トークンを追加
103
  # max_lengthで切り取って,その前後に[CLS]と[SEP]を追加するためのラベルを入れる
104
- labels = [0] + labels[:max_length-2] + [0]
105
  # max_lengthに満たない場合は,満たない分を後ろ側に追加する
106
- labels = labels + [0]*( max_length - len(labels) )
107
  encoding['labels'] = labels
108
 
109
  return encoding
110
 
111
  def encode_plus_untagged(
112
- self, text, max_length=None, return_tensors=None
113
  ):
114
  """
115
  文章をトークン化し、それぞれのトークンの文章中の位置も特定しておく。
@@ -117,50 +118,50 @@ class NER_tokenizer_BIO(BertJapaneseTokenizer):
117
  """
118
  # 文章のトークン化を行い、
119
  # それぞれのトークンと文章中の文字列を対応づける。
120
- tokens = [] # トークンを追加していく。
121
- tokens_original = [] # トークンに対応する文章中の文字列を追加していく。
122
- words = self.word_tokenizer.tokenize(text) # MeCabで単語に分割
123
  for word in words:
124
  # 単語をサブワードに分割
125
- tokens_word = self.subword_tokenizer.tokenize(word)
126
  tokens.extend(tokens_word)
127
- if tokens_word[0] == '[UNK]': # 未知語への対応
128
  tokens_original.append(word)
129
  else:
130
  tokens_original.extend([
131
- token.replace('##','') for token in tokens_word
132
  ])
133
 
134
  # 各トークンの文章中での位置を調べる。(空白の位置を考慮する)
135
  position = 0
136
- spans = [] # トークンの位置を追加していく。
137
  for token in tokens_original:
138
  l = len(token)
139
  while 1:
140
- if token != text[position:position+l]:
141
  position += 1
142
  else:
143
- spans.append([position, position+l])
144
  position += l
145
  break
146
 
147
  # 符号化を行いBERTに入力できる形式にする。
148
- input_ids = self.convert_tokens_to_ids(tokens)
149
  encoding = self.prepare_for_model(
150
- input_ids,
151
- max_length=max_length,
152
- padding='max_length' if max_length else False,
153
  truncation=True if max_length else False
154
  )
155
  sequence_length = len(encoding['input_ids'])
156
  # 特殊トークン[CLS]に対するダミーのspanを追加。
157
- spans = [[-1, -1]] + spans[:sequence_length-2]
158
  # 特殊トークン[SEP]、[PAD]に対するダミーのspanを追加。
159
- spans = spans + [[-1, -1]] * ( sequence_length - len(spans) )
160
 
161
  # 必要に応じてtorch.Tensorにする。
162
  if return_tensors == 'pt':
163
- encoding = { k: torch.tensor([v]) for k, v in encoding.items() }
164
 
165
  return encoding, spans
166
 
@@ -169,28 +170,26 @@ class NER_tokenizer_BIO(BertJapaneseTokenizer):
169
  """
170
  Viterbiアルゴリズムで最適解を求める。
171
  """
172
- m = 2*num_entity_type + 1
173
  penalty_matrix = np.zeros([m, m])
174
  for i in range(m):
175
- for j in range(1+num_entity_type, m):
176
- if not ( (i == j) or (i+num_entity_type == j) ):
177
- penalty_matrix[i,j] = penalty
178
- path = [ [i] for i in range(m) ]
179
- scores_path = scores_bert[0] - penalty_matrix[0,:]
180
  scores_bert = scores_bert[1:]
181
 
182
-
183
-
184
  for scores in scores_bert:
185
- assert len(scores) == 2*num_entity_type + 1
186
- score_matrix = np.array(scores_path).reshape(-1,1) \
187
- + np.array(scores).reshape(1,-1) \
188
- - penalty_matrix
189
  scores_path = score_matrix.max(axis=0)
190
  argmax = score_matrix.argmax(axis=0)
191
  path_new = []
192
  for i, idx in enumerate(argmax):
193
- path_new.append( path[idx] + [i] )
194
  path = path_new
195
 
196
  labels_optimal = path[np.argmax(scores_path)]
@@ -203,26 +202,26 @@ class NER_tokenizer_BIO(BertJapaneseTokenizer):
203
  """
204
  assert len(spans) == len(scores)
205
  num_entity_type = self.num_entity_type
206
-
207
  # 特殊トークンに対応する部分を取り除く
208
- scores = [score for score, span in zip(scores, spans) if span[0]!=-1]
209
- spans = [span for span in spans if span[0]!=-1]
210
-
211
  # Viterbiアルゴリズムでラベルの予測値を決める。
212
  labels = self.Viterbi(scores, num_entity_type)
213
 
214
  # 同じラベルが連続するトークンをまとめて、固有表現を抽出する。
215
  entities = []
216
  for label, group \
217
- in itertools.groupby(enumerate(labels), key=lambda x: x[1]):
218
-
219
  group = list(group)
220
  start = spans[group[0][0]][0]
221
  end = spans[group[-1][0]][1]
222
 
223
- if label != 0: # 固有表現であれば
224
  if 1 <= label <= num_entity_type:
225
- # ラベルが`B-`ならば、新しいentityを追加
226
  entity = {
227
  "name": text[start:end],
228
  "span": [start, end],
@@ -231,8 +230,7 @@ class NER_tokenizer_BIO(BertJapaneseTokenizer):
231
  entities.append(entity)
232
  else:
233
  # ラベルが`I-`ならば、直近のentityを更新
234
- entity['span'][1] = end
235
  entity['name'] = text[entity['span'][0]:entity['span'][1]]
236
-
237
- return entities
238
 
 
 
1
  # %%
2
 
3
  import itertools
4
+
5
  import numpy as np
 
 
6
  import pytorch_lightning as pl
7
+ import torch
8
+ from transformers import BertForTokenClassification, BertJapaneseTokenizer
9
 
 
 
 
 
 
 
10
 
11
  # %%
12
+ # 8-16
13
  # PyTorch Lightningのモデル
14
  class BertForTokenClassification_pl(pl.LightningModule):
15
+
16
+ def __init__(self, num_labels, model='sociocom/MedNERN-CR-JA', lr=1e-4):
17
  super().__init__()
18
+ self.lr = lr
19
  self.save_hyperparameters()
20
  self.bert_tc = BertForTokenClassification.from_pretrained(
21
+ model,
22
+ num_labels=num_labels,
23
+ ignore_mismatched_sizes=True
24
  )
25
+
26
+ @classmethod
27
+ def from_pretrained_bin(cls, model_path, num_labels):
28
+ model = cls(num_labels)
29
+ model.load_state_dict(torch.load(model_path))
30
+ return model
31
+
32
  def training_step(self, batch, batch_idx):
33
  output = self.bert_tc(**batch)
34
  loss = output.loss
35
  self.log('train_loss', loss)
36
  return loss
37
+
38
  def validation_step(self, batch, batch_idx):
39
  output = self.bert_tc(**batch)
40
  val_loss = output.loss
41
  self.log('val_loss', val_loss)
 
 
 
42
 
43
+ def configure_optimizers(self):
44
+ return torch.optim.Adam(self.parameters(), lr=self.lr)
45
 
46
 
47
  # %%
 
59
  符号化とラベル列の作成を行う。
60
  """
61
  # 固有表現の前後でtextを分割し、それぞれのラベルをつけておく。
62
+ splitted = [] # 分割後の文字列を追加していく
63
  position = 0
64
+
65
  for entity in entities:
66
  start = entity['span'][0]
67
  end = entity['span'][1]
68
  label = entity['type_id']
69
+ splitted.append({'text': text[position:start], 'label': 0})
70
+ splitted.append({'text': text[start:end], 'label': label})
71
  position = end
72
+ splitted.append({'text': text[position:], 'label': 0})
73
+ splitted = [s for s in splitted if s['text']]
74
 
75
  # 分割されたそれぞれの文章をトークン化し、ラベルをつける。
76
+ tokens = [] # トークンを追加していく
77
+ labels = [] # ラベルを追加していく
78
  for s in splitted:
79
  tokens_splitted = self.tokenize(s['text'])
80
  label = s['label']
81
+ if label > 0: # 固有表現
82
  # まずトークン全てにI-タグを付与
83
  # 番号順O-tag:0, B-tag:1~タグの数,I-tag:タグの数〜
84
+ labels_splitted = \
85
+ [label + self.num_entity_type] * len(tokens_splitted)
86
  # 先頭のトークンをB-タグにする
87
  labels_splitted[0] = label
88
+ else: # それ以外
89
+ labels_splitted = [0] * len(tokens_splitted)
90
+
91
  tokens.extend(tokens_splitted)
92
  labels.extend(labels_splitted)
93
 
94
  # 符号化を行いBERTに入力できる形式にする。
95
  input_ids = self.convert_tokens_to_ids(tokens)
96
  encoding = self.prepare_for_model(
97
+ input_ids,
98
+ max_length=max_length,
99
  padding='max_length',
100
  truncation=True
101
+ )
102
 
103
  # ラベルに特殊トークンを追加
104
  # max_lengthで切り取って,その前後に[CLS]と[SEP]を追加するためのラベルを入れる
105
+ labels = [0] + labels[:max_length - 2] + [0]
106
  # max_lengthに満たない場合は,満たない分を後ろ側に追加する
107
+ labels = labels + [0] * (max_length - len(labels))
108
  encoding['labels'] = labels
109
 
110
  return encoding
111
 
112
  def encode_plus_untagged(
113
+ self, text, max_length=None, return_tensors=None
114
  ):
115
  """
116
  文章をトークン化し、それぞれのトークンの文章中の位置も特定しておく。
 
118
  """
119
  # 文章のトークン化を行い、
120
  # それぞれのトークンと文章中の文字列を対応づける。
121
+ tokens = [] # トークンを追加していく。
122
+ tokens_original = [] # トークンに対応する文章中の文字列を追加していく。
123
+ words = self.word_tokenizer.tokenize(text) # MeCabで単語に分割
124
  for word in words:
125
  # 単語をサブワードに分割
126
+ tokens_word = self.subword_tokenizer.tokenize(word)
127
  tokens.extend(tokens_word)
128
+ if tokens_word[0] == '[UNK]': # 未知語への対応
129
  tokens_original.append(word)
130
  else:
131
  tokens_original.extend([
132
+ token.replace('##', '') for token in tokens_word
133
  ])
134
 
135
  # 各トークンの文章中での位置を調べる。(空白の位置を考慮する)
136
  position = 0
137
+ spans = [] # トークンの位置を追加していく。
138
  for token in tokens_original:
139
  l = len(token)
140
  while 1:
141
+ if token != text[position:position + l]:
142
  position += 1
143
  else:
144
+ spans.append([position, position + l])
145
  position += l
146
  break
147
 
148
  # 符号化を行いBERTに入力できる形式にする。
149
+ input_ids = self.convert_tokens_to_ids(tokens)
150
  encoding = self.prepare_for_model(
151
+ input_ids,
152
+ max_length=max_length,
153
+ padding='max_length' if max_length else False,
154
  truncation=True if max_length else False
155
  )
156
  sequence_length = len(encoding['input_ids'])
157
  # 特殊トークン[CLS]に対するダミーのspanを追加。
158
+ spans = [[-1, -1]] + spans[:sequence_length - 2]
159
  # 特殊トークン[SEP]、[PAD]に対するダミーのspanを追加。
160
+ spans = spans + [[-1, -1]] * (sequence_length - len(spans))
161
 
162
  # 必要に応じてtorch.Tensorにする。
163
  if return_tensors == 'pt':
164
+ encoding = {k: torch.tensor([v]) for k, v in encoding.items()}
165
 
166
  return encoding, spans
167
 
 
170
  """
171
  Viterbiアルゴリズムで最適解を求める。
172
  """
173
+ m = 2 * num_entity_type + 1
174
  penalty_matrix = np.zeros([m, m])
175
  for i in range(m):
176
+ for j in range(1 + num_entity_type, m):
177
+ if not ((i == j) or (i + num_entity_type == j)):
178
+ penalty_matrix[i, j] = penalty
179
+ path = [[i] for i in range(m)]
180
+ scores_path = scores_bert[0] - penalty_matrix[0, :]
181
  scores_bert = scores_bert[1:]
182
 
 
 
183
  for scores in scores_bert:
184
+ assert len(scores) == 2 * num_entity_type + 1
185
+ score_matrix = np.array(scores_path).reshape(-1, 1) \
186
+ + np.array(scores).reshape(1, -1) \
187
+ - penalty_matrix
188
  scores_path = score_matrix.max(axis=0)
189
  argmax = score_matrix.argmax(axis=0)
190
  path_new = []
191
  for i, idx in enumerate(argmax):
192
+ path_new.append(path[idx] + [i])
193
  path = path_new
194
 
195
  labels_optimal = path[np.argmax(scores_path)]
 
202
  """
203
  assert len(spans) == len(scores)
204
  num_entity_type = self.num_entity_type
205
+
206
  # 特殊トークンに対応する部分を取り除く
207
+ scores = [score for score, span in zip(scores, spans) if span[0] != -1]
208
+ spans = [span for span in spans if span[0] != -1]
209
+
210
  # Viterbiアルゴリズムでラベルの予測値を決める。
211
  labels = self.Viterbi(scores, num_entity_type)
212
 
213
  # 同じラベルが連続するトークンをまとめて、固有表現を抽出する。
214
  entities = []
215
  for label, group \
216
+ in itertools.groupby(enumerate(labels), key=lambda x: x[1]):
217
+
218
  group = list(group)
219
  start = spans[group[0][0]][0]
220
  end = spans[group[-1][0]][1]
221
 
222
+ if label != 0: # 固有表現であれば
223
  if 1 <= label <= num_entity_type:
224
+ # ラベルが`B-`ならば、新しいentityを追加
225
  entity = {
226
  "name": text[start:end],
227
  "span": [start, end],
 
230
  entities.append(entity)
231
  else:
232
  # ラベルが`I-`ならば、直近のentityを更新
233
+ entity['span'][1] = end
234
  entity['name'] = text[entity['span'][0]:entity['span'][1]]
 
 
235
 
236
+ return entities
README.md CHANGED
@@ -12,26 +12,65 @@ metrics:
12
  - NTCIR-16 Real-MedNLP subtask 1
13
  ---
14
 
15
-
16
  This is a model for named entity recognition of Japanese medical documents.
17
 
18
- ### How to use
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- Download the following five files and put into the same folder.
21
 
22
- - id_to_tags.pkl
23
- - key_attr.pkl
24
- - NER_medNLP.py
25
- - predict.py
26
- - text.txt (This is an input file which should be predicted, which could be changed.)
27
 
28
- You can use this model by running `predict.py`.
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  ```
31
  python3 predict.py
32
  ```
33
 
34
- #### Entity normalization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  This model supports entity normalization via dictionary matching. The dictionary is a list of medical terms or
37
  drugs and their standard forms.
@@ -39,23 +78,42 @@ drugs and their standard forms.
39
  Two different dictionaries are used for drug and disease normalization, stored in the `dictionaries` folder as
40
  `drug_dict.csv` and `disease_dict.csv`, respectively.
41
 
42
- To enable normalization you can add the `--normalize` flag to the `predict.py` command.
43
 
44
  ```
45
- python3 predict.py --normalize
46
  ```
47
 
48
  Normalization will add the `norm` attribute to the output XML tags. This attribute can be empty if a normalized form of
49
  the term is not found.
50
 
51
- The provided disease normalization dictionary (`dictionaties/disease_dict.csv`) is based on the [Manbyo Dictionary](https://sociocom.naist.jp/manbyo-dic-en/) and provides normalization to the standard ICD code for the diseases.
 
 
52
 
53
- The default drug dictionary (`dictionaties/drug_dict.csv`) is based on the [Hyakuyaku Dictionary](https://sociocom.naist.jp/hyakuyaku-dic-en/).
 
54
 
55
  The dictionary is a CSV file with three columns: the first column is the surface form term and the third column contain
56
  its standard form. The second column is not used.
57
 
58
- User can freely change the dictionary to fit their needs, as long as the format and filename are kept.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  ### Input Example
61
 
@@ -71,5 +129,17 @@ User can freely change the dictionary to fit their needs, as long as the format
71
  <timex3 type="med">治療経過中</timex3>に<d certainty="positive" norm="I472">非持続性心室頻拍</d>が認められたため<m-key state="executed" norm="アミオダロン塩酸塩">アミオダロン</m-key>が併用となった。
72
  ```
73
 
74
- ### Publication
 
 
75
 
 
 
 
 
 
 
 
 
 
 
 
12
  - NTCIR-16 Real-MedNLP subtask 1
13
  ---
14
 
 
15
  This is a model for named entity recognition of Japanese medical documents.
16
 
17
+ # Introduction
18
+
19
+ This repository contains the base model and a support predict script for using the model and providing a XML tagged text output.
20
+
21
+ The original model was trained on the [MedTxt-CR-JA](https://sociocom.naist.jp/medtxt/cr) dataset, so the provided prediction code outputs XML tags in the same format.
22
+
23
+ The script also provide the normalization method for the output entities, which is not embedded in the model.
24
+
25
+ If you want to re-train or update the model, we provide additional support scripts in [this GitHub repository](https://github.com/sociocom/MedNERN-CR-JA).
26
+ Issues and suggestions can also be submitted there.
27
+
28
+ ### A note about loading the model using standard HuggingFace methods
29
+ This model should also be loadable using standard HuggingFace `from_pretrained` methods. However, the model by itself only outputs labels in the format "LABEL_0", "LABEL1", etc.
30
+
31
+ The conversion of model outputs to the actual labels ("<m-key>, "<m-val>", "<timex-3>" etc.) is not yet embedded into the model, so the extra `id_to_tags.pkl` file is necessary
32
+ to make the conversion. It contains a mapping from the model output ids to the actual labels.
33
+
34
+ Such process can be done manually if needed, but the `predict.py` script already does that.
35
 
36
+ We are currently working to better standardize the model to HuggingFace's standards.
37
 
38
+ ## How to use
 
 
 
 
39
 
40
+ Clone the repository and install the requirements:
41
+
42
+ ```
43
+ pip install -r requirements.txt
44
+ ```
45
+
46
+ The code has been developed tested with Python 3.9 in MacOS 14.1 (M1 MacBook Pro).
47
+
48
+ ### Prediction
49
+
50
+ The prediction script will output the results in the same XML format as the input file. It can be run with the following
51
+ command:
52
 
53
  ```
54
  python3 predict.py
55
  ```
56
 
57
+ The default parameters will take the model located in `pytorch_model.bin` and the input file `text.txt`.
58
+ The resulting predictions will be output to the screen.
59
+
60
+ To select a different model or input file, use the `-m` and `-i` parameters, respectively:
61
+
62
+ ```
63
+ python3 predict.py -m <model_path> -i <your_input_file>.txt
64
+ ```
65
+
66
+ The input file can be a single text file or a folder containing multiple `.txt` files, for batch processing. For example:
67
+
68
+ ```
69
+ python3 predict.py -m <model_path> -i <your_input_folder>
70
+ ```
71
+
72
+
73
+ ### Entity normalization
74
 
75
  This model supports entity normalization via dictionary matching. The dictionary is a list of medical terms or
76
  drugs and their standard forms.
 
78
  Two different dictionaries are used for drug and disease normalization, stored in the `dictionaries` folder as
79
  `drug_dict.csv` and `disease_dict.csv`, respectively.
80
 
81
+ To enable normalization you can add the `--normalize` flag to the `predict.py` command.
82
 
83
  ```
84
+ python3 predict.py -m <model_path> --normalize
85
  ```
86
 
87
  Normalization will add the `norm` attribute to the output XML tags. This attribute can be empty if a normalized form of
88
  the term is not found.
89
 
90
+ The provided disease normalization dictionary (`dictionaties/disease_dict.csv`) is based on
91
+ the [Manbyo Dictionary](https://sociocom.naist.jp/manbyo-dic-en/) and provides normalization to the standard ICD code
92
+ for the diseases.
93
 
94
+ The default drug dictionary (`dictionaties/drug_dict.csv`) is based on
95
+ the [Hyakuyaku Dictionary](https://sociocom.naist.jp/hyakuyaku-dic-en/).
96
 
97
  The dictionary is a CSV file with three columns: the first column is the surface form term and the third column contain
98
  its standard form. The second column is not used.
99
 
100
+ ### Replacing the default dictionaries
101
+
102
+ User can freely change the dictionary to fit their needs by passing the path to a custom dictionary file.
103
+ The dictionary file must have at least a column containing the list of surface forms and a column containing the list of
104
+ normalized forms.
105
+
106
+ The parameters `--drug_dict` and `--disease_dict` can be used to specify the path to the drug and disease dictionaries,
107
+ respectively.
108
+ When doing so, the respective parameters informing the column index of the surface form and normalized form must also be
109
+ provided.
110
+ You don't need to replace both dictionaries at the same time, you can replace only one of them.
111
+
112
+ E.g.:
113
+
114
+ ```
115
+ python3 predict.py --normalize --drug_dict dictionaries/drug_dict.csv --drug_surface_form 0 --drug_norm_form 2 --disease_dict dictionaries/disease_dict.csv --disease_surface_form 0 --disease_norm_form 2
116
+ ```
117
 
118
  ### Input Example
119
 
 
129
  <timex3 type="med">治療経過中</timex3>に<d certainty="positive" norm="I472">非持続性心室頻拍</d>が認められたため<m-key state="executed" norm="アミオダロン塩酸塩">アミオダロン</m-key>が併用となった。
130
  ```
131
 
132
+ ## Publication
133
+
134
+ This model can be cited as:
135
 
136
+ ```
137
+ @misc {social_computing_lab_2023,
138
+ author = { {Social Computing Lab} },
139
+ title = { MedNERN-CR-JA (Revision 13dbcb6) },
140
+ year = 2023,
141
+ url = { https://huggingface.co/sociocom/MedNERN-CR-JA },
142
+ doi = { 10.57967/hf/0620 },
143
+ publisher = { Hugging Face }
144
+ }
145
+ ```
config.json CHANGED
@@ -89,7 +89,13 @@
89
  "77": "LABEL_77",
90
  "78": "LABEL_78",
91
  "79": "LABEL_79",
92
- "80": "LABEL_80"
 
 
 
 
 
 
93
  },
94
  "initializer_range": 0.02,
95
  "intermediate_size": 3072,
@@ -174,6 +180,12 @@
174
  "LABEL_79": 79,
175
  "LABEL_8": 8,
176
  "LABEL_80": 80,
 
 
 
 
 
 
177
  "LABEL_9": 9
178
  },
179
  "layer_norm_eps": 1e-12,
 
89
  "77": "LABEL_77",
90
  "78": "LABEL_78",
91
  "79": "LABEL_79",
92
+ "80": "LABEL_80",
93
+ "81": "LABEL_81",
94
+ "82": "LABEL_82",
95
+ "83": "LABEL_83",
96
+ "84": "LABEL_84",
97
+ "85": "LABEL_85",
98
+ "86": "LABEL_86"
99
  },
100
  "initializer_range": 0.02,
101
  "intermediate_size": 3072,
 
180
  "LABEL_79": 79,
181
  "LABEL_8": 8,
182
  "LABEL_80": 80,
183
+ "LABEL_81": 81,
184
+ "LABEL_82": 82,
185
+ "LABEL_83": 83,
186
+ "LABEL_84": 84,
187
+ "LABEL_85": 85,
188
+ "LABEL_86": 86,
189
  "LABEL_9": 9
190
  },
191
  "layer_norm_eps": 1e-12,
id_to_tags.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:57e7ea0bc4bdcaf4b19f7eec5c6edf2fce867cc9895cb20079b48881bc32ee5a
3
- size 620
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26cbbc0594cf7a1c4439a1010c5e2c55c1f0fb0a9664d93248b7b7d7de0cc434
3
+ size 671
model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:802f29afc4eae3cbf49f3957f9b2d27ae247e6e079ad6d947ccf181dff7c754c
3
- size 440383704
 
 
 
 
predict.py CHANGED
@@ -1,23 +1,24 @@
1
  # %%
2
  import argparse
3
-
4
- from tqdm import tqdm
5
- import unicodedata
6
- import re
7
  import pickle
 
 
8
  import torch
9
- import NER_medNLP as ner
10
 
11
- from EntityNormalizer import EntityNormalizer, DiseaseDict, DrugDict
 
 
12
 
13
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
14
 
15
  # %% global変数として使う
16
  dict_key = {}
17
 
18
 
19
  # %%
20
- def to_xml(data):
21
  with open("key_attr.pkl", "rb") as tf:
22
  key_attr = pickle.load(tf)
23
 
@@ -27,7 +28,11 @@ def to_xml(data):
27
  if entities == "":
28
  return
29
  span = entities['span']
30
- type_id = id_to_tags[entities['type_id']].split('_')
 
 
 
 
31
  tag = type_id[0]
32
 
33
  if not type_id[1] == "":
@@ -49,17 +54,11 @@ def to_xml(data):
49
 
50
 
51
  def predict_entities(modelpath, sentences_list, len_num_entity_type):
52
- # model = ner.BertForTokenClassification_pl.load_from_checkpoint(
53
- # checkpoint_path = modelpath + ".ckpt"
54
- # )
55
- # bert_tc = model.bert_tc.cuda()
56
-
57
- model = ner.BertForTokenClassification_pl(modelpath, num_labels=81, lr=1e-5)
58
  bert_tc = model.bert_tc.to(device)
59
 
60
- MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'
61
  tokenizer = ner.NER_tokenizer_BIO.from_pretrained(
62
- MODEL_NAME,
63
  num_entity_type=len_num_entity_type # Entityの数を変え忘れないように!
64
  )
65
 
@@ -69,7 +68,7 @@ def predict_entities(modelpath, sentences_list, len_num_entity_type):
69
  text_entities_set = []
70
  for dataset in sentences_list:
71
  text_entities = []
72
- for sample in tqdm(dataset):
73
  text = sample
74
  encoding, spans = tokenizer.encode_plus_untagged(
75
  text, return_tensors='pt'
@@ -93,12 +92,12 @@ def predict_entities(modelpath, sentences_list, len_num_entity_type):
93
  return text_entities_set
94
 
95
 
96
- def combine_sentences(text_entities_set, insert: str):
97
  documents = []
98
  for text_entities in tqdm(text_entities_set):
99
  document = []
100
  for t in text_entities:
101
- document.append(to_xml(t))
102
  documents.append('\n'.join(document))
103
  return documents
104
 
@@ -115,9 +114,19 @@ def value_to_key(value, key_attr): # attributeから属性名を取得
115
 
116
 
117
  # %%
118
- def normalize_entities(text_entities_set):
119
- disease_normalizer = EntityNormalizer(DiseaseDict(), matching_threshold=50)
120
- drug_normalizer = EntityNormalizer(DrugDict(), matching_threshold=50)
 
 
 
 
 
 
 
 
 
 
121
 
122
  for entry in text_entities_set:
123
  for text_entities in entry:
@@ -136,31 +145,59 @@ def normalize_entities(text_entities_set):
136
  entity['norm'] = str(normalization)
137
 
138
 
139
- if __name__ == '__main__':
140
- parser = argparse.ArgumentParser(description='Predict entities from text')
141
- parser.add_argument('--normalize', action=argparse.BooleanOptionalAction, help='Enable entity normalization')
142
- args = parser.parse_args()
143
-
144
  with open("id_to_tags.pkl", "rb") as tf:
145
  id_to_tags = pickle.load(tf)
146
- with open("key_attr.pkl", "rb") as tf:
147
- key_attr = pickle.load(tf)
148
- with open('text.txt') as f:
149
- articles_raw = f.read()
150
 
151
- article_norm = unicodedata.normalize('NFKC', articles_raw)
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- sentences_raw = [s for s in re.split(r'\n', articles_raw) if s != '']
154
- sentences_norm = [s for s in re.split(r'\n', article_norm) if s != '']
155
 
156
- text_entities_set = predict_entities("sociocom/RealMedNLP_CR_JA", [sentences_norm], len(id_to_tags))
 
157
 
158
- for i, texts_ent in enumerate(text_entities_set[0]):
159
- texts_ent['text'] = sentences_raw[i]
160
 
161
- if args.normalize:
162
- normalize_entities(text_entities_set)
163
 
164
- documents = combine_sentences(text_entities_set, '\n')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- print(documents[0])
 
 
1
  # %%
2
  import argparse
3
+ import os.path
 
 
 
4
  import pickle
5
+ import unicodedata
6
+
7
  import torch
8
+ from tqdm import tqdm
9
 
10
+ import NER_medNLP as ner
11
+ import utils
12
+ from EntityNormalizer import EntityNormalizer, EntityDictionary, DefaultDiseaseDict, DefaultDrugDict
13
 
14
+ device = torch.device("mps" if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
15
 
16
  # %% global変数として使う
17
  dict_key = {}
18
 
19
 
20
  # %%
21
+ def to_xml(data, id_to_tags):
22
  with open("key_attr.pkl", "rb") as tf:
23
  key_attr = pickle.load(tf)
24
 
 
28
  if entities == "":
29
  return
30
  span = entities['span']
31
+ try:
32
+ type_id = id_to_tags[entities['type_id']].split('_')
33
+ except:
34
+ print("out of rage type_id", entities)
35
+ continue
36
  tag = type_id[0]
37
 
38
  if not type_id[1] == "":
 
54
 
55
 
56
  def predict_entities(modelpath, sentences_list, len_num_entity_type):
57
+ model = ner.BertForTokenClassification_pl.from_pretrained_bin(model_path=modelpath, num_labels=2 * len_num_entity_type + 1)
 
 
 
 
 
58
  bert_tc = model.bert_tc.to(device)
59
 
 
60
  tokenizer = ner.NER_tokenizer_BIO.from_pretrained(
61
+ 'cl-tohoku/bert-base-japanese-whole-word-masking',
62
  num_entity_type=len_num_entity_type # Entityの数を変え忘れないように!
63
  )
64
 
 
68
  text_entities_set = []
69
  for dataset in sentences_list:
70
  text_entities = []
71
+ for sample in tqdm(dataset, desc='Predict'):
72
  text = sample
73
  encoding, spans = tokenizer.encode_plus_untagged(
74
  text, return_tensors='pt'
 
92
  return text_entities_set
93
 
94
 
95
+ def combine_sentences(text_entities_set, id_to_tags, insert: str):
96
  documents = []
97
  for text_entities in tqdm(text_entities_set):
98
  document = []
99
  for t in text_entities:
100
+ document.append(to_xml(t, id_to_tags))
101
  documents.append('\n'.join(document))
102
  return documents
103
 
 
114
 
115
 
116
  # %%
117
+ def normalize_entities(text_entities_set, id_to_tags, disease_dict=None, disease_candidate_col=None, disease_normalization_col=None, disease_matching_threshold=None, drug_dict=None,
118
+ drug_candidate_col=None, drug_normalization_col=None, drug_matching_threshold=None):
119
+ if disease_dict:
120
+ disease_dict = EntityDictionary(disease_dict, disease_candidate_col, disease_normalization_col)
121
+ else:
122
+ disease_dict = DefaultDiseaseDict()
123
+ disease_normalizer = EntityNormalizer(disease_dict, matching_threshold=disease_matching_threshold)
124
+
125
+ if drug_dict:
126
+ drug_dict = EntityDictionary(drug_dict, drug_candidate_col, drug_normalization_col)
127
+ else:
128
+ drug_dict = DefaultDrugDict()
129
+ drug_normalizer = EntityNormalizer(drug_dict, matching_threshold=drug_matching_threshold)
130
 
131
  for entry in text_entities_set:
132
  for text_entities in entry:
 
145
  entity['norm'] = str(normalization)
146
 
147
 
148
+ def run(model, input, output=None, normalize=False, **kwargs):
 
 
 
 
149
  with open("id_to_tags.pkl", "rb") as tf:
150
  id_to_tags = pickle.load(tf)
 
 
 
 
151
 
152
+ if (os.path.isdir(input)):
153
+ files = [f for f in os.listdir(input) if os.path.isfile(os.path.join(input, f))]
154
+ else:
155
+ files = [input]
156
+
157
+ for file in tqdm(files, desc="Input file"):
158
+ with open(file) as f:
159
+ articles_raw = f.read()
160
+
161
+ article_norm = unicodedata.normalize('NFKC', articles_raw)
162
+
163
+ sentences_raw = utils.split_sentences(articles_raw)
164
+ sentences_norm = utils.split_sentences(article_norm)
165
 
166
+ text_entities_set = predict_entities(model, [sentences_norm], len(id_to_tags))
 
167
 
168
+ for i, texts_ent in enumerate(text_entities_set[0]):
169
+ texts_ent['text'] = sentences_raw[i]
170
 
171
+ if normalize:
172
+ normalize_entities(text_entities_set, id_to_tags, **kwargs)
173
 
174
+ documents = combine_sentences(text_entities_set, id_to_tags, '\n')
 
175
 
176
+ print(documents[0])
177
+
178
+ if output:
179
+ with open(file.replace(input, output), 'w') as f:
180
+ f.write(documents[0])
181
+
182
+
183
+ if __name__ == '__main__':
184
+ parser = argparse.ArgumentParser(description='Predict entities from text')
185
+ parser.add_argument('-m', '--model', type=str, default='pytorch_model.bin', help='Path to model checkpoint')
186
+ parser.add_argument('-i', '--input', type=str, default='text.txt', help='Path to text file or directory')
187
+ parser.add_argument('-o', '--output', type=str, default=None, help='Path to output file or directory')
188
+ parser.add_argument('-n', '--normalize', action=argparse.BooleanOptionalAction, help='Enable entity normalization', default=False)
189
+
190
+ # Dictionary override arguments
191
+ parser.add_argument("--drug-dict", help="File path for overriding the default drug dictionary")
192
+ parser.add_argument("--drug-candidate-col", type=int, help="Column name for drug candidates in the CSV file (required if --drug-dict is specified)")
193
+ parser.add_argument("--drug-normalization-col", type=int, help="Column name for drug normalization in the CSV file (required if --drug-dict is specified")
194
+ parser.add_argument('--disease-matching-threshold', type=int, default=50, help='Matching threshold for disease dictionary')
195
+
196
+ parser.add_argument("--disease-dict", help="File path for overriding the default disease dictionary")
197
+ parser.add_argument("--disease-candidate-col", type=int, help="Column name for disease candidates in the CSV file (required if --disease-dict is specified)")
198
+ parser.add_argument("--disease-normalization-col", type=int, help="Column name for disease normalization in the CSV file (required if --disease-dict is specified)")
199
+ parser.add_argument('--drug-matching-threshold', type=int, default=50, help='Matching threshold for drug dictionary')
200
+ args = parser.parse_args()
201
 
202
+ argument_dict = vars(args)
203
+ run(**argument_dict)
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ccce71084b8f6e81415e8f8e07cf27f59087aa2fda02c296959322ef8acb8a6a
3
- size 440439601
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a92d7a2fc876d379593b3425f14aa001f439e2c4a3ca882768fd8a7a35be363d
3
+ size 440466875
requirements.txt CHANGED
@@ -1,45 +1,42 @@
1
- aiohttp==3.8.4
2
  aiosignal==1.3.1
3
- async-timeout==4.0.2
4
- attrs==22.2.0
5
- certifi==2022.12.7
6
- charset-normalizer==3.1.0
7
- et-xmlfile==1.1.0
8
- filelock==3.11.0
9
- frozenlist==1.3.3
10
- fsspec==2023.4.0
11
- fugashi==1.2.1
12
- huggingface-hub==0.13.4
13
  idna==3.4
14
  ipadic==1.0.0
15
  Jinja2==3.1.2
16
- Levenshtein==0.20.9
17
- lightning-utilities==0.8.0
18
- MarkupSafe==2.1.2
19
  mojimoji==0.0.12
20
  mpmath==1.3.0
21
  multidict==6.0.4
22
- networkx==3.1
23
- numpy==1.24.2
24
- openpyxl==3.1.2
25
- packaging==23.0
26
- pandas==2.0.0
27
  python-dateutil==2.8.2
28
- pytorch-lightning==2.0.1.post0
29
- pytz==2023.3
30
- PyYAML==6.0
31
- rapidfuzz==2.15.1
32
- regex==2023.3.23
33
- requests==2.28.2
 
34
  six==1.16.0
35
- soupsieve==2.4
36
- sympy==1.11.1
37
  tokenizers==0.13.3
38
- torch==2.0.0
39
- torchmetrics==0.11.4
40
- tqdm==4.65.0
41
  transformers==4.27.4
42
- typing_extensions==4.5.0
43
  tzdata==2023.3
44
- urllib3==1.26.15
45
- yarl==1.8.2
 
1
+ aiohttp==3.8.6
2
  aiosignal==1.3.1
3
+ async-timeout==4.0.3
4
+ attrs==23.1.0
5
+ certifi==2023.7.22
6
+ charset-normalizer==3.3.2
7
+ filelock==3.13.1
8
+ frozenlist==1.4.0
9
+ fsspec==2023.10.0
10
+ fugashi==1.3.0
11
+ huggingface-hub==0.17.3
 
12
  idna==3.4
13
  ipadic==1.0.0
14
  Jinja2==3.1.2
15
+ lightning-utilities==0.9.0
16
+ MarkupSafe==2.1.3
 
17
  mojimoji==0.0.12
18
  mpmath==1.3.0
19
  multidict==6.0.4
20
+ networkx==3.2.1
21
+ numpy==1.26.2
22
+ packaging==23.2
23
+ pandas==2.1.3
 
24
  python-dateutil==2.8.2
25
+ pytorch-lightning==2.1.1
26
+ pytz==2023.3.post1
27
+ PyYAML==6.0.1
28
+ rapidfuzz==3.5.2
29
+ regex==2023.10.3
30
+ requests==2.31.0
31
+ safetensors==0.4.0
32
  six==1.16.0
33
+ sympy==1.12
 
34
  tokenizers==0.13.3
35
+ torch==2.1.0
36
+ torchmetrics==1.2.0
37
+ tqdm==4.66.1
38
  transformers==4.27.4
39
+ typing_extensions==4.8.0
40
  tzdata==2023.3
41
+ urllib3==2.1.0
42
+ yarl==1.9.2
tokenizer_config.json CHANGED
@@ -20,5 +20,5 @@
20
  "NER_medNLP.NER_tokenizer_BIO",
21
  null
22
  ],
23
- "num_entity_type": "40"
24
  }
 
20
  "NER_medNLP.NER_tokenizer_BIO",
21
  null
22
  ],
23
+ "num_entity_type": "43"
24
  }
utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def split_sentences(text):
5
+ """Given a string, split it into sentences.
6
+
7
+ :param text: The string to be processed.
8
+ :return: The list of split sentences.
9
+ """
10
+ processed_text = re.split(
11
+ "(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=[.?!])\s\n*|(?<=[^A-zA-z0-90-9 ].)(?<=[。..??!!])(?![\.」])\n*", text)
12
+ # processed_text = re.split("(? <=[。??!!])") # In case only a simple regex is necessary
13
+ processed_text = [x.strip() for x in processed_text]
14
+ processed_text = [x for x in processed_text if x != '']
15
+ return processed_text