File size: 2,397 Bytes
894b24d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def find_broken_examples(data):
    splits = list(data.keys())
    broken = []
    
    for s in splits:
        for i, tokens in enumerate(data[s]['tokens']):
            for token in tokens:
                if not token.isprintable():
                    broken.append(s + '-' + str(i))
    
    return broken


def update_data(examples, split, broken_ids):  
    new_tags = []
    new_tokens = []
    for id_ in examples['id']:
        sent_id = split + '-' + id_
        if sent_id in broken_ids:
            continue
      
        new_tokens.append(examples['tokens'][int(id_)])
        new_tags.append(examples['ner_tags'][int(id_)])
        
        assert len(new_tokens) == len(new_tags)
        assert len(new_tokens[-1]) == len(new_tags[-1])
        
    return {
        'id': [str(i) for i in range(len(new_tokens))],
        'tokens': new_tokens,
        'ner_tags': new_tags
    }


def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            # Start of a new word!
            current_word = word_id
            label = -100 if word_id is None else labels[word_id]
            new_labels.append(label)
        elif word_id is None:
            # Special token
            new_labels.append(-100)
        else:
            # Same word as previous token
            # label = labels[word_id]
            # If the label is B-XXX we change it to I-XXX
            # if label % 2 == 1:
                # label += 1
            label = -100
            new_labels.append(label)

    return new_labels


def tokenize_and_align_labels(examples, tokenizer):
    tokenized_inputs = tokenizer(
        examples["tokens"], truncation=True, is_split_into_words=True, padding='max_length'
    )
    all_labels = examples["ner_tags"]
    new_labels = []
    word_ids = []
    for i, labels in enumerate(all_labels):
        word_ids.append(tokenized_inputs.word_ids(i))
        new_labels.append(align_labels_with_tokens(labels, word_ids[i]))

    tokenized_inputs["labels"] = new_labels
    tokenized_inputs['word_ids'] = word_ids
    
    return tokenized_inputs


# def model_init(checkpoint, id2label, label2id):
#     model = AutoModelForTokenClassification.from_pretrained(
#         checkpoint,
#         id2label=id2label,
#         label2id=label2id
#     )

#     return model