Spaces:
Sleeping
Sleeping
__all__ = ['is_flower', 'learn', 'classify_image', 'categories', 'image', 'label', 'examples', 'intf'] | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments | |
import torch | |
model_name = "NbAiLab/nb-bert-base" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
first_model_path = "models/first_model" | |
first_model = AutoModelForSequenceClassification.from_pretrained(first_model_path) | |
second_model_path = "models/second_model" | |
second_model = AutoModelForSequenceClassification.from_pretrained(second_model_path) | |
f_30_40_model_path = "models/FEMALE_30_40model_new" | |
f_30_40_model = AutoModelForSequenceClassification.from_pretrained(f_30_40_model_path) | |
f_40_55_model_path = "models/FEMALE_40_55model_new" | |
f_40_55_model = AutoModelForSequenceClassification.from_pretrained(f_40_55_model_path) | |
m_30_40_model_path = "models/MALE_30_40model_new" | |
m_30_40_model = AutoModelForSequenceClassification.from_pretrained(m_30_40_model_path) | |
m_40_55_model_path = "models/MALE_40_55model_new" | |
m_40_55_model = AutoModelForSequenceClassification.from_pretrained(m_40_55_model_path) | |
def classify_text(test_text, selected_model='Model 3'): | |
categories = ['Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55'] | |
if selected_model in ('Model 1', 'Model 2'): | |
if selected_model == 'Model 1': | |
model = first_model | |
elif selected_model == 'Model 2': | |
model = second_model | |
else: | |
raise ValueError("Invalid model selection") | |
inputs = tokenizer(test_text, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probabilities = torch.softmax(logits, dim=1) | |
predicted_class = torch.argmax(probabilities, dim=1).item() | |
class_labels = model.config.id2label | |
predicted_label = class_labels[predicted_class] | |
probabilities = probabilities[0].tolist() | |
return dict(zip(categories, map(float,probabilities))) | |
elif selected_model == 'Model 3': | |
models = [f_30_40_model, f_40_55_model, m_30_40_model, m_40_55_model] | |
predicted_labels = [] | |
probs = [] | |
performance_labels = ['Lite god', 'Nokså god', 'God'] | |
inputs = tokenizer(test_text, return_tensors="pt") | |
for model in models: | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probabilities = torch.softmax(logits, dim=1) | |
prob, _ = torch.max(probabilities, dim=1) | |
prob = prob.item() | |
predicted_class = torch.argmax(probabilities, dim=1).item() | |
predicted_performance = performance_labels[predicted_class] | |
predicted_labels.append(predicted_performance) | |
probs.append(prob) | |
ret_str = '-------- Predicted performance ------ \n' | |
for cat, lab, prob in zip(categories, predicted_labels, probs): | |
ret_str += f' \t {cat}: {lab} \n \t Med sannsynlighet: {prob:.2f} \n' | |
ret_str += '------------------------------------ \n' | |
return ret_str | |
# Cell | |
label = gr.outputs.Label() | |
categories = ('Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55') | |
app_title = "Target group classifier" | |
examples_1 = [["Moren leter etter sønnen i et ihjelbombet leilighetskompleks.", 'Model 1'], | |
["Fotballstadion tok fyr i helgen", 'Model 2'], | |
["De første månedene av krigen gikk så som så. Nå har Putin skiftet strategi.", 'Model 1'], | |
] | |
examples_2 = [ | |
["Title: Disse hadde størst formue i 2022, Text lead: Laksearvingen Gustav Magnar Witzøe økte formuen med nesten 7 milliarder i fjor, og troner nok en gang øverst på listen over Norges rikeste."], | |
["Title: Dette er de mest populære navnene i 2022, Text lead: Navnetoppen for 2022 er klar. Se hvilke navn som er mest populære i din kommune."], | |
["Title: 2023 er det varmeste året noen gang registrert, Text lead: En ny rapport viser at 2023 er det varmeste året registrert siden man startet målingene. Klimaforsker kaller tallene urovekkende."] | |
] | |
io1 = gr.Interface(fn=classify_text, inputs=["text", gr.Dropdown(['Model 1', 'Model 2'])], outputs='label', examples=examples_1, title=app_title) | |
io2 = gr.Interface(fn=classify_text, inputs=["text"], outputs='text', examples=examples_2, title=app_title) | |
gr.TabbedInterface( | |
[io1, io2], ["Model 1 & 2", "Model 3"] | |
).launch(inline=False, debug=True) | |