Spaces:
Sleeping
Sleeping
File size: 4,562 Bytes
28981b1 2f01d97 e494bde ad3a545 e494bde 2f01d97 ad3a545 2f01d97 480db19 2f01d97 480db19 2f01d97 4a49928 8219eaa ad3a545 4a49928 8219eaa ad3a545 4a49928 8219eaa 28981b1 4a49928 8219eaa ad3a545 4578655 ad3a545 8219eaa 483b69e 58c23a6 483b69e 8219eaa 58c23a6 8219eaa 483b69e 58c23a6 8219eaa 58c23a6 ad3a545 58c23a6 688d5f2 8219eaa ad3a545 58c23a6 4578655 58c23a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
__all__ = ['is_flower', 'learn', 'classify_image', 'categories', 'image', 'label', 'examples', 'intf']
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import torch
model_name = "NbAiLab/nb-bert-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
first_model_path = "models/first_model"
first_model = AutoModelForSequenceClassification.from_pretrained(first_model_path)
second_model_path = "models/second_model"
second_model = AutoModelForSequenceClassification.from_pretrained(second_model_path)
f_30_40_model_path = "models/FEMALE_30_40model_new"
f_30_40_model = AutoModelForSequenceClassification.from_pretrained(f_30_40_model_path)
f_40_55_model_path = "models/FEMALE_40_55model_new"
f_40_55_model = AutoModelForSequenceClassification.from_pretrained(f_40_55_model_path)
m_30_40_model_path = "models/MALE_30_40model_new"
m_30_40_model = AutoModelForSequenceClassification.from_pretrained(m_30_40_model_path)
m_40_55_model_path = "models/MALE_40_55model_new"
m_40_55_model = AutoModelForSequenceClassification.from_pretrained(m_40_55_model_path)
def classify_text(test_text, selected_model='Model 3'):
categories = ['Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55']
if selected_model in ('Model 1', 'Model 2'):
if selected_model == 'Model 1':
model = first_model
elif selected_model == 'Model 2':
model = second_model
else:
raise ValueError("Invalid model selection")
inputs = tokenizer(test_text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
class_labels = model.config.id2label
predicted_label = class_labels[predicted_class]
probabilities = probabilities[0].tolist()
return dict(zip(categories, map(float,probabilities)))
elif selected_model == 'Model 3':
models = [f_30_40_model, f_40_55_model, m_30_40_model, m_40_55_model]
predicted_labels = []
probs = []
performance_labels = ['Lite god', 'Nokså god', 'God']
inputs = tokenizer(test_text, return_tensors="pt")
for model in models:
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
prob, _ = torch.max(probabilities, dim=1)
prob = prob.item()
predicted_class = torch.argmax(probabilities, dim=1).item()
predicted_performance = performance_labels[predicted_class]
predicted_labels.append(predicted_performance)
probs.append(prob)
ret_str = '-------- Predicted performance ------ \n'
for cat, lab, prob in zip(categories, predicted_labels, probs):
ret_str += f' \t {cat}: {lab} \n \t Med sannsynlighet: {prob:.2f} \n'
ret_str += '------------------------------------ \n'
return ret_str
# Cell
label = gr.outputs.Label()
categories = ('Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55')
app_title = "Target group classifier"
examples_1 = [["Moren leter etter sønnen i et ihjelbombet leilighetskompleks.", 'Model 1'],
["Fotballstadion tok fyr i helgen", 'Model 2'],
["De første månedene av krigen gikk så som så. Nå har Putin skiftet strategi.", 'Model 1'],
]
examples_2 = [
["Title: Disse hadde størst formue i 2022, Text lead: Laksearvingen Gustav Magnar Witzøe økte formuen med nesten 7 milliarder i fjor, og troner nok en gang øverst på listen over Norges rikeste."],
["Title: Dette er de mest populære navnene i 2022, Text lead: Navnetoppen for 2022 er klar. Se hvilke navn som er mest populære i din kommune."],
["Title: 2023 er det varmeste året noen gang registrert, Text lead: En ny rapport viser at 2023 er det varmeste året registrert siden man startet målingene. Klimaforsker kaller tallene urovekkende."]
]
io1 = gr.Interface(fn=classify_text, inputs=["text", gr.Dropdown(['Model 1', 'Model 2'])], outputs='label', examples=examples_1, title=app_title)
io2 = gr.Interface(fn=classify_text, inputs=["text"], outputs='text', examples=examples_2, title=app_title)
gr.TabbedInterface(
[io1, io2], ["Model 1 & 2", "Model 3"]
).launch(inline=False, debug=True)
|