Dusan commited on
Commit
96063e3
β€’
1 Parent(s): bf92bc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -113
app.py CHANGED
@@ -1,114 +1,114 @@
1
- # import os
2
-
3
- # os.chdir('naacl-2021-fudge-controlled-generation/')
4
-
5
- import gradio as gr
6
- from fudge.predict_clickbait import generate_clickbait, tokenizer, classifier_tokenizer
7
- from datasets import load_dataset,DatasetDict,Dataset
8
- # from datasets import
9
- from transformers import AutoTokenizer,AutoModelForSeq2SeqLM
10
- import numpy as np
11
- from sklearn.model_selection import train_test_split
12
- import pandas as pd
13
- from sklearn.utils.class_weight import compute_class_weight
14
- import torch
15
- import pandas as pd
16
- from fudge.model import Model
17
- import os
18
- from argparse import ArgumentParser
19
- from collections import namedtuple
20
- import mock
21
-
22
- from tqdm import tqdm
23
- import numpy as np
24
- import torch.nn as nn
25
- import torch.nn.functional as F
26
- from data import Dataset
27
- from fudge.util import save_checkpoint, ProgressMeter, AverageMeter, num_params
28
- from fudge.constants import *
29
-
30
-
31
- # imp.reload(model)
32
- pretrained_model = "../checkpoint-150/"
33
- generation_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model, return_dict=True).to(device)
34
-
35
- device = 'cuda'
36
- pad_id = 0
37
-
38
- generation_model.eval()
39
-
40
- model_args = mock.Mock()
41
- model_args.task = 'clickbait'
42
- model_args.device = device
43
- model_args.checkpoint = '../checkpoint-1464/'
44
-
45
- # conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
46
- conditioning_model = Model(model_args, pad_id, vocab_size=None) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
47
- conditioning_model = conditioning_model.to(device)
48
- conditioning_model.eval()
49
-
50
- condition_lambda = 5.0
51
- length_cutoff = 50
52
- precondition_topk = 200
53
-
54
-
55
- conditioning_model.classifier
56
-
57
- model_args.checkpoint
58
-
59
- classifier_tokenizer = AutoTokenizer.from_pretrained(model_args.checkpoint, load_best_model_at_end=True)
60
-
61
-
62
- def rate_title(input_text, model, tokenizer, device='cuda'):
63
- # input_text = {
64
- # "postText": input_text['postText'],
65
- # "truthClass" : input_text['truthClass']
66
- # }
67
- tokenized_input = preprocess_function_title_only_classification(input_text,tokenizer=tokenizer)
68
- # print(tokenized_input.items())
69
- dict_tokenized_input = {k : torch.tensor([v]).to(device) for k,v in tokenized_input.items() if k != 'labels'}
70
- predicted_class = float(model(**dict_tokenized_input).logits)
71
- actual_class = input_text['truthClass']
72
-
73
- # print(predicted_class, actual_class)
74
- return {'predicted_class' : predicted_class}
75
-
76
- def preprocess_function_title_only_classification(examples,tokenizer=None):
77
- model_inputs = tokenizer(examples['postText'], padding="longest", truncation=True, max_length=25)
78
-
79
- model_inputs['labels'] = examples['truthClass']
80
-
81
- return model_inputs
82
-
83
-
84
-
85
- def clickbait_generator(article_content, condition_lambda=5.0):
86
- # result = "Hi {}! 😎. The Mulitple of {} is {}".format(name, number, round(number**2, 2))
87
- results = generate_clickbait(model=generation_model,
88
- tokenizer=tokenizer,
89
- conditioning_model=conditioning_model,
90
- input_text=[None],
91
- dataset_info=dataset_info,
92
- precondition_topk=precondition_topk,
93
- length_cutoff=length_cutoff,
94
- condition_lambda=condition_lambda,
95
- article_content=article_content,
96
- device=device)
97
-
98
- return results[0].replace('</s>', '').replace('<pad>', '')
99
-
100
- title = "Clickbait generator"
101
- description = """
102
- "Use the [Fudge](https://github.com/yangkevin2/naacl-2021-fudge-controlled-generation) implementation fine tuned for our purposes to try and create news headline you are looking for!"
103
- """
104
-
105
- article = "Check out [the codebase for our model](https://github.com/dsvilarkovic/naacl-2021-fudge-controlled-generation) that this demo is based off of."
106
-
107
-
108
- app = gr.Interface(
109
- title = title,
110
- description = description,
111
- label = 'Article content or paragraph',
112
- fn = clickbait_generator,
113
- inputs=["text", gr.Slider(0, 100, step=0.1, value=5.0)], outputs="text")
114
  app.launch()
 
