Hanna Hjelmeland commited on
Commit
58c23a6
1 Parent(s): cfb23d9

Add tabed interface and model 3

Browse files
Files changed (1) hide show
  1. app.py +23 -7
app.py CHANGED
@@ -53,6 +53,7 @@ def classify_text(test_text, selected_model):
53
  elif selected_model == 'Model 3':
54
  models = [f_30_40_model, f_40_55_model, m_30_40_model, m_40_55_model]
55
  predicted_labels = []
 
56
  performance_labels = ['Lite god', 'Nokså god', 'God']
57
  inputs = tokenizer(test_text, return_tensors="pt")
58
 
@@ -62,23 +63,38 @@ def classify_text(test_text, selected_model):
62
  logits = outputs.logits
63
  probabilities = torch.softmax(logits, dim=1)
64
 
 
 
65
  predicted_class = torch.argmax(probabilities, dim=1).item()
66
  predicted_performance = performance_labels[predicted_class]
67
  predicted_labels.append(predicted_performance)
 
68
 
69
- return dict(zip(categories, map(str,predicted_labels)))
 
 
 
 
70
 
71
  # Cell
72
  label = gr.outputs.Label()
73
  categories = ('Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55')
74
  app_title = "Target group classifier"
75
 
76
- examples = [["Moren leter etter sønnen i et ihjelbombet leilighetskompleks.", 'Model 1'],
77
  ["Fotballstadion tok fyr i helgen", 'Model 2'],
78
  ["De første månedene av krigen gikk så som så. Nå har Putin skiftet strategi.", 'Model 1'],
79
- ["Title: Disse hadde størst formue i 2022, Text lead: Laksearvingen Gustav Magnar Witzøe økte formuen med nesten 7 milliarder i fjor, og troner nok en gang øverst på listen over Norges rikeste.", "Model 3"],
80
- ["Title: Dette er de mest populære navnene i 2022, Text lead: Navnetoppen for 2022 er klar. Se hvilke navn som er mest populære i din kommune.", "Model 3"],
81
- ["Title: 2023 er det varmeste året noen gang registrert, Text lead: En ny rapport viser at 2023 er det varmeste året registrert siden man startet målingene. Klimaforsker kaller tallene urovekkende.", "Model 3"]
82
  ]
83
- intf = gr.Interface(fn=classify_text, inputs=["text", gr.Dropdown(['Model 1', 'Model 2', 'Model 3'])], outputs=label, examples=examples, title=app_title)
84
- intf.launch(inline=False)
 
 
 
 
 
 
 
 
 
 
 
 
53
  elif selected_model == 'Model 3':
54
  models = [f_30_40_model, f_40_55_model, m_30_40_model, m_40_55_model]
55
  predicted_labels = []
56
+ probs = []
57
  performance_labels = ['Lite god', 'Nokså god', 'God']
58
  inputs = tokenizer(test_text, return_tensors="pt")
59
 
 
63
  logits = outputs.logits
64
  probabilities = torch.softmax(logits, dim=1)
65
 
66
+ prob, _ = torch.max(probabilities, dim=1)
67
+ prob = prob.item()
68
  predicted_class = torch.argmax(probabilities, dim=1).item()
69
  predicted_performance = performance_labels[predicted_class]
70
  predicted_labels.append(predicted_performance)
71
+ probs.append(prob)
72
 
73
+ ret_str = '-------- Predicted performance ------ \n'
74
+ for cat, lab, prob in zip(categories, predicted_labels, probs):
75
+ ret_str += f' \t {cat}: {lab} \n \t Med sannsynlighet: {prob:.2f} \n'
76
+ ret_str += '------------------------------------ \n'
77
+ return ret_str
78
 
79
  # Cell
80
  label = gr.outputs.Label()
81
  categories = ('Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55')
82
  app_title = "Target group classifier"
83
 
84
+ examples_1 = [["Moren leter etter sønnen i et ihjelbombet leilighetskompleks.", 'Model 1'],
85
  ["Fotballstadion tok fyr i helgen", 'Model 2'],
86
  ["De første månedene av krigen gikk så som så. Nå har Putin skiftet strategi.", 'Model 1'],
 
 
 
87
  ]
88
+
89
+ examples_2 = [
90
+ ["Title: Disse hadde størst formue i 2022, Text lead: Laksearvingen Gustav Magnar Witzøe økte formuen med nesten 7 milliarder i fjor, og troner nok en gang øverst på listen over Norges rikeste."],
91
+ ["Title: Dette er de mest populære navnene i 2022, Text lead: Navnetoppen for 2022 er klar. Se hvilke navn som er mest populære i din kommune."],
92
+ ["Title: 2023 er det varmeste året noen gang registrert, Text lead: En ny rapport viser at 2023 er det varmeste året registrert siden man startet målingene. Klimaforsker kaller tallene urovekkende."]
93
+ ]
94
+
95
+ io1 = gr.Interface(fn=classify_text, inputs=["text", gr.Dropdown(['Model 1', 'Model 2'])], outputs='text', examples=examples_1, title=app_title)
96
+ io2 = gr.Interface(fn=classify_text, inputs=["text", 'Model 3'], outputs='text', examples=examples_2, title=app_title)
97
+
98
+ gr.TabbedInterface(
99
+ [io1, io2], ["Model 1 & 2", "Model 3"]
100
+ ).launch(inline=False, debug=True)