enhg-parsing / annotate.py
nielklug's picture
add tag
920b22f
import sys
import argparse
import torch
from transformers import AutoTokenizer
from transformers import AutoModelForTokenClassification
def print_sentence(sentences, inputs, logits, model):
words, tags, prob_out = [], [], []
all_probs = logits.softmax(dim=2)
for i, sentence in enumerate(sentences):
# Map tokens to their respective word
word_ids = inputs.word_ids(batch_index=i)
previous_word_idx = None
for k, word_idx in enumerate(word_ids):
if word_idx is not None and word_idx != previous_word_idx:
# Only label the first token of a given word.
probs, tagIDs = all_probs[i][k].sort(descending=True)
label = model.config.id2label[tagIDs[0].item()]
prob = probs[0].item()
word = sentence[word_idx]
words.append(word)
tags.append(label)
prob_out.append(prob)
previous_word_idx = word_idx
return words, tags, prob_out
def tag_text(text):
# device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
tokenizer = AutoTokenizer.from_pretrained("nielklug/enhg_tagger")
model = AutoModelForTokenClassification.from_pretrained("nielklug/enhg_tagger")
model = model.to(device).eval()
with torch.no_grad():
words = text.split('\n')
inputs = tokenizer(words, is_split_into_words=True, return_tensors="pt")
logits = model(**inputs.to(device)).logits
return print_sentence([words], inputs, logits, model)