File size: 1,807 Bytes
28981b1
2f01d97
 
e494bde
ad3a545
 
e494bde
2f01d97
ad3a545
 
2f01d97
f303bd5
ad3a545
2f01d97
 
ad3a545
 
 
 
 
 
 
 
 
 
 
 
e494bde
ad3a545
28981b1
ad3a545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f01d97
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
__all__ = ['is_flower', 'learn', 'classify_image', 'categories', 'image', 'label', 'examples', 'intf']

from fastai.vision.all import *
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import torch


model_name = "NbAiLab/nb-bert-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)

model_path = "models/first_model"
model = AutoModelForSequenceClassification.from_pretrained(model_path)


def classify_text(test_text):
    inputs = tokenizer(test_text, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=1)

    predicted_class = torch.argmax(probabilities, dim=1).item()
    class_labels = model.config.id2label
    predicted_label = class_labels[predicted_class]
    probabilities = probabilities[0].tolist()

    categories = ['Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55']

    category_probabilities = list(zip(categories, probabilities))

    max_category = max(category_probabilities, key=lambda x: x[1])

    #print('The model predicts that this text lead would have a majority of readers in the target group', max_category[0])

    return dict(zip(categories, map(float,probabilities)))

# Cell
label = gr.outputs.Label()
categories = ('Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55')
app_title = "Target group classifier"

examples = ["Moren leter etter sønnen i et ihjelbombet leilighetskompleks.",
            "De første månedene av krigen gikk så som så. Nå har Putin skiftet strategi."
            "Fotballstadion tok fyr i helgen"
]
intf = gr.Interface(fn=classify_text, inputs="text", outputs=label, examples=examples, title=app_title)
intf.launch(inline=False)