File size: 1,614 Bytes
920b22f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import sys
import argparse
import torch
from transformers import AutoTokenizer
from transformers import AutoModelForTokenClassification


def print_sentence(sentences, inputs, logits, model):
    words, tags, prob_out = [], [], []
    all_probs = logits.softmax(dim=2)
    for i, sentence in enumerate(sentences):
        # Map tokens to their respective word
        word_ids = inputs.word_ids(batch_index=i)
        previous_word_idx = None
        for k, word_idx in enumerate(word_ids):
            if word_idx is not None and word_idx != previous_word_idx:
                # Only label the first token of a given word.
                probs, tagIDs = all_probs[i][k].sort(descending=True)
                label = model.config.id2label[tagIDs[0].item()]
                prob = probs[0].item()
                word = sentence[word_idx]
                words.append(word)
                tags.append(label)
                prob_out.append(prob)
                previous_word_idx = word_idx
    
    return words, tags, prob_out
        
def tag_text(text):
    # device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")

    tokenizer = AutoTokenizer.from_pretrained("nielklug/enhg_tagger")
    model = AutoModelForTokenClassification.from_pretrained("nielklug/enhg_tagger")
    model = model.to(device).eval()

    with torch.no_grad():
        words = text.split('\n')
        inputs = tokenizer(words, is_split_into_words=True, return_tensors="pt")
        logits = model(**inputs.to(device)).logits
        return print_sentence([words], inputs, logits, model)