Hanna Hjelmeland
Update model paths
4a49928
raw
history blame
4.56 kB
__all__ = ['is_flower', 'learn', 'classify_image', 'categories', 'image', 'label', 'examples', 'intf']
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import torch
model_name = "NbAiLab/nb-bert-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
first_model_path = "models/first_model"
first_model = AutoModelForSequenceClassification.from_pretrained(first_model_path)
second_model_path = "models/second_model"
second_model = AutoModelForSequenceClassification.from_pretrained(second_model_path)
f_30_40_model_path = "models/FEMALE_30_40model_new"
f_30_40_model = AutoModelForSequenceClassification.from_pretrained(f_30_40_model_path)
f_40_55_model_path = "models/FEMALE_40_55model_new"
f_40_55_model = AutoModelForSequenceClassification.from_pretrained(f_40_55_model_path)
m_30_40_model_path = "models/MALE_30_40model_new"
m_30_40_model = AutoModelForSequenceClassification.from_pretrained(m_30_40_model_path)
m_40_55_model_path = "models/MALE_40_55model_new"
m_40_55_model = AutoModelForSequenceClassification.from_pretrained(m_40_55_model_path)
def classify_text(test_text, selected_model='Model 3'):
categories = ['Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55']
if selected_model in ('Model 1', 'Model 2'):
if selected_model == 'Model 1':
model = first_model
elif selected_model == 'Model 2':
model = second_model
else:
raise ValueError("Invalid model selection")
inputs = tokenizer(test_text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
class_labels = model.config.id2label
predicted_label = class_labels[predicted_class]
probabilities = probabilities[0].tolist()
return dict(zip(categories, map(float,probabilities)))
elif selected_model == 'Model 3':
models = [f_30_40_model, f_40_55_model, m_30_40_model, m_40_55_model]
predicted_labels = []
probs = []
performance_labels = ['Lite god', 'Nokså god', 'God']
inputs = tokenizer(test_text, return_tensors="pt")
for model in models:
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
prob, _ = torch.max(probabilities, dim=1)
prob = prob.item()
predicted_class = torch.argmax(probabilities, dim=1).item()
predicted_performance = performance_labels[predicted_class]
predicted_labels.append(predicted_performance)
probs.append(prob)
ret_str = '-------- Predicted performance ------ \n'
for cat, lab, prob in zip(categories, predicted_labels, probs):
ret_str += f' \t {cat}: {lab} \n \t Med sannsynlighet: {prob:.2f} \n'
ret_str += '------------------------------------ \n'
return ret_str
# Cell
label = gr.outputs.Label()
categories = ('Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55')
app_title = "Target group classifier"
examples_1 = [["Moren leter etter sønnen i et ihjelbombet leilighetskompleks.", 'Model 1'],
["Fotballstadion tok fyr i helgen", 'Model 2'],
["De første månedene av krigen gikk så som så. Nå har Putin skiftet strategi.", 'Model 1'],
]
examples_2 = [
["Title: Disse hadde størst formue i 2022, Text lead: Laksearvingen Gustav Magnar Witzøe økte formuen med nesten 7 milliarder i fjor, og troner nok en gang øverst på listen over Norges rikeste."],
["Title: Dette er de mest populære navnene i 2022, Text lead: Navnetoppen for 2022 er klar. Se hvilke navn som er mest populære i din kommune."],
["Title: 2023 er det varmeste året noen gang registrert, Text lead: En ny rapport viser at 2023 er det varmeste året registrert siden man startet målingene. Klimaforsker kaller tallene urovekkende."]
]
io1 = gr.Interface(fn=classify_text, inputs=["text", gr.Dropdown(['Model 1', 'Model 2'])], outputs='label', examples=examples_1, title=app_title)
io2 = gr.Interface(fn=classify_text, inputs=["text"], outputs='text', examples=examples_2, title=app_title)
gr.TabbedInterface(
[io1, io2], ["Model 1 & 2", "Model 3"]
).launch(inline=False, debug=True)