Hanna Hjelmeland commited on
Commit
8219eaa
1 Parent(s): b3bfc37

Update app

Browse files
Files changed (1) hide show
  1. app.py +56 -27
app.py CHANGED
@@ -14,36 +14,62 @@ first_model = AutoModelForSequenceClassification.from_pretrained(first_model_pat
14
  second_model_path = "models/second_model"
15
  second_model = AutoModelForSequenceClassification.from_pretrained(second_model_path)
16
 
17
- def classify_text(test_text, selected_model):
18
-
19
- if selected_model == 'Model 1':
20
- model = first_model
21
- elif selected_model == 'Model 2':
22
- model = second_model
23
- else:
24
- raise ValueError("Invalid model selection")
25
-
26
- inputs = tokenizer(test_text, return_tensors="pt")
27
 
28
- with torch.no_grad():
29
- outputs = model(**inputs)
30
- logits = outputs.logits
31
- probabilities = torch.softmax(logits, dim=1)
32
 
33
- predicted_class = torch.argmax(probabilities, dim=1).item()
34
- class_labels = model.config.id2label
35
- predicted_label = class_labels[predicted_class]
36
- probabilities = probabilities[0].tolist()
37
-
38
- categories = ['Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55']
39
 
40
- #category_probabilities = list(zip(categories, probabilities))
 
41
 
42
- #max_category = max(category_probabilities, key=lambda x: x[1])
43
-
44
- #print('The model predicts that this text lead would have a majority of readers in the target group', max_category[0])
45
 
46
- return dict(zip(categories, map(float,probabilities)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # Cell
49
  label = gr.outputs.Label()
@@ -52,7 +78,10 @@ app_title = "Target group classifier"
52
 
53
  examples = [["Moren leter etter sønnen i et ihjelbombet leilighetskompleks.", 'Model 1'],
54
  ["Fotballstadion tok fyr i helgen", 'Model 2'],
55
- ["De første månedene av krigen gikk så som så. Nå har Putin skiftet strategi.", 'Model 1']
 
 
 
56
  ]
57
- intf = gr.Interface(fn=classify_text, inputs=["text", gr.Dropdown(['Model 1', 'Model 2'])], outputs=label, examples=examples, title=app_title)
58
  intf.launch(inline=False)
 
14
  second_model_path = "models/second_model"
15
  second_model = AutoModelForSequenceClassification.from_pretrained(second_model_path)
16
 
17
+ f_30_40_model_path = "models/FEMALE_30_40model"
18
+ f_30_40_model = AutoModelForSequenceClassification.from_pretrained(f_30_40_model_path)
 
 
 
 
 
 
 
 
19
 
20
+ f_40_55_model_path = "models/FEMALE_40_55model"
21
+ f_40_55_model = AutoModelForSequenceClassification.from_pretrained(f_40_55_model_path)
 
 
22
 
23
+ m_30_40_model_path = "models/MALE_30_40model"
24
+ m_30_40_model = AutoModelForSequenceClassification.from_pretrained(m_30_40_model_path)
 
 
 
 
25
 
26
+ m_40_55_model_path = "models/MALE_40_55model"
27
+ m_40_55_model = AutoModelForSequenceClassification.from_pretrained(m_40_55_model_path)
28
 
29
+ def classify_text(test_text, selected_model):
 
 
30
 
31
+ categories = ['Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55']
32
+
33
+ if selected_model in ('Model 1', 'Model 2'):
34
+ if selected_model == 'Model 1':
35
+ model = first_model
36
+ elif selected_model == 'Model 2':
37
+ model = second_model
38
+ else:
39
+ raise ValueError("Invalid model selection")
40
+ inputs = tokenizer(test_text, return_tensors="pt")
41
+
42
+ with torch.no_grad():
43
+ outputs = model(**inputs)
44
+ logits = outputs.logits
45
+ probabilities = torch.softmax(logits, dim=1)
46
+
47
+ predicted_class = torch.argmax(probabilities, dim=1).item()
48
+ class_labels = model.config.id2label
49
+ predicted_label = class_labels[predicted_class]
50
+ probabilities = probabilities[0].tolist()
51
+
52
+ return dict(zip(categories, map(float,probabilities)))
53
+ elif selected_model == 'Model 3':
54
+ models = [f_30_40_model, f_40_55_model, m_30_40_model, m_40_55_model]
55
+ performance_labels = []
56
+ inputs = tokenizer(test_text, return_tensors="pt")
57
+
58
+ for model in models:
59
+ with torch.no_grad():
60
+ outputs = model(**inputs)
61
+ logits = outputs.logits
62
+ probabilities = torch.softmax(logits, dim=1)
63
+
64
+ predicted_class = torch.argmax(probabilities, dim=1).item()
65
+ performance_labels = ['Lite god', 'Nokså god', 'God']
66
+ predicted_performance = performance_labels[predicted_class]
67
+
68
+ class_labels = model.config.id2label
69
+ predicted_label = class_labels[predicted_class]
70
+ performance_labels.append(predicted_label)
71
+
72
+ return dict(zip(categories, map(float,performance_labels)))
73
 
74
  # Cell
75
  label = gr.outputs.Label()
 
78
 
79
  examples = [["Moren leter etter sønnen i et ihjelbombet leilighetskompleks.", 'Model 1'],
80
  ["Fotballstadion tok fyr i helgen", 'Model 2'],
81
+ ["De første månedene av krigen gikk så som så. Nå har Putin skiftet strategi.", 'Model 1'],
82
+ ["Title: Disse hadde størst formue i 2022, Text lead: Laksearvingen Gustav Magnar Witzøe økte formuen med nesten 7 milliarder i fjor, og troner nok en gang øverst på listen over Norges rikeste.", "Model 3"],
83
+ ["Title: Dette er de mest populære navnene i 2022, Text lead: Navnetoppen for 2022 er klar. Se hvilke navn som er mest populære i din kommune.", "Model 3"],
84
+ ["Title: 2023 er det varmeste året noen gang registrert, Text lead: En ny rapport viser at 2023 er det varmeste året registrert siden man startet målingene. Klimaforsker kaller tallene urovekkende.", "Model 3"]
85
  ]
86
+ intf = gr.Interface(fn=classify_text, inputs=["text", gr.Dropdown(['Model 1', 'Model 2', 'Model 3'])], outputs=label, examples=examples, title=app_title)
87
  intf.launch(inline=False)