def find_broken_examples(data): splits = list(data.keys()) broken = [] for s in splits: for i, tokens in enumerate(data[s]['tokens']): for token in tokens: if not token.isprintable(): broken.append(s + '-' + str(i)) return broken def update_data(examples, split, broken_ids): new_tags = [] new_tokens = [] for id_ in examples['id']: sent_id = split + '-' + id_ if sent_id in broken_ids: continue new_tokens.append(examples['tokens'][int(id_)]) new_tags.append(examples['ner_tags'][int(id_)]) assert len(new_tokens) == len(new_tags) assert len(new_tokens[-1]) == len(new_tags[-1]) return { 'id': [str(i) for i in range(len(new_tokens))], 'tokens': new_tokens, 'ner_tags': new_tags } def align_labels_with_tokens(labels, word_ids): new_labels = [] current_word = None for word_id in word_ids: if word_id != current_word: # Start of a new word! current_word = word_id label = -100 if word_id is None else labels[word_id] new_labels.append(label) elif word_id is None: # Special token new_labels.append(-100) else: # Same word as previous token # label = labels[word_id] # If the label is B-XXX we change it to I-XXX # if label % 2 == 1: # label += 1 label = -100 new_labels.append(label) return new_labels def tokenize_and_align_labels(examples, tokenizer): tokenized_inputs = tokenizer( examples["tokens"], truncation=True, is_split_into_words=True, padding='max_length' ) all_labels = examples["ner_tags"] new_labels = [] word_ids = [] for i, labels in enumerate(all_labels): word_ids.append(tokenized_inputs.word_ids(i)) new_labels.append(align_labels_with_tokens(labels, word_ids[i])) tokenized_inputs["labels"] = new_labels tokenized_inputs['word_ids'] = word_ids return tokenized_inputs # def model_init(checkpoint, id2label, label2id): # model = AutoModelForTokenClassification.from_pretrained( # checkpoint, # id2label=id2label, # label2id=label2id # ) # return model