1
+ # import os
2
+
3
+ # os.chdir('naacl-2021-fudge-controlled-generation/')
4
+
5
+ import gradio as gr
6
+ from fudge.predict_clickbait import generate_clickbait, tokenizer, classifier_tokenizer
7
+ from datasets import load_dataset,DatasetDict,Dataset
8
+ # from datasets import
9
+ from transformers import AutoTokenizer,AutoModelForSeq2SeqLM
10
+ import numpy as np
11
+ from sklearn.model_selection import train_test_split
12
+ import pandas as pd
13
+ from sklearn.utils.class_weight import compute_class_weight
14
+ import torch
15
+ import pandas as pd
16
+ from fudge.model import Model
17
+ import os
18
+ from argparse import ArgumentParser
19
+ from collections import namedtuple
20
+ import mock
21
+
22
+ from tqdm import tqdm
23
+ import numpy as np
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ from fudge.data import Dataset
27
+ from fudge.util import save_checkpoint, ProgressMeter, AverageMeter, num_params
28
+ from fudge.constants import *
29
+
30
+
31
+ # imp.reload(model)
32
+ pretrained_model = "../checkpoint-150/"
33
+ generation_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model, return_dict=True).to(device)
34
+
35
+ device = 'cuda'
36
+ pad_id = 0
37
+
38
+ generation_model.eval()
39
+
40
+ model_args = mock.Mock()
41
+ model_args.task = 'clickbait'
42
+ model_args.device = device
43
+ model_args.checkpoint = '../checkpoint-1464/'
44
+
45
+ # conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
46
+ conditioning_model = Model(model_args, pad_id, vocab_size=None) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
47
+ conditioning_model = conditioning_model.to(device)
48
+ conditioning_model.eval()
49
+
50
+ condition_lambda = 5.0
51
+ length_cutoff = 50
52
+ precondition_topk = 200
53
+
54
+
55
+ conditioning_model.classifier
56
+
57
+ model_args.checkpoint
58
+
59
+ classifier_tokenizer = AutoTokenizer.from_pretrained(model_args.checkpoint, load_best_model_at_end=True)
60
+
61
+
62
+ def rate_title(input_text, model, tokenizer, device='cuda'):
63
+ # input_text = {
64
+ # "postText": input_text['postText'],
65
+ # "truthClass" : input_text['truthClass']
66
+ # }
67
+ tokenized_input = preprocess_function_title_only_classification(input_text,tokenizer=tokenizer)
68
+ # print(tokenized_input.items())
69
+ dict_tokenized_input = {k : torch.tensor([v]).to(device) for k,v in tokenized_input.items() if k != 'labels'}
70
+ predicted_class = float(model(**dict_tokenized_input).logits)
71
+ actual_class = input_text['truthClass']
72
+
73
+ # print(predicted_class, actual_class)
74
+ return {'predicted_class' : predicted_class}
75
+
76
+ def preprocess_function_title_only_classification(examples,tokenizer=None):
77
+ model_inputs = tokenizer(examples['postText'], padding="longest", truncation=True, max_length=25)
78
+
79
+ model_inputs['labels'] = examples['truthClass']
80
+
81
+ return model_inputs
82
+
83
+
84
+
85
+ def clickbait_generator(article_content, condition_lambda=5.0):
86
+ # result = "Hi {}! 😎. The Mulitple of {} is {}".format(name, number, round(number**2, 2))
87
+ results = generate_clickbait(model=generation_model,
88
+ tokenizer=tokenizer,
89
+ conditioning_model=conditioning_model,
90
+ input_text=[None],
91
+ dataset_info=dataset_info,
92
+ precondition_topk=precondition_topk,
93
+ length_cutoff=length_cutoff,
94
+ condition_lambda=condition_lambda,
95
+ article_content=article_content,
96
+ device=device)
97
+
98
+ return results[0].replace('</s>', '').replace('<pad>', '')
99
+
100
+ title = "Clickbait generator"
101
+ description = """
102
+ "Use the [Fudge](https://github.com/yangkevin2/naacl-2021-fudge-controlled-generation) implementation fine tuned for our purposes to try and create news headline you are looking for!"
103
+ """
104
+
105
+ article = "Check out [the codebase for our model](https://github.com/dsvilarkovic/naacl-2021-fudge-controlled-generation) that this demo is based off of."
106
+
107
+
108
+ app = gr.Interface(
109
+ title = title,
110
+ description = description,
111
+ label = 'Article content or paragraph',
112
+ fn = clickbait_generator,
113
+ inputs=["text", gr.Slider(0, 100, step=0.1, value=5.0)], outputs="text")
114
  app.launch()