|
|
|
|
|
import itertools |
|
|
|
import numpy as np |
|
import pytorch_lightning as pl |
|
import torch |
|
from transformers import BertForTokenClassification, BertJapaneseTokenizer |
|
|
|
|
|
|
|
|
|
|
|
class BertForTokenClassification_pl(pl.LightningModule): |
|
|
|
def __init__(self, num_labels, model='sociocom/MedNERN-CR-JA', lr=1e-4): |
|
super().__init__() |
|
self.lr = lr |
|
self.save_hyperparameters() |
|
self.bert_tc = BertForTokenClassification.from_pretrained( |
|
model, |
|
num_labels=num_labels, |
|
ignore_mismatched_sizes=True |
|
) |
|
|
|
@classmethod |
|
def from_pretrained_bin(cls, model_path, num_labels): |
|
model = cls(num_labels) |
|
model.load_state_dict(torch.load(model_path)) |
|
return model |
|
|
|
def training_step(self, batch, batch_idx): |
|
output = self.bert_tc(**batch) |
|
loss = output.loss |
|
self.log('train_loss', loss) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
output = self.bert_tc(**batch) |
|
val_loss = output.loss |
|
self.log('val_loss', val_loss) |
|
|
|
def configure_optimizers(self): |
|
return torch.optim.Adam(self.parameters(), lr=self.lr) |
|
|
|
|
|
|
|
class NER_tokenizer_BIO(BertJapaneseTokenizer): |
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
self.num_entity_type = kwargs.pop('num_entity_type') |
|
super().__init__(*args, **kwargs) |
|
|
|
def encode_plus_tagged(self, text, entities, max_length): |
|
""" |
|
文章とそれに含まれる固有表現が与えられた時に、 |
|
符号化とラベル列の作成を行う。 |
|
""" |
|
|
|
splitted = [] |
|
position = 0 |
|
|
|
for entity in entities: |
|
start = entity['span'][0] |
|
end = entity['span'][1] |
|
label = entity['type_id'] |
|
splitted.append({'text': text[position:start], 'label': 0}) |
|
splitted.append({'text': text[start:end], 'label': label}) |
|
position = end |
|
splitted.append({'text': text[position:], 'label': 0}) |
|
splitted = [s for s in splitted if s['text']] |
|
|
|
|
|
tokens = [] |
|
labels = [] |
|
for s in splitted: |
|
tokens_splitted = self.tokenize(s['text']) |
|
label = s['label'] |
|
if label > 0: |
|
|
|
|
|
labels_splitted = \ |
|
[label + self.num_entity_type] * len(tokens_splitted) |
|
|
|
labels_splitted[0] = label |
|
else: |
|
labels_splitted = [0] * len(tokens_splitted) |
|
|
|
tokens.extend(tokens_splitted) |
|
labels.extend(labels_splitted) |
|
|
|
|
|
input_ids = self.convert_tokens_to_ids(tokens) |
|
encoding = self.prepare_for_model( |
|
input_ids, |
|
max_length=max_length, |
|
padding='max_length', |
|
truncation=True |
|
) |
|
|
|
|
|
|
|
labels = [0] + labels[:max_length - 2] + [0] |
|
|
|
labels = labels + [0] * (max_length - len(labels)) |
|
encoding['labels'] = labels |
|
|
|
return encoding |
|
|
|
def encode_plus_untagged( |
|
self, text, max_length=None, return_tensors=None |
|
): |
|
""" |
|
文章をトークン化し、それぞれのトークンの文章中の位置も特定しておく。 |
|
IO法のトークナイザのencode_plus_untaggedと同じ |
|
""" |
|
|
|
|
|
tokens = [] |
|
tokens_original = [] |
|
words = self.word_tokenizer.tokenize(text) |
|
for word in words: |
|
|
|
tokens_word = self.subword_tokenizer.tokenize(word) |
|
tokens.extend(tokens_word) |
|
if tokens_word[0] == '[UNK]': |
|
tokens_original.append(word) |
|
else: |
|
tokens_original.extend([ |
|
token.replace('##', '') for token in tokens_word |
|
]) |
|
|
|
|
|
position = 0 |
|
spans = [] |
|
for token in tokens_original: |
|
l = len(token) |
|
while 1: |
|
if token != text[position:position + l]: |
|
position += 1 |
|
else: |
|
spans.append([position, position + l]) |
|
position += l |
|
break |
|
|
|
|
|
input_ids = self.convert_tokens_to_ids(tokens) |
|
encoding = self.prepare_for_model( |
|
input_ids, |
|
max_length=max_length, |
|
padding='max_length' if max_length else False, |
|
truncation=True if max_length else False |
|
) |
|
sequence_length = len(encoding['input_ids']) |
|
|
|
spans = [[-1, -1]] + spans[:sequence_length - 2] |
|
|
|
spans = spans + [[-1, -1]] * (sequence_length - len(spans)) |
|
|
|
|
|
if return_tensors == 'pt': |
|
encoding = {k: torch.tensor([v]) for k, v in encoding.items()} |
|
|
|
return encoding, spans |
|
|
|
@staticmethod |
|
def Viterbi(scores_bert, num_entity_type, penalty=10000): |
|
""" |
|
Viterbiアルゴリズムで最適解を求める。 |
|
""" |
|
m = 2 * num_entity_type + 1 |
|
penalty_matrix = np.zeros([m, m]) |
|
for i in range(m): |
|
for j in range(1 + num_entity_type, m): |
|
if not ((i == j) or (i + num_entity_type == j)): |
|
penalty_matrix[i, j] = penalty |
|
path = [[i] for i in range(m)] |
|
scores_path = scores_bert[0] - penalty_matrix[0, :] |
|
scores_bert = scores_bert[1:] |
|
|
|
for scores in scores_bert: |
|
assert len(scores) == 2 * num_entity_type + 1 |
|
score_matrix = np.array(scores_path).reshape(-1, 1) \ |
|
+ np.array(scores).reshape(1, -1) \ |
|
- penalty_matrix |
|
scores_path = score_matrix.max(axis=0) |
|
argmax = score_matrix.argmax(axis=0) |
|
path_new = [] |
|
for i, idx in enumerate(argmax): |
|
path_new.append(path[idx] + [i]) |
|
path = path_new |
|
|
|
labels_optimal = path[np.argmax(scores_path)] |
|
return labels_optimal |
|
|
|
def convert_bert_output_to_entities(self, text, scores, spans): |
|
""" |
|
文章、分類スコア、各トークンの位置から固有表現を得る。 |
|
分類スコアはサイズが(系列長、ラベル数)の2次元配列 |
|
""" |
|
assert len(spans) == len(scores) |
|
num_entity_type = self.num_entity_type |
|
|
|
|
|
scores = [score for score, span in zip(scores, spans) if span[0] != -1] |
|
spans = [span for span in spans if span[0] != -1] |
|
|
|
|
|
labels = self.Viterbi(scores, num_entity_type) |
|
|
|
|
|
entities = [] |
|
for label, group \ |
|
in itertools.groupby(enumerate(labels), key=lambda x: x[1]): |
|
|
|
group = list(group) |
|
start = spans[group[0][0]][0] |
|
end = spans[group[-1][0]][1] |
|
|
|
if label != 0: |
|
if 1 <= label <= num_entity_type: |
|
|
|
entity = { |
|
"name": text[start:end], |
|
"span": [start, end], |
|
"type_id": label |
|
} |
|
entities.append(entity) |
|
else: |
|
|
|
entity['span'][1] = end |
|
entity['name'] = text[entity['span'][0]:entity['span'][1]] |
|
|
|
return entities |
|
|