Spaces:
Runtime error
Runtime error
Dusan Svilarkovic
commited on
Commit
•
fc5ecba
1
Parent(s):
7872a22
Adding Fudge
Browse files- naacl-2021-fudge-controlled-generation/LICENSE +21 -0
- naacl-2021-fudge-controlled-generation/README.md +155 -0
- naacl-2021-fudge-controlled-generation/clickbait_classifier.py +128 -0
- naacl-2021-fudge-controlled-generation/constants.py +32 -0
- naacl-2021-fudge-controlled-generation/data.py +415 -0
- naacl-2021-fudge-controlled-generation/eval_formality_metrics.py +73 -0
- naacl-2021-fudge-controlled-generation/eval_poetry_metrics.py +135 -0
- naacl-2021-fudge-controlled-generation/eval_topic_metrics.py +134 -0
- naacl-2021-fudge-controlled-generation/evaluate_clickbait.py +200 -0
- naacl-2021-fudge-controlled-generation/evaluate_formality.py +104 -0
- naacl-2021-fudge-controlled-generation/evaluate_poetry.py +115 -0
- naacl-2021-fudge-controlled-generation/evaluate_topic.py +143 -0
- naacl-2021-fudge-controlled-generation/formality_data/README.md +2 -0
- naacl-2021-fudge-controlled-generation/formality_data/fisher_test_oracle.es +0 -0
- naacl-2021-fudge-controlled-generation/formality_data/test.noid.cleaned_0 +0 -0
- naacl-2021-fudge-controlled-generation/formality_data/test.noid.cleaned_1 +0 -0
- naacl-2021-fudge-controlled-generation/main.py +192 -0
- naacl-2021-fudge-controlled-generation/model.py +182 -0
- naacl-2021-fudge-controlled-generation/poetry_data/README.md +1 -0
- naacl-2021-fudge-controlled-generation/poetry_data/couplet_ends.txt +154 -0
- naacl-2021-fudge-controlled-generation/poetry_data/couplet_prefixes.txt +154 -0
- naacl-2021-fudge-controlled-generation/poetry_util.py +83 -0
- naacl-2021-fudge-controlled-generation/predict_clickbait.py +199 -0
- naacl-2021-fudge-controlled-generation/predict_formality.py +404 -0
- naacl-2021-fudge-controlled-generation/predict_poetry.py +219 -0
- naacl-2021-fudge-controlled-generation/predict_topic.py +126 -0
- naacl-2021-fudge-controlled-generation/requirements.txt +7 -0
- naacl-2021-fudge-controlled-generation/topic_data/README.md +3 -0
- naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/computers.txt +163 -0
- naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/legal.txt +108 -0
- naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/military.txt +136 -0
- naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/politics.txt +40 -0
- naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/religion.txt +207 -0
- naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/science.txt +47 -0
- naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/space.txt +16 -0
- naacl-2021-fudge-controlled-generation/topic_data/topic_prefixes.txt +20 -0
- naacl-2021-fudge-controlled-generation/topic_data/val_wordlists/fantasy.txt +26 -0
- naacl-2021-fudge-controlled-generation/topic_data/wordlists/computers.txt +176 -0
- naacl-2021-fudge-controlled-generation/topic_data/wordlists/legal.txt +131 -0
- naacl-2021-fudge-controlled-generation/topic_data/wordlists/military.txt +149 -0
- naacl-2021-fudge-controlled-generation/topic_data/wordlists/politics.txt +47 -0
- naacl-2021-fudge-controlled-generation/topic_data/wordlists/religion.txt +232 -0
- naacl-2021-fudge-controlled-generation/topic_data/wordlists/science.txt +48 -0
- naacl-2021-fudge-controlled-generation/topic_data/wordlists/space.txt +18 -0
- naacl-2021-fudge-controlled-generation/transcript.txt +415 -0
- naacl-2021-fudge-controlled-generation/util.py +110 -0
naacl-2021-fudge-controlled-generation/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Kevin Yang
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
naacl-2021-fudge-controlled-generation/README.md
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FUDGE: Controlled Text Generation With Future Discriminators
|
2 |
+
|
3 |
+
This repo contains code corresponding to the paper FUDGE: Controlled Text Generation With Future Discriminators (https://arxiv.org/abs/2104.05218) by Kevin Yang and Dan Klein, published at NAACL 2021.
|
4 |
+
|
5 |
+
You can also find a video presentation at http://somup.com/crhlVPFKN7 and the corresponding slides in `slides.pptx`.
|
6 |
+
|
7 |
+
## Setup/Installation
|
8 |
+
|
9 |
+
We tested on Python 3.8.5 but earlier versions of Python 3 are almost certainly fine. To get the required packages (other versions likely to work too):
|
10 |
+
|
11 |
+
```
|
12 |
+
pip install -r requirements.txt
|
13 |
+
```
|
14 |
+
|
15 |
+
Additionally, to get our pre-trained predictor checkpoints and training data, run:
|
16 |
+
|
17 |
+
```
|
18 |
+
wget https://naacl2021-fudge-files.s3.amazonaws.com/large_files.zip
|
19 |
+
```
|
20 |
+
|
21 |
+
and extract the zip to the top-level `lm-prediction/` folder. (There should be three folders, `ckpt/`, `train_data/`, and `topic_human_evals/`. The zip is 7GB.) Note: the zip seems to not work for some people actually, if this is the case you can get the files directly from https://drive.google.com/drive/folders/1GZfOGqpQxDmIfD2RvuhUQla9eX2OHUXU?usp=sharing (13GB).
|
22 |
+
|
23 |
+
`ckpt/` contains predictor checkpoints for each task if you are just interested in running inference. (Note that for the paper results, we used predictors trained with an older version of the code, but the new checkpoints get similar results, so you are OK to use the new predictors provided here if e.g. you just want to use FUDGE as a baseline. You can just run the evaluation commands provided below; it should take maybe 5-60 minutes depending on the task and your compute, assuming you have a GPU.)
|
24 |
+
|
25 |
+
`train_data/` contains our GPT2-generated training data for the poetry and topic tasks' predictors. See https://github.com/raosudha89/GYAFC-corpus for instructions on gaining access to the GYAFC data used for the machine translation formality task; replace our dummy folders with the corresponding folders/files if you want to train our formality predictor.
|
26 |
+
|
27 |
+
## Clickbait
|
28 |
+
To generate outputs, run:
|
29 |
+
|
30 |
+
```
|
31 |
+
python -u evaluate_clickbait.py --ckpt ckpt/topic/future_word_predictor/model.pth.tar --dataset_info ckpt/topic/future_word_predictor/dataset_info --in_file topic_data/topic_prefixes.txt --condition_lambda 4.0 --verbose --precondition_topk 200 --length_cutoff 80 --device cpu
|
32 |
+
|
33 |
+
python -u evaluate_clickbait.py --ckpt ckpt/formality/predictor_gyafc_entertainment_music/model.pth.tar --dataset_info ckpt/formality/predictor_gyafc_entertainment_music/dataset_info --in_file formality_data/fisher_test_oracle.es
|
34 |
+
|
35 |
+
python -u evaluate_clickbait.py --ckpt ckpt/topic/future_word_predictor/model.pth.tar --dataset_info ckpt/topic/future_word_predictor/dataset_info --in_file topic_data/topic_prefixes.txt --condition_lambda 4.0 --verbose --precondition_topk 200 --sample_size 3 --max_sample_batch 1 --length_cutoff 80 --log_file clickbait_preds.log
|
36 |
+
```
|
37 |
+
|
38 |
+
Then evaluate metrics using:
|
39 |
+
|
40 |
+
```
|
41 |
+
python eval_topic_metrics.py --log_file topic_preds.log --tw_dir topic_data/test_wordlists
|
42 |
+
```
|
43 |
+
|
44 |
+
|
45 |
+
## Poetry Couplet Completion
|
46 |
+
|
47 |
+
### Evaluation
|
48 |
+
|
49 |
+
To generate outputs, run:
|
50 |
+
|
51 |
+
```
|
52 |
+
python -u evaluate_poetry.py --iambic_ckpt ckpt/poetry/iambic_predictor/model.pth.tar --rhyme_ckpt ckpt/poetry/rhyme_predictor/model.pth.tar --newline_ckpt ckpt/poetry/newline_predictor/model.pth.tar --dataset_info ckpt/poetry/rhyme_predictor/dataset_info --rhyme_info ckpt/poetry/rhyme_predictor/rhyme_info --prefix_file poetry_data/couplet_prefixes.txt --precondition_topk 200 > poetry_preds.log
|
53 |
+
```
|
54 |
+
|
55 |
+
Then evaluate metrics using:
|
56 |
+
|
57 |
+
```
|
58 |
+
python eval_poetry_metrics.py --pred_file poetry_preds.log --prefix_file poetry_data/couplet_prefixes.txt
|
59 |
+
```
|
60 |
+
|
61 |
+
### Training your own predictors
|
62 |
+
|
63 |
+
Example commands for all three predictors used in the poetry task below. (You actually probably don't need so many epochs for iambic and rhyme; in any case the commands will save intermediate ckpts so you can just stop them early if needed by inspecting the log.)
|
64 |
+
|
65 |
+
Iambic predictor:
|
66 |
+
|
67 |
+
```
|
68 |
+
python -u main.py --task iambic --data_dir train_data/gpt2_generations --save_dir ckpt/poetry/iambic_retrain_predictor --num_workers 20 --batch_size 128 --epoch_max_len 100000 --validation_freq 10 --lr 2e-4 --epochs 1500 > iambic_retrain_predictor.log
|
69 |
+
```
|
70 |
+
|
71 |
+
Rhyme predictor:
|
72 |
+
|
73 |
+
```
|
74 |
+
python -u main.py --task rhyme --data_dir train_data/gpt2_generations --save_dir ckpt/poetry/rhyme_retrain_predictor --num_workers 20 --batch_size 128 --epoch_max_len 100000 --validation_freq 10 --lr 2e-4 --epochs 1500 > rhyme_retrain_predictor.log
|
75 |
+
```
|
76 |
+
|
77 |
+
End of sentence predictor (referred to as "newline" in the code; 50 epochs is more than enough for this one):
|
78 |
+
|
79 |
+
```
|
80 |
+
python -u main.py --task newline --data_dir train_data/gpt2_generations --save_dir ckpt/poetry/newline_retrain_predictor --num_workers 20 --batch_size 128 --epoch_max_len 100000 --validation_freq 10 --lr 2e-4 --epochs 50 > newline_retrain_predictor.log
|
81 |
+
```
|
82 |
+
|
83 |
+
The same evaluation commands as before will work; just modify the paths in the command to point to `model_best.pth.tar`, `dataset_info`, and `rhyme_info` from your newly trained ckpt folders.
|
84 |
+
|
85 |
+
## Topic Control
|
86 |
+
|
87 |
+
### Evaluation
|
88 |
+
|
89 |
+
To generate outputs, run:
|
90 |
+
|
91 |
+
```
|
92 |
+
python -u evaluate_topic.py --ckpt ckpt/topic/future_word_predictor/model.pth.tar --dataset_info ckpt/topic/future_word_predictor/dataset_info --prefix_file topic_data/topic_prefixes.txt --wordlist_dir topic_data/wordlists --condition_lambda 4.0 --verbose --precondition_topk 200 --topk 10 --sample_size 3 --max_sample_batch 1 --length_cutoff 80 --log_file topic_preds.log
|
93 |
+
```
|
94 |
+
|
95 |
+
Then evaluate metrics using:
|
96 |
+
|
97 |
+
```
|
98 |
+
python eval_topic_metrics.py --log_file topic_preds.log --tw_dir topic_data/test_wordlists
|
99 |
+
```
|
100 |
+
|
101 |
+
You can also find our original generations and baselines in `topic_human_evals/`.
|
102 |
+
|
103 |
+
### Training your own predictors
|
104 |
+
|
105 |
+
Example command below.
|
106 |
+
|
107 |
+
```
|
108 |
+
python -u main.py --task topic --data_dir train_data/gpt2_generations --save_dir ckpt/topic/future_word_retrain_predictor --num_workers 20 --batch_size 128 --epoch_max_len 100000 --validation_freq 10 --lr 2e-4 --epochs 500 --glove_file train_data/glove.840B.300d.txt > future_word_retrain_predictor.log
|
109 |
+
```
|
110 |
+
|
111 |
+
The same evaluation commands as before will work; just modify the paths in the command to point to `model_best.pth.tar`, `dataset_info`, and `rhyme_info` from your newly trained ckpt folders.
|
112 |
+
|
113 |
+
## Machine Translation Formality
|
114 |
+
|
115 |
+
### Evaluation
|
116 |
+
|
117 |
+
To generate outputs, run:
|
118 |
+
|
119 |
+
```
|
120 |
+
python -u evaluate_formality.py --ckpt ckpt/formality/predictor_gyafc_entertainment_music/model.pth.tar --dataset_info ckpt/formality/predictor_gyafc_entertainment_music/dataset_info --in_file formality_data/fisher_test_oracle.es --model_path ckpt/formality/marian_finetune_fisher > formality_preds.log
|
121 |
+
```
|
122 |
+
|
123 |
+
The above command generates predictions using the Marian model finetuned on the Fisher dataset; remove the `--model_path` argument to get predictions with the un-finetuned Marian model from HuggingFace (referred to as 0-shot in the paper)
|
124 |
+
|
125 |
+
Then evaluate metrics using:
|
126 |
+
|
127 |
+
```
|
128 |
+
python eval_formality_metrics.py --pred formality_preds.log --ref formality_data/test.noid.cleaned_0 formality_data/test.noid.cleaned_1 --ckpt ckpt/formality/test_evaluator_gyafc_family_relationships/model.pth.tar --dataset_info ckpt/formality/test_evaluator_gyafc_family_relationships/dataset_info
|
129 |
+
```
|
130 |
+
|
131 |
+
### Training your own predictors
|
132 |
+
|
133 |
+
Example command below. (Reminder: you need to go get the GYAFC dataset following the instructions in https://github.com/raosudha89/GYAFC-corpus.)
|
134 |
+
|
135 |
+
```
|
136 |
+
python -u main.py --task formality --data_dir train_data/GYAFC_Corpus/Entertainment_Music --save_dir ckpt/formality/formality_retrain_predictor --num_workers 20 --batch_size 32 --epoch_max_len 1000000 --validation_freq 1 --lr 2e-5 --epochs 20 > formality_retrain_predictor.log
|
137 |
+
```
|
138 |
+
|
139 |
+
(The test-time formality evaluator is trained in the same way, just using the Family/Relationships half of the GYAFC dataset.)
|
140 |
+
|
141 |
+
The same evaluation commands as before will work; just modify the paths in the command to point to `model_best.pth.tar`, `dataset_info`, and `rhyme_info` from your newly trained ckpt folders.
|
142 |
+
|
143 |
+
## Running FUDGE on your own data
|
144 |
+
|
145 |
+
The code has been refactored so that the iambic (poetry), rhyme (poetry), newline (poetry), future word (topic), and formality (machine translation) are controlled by the `--task` flag to `main.py`. You should add your task as another option here, then modify the data processing in `data.py` and the model in `model.py` as needed for your task. (In `data.py` you probably won't need all the entries of the tuple that is expected of the loader; you can just put dummy entries in the ones you don't need.) You might also need to modify the loss computation in the `train` and `validate` functions in `main.py`. You'll probably want to write new evaluation scripts, though the existing poetry/topic/formality ones are hopefully helpful as references.
|
146 |
+
|
147 |
+
Alternatively, the general FUDGE framework is pretty simple, so you could always try reimplementing things yourself. A few additional details based on questions I've received:
|
148 |
+
|
149 |
+
(1) The formality task setup is likely closest to what you want if you're just trying to run the simplest form of FUDGE (take a language model, and use a classifier to optimize toward a single attribute) although you may need to swap out the Marian translation model/tokenizer we use.
|
150 |
+
|
151 |
+
(2) When you construct your training data, if you have an example in your data e.g. "This movie is great!" for positive sentiment, you want to learn on all the pairs (This, +), (This movie, +), (This movie is, +), etc., as that's one of the main points of our approach.
|
152 |
+
|
153 |
+
(3) For computational efficiency, we first filter the base model's next token probabilities down to the top 200 (Sec. 3.1 in the paper), before adding the classifier logits. This way you only need to evaluate your classifier on 200 continuations. Then afterward, you filter down again to whatever top-k/greedy/nucleus sampling you're using for evaluation (we use top-k with k=10 for poetry and topic, greedy for formality).
|
154 |
+
|
155 |
+
(4) You can use a pretrained LM backbone instead of a simple LSTM backbone for the predictor as well. This should work better when your dataset is smaller.
|
naacl-2021-fudge-controlled-generation/clickbait_classifier.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import BertModel, BertConfig, PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
|
3 |
+
from typing import List, Optional, Tuple, Union
|
4 |
+
from transformers.modeling_outputs import TokenClassifierOutput,SequenceClassifierOutput
|
5 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss, BCELoss
|
6 |
+
import torch.nn as nn
|
7 |
+
# from modeling_mpnet import MPNetModel, MPnetConfig
|
8 |
+
|
9 |
+
class ClickbaitConfig(PretrainedConfig):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
model_type: str = "bert",
|
13 |
+
pretrained_model: str = "bert-base-uncased",
|
14 |
+
num_labels: int = 1,
|
15 |
+
dropout: float = 0.1,
|
16 |
+
inner_dim1: int = 256,
|
17 |
+
inner_dim2: int = 32,
|
18 |
+
max_length: int = 512,
|
19 |
+
load_pretrained: bool = True,
|
20 |
+
freeze_bert: bool = True,
|
21 |
+
**kwargs
|
22 |
+
):
|
23 |
+
super(ClickbaitConfig, self).__init__(num_labels=num_labels, **kwargs)
|
24 |
+
self.model_type = model_type
|
25 |
+
self.pretrained_model = pretrained_model
|
26 |
+
self.dropout = dropout
|
27 |
+
self.inner_dim1 = inner_dim1
|
28 |
+
self.inner_dim2 = inner_dim2
|
29 |
+
self.max_length = max_length
|
30 |
+
self.load_pretrained = load_pretrained
|
31 |
+
self.freeze_bert = freeze_bert
|
32 |
+
|
33 |
+
|
34 |
+
class BertClickbaitClassifier(PreTrainedModel):
|
35 |
+
"""
|
36 |
+
Taken and extended from BertforSequenceClassification : https://github.com/huggingface/transformers/blob/v4.19.2/src/transformers/models/bert/modeling_bert.py#L1508
|
37 |
+
"""
|
38 |
+
config_class = ClickbaitConfig
|
39 |
+
def __init__(self, config: ClickbaitConfig):
|
40 |
+
super(BertClickbaitClassifier, self).__init__(config)
|
41 |
+
self.num_labels = config.num_labels
|
42 |
+
self.config = config
|
43 |
+
# self.bert_config = BertConfig.from_pretrained(config.pretrained_model)
|
44 |
+
self.bert_config = AutoConfig.from_pretrained(config.pretrained_model)
|
45 |
+
|
46 |
+
# self.bert = BertModel(self.bert_config)
|
47 |
+
self.bert = AutoModel.from_pretrained(config.pretrained_model, config=self.bert_config)
|
48 |
+
# self.bert = SentenceTransformer(config.pretrained_model, config=self.bert_config)
|
49 |
+
# self.bert = MPNetModel(config.pretrained_model, config=self.bert_config)
|
50 |
+
if config.load_pretrained:
|
51 |
+
print("Load pretrained weights from {}".format(config.pretrained_model))
|
52 |
+
self.bert = self.bert.from_pretrained(config.pretrained_model)
|
53 |
+
if config.freeze_bert:
|
54 |
+
print("Freeze weights in the BERT model. Just the classifier will be trained")
|
55 |
+
for param in self.bert.parameters():
|
56 |
+
param.requires_grad = False
|
57 |
+
|
58 |
+
self.linear_1 = nn.Linear(self.bert.config.hidden_size, config.inner_dim1)
|
59 |
+
self.dropout_1 = nn.Dropout(config.dropout)
|
60 |
+
self.relu_1 = nn.ReLU()
|
61 |
+
self.dropout_2 = nn.Dropout(config.dropout)
|
62 |
+
self.linear_2 = nn.Linear(config.inner_dim1, config.inner_dim2)
|
63 |
+
self.relu_2 = nn.ReLU()
|
64 |
+
self.dropout_3 = nn.Dropout(config.dropout)
|
65 |
+
self.classifier = nn.Linear(config.inner_dim2, config.num_labels)
|
66 |
+
self.sigmoid = nn.Sigmoid()
|
67 |
+
|
68 |
+
|
69 |
+
def forward(
|
70 |
+
self,
|
71 |
+
input_ids: Optional[torch.Tensor] = None,
|
72 |
+
attention_mask: Optional[torch.Tensor] = None,
|
73 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
74 |
+
position_ids: Optional[torch.Tensor] = None,
|
75 |
+
head_mask: Optional[torch.Tensor] = None,
|
76 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
77 |
+
labels: Optional[torch.Tensor] = None,
|
78 |
+
output_attentions: Optional[bool] = None,
|
79 |
+
output_hidden_states: Optional[bool] = None,
|
80 |
+
return_dict: Optional[bool] = None,
|
81 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
82 |
+
r"""
|
83 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
84 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
85 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
86 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
87 |
+
"""
|
88 |
+
|
89 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
90 |
+
|
91 |
+
outputs = self.bert(
|
92 |
+
input_ids,
|
93 |
+
attention_mask=attention_mask,
|
94 |
+
token_type_ids=token_type_ids,
|
95 |
+
position_ids=position_ids,
|
96 |
+
head_mask=head_mask,
|
97 |
+
inputs_embeds=inputs_embeds,
|
98 |
+
output_attentions=output_attentions,
|
99 |
+
output_hidden_states=output_hidden_states,
|
100 |
+
return_dict=return_dict,
|
101 |
+
)
|
102 |
+
|
103 |
+
output = outputs[0][:,0,:]
|
104 |
+
|
105 |
+
x = self.dropout_1(output)
|
106 |
+
x = self.linear_1(x)
|
107 |
+
x = self.relu_1(x)
|
108 |
+
x = self.dropout_2(x)
|
109 |
+
x = self.linear_2(x)
|
110 |
+
x = self.relu_2(x)
|
111 |
+
x = self.dropout_3(x)
|
112 |
+
|
113 |
+
logits = self.classifier(x)
|
114 |
+
logits = self.sigmoid(logits)
|
115 |
+
|
116 |
+
loss = None
|
117 |
+
if labels is not None:
|
118 |
+
loss_fct = BCELoss(weight=WEIGHT)
|
119 |
+
labels = 1.0*labels
|
120 |
+
loss = loss_fct(logits.view(-1), labels.view(-1))
|
121 |
+
if not return_dict:
|
122 |
+
output = (logits,) + outputs[2:]
|
123 |
+
return ((loss,) + output) if loss is not None else output
|
124 |
+
|
125 |
+
return SequenceClassifierOutput(
|
126 |
+
loss=loss,
|
127 |
+
logits=logits
|
128 |
+
)
|
naacl-2021-fudge-controlled-generation/constants.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PAD_TOKEN = '[PAD]'
|
2 |
+
EOT_TOKEN = '<|endoftext|>'
|
3 |
+
SEP = 50256 # just use the weird eot token
|
4 |
+
|
5 |
+
TOPIC_MODEL_STRING = 'gpt2-medium'
|
6 |
+
FORMALITY_MODEL_STRING = 'Helsinki-NLP/opus-mt-es-en'
|
7 |
+
|
8 |
+
DIR_END_SPLIT_POSITIONS = 32
|
9 |
+
|
10 |
+
TOPIC_VAL_SIZE = 100000
|
11 |
+
FORMALITY_VAL_SIZE = 2000
|
12 |
+
VOCAB_SIZE = 50000
|
13 |
+
|
14 |
+
FORMALITY_MAX_LEN = 200
|
15 |
+
|
16 |
+
GLOVE_PRINT_PROGRESS_FREQ = 1000000
|
17 |
+
GLOVE_DIM = 300
|
18 |
+
HIDDEN_DIM = 300
|
19 |
+
RNN_DIM = 150
|
20 |
+
|
21 |
+
MIN_SENTENCE_LENGTH = 3
|
22 |
+
|
23 |
+
POETRY_LINE_SYLLABLES = 10
|
24 |
+
MAX_SYLLABLES_PER_WORD = 10 # no way anything is more
|
25 |
+
MAX_COUNT_SYLLABLE_DIST = 10
|
26 |
+
MAX_COUNT_SYLLABLE_INPUT_LENGTH = 25 # for just a couplet, shouldn't need more
|
27 |
+
COUNT_SYLLABLE_DIM = 100
|
28 |
+
UNKNOWN_RHYME_GROUP = 'UNKNOWN_RHYME_GROUP'
|
29 |
+
PHRASE_ENDS = '.?!'
|
30 |
+
|
31 |
+
POETRY_BANNED_TOKENS = [198, 50256, 628, 220] # newlines and eos and such
|
32 |
+
|
naacl-2021-fudge-controlled-generation/data.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
from collections import defaultdict, namedtuple
|
6 |
+
import string
|
7 |
+
|
8 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false' # turn off since we're using multiple threads for loading anyway
|
9 |
+
|
10 |
+
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
|
11 |
+
import numpy as np
|
12 |
+
from tqdm import tqdm
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from util import suppress_stdout
|
16 |
+
from poetry_util import is_iambic, count_syllables, get_rhymes, get_rhyme_group
|
17 |
+
from constants import *
|
18 |
+
|
19 |
+
DatasetInfo = namedtuple('DatasetInfo',
|
20 |
+
['index2word', 'word2index', 'total_words', 'vocab', 'glove_embeddings'])
|
21 |
+
RhymeInfo = namedtuple('RhymeInfo',
|
22 |
+
['word2rhyme_group', 'rhyme_group_counts', 'rhyme_groups', 'index2rhyme_group', 'rhyme_group2index', 'total_rhyme_groups'])
|
23 |
+
|
24 |
+
def collate(batch):
|
25 |
+
pad_id = batch[0][4]
|
26 |
+
inputs = [b[0] for b in batch]
|
27 |
+
lengths = torch.LongTensor([b[1] for b in batch])
|
28 |
+
max_length = lengths.max()
|
29 |
+
for i in range(len(inputs)):
|
30 |
+
if len(inputs[i]) < max_length:
|
31 |
+
inputs[i] = torch.cat([inputs[i], torch.zeros(max_length - len(inputs[i])).long()], dim=0) # actually 0 is fine as pad since it's masked out
|
32 |
+
inputs = torch.stack(inputs, dim=0)
|
33 |
+
future_words = torch.LongTensor([b[2] for b in batch]).unsqueeze(0).expand(len(batch), -1).clone() # batch x N=batch
|
34 |
+
labels = torch.zeros_like(future_words).long()
|
35 |
+
labels = labels.scatter(1, torch.arange(len(batch)).unsqueeze(1), torch.ones(len(batch)).long().unsqueeze(1)).clone()
|
36 |
+
log_probs = torch.Tensor([b[3] for b in batch])
|
37 |
+
classification_labels = [b[5] for b in batch] # batch
|
38 |
+
if type(classification_labels[0]) == list:
|
39 |
+
for i in range(len(classification_labels)):
|
40 |
+
assert len(classification_labels[i]) == lengths[i]
|
41 |
+
if len(classification_labels[i]) < max_length:
|
42 |
+
classification_labels[i] = torch.cat([torch.LongTensor(classification_labels[i]), -1 + torch.zeros(max_length - len(classification_labels[i])).long()], dim=0)
|
43 |
+
else:
|
44 |
+
classification_labels[i] = torch.LongTensor(classification_labels[i])
|
45 |
+
classification_labels = torch.stack(classification_labels, dim=0) # batch x seq
|
46 |
+
else:
|
47 |
+
assert type(classification_labels[0]) == int
|
48 |
+
classification_labels = torch.LongTensor(classification_labels) # they're just int labels
|
49 |
+
syllables_to_go = torch.LongTensor([b[6] for b in batch])
|
50 |
+
future_word_num_syllables = torch.LongTensor([b[7] for b in batch])
|
51 |
+
rhyme_group_index = torch.LongTensor([b[8] for b in batch])
|
52 |
+
return (inputs, lengths, future_words, log_probs, labels, classification_labels, syllables_to_go, future_word_num_syllables, rhyme_group_index)
|
53 |
+
|
54 |
+
|
55 |
+
def load_rhyme_info(index2word, vocab):
|
56 |
+
word2rhyme_group = defaultdict(lambda: UNKNOWN_RHYME_GROUP)
|
57 |
+
rhyme_group_counts = defaultdict(lambda: 0)
|
58 |
+
rhyme_groups = set()
|
59 |
+
for word in index2word:
|
60 |
+
try:
|
61 |
+
rhyme_group = get_rhyme_group(word)
|
62 |
+
word2rhyme_group[word] = rhyme_group
|
63 |
+
rhyme_group_counts[rhyme_group] += (vocab[word] if word in vocab else 1) # for rare words not in vocab, just use 1
|
64 |
+
rhyme_groups.add(rhyme_group)
|
65 |
+
except:
|
66 |
+
rhyme_group_counts[UNKNOWN_RHYME_GROUP] += (vocab[word] if word in vocab else 1)
|
67 |
+
index2rhyme_group = [UNKNOWN_RHYME_GROUP] + sorted(list(rhyme_groups))
|
68 |
+
rhyme_group2index = {s: i for i, s in enumerate(index2rhyme_group)}
|
69 |
+
total_rhyme_groups = sum(rhyme_group_counts.values())
|
70 |
+
|
71 |
+
return RhymeInfo(word2rhyme_group=dict(word2rhyme_group),
|
72 |
+
rhyme_group_counts=dict(rhyme_group_counts),
|
73 |
+
rhyme_groups=rhyme_groups,
|
74 |
+
index2rhyme_group=index2rhyme_group,
|
75 |
+
rhyme_group2index=rhyme_group2index,
|
76 |
+
total_rhyme_groups=total_rhyme_groups)
|
77 |
+
|
78 |
+
|
79 |
+
class Dataset:
|
80 |
+
def __init__(self, args):
|
81 |
+
print('loading data')
|
82 |
+
random.seed(args.seed)
|
83 |
+
self.batch_size = args.batch_size
|
84 |
+
self.data_dir = args.data_dir
|
85 |
+
self.topic = args.task == 'topic'
|
86 |
+
self.formality = args.task == 'formality'
|
87 |
+
self.iambic = args.task == 'iambic'
|
88 |
+
self.rhyme = args.task == 'rhyme'
|
89 |
+
self.newline = args.task == 'newline'
|
90 |
+
|
91 |
+
self.tokenizer = AutoTokenizer.from_pretrained(FORMALITY_MODEL_STRING if self.formality else TOPIC_MODEL_STRING)
|
92 |
+
self.tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
|
93 |
+
self.gpt_pad_id = self.tokenizer.encode(PAD_TOKEN)[0] # actually just the vocab size
|
94 |
+
sentences = []
|
95 |
+
self.vocab = defaultdict(lambda: 0)
|
96 |
+
if self.formality:
|
97 |
+
self.vocab['placeholder'] = 1 # anything so we don't crash
|
98 |
+
train, val, test = [], [], []
|
99 |
+
for category, label in [('formal', 1), ('informal', 0)]:
|
100 |
+
with open(os.path.join(args.data_dir, 'train', category), 'r') as rf:
|
101 |
+
for i, line in enumerate(rf):
|
102 |
+
if len(line) > FORMALITY_MAX_LEN:
|
103 |
+
line = ' '.join(line.strip()[:FORMALITY_MAX_LEN].split()[:-1]) # cutoff words until below max len; chosen so only ~20 examples affected in dataset
|
104 |
+
if i < FORMALITY_VAL_SIZE // 2:
|
105 |
+
val.append((line.strip(), label))
|
106 |
+
else:
|
107 |
+
train.append((line.strip(), label))
|
108 |
+
with open(os.path.join(args.data_dir, 'test', category), 'r') as rf:
|
109 |
+
for line in rf:
|
110 |
+
if len(line) > FORMALITY_MAX_LEN:
|
111 |
+
line = ' '.join(line.strip()[:FORMALITY_MAX_LEN].split()[:-1]) # cutoff words until below max len
|
112 |
+
test.append((line.strip(), label))
|
113 |
+
self.splits = {}
|
114 |
+
self.splits['train'], self.splits['val'], self.splits['test'] = train, val, test
|
115 |
+
else: # topic / poetry
|
116 |
+
for root, _, filenames in os.walk(args.data_dir):
|
117 |
+
for fname in filenames:
|
118 |
+
with open(os.path.join(root, fname), 'r') as rf:
|
119 |
+
for line in rf:
|
120 |
+
sentences.append(line.strip())
|
121 |
+
for word in line.strip().split(' '):
|
122 |
+
self.vocab[word] += 1
|
123 |
+
random.shuffle(sentences)
|
124 |
+
self.splits = {}
|
125 |
+
if args.debug:
|
126 |
+
self.splits['val'] = sentences
|
127 |
+
self.splits['test'] = sentences
|
128 |
+
self.splits['train'] = sentences
|
129 |
+
else:
|
130 |
+
self.splits['val'] = sentences[:TOPIC_VAL_SIZE]
|
131 |
+
self.splits['test'] = sentences[TOPIC_VAL_SIZE:2*TOPIC_VAL_SIZE]
|
132 |
+
self.splits['train'] = sentences[2*TOPIC_VAL_SIZE:]
|
133 |
+
|
134 |
+
if args.dataset_info is not None:
|
135 |
+
print('loading dataset info from file')
|
136 |
+
with open(args.dataset_info, 'rb') as rf:
|
137 |
+
dataset_info = pickle.load(rf)
|
138 |
+
self.vocab, self.total_words, self.index2word, self.word2index, self.glove_embeddings = \
|
139 |
+
dataset_info.vocab, dataset_info.total_words, dataset_info.index2word, dataset_info.word2index, dataset_info.glove_embeddings
|
140 |
+
self.dataset_info = dataset_info
|
141 |
+
else:
|
142 |
+
print('generating dataset info from scratch')
|
143 |
+
words_values = list(self.vocab.items())
|
144 |
+
words_values = sorted(words_values, key=lambda x: x[1], reverse=True)
|
145 |
+
if args.glove_file is None:
|
146 |
+
print('no glove embeddings given')
|
147 |
+
for word, _ in words_values[VOCAB_SIZE:]: # only use somewhat common tokens
|
148 |
+
del self.vocab[word]
|
149 |
+
glove_embeddings = None
|
150 |
+
else:
|
151 |
+
print('loading glove embeddings')
|
152 |
+
glove_embeddings = {}
|
153 |
+
with open(args.glove_file, 'r') as rf:
|
154 |
+
for i, line in enumerate(rf):
|
155 |
+
if i % GLOVE_PRINT_PROGRESS_FREQ == 0:
|
156 |
+
print(i)
|
157 |
+
line = line.strip().split()
|
158 |
+
if len(line) != GLOVE_DIM + 1:
|
159 |
+
continue # skip multi-word embeddings which are rare anyway
|
160 |
+
glove_embeddings[line[0]] = [float(x) for x in line[1:]]
|
161 |
+
for word, _ in words_values:
|
162 |
+
if word not in glove_embeddings:
|
163 |
+
del self.vocab[word]
|
164 |
+
self.total_words = sum(self.vocab.values())
|
165 |
+
self.index2word = [PAD_TOKEN] + sorted(list(self.vocab.keys()))
|
166 |
+
self.word2index = {s: i for i, s in enumerate(self.index2word)}
|
167 |
+
self.vocab = dict(self.vocab) # so we can pickle later
|
168 |
+
if glove_embeddings is None:
|
169 |
+
self.glove_embeddings = None
|
170 |
+
else:
|
171 |
+
self.glove_embeddings = torch.stack([torch.zeros(GLOVE_DIM)] + [torch.Tensor(glove_embeddings[word]) for word in self.index2word[1:]], dim=0)
|
172 |
+
|
173 |
+
self.dataset_info = DatasetInfo(index2word=self.index2word,
|
174 |
+
word2index=self.word2index,
|
175 |
+
total_words=self.total_words,
|
176 |
+
vocab=self.vocab,
|
177 |
+
glove_embeddings=self.glove_embeddings)
|
178 |
+
|
179 |
+
if self.rhyme:
|
180 |
+
if args.rhyme_info is not None:
|
181 |
+
print('loading rhyme info from file')
|
182 |
+
with open(args.rhyme_info, 'rb') as rf:
|
183 |
+
self.rhyme_info = pickle.load(rf)
|
184 |
+
else:
|
185 |
+
self.rhyme_info = load_rhyme_info(self.index2word, self.vocab)
|
186 |
+
self.word2rhyme_group, self.rhyme_group_counts, self.rhyme_groups, self.index2rhyme_group, self.rhyme_group2index, self.total_rhyme_groups = \
|
187 |
+
defaultdict(lambda: UNKNOWN_RHYME_GROUP, self.rhyme_info.word2rhyme_group), self.rhyme_info.rhyme_group_counts, self.rhyme_info.rhyme_groups, self.rhyme_info.index2rhyme_group, self.rhyme_info.rhyme_group2index, self.rhyme_info.total_rhyme_groups
|
188 |
+
|
189 |
+
print('done loading data')
|
190 |
+
print('split sizes:')
|
191 |
+
for key in ['train', 'val', 'test']:
|
192 |
+
print(key, len(self.splits[key]))
|
193 |
+
if not self.formality:
|
194 |
+
print('total words', self.total_words)
|
195 |
+
print('vocab size', len(self.index2word))
|
196 |
+
|
197 |
+
|
198 |
+
def shuffle(self, split, seed=None):
|
199 |
+
assert split in ['train', 'val', 'test']
|
200 |
+
if seed is not None:
|
201 |
+
random.seed(seed)
|
202 |
+
random.shuffle(self.splits[split])
|
203 |
+
|
204 |
+
|
205 |
+
def loader(self, split, num_workers=20, indices=None):
|
206 |
+
assert split in ['train', 'val', 'test']
|
207 |
+
data = self.splits[split] if indices is None else [self.splits[split][i] for i in indices]
|
208 |
+
return torch.utils.data.DataLoader(SplitLoader(data, self), batch_size=self.batch_size, pin_memory=True, collate_fn=collate, num_workers=num_workers)
|
209 |
+
|
210 |
+
|
211 |
+
class SplitLoader(torch.utils.data.IterableDataset):
|
212 |
+
def __init__(self, data, parent):
|
213 |
+
super(SplitLoader).__init__()
|
214 |
+
self.data = data
|
215 |
+
self.pos = 0
|
216 |
+
self.parent = parent
|
217 |
+
|
218 |
+
|
219 |
+
def __len__(self):
|
220 |
+
return len(self.data)
|
221 |
+
|
222 |
+
|
223 |
+
def __iter__(self):
|
224 |
+
return self
|
225 |
+
|
226 |
+
|
227 |
+
def __next__(self):
|
228 |
+
increment = 1
|
229 |
+
worker_info = torch.utils.data.get_worker_info()
|
230 |
+
if worker_info is not None: # # in a worker process
|
231 |
+
increment = worker_info.num_workers
|
232 |
+
worker_id = worker_info.id
|
233 |
+
if self.pos == 0:
|
234 |
+
self.pos = worker_id
|
235 |
+
valid = False
|
236 |
+
while not valid:
|
237 |
+
if self.pos >= len(self):
|
238 |
+
raise StopIteration
|
239 |
+
if self.parent.topic:
|
240 |
+
failed = False
|
241 |
+
future_word_num_syllables, rhyme_group_index, syllables_to_go = -1, -1, -1
|
242 |
+
raw_sentence, classification_label = self.data[self.pos], -1
|
243 |
+
original_sentence = raw_sentence.split()
|
244 |
+
sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
|
245 |
+
length = len(sentence)
|
246 |
+
min_sentence_length = MIN_SENTENCE_LENGTH
|
247 |
+
if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
|
248 |
+
pos_to_split = random.randint(1, length - 1) # for lm, learn all positions at once
|
249 |
+
inp = sentence[:pos_to_split]
|
250 |
+
length = len(inp)
|
251 |
+
num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
|
252 |
+
if not failed and num_words_in_input < len(original_sentence):
|
253 |
+
future_word_position_max = len(original_sentence) - 1
|
254 |
+
future_word_position = random.randint(num_words_in_input-1, future_word_position_max) # allow the last possibly partial word though
|
255 |
+
future_word = original_sentence[future_word_position]
|
256 |
+
unstripped_future_word = future_word
|
257 |
+
future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
|
258 |
+
if not failed and future_word in self.parent.word2index.keys():
|
259 |
+
word_log_prob = math.log(self.parent.vocab[future_word] / self.parent.total_words) # roughly baseline prob of word under noise model
|
260 |
+
future_word = self.parent.word2index[future_word]
|
261 |
+
pad_id = self.parent.gpt_pad_id
|
262 |
+
example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
|
263 |
+
valid = not failed
|
264 |
+
elif self.parent.formality:
|
265 |
+
future_word_num_syllables, rhyme_group_index, syllables_to_go = -1, -1, -1
|
266 |
+
raw_sentence, classification_label = self.data[self.pos]
|
267 |
+
original_sentence = raw_sentence.split()
|
268 |
+
sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
|
269 |
+
length = len(sentence)
|
270 |
+
min_sentence_length = MIN_SENTENCE_LENGTH
|
271 |
+
if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
|
272 |
+
pos_to_split = length # no need to split; we're going to train on all possible prefixes simultaneously for efficiency
|
273 |
+
inp = sentence[:pos_to_split]
|
274 |
+
length = len(inp)
|
275 |
+
num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
|
276 |
+
# only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway
|
277 |
+
future_word_position_max = len(original_sentence) - 1
|
278 |
+
future_word_position = 0
|
279 |
+
future_word = 'placeholder'
|
280 |
+
unstripped_future_word = future_word
|
281 |
+
future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
|
282 |
+
word_log_prob, future_word = 0, 0
|
283 |
+
pad_id = self.parent.gpt_pad_id
|
284 |
+
example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
|
285 |
+
valid = True
|
286 |
+
elif self.parent.iambic:
|
287 |
+
failed = False
|
288 |
+
future_word_num_syllables, rhyme_group_index, syllables_to_go = -1, -1, -1
|
289 |
+
raw_sentence, classification_label = self.data[self.pos], -1
|
290 |
+
original_sentence = raw_sentence.split()
|
291 |
+
sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
|
292 |
+
length = len(sentence)
|
293 |
+
min_sentence_length = MIN_SENTENCE_LENGTH
|
294 |
+
if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
|
295 |
+
pos_to_split = random.randint(0, length - 1)
|
296 |
+
# try to get a subseq of exactly 10 syllables
|
297 |
+
inp = sentence[pos_to_split:]
|
298 |
+
num_syllables = 0
|
299 |
+
checked = False
|
300 |
+
for i in range(1, len(inp)):
|
301 |
+
decoded = self.parent.tokenizer.decode(inp[:i])
|
302 |
+
num_syllables = count_syllables(decoded)
|
303 |
+
if num_syllables > POETRY_LINE_SYLLABLES:
|
304 |
+
inp = inp[:i-1] # might get a few data points where the split is in the middle of a word, but it should be ok for learning.
|
305 |
+
last_line_length = i-1
|
306 |
+
decoded = self.parent.tokenizer.decode(inp)
|
307 |
+
num_syllables = count_syllables(decoded)
|
308 |
+
checked = True
|
309 |
+
break
|
310 |
+
if not checked or num_syllables != POETRY_LINE_SYLLABLES:
|
311 |
+
failed = True
|
312 |
+
length = len(inp)
|
313 |
+
num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
|
314 |
+
classification_label = [is_iambic(self.parent.tokenizer.decode(inp)) for _ in range(length)] # predict for whole seq including future
|
315 |
+
# only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway
|
316 |
+
future_word_position_max = len(original_sentence) - 1
|
317 |
+
future_word_position = 0
|
318 |
+
future_word = 'placeholder'
|
319 |
+
unstripped_future_word = future_word
|
320 |
+
future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
|
321 |
+
if not failed:
|
322 |
+
word_log_prob, future_word = 0, 0
|
323 |
+
pad_id = self.parent.gpt_pad_id
|
324 |
+
example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
|
325 |
+
valid = not failed
|
326 |
+
elif self.parent.rhyme:
|
327 |
+
failed = False
|
328 |
+
future_word_num_syllables, rhyme_group_index = -1, -1
|
329 |
+
raw_sentence, classification_label = self.data[self.pos], -1
|
330 |
+
original_sentence = raw_sentence.split()
|
331 |
+
sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
|
332 |
+
length = len(sentence)
|
333 |
+
min_sentence_length = MIN_SENTENCE_LENGTH
|
334 |
+
if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
|
335 |
+
pos_to_split = random.randint(1, length - 1) # for lm, learn all positions at once
|
336 |
+
inp = sentence[:pos_to_split]
|
337 |
+
length = len(inp)
|
338 |
+
num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
|
339 |
+
if not failed and num_words_in_input < len(original_sentence):
|
340 |
+
# only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway
|
341 |
+
future_word_position_max = min(len(original_sentence) - 1, num_words_in_input + MAX_COUNT_SYLLABLE_DIST)
|
342 |
+
future_word_position = random.randint(num_words_in_input-1, future_word_position_max) # allow the last possibly partial word though
|
343 |
+
future_word = original_sentence[future_word_position]
|
344 |
+
unstripped_future_word = future_word
|
345 |
+
future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
|
346 |
+
|
347 |
+
words_in_between = original_sentence[num_words_in_input-1:future_word_position+1]
|
348 |
+
syllables_to_go = count_syllables(' '.join(words_in_between))
|
349 |
+
if syllables_to_go > MAX_COUNT_SYLLABLE_DIST:
|
350 |
+
failed = True
|
351 |
+
future_word_num_syllables = count_syllables(future_word)
|
352 |
+
rhyme_group = self.parent.word2rhyme_group[future_word]
|
353 |
+
rhyme_group_index = self.parent.rhyme_group2index[rhyme_group]
|
354 |
+
# truncate context a bit since we're just doing couplets. random length from 1 to max desired length for this purpose.
|
355 |
+
desired_length = random.randint(1, MAX_COUNT_SYLLABLE_INPUT_LENGTH)
|
356 |
+
inp = inp[-desired_length:]
|
357 |
+
length = len(inp)
|
358 |
+
|
359 |
+
if not failed and future_word in self.parent.word2index.keys():
|
360 |
+
word_log_prob = math.log(self.parent.rhyme_group_counts[rhyme_group] / self.parent.total_rhyme_groups)
|
361 |
+
future_word = rhyme_group_index # future conditioning is just the rhyme group in this case
|
362 |
+
pad_id = self.parent.gpt_pad_id
|
363 |
+
example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
|
364 |
+
valid = not failed
|
365 |
+
elif self.parent.newline:
|
366 |
+
failed = False
|
367 |
+
future_word_num_syllables, rhyme_group_index = -1, -1
|
368 |
+
raw_sentence, classification_label = self.data[self.pos], -1
|
369 |
+
original_sentence = raw_sentence.split()
|
370 |
+
sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
|
371 |
+
length = len(sentence)
|
372 |
+
min_sentence_length = MIN_SENTENCE_LENGTH
|
373 |
+
if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
|
374 |
+
pos_to_split = random.randint(1, length - 1) # for lm, learn all positions at once
|
375 |
+
inp = sentence[:pos_to_split]
|
376 |
+
while pos_to_split < len(sentence):
|
377 |
+
if len(self.parent.tokenizer.decode(inp).split()) == len(self.parent.tokenizer.decode(sentence[:pos_to_split + 1]).split()):
|
378 |
+
pos_to_split += 1
|
379 |
+
inp = sentence[:pos_to_split]
|
380 |
+
else:
|
381 |
+
break
|
382 |
+
length = len(inp)
|
383 |
+
num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
|
384 |
+
if not failed and num_words_in_input < len(original_sentence):
|
385 |
+
# only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway
|
386 |
+
future_word_position_max = len(original_sentence) - 1
|
387 |
+
future_word_position = random.randint(num_words_in_input-1, future_word_position_max) # allow the last possibly partial word though
|
388 |
+
future_word = original_sentence[future_word_position]
|
389 |
+
unstripped_future_word = future_word
|
390 |
+
future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
|
391 |
+
|
392 |
+
# future_word = original_sentence[-1] # useful for debugging
|
393 |
+
words_in_between = original_sentence[num_words_in_input-1:future_word_position+1]
|
394 |
+
syllables_to_go = count_syllables(' '.join(words_in_between))
|
395 |
+
if syllables_to_go > MAX_COUNT_SYLLABLE_DIST:
|
396 |
+
failed = True
|
397 |
+
# truncate context a bit since we're just doing couplets. random length from 1 to max desired length for this purpose.
|
398 |
+
desired_length = random.randint(1, MAX_COUNT_SYLLABLE_INPUT_LENGTH)
|
399 |
+
# desired_length = 10 # useful for debugging
|
400 |
+
inp = inp[-desired_length:]
|
401 |
+
length = len(inp)
|
402 |
+
true_label = 1 if unstripped_future_word.strip()[-1] in PHRASE_ENDS else 0 # common ways to end a phrase
|
403 |
+
classification_label = [-1 for _ in range(length)]
|
404 |
+
classification_label[-1] = true_label # only learn at the last position
|
405 |
+
if not failed and future_word in self.parent.word2index.keys():
|
406 |
+
word_log_prob = math.log(self.parent.vocab[future_word] / self.parent.total_words) # roughly baseline prob of word under noise model
|
407 |
+
future_word = self.parent.word2index[future_word]
|
408 |
+
pad_id = self.parent.gpt_pad_id
|
409 |
+
example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
|
410 |
+
valid = not failed
|
411 |
+
else:
|
412 |
+
raise NotImplementedError
|
413 |
+
|
414 |
+
self.pos += increment
|
415 |
+
return example
|
naacl-2021-fudge-controlled-generation/eval_formality_metrics.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentParser
|
2 |
+
import pickle
|
3 |
+
import os
|
4 |
+
import math
|
5 |
+
|
6 |
+
import sacrebleu
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model, MarianTokenizer, MarianMTModel
|
10 |
+
|
11 |
+
from constants import *
|
12 |
+
from model import Model
|
13 |
+
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
|
14 |
+
|
15 |
+
def avg_formality(preds, model, tokenizer, device='cuda'):
|
16 |
+
probs = []
|
17 |
+
for sent in preds:
|
18 |
+
encoded_input = tokenizer.encode(sent, return_tensors='pt').to(device)
|
19 |
+
lengths = torch.LongTensor([encoded_input.shape[1]]).to(device)
|
20 |
+
scores = model(encoded_input, lengths=lengths) # batch x seq
|
21 |
+
score = scores.flatten()[-1].item()
|
22 |
+
probs.append(math.exp(score) / (1 + math.exp(score))) # sigmoided score = prob
|
23 |
+
return np.mean(probs)
|
24 |
+
|
25 |
+
if __name__=='__main__':
|
26 |
+
parser = ArgumentParser()
|
27 |
+
parser.add_argument('--pred', type=str)
|
28 |
+
parser.add_argument('--ref', type=str, nargs='*', help='bleu refs')
|
29 |
+
parser.add_argument('--ckpt', type=str, help='formality classifier')
|
30 |
+
parser.add_argument('--dataset_info', type=str)
|
31 |
+
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
|
32 |
+
parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en')
|
33 |
+
|
34 |
+
args = parser.parse_args()
|
35 |
+
|
36 |
+
# refs = [['The dog bit the man.', 'It was not unexpected.', 'The man bit him first.'],
|
37 |
+
# ['The dog had bit the man.', 'No one was surprised.', 'The man had bitten the dog.']]
|
38 |
+
# sys = ['The dog bit the man.', "It wasn't surprising.", 'The man had just bitten him.']
|
39 |
+
print('num ref files', len(args.ref))
|
40 |
+
pred = []
|
41 |
+
with open(args.pred, 'r') as rf:
|
42 |
+
for line in rf:
|
43 |
+
pred.append(line.strip())
|
44 |
+
refs = []
|
45 |
+
for ref_file in args.ref:
|
46 |
+
ref = []
|
47 |
+
with open(ref_file, 'r') as rf:
|
48 |
+
for line in rf:
|
49 |
+
ref.append(line.strip())
|
50 |
+
assert len(ref) == len(pred)
|
51 |
+
refs.append(ref)
|
52 |
+
bleu = sacrebleu.corpus_bleu(pred, refs)
|
53 |
+
print('BLEU score:', bleu.score)
|
54 |
+
|
55 |
+
with open(args.dataset_info, 'rb') as rf:
|
56 |
+
dataset_info = pickle.load(rf)
|
57 |
+
|
58 |
+
tokenizer = MarianTokenizer.from_pretrained(args.model_string)
|
59 |
+
tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
|
60 |
+
pad_id = tokenizer.encode(PAD_TOKEN)[0]
|
61 |
+
|
62 |
+
checkpoint = torch.load(args.ckpt, map_location=args.device)
|
63 |
+
model_args = checkpoint['args']
|
64 |
+
conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
|
65 |
+
conditioning_model.load_state_dict(checkpoint['state_dict'])
|
66 |
+
conditioning_model = conditioning_model.to(args.device)
|
67 |
+
conditioning_model.eval()
|
68 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
69 |
+
.format(args.ckpt, checkpoint['epoch']))
|
70 |
+
print('num params', num_params(conditioning_model))
|
71 |
+
|
72 |
+
print('avg formality prob according to model', avg_formality(pred, conditioning_model, tokenizer, device=args.device))
|
73 |
+
|
naacl-2021-fudge-controlled-generation/eval_poetry_metrics.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentParser
|
2 |
+
import math
|
3 |
+
import string
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForSequenceClassification
|
10 |
+
|
11 |
+
from poetry_util import is_iambic, perfect_rhyme_end, count_syllables
|
12 |
+
from constants import *
|
13 |
+
|
14 |
+
|
15 |
+
def conditional_perplexity(prefix, pred, tokenizer, model, device='cuda', sep_losses=False):
|
16 |
+
# calculate perplexity on pred only, conditioned on prefix
|
17 |
+
sentence = prefix + pred
|
18 |
+
sos_token = tokenizer.decode([0])
|
19 |
+
prefix_tensor_input = tokenizer.encode(sos_token + prefix.replace(EOT_TOKEN, ' ').strip(), return_tensors='pt').to(device)
|
20 |
+
full_tensor_input = tokenizer.encode(sos_token + sentence.replace(EOT_TOKEN, ' ').strip(), return_tensors='pt').to(device)
|
21 |
+
if sep_losses:
|
22 |
+
prefix_loss = model(prefix_tensor_input, labels=prefix_tensor_input)[0].sum()
|
23 |
+
full_loss = model(full_tensor_input, labels=full_tensor_input)[0].sum()
|
24 |
+
else:
|
25 |
+
prefix_loss = model(prefix_tensor_input, labels=prefix_tensor_input)[0] * (prefix_tensor_input.shape[1]-1) # neg log prob of prefix
|
26 |
+
full_loss = model(full_tensor_input, labels=full_tensor_input)[0] * (full_tensor_input.shape[1]-1) # neg log prob of full seq
|
27 |
+
pred_loss = full_loss - prefix_loss # neg log prob of preds given prefix
|
28 |
+
avg_pred_loss = pred_loss / (full_tensor_input.shape[1] - prefix_tensor_input.shape[1])
|
29 |
+
return math.exp(avg_pred_loss.item())
|
30 |
+
|
31 |
+
|
32 |
+
def grammaticality(sentences, tokenizer, model, device='cuda'):
|
33 |
+
with torch.no_grad():
|
34 |
+
total_good = 0
|
35 |
+
for sent in tqdm(sentences, total=len(sentences)):
|
36 |
+
good_prob = F.softmax(model(tokenizer.encode(sent, return_tensors='pt').to(device))[0].flatten(), dim=0)[1]
|
37 |
+
total_good += good_prob
|
38 |
+
return total_good / len(sentences) # avg probability of grammaticality according to model
|
39 |
+
|
40 |
+
|
41 |
+
def distinctness(sentences):
|
42 |
+
d1 = set()
|
43 |
+
d2 = set()
|
44 |
+
d3 = set()
|
45 |
+
total_words = 0
|
46 |
+
for sentence in sentences:
|
47 |
+
o = sentence.split(' ')
|
48 |
+
total_words += len(o)
|
49 |
+
d1.update(o)
|
50 |
+
for i in range(len(o) - 1):
|
51 |
+
d2.add(o[i] + '_' + o[i+1])
|
52 |
+
for i in range(len(o) - 2):
|
53 |
+
d3.add(o[i] + '_' + o[i+1] + '_' + o[i+2])
|
54 |
+
return len(d1) / total_words, len(d2) / total_words, len(d3) / total_words
|
55 |
+
|
56 |
+
|
57 |
+
if __name__=='__main__':
|
58 |
+
parser = ArgumentParser()
|
59 |
+
parser.add_argument('--pred_file', type=str)
|
60 |
+
parser.add_argument('--prefix_file', type=str)
|
61 |
+
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
|
62 |
+
args = parser.parse_args()
|
63 |
+
|
64 |
+
preds = []
|
65 |
+
with open(args.pred_file, 'r') as rf:
|
66 |
+
for line in rf:
|
67 |
+
preds.append(line[:-1]) # drop \n but not beginning spaces if any
|
68 |
+
prefixes = []
|
69 |
+
with open(args.prefix_file, 'r') as rf:
|
70 |
+
for line in rf:
|
71 |
+
prefixes.append(line.strip())
|
72 |
+
assert len(prefixes) == len(preds)
|
73 |
+
rhymes = 0
|
74 |
+
iambic = 0
|
75 |
+
ten_syllables = 0
|
76 |
+
end = 0
|
77 |
+
diff_rhymes = 0
|
78 |
+
all_success = 0
|
79 |
+
total = len(prefixes)
|
80 |
+
for prefix, pred in zip(prefixes, preds):
|
81 |
+
if is_iambic(pred):
|
82 |
+
iambic += 1
|
83 |
+
if perfect_rhyme_end(prefix, pred):
|
84 |
+
rhymes += 1
|
85 |
+
if prefix.split()[-1].strip(string.punctuation) != pred.split()[-1].strip(string.punctuation):
|
86 |
+
diff_rhymes += 1
|
87 |
+
if count_syllables(pred) == 10:
|
88 |
+
ten_syllables += 1
|
89 |
+
if pred.strip()[-1] in PHRASE_ENDS:
|
90 |
+
end += 1
|
91 |
+
if is_iambic(pred) and perfect_rhyme_end(prefix, pred) and count_syllables(pred) == 10 and pred.strip()[-1] in PHRASE_ENDS:
|
92 |
+
all_success += 1
|
93 |
+
print('iambic', iambic, 'out of', total, ', frac', iambic / total)
|
94 |
+
print('rhymes', rhymes, 'out of', total, ', frac', rhymes / total)
|
95 |
+
print('end sentence', end, 'out of', total, ', frac', end / total)
|
96 |
+
print('10 syllables', ten_syllables, 'out of', total, ', frac', ten_syllables / total)
|
97 |
+
print('all success', all_success, 'out of', total, ', frac', all_success / total)
|
98 |
+
print('rhymes with diff word', diff_rhymes, 'out of', total, ', frac', diff_rhymes / total)
|
99 |
+
|
100 |
+
print('distinctness', distinctness(preds))
|
101 |
+
|
102 |
+
grammar_tokenizer = AutoTokenizer.from_pretrained('textattack/roberta-base-CoLA')
|
103 |
+
grammar_model = AutoModelForSequenceClassification.from_pretrained('textattack/roberta-base-CoLA').to(args.device)
|
104 |
+
grammar_model.eval()
|
105 |
+
print('grammaticality', grammaticality(preds, grammar_tokenizer, grammar_model, device=args.device))
|
106 |
+
|
107 |
+
perplexities = []
|
108 |
+
eval_tokenizer = AutoTokenizer.from_pretrained('transfo-xl-wt103')
|
109 |
+
eval_model = AutoModelWithLMHead.from_pretrained('transfo-xl-wt103').to(args.device)
|
110 |
+
eval_model.eval()
|
111 |
+
for prefix, pred in zip(prefixes, preds):
|
112 |
+
perplexities.append(conditional_perplexity(prefix, pred, eval_tokenizer, eval_model, device=args.device, sep_losses=True))
|
113 |
+
print('transformer xl perplexity', np.mean(perplexities), '+/-', np.std(perplexities))
|
114 |
+
|
115 |
+
perplexities = []
|
116 |
+
eval_tokenizer = AutoTokenizer.from_pretrained('openai-gpt')
|
117 |
+
eval_model = AutoModelWithLMHead.from_pretrained('openai-gpt').to(args.device)
|
118 |
+
eval_model.eval()
|
119 |
+
for prefix, pred in zip(prefixes, preds):
|
120 |
+
perplexities.append(conditional_perplexity(prefix, pred, eval_tokenizer, eval_model, device=args.device))
|
121 |
+
print('gpt perplexity', np.mean(perplexities), '+/-', np.std(perplexities))
|
122 |
+
|
123 |
+
# NOTE: uncomment this section with the path to the Shakespeare-finetuned GPT to evaluate this metric. it's in ckpt/poetry/gpt_finetune_shakespeare.pth.tar.
|
124 |
+
# eval_tokenizer = AutoTokenizer.from_pretrained('openai-gpt')
|
125 |
+
# eval_model = AutoModelWithLMHead.from_pretrained('openai-gpt').to(args.device)
|
126 |
+
# checkpoint = torch.load('***PATH_TO_SHAKESPEARE_FINETUNED_GPT***', map_location=args.device)
|
127 |
+
# mod_dict = {}
|
128 |
+
# for key in checkpoint['state_dict']:
|
129 |
+
# mod_dict[key.replace('classifier.', '')] = checkpoint['state_dict'][key]
|
130 |
+
# eval_model.load_state_dict(mod_dict)
|
131 |
+
# eval_model.eval()
|
132 |
+
# perplexities = []
|
133 |
+
# for prefix, pred in zip(prefixes, preds):
|
134 |
+
# perplexities.append(conditional_perplexity(prefix, pred, eval_tokenizer, eval_model, device=args.device))
|
135 |
+
# print('shakespeare finetuned perplexity', np.mean(perplexities), '+/-', np.std(perplexities))
|
naacl-2021-fudge-controlled-generation/eval_topic_metrics.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
import pickle
|
5 |
+
import math
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
from collections import defaultdict
|
8 |
+
import string
|
9 |
+
import csv
|
10 |
+
|
11 |
+
from tqdm import tqdm
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForSequenceClassification
|
17 |
+
|
18 |
+
from data import Dataset
|
19 |
+
from model import Model
|
20 |
+
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params, pad_mask
|
21 |
+
from predict import predict
|
22 |
+
from constants import *
|
23 |
+
|
24 |
+
def tw_topic_eval(sentences, category, tw_dir, cap=None):
|
25 |
+
# num matches of distinct words
|
26 |
+
words = []
|
27 |
+
with open(os.path.join(tw_dir, category + '.txt'), 'r') as rf:
|
28 |
+
for line in rf:
|
29 |
+
words.append(line.strip().lower())
|
30 |
+
num_match = 0
|
31 |
+
for sent in sentences:
|
32 |
+
sent_match = 0
|
33 |
+
sent = sent.strip().lower().split()
|
34 |
+
sent = [tok.strip(string.punctuation) for tok in sent]
|
35 |
+
for word in words:
|
36 |
+
if word in sent:
|
37 |
+
sent_match += 1
|
38 |
+
if cap is None:
|
39 |
+
num_match += sent_match
|
40 |
+
else:
|
41 |
+
num_match += min(cap, sent_match)
|
42 |
+
return num_match
|
43 |
+
|
44 |
+
|
45 |
+
def perplexity(sentences, tokenizer, model, device='cuda'):
|
46 |
+
# calculate perplexity
|
47 |
+
with torch.no_grad():
|
48 |
+
ppl = []
|
49 |
+
sos_token = tokenizer.decode([0])
|
50 |
+
for sentence in tqdm(sentences, total=len(sentences)):
|
51 |
+
full_tensor_input = tokenizer.encode(sos_token + sentence.replace(EOT_TOKEN, ' ').strip(), return_tensors='pt').to(device)
|
52 |
+
full_loss = model(full_tensor_input, labels=full_tensor_input)[0].mean()
|
53 |
+
ppl.append(torch.exp(full_loss).flatten().cpu().item())
|
54 |
+
return np.mean(ppl), np.std(ppl)
|
55 |
+
|
56 |
+
|
57 |
+
def grammaticality(sentences, tokenizer, model, device='cuda'):
|
58 |
+
with torch.no_grad():
|
59 |
+
total_good = 0
|
60 |
+
for sent in tqdm(sentences, total=len(sentences)):
|
61 |
+
good_prob = F.softmax(model(tokenizer.encode(sent, return_tensors='pt').to(device))[0].flatten(), dim=0)[1]
|
62 |
+
total_good += good_prob
|
63 |
+
return total_good / len(sentences) # avg probability of grammaticality according to model
|
64 |
+
|
65 |
+
|
66 |
+
def distinctness(results):
|
67 |
+
d1, d2, d3 = defaultdict(lambda: set()), defaultdict(lambda: set()), defaultdict(lambda: set())
|
68 |
+
total_words = defaultdict(lambda: 0)
|
69 |
+
for cw, outputs in results.items():
|
70 |
+
for o in outputs:
|
71 |
+
o = o.replace(EOT_TOKEN, ' ').strip().split(' ')
|
72 |
+
o = [str(x) for x in o]
|
73 |
+
total_words[cw] += len(o)
|
74 |
+
d1[cw].update(o)
|
75 |
+
for i in range(len(o) - 1):
|
76 |
+
d2[cw].add(o[i] + ' ' + o[i+1])
|
77 |
+
for i in range(len(o) - 2):
|
78 |
+
d3[cw].add(o[i] + ' ' + o[i+1] + ' ' + o[i+2])
|
79 |
+
return_info = []
|
80 |
+
avg_d1, avg_d2, avg_d3 = 0, 0, 0
|
81 |
+
for cw in total_words.keys():
|
82 |
+
return_info.append((cw, 'DISTINCTNESS', len(d1[cw]) / total_words[cw], len(d2[cw]) / total_words[cw], len(d3[cw]) / total_words[cw]))
|
83 |
+
avg_d1 += len(d1[cw]) / total_words[cw]
|
84 |
+
avg_d2 += len(d2[cw]) / total_words[cw]
|
85 |
+
avg_d3 += len(d3[cw]) / total_words[cw]
|
86 |
+
avg_d1, avg_d2, avg_d3 = avg_d1 / len(total_words.keys()), avg_d2 / len(total_words.keys()), avg_d3 / len(total_words.keys())
|
87 |
+
return return_info, (avg_d1, avg_d2, avg_d3)
|
88 |
+
|
89 |
+
|
90 |
+
if __name__=='__main__':
|
91 |
+
parser = ArgumentParser()
|
92 |
+
parser.add_argument('--log_file', type=str, required=True, help='where to load results from')
|
93 |
+
parser.add_argument('--tw_dir', type=str, default='test_wordlists', help='test wordlists')
|
94 |
+
parser.add_argument('--batch_size', type=int, default=8, help='max samples at a time')
|
95 |
+
parser.add_argument('--cap_per_example', type=int, default=None, help='max matches to count per sentence')
|
96 |
+
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
|
97 |
+
args = parser.parse_args()
|
98 |
+
|
99 |
+
tw_topic_match_c_total = 0
|
100 |
+
category_totals_c = defaultdict(lambda:0)
|
101 |
+
results = defaultdict(lambda: [])
|
102 |
+
with open(args.log_file, 'r') as rf:
|
103 |
+
data = list(csv.DictReader(rf))
|
104 |
+
for line in data:
|
105 |
+
results[line['category']].append(line['generation'])
|
106 |
+
|
107 |
+
all_c_sents = []
|
108 |
+
for category, condition_results in results.items():
|
109 |
+
tw_topic_match_c = tw_topic_eval(condition_results, category, args.tw_dir, cap=args.cap_per_example)
|
110 |
+
tw_topic_match_c_total += tw_topic_match_c
|
111 |
+
category_totals_c[category] += tw_topic_match_c
|
112 |
+
all_c_sents += condition_results
|
113 |
+
|
114 |
+
print('Test wordlist matches (divide by num outputs to get the Success metric):', tw_topic_match_c_total)
|
115 |
+
print('per category:', category_totals_c)
|
116 |
+
|
117 |
+
dist_info_by_category, dist_overall = distinctness(results)
|
118 |
+
print('Overall avg distinctness:', dist_overall)
|
119 |
+
print('per category:', dist_info_by_category)
|
120 |
+
|
121 |
+
grammar_tokenizer = AutoTokenizer.from_pretrained('textattack/roberta-base-CoLA')
|
122 |
+
grammar_model = AutoModelForSequenceClassification.from_pretrained('textattack/roberta-base-CoLA').to(args.device)
|
123 |
+
grammar_model.eval()
|
124 |
+
print('grammaticality:', grammaticality(all_c_sents, grammar_tokenizer, grammar_model, device=args.device))
|
125 |
+
|
126 |
+
eval_tokenizer = AutoTokenizer.from_pretrained('openai-gpt')
|
127 |
+
eval_model = AutoModelWithLMHead.from_pretrained('openai-gpt').to(args.device)
|
128 |
+
eval_model.eval()
|
129 |
+
print('GPT perplexity:', perplexity(all_c_sents, eval_tokenizer, eval_model))
|
130 |
+
|
131 |
+
eval_tokenizer = AutoTokenizer.from_pretrained('transfo-xl-wt103')
|
132 |
+
eval_model = AutoModelWithLMHead.from_pretrained('transfo-xl-wt103').to(args.device)
|
133 |
+
eval_model.eval()
|
134 |
+
print('TFXL perplexity:', perplexity(all_c_sents, eval_tokenizer, eval_model))
|
naacl-2021-fudge-controlled-generation/evaluate_clickbait.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
import pickle
|
5 |
+
import math
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
|
8 |
+
from typing import Iterable, List, Optional, Tuple
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from transformers import AutoTokenizer, AutoModelWithLMHead
|
16 |
+
from torch import Tensor
|
17 |
+
|
18 |
+
from data import Dataset
|
19 |
+
from model import Model
|
20 |
+
from util import num_params
|
21 |
+
from constants import *
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained('google/pegasus-xsum')
|
26 |
+
classifier_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
|
27 |
+
|
28 |
+
|
29 |
+
def main(args):
|
30 |
+
with open(args.dataset_info, 'rb') as rf:
|
31 |
+
dataset_info = pickle.load(rf)
|
32 |
+
|
33 |
+
article_content = """Australian actor Guy Pearce will return for the iconic soap Neighbours finale on August 1 to reprise his role as Mike Young.
|
34 |
+
Guy, 54, played the troubled Mike from 1986 to 1989, and is now set to make a comeback on the show after 33 years, Metro.co.uk reports.
|
35 |
+
The star's character arcs explored the implications of domestic abuse, student-teacher relationships and dealing with loss of loved ones.
|
36 |
+
Speaking to Metro.co.uk, Guy said: 'It is very exciting and surreal at the same time being back on set again, however it feels like coming home.
|
37 |
+
'It's where it all started for me professionally. I've been asked to come back on occasions over the years and wondered if it was the right thing
|
38 |
+
to do, but once I knew the show was finishing, I knew I had to do it.'He added that there is 'nothing like being here all together again'
|
39 |
+
, even though he's had a chance to catch-up with other cast members."""
|
40 |
+
|
41 |
+
tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
|
42 |
+
pad_id = tokenizer.encode(PAD_TOKEN)[0]
|
43 |
+
|
44 |
+
#For loading Clickbait summarizer
|
45 |
+
model = AutoModelWithLMHead.from_pretrained(args.model_string, return_dict=True).to(args.device)
|
46 |
+
|
47 |
+
model.eval()
|
48 |
+
|
49 |
+
checkpoint = torch.load(args.ckpt, map_location=args.device)
|
50 |
+
model_args = checkpoint['args']
|
51 |
+
conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
|
52 |
+
conditioning_model.load_state_dict(checkpoint['state_dict'])
|
53 |
+
conditioning_model = conditioning_model.to(args.device)
|
54 |
+
conditioning_model.eval()
|
55 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
56 |
+
.format(args.ckpt, checkpoint['epoch']))
|
57 |
+
print('num params', num_params(conditioning_model))
|
58 |
+
|
59 |
+
while True:
|
60 |
+
results = generate_clickbait(model,
|
61 |
+
tokenizer,
|
62 |
+
conditioning_model,
|
63 |
+
[args.input_text],
|
64 |
+
dataset_info,
|
65 |
+
precondition_topk=args.precondition_topk,
|
66 |
+
do_sample=args.do_sample,
|
67 |
+
length_cutoff=args.length_cutoff,
|
68 |
+
condition_lambda=args.condition_lambda,
|
69 |
+
article_content=article_content,
|
70 |
+
device=args.device)
|
71 |
+
# print(results)
|
72 |
+
import pdb; pdb.set_trace()
|
73 |
+
|
74 |
+
|
75 |
+
def generate_clickbait(model,
|
76 |
+
tokenizer,
|
77 |
+
conditioning_model,
|
78 |
+
input_text,
|
79 |
+
dataset_info,
|
80 |
+
precondition_topk,
|
81 |
+
length_cutoff,
|
82 |
+
condition_lambda=1.0,
|
83 |
+
article_content=None,
|
84 |
+
device='cuda'):
|
85 |
+
with torch.no_grad():
|
86 |
+
batch_size = len(input_text)
|
87 |
+
# encoded_input_article = [tokenizer.encode(article_content, return_tensors='pt',add_special_tokens=False).to(device)] # batch x seq
|
88 |
+
encoded_input_article = tokenizer(article_content, return_tensors='pt',add_special_tokens=False, max_length=512).to(device) # batch x seq
|
89 |
+
# encoded_input_article = torch.cat(encoded_input_article, dim=0)
|
90 |
+
# attention_mask = encoded_input_article.new_ones(encoded_input_article.shape).to(device)
|
91 |
+
|
92 |
+
# CHANGE=ko
|
93 |
+
encoded_input = tokenizer('<pad>', return_tensors='pt',add_special_tokens=False).to(device) # batch x seq
|
94 |
+
# encoded_input = tokenizer('<pad>'+ input_text[0], return_tensors='pt',add_special_tokens=False).to(device) # batch x seq
|
95 |
+
# encoded_input = torch.cat(encoded_input, dim=0)
|
96 |
+
encoded_input = encoded_input['input_ids']
|
97 |
+
|
98 |
+
|
99 |
+
lengths = torch.LongTensor([encoded_input.shape[1]]).to(device)
|
100 |
+
# lengths = 1
|
101 |
+
|
102 |
+
past = None
|
103 |
+
use_cache = True
|
104 |
+
|
105 |
+
# CHANGE
|
106 |
+
# model_kwargs = {'encoder_outputs': model.get_encoder()(encoded_input_article, attention_mask=attention_mask)}
|
107 |
+
# print(encoded_input_article)
|
108 |
+
# print(encoded_input_article['input_ids'].shape, encoded_input_article['attention_mask'].shape)
|
109 |
+
model_kwargs = {'encoder_outputs': model.get_encoder()(input_ids=encoded_input_article['input_ids'],
|
110 |
+
attention_mask=encoded_input_article['attention_mask'],
|
111 |
+
return_dict=True,
|
112 |
+
output_attentions=False,
|
113 |
+
output_hidden_states=False),
|
114 |
+
}
|
115 |
+
|
116 |
+
while lengths.max() < length_cutoff:
|
117 |
+
model_inputs = model.prepare_inputs_for_generation(
|
118 |
+
input_ids = encoded_input_article['input_ids'],
|
119 |
+
decoder_input_ids=encoded_input,
|
120 |
+
# past=past,
|
121 |
+
attention_mask=encoded_input_article['attention_mask'],
|
122 |
+
use_cache=use_cache,
|
123 |
+
**model_kwargs
|
124 |
+
)
|
125 |
+
|
126 |
+
outputs = model(**model_inputs, return_dict=True)
|
127 |
+
logits = outputs.logits[:, -1, :]
|
128 |
+
|
129 |
+
if "past_key_values" in outputs:
|
130 |
+
model_kwargs["past"] = outputs.past_key_values
|
131 |
+
|
132 |
+
# logits = model(encoded_input)[0][:, -1, :] # batch x vocab
|
133 |
+
top_logits, top_indices = logits.topk(precondition_topk, dim=1) # batch x topk
|
134 |
+
new_input_candidates = torch.cat([encoded_input.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2) # batch x topk x seq+1
|
135 |
+
expanded_lengths = (lengths + 1).unsqueeze(1).expand(batch_size, precondition_topk) # batch x topk
|
136 |
+
|
137 |
+
if condition_lambda == 0:
|
138 |
+
condition_logits = torch.zeros_like(top_logits).float()
|
139 |
+
condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
|
140 |
+
else:
|
141 |
+
decoded_outputs = tokenizer.batch_decode(new_input_candidates.view(-1, new_input_candidates.size(-1)), clean_up_tokenization_spaces=False)
|
142 |
+
resulting_tokenization = classifier_tokenizer(decoded_outputs, add_special_tokens=False, padding='longest')
|
143 |
+
encoded_with_classifier = resulting_tokenization['input_ids']
|
144 |
+
attention_mask = torch.tensor(resulting_tokenization['attention_mask']).to(model.device)
|
145 |
+
tplus1_candidates_classifier = torch.tensor(encoded_with_classifier).view(batch_size, precondition_topk, -1).to(model.device)
|
146 |
+
|
147 |
+
condition_logits = conditioning_model(tplus1_candidates_classifier.flatten(0, 1), # batch*topk x seq+1
|
148 |
+
expanded_lengths.flatten(0, 1), # batch*topk
|
149 |
+
None,
|
150 |
+
None,
|
151 |
+
None,
|
152 |
+
attention_mask=attention_mask
|
153 |
+
)
|
154 |
+
condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
|
155 |
+
condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs
|
156 |
+
|
157 |
+
condition_logits = torch.mean(condition_logits, dim=2)
|
158 |
+
full_logits = top_logits + condition_logits * condition_lambda # batch x topk
|
159 |
+
post_logits, post_indices = full_logits.topk(precondition_topk, dim=1)
|
160 |
+
post_probs = F.softmax(post_logits, dim=1)
|
161 |
+
# index_into_top_indices = post_indices[torch.arange(batch_size).to(post_indices.device), torch.multinomial(post_probs, 1).flatten()] # batch
|
162 |
+
index_into_top_indices = post_indices[:, torch.multinomial(post_probs, 1).flatten()] # batch
|
163 |
+
|
164 |
+
# next_indices = top_indices[torch.arange(batch_size).to(top_indices.device), index_into_top_indices] # batch
|
165 |
+
next_indices = top_indices[:, index_into_top_indices] # batch
|
166 |
+
|
167 |
+
# encoded_input = torch.cat([encoded_input, next_indices.unsqueeze(1)], dim=1) # batch x seq+1
|
168 |
+
encoded_input = torch.cat([encoded_input, next_indices.squeeze(1)], dim=1)
|
169 |
+
lengths = lengths + 1 # batch
|
170 |
+
|
171 |
+
# print(tokenizer.decode(encoded_input[0], add_special_tokens=False))
|
172 |
+
return [tokenizer.decode(s) for s in encoded_input]
|
173 |
+
|
174 |
+
|
175 |
+
if __name__=='__main__':
|
176 |
+
parser = ArgumentParser()
|
177 |
+
|
178 |
+
# DATA
|
179 |
+
parser.add_argument('--ckpt', type=str, required=True)
|
180 |
+
parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
|
181 |
+
parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en')
|
182 |
+
|
183 |
+
parser.add_argument('--in_file', type=str, default=None, required=True, help='text to run pred on')
|
184 |
+
|
185 |
+
parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from text generation at each step before conditioning and re-pruning')
|
186 |
+
parser.add_argument('--do_sample', action='store_true', default=False, help='sample instead of greedy')
|
187 |
+
parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
|
188 |
+
parser.add_argument('--length_cutoff', type=int, default=512, help='max length')
|
189 |
+
|
190 |
+
parser.add_argument('--seed', type=int, default=1, help='random seed')
|
191 |
+
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
|
192 |
+
parser.add_argument('--debug', action='store_true', default=False)
|
193 |
+
|
194 |
+
args = parser.parse_args()
|
195 |
+
|
196 |
+
random.seed(args.seed)
|
197 |
+
np.random.seed(args.seed)
|
198 |
+
torch.manual_seed(args.seed)
|
199 |
+
|
200 |
+
main(args)
|
naacl-2021-fudge-controlled-generation/evaluate_formality.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
import pickle
|
5 |
+
import math
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
from collections import namedtuple
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model, MarianTokenizer, MarianMTModel
|
15 |
+
|
16 |
+
from data import Dataset
|
17 |
+
from model import Model
|
18 |
+
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
|
19 |
+
from constants import *
|
20 |
+
from predict_formality import predict_formality
|
21 |
+
|
22 |
+
def main(args):
|
23 |
+
with open(args.dataset_info, 'rb') as rf:
|
24 |
+
dataset_info = pickle.load(rf)
|
25 |
+
tokenizer = MarianTokenizer.from_pretrained(args.model_string)
|
26 |
+
tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
|
27 |
+
pad_id = tokenizer.encode(PAD_TOKEN)[0]
|
28 |
+
model = MarianMTModel.from_pretrained(args.model_string, return_dict=True).to(args.device)
|
29 |
+
if args.model_path is not None:
|
30 |
+
if os.path.isdir(args.model_path):
|
31 |
+
for _, _, files in os.walk(args.model_path):
|
32 |
+
for fname in files:
|
33 |
+
if fname.endswith('.ckpt'):
|
34 |
+
args.model_path = os.path.join(args.model_path, fname)
|
35 |
+
break
|
36 |
+
ckpt = torch.load(args.model_path, map_location=torch.device(args.device))
|
37 |
+
try:
|
38 |
+
model.load_state_dict(ckpt['state_dict'], strict=False)
|
39 |
+
except:
|
40 |
+
state_dict = {}
|
41 |
+
for key in ckpt['state_dict'].keys():
|
42 |
+
assert key.startswith('model.')
|
43 |
+
state_dict[key[6:]] = ckpt['state_dict'][key]
|
44 |
+
model.load_state_dict(state_dict)
|
45 |
+
model.eval()
|
46 |
+
|
47 |
+
checkpoint = torch.load(args.ckpt, map_location=args.device)
|
48 |
+
model_args = checkpoint['args']
|
49 |
+
conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
|
50 |
+
conditioning_model.load_state_dict(checkpoint['state_dict'])
|
51 |
+
conditioning_model = conditioning_model.to(args.device)
|
52 |
+
conditioning_model.eval()
|
53 |
+
if args.verbose:
|
54 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
55 |
+
.format(args.ckpt, checkpoint['epoch']))
|
56 |
+
print('num params', num_params(conditioning_model))
|
57 |
+
|
58 |
+
inputs = []
|
59 |
+
with open(args.in_file, 'r') as rf:
|
60 |
+
for line in rf:
|
61 |
+
inputs.append(line.strip())
|
62 |
+
|
63 |
+
for inp in tqdm(inputs, total=len(inputs)):
|
64 |
+
results = predict_formality(model,
|
65 |
+
tokenizer,
|
66 |
+
conditioning_model,
|
67 |
+
[inp],
|
68 |
+
dataset_info,
|
69 |
+
precondition_topk=args.precondition_topk,
|
70 |
+
do_sample=args.do_sample,
|
71 |
+
length_cutoff=args.length_cutoff,
|
72 |
+
condition_lambda=args.condition_lambda,
|
73 |
+
device=args.device)
|
74 |
+
print(results[0])
|
75 |
+
|
76 |
+
|
77 |
+
if __name__=='__main__':
|
78 |
+
parser = ArgumentParser()
|
79 |
+
|
80 |
+
# DATA
|
81 |
+
parser.add_argument('--ckpt', type=str, required=True)
|
82 |
+
parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
|
83 |
+
parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en')
|
84 |
+
parser.add_argument('--model_path', type=str, default=None)
|
85 |
+
|
86 |
+
parser.add_argument('--in_file', type=str, default=None, required=True, help='file containing text to run pred on')
|
87 |
+
|
88 |
+
parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
|
89 |
+
parser.add_argument('--do_sample', action='store_true', default=False, help='sample or greedy; only greedy implemented')
|
90 |
+
parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
|
91 |
+
parser.add_argument('--length_cutoff', type=int, default=512, help='max length')
|
92 |
+
|
93 |
+
parser.add_argument('--seed', type=int, default=1, help='random seed')
|
94 |
+
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
|
95 |
+
parser.add_argument('--debug', action='store_true', default=False)
|
96 |
+
parser.add_argument('--verbose', action='store_true', default=False)
|
97 |
+
|
98 |
+
args = parser.parse_args()
|
99 |
+
|
100 |
+
random.seed(args.seed)
|
101 |
+
np.random.seed(args.seed)
|
102 |
+
torch.manual_seed(args.seed)
|
103 |
+
|
104 |
+
main(args)
|
naacl-2021-fudge-controlled-generation/evaluate_poetry.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
import pickle
|
5 |
+
import math
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
import string
|
8 |
+
from collections import defaultdict
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
|
16 |
+
|
17 |
+
from data import Dataset, load_rhyme_info
|
18 |
+
from model import Model
|
19 |
+
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
|
20 |
+
from constants import *
|
21 |
+
from poetry_util import get_rhymes, count_syllables
|
22 |
+
from predict_poetry import predict_couplet
|
23 |
+
|
24 |
+
def main(args):
|
25 |
+
with open(args.dataset_info, 'rb') as rf:
|
26 |
+
dataset_info = pickle.load(rf)
|
27 |
+
gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
|
28 |
+
gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
|
29 |
+
gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
|
30 |
+
gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device)
|
31 |
+
gpt_model.eval()
|
32 |
+
|
33 |
+
checkpoint = torch.load(args.iambic_ckpt, map_location=args.device)
|
34 |
+
model_args = checkpoint['args']
|
35 |
+
iambic_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
|
36 |
+
iambic_model.load_state_dict(checkpoint['state_dict'])
|
37 |
+
iambic_model = iambic_model.to(args.device)
|
38 |
+
iambic_model.eval()
|
39 |
+
if args.verbose:
|
40 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
41 |
+
.format(args.iambic_ckpt, checkpoint['epoch']))
|
42 |
+
print('iambic model num params', num_params(iambic_model))
|
43 |
+
|
44 |
+
with open(args.rhyme_info, 'rb') as rf:
|
45 |
+
rhyme_info = pickle.load(rf)
|
46 |
+
checkpoint = torch.load(args.rhyme_ckpt, map_location=args.device)
|
47 |
+
model_args = checkpoint['args']
|
48 |
+
rhyme_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word), rhyme_group_size=len(rhyme_info.index2rhyme_group), verbose=args.verbose) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
|
49 |
+
rhyme_model.load_state_dict(checkpoint['state_dict'])
|
50 |
+
rhyme_model = rhyme_model.to(args.device)
|
51 |
+
rhyme_model.eval()
|
52 |
+
if args.verbose:
|
53 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
54 |
+
.format(args.rhyme_ckpt, checkpoint['epoch']))
|
55 |
+
print('rhyme model num params', num_params(rhyme_model))
|
56 |
+
|
57 |
+
checkpoint = torch.load(args.newline_ckpt, map_location=args.device)
|
58 |
+
model_args = checkpoint['args']
|
59 |
+
newline_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
|
60 |
+
newline_model.load_state_dict(checkpoint['state_dict'])
|
61 |
+
newline_model = newline_model.to(args.device)
|
62 |
+
newline_model.eval()
|
63 |
+
if args.verbose:
|
64 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
65 |
+
.format(args.newline_ckpt, checkpoint['epoch']))
|
66 |
+
print('iambic model num params', num_params(newline_model))
|
67 |
+
|
68 |
+
with open(args.prefix_file, 'r') as rf:
|
69 |
+
lines = rf.readlines()
|
70 |
+
for line in tqdm(lines, total=len(lines)):
|
71 |
+
couplet = predict_couplet(gpt_model,
|
72 |
+
gpt_tokenizer,
|
73 |
+
iambic_model,
|
74 |
+
rhyme_model,
|
75 |
+
newline_model,
|
76 |
+
[line],
|
77 |
+
dataset_info,
|
78 |
+
rhyme_info,
|
79 |
+
args.precondition_topk,
|
80 |
+
args.topk,
|
81 |
+
condition_lambda=args.condition_lambda,
|
82 |
+
device=args.device)
|
83 |
+
assert len(couplet) == 2
|
84 |
+
print(couplet[1].strip().replace('\n', ''))
|
85 |
+
|
86 |
+
|
87 |
+
if __name__=='__main__':
|
88 |
+
parser = ArgumentParser()
|
89 |
+
|
90 |
+
# DATA
|
91 |
+
parser.add_argument('--iambic_ckpt', type=str, required=True)
|
92 |
+
parser.add_argument('--rhyme_ckpt', type=str, required=True)
|
93 |
+
parser.add_argument('--newline_ckpt', type=str, required=True)
|
94 |
+
parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
|
95 |
+
parser.add_argument('--rhyme_info', type=str, required=True, help='saved rhyme info')
|
96 |
+
parser.add_argument('--model_string', type=str, default='gpt2-medium')
|
97 |
+
|
98 |
+
parser.add_argument('--prefix_file', type=str, default=None, required=True, help='file of prefix lines for couplets')
|
99 |
+
|
100 |
+
parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
|
101 |
+
parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step')
|
102 |
+
parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
|
103 |
+
|
104 |
+
parser.add_argument('--seed', type=int, default=1, help='random seed')
|
105 |
+
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
|
106 |
+
parser.add_argument('--debug', action='store_true', default=False)
|
107 |
+
parser.add_argument('--verbose', action='store_true', default=False)
|
108 |
+
|
109 |
+
args = parser.parse_args()
|
110 |
+
|
111 |
+
random.seed(args.seed)
|
112 |
+
np.random.seed(args.seed)
|
113 |
+
torch.manual_seed(args.seed)
|
114 |
+
|
115 |
+
main(args)
|
naacl-2021-fudge-controlled-generation/evaluate_topic.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
import pickle
|
5 |
+
import math
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
from collections import defaultdict
|
8 |
+
import string
|
9 |
+
import csv
|
10 |
+
|
11 |
+
from tqdm import tqdm
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
|
17 |
+
|
18 |
+
from data import Dataset
|
19 |
+
from model import Model
|
20 |
+
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params, pad_mask
|
21 |
+
from predict_topic import predict
|
22 |
+
from constants import *
|
23 |
+
|
24 |
+
|
25 |
+
def main(args):
|
26 |
+
with open(args.dataset_info, 'rb') as rf:
|
27 |
+
dataset_info = pickle.load(rf)
|
28 |
+
gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
|
29 |
+
gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
|
30 |
+
gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
|
31 |
+
gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device)
|
32 |
+
gpt_model.eval()
|
33 |
+
|
34 |
+
checkpoint = torch.load(args.ckpt, map_location=args.device)
|
35 |
+
model_args = checkpoint['args']
|
36 |
+
conditioning_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
|
37 |
+
conditioning_model.load_state_dict(checkpoint['state_dict'])
|
38 |
+
conditioning_model = conditioning_model.to(args.device)
|
39 |
+
conditioning_model.eval()
|
40 |
+
if args.verbose:
|
41 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
42 |
+
.format(args.ckpt, checkpoint['epoch']))
|
43 |
+
print('num params', num_params(conditioning_model))
|
44 |
+
|
45 |
+
input_texts, conditions, categories = [], [], []
|
46 |
+
|
47 |
+
if args.condition_file is not None:
|
48 |
+
with open(args.condition_file, 'r') as rf:
|
49 |
+
for line in rf:
|
50 |
+
input_texts.append(line.strip().split('\t')[0])
|
51 |
+
conditions.append(line.strip().split('\t')[1])
|
52 |
+
categories.append(None)
|
53 |
+
for cw in conditions[-1].split():
|
54 |
+
assert cw in dataset_info.word2index
|
55 |
+
else:
|
56 |
+
prefixes = []
|
57 |
+
with open(args.prefix_file, 'r') as rf:
|
58 |
+
for line in rf:
|
59 |
+
prefixes.append(line.strip())
|
60 |
+
condition_wordlists = []
|
61 |
+
for root, _, files in os.walk(args.wordlist_dir):
|
62 |
+
for fname in files:
|
63 |
+
words = []
|
64 |
+
with open(os.path.join(root, fname), 'r') as rf:
|
65 |
+
for line in rf:
|
66 |
+
word = line.strip()
|
67 |
+
if word in dataset_info.word2index:
|
68 |
+
words.append(word)
|
69 |
+
else:
|
70 |
+
if args.verbose:
|
71 |
+
print('word not found:', word)
|
72 |
+
condition_wordlists.append((' '.join(words), fname.split('.')[0]))
|
73 |
+
for p in prefixes:
|
74 |
+
for c, category in condition_wordlists:
|
75 |
+
input_texts.append(p)
|
76 |
+
conditions.append(c)
|
77 |
+
categories.append(category)
|
78 |
+
|
79 |
+
all_cr = []
|
80 |
+
pair_num = 0
|
81 |
+
for input_text, condition_words, category in tqdm(zip(input_texts, conditions, categories), total=len(conditions)):
|
82 |
+
predict_function = predict
|
83 |
+
condition_results = []
|
84 |
+
for i in range(0, args.sample_size, args.max_sample_batch):
|
85 |
+
num_samples = min(args.max_sample_batch, args.sample_size - i)
|
86 |
+
condition_results += predict_function(gpt_model,
|
87 |
+
gpt_tokenizer,
|
88 |
+
conditioning_model,
|
89 |
+
[input_text for _ in range(num_samples)],
|
90 |
+
condition_words,
|
91 |
+
dataset_info,
|
92 |
+
args.precondition_topk,
|
93 |
+
args.topk,
|
94 |
+
args.length_cutoff,
|
95 |
+
condition_lambda=args.condition_lambda,
|
96 |
+
device=args.device)
|
97 |
+
all_cr.append((input_text, category, condition_results))
|
98 |
+
pair_num += 1
|
99 |
+
if args.max_pairs > 0 and pair_num >= args.max_pairs:
|
100 |
+
break
|
101 |
+
with open(args.log_file, 'w') as wf:
|
102 |
+
writer = csv.DictWriter(wf, fieldnames=['category', 'input_text', 'generation'])
|
103 |
+
writer.writeheader()
|
104 |
+
for cr_group in all_cr:
|
105 |
+
for cr in cr_group[2]:
|
106 |
+
writer.writerow({'category': cr_group[1], 'input_text': cr_group[0], 'generation': cr})
|
107 |
+
|
108 |
+
|
109 |
+
if __name__=='__main__':
|
110 |
+
parser = ArgumentParser()
|
111 |
+
|
112 |
+
# DATA
|
113 |
+
parser.add_argument('--ckpt', type=str, required=True)
|
114 |
+
parser.add_argument('--log_file', type=str, required=True, help='file to write outputs to (csv format)')
|
115 |
+
parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
|
116 |
+
parser.add_argument('--model_string', type=str, default='gpt2-medium')
|
117 |
+
|
118 |
+
parser.add_argument('--condition_file', type=str, default=None, help='file of inputs and conditions')
|
119 |
+
parser.add_argument('--prefix_file', type=str, default=None, help='prefix set')
|
120 |
+
parser.add_argument('--wordlist_dir', type=str, default=None, help='dir of bow wordlists for categories')
|
121 |
+
parser.add_argument('--sample_size', type=int, default=3, help='samples per input text-condition pair')
|
122 |
+
parser.add_argument('--max_sample_batch', type=int, default=3, help='max samples at a time')
|
123 |
+
parser.add_argument('--max_pairs', type=int, default=-1, help='max input-condition pairs, for debugging quickly')
|
124 |
+
|
125 |
+
parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
|
126 |
+
parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step')
|
127 |
+
parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
|
128 |
+
parser.add_argument('--length_cutoff', type=int, default=80, help='max length')
|
129 |
+
|
130 |
+
parser.add_argument('--seed', type=int, default=1, help='random seed')
|
131 |
+
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
|
132 |
+
parser.add_argument('--debug', action='store_true', default=False)
|
133 |
+
parser.add_argument('--verbose', action='store_true', default=False)
|
134 |
+
|
135 |
+
args = parser.parse_args()
|
136 |
+
|
137 |
+
assert (args.condition_file is not None) != (args.prefix_file is not None and args.wordlist_dir is not None) # one of two interfaces for specifying
|
138 |
+
|
139 |
+
random.seed(args.seed)
|
140 |
+
np.random.seed(args.seed)
|
141 |
+
torch.manual_seed(args.seed)
|
142 |
+
|
143 |
+
main(args)
|
naacl-2021-fudge-controlled-generation/formality_data/README.md
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
`fisher_test_oracle.es` is the source-side Spanish test set.
|
2 |
+
`test_noid.cleaned_0` and `test_noid.cleaned_1` are Salesky 2019's fluent English test-time references.
|
naacl-2021-fudge-controlled-generation/formality_data/fisher_test_oracle.es
ADDED
The diff for this file is too large to render.
See raw diff
|
|
naacl-2021-fudge-controlled-generation/formality_data/test.noid.cleaned_0
ADDED
The diff for this file is too large to render.
See raw diff
|
|
naacl-2021-fudge-controlled-generation/formality_data/test.noid.cleaned_1
ADDED
The diff for this file is too large to render.
See raw diff
|
|
naacl-2021-fudge-controlled-generation/main.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
import pickle
|
5 |
+
import math
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
|
8 |
+
from tqdm import tqdm
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
from data import Dataset
|
14 |
+
from model import Model
|
15 |
+
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params, pad_mask
|
16 |
+
from constants import *
|
17 |
+
|
18 |
+
|
19 |
+
def train(model, dataset, optimizer, criterion, epoch, args, data_start_index):
|
20 |
+
model.train()
|
21 |
+
if data_start_index == 0:
|
22 |
+
dataset.shuffle('train', seed=epoch + args.seed)
|
23 |
+
if args.epoch_max_len is not None:
|
24 |
+
data_end_index = min(data_start_index + args.epoch_max_len, len(dataset.splits['train']))
|
25 |
+
loader = dataset.loader('train', num_workers=args.num_workers, indices=list(range(data_start_index, data_end_index)))
|
26 |
+
data_start_index = data_end_index if data_end_index < len(dataset.splits['train']) else 0
|
27 |
+
else:
|
28 |
+
loader = dataset.loader('train', num_workers=args.num_workers)
|
29 |
+
loss_meter = AverageMeter('loss', ':6.4f')
|
30 |
+
total_length = len(loader)
|
31 |
+
progress = ProgressMeter(total_length, [loss_meter], prefix='Training: ')
|
32 |
+
for batch_num, batch in enumerate(tqdm(loader, total=len(loader))):
|
33 |
+
batch = [tensor.to(args.device) for tensor in batch]
|
34 |
+
inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch
|
35 |
+
if args.task not in ['formality', 'iambic']:
|
36 |
+
if not args.debug and len(inputs) != args.batch_size: # it'll screw up the bias...?
|
37 |
+
continue
|
38 |
+
scores = model(inputs, lengths, future_words, log_probs, syllables_to_go, future_word_num_syllables, rhyme_group_index, run_classifier=True)
|
39 |
+
if args.task == 'formality': # we're learning for all positions at once. scores are batch x seq
|
40 |
+
expanded_labels = classification_targets.unsqueeze(1).expand(-1, scores.shape[1]) # batch x seq
|
41 |
+
length_mask = pad_mask(lengths).permute(1, 0) # batch x seq
|
42 |
+
loss = criterion(scores.flatten()[length_mask.flatten()==1], expanded_labels.flatten().float()[length_mask.flatten()==1])
|
43 |
+
elif args.task in ['iambic', 'newline']:
|
44 |
+
use_indices = classification_targets.flatten() != -1
|
45 |
+
loss = criterion(scores.flatten()[use_indices], classification_targets.flatten().float()[use_indices])
|
46 |
+
else: # topic, rhyme
|
47 |
+
loss = criterion(scores.flatten(), labels.flatten().float())
|
48 |
+
optimizer.zero_grad()
|
49 |
+
loss.backward()
|
50 |
+
optimizer.step()
|
51 |
+
loss_meter.update(loss.detach(), len(labels))
|
52 |
+
if batch_num % args.train_print_freq == 0:
|
53 |
+
progress.display(batch_num)
|
54 |
+
progress.display(total_length)
|
55 |
+
return data_start_index
|
56 |
+
|
57 |
+
|
58 |
+
def validate(model, dataset, criterion, epoch, args):
|
59 |
+
model.eval()
|
60 |
+
random.seed(0)
|
61 |
+
loader = dataset.loader('val', num_workers=args.num_workers)
|
62 |
+
loss_meter = AverageMeter('loss', ':6.4f')
|
63 |
+
total_length = len(loader)
|
64 |
+
progress = ProgressMeter(total_length, [loss_meter], prefix='Validation: ')
|
65 |
+
with torch.no_grad():
|
66 |
+
for batch_num, batch in enumerate(tqdm(loader, total=len(loader))):
|
67 |
+
batch = [tensor.to(args.device) for tensor in batch]
|
68 |
+
inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch
|
69 |
+
if args.task not in ['formality', 'iambic']: # topic predictor
|
70 |
+
if not args.debug and len(inputs) != args.batch_size:
|
71 |
+
continue
|
72 |
+
scores = model(inputs, lengths, future_words, log_probs, syllables_to_go, future_word_num_syllables, rhyme_group_index, run_classifier=True)
|
73 |
+
if args.task == 'formality': # we're learning for all positions at once. scores are batch x seq
|
74 |
+
expanded_labels = classification_targets.unsqueeze(1).expand(-1, scores.shape[1]) # batch x seq
|
75 |
+
length_mask = pad_mask(lengths).permute(1, 0) # batch x seq
|
76 |
+
loss = criterion(scores.flatten()[length_mask.flatten()==1], expanded_labels.flatten().float()[length_mask.flatten()==1])
|
77 |
+
elif args.task in ['iambic', 'newline']:
|
78 |
+
use_indices = classification_targets.flatten() != -1
|
79 |
+
loss = criterion(scores.flatten()[use_indices], classification_targets.flatten().float()[use_indices])
|
80 |
+
else: # topic, rhyme
|
81 |
+
loss = criterion(scores.flatten(), labels.flatten().float())
|
82 |
+
loss_meter.update(loss.detach(), len(labels))
|
83 |
+
if batch_num % args.train_print_freq == 0:
|
84 |
+
progress.display(batch_num)
|
85 |
+
progress.display(total_length)
|
86 |
+
return loss_meter.avg
|
87 |
+
|
88 |
+
|
89 |
+
def main(args):
|
90 |
+
dataset = Dataset(args)
|
91 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
92 |
+
with open(os.path.join(args.save_dir, 'dataset_info'), 'wb') as wf:
|
93 |
+
pickle.dump(dataset.dataset_info, wf)
|
94 |
+
if args.task == 'rhyme':
|
95 |
+
with open(os.path.join(args.save_dir, 'rhyme_info'), 'wb') as wf:
|
96 |
+
pickle.dump(dataset.rhyme_info, wf)
|
97 |
+
if args.ckpt:
|
98 |
+
checkpoint = torch.load(args.ckpt, map_location=args.device)
|
99 |
+
start_epoch = checkpoint['epoch'] + 1
|
100 |
+
best_val_metric = checkpoint['best_metric']
|
101 |
+
model_args = checkpoint['args']
|
102 |
+
model = Model(model_args, dataset.gpt_pad_id, len(dataset.index2word), rhyme_group_size=len(dataset.index2rhyme_group) if args.task == 'rhyme' else None) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
|
103 |
+
model.load_state_dict(checkpoint['state_dict'])
|
104 |
+
model = model.to(args.device)
|
105 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=model_args.lr)
|
106 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
107 |
+
data_start_index = checkpoint['data_start_index']
|
108 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
109 |
+
.format(args.ckpt, checkpoint['epoch']))
|
110 |
+
# NOTE: just import pdb after loading the model here if you want to play with it, it's easy
|
111 |
+
# model.eval()
|
112 |
+
# import pdb; pdb.set_trace()
|
113 |
+
else:
|
114 |
+
model = Model(args, dataset.gpt_pad_id, len(dataset.index2word), rhyme_group_size=len(dataset.index2rhyme_group) if args.task == 'rhyme' else None, glove_embeddings=dataset.glove_embeddings)
|
115 |
+
model = model.to(args.device)
|
116 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
117 |
+
best_val_metric = 1e8 # lower is better for BCE
|
118 |
+
data_start_index = 0
|
119 |
+
print('num params', num_params(model))
|
120 |
+
criterion = nn.BCEWithLogitsLoss().to(args.device)
|
121 |
+
|
122 |
+
if args.evaluate:
|
123 |
+
epoch = 0
|
124 |
+
validate(model, dataset, criterion, epoch, args)
|
125 |
+
return
|
126 |
+
for epoch in range(args.epochs):
|
127 |
+
print("TRAINING: Epoch {} at {}".format(epoch, time.ctime()))
|
128 |
+
data_start_index = train(model, dataset, optimizer, criterion, epoch, args, data_start_index)
|
129 |
+
if epoch % args.validation_freq == 0:
|
130 |
+
print("VALIDATION: Epoch {} at {}".format(epoch, time.ctime()))
|
131 |
+
metric = validate(model, dataset, criterion, epoch, args)
|
132 |
+
|
133 |
+
if not args.debug:
|
134 |
+
if metric < best_val_metric:
|
135 |
+
print('new best val metric', metric)
|
136 |
+
best_val_metric = metric
|
137 |
+
save_checkpoint({
|
138 |
+
'epoch': epoch,
|
139 |
+
'state_dict': model.state_dict(),
|
140 |
+
'best_metric': best_val_metric,
|
141 |
+
'optimizer': optimizer.state_dict(),
|
142 |
+
'data_start_index': data_start_index,
|
143 |
+
'args': args
|
144 |
+
}, os.path.join(args.save_dir, 'model_best.pth.tar'))
|
145 |
+
save_checkpoint({
|
146 |
+
'epoch': epoch,
|
147 |
+
'state_dict': model.state_dict(),
|
148 |
+
'best_metric': metric,
|
149 |
+
'optimizer': optimizer.state_dict(),
|
150 |
+
'data_start_index': data_start_index,
|
151 |
+
'args': args
|
152 |
+
}, os.path.join(args.save_dir, 'model_epoch' + str(epoch) + '.pth.tar'))
|
153 |
+
|
154 |
+
|
155 |
+
if __name__=='__main__':
|
156 |
+
parser = ArgumentParser()
|
157 |
+
|
158 |
+
# DATA
|
159 |
+
parser.add_argument('--task', type=str, required=True, choices=['iambic', 'rhyme', 'newline', 'topic', 'formality', 'clickbait'])
|
160 |
+
parser.add_argument('--data_dir', type=str, required=True)
|
161 |
+
parser.add_argument('--glove_file', type=str, help='glove embedding init, for topic task')
|
162 |
+
|
163 |
+
# SAVE/LOAD
|
164 |
+
parser.add_argument('--save_dir', type=str, required=True, help='where to save ckpts')
|
165 |
+
parser.add_argument('--ckpt', type=str, default=None, help='load ckpt from file if given')
|
166 |
+
parser.add_argument('--dataset_info', type=str, help='saved dataset info')
|
167 |
+
parser.add_argument('--rhyme_info', type=str, help='saved dataset rhyme info, for a ckpt with task==rhyme')
|
168 |
+
|
169 |
+
# TRAINING
|
170 |
+
parser.add_argument('--batch_size', type=int, default=128)
|
171 |
+
parser.add_argument('--epochs', type=int, default=100)
|
172 |
+
parser.add_argument('--epoch_max_len', type=int, default=None, help='max batches per epoch if set, for more frequent validation')
|
173 |
+
parser.add_argument('--validation_freq', type=int, default=1, help='validate every X epochs')
|
174 |
+
parser.add_argument('--lr', type=float, default=1e-3, help='Adam learning rate')
|
175 |
+
parser.add_argument('--seed', type=int, default=1, help='random seed')
|
176 |
+
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
|
177 |
+
parser.add_argument('--num_workers', type=int, default=20, help='num workers for data loader')
|
178 |
+
parser.add_argument('--evaluate', action='store_true', default=False)
|
179 |
+
parser.add_argument('--debug', action='store_true', default=False)
|
180 |
+
|
181 |
+
# PRINTING
|
182 |
+
parser.add_argument('--train_print_freq', type=int, default=100, help='how often to print metrics (every X batches)')
|
183 |
+
|
184 |
+
args = parser.parse_args()
|
185 |
+
|
186 |
+
random.seed(args.seed)
|
187 |
+
np.random.seed(args.seed)
|
188 |
+
torch.manual_seed(args.seed)
|
189 |
+
if args.evaluate:
|
190 |
+
assert args.ckpt is not None
|
191 |
+
|
192 |
+
main(args)
|
naacl-2021-fudge-controlled-generation/model.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.nn.utils.rnn import pad_sequence, pad_packed_sequence, pack_padded_sequence
|
7 |
+
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model, GPT2LMHeadModel, GPT2Config, GPT2ForSequenceClassification, GPT2LMHeadModel, MarianTokenizer
|
8 |
+
|
9 |
+
from constants import *
|
10 |
+
from util import pad_mask
|
11 |
+
from clickbait_classifier import BertClickbaitClassifier, ClickbaitConfig
|
12 |
+
|
13 |
+
class Model(nn.Module):
|
14 |
+
def __init__(self, args, gpt_pad_id, vocab_size, rhyme_group_size=None, glove_embeddings=None, verbose=True):
|
15 |
+
super(Model, self).__init__()
|
16 |
+
|
17 |
+
# self.topic = args.task == 'topic'
|
18 |
+
self.formality = args.task == 'formality'
|
19 |
+
self.iambic = args.task == 'iambic'
|
20 |
+
self.rhyme = args.task == 'rhyme'
|
21 |
+
self.newline = args.task == 'newline'
|
22 |
+
self.clickbait = args.task == 'clickbait'
|
23 |
+
# if self.topic:
|
24 |
+
# self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words
|
25 |
+
# if glove_embeddings is None:
|
26 |
+
# if verbose:
|
27 |
+
# print('initializing word embeddings from scratch')
|
28 |
+
# self.word_embed = nn.Embedding(vocab_size, GLOVE_DIM, padding_idx=0)
|
29 |
+
# else:
|
30 |
+
# if verbose:
|
31 |
+
# print('initializing word embeddings from glove')
|
32 |
+
# self.word_embed = nn.Embedding.from_pretrained(glove_embeddings, padding_idx=0)
|
33 |
+
# self.rnn = nn.LSTM(HIDDEN_DIM, RNN_DIM, num_layers=3, bidirectional=True)
|
34 |
+
# self.attention_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
|
35 |
+
# large_hidden_dim = HIDDEN_DIM
|
36 |
+
# self.embed_key_linear = nn.Linear(large_hidden_dim, HIDDEN_DIM)
|
37 |
+
# self.attention_value_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
|
38 |
+
# self.out_embed_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
|
39 |
+
# self.out_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
|
40 |
+
# self.out_linear2 = nn.Linear(HIDDEN_DIM + large_hidden_dim, HIDDEN_DIM)
|
41 |
+
# self.out_linear3 = nn.Linear(HIDDEN_DIM, 1)
|
42 |
+
# self.nonlinear = nn.ReLU()
|
43 |
+
# elif self.formality:
|
44 |
+
if self.formality:
|
45 |
+
self.marian_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=0) # 0 in marian is ''
|
46 |
+
self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False, dropout=0.5) # want it to be causal so we can learn all positions
|
47 |
+
self.out_linear = nn.Linear(HIDDEN_DIM, 1)
|
48 |
+
elif self.iambic:
|
49 |
+
self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id)
|
50 |
+
self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False, dropout=0) # want it to be causal so we can learn all positions
|
51 |
+
self.out_linear = nn.Linear(HIDDEN_DIM, 1)
|
52 |
+
elif self.rhyme:
|
53 |
+
self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words
|
54 |
+
self.word_embed = nn.Embedding(rhyme_group_size+1, GLOVE_DIM, padding_idx=0) # this embedding for future words will actually embed the rhyme group idx
|
55 |
+
self.rnn = nn.LSTM(HIDDEN_DIM, RNN_DIM, num_layers=3, bidirectional=True)
|
56 |
+
self.attention_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
|
57 |
+
large_hidden_dim = HIDDEN_DIM + COUNT_SYLLABLE_DIM
|
58 |
+
self.embed_key_linear = nn.Linear(large_hidden_dim, HIDDEN_DIM)
|
59 |
+
self.attention_value_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
|
60 |
+
self.out_embed_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
|
61 |
+
self.out_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
|
62 |
+
self.out_linear2 = nn.Linear(HIDDEN_DIM + large_hidden_dim, HIDDEN_DIM)
|
63 |
+
self.out_linear3 = nn.Linear(HIDDEN_DIM, 1)
|
64 |
+
self.count_syllable_embed = nn.Embedding(MAX_COUNT_SYLLABLE_DIST+1, COUNT_SYLLABLE_DIM)
|
65 |
+
self.nonlinear = nn.ReLU()
|
66 |
+
elif self.newline:
|
67 |
+
self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words
|
68 |
+
self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False)
|
69 |
+
self.count_syllable_embed = nn.Embedding(MAX_COUNT_SYLLABLE_DIST+1, COUNT_SYLLABLE_DIM)
|
70 |
+
self.out_linear = nn.Linear(HIDDEN_DIM + COUNT_SYLLABLE_DIM, HIDDEN_DIM)
|
71 |
+
self.out_linear2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
|
72 |
+
self.out_linear3 = nn.Linear(HIDDEN_DIM, 1)
|
73 |
+
self.nonlinear = nn.ReLU()
|
74 |
+
elif self.clickbait:
|
75 |
+
# mpnet_config = ClickbaitConfig(
|
76 |
+
# model_type="mpnet",
|
77 |
+
# pretrained_model="sentence-transformers/all-mpnet-base-v2",
|
78 |
+
# num_labels=1,
|
79 |
+
# dropout=0.2,
|
80 |
+
# inner_dim1=256,
|
81 |
+
# inner_dim2=32,
|
82 |
+
# max_length=25,
|
83 |
+
# load_pretrained=True,
|
84 |
+
# freeze_bert=False,
|
85 |
+
# )
|
86 |
+
#TODO add a checkpoint to Classifier
|
87 |
+
# print('add a checkpoint to Classifier')
|
88 |
+
checkpoint = args.checkpoint #'ckpt/clickbait_classifier/checkpoint-1464'
|
89 |
+
# self.classifier = BertClickbaitClassifier(config=mpnet_config).to(torch.device(args.device))
|
90 |
+
self.classifier = BertClickbaitClassifier.from_pretrained(checkpoint).to(torch.device(args.device))
|
91 |
+
else:
|
92 |
+
raise NotImplementedError # TODO honestly this can/should be refactored into different models
|
93 |
+
|
94 |
+
|
95 |
+
def forward(self, inputs, lengths=None, future_words=None, log_probs=None, syllables_to_go=None, future_word_num_syllables=None, rhyme_group_index=None, run_classifier=False, attention_mask=None):
|
96 |
+
"""
|
97 |
+
inputs: token ids, batch x seq, right-padded with 0s
|
98 |
+
lengths: lengths of inputs; batch
|
99 |
+
future_words: batch x N words to check if not predict next token, else batch
|
100 |
+
log_probs: N
|
101 |
+
syllables_to_go: batch
|
102 |
+
"""
|
103 |
+
# if self.topic:
|
104 |
+
# inputs = self.gpt_embed(inputs) # batch x seq x 300
|
105 |
+
# inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
|
106 |
+
# rnn_output, _ = self.rnn(inputs)
|
107 |
+
# rnn_output, _ = pad_packed_sequence(rnn_output)
|
108 |
+
# rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
|
109 |
+
# hidden = rnn_output
|
110 |
+
# attention_mask = pad_mask(lengths).permute(1, 0) # batch x seq
|
111 |
+
# embed = self.word_embed(future_words) # batch x N x 300
|
112 |
+
# embed_query = self.embed_key_linear(embed)
|
113 |
+
# attention_tensor = self.attention_linear(hidden).unsqueeze(2) * embed_query.unsqueeze(1) # batch x seq x N x 300
|
114 |
+
# attention_weights = F.softmax(attention_tensor.sum(dim=3), dim=1) # batch x seq x N
|
115 |
+
# attention_weights = attention_weights * attention_mask.unsqueeze(2)
|
116 |
+
# hidden = self.attention_value_linear(hidden)
|
117 |
+
# weighted_hidden = (hidden.unsqueeze(2) * attention_weights.unsqueeze(3)).sum(dim=1) # batch x seq x N x 768 -> batch x N x 768
|
118 |
+
# unnormalized_scores = (self.out_linear(weighted_hidden) * self.out_embed_linear(embed)) # batch x N x 300
|
119 |
+
# unnormalized_scores = torch.cat([unnormalized_scores, embed], dim=2)
|
120 |
+
# unnormalized_scores = self.nonlinear(self.out_linear2(self.nonlinear(unnormalized_scores)))
|
121 |
+
# unnormalized_scores = self.out_linear3(unnormalized_scores)
|
122 |
+
# scores = unnormalized_scores.squeeze(2) - log_probs.unsqueeze(0)
|
123 |
+
# return scores # batch x N of normalized scores or batch x
|
124 |
+
# elif self.formality:
|
125 |
+
if self.formality:
|
126 |
+
inputs = self.marian_embed(inputs)
|
127 |
+
inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
|
128 |
+
rnn_output, _ = self.rnn(inputs)
|
129 |
+
rnn_output, _ = pad_packed_sequence(rnn_output)
|
130 |
+
rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
|
131 |
+
return self.out_linear(rnn_output).squeeze(2)
|
132 |
+
elif self.iambic:
|
133 |
+
inputs = self.gpt_embed(inputs)
|
134 |
+
inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
|
135 |
+
rnn_output, _ = self.rnn(inputs)
|
136 |
+
rnn_output, _ = pad_packed_sequence(rnn_output)
|
137 |
+
rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
|
138 |
+
return self.out_linear(rnn_output).squeeze(2)
|
139 |
+
elif self.rhyme:
|
140 |
+
inputs = self.gpt_embed(inputs) # batch x seq x 300
|
141 |
+
inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
|
142 |
+
rnn_output, _ = self.rnn(inputs)
|
143 |
+
rnn_output, _ = pad_packed_sequence(rnn_output)
|
144 |
+
rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
|
145 |
+
hidden = rnn_output
|
146 |
+
attention_mask = pad_mask(lengths).permute(1, 0) # batch x seq
|
147 |
+
embed = self.word_embed(future_words) # batch x N x 300
|
148 |
+
embedded_syllables_to_go = self.count_syllable_embed(syllables_to_go).unsqueeze(1).expand(-1, embed.shape[1], -1) # batch x N x 100
|
149 |
+
auxiliary_embed = embedded_syllables_to_go
|
150 |
+
embed_query = self.embed_key_linear(torch.cat([embed, auxiliary_embed], dim=2))
|
151 |
+
attention_tensor = self.attention_linear(hidden).unsqueeze(2) * embed_query.unsqueeze(1) # batch x seq x N x 300
|
152 |
+
attention_weights = F.softmax(attention_tensor.sum(dim=3), dim=1) # batch x seq x N
|
153 |
+
attention_weights = attention_weights * attention_mask.unsqueeze(2)
|
154 |
+
hidden = self.attention_value_linear(hidden)
|
155 |
+
weighted_hidden = (hidden.unsqueeze(2) * attention_weights.unsqueeze(3)).sum(dim=1) # batch x seq x N x 768 -> batch x N x 768
|
156 |
+
unnormalized_scores = (self.out_linear(weighted_hidden) * self.out_embed_linear(embed)) # batch x N x 300
|
157 |
+
unnormalized_scores = torch.cat([unnormalized_scores, embed, auxiliary_embed], dim=2)
|
158 |
+
unnormalized_scores = self.nonlinear(self.out_linear2(self.nonlinear(unnormalized_scores)))
|
159 |
+
unnormalized_scores = self.out_linear3(unnormalized_scores)
|
160 |
+
scores = unnormalized_scores.squeeze(2) - log_probs.unsqueeze(0)
|
161 |
+
return scores # batch x N of normalized scores or batch x
|
162 |
+
elif self.newline:
|
163 |
+
inputs = self.gpt_embed(inputs) # batch x seq x 300
|
164 |
+
inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
|
165 |
+
rnn_output, _ = self.rnn(inputs)
|
166 |
+
rnn_output, _ = pad_packed_sequence(rnn_output)
|
167 |
+
rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
|
168 |
+
hidden = torch.cat([rnn_output, self.count_syllable_embed(syllables_to_go).unsqueeze(1).expand(-1, rnn_output.shape[1], -1)], dim=2)
|
169 |
+
return self.out_linear3(self.nonlinear(self.out_linear2(self.nonlinear(self.out_linear(hidden))))).squeeze(2)
|
170 |
+
elif self.clickbait:
|
171 |
+
|
172 |
+
input_ids = torch.tensor(inputs)
|
173 |
+
classifer_output = self.classifier(input_ids = input_ids, attention_mask = attention_mask).logits
|
174 |
+
|
175 |
+
classifer_output = classifer_output[None,:,:] # batch x seq x 300
|
176 |
+
# return self.out_linear(rnn_output).squeeze(2)
|
177 |
+
return classifer_output.squeeze(2)
|
178 |
+
|
179 |
+
else:
|
180 |
+
raise NotImplementedError
|
181 |
+
|
182 |
+
|
naacl-2021-fudge-controlled-generation/poetry_data/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
`couplet_prefixes.txt` contains the 13th line of each of Shakespeare's sonnets. `couplet_ends.txt` contains the 14th. (Each 14-line sonnet ends with a couplet in the last two lines). The prefixes are our test set prefixes for the couplet completion task; the ends are Shakespeare's outputs.
|
naacl-2021-fudge-controlled-generation/poetry_data/couplet_ends.txt
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
To eat the world's due, by the grave and thee.
|
2 |
+
And see thy blood warm when thou feel'st it cold.
|
3 |
+
Die single, and thine image dies with thee.
|
4 |
+
Which, used, lives th' executor to be.
|
5 |
+
Leese but their show; their substance still lives sweet.
|
6 |
+
To be death's conquest and make worms thine heir.
|
7 |
+
Unlook'd on diest, unless thou get a son.
|
8 |
+
Sings this to thee: 'thou single wilt prove none.'
|
9 |
+
That on himself such murderous shame commits.
|
10 |
+
That beauty still may live in thine or thee.
|
11 |
+
Thou shouldst print more, not let that copy die.
|
12 |
+
Save breed, to brave him when he takes thee hence.
|
13 |
+
You had a father: let your son say so.
|
14 |
+
Thy end is truth's and beauty's doom and date.
|
15 |
+
As he takes from you, I engraft you new.
|
16 |
+
And you must live, drawn by your own sweet skill.
|
17 |
+
You should live twice; in it and in my rhyme.
|
18 |
+
So long lives this and this gives life to thee.
|
19 |
+
My love shall in my verse ever live young.
|
20 |
+
Mine be thy love and thy love's use their treasure.
|
21 |
+
I will not praise that purpose not to sell.
|
22 |
+
Thou gavest me thine, not to give back again.
|
23 |
+
To hear with eyes belongs to love's fine wit.
|
24 |
+
They draw but what they see, know not the heart.
|
25 |
+
Where I may not remove nor be removed.
|
26 |
+
Till then not show my head where thou mayst prove me.
|
27 |
+
For thee and for myself no quiet find.
|
28 |
+
And night doth nightly make grief's strength seem stronger.
|
29 |
+
That then I scorn to change my state with kings.
|
30 |
+
All losses are restored and sorrows end.
|
31 |
+
And thou, all they, hast all the all of me.
|
32 |
+
Theirs for their style I'll read, his for his love.'
|
33 |
+
Suns of the world may stain when heaven's sun staineth.
|
34 |
+
And they are rich and ransom all ill deeds.
|
35 |
+
To that sweet thief which sourly robs from me.
|
36 |
+
As, thou being mine, mine is thy good report.
|
37 |
+
This wish I have; then ten times happy me!
|
38 |
+
The pain be mine, but thine shall be the praise.
|
39 |
+
By praising him here who doth hence remain!
|
40 |
+
Kill me with spites; yet we must not be foes.
|
41 |
+
Thine, by thy beauty being false to me.
|
42 |
+
Sweet flattery! then she loves but me alone.
|
43 |
+
And nights bright days when dreams do show thee me.
|
44 |
+
But heavy tears, badges of either's woe.
|
45 |
+
I send them back again and straight grow sad.
|
46 |
+
And my heart's right thy inward love of heart.
|
47 |
+
Awakes my heart to heart's and eye's delight.
|
48 |
+
For truth proves thievish for a prize so dear.
|
49 |
+
Since why to love I can allege no cause.
|
50 |
+
My grief lies onward and my joy behind.
|
51 |
+
Towards thee I'll run, and give him leave to go.
|
52 |
+
Being had, to triumph, being lack'd, to hope.
|
53 |
+
But you like none, none you, for constant heart.
|
54 |
+
When that shall fade, my verse distills your truth.
|
55 |
+
You live in this, and dwell in lover's eyes.
|
56 |
+
Makes summer's welcome thrice more wish'd, more rare.
|
57 |
+
Though you do any thing, he thinks no ill.
|
58 |
+
Not blame your pleasure, be it ill or well.
|
59 |
+
To subjects worse have given admiring praise.
|
60 |
+
Praising thy worth, despite his cruel hand.
|
61 |
+
From me far off, with others all too near.
|
62 |
+
Painting my age with beauty of thy days.
|
63 |
+
And they shall live, and he in them still green.
|
64 |
+
But weep to have that which it fears to lose.
|
65 |
+
That in black ink my love may still shine bright.
|
66 |
+
Save that, to die, I leave my love alone.
|
67 |
+
In days long since, before these last so bad.
|
68 |
+
To show false Art what beauty was of yore.
|
69 |
+
The solve is this, that thou dost common grow.
|
70 |
+
Then thou alone kingdoms of hearts shouldst owe.
|
71 |
+
And mock you with me after I am gone.
|
72 |
+
And so should you, to love things nothing worth.
|
73 |
+
To love that well which thou must leave ere long.
|
74 |
+
And that is this, and this with thee remains.
|
75 |
+
Or gluttoning on all, or all away.
|
76 |
+
So is my love still telling what is told.
|
77 |
+
Shall profit thee and much enrich thy book.
|
78 |
+
As high as learning my rude ignorance.
|
79 |
+
Since what he owes thee thou thyself dost pay.
|
80 |
+
The worst was this; my love was my decay.
|
81 |
+
Where breath most breathes, even in the mouths of men.
|
82 |
+
Where cheeks need blood; in thee it is abused.
|
83 |
+
Than both your poets can in praise devise.
|
84 |
+
Being fond on praise, which makes your praises worse.
|
85 |
+
Me for my dumb thoughts, speaking in effect.
|
86 |
+
Then lack'd I matter; that enfeebled mine.
|
87 |
+
In sleep a king, but waking no such matter.
|
88 |
+
That for thy right myself will bear all wrong.
|
89 |
+
For I must ne'er love him whom thou dost hate.
|
90 |
+
Compared with loss of thee will not seem so.
|
91 |
+
All this away and me most wretched make.
|
92 |
+
Thou mayst be false, and yet I know it not.
|
93 |
+
if thy sweet virtue answer not thy show!
|
94 |
+
Lilies that fester smell far worse than weeds.
|
95 |
+
The hardest knife ill-used doth lose his edge.
|
96 |
+
As, thou being mine, mine is thy good report.
|
97 |
+
That leaves look pale, dreading the winter's near.
|
98 |
+
As with your shadow I with these did play:
|
99 |
+
But sweet or colour it had stol'n from thee.
|
100 |
+
So thou prevent'st his scythe and crooked knife.
|
101 |
+
To make him seem long hence as he shows now.
|
102 |
+
Because I would not dull you with my song.
|
103 |
+
Your own glass shows you when you look in it.
|
104 |
+
Ere you were born was beauty's summer dead.
|
105 |
+
Which three till now never kept seat in one.
|
106 |
+
Had eyes to wonder, but lack tongues to praise.
|
107 |
+
When tyrants' crests and tombs of brass are spent.
|
108 |
+
Where time and outward form would show it dead.
|
109 |
+
Save thou, my rose; in it thou art my all.
|
110 |
+
Even to thy pure and most most loving breast.
|
111 |
+
Even that your pity is enough to cure me.
|
112 |
+
That all the world besides methinks are dead.
|
113 |
+
My most true mind thus makes mine eye untrue.
|
114 |
+
That mine eye loves it and doth first begin.
|
115 |
+
To give full growth to that which still doth grow?
|
116 |
+
I never writ, nor no man ever loved.
|
117 |
+
The constancy and virtue of your love.
|
118 |
+
Drugs poison him that so fell sick of you.
|
119 |
+
And gain by ill thrice more than I have spent.
|
120 |
+
Mine ransoms yours, and yours must ransom me.
|
121 |
+
All men are bad, and in their badness reign.
|
122 |
+
Were to import forgetfulness in me.
|
123 |
+
I will be true, despite thy scythe and thee.
|
124 |
+
Which die for goodness, who have lived for crime.
|
125 |
+
When most impeach'd stands least in thy control.
|
126 |
+
And her quietus is to render thee.
|
127 |
+
That every tongue says beauty should look so.
|
128 |
+
Give them thy fingers, me thy lips to kiss.
|
129 |
+
To shun the heaven that leads men to this hell.
|
130 |
+
As any she belied with false compare.
|
131 |
+
And thence this slander, as I think, proceeds.
|
132 |
+
And all they foul that thy complexion lack.
|
133 |
+
Perforce am thine, and all that is in me.
|
134 |
+
He pays the whole, and yet am I not free.
|
135 |
+
Think all but one, and me in that one 'Will.'
|
136 |
+
And then thou lovest me, for my name is 'Will.'
|
137 |
+
And to this false plague are they now transferr'd.
|
138 |
+
And in our faults by lies we flatter'd be.
|
139 |
+
Kill me outright with looks and rid my pain.
|
140 |
+
Bear thine eyes straight, though thy proud heart go wide.
|
141 |
+
That she that makes me sin awards me pain.
|
142 |
+
By self-example mayst thou be denied!
|
143 |
+
If thou turn back, and my loud crying still.
|
144 |
+
Till my bad angel fire my good one out.
|
145 |
+
And saved my life, saying 'not you.'
|
146 |
+
And Death once dead, there's no more dying then.
|
147 |
+
Who art as black as hell, as dark as night.
|
148 |
+
Lest eyes well-seeing thy foul faults should find.
|
149 |
+
Those that can see thou lovest, and I am blind.
|
150 |
+
More worthy I to be beloved of thee.
|
151 |
+
Her 'love' for whose dear love I rise and fall.
|
152 |
+
To swear against the truth so foul a lie!
|
153 |
+
Where Cupid got new fire--my mistress' eyes.
|
154 |
+
Love's fire heats water, water cools not love.
|
naacl-2021-fudge-controlled-generation/poetry_data/couplet_prefixes.txt
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Pity the world, or else this glutton be,
|
2 |
+
This were to be new made when thou art old,
|
3 |
+
But if thou live, remember'd not to be,
|
4 |
+
Thy unused beauty must be tomb'd with thee,
|
5 |
+
But flowers distill'd though they with winter meet,
|
6 |
+
Be not self-will'd, for thou art much too fair
|
7 |
+
So thou, thyself out-going in thy noon,
|
8 |
+
Whose speechless song, being many, seeming one,
|
9 |
+
No love toward others in that bosom sits
|
10 |
+
Make thee another self, for love of me,
|
11 |
+
She carved thee for her seal, and meant thereby
|
12 |
+
And nothing 'gainst Time's scythe can make defence
|
13 |
+
O, none but unthrifts! Dear my love, you know
|
14 |
+
Or else of thee this I prognosticate:
|
15 |
+
And all in war with Time for love of you,
|
16 |
+
To give away yourself keeps yourself still,
|
17 |
+
But were some child of yours alive that time,
|
18 |
+
So long as men can breathe or eyes can see,
|
19 |
+
Yet, do thy worst, old Time: despite thy wrong,
|
20 |
+
But since she prick'd thee out for women's pleasure,
|
21 |
+
Let them say more than like of hearsay well;
|
22 |
+
Presume not on thy heart when mine is slain;
|
23 |
+
O, learn to read what silent love hath writ:
|
24 |
+
Yet eyes this cunning want to grace their art;
|
25 |
+
Then happy I, that love and am beloved
|
26 |
+
Then may I dare to boast how I do love thee;
|
27 |
+
Lo! thus, by day my limbs, by night my mind,
|
28 |
+
But day doth daily draw my sorrows longer
|
29 |
+
For thy sweet love remember'd such wealth brings
|
30 |
+
But if the while I think on thee, dear friend,
|
31 |
+
Their images I loved I view in thee,
|
32 |
+
But since he died and poets better prove,
|
33 |
+
Yet him for this my love no whit disdaineth;
|
34 |
+
Ah! but those tears are pearl which thy love sheds,
|
35 |
+
That I an accessary needs must be
|
36 |
+
But do not so; I love thee in such sort
|
37 |
+
Look, what is best, that best I wish in thee:
|
38 |
+
If my slight Muse do please these curious days,
|
39 |
+
And that thou teachest how to make one twain,
|
40 |
+
Lascivious grace, in whom all ill well shows,
|
41 |
+
Hers by thy beauty tempting her to thee,
|
42 |
+
But here's the joy; my friend and I are one;
|
43 |
+
All days are nights to see till I see thee,
|
44 |
+
Receiving nought by elements so slow
|
45 |
+
This told, I joy; but then no longer glad,
|
46 |
+
As thus; mine eye's due is thy outward part,
|
47 |
+
Or, if they sleep, thy picture in my sight
|
48 |
+
And even thence thou wilt be stol'n, I fear,
|
49 |
+
To leave poor me thou hast the strength of laws,
|
50 |
+
For that same groan doth put this in my mind;
|
51 |
+
Since from thee going he went wilful-slow,
|
52 |
+
Blessed are you, whose worthiness gives scope,
|
53 |
+
In all external grace you have some part,
|
54 |
+
And so of you, beauteous and lovely youth,
|
55 |
+
So, till the judgment that yourself arise,
|
56 |
+
Else call it winter, which being full of care
|
57 |
+
So true a fool is love that in your will,
|
58 |
+
I am to wait, though waiting so be hell;
|
59 |
+
O, sure I am, the wits of former days
|
60 |
+
And yet to times in hope my verse shall stand,
|
61 |
+
For thee watch I whilst thou dost wake elsewhere,
|
62 |
+
'Tis thee, myself, that for myself I praise,
|
63 |
+
His beauty shall in these black lines be seen,
|
64 |
+
This thought is as a death, which cannot choose
|
65 |
+
O, none, unless this miracle have might,
|
66 |
+
Tired with all these, from these would I be gone,
|
67 |
+
O, him she stores, to show what wealth she had
|
68 |
+
And him as for a map doth Nature store,
|
69 |
+
But why thy odour matcheth not thy show,
|
70 |
+
If some suspect of ill mask'd not thy show,
|
71 |
+
Lest the wise world should look into your moan
|
72 |
+
For I am shamed by that which I bring forth,
|
73 |
+
This thou perceivest, which makes thy love more strong,
|
74 |
+
The worth of that is that which it contains,
|
75 |
+
Thus do I pine and surfeit day by day,
|
76 |
+
For as the sun is daily new and old,
|
77 |
+
These offices, so oft as thou wilt look,
|
78 |
+
But thou art all my art and dost advance
|
79 |
+
Then thank him not for that which he doth say,
|
80 |
+
Then if he thrive and I be cast away,
|
81 |
+
You still shall live--such virtue hath my pen--
|
82 |
+
And their gross painting might be better used
|
83 |
+
There lives more life in one of your fair eyes
|
84 |
+
You to your beauteous blessings add a curse,
|
85 |
+
Then others for the breath of words respect,
|
86 |
+
But when your countenance fill'd up his line,
|
87 |
+
Thus have I had thee, as a dream doth flatter,
|
88 |
+
Such is my love, to thee I so belong,
|
89 |
+
For thee against myself I'll vow debate,
|
90 |
+
And other strains of woe, which now seem woe,
|
91 |
+
Wretched in this alone, that thou mayst take
|
92 |
+
But what's so blessed-fair that fears no blot?
|
93 |
+
How like Eve's apple doth thy beauty grow,
|
94 |
+
For sweetest things turn sourest by their deeds;
|
95 |
+
Take heed, dear heart, of this large privilege;
|
96 |
+
But do not so; I love thee in such sort
|
97 |
+
Or, if they sing, 'tis with so dull a cheer
|
98 |
+
Yet seem'd it winter still, and, you away,
|
99 |
+
More flowers I noted, yet I none could see
|
100 |
+
Give my love fame faster than Time wastes life;
|
101 |
+
Then do thy office, Muse; I teach thee how
|
102 |
+
Therefore like her I sometime hold my tongue,
|
103 |
+
And more, much more, than in my verse can sit
|
104 |
+
For fear of which, hear this, thou age unbred;
|
105 |
+
'Fair, kind, and true,' have often lived alone,
|
106 |
+
For we, which now behold these present days,
|
107 |
+
And thou in this shalt find thy monument,
|
108 |
+
Finding the first conceit of love there bred
|
109 |
+
For nothing this wide universe I call,
|
110 |
+
Then give me welcome, next my heaven the best,
|
111 |
+
Pity me then, dear friend, and I assure ye
|
112 |
+
You are so strongly in my purpose bred
|
113 |
+
Incapable of more, replete with you,
|
114 |
+
If it be poison'd, 'tis the lesser sin
|
115 |
+
Love is a babe; then might I not say so,
|
116 |
+
If this be error and upon me proved,
|
117 |
+
Since my appeal says I did strive to prove
|
118 |
+
But thence I learn, and find the lesson true,
|
119 |
+
So I return rebuked to my content
|
120 |
+
But that your trespass now becomes a fee;
|
121 |
+
Unless this general evil they maintain,
|
122 |
+
To keep an adjunct to remember thee
|
123 |
+
This I do vow and this shall ever be;
|
124 |
+
To this I witness call the fools of time,
|
125 |
+
Hence, thou suborn'd informer! a true soul
|
126 |
+
Her audit, though delay'd, answer'd must be,
|
127 |
+
Yet so they mourn, becoming of their woe,
|
128 |
+
Since saucy jacks so happy are in this,
|
129 |
+
All this the world well knows; yet none knows well
|
130 |
+
And yet, by heaven, I think my love as rare
|
131 |
+
In nothing art thou black save in thy deeds,
|
132 |
+
Then will I swear beauty herself is black
|
133 |
+
And yet thou wilt; for I, being pent in thee,
|
134 |
+
Him have I lost; thou hast both him and me:
|
135 |
+
Let no unkind, no fair beseechers kill;
|
136 |
+
Make but my name thy love, and love that still,
|
137 |
+
In things right true my heart and eyes have erred,
|
138 |
+
Therefore I lie with her and she with me,
|
139 |
+
Yet do not so; but since I am near slain,
|
140 |
+
That I may not be so, nor thou belied,
|
141 |
+
Only my plague thus far I count my gain,
|
142 |
+
If thou dost seek to have what thou dost hide,
|
143 |
+
So will I pray that thou mayst have thy 'Will,'
|
144 |
+
Yet this shall I ne'er know, but live in doubt,
|
145 |
+
'I hate' from hate away she threw,
|
146 |
+
So shalt thou feed on Death, that feeds on men,
|
147 |
+
For I have sworn thee fair and thought thee bright,
|
148 |
+
O cunning Love! with tears thou keep'st me blind,
|
149 |
+
But, love, hate on, for now I know thy mind;
|
150 |
+
If thy unworthiness raised love in me,
|
151 |
+
No want of conscience hold it that I call
|
152 |
+
For I have sworn thee fair; more perjured I,
|
153 |
+
But found no cure: the bath for my help lies
|
154 |
+
Came there for cure, and this by that I prove,
|
naacl-2021-fudge-controlled-generation/poetry_util.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import string
|
2 |
+
|
3 |
+
import pronouncing
|
4 |
+
from Phyme import Phyme
|
5 |
+
phyme = Phyme()
|
6 |
+
|
7 |
+
from constants import *
|
8 |
+
|
9 |
+
def is_iambic(phrase):
|
10 |
+
"""
|
11 |
+
check that we satisfy iambic meter.
|
12 |
+
return 1 if so, otherwise 0.
|
13 |
+
definitely an imperfect check...
|
14 |
+
if we end up needing to check a word that's not in the CMU dictionary, just return 0.
|
15 |
+
"""
|
16 |
+
meter = ''
|
17 |
+
for word in phrase.split():
|
18 |
+
word = word.strip().strip(string.punctuation).lower()
|
19 |
+
try:
|
20 |
+
phones_list = pronouncing.phones_for_word(word)
|
21 |
+
stresses = pronouncing.stresses(phones_list[0])
|
22 |
+
if len(stresses) == 1:
|
23 |
+
if stresses == '1':
|
24 |
+
stresses = '2' # allow ambiguity for 1-syllable words with stress 1
|
25 |
+
meter += stresses # just default to the first pronunciation if > 1 given
|
26 |
+
except:
|
27 |
+
return 0 # word not found
|
28 |
+
meter = [int(x) for x in meter]
|
29 |
+
even_stresses_full = [meter[i] for i in range(0, len(meter), 2)]
|
30 |
+
odd_stresses_full = [meter[i] for i in range(1, len(meter), 2)]
|
31 |
+
even_stresses = set(even_stresses_full)
|
32 |
+
odd_stresses = set(odd_stresses_full)
|
33 |
+
if 0 in odd_stresses:
|
34 |
+
return 0
|
35 |
+
if 1 in even_stresses:
|
36 |
+
return 0
|
37 |
+
return 1
|
38 |
+
|
39 |
+
|
40 |
+
def count_syllables(words):
|
41 |
+
syllables = 0
|
42 |
+
for word in words.split():
|
43 |
+
word = word.strip().strip(string.punctuation)
|
44 |
+
try:
|
45 |
+
phones_list = pronouncing.phones_for_word(word)
|
46 |
+
stresses = pronouncing.stresses(phones_list[0])
|
47 |
+
syllables += min(MAX_SYLLABLES_PER_WORD, len(stresses))
|
48 |
+
except:
|
49 |
+
# if we don't know, just do a quick approximation here; it shouldn't come up too often
|
50 |
+
syllables += min(MAX_SYLLABLES_PER_WORD, round(len(word) / 3))
|
51 |
+
return syllables
|
52 |
+
|
53 |
+
|
54 |
+
def get_rhymes(word):
|
55 |
+
# throws exception if word not in the rhyme dict (rare)
|
56 |
+
rhymes = []
|
57 |
+
rhyme_dict = phyme.get_perfect_rhymes(word)
|
58 |
+
for length_dict in rhyme_dict.values():
|
59 |
+
for word in length_dict:
|
60 |
+
if '(' in word: # sometimes you have stuff like preferred(1) where they indicate a particular pronunciation
|
61 |
+
rhymes.append(word.split('(')[0])
|
62 |
+
else:
|
63 |
+
rhymes.append(word)
|
64 |
+
return sorted(list(set(rhymes)))
|
65 |
+
|
66 |
+
|
67 |
+
def get_rhyme_group(word):
|
68 |
+
sorted_rhyme_list = get_rhymes(word)
|
69 |
+
return ' '.join(sorted_rhyme_list)
|
70 |
+
|
71 |
+
|
72 |
+
def perfect_rhyme_end(s1, s2):
|
73 |
+
ending_word1 = s1.split()[-1].strip(string.punctuation)
|
74 |
+
ending_word2 = s2.split()[-1].strip(string.punctuation)
|
75 |
+
try:
|
76 |
+
return get_rhyme_group(ending_word1) == get_rhyme_group(ending_word2)
|
77 |
+
except:
|
78 |
+
return False # unknown words
|
79 |
+
|
80 |
+
if __name__=='__main__':
|
81 |
+
result = is_iambic('Shall I compare thee to a summer day')
|
82 |
+
result2 = count_syllables('Shall I compare thee to a summer day')
|
83 |
+
import pdb; pdb.set_trace()
|
naacl-2021-fudge-controlled-generation/predict_clickbait.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
import pickle
|
5 |
+
import math
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
|
8 |
+
from typing import Iterable, List, Optional, Tuple
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from transformers import AutoTokenizer, AutoModelWithLMHead
|
16 |
+
from torch import Tensor
|
17 |
+
|
18 |
+
from data import Dataset
|
19 |
+
from model import Model
|
20 |
+
from util import num_params
|
21 |
+
from constants import *
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained('google/pegasus-xsum')
|
26 |
+
classifier_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
|
27 |
+
|
28 |
+
|
29 |
+
def main(args):
|
30 |
+
with open(args.dataset_info, 'rb') as rf:
|
31 |
+
dataset_info = pickle.load(rf)
|
32 |
+
|
33 |
+
article_content = """Australian actor Guy Pearce will return for the iconic soap Neighbours finale on August 1 to reprise his role as Mike Young.
|
34 |
+
Guy, 54, played the troubled Mike from 1986 to 1989, and is now set to make a comeback on the show after 33 years, Metro.co.uk reports.
|
35 |
+
The star's character arcs explored the implications of domestic abuse, student-teacher relationships and dealing with loss of loved ones.
|
36 |
+
Speaking to Metro.co.uk, Guy said: 'It is very exciting and surreal at the same time being back on set again, however it feels like coming home.
|
37 |
+
'It's where it all started for me professionally. I've been asked to come back on occasions over the years and wondered if it was the right thing
|
38 |
+
to do, but once I knew the show was finishing, I knew I had to do it.'He added that there is 'nothing like being here all together again'
|
39 |
+
, even though he's had a chance to catch-up with other cast members."""
|
40 |
+
|
41 |
+
tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
|
42 |
+
pad_id = tokenizer.encode(PAD_TOKEN)[0]
|
43 |
+
|
44 |
+
#For loading Clickbait summarizer
|
45 |
+
model = AutoModelWithLMHead.from_pretrained(args.model_string, return_dict=True).to(args.device)
|
46 |
+
|
47 |
+
model.eval()
|
48 |
+
|
49 |
+
checkpoint = torch.load(args.ckpt, map_location=args.device)
|
50 |
+
model_args = checkpoint['args']
|
51 |
+
conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
|
52 |
+
conditioning_model.load_state_dict(checkpoint['state_dict'])
|
53 |
+
conditioning_model = conditioning_model.to(args.device)
|
54 |
+
conditioning_model.eval()
|
55 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
56 |
+
.format(args.ckpt, checkpoint['epoch']))
|
57 |
+
print('num params', num_params(conditioning_model))
|
58 |
+
|
59 |
+
while True:
|
60 |
+
results = generate_clickbait(model,
|
61 |
+
tokenizer,
|
62 |
+
conditioning_model,
|
63 |
+
[args.input_text],
|
64 |
+
dataset_info,
|
65 |
+
precondition_topk=args.precondition_topk,
|
66 |
+
do_sample=args.do_sample,
|
67 |
+
length_cutoff=args.length_cutoff,
|
68 |
+
condition_lambda=args.condition_lambda,
|
69 |
+
article_content=article_content,
|
70 |
+
device=args.device)
|
71 |
+
# print(results)
|
72 |
+
import pdb; pdb.set_trace()
|
73 |
+
|
74 |
+
|
75 |
+
def generate_clickbait(model,
|
76 |
+
tokenizer,
|
77 |
+
conditioning_model,
|
78 |
+
input_text,
|
79 |
+
dataset_info,
|
80 |
+
precondition_topk,
|
81 |
+
length_cutoff,
|
82 |
+
condition_lambda=1.0,
|
83 |
+
article_content=None,
|
84 |
+
device='cuda'):
|
85 |
+
with torch.no_grad():
|
86 |
+
batch_size = len(input_text)
|
87 |
+
# encoded_input_article = [tokenizer.encode(article_content, return_tensors='pt',add_special_tokens=False).to(device)] # batch x seq
|
88 |
+
max_input_length = 512
|
89 |
+
encoded_input_article = tokenizer(article_content, return_tensors='pt',add_special_tokens=False, max_length = max_input_length).to(device) # batch x seq
|
90 |
+
# encoded_input_article = torch.cat(encoded_input_article, dim=0)
|
91 |
+
# attention_mask = encoded_input_article.new_ones(encoded_input_article.shape).to(device)
|
92 |
+
|
93 |
+
# CHANGE=ko
|
94 |
+
encoded_input = tokenizer('<pad>', return_tensors='pt',add_special_tokens=False).to(device) # batch x seq
|
95 |
+
# encoded_input = tokenizer('<pad>'+ input_text[0], return_tensors='pt',add_special_tokens=False).to(device) # batch x seq
|
96 |
+
# encoded_input = torch.cat(encoded_input, dim=0)
|
97 |
+
encoded_input = encoded_input['input_ids']
|
98 |
+
|
99 |
+
|
100 |
+
lengths = torch.LongTensor([encoded_input.shape[1]]).to(device)
|
101 |
+
# lengths = 1
|
102 |
+
|
103 |
+
past = None
|
104 |
+
use_cache = True
|
105 |
+
|
106 |
+
# CHANGE
|
107 |
+
# model_kwargs = {'encoder_outputs': model.get_encoder()(encoded_input_article, attention_mask=attention_mask)}
|
108 |
+
model_kwargs = {'encoder_outputs': model.get_encoder()(input_ids=encoded_input_article['input_ids'],
|
109 |
+
attention_mask=encoded_input_article['attention_mask'],
|
110 |
+
return_dict=True,
|
111 |
+
output_attentions=False,
|
112 |
+
output_hidden_states=False),
|
113 |
+
}
|
114 |
+
|
115 |
+
while lengths.max() < length_cutoff:
|
116 |
+
model_inputs = model.prepare_inputs_for_generation(
|
117 |
+
input_ids = encoded_input_article['input_ids'],
|
118 |
+
decoder_input_ids=encoded_input,
|
119 |
+
# past=past,
|
120 |
+
attention_mask=encoded_input_article['attention_mask'],
|
121 |
+
use_cache=use_cache,
|
122 |
+
**model_kwargs
|
123 |
+
)
|
124 |
+
|
125 |
+
outputs = model(**model_inputs, return_dict=True)
|
126 |
+
logits = outputs.logits[:, -1, :]
|
127 |
+
|
128 |
+
if "past_key_values" in outputs:
|
129 |
+
model_kwargs["past"] = outputs.past_key_values
|
130 |
+
|
131 |
+
# logits = model(encoded_input)[0][:, -1, :] # batch x vocab
|
132 |
+
top_logits, top_indices = logits.topk(precondition_topk, dim=1) # batch x topk
|
133 |
+
new_input_candidates = torch.cat([encoded_input.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2) # batch x topk x seq+1
|
134 |
+
expanded_lengths = (lengths + 1).unsqueeze(1).expand(batch_size, precondition_topk) # batch x topk
|
135 |
+
|
136 |
+
if condition_lambda == 0:
|
137 |
+
condition_logits = torch.zeros_like(top_logits).float()
|
138 |
+
condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
|
139 |
+
else:
|
140 |
+
decoded_outputs = tokenizer.batch_decode(new_input_candidates.view(-1, new_input_candidates.size(-1)), clean_up_tokenization_spaces=False)
|
141 |
+
resulting_tokenization = classifier_tokenizer(decoded_outputs, add_special_tokens=False, padding='longest')
|
142 |
+
encoded_with_classifier = resulting_tokenization['input_ids']
|
143 |
+
attention_mask = torch.tensor(resulting_tokenization['attention_mask']).to(model.device)
|
144 |
+
tplus1_candidates_classifier = torch.tensor(encoded_with_classifier).view(batch_size, precondition_topk, -1).to(model.device)
|
145 |
+
|
146 |
+
condition_logits = conditioning_model(tplus1_candidates_classifier.flatten(0, 1), # batch*topk x seq+1
|
147 |
+
expanded_lengths.flatten(0, 1), # batch*topk
|
148 |
+
None,
|
149 |
+
None,
|
150 |
+
None,
|
151 |
+
attention_mask=attention_mask
|
152 |
+
)
|
153 |
+
condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
|
154 |
+
condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs
|
155 |
+
|
156 |
+
condition_logits = torch.mean(condition_logits, dim=2)
|
157 |
+
full_logits = top_logits + condition_logits * condition_lambda # batch x topk
|
158 |
+
post_logits, post_indices = full_logits.topk(precondition_topk, dim=1)
|
159 |
+
post_probs = F.softmax(post_logits, dim=1)
|
160 |
+
# index_into_top_indices = post_indices[torch.arange(batch_size).to(post_indices.device), torch.multinomial(post_probs, 1).flatten()] # batch
|
161 |
+
index_into_top_indices = post_indices[:, torch.multinomial(post_probs, 1).flatten()] # batch
|
162 |
+
|
163 |
+
# next_indices = top_indices[torch.arange(batch_size).to(top_indices.device), index_into_top_indices] # batch
|
164 |
+
next_indices = top_indices[:, index_into_top_indices] # batch
|
165 |
+
|
166 |
+
# encoded_input = torch.cat([encoded_input, next_indices.unsqueeze(1)], dim=1) # batch x seq+1
|
167 |
+
encoded_input = torch.cat([encoded_input, next_indices.squeeze(1)], dim=1)
|
168 |
+
lengths = lengths + 1 # batch
|
169 |
+
|
170 |
+
# print(tokenizer.decode(encoded_input[0], add_special_tokens=False))
|
171 |
+
return [tokenizer.decode(s) for s in encoded_input]
|
172 |
+
|
173 |
+
|
174 |
+
if __name__=='__main__':
|
175 |
+
parser = ArgumentParser()
|
176 |
+
|
177 |
+
# DATA
|
178 |
+
parser.add_argument('--ckpt', type=str, required=True)
|
179 |
+
parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
|
180 |
+
parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en')
|
181 |
+
|
182 |
+
parser.add_argument('--in_file', type=str, default=None, required=True, help='text to run pred on')
|
183 |
+
|
184 |
+
parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from text generation at each step before conditioning and re-pruning')
|
185 |
+
parser.add_argument('--do_sample', action='store_true', default=False, help='sample instead of greedy')
|
186 |
+
parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
|
187 |
+
parser.add_argument('--length_cutoff', type=int, default=512, help='max length')
|
188 |
+
|
189 |
+
parser.add_argument('--seed', type=int, default=1, help='random seed')
|
190 |
+
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
|
191 |
+
parser.add_argument('--debug', action='store_true', default=False)
|
192 |
+
|
193 |
+
args = parser.parse_args()
|
194 |
+
|
195 |
+
random.seed(args.seed)
|
196 |
+
np.random.seed(args.seed)
|
197 |
+
torch.manual_seed(args.seed)
|
198 |
+
|
199 |
+
main(args)
|
naacl-2021-fudge-controlled-generation/predict_formality.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
import pickle
|
5 |
+
import math
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
|
8 |
+
from typing import Iterable, List, Optional, Tuple
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model, MarianTokenizer, MarianMTModel
|
16 |
+
from torch import Tensor
|
17 |
+
|
18 |
+
from data import Dataset
|
19 |
+
from model import Model
|
20 |
+
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
|
21 |
+
from constants import *
|
22 |
+
|
23 |
+
def main(args):
|
24 |
+
with open(args.dataset_info, 'rb') as rf:
|
25 |
+
dataset_info = pickle.load(rf)
|
26 |
+
tokenizer = MarianTokenizer.from_pretrained(args.model_string)
|
27 |
+
tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
|
28 |
+
pad_id = tokenizer.encode(PAD_TOKEN)[0]
|
29 |
+
model = MarianMTModel.from_pretrained(args.model_string, return_dict=True).to(args.device)
|
30 |
+
model.eval()
|
31 |
+
|
32 |
+
checkpoint = torch.load(args.ckpt, map_location=args.device)
|
33 |
+
model_args = checkpoint['args']
|
34 |
+
conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
|
35 |
+
conditioning_model.load_state_dict(checkpoint['state_dict'])
|
36 |
+
conditioning_model = conditioning_model.to(args.device)
|
37 |
+
conditioning_model.eval()
|
38 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
39 |
+
.format(args.ckpt, checkpoint['epoch']))
|
40 |
+
print('num params', num_params(conditioning_model))
|
41 |
+
|
42 |
+
while True:
|
43 |
+
results = predict_formality(model,
|
44 |
+
tokenizer,
|
45 |
+
conditioning_model,
|
46 |
+
[args.input_text],
|
47 |
+
dataset_info,
|
48 |
+
precondition_topk=args.precondition_topk,
|
49 |
+
do_sample=args.do_sample,
|
50 |
+
length_cutoff=args.length_cutoff,
|
51 |
+
condition_lambda=args.condition_lambda,
|
52 |
+
device=args.device)
|
53 |
+
print(results)
|
54 |
+
import pdb; pdb.set_trace()
|
55 |
+
|
56 |
+
|
57 |
+
def predict_formality(model, tokenizer, conditioning_model, input_text, dataset_info, precondition_topk=200, do_sample=False, length_cutoff=512, condition_lambda=1.0, device='cuda'):
|
58 |
+
with torch.no_grad():
|
59 |
+
batch_size = len(input_text)
|
60 |
+
|
61 |
+
# assumes initially all same length.
|
62 |
+
# encode every x_i i \in [seq] word to respectable embedding
|
63 |
+
encoded_input = [tokenizer.encode(it, return_tensors='pt').to(device) for it in input_text] # batch x seq
|
64 |
+
encoded_input = torch.cat(encoded_input, dim=0)
|
65 |
+
|
66 |
+
input_ids = torch.LongTensor([[58100]]).to(device)
|
67 |
+
cur_len = 1
|
68 |
+
max_length = length_cutoff
|
69 |
+
min_length = 0
|
70 |
+
temperature = 1.0
|
71 |
+
top_k = 50
|
72 |
+
top_p = 1.0
|
73 |
+
repetition_penalty = 1.0
|
74 |
+
no_repeat_ngram_size = 0
|
75 |
+
bad_words_ids = [[58100]]
|
76 |
+
pad_token_id = 58100
|
77 |
+
eos_token_id = 0
|
78 |
+
effective_batch_size = batch_size
|
79 |
+
attention_mask = encoded_input.new_ones(encoded_input.shape)
|
80 |
+
use_cache = True
|
81 |
+
model_specific_kwargs = {'encoder_outputs': model.get_encoder()(encoded_input, attention_mask=attention_mask)}
|
82 |
+
|
83 |
+
output = _generate_no_beam_search(model,
|
84 |
+
conditioning_model,
|
85 |
+
condition_lambda,
|
86 |
+
precondition_topk,
|
87 |
+
input_ids,
|
88 |
+
cur_len,
|
89 |
+
max_length,
|
90 |
+
min_length,
|
91 |
+
do_sample,
|
92 |
+
temperature,
|
93 |
+
top_k,
|
94 |
+
top_p,
|
95 |
+
repetition_penalty,
|
96 |
+
no_repeat_ngram_size,
|
97 |
+
bad_words_ids,
|
98 |
+
pad_token_id,
|
99 |
+
eos_token_id,
|
100 |
+
batch_size,
|
101 |
+
attention_mask,
|
102 |
+
use_cache,
|
103 |
+
model_specific_kwargs)
|
104 |
+
|
105 |
+
return [tokenizer.decode(s[1:]) for s in output] # 1: to delete the pad token
|
106 |
+
|
107 |
+
|
108 |
+
# hack of code from transformers/generation_utils.py
|
109 |
+
# to get our conditioning
|
110 |
+
def postprocess_next_token_scores(
|
111 |
+
model,
|
112 |
+
scores,
|
113 |
+
input_ids,
|
114 |
+
no_repeat_ngram_size,
|
115 |
+
bad_words_ids,
|
116 |
+
cur_len,
|
117 |
+
min_length,
|
118 |
+
max_length,
|
119 |
+
eos_token_id,
|
120 |
+
repetition_penalty,
|
121 |
+
batch_size,
|
122 |
+
num_beams,
|
123 |
+
):
|
124 |
+
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
125 |
+
if repetition_penalty != 1.0:
|
126 |
+
model.enforce_repetition_penalty_(
|
127 |
+
scores,
|
128 |
+
batch_size,
|
129 |
+
num_beams,
|
130 |
+
input_ids,
|
131 |
+
repetition_penalty,
|
132 |
+
)
|
133 |
+
|
134 |
+
# set eos token prob to zero if min_length is not reached
|
135 |
+
if eos_token_id is not None and cur_len < min_length:
|
136 |
+
scores[:, eos_token_id] = -float("inf")
|
137 |
+
|
138 |
+
if no_repeat_ngram_size > 0:
|
139 |
+
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
140 |
+
num_batch_hypotheses = batch_size * num_beams
|
141 |
+
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
142 |
+
banned_batch_tokens = calc_banned_ngram_tokens(
|
143 |
+
input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
|
144 |
+
)
|
145 |
+
for i, banned_tokens in enumerate(banned_batch_tokens):
|
146 |
+
scores[i, banned_tokens] = -float("inf")
|
147 |
+
|
148 |
+
if bad_words_ids is not None:
|
149 |
+
# Exclude EOS token (already processed)
|
150 |
+
bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
|
151 |
+
# calculate a list of banned tokens according to bad words
|
152 |
+
banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids)
|
153 |
+
# Modify the scores in place by setting the banned tokens logits to `-inf`
|
154 |
+
set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
|
155 |
+
|
156 |
+
return scores
|
157 |
+
|
158 |
+
def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
|
159 |
+
"""Copied from fairseq for no_repeat_ngram in beam_search"""
|
160 |
+
if cur_len + 1 < no_repeat_ngram_size:
|
161 |
+
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
162 |
+
return [[] for _ in range(num_hypos)]
|
163 |
+
generated_ngrams = [{} for _ in range(num_hypos)]
|
164 |
+
for idx in range(num_hypos):
|
165 |
+
gen_tokens = prev_input_ids[idx].tolist()
|
166 |
+
generated_ngram = generated_ngrams[idx]
|
167 |
+
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
|
168 |
+
prev_ngram_tuple = tuple(ngram[:-1])
|
169 |
+
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
|
170 |
+
|
171 |
+
def _get_generated_ngrams(hypo_idx):
|
172 |
+
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
173 |
+
start_idx = cur_len + 1 - no_repeat_ngram_size
|
174 |
+
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
|
175 |
+
return generated_ngrams[hypo_idx].get(ngram_idx, [])
|
176 |
+
|
177 |
+
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
178 |
+
return banned_tokens
|
179 |
+
|
180 |
+
|
181 |
+
def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
|
182 |
+
banned_tokens = []
|
183 |
+
|
184 |
+
def _tokens_match(prev_tokens, tokens):
|
185 |
+
if len(tokens) == 0:
|
186 |
+
# if bad word tokens is just one token always ban it
|
187 |
+
return True
|
188 |
+
if len(tokens) > len(prev_tokens):
|
189 |
+
# if bad word tokens are longer than prev tokens they can't be equal
|
190 |
+
return False
|
191 |
+
|
192 |
+
if prev_tokens[-len(tokens) :] == tokens:
|
193 |
+
# if tokens match
|
194 |
+
return True
|
195 |
+
else:
|
196 |
+
return False
|
197 |
+
|
198 |
+
for prev_input_ids_slice in prev_input_ids:
|
199 |
+
banned_tokens_slice = []
|
200 |
+
|
201 |
+
for banned_token_seq in bad_words_ids:
|
202 |
+
assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
|
203 |
+
bad_words_ids
|
204 |
+
)
|
205 |
+
|
206 |
+
if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
|
207 |
+
# if tokens do not match continue
|
208 |
+
continue
|
209 |
+
|
210 |
+
banned_tokens_slice.append(banned_token_seq[-1])
|
211 |
+
|
212 |
+
banned_tokens.append(banned_tokens_slice)
|
213 |
+
|
214 |
+
return banned_tokens
|
215 |
+
|
216 |
+
def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
|
217 |
+
"""Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
|
218 |
+
a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
|
219 |
+
Args:
|
220 |
+
scores: logits distribution of shape (batch size, vocabulary size)
|
221 |
+
banned_tokens: list of list of tokens to ban of length (batch_size)
|
222 |
+
"""
|
223 |
+
banned_mask_list = []
|
224 |
+
for idx, batch_banned_tokens in enumerate(banned_tokens):
|
225 |
+
for token in batch_banned_tokens:
|
226 |
+
banned_mask_list.append([idx, token])
|
227 |
+
if not banned_mask_list:
|
228 |
+
return
|
229 |
+
banned_mask = torch.LongTensor(banned_mask_list)
|
230 |
+
indices = torch.ones(len(banned_mask))
|
231 |
+
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
|
232 |
+
# [ 0 1 1 ]
|
233 |
+
# [ 0 0 0 ]
|
234 |
+
# [ 1 0 0 ]
|
235 |
+
|
236 |
+
banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
|
237 |
+
scores.masked_fill_(banned_mask, -float("inf"))
|
238 |
+
|
239 |
+
def _generate_no_beam_search(
|
240 |
+
model,
|
241 |
+
conditioning_model,
|
242 |
+
condition_lambda,
|
243 |
+
precondition_topk,
|
244 |
+
input_ids,
|
245 |
+
cur_len,
|
246 |
+
max_length,
|
247 |
+
min_length,
|
248 |
+
do_sample,
|
249 |
+
temperature,
|
250 |
+
top_k,
|
251 |
+
top_p,
|
252 |
+
repetition_penalty,
|
253 |
+
no_repeat_ngram_size,
|
254 |
+
bad_words_ids,
|
255 |
+
pad_token_id,
|
256 |
+
eos_token_id,
|
257 |
+
batch_size,
|
258 |
+
attention_mask,
|
259 |
+
use_cache,
|
260 |
+
model_kwargs,
|
261 |
+
):
|
262 |
+
"""Generate sequences for each example without beam search (num_beams == 1).
|
263 |
+
All returned sequence are generated independantly.
|
264 |
+
"""
|
265 |
+
# length of generated sentences / unfinished sentences
|
266 |
+
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
267 |
+
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
268 |
+
past = None
|
269 |
+
while cur_len < max_length:
|
270 |
+
model_inputs = model.prepare_inputs_for_generation(
|
271 |
+
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
|
272 |
+
)
|
273 |
+
|
274 |
+
outputs = model(**model_inputs, return_dict=True)
|
275 |
+
next_token_logits = outputs.logits[:, -1, :]
|
276 |
+
|
277 |
+
# scores = model.postprocess_next_token_scores(
|
278 |
+
# scores=next_token_logits,
|
279 |
+
# input_ids=input_ids,
|
280 |
+
# no_repeat_ngram_size=no_repeat_ngram_size,
|
281 |
+
# bad_words_ids=bad_words_ids,
|
282 |
+
# cur_len=cur_len,
|
283 |
+
# min_length=min_length,
|
284 |
+
# max_length=max_length,
|
285 |
+
# eos_token_id=eos_token_id,
|
286 |
+
# repetition_penalty=repetition_penalty,
|
287 |
+
# batch_size=batch_size,
|
288 |
+
# num_beams=1,
|
289 |
+
# )
|
290 |
+
|
291 |
+
scores = postprocess_next_token_scores(
|
292 |
+
model=model,
|
293 |
+
scores=next_token_logits,
|
294 |
+
input_ids=input_ids,
|
295 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
296 |
+
bad_words_ids=bad_words_ids,
|
297 |
+
cur_len=cur_len,
|
298 |
+
min_length=min_length,
|
299 |
+
max_length=max_length,
|
300 |
+
eos_token_id=eos_token_id,
|
301 |
+
repetition_penalty=repetition_penalty,
|
302 |
+
batch_size=batch_size,
|
303 |
+
num_beams=1,
|
304 |
+
)
|
305 |
+
|
306 |
+
# if model has past, then set the past variable to speed up decoding
|
307 |
+
if "past_key_values" in outputs:
|
308 |
+
past = outputs.past_key_values
|
309 |
+
elif "mems" in outputs:
|
310 |
+
past = outputs.mems
|
311 |
+
|
312 |
+
top_logits, top_indices = scores.topk(precondition_topk, dim=1) # batch x topk
|
313 |
+
tplus1_candidates = torch.cat([input_ids.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2)[:, :, 1:] # batch x topk x seq+1, with pad dropped
|
314 |
+
expanded_lengths = torch.LongTensor([[cur_len for _ in range(precondition_topk)] for _ in range(batch_size)]).to(scores.device)
|
315 |
+
if condition_lambda == 0:
|
316 |
+
condition_logits = torch.zeros_like(top_logits).float()
|
317 |
+
else:
|
318 |
+
condition_logits = conditioning_model(tplus1_candidates.flatten(0, 1), # batch*topk x seq+1
|
319 |
+
expanded_lengths.flatten(0, 1), # batch*topk
|
320 |
+
None,
|
321 |
+
None,
|
322 |
+
None)
|
323 |
+
condition_logits = condition_logits.view(batch_size, precondition_topk, -1)[:, :, -1] # batch x topk of last formality pred
|
324 |
+
condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs
|
325 |
+
# condition_logits = - torch.log(1 + torch.exp(condition_logits)) # for informal
|
326 |
+
full_logits = top_logits + condition_lambda * condition_logits
|
327 |
+
if do_sample:
|
328 |
+
raise NotImplementedError
|
329 |
+
else:
|
330 |
+
# Greedy decoding
|
331 |
+
next_token = top_indices[torch.arange(batch_size).to(top_indices.device), torch.argmax(full_logits, dim=-1)]
|
332 |
+
|
333 |
+
# if do_sample:
|
334 |
+
# # Temperature (higher temperature => more likely to sample low probability tokens)
|
335 |
+
# if temperature != 1.0:
|
336 |
+
# scores = scores / temperature
|
337 |
+
# # Top-p/top-k filtering
|
338 |
+
# next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
|
339 |
+
# # Sample
|
340 |
+
# probs = F.softmax(next_token_logscores, dim=-1)
|
341 |
+
# next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
342 |
+
# else:
|
343 |
+
# # Greedy decoding
|
344 |
+
# next_token = torch.argmax(next_token_logits, dim=-1)
|
345 |
+
|
346 |
+
# update generations and finished sentences
|
347 |
+
if eos_token_id is not None:
|
348 |
+
# pad finished sentences if eos_token_id exist
|
349 |
+
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
|
350 |
+
else:
|
351 |
+
tokens_to_add = next_token
|
352 |
+
|
353 |
+
# add token and increase length by one
|
354 |
+
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
|
355 |
+
cur_len = cur_len + 1
|
356 |
+
|
357 |
+
if eos_token_id is not None:
|
358 |
+
eos_in_sents = tokens_to_add == eos_token_id
|
359 |
+
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
360 |
+
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
|
361 |
+
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
|
362 |
+
# unfinished_sents is set to zero if eos in sentence
|
363 |
+
unfinished_sents.mul_((~eos_in_sents).long())
|
364 |
+
|
365 |
+
# stop when there is a </s> in each sentence, or if we exceed the maximul length
|
366 |
+
if unfinished_sents.max() == 0:
|
367 |
+
break
|
368 |
+
|
369 |
+
# extend attention_mask for new generated input if only decoder
|
370 |
+
if model.config.is_encoder_decoder is False:
|
371 |
+
attention_mask = torch.cat(
|
372 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
373 |
+
)
|
374 |
+
|
375 |
+
return input_ids
|
376 |
+
|
377 |
+
if __name__=='__main__':
|
378 |
+
parser = ArgumentParser()
|
379 |
+
|
380 |
+
# DATA
|
381 |
+
parser.add_argument('--ckpt', type=str, required=True)
|
382 |
+
parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
|
383 |
+
parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en')
|
384 |
+
|
385 |
+
parser.add_argument('--input_text', type=str, default=None, required=True, help='text to run pred on')
|
386 |
+
|
387 |
+
parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
|
388 |
+
parser.add_argument('--do_sample', action='store_true', default=False, help='sample instead of greedy')
|
389 |
+
parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
|
390 |
+
parser.add_argument('--length_cutoff', type=int, default=512, help='max length')
|
391 |
+
|
392 |
+
parser.add_argument('--seed', type=int, default=1, help='random seed')
|
393 |
+
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
|
394 |
+
parser.add_argument('--debug', action='store_true', default=False)
|
395 |
+
|
396 |
+
args = parser.parse_args()
|
397 |
+
|
398 |
+
random.seed(args.seed)
|
399 |
+
np.random.seed(args.seed)
|
400 |
+
torch.manual_seed(args.seed)
|
401 |
+
|
402 |
+
main(args)
|
403 |
+
|
404 |
+
|
naacl-2021-fudge-controlled-generation/predict_poetry.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
import pickle
|
5 |
+
import math
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
import string
|
8 |
+
from collections import defaultdict
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
|
16 |
+
|
17 |
+
from data import Dataset, load_rhyme_info
|
18 |
+
from model import Model
|
19 |
+
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
|
20 |
+
from constants import *
|
21 |
+
from poetry_util import get_rhymes, count_syllables
|
22 |
+
|
23 |
+
def main(args):
|
24 |
+
with open(args.dataset_info, 'rb') as rf:
|
25 |
+
dataset_info = pickle.load(rf)
|
26 |
+
gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
|
27 |
+
gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
|
28 |
+
gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
|
29 |
+
gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device)
|
30 |
+
gpt_model.eval()
|
31 |
+
|
32 |
+
checkpoint = torch.load(args.iambic_ckpt, map_location=args.device)
|
33 |
+
model_args = checkpoint['args']
|
34 |
+
iambic_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
|
35 |
+
iambic_model.load_state_dict(checkpoint['state_dict'])
|
36 |
+
iambic_model = iambic_model.to(args.device)
|
37 |
+
iambic_model.eval()
|
38 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
39 |
+
.format(args.iambic_ckpt, checkpoint['epoch']))
|
40 |
+
print('iambic model num params', num_params(iambic_model))
|
41 |
+
|
42 |
+
with open(args.rhyme_info, 'rb') as rf:
|
43 |
+
rhyme_info = pickle.load(rf)
|
44 |
+
checkpoint = torch.load(args.rhyme_ckpt, map_location=args.device)
|
45 |
+
model_args = checkpoint['args']
|
46 |
+
rhyme_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word), rhyme_group_size=len(rhyme_info.index2rhyme_group)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
|
47 |
+
rhyme_model.load_state_dict(checkpoint['state_dict'])
|
48 |
+
rhyme_model = rhyme_model.to(args.device)
|
49 |
+
rhyme_model.eval()
|
50 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
51 |
+
.format(args.rhyme_ckpt, checkpoint['epoch']))
|
52 |
+
print('rhyme model num params', num_params(rhyme_model))
|
53 |
+
|
54 |
+
checkpoint = torch.load(args.newline_ckpt, map_location=args.device)
|
55 |
+
model_args = checkpoint['args']
|
56 |
+
newline_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
|
57 |
+
newline_model.load_state_dict(checkpoint['state_dict'])
|
58 |
+
newline_model = newline_model.to(args.device)
|
59 |
+
newline_model.eval()
|
60 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
61 |
+
.format(args.newline_ckpt, checkpoint['epoch']))
|
62 |
+
print('iambic model num params', num_params(newline_model))
|
63 |
+
|
64 |
+
while True:
|
65 |
+
results = predict_couplet(gpt_model,
|
66 |
+
gpt_tokenizer,
|
67 |
+
iambic_model,
|
68 |
+
rhyme_model,
|
69 |
+
newline_model,
|
70 |
+
[args.input_text],
|
71 |
+
dataset_info,
|
72 |
+
rhyme_info,
|
73 |
+
args.precondition_topk,
|
74 |
+
args.topk,
|
75 |
+
condition_lambda=args.condition_lambda,
|
76 |
+
device=args.device)
|
77 |
+
for line in results:
|
78 |
+
print(line)
|
79 |
+
import pdb; pdb.set_trace()
|
80 |
+
|
81 |
+
|
82 |
+
def predict_couplet(gpt_model, gpt_tokenizer, iambic_model, rhyme_model, newline_model, input_text, dataset_info, rhyme_info, precondition_topk, postcondition_topk, condition_lambda=1.0, device='cuda'):
|
83 |
+
assert len(input_text) == 1 # only do one at a time for now
|
84 |
+
current_text = input_text[0]
|
85 |
+
current_line_text = ''
|
86 |
+
all_lines = [current_text]
|
87 |
+
ending_word = current_text.split()[-1].strip(string.punctuation)
|
88 |
+
word2rhyme_group = defaultdict(lambda: UNKNOWN_RHYME_GROUP, rhyme_info.word2rhyme_group)
|
89 |
+
rhyme_group = word2rhyme_group[ending_word]
|
90 |
+
|
91 |
+
line = predict_iambic_pentameter_line(gpt_model,
|
92 |
+
gpt_tokenizer,
|
93 |
+
iambic_model,
|
94 |
+
rhyme_model,
|
95 |
+
newline_model,
|
96 |
+
current_text,
|
97 |
+
current_line_text,
|
98 |
+
rhyme_group,
|
99 |
+
dataset_info,
|
100 |
+
rhyme_info,
|
101 |
+
precondition_topk,
|
102 |
+
postcondition_topk,
|
103 |
+
condition_lambda=condition_lambda,
|
104 |
+
device=device)
|
105 |
+
all_lines.append(line)
|
106 |
+
|
107 |
+
return all_lines
|
108 |
+
|
109 |
+
|
110 |
+
def predict_iambic_pentameter_line(gpt_model, gpt_tokenizer, iambic_model, rhyme_model, newline_model, current_text, current_line_text, rhyme_group, dataset_info, rhyme_info, precondition_topk, postcondition_topk, banned_tokens=POETRY_BANNED_TOKENS, condition_lambda=1.0, device='cuda', length_cutoff=30):
|
111 |
+
# TODO(poetry) delete banned tokens?
|
112 |
+
with torch.no_grad():
|
113 |
+
batch_size = 1
|
114 |
+
|
115 |
+
rhyme_group_index = rhyme_info.rhyme_group2index[rhyme_group]
|
116 |
+
future_words = torch.LongTensor([rhyme_group_index]).to(device) # 1
|
117 |
+
log_probs = torch.Tensor([math.log(rhyme_info.rhyme_group_counts[rhyme_group] / rhyme_info.total_rhyme_groups)]).to(device) # 1
|
118 |
+
|
119 |
+
# assumes initially all same length.
|
120 |
+
previous_encoded_text = [gpt_tokenizer.encode(it, return_tensors='pt').to(device) for it in [current_text]]
|
121 |
+
previous_enc_len = previous_encoded_text[0].shape[1]
|
122 |
+
encoded_input = [gpt_tokenizer.encode(it, return_tensors='pt').to(device) for it in [current_text + current_line_text]] # batch x seq
|
123 |
+
encoded_input = torch.cat(encoded_input, dim=0)
|
124 |
+
lengths = torch.LongTensor([encoded_input.shape[1]]).to(device)
|
125 |
+
|
126 |
+
line_syllable_count = count_syllables(current_line_text)
|
127 |
+
assert line_syllable_count < POETRY_LINE_SYLLABLES # assume we started with less than one full line
|
128 |
+
syllables_to_go = POETRY_LINE_SYLLABLES - line_syllable_count
|
129 |
+
|
130 |
+
for _ in range(length_cutoff): # really shouldn't have a line this long anyway
|
131 |
+
gpt_logits = gpt_model(encoded_input)[0][:, -1, :] # batch x vocab
|
132 |
+
gpt_logits[:, banned_tokens] = -1e8
|
133 |
+
top_logits, top_indices = gpt_logits.topk(precondition_topk, dim=1)
|
134 |
+
|
135 |
+
new_input_candidates = torch.cat([encoded_input.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2) # batch x topk x seq+1
|
136 |
+
expanded_lengths = (lengths + 1).unsqueeze(1).expand(batch_size, precondition_topk) # batch x topk
|
137 |
+
expanded_future_words = future_words.unsqueeze(0).unsqueeze(1).expand(batch_size, precondition_topk, -1) # batch x topk x N
|
138 |
+
candidate_syllables_to_go = []
|
139 |
+
for candidate in new_input_candidates[0]:
|
140 |
+
candidate_until_last_word_text = ' '.join(gpt_tokenizer.decode(candidate[previous_enc_len:]).split()[:-1])
|
141 |
+
candidate_syllables_to_go.append(10 - count_syllables(candidate_until_last_word_text))
|
142 |
+
# usually these are all the same, but run them all for correctness. could do more efficiently but it's not too slow anyway.
|
143 |
+
expanded_syllables_to_go = torch.LongTensor(candidate_syllables_to_go).to(device).view(1, precondition_topk)
|
144 |
+
|
145 |
+
if condition_lambda == 0:
|
146 |
+
iambic_logits = torch.zeros_like(expanded_lengths).float()
|
147 |
+
else:
|
148 |
+
# truncate prefix because we trained on single lines
|
149 |
+
iambic_logits = iambic_model(new_input_candidates[:, :, previous_enc_len:].flatten(0, 1), expanded_lengths.flatten(0, 1) - previous_enc_len, None, None, None)[:, -1] # batch*topk x seq+1 -> batch*topk
|
150 |
+
iambic_logits = iambic_logits.view(batch_size, precondition_topk)
|
151 |
+
iambic_logits = iambic_logits - torch.log(1 + torch.exp(iambic_logits))
|
152 |
+
if condition_lambda == 0:
|
153 |
+
rhyme_logits = torch.zeros_like(expanded_lengths).float()
|
154 |
+
else:
|
155 |
+
rhyme_logits = rhyme_model(new_input_candidates.flatten(0, 1), # batch*topk x seq+1
|
156 |
+
expanded_lengths.flatten(0, 1), # batch*topk
|
157 |
+
expanded_future_words.flatten(0, 1), # batch*topk x N
|
158 |
+
log_probs, # N
|
159 |
+
expanded_syllables_to_go.flatten(0, 1)) # batch*topk
|
160 |
+
rhyme_logits = rhyme_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
|
161 |
+
rhyme_logits = rhyme_logits - torch.log(1 + torch.exp(rhyme_logits)) # batch x topk x N
|
162 |
+
rhyme_logits = rhyme_logits.squeeze(2) # batch x topk
|
163 |
+
if condition_lambda == 0:
|
164 |
+
newline_logits = torch.zeros_like(expanded_lengths).float()
|
165 |
+
else:
|
166 |
+
newline_logits = newline_model(new_input_candidates.flatten(0, 1), # batch*topk x seq+1
|
167 |
+
expanded_lengths.flatten(0, 1), # batch*topk
|
168 |
+
expanded_future_words.flatten(0, 1), # batch*topk x N
|
169 |
+
log_probs, # N
|
170 |
+
expanded_syllables_to_go.flatten(0, 1)) # batch*topk
|
171 |
+
newline_logits = newline_logits[:, -1].view(batch_size, precondition_topk, -1) # batch x topk x N
|
172 |
+
newline_logits = newline_logits - torch.log(1 + torch.exp(newline_logits)) # batch x topk x N
|
173 |
+
newline_logits = newline_logits.squeeze(2) # batch x topk
|
174 |
+
|
175 |
+
full_logits = top_logits + condition_lambda * iambic_logits + condition_lambda * rhyme_logits + condition_lambda * newline_logits
|
176 |
+
post_logits, post_indices = full_logits.topk(postcondition_topk, dim=1)
|
177 |
+
post_probs = F.softmax(post_logits, dim=1)
|
178 |
+
index_into_top_indices = post_indices[torch.arange(batch_size).to(post_indices.device), torch.multinomial(post_probs, 1).flatten()] # batch
|
179 |
+
next_indices = top_indices[torch.arange(batch_size).to(top_indices.device), index_into_top_indices] # batch
|
180 |
+
encoded_input = torch.cat([encoded_input, next_indices.unsqueeze(1)], dim=1) # batch x seq+1
|
181 |
+
lengths = lengths + 1
|
182 |
+
syllables_to_go = POETRY_LINE_SYLLABLES - count_syllables(gpt_tokenizer.decode(encoded_input[0][previous_enc_len:])) # if we get very unlucky with a partial word that the syllable counter doesn't recognize we might end early, but it's unlikely
|
183 |
+
if syllables_to_go <= 0 and [gpt_tokenizer.decode(s) for s in encoded_input][0][-1] in PHRASE_ENDS:
|
184 |
+
break
|
185 |
+
if syllables_to_go < 0:
|
186 |
+
# encoded_input = encoded_input[:, :-1]
|
187 |
+
break
|
188 |
+
|
189 |
+
return [gpt_tokenizer.decode(s) for s in encoded_input][0][len(current_text):]
|
190 |
+
|
191 |
+
|
192 |
+
if __name__=='__main__':
|
193 |
+
parser = ArgumentParser()
|
194 |
+
|
195 |
+
# DATA
|
196 |
+
parser.add_argument('--iambic_ckpt', type=str, required=True)
|
197 |
+
parser.add_argument('--rhyme_ckpt', type=str, required=True)
|
198 |
+
parser.add_argument('--newline_ckpt', type=str, required=True)
|
199 |
+
parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
|
200 |
+
parser.add_argument('--rhyme_info', type=str, required=True, help='saved rhyme info')
|
201 |
+
parser.add_argument('--model_string', type=str, default='gpt2-medium')
|
202 |
+
|
203 |
+
parser.add_argument('--input_text', type=str, default=None, required=True, help='initial text')
|
204 |
+
|
205 |
+
parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
|
206 |
+
parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step')
|
207 |
+
parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
|
208 |
+
|
209 |
+
parser.add_argument('--seed', type=int, default=1, help='random seed')
|
210 |
+
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
|
211 |
+
parser.add_argument('--debug', action='store_true', default=False)
|
212 |
+
|
213 |
+
args = parser.parse_args()
|
214 |
+
|
215 |
+
random.seed(args.seed)
|
216 |
+
np.random.seed(args.seed)
|
217 |
+
torch.manual_seed(args.seed)
|
218 |
+
|
219 |
+
main(args)
|
naacl-2021-fudge-controlled-generation/predict_topic.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
import pickle
|
5 |
+
import math
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
|
8 |
+
from tqdm import tqdm
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
|
14 |
+
|
15 |
+
from data import Dataset
|
16 |
+
from model import Model
|
17 |
+
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
|
18 |
+
from constants import *
|
19 |
+
|
20 |
+
def main(args):
|
21 |
+
with open(args.dataset_info, 'rb') as rf:
|
22 |
+
dataset_info = pickle.load(rf)
|
23 |
+
for cw in args.condition_words.split():
|
24 |
+
assert cw in dataset_info.word2index
|
25 |
+
gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
|
26 |
+
gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
|
27 |
+
gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
|
28 |
+
gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device)
|
29 |
+
gpt_model.eval()
|
30 |
+
|
31 |
+
checkpoint = torch.load(args.ckpt, map_location=args.device)
|
32 |
+
model_args = checkpoint['args']
|
33 |
+
conditioning_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
|
34 |
+
conditioning_model.load_state_dict(checkpoint['state_dict'])
|
35 |
+
conditioning_model = conditioning_model.to(args.device)
|
36 |
+
conditioning_model.eval()
|
37 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
38 |
+
.format(args.ckpt, checkpoint['epoch']))
|
39 |
+
print('num params', num_params(conditioning_model))
|
40 |
+
|
41 |
+
while True:
|
42 |
+
results = predict(gpt_model,
|
43 |
+
gpt_tokenizer,
|
44 |
+
conditioning_model,
|
45 |
+
[args.input_text],
|
46 |
+
args.condition_words,
|
47 |
+
dataset_info,
|
48 |
+
args.precondition_topk,
|
49 |
+
args.topk,
|
50 |
+
args.length_cutoff,
|
51 |
+
condition_lambda=args.condition_lambda,
|
52 |
+
device=args.device)
|
53 |
+
print(results)
|
54 |
+
import pdb; pdb.set_trace()
|
55 |
+
|
56 |
+
def predict(gpt_model, gpt_tokenizer, conditioning_model, input_text, condition_words, dataset_info, precondition_topk, postcondition_topk, length_cutoff, condition_lambda=1.0, device='cuda'):
|
57 |
+
with torch.no_grad():
|
58 |
+
batch_size = len(input_text)
|
59 |
+
|
60 |
+
condition_words = condition_words.split()
|
61 |
+
future_words = torch.LongTensor([dataset_info.word2index[cw] for cw in condition_words]).to(device) # N
|
62 |
+
log_probs = torch.Tensor([math.log(dataset_info.vocab[cw] / dataset_info.total_words) for cw in condition_words]).to(device) # N
|
63 |
+
|
64 |
+
# assumes initially all same length.
|
65 |
+
encoded_input = [gpt_tokenizer.encode(it, return_tensors='pt').to(device) for it in input_text] # batch x seq
|
66 |
+
encoded_input = torch.cat(encoded_input, dim=0)
|
67 |
+
lengths = torch.LongTensor([encoded_input.shape[1]]).to(device)
|
68 |
+
|
69 |
+
gpt_encoded_future_words = [gpt_tokenizer.encode(' ' + cw, return_tensors='pt')[0].to(device) for cw in condition_words]
|
70 |
+
while lengths.max() < length_cutoff:
|
71 |
+
tokens_left = torch.LongTensor([length_cutoff - lengths.max() for _ in range(batch_size)]).to(device)
|
72 |
+
gpt_logits = gpt_model(encoded_input)[0][:, -1, :] # batch x vocab
|
73 |
+
top_logits, top_indices = gpt_logits.topk(precondition_topk, dim=1) # batch x topk
|
74 |
+
new_input_candidates = torch.cat([encoded_input.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2) # batch x topk x seq+1
|
75 |
+
expanded_lengths = (lengths + 1).unsqueeze(1).expand(batch_size, precondition_topk) # batch x topk
|
76 |
+
expanded_future_words = future_words.unsqueeze(0).unsqueeze(1).expand(batch_size, precondition_topk, -1) # batch x topk x N
|
77 |
+
expanded_tokens_left = tokens_left.unsqueeze(1).expand(-1, precondition_topk) # batch x topk
|
78 |
+
if condition_lambda == 0:
|
79 |
+
condition_logits = torch.zeros_like(expanded_future_words).float()
|
80 |
+
else:
|
81 |
+
condition_logits = conditioning_model(new_input_candidates.flatten(0, 1), # batch*topk x seq+1
|
82 |
+
expanded_lengths.flatten(0, 1), # batch*topk
|
83 |
+
expanded_future_words.flatten(0, 1), # batch*topk x N
|
84 |
+
log_probs, # N
|
85 |
+
expanded_tokens_left.flatten(0, 1)) # batch*topk
|
86 |
+
condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
|
87 |
+
condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs
|
88 |
+
|
89 |
+
condition_logits = torch.mean(condition_logits, dim=2)
|
90 |
+
full_logits = top_logits + condition_logits * condition_lambda # batch x topk
|
91 |
+
post_logits, post_indices = full_logits.topk(postcondition_topk, dim=1)
|
92 |
+
post_probs = F.softmax(post_logits, dim=1)
|
93 |
+
index_into_top_indices = post_indices[torch.arange(batch_size).to(post_indices.device), torch.multinomial(post_probs, 1).flatten()] # batch
|
94 |
+
next_indices = top_indices[torch.arange(batch_size).to(top_indices.device), index_into_top_indices] # batch
|
95 |
+
encoded_input = torch.cat([encoded_input, next_indices.unsqueeze(1)], dim=1) # batch x seq+1
|
96 |
+
lengths = lengths + 1 # batch
|
97 |
+
return [gpt_tokenizer.decode(s) for s in encoded_input]
|
98 |
+
|
99 |
+
|
100 |
+
if __name__=='__main__':
|
101 |
+
parser = ArgumentParser()
|
102 |
+
|
103 |
+
# DATA
|
104 |
+
parser.add_argument('--ckpt', type=str, required=True)
|
105 |
+
parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
|
106 |
+
parser.add_argument('--model_string', type=str, default='gpt2-medium')
|
107 |
+
|
108 |
+
parser.add_argument('--input_text', type=str, default=None, required=True, help='initial text')
|
109 |
+
parser.add_argument('--condition_words', type=str, default=None, required=True, help='word(s) to optimize for')
|
110 |
+
|
111 |
+
parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
|
112 |
+
parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step')
|
113 |
+
parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
|
114 |
+
parser.add_argument('--length_cutoff', type=int, default=80, help='max length')
|
115 |
+
|
116 |
+
parser.add_argument('--seed', type=int, default=1, help='random seed')
|
117 |
+
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
|
118 |
+
parser.add_argument('--debug', action='store_true', default=False)
|
119 |
+
|
120 |
+
args = parser.parse_args()
|
121 |
+
|
122 |
+
random.seed(args.seed)
|
123 |
+
np.random.seed(args.seed)
|
124 |
+
torch.manual_seed(args.seed)
|
125 |
+
|
126 |
+
main(args)
|
naacl-2021-fudge-controlled-generation/requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Phyme==0.0.9
|
2 |
+
pronouncing==0.2.0
|
3 |
+
pytorch-lightning==1.0.6
|
4 |
+
torch==1.7.0
|
5 |
+
tqdm==4.49.0
|
6 |
+
sacrebleu==1.4.14
|
7 |
+
sacremoses==0.0.43
|
naacl-2021-fudge-controlled-generation/topic_data/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
`topic_prefixes.txt` contains the 20 prefixes used at test time for starting the generations.
|
2 |
+
|
3 |
+
`wordlists/` contains the wordlists for each of the 7 topics used during testing. The heldout bags used to evaluate the generalization of the topic words to other related words are in `test_wordlists/`. `val_wordlists/` contains just one extra wordlist used for tuning.
|
naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/computers.txt
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sailor
|
2 |
+
memories
|
3 |
+
article
|
4 |
+
phishing
|
5 |
+
crucial
|
6 |
+
interactive
|
7 |
+
capabilities
|
8 |
+
ISP
|
9 |
+
query
|
10 |
+
signal
|
11 |
+
computation
|
12 |
+
detect
|
13 |
+
compiling
|
14 |
+
workstation
|
15 |
+
barcode
|
16 |
+
XP
|
17 |
+
cake
|
18 |
+
counterfeiting
|
19 |
+
decimal
|
20 |
+
back-up
|
21 |
+
reasoning
|
22 |
+
DSL
|
23 |
+
C++
|
24 |
+
DVD
|
25 |
+
Frequently
|
26 |
+
wifi
|
27 |
+
deleting
|
28 |
+
paper
|
29 |
+
DNS
|
30 |
+
CyanogenMod
|
31 |
+
overflow
|
32 |
+
Android
|
33 |
+
latency
|
34 |
+
creating
|
35 |
+
redirect
|
36 |
+
sites
|
37 |
+
sidebar
|
38 |
+
Jacket
|
39 |
+
prev
|
40 |
+
connections
|
41 |
+
PDF
|
42 |
+
torrent
|
43 |
+
original
|
44 |
+
gmail
|
45 |
+
rename
|
46 |
+
coder
|
47 |
+
mainboard
|
48 |
+
parasite
|
49 |
+
casing
|
50 |
+
lurks
|
51 |
+
pixels
|
52 |
+
touchpad
|
53 |
+
update
|
54 |
+
visuals
|
55 |
+
encyclopedia
|
56 |
+
mice
|
57 |
+
Solaris
|
58 |
+
caching
|
59 |
+
copies
|
60 |
+
usb
|
61 |
+
chew
|
62 |
+
fixes
|
63 |
+
house
|
64 |
+
operand
|
65 |
+
input
|
66 |
+
pull
|
67 |
+
iterative
|
68 |
+
educational
|
69 |
+
autocomplete
|
70 |
+
on-line
|
71 |
+
confidentiality
|
72 |
+
decrypt
|
73 |
+
beach
|
74 |
+
mails
|
75 |
+
rectangular
|
76 |
+
jQuery
|
77 |
+
Excel
|
78 |
+
point-in-time
|
79 |
+
Ubuntu
|
80 |
+
decryption
|
81 |
+
dialup
|
82 |
+
profit
|
83 |
+
off-line
|
84 |
+
developing
|
85 |
+
choice
|
86 |
+
notebook
|
87 |
+
storing
|
88 |
+
typeface
|
89 |
+
little
|
90 |
+
customer
|
91 |
+
step
|
92 |
+
text
|
93 |
+
run-time
|
94 |
+
interview
|
95 |
+
layout
|
96 |
+
computing
|
97 |
+
chairs
|
98 |
+
infected
|
99 |
+
must
|
100 |
+
tools
|
101 |
+
search
|
102 |
+
pane
|
103 |
+
gamepad
|
104 |
+
disc
|
105 |
+
initialize
|
106 |
+
display
|
107 |
+
button
|
108 |
+
Firefox
|
109 |
+
automatically
|
110 |
+
garbage
|
111 |
+
512MB
|
112 |
+
cyber
|
113 |
+
logon
|
114 |
+
elements
|
115 |
+
restoring
|
116 |
+
writer
|
117 |
+
saving
|
118 |
+
parsing
|
119 |
+
execute
|
120 |
+
configuring
|
121 |
+
telephoto
|
122 |
+
popup
|
123 |
+
utilities
|
124 |
+
packet
|
125 |
+
pasting
|
126 |
+
guest
|
127 |
+
edit
|
128 |
+
glass
|
129 |
+
e-mail
|
130 |
+
components
|
131 |
+
binaries
|
132 |
+
subdirectory
|
133 |
+
restart
|
134 |
+
XSLT
|
135 |
+
inkjet
|
136 |
+
allows
|
137 |
+
functionality
|
138 |
+
debian
|
139 |
+
change
|
140 |
+
click
|
141 |
+
dialog
|
142 |
+
GPU
|
143 |
+
stored
|
144 |
+
attribute
|
145 |
+
deflate
|
146 |
+
cheat
|
147 |
+
direction
|
148 |
+
camera
|
149 |
+
hats
|
150 |
+
topic
|
151 |
+
journalists
|
152 |
+
taxi
|
153 |
+
console
|
154 |
+
identifier
|
155 |
+
VPN
|
156 |
+
flames
|
157 |
+
spyware
|
158 |
+
secure
|
159 |
+
shoe
|
160 |
+
Macs
|
161 |
+
php
|
162 |
+
demo
|
163 |
+
extract
|
naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/legal.txt
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
waived
|
2 |
+
homicide
|
3 |
+
repress
|
4 |
+
statutory
|
5 |
+
sentencing
|
6 |
+
respondent
|
7 |
+
maintain
|
8 |
+
legislative
|
9 |
+
prosecution
|
10 |
+
whether
|
11 |
+
forgive
|
12 |
+
mandamus
|
13 |
+
democratic
|
14 |
+
treasurer
|
15 |
+
acquittal
|
16 |
+
offender
|
17 |
+
sued
|
18 |
+
edict
|
19 |
+
malpractice
|
20 |
+
debatable
|
21 |
+
criminal
|
22 |
+
injunctive
|
23 |
+
appellant
|
24 |
+
convicted
|
25 |
+
admit
|
26 |
+
proxies
|
27 |
+
aggrieved
|
28 |
+
enforcement
|
29 |
+
second-degree
|
30 |
+
ethical
|
31 |
+
knowing
|
32 |
+
liability
|
33 |
+
event
|
34 |
+
property
|
35 |
+
conviction
|
36 |
+
deposited
|
37 |
+
immune
|
38 |
+
assertion
|
39 |
+
assualt
|
40 |
+
regulations
|
41 |
+
exams
|
42 |
+
pixels
|
43 |
+
prosecuting
|
44 |
+
insolvent
|
45 |
+
felonies
|
46 |
+
families
|
47 |
+
mediator
|
48 |
+
rulings
|
49 |
+
heard
|
50 |
+
wrongs
|
51 |
+
wrongful
|
52 |
+
folder
|
53 |
+
federal
|
54 |
+
widget
|
55 |
+
restaurant
|
56 |
+
incarcerated
|
57 |
+
burglary
|
58 |
+
pants
|
59 |
+
land-use
|
60 |
+
quash
|
61 |
+
sitting
|
62 |
+
rescind
|
63 |
+
dispute
|
64 |
+
leave
|
65 |
+
requesting
|
66 |
+
appearing
|
67 |
+
testify
|
68 |
+
discoveries
|
69 |
+
championship
|
70 |
+
police
|
71 |
+
judgment
|
72 |
+
purchase
|
73 |
+
revelation
|
74 |
+
solicitor
|
75 |
+
disagree
|
76 |
+
judicial
|
77 |
+
reversing
|
78 |
+
jurors
|
79 |
+
decision
|
80 |
+
negligent
|
81 |
+
mutual
|
82 |
+
track
|
83 |
+
objecting
|
84 |
+
major
|
85 |
+
amendment
|
86 |
+
alleging
|
87 |
+
agreement
|
88 |
+
investment
|
89 |
+
custodial
|
90 |
+
accusation
|
91 |
+
passageways
|
92 |
+
asserted
|
93 |
+
authority
|
94 |
+
deputies
|
95 |
+
insolvency
|
96 |
+
sworn
|
97 |
+
defensive
|
98 |
+
embezzlement
|
99 |
+
disputes
|
100 |
+
findings
|
101 |
+
reservation
|
102 |
+
litem
|
103 |
+
inmates
|
104 |
+
step-by-step
|
105 |
+
innocence
|
106 |
+
parties
|
107 |
+
transcribed
|
108 |
+
inept
|
naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/military.txt
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
team
|
2 |
+
threat
|
3 |
+
sloop
|
4 |
+
offensively
|
5 |
+
guerilla
|
6 |
+
invading
|
7 |
+
samurai
|
8 |
+
propel
|
9 |
+
sunk
|
10 |
+
concern
|
11 |
+
persuade
|
12 |
+
Maj.
|
13 |
+
wear
|
14 |
+
fatigues
|
15 |
+
subsidiary
|
16 |
+
glider
|
17 |
+
advancing
|
18 |
+
ICBM
|
19 |
+
won
|
20 |
+
cargo
|
21 |
+
groan
|
22 |
+
knowledge
|
23 |
+
proposal
|
24 |
+
terms
|
25 |
+
deputy
|
26 |
+
taken
|
27 |
+
bricks
|
28 |
+
operation
|
29 |
+
Iraq
|
30 |
+
zoning
|
31 |
+
offices
|
32 |
+
fought
|
33 |
+
detonated
|
34 |
+
adjutant
|
35 |
+
skipper
|
36 |
+
batteries
|
37 |
+
medical
|
38 |
+
strategic
|
39 |
+
armistice
|
40 |
+
rocket
|
41 |
+
enemies
|
42 |
+
tensions
|
43 |
+
forming
|
44 |
+
inundate
|
45 |
+
engaging
|
46 |
+
dormitories
|
47 |
+
flying
|
48 |
+
allies
|
49 |
+
cursor
|
50 |
+
casing
|
51 |
+
zone
|
52 |
+
scouts
|
53 |
+
stationed
|
54 |
+
pistol
|
55 |
+
paragraph
|
56 |
+
highest
|
57 |
+
tribute
|
58 |
+
strategy
|
59 |
+
pump
|
60 |
+
decoding
|
61 |
+
argue
|
62 |
+
public
|
63 |
+
policeman
|
64 |
+
lob
|
65 |
+
sword
|
66 |
+
bleeding
|
67 |
+
civilians
|
68 |
+
rifles
|
69 |
+
airmen
|
70 |
+
freedom
|
71 |
+
explosion
|
72 |
+
capturing
|
73 |
+
skirmish
|
74 |
+
conquered
|
75 |
+
frigate
|
76 |
+
armour
|
77 |
+
leaving
|
78 |
+
customer
|
79 |
+
expert
|
80 |
+
armies
|
81 |
+
aviation
|
82 |
+
armoury
|
83 |
+
rifleman
|
84 |
+
lace
|
85 |
+
khaki
|
86 |
+
barrage
|
87 |
+
civilian
|
88 |
+
secluded
|
89 |
+
casualties
|
90 |
+
injuries
|
91 |
+
academies
|
92 |
+
hires
|
93 |
+
dead
|
94 |
+
ATL
|
95 |
+
late
|
96 |
+
relinquish
|
97 |
+
naval
|
98 |
+
riflemen
|
99 |
+
seige
|
100 |
+
sonar
|
101 |
+
aboard
|
102 |
+
longtime
|
103 |
+
bottom
|
104 |
+
gatling
|
105 |
+
militia
|
106 |
+
clandestine
|
107 |
+
execute
|
108 |
+
assets
|
109 |
+
significant
|
110 |
+
personnel
|
111 |
+
escorting
|
112 |
+
manoeuvre
|
113 |
+
Sgt.
|
114 |
+
rear
|
115 |
+
shoulders
|
116 |
+
rescuing
|
117 |
+
hand-to-hand
|
118 |
+
howitzer
|
119 |
+
committee
|
120 |
+
rifle
|
121 |
+
victory
|
122 |
+
defensive
|
123 |
+
forcing
|
124 |
+
honour
|
125 |
+
companies
|
126 |
+
pirate
|
127 |
+
evacuating
|
128 |
+
sabotaging
|
129 |
+
citadel
|
130 |
+
cadre
|
131 |
+
camera
|
132 |
+
launchers
|
133 |
+
flames
|
134 |
+
encoding
|
135 |
+
visor
|
136 |
+
ship
|
naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/politics.txt
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
credibility
|
2 |
+
Nazism
|
3 |
+
imported
|
4 |
+
remember
|
5 |
+
progressivism
|
6 |
+
legislative
|
7 |
+
communist
|
8 |
+
gender
|
9 |
+
democratic
|
10 |
+
immediate
|
11 |
+
capitalist
|
12 |
+
purchase
|
13 |
+
energy
|
14 |
+
referenda
|
15 |
+
ratify
|
16 |
+
lengthy
|
17 |
+
authorisation
|
18 |
+
aristocrats
|
19 |
+
jurisdiction
|
20 |
+
judge
|
21 |
+
socialist
|
22 |
+
excise
|
23 |
+
fascist
|
24 |
+
secondary
|
25 |
+
subsidies
|
26 |
+
autocratic
|
27 |
+
shortfall
|
28 |
+
appropriated
|
29 |
+
uphold
|
30 |
+
income
|
31 |
+
federated
|
32 |
+
federal
|
33 |
+
efforts
|
34 |
+
diplomatic
|
35 |
+
freedom
|
36 |
+
properties
|
37 |
+
ideologies
|
38 |
+
exporting
|
39 |
+
minority
|
40 |
+
cultural
|
naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/religion.txt
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Elegant
|
2 |
+
Catholicism
|
3 |
+
Metatron
|
4 |
+
Mind
|
5 |
+
Empires
|
6 |
+
SWF
|
7 |
+
Secular
|
8 |
+
Judas
|
9 |
+
Prime
|
10 |
+
Terrier
|
11 |
+
Preview
|
12 |
+
Existence
|
13 |
+
Silent
|
14 |
+
sanctuaries
|
15 |
+
Answer
|
16 |
+
Balancing
|
17 |
+
Mutual
|
18 |
+
Constantinople
|
19 |
+
Scrolls
|
20 |
+
Network
|
21 |
+
Almighty
|
22 |
+
Attorney
|
23 |
+
Liberation
|
24 |
+
Database
|
25 |
+
Practicing
|
26 |
+
St.
|
27 |
+
Eucharist
|
28 |
+
Glorious
|
29 |
+
Catholic
|
30 |
+
Compassion
|
31 |
+
Volume
|
32 |
+
Saviour
|
33 |
+
Meditation
|
34 |
+
Testament
|
35 |
+
Morality
|
36 |
+
Heart
|
37 |
+
Aramaic
|
38 |
+
Court
|
39 |
+
Baskets
|
40 |
+
Fervor
|
41 |
+
Date
|
42 |
+
Curriculum
|
43 |
+
Liberal
|
44 |
+
Creativity
|
45 |
+
Everlasting
|
46 |
+
PDF
|
47 |
+
Rev.
|
48 |
+
Thank
|
49 |
+
Nanak
|
50 |
+
Dangerous
|
51 |
+
Shari'a
|
52 |
+
Policy
|
53 |
+
Talmud
|
54 |
+
Best
|
55 |
+
Supply
|
56 |
+
Oneness
|
57 |
+
Punishment
|
58 |
+
Reincarnation
|
59 |
+
TransCanada
|
60 |
+
Forums
|
61 |
+
VoIP
|
62 |
+
Factors
|
63 |
+
Assistance
|
64 |
+
Charities
|
65 |
+
Calculator
|
66 |
+
Shadows
|
67 |
+
Him
|
68 |
+
Natural
|
69 |
+
Lamp
|
70 |
+
Thyme
|
71 |
+
Templar
|
72 |
+
Muhammad
|
73 |
+
Venue
|
74 |
+
Hell
|
75 |
+
Bunyan
|
76 |
+
Songs
|
77 |
+
Epistle
|
78 |
+
Suites
|
79 |
+
Economic
|
80 |
+
Intel
|
81 |
+
Spanish
|
82 |
+
Lives
|
83 |
+
Married
|
84 |
+
Hypothesis
|
85 |
+
Cosmic
|
86 |
+
Injunction
|
87 |
+
Involvement
|
88 |
+
Leviticus
|
89 |
+
Self
|
90 |
+
Truth
|
91 |
+
Mystical
|
92 |
+
Melody
|
93 |
+
Pure
|
94 |
+
Sermon
|
95 |
+
Atlantic
|
96 |
+
Excel
|
97 |
+
Sonata
|
98 |
+
SPCA
|
99 |
+
Saturday
|
100 |
+
Adventure
|
101 |
+
Honour
|
102 |
+
Resurrection
|
103 |
+
Emanuel
|
104 |
+
Connery
|
105 |
+
Rites
|
106 |
+
United
|
107 |
+
Pope
|
108 |
+
Mary
|
109 |
+
Chen
|
110 |
+
Lisa
|
111 |
+
ODST
|
112 |
+
Videos
|
113 |
+
Modernity
|
114 |
+
Sculpture
|
115 |
+
Jewish
|
116 |
+
Heavy
|
117 |
+
Remote
|
118 |
+
Praise
|
119 |
+
Foods
|
120 |
+
Merrell
|
121 |
+
Safety
|
122 |
+
Influencing
|
123 |
+
Tie
|
124 |
+
Outreach
|
125 |
+
Kenichi
|
126 |
+
Criminal
|
127 |
+
Stevie
|
128 |
+
Judgement
|
129 |
+
SQL
|
130 |
+
Basilica
|
131 |
+
Piano
|
132 |
+
Reiki
|
133 |
+
Understanding
|
134 |
+
Cognition
|
135 |
+
Maker
|
136 |
+
Diocese
|
137 |
+
Marital
|
138 |
+
Masjid
|
139 |
+
Militant
|
140 |
+
Methodist
|
141 |
+
Political
|
142 |
+
Appeals
|
143 |
+
Deities
|
144 |
+
Purchase
|
145 |
+
Rallies
|
146 |
+
Testing
|
147 |
+
Contemporary
|
148 |
+
Help
|
149 |
+
Sweet
|
150 |
+
Fallen
|
151 |
+
Spangled
|
152 |
+
Renewable
|
153 |
+
Laughter
|
154 |
+
Provider
|
155 |
+
Charitable
|
156 |
+
Ethical
|
157 |
+
Families
|
158 |
+
Cure
|
159 |
+
Significance
|
160 |
+
Communities
|
161 |
+
Cost
|
162 |
+
Demon
|
163 |
+
Motivation
|
164 |
+
Calvary
|
165 |
+
Double
|
166 |
+
Mysteries
|
167 |
+
Determining
|
168 |
+
Baptist
|
169 |
+
Mandir
|
170 |
+
Qi
|
171 |
+
Loss
|
172 |
+
Lust
|
173 |
+
Echoes
|
174 |
+
Lord
|
175 |
+
Vote
|
176 |
+
Glad
|
177 |
+
Dharma
|
178 |
+
Kombat
|
179 |
+
Prostitute
|
180 |
+
Wetlands
|
181 |
+
Queries
|
182 |
+
Always
|
183 |
+
Focus
|
184 |
+
EOS
|
185 |
+
Worship
|
186 |
+
Implications
|
187 |
+
Wiccan
|
188 |
+
Invitations
|
189 |
+
Theology
|
190 |
+
Hospital
|
191 |
+
Freedom
|
192 |
+
Mirror
|
193 |
+
Uncharted
|
194 |
+
Radiance
|
195 |
+
Serving
|
196 |
+
Buddhist
|
197 |
+
Kiss
|
198 |
+
Mother
|
199 |
+
Death
|
200 |
+
Episcopal
|
201 |
+
Impact
|
202 |
+
Shinto
|
203 |
+
Crisis
|
204 |
+
Secure
|
205 |
+
Learning
|
206 |
+
Dreams
|
207 |
+
Association
|
naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/science.txt
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
astronomical
|
2 |
+
evolved
|
3 |
+
tests
|
4 |
+
reason
|
5 |
+
idea
|
6 |
+
component
|
7 |
+
jug
|
8 |
+
rain
|
9 |
+
renewable
|
10 |
+
scaling
|
11 |
+
phone
|
12 |
+
action
|
13 |
+
studies
|
14 |
+
humidity
|
15 |
+
siphon
|
16 |
+
warming
|
17 |
+
compounds
|
18 |
+
genomics
|
19 |
+
electrons
|
20 |
+
mathematics
|
21 |
+
clinical
|
22 |
+
physiology
|
23 |
+
hypotheses
|
24 |
+
stored
|
25 |
+
statutes
|
26 |
+
magnesium
|
27 |
+
measuring
|
28 |
+
fuels
|
29 |
+
scientific
|
30 |
+
bone
|
31 |
+
molecular
|
32 |
+
microscopy
|
33 |
+
observing
|
34 |
+
parameter
|
35 |
+
transition
|
36 |
+
system
|
37 |
+
bacterium
|
38 |
+
ligand
|
39 |
+
increasing
|
40 |
+
theories
|
41 |
+
physicist
|
42 |
+
flow
|
43 |
+
pounds
|
44 |
+
nothing
|
45 |
+
observatory
|
46 |
+
gravitational
|
47 |
+
electron
|
naacl-2021-fudge-controlled-generation/topic_data/test_wordlists/space.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cosmos
|
2 |
+
mothership
|
3 |
+
flyby
|
4 |
+
broadband
|
5 |
+
aeronautics
|
6 |
+
fireball
|
7 |
+
Romulan
|
8 |
+
room
|
9 |
+
cosmonaut
|
10 |
+
actress
|
11 |
+
worlds
|
12 |
+
heavens
|
13 |
+
lunar
|
14 |
+
interstellar
|
15 |
+
galaxies
|
16 |
+
lander
|
naacl-2021-fudge-controlled-generation/topic_data/topic_prefixes.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
In summary
|
2 |
+
This essay discusses
|
3 |
+
Views on
|
4 |
+
The connection
|
5 |
+
Foundational to this is
|
6 |
+
To review,
|
7 |
+
In brief,
|
8 |
+
An illustration of
|
9 |
+
Furthermore,
|
10 |
+
The central theme
|
11 |
+
To conclude,
|
12 |
+
The key aspect
|
13 |
+
Prior to this
|
14 |
+
Emphasised are
|
15 |
+
To summarise
|
16 |
+
The relationship
|
17 |
+
More importantly,
|
18 |
+
It has been shown
|
19 |
+
The issue focused on
|
20 |
+
In this essay
|
naacl-2021-fudge-controlled-generation/topic_data/val_wordlists/fantasy.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
beast
|
2 |
+
Cerberus
|
3 |
+
demon
|
4 |
+
dragon
|
5 |
+
fairy
|
6 |
+
Frankenstein
|
7 |
+
ghost
|
8 |
+
Godzilla
|
9 |
+
giant
|
10 |
+
horror
|
11 |
+
hydra
|
12 |
+
imp
|
13 |
+
monster
|
14 |
+
mummy
|
15 |
+
ogre
|
16 |
+
orc
|
17 |
+
savage
|
18 |
+
spirit
|
19 |
+
sprite
|
20 |
+
titan
|
21 |
+
troll
|
22 |
+
undead
|
23 |
+
unicorn
|
24 |
+
vampire
|
25 |
+
witch
|
26 |
+
zombie
|
naacl-2021-fudge-controlled-generation/topic_data/wordlists/computers.txt
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
algorithm
|
2 |
+
analog
|
3 |
+
app
|
4 |
+
application
|
5 |
+
array
|
6 |
+
backup
|
7 |
+
bandwidth
|
8 |
+
binary
|
9 |
+
bit
|
10 |
+
bite
|
11 |
+
blog
|
12 |
+
blogger
|
13 |
+
bookmark
|
14 |
+
boot
|
15 |
+
broadband
|
16 |
+
browser
|
17 |
+
buffer
|
18 |
+
bug
|
19 |
+
bus
|
20 |
+
byte
|
21 |
+
cache
|
22 |
+
caps
|
23 |
+
captcha
|
24 |
+
CD
|
25 |
+
client
|
26 |
+
command
|
27 |
+
compile
|
28 |
+
compress
|
29 |
+
computer
|
30 |
+
configure
|
31 |
+
cookie
|
32 |
+
copy
|
33 |
+
CPU
|
34 |
+
dashboard
|
35 |
+
data
|
36 |
+
database
|
37 |
+
debug
|
38 |
+
delete
|
39 |
+
desktop
|
40 |
+
development
|
41 |
+
digital
|
42 |
+
disk
|
43 |
+
document
|
44 |
+
domain
|
45 |
+
dot
|
46 |
+
download
|
47 |
+
drag
|
48 |
+
dynamic
|
49 |
+
email
|
50 |
+
encrypt
|
51 |
+
encryption
|
52 |
+
enter
|
53 |
+
FAQ
|
54 |
+
file
|
55 |
+
firewall
|
56 |
+
firmware
|
57 |
+
flaming
|
58 |
+
flash
|
59 |
+
folder
|
60 |
+
font
|
61 |
+
format
|
62 |
+
frame
|
63 |
+
graphics
|
64 |
+
hack
|
65 |
+
hacker
|
66 |
+
hardware
|
67 |
+
home
|
68 |
+
host
|
69 |
+
html
|
70 |
+
icon
|
71 |
+
inbox
|
72 |
+
integer
|
73 |
+
interface
|
74 |
+
Internet
|
75 |
+
IP
|
76 |
+
iteration
|
77 |
+
Java
|
78 |
+
joystick
|
79 |
+
kernel
|
80 |
+
key
|
81 |
+
keyboard
|
82 |
+
keyword
|
83 |
+
laptop
|
84 |
+
link
|
85 |
+
Linux
|
86 |
+
logic
|
87 |
+
login
|
88 |
+
lurking
|
89 |
+
Macintosh
|
90 |
+
macro
|
91 |
+
malware
|
92 |
+
media
|
93 |
+
memory
|
94 |
+
mirror
|
95 |
+
modem
|
96 |
+
monitor
|
97 |
+
motherboard
|
98 |
+
mouse
|
99 |
+
multimedia
|
100 |
+
net
|
101 |
+
network
|
102 |
+
node
|
103 |
+
offline
|
104 |
+
online
|
105 |
+
OS
|
106 |
+
option
|
107 |
+
output
|
108 |
+
page
|
109 |
+
password
|
110 |
+
paste
|
111 |
+
path
|
112 |
+
piracy
|
113 |
+
pirate
|
114 |
+
platform
|
115 |
+
podcast
|
116 |
+
portal
|
117 |
+
print
|
118 |
+
printer
|
119 |
+
privacy
|
120 |
+
process
|
121 |
+
program
|
122 |
+
programmer
|
123 |
+
protocol
|
124 |
+
RAM
|
125 |
+
reboot
|
126 |
+
resolution
|
127 |
+
restore
|
128 |
+
ROM
|
129 |
+
root
|
130 |
+
router
|
131 |
+
runtime
|
132 |
+
save
|
133 |
+
scan
|
134 |
+
scanner
|
135 |
+
screen
|
136 |
+
screenshot
|
137 |
+
script
|
138 |
+
scroll
|
139 |
+
security
|
140 |
+
server
|
141 |
+
shell
|
142 |
+
shift
|
143 |
+
snapshot
|
144 |
+
software
|
145 |
+
spam
|
146 |
+
spreadsheet
|
147 |
+
storage
|
148 |
+
surf
|
149 |
+
syntax
|
150 |
+
table
|
151 |
+
tag
|
152 |
+
template
|
153 |
+
thread
|
154 |
+
toolbar
|
155 |
+
trash
|
156 |
+
undo
|
157 |
+
Unix
|
158 |
+
upload
|
159 |
+
URL
|
160 |
+
user
|
161 |
+
UI
|
162 |
+
username
|
163 |
+
utility
|
164 |
+
version
|
165 |
+
virtual
|
166 |
+
virus
|
167 |
+
web
|
168 |
+
website
|
169 |
+
widget
|
170 |
+
wiki
|
171 |
+
window
|
172 |
+
Windows
|
173 |
+
wireless
|
174 |
+
worm
|
175 |
+
XML
|
176 |
+
Zip
|
naacl-2021-fudge-controlled-generation/topic_data/wordlists/legal.txt
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
affidavit
|
2 |
+
allegation
|
3 |
+
appeal
|
4 |
+
appearance
|
5 |
+
argument
|
6 |
+
arrest
|
7 |
+
assault
|
8 |
+
attorney
|
9 |
+
bail
|
10 |
+
bankrupt
|
11 |
+
bankruptcy
|
12 |
+
bar
|
13 |
+
bench
|
14 |
+
warrant
|
15 |
+
bond
|
16 |
+
booking
|
17 |
+
capital
|
18 |
+
crime
|
19 |
+
case
|
20 |
+
chambers
|
21 |
+
claim
|
22 |
+
complainant
|
23 |
+
complaint
|
24 |
+
confess
|
25 |
+
confession
|
26 |
+
constitution
|
27 |
+
constitutional
|
28 |
+
contract
|
29 |
+
counsel
|
30 |
+
court
|
31 |
+
custody
|
32 |
+
damages
|
33 |
+
decree
|
34 |
+
defendant
|
35 |
+
defense
|
36 |
+
deposition
|
37 |
+
discovery
|
38 |
+
equity
|
39 |
+
estate
|
40 |
+
ethics
|
41 |
+
evidence
|
42 |
+
examination
|
43 |
+
family
|
44 |
+
law
|
45 |
+
felony
|
46 |
+
file
|
47 |
+
fraud
|
48 |
+
grievance
|
49 |
+
guardian
|
50 |
+
guilty
|
51 |
+
hearing
|
52 |
+
immunity
|
53 |
+
incarceration
|
54 |
+
incompetent
|
55 |
+
indictment
|
56 |
+
injunction
|
57 |
+
innocent
|
58 |
+
instructions
|
59 |
+
jail
|
60 |
+
judge
|
61 |
+
judiciary
|
62 |
+
jurisdiction
|
63 |
+
jury
|
64 |
+
justice
|
65 |
+
law
|
66 |
+
lawsuit
|
67 |
+
lawyer
|
68 |
+
legal
|
69 |
+
legislation
|
70 |
+
liable
|
71 |
+
litigation
|
72 |
+
manslaughter
|
73 |
+
mediation
|
74 |
+
minor
|
75 |
+
misdemeanor
|
76 |
+
moot
|
77 |
+
murder
|
78 |
+
negligence
|
79 |
+
oath
|
80 |
+
objection
|
81 |
+
opinion
|
82 |
+
order
|
83 |
+
ordinance
|
84 |
+
pardon
|
85 |
+
parole
|
86 |
+
party
|
87 |
+
perjury
|
88 |
+
petition
|
89 |
+
plaintiff
|
90 |
+
plea
|
91 |
+
precedent
|
92 |
+
prison
|
93 |
+
probation
|
94 |
+
prosecute
|
95 |
+
prosecutor
|
96 |
+
proxy
|
97 |
+
record
|
98 |
+
redress
|
99 |
+
resolution
|
100 |
+
reverse
|
101 |
+
revoke
|
102 |
+
robbery
|
103 |
+
rules
|
104 |
+
sentence
|
105 |
+
settlement
|
106 |
+
sheriff
|
107 |
+
sidebar
|
108 |
+
standing
|
109 |
+
state
|
110 |
+
statute
|
111 |
+
stay
|
112 |
+
subpoena
|
113 |
+
suit
|
114 |
+
suppress
|
115 |
+
sustain
|
116 |
+
testimony
|
117 |
+
theft
|
118 |
+
title
|
119 |
+
tort
|
120 |
+
transcript
|
121 |
+
trial
|
122 |
+
trust
|
123 |
+
trustee
|
124 |
+
venue
|
125 |
+
verdict
|
126 |
+
waiver
|
127 |
+
warrant
|
128 |
+
will
|
129 |
+
witness
|
130 |
+
writ
|
131 |
+
zoning
|
naacl-2021-fudge-controlled-generation/topic_data/wordlists/military.txt
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
academy
|
2 |
+
advance
|
3 |
+
aircraft
|
4 |
+
ally
|
5 |
+
ammo
|
6 |
+
ammunition
|
7 |
+
armor
|
8 |
+
arms
|
9 |
+
army
|
10 |
+
arrow
|
11 |
+
arsenal
|
12 |
+
artillery
|
13 |
+
attack
|
14 |
+
attention
|
15 |
+
ballistic
|
16 |
+
barracks
|
17 |
+
base
|
18 |
+
battalion
|
19 |
+
battery
|
20 |
+
battle
|
21 |
+
battlefield
|
22 |
+
bomb
|
23 |
+
bombard
|
24 |
+
bombardment
|
25 |
+
brig
|
26 |
+
brigade
|
27 |
+
bullet
|
28 |
+
camouflage
|
29 |
+
camp
|
30 |
+
cannon
|
31 |
+
captain
|
32 |
+
capture
|
33 |
+
carrier
|
34 |
+
casualty
|
35 |
+
catapult
|
36 |
+
cavalry
|
37 |
+
colonel
|
38 |
+
combat
|
39 |
+
command
|
40 |
+
commander
|
41 |
+
commission
|
42 |
+
company
|
43 |
+
conflict
|
44 |
+
conquest
|
45 |
+
convoy
|
46 |
+
corps
|
47 |
+
covert
|
48 |
+
crew
|
49 |
+
decode
|
50 |
+
defeat
|
51 |
+
defend
|
52 |
+
defense
|
53 |
+
destroyer
|
54 |
+
division
|
55 |
+
draft
|
56 |
+
encode
|
57 |
+
enemy
|
58 |
+
engage
|
59 |
+
enlist
|
60 |
+
evacuate
|
61 |
+
explosive
|
62 |
+
fight
|
63 |
+
fire
|
64 |
+
fleet
|
65 |
+
force
|
66 |
+
formation
|
67 |
+
fort
|
68 |
+
front
|
69 |
+
garrison
|
70 |
+
general
|
71 |
+
grenade
|
72 |
+
grunt
|
73 |
+
guerrilla
|
74 |
+
gun
|
75 |
+
headquarters
|
76 |
+
helmet
|
77 |
+
honor
|
78 |
+
hospital
|
79 |
+
infantry
|
80 |
+
injury
|
81 |
+
intelligence
|
82 |
+
invade
|
83 |
+
invasion
|
84 |
+
jet
|
85 |
+
kill
|
86 |
+
leave
|
87 |
+
lieutenant
|
88 |
+
major
|
89 |
+
maneuver
|
90 |
+
marines
|
91 |
+
MIA
|
92 |
+
mid
|
93 |
+
military
|
94 |
+
mine
|
95 |
+
missile
|
96 |
+
mortar
|
97 |
+
navy
|
98 |
+
neutral
|
99 |
+
offense
|
100 |
+
officer
|
101 |
+
ordinance
|
102 |
+
parachute
|
103 |
+
peace
|
104 |
+
plane
|
105 |
+
platoon
|
106 |
+
private
|
107 |
+
radar
|
108 |
+
rank
|
109 |
+
recruit
|
110 |
+
regiment
|
111 |
+
rescue
|
112 |
+
reserves
|
113 |
+
retreat
|
114 |
+
ribbon
|
115 |
+
sabotage
|
116 |
+
sailor
|
117 |
+
salute
|
118 |
+
section
|
119 |
+
sergeant
|
120 |
+
service
|
121 |
+
shell
|
122 |
+
shoot
|
123 |
+
shot
|
124 |
+
siege
|
125 |
+
sniper
|
126 |
+
soldier
|
127 |
+
spear
|
128 |
+
specialist
|
129 |
+
squad
|
130 |
+
squadron
|
131 |
+
staff
|
132 |
+
submarine
|
133 |
+
surrender
|
134 |
+
tactical
|
135 |
+
tactics
|
136 |
+
tank
|
137 |
+
torpedo
|
138 |
+
troops
|
139 |
+
truce
|
140 |
+
uniform
|
141 |
+
unit
|
142 |
+
veteran
|
143 |
+
volley
|
144 |
+
war
|
145 |
+
warfare
|
146 |
+
warrior
|
147 |
+
weapon
|
148 |
+
win
|
149 |
+
wound
|
naacl-2021-fudge-controlled-generation/topic_data/wordlists/politics.txt
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
affirm
|
2 |
+
appropriation
|
3 |
+
aristocracy
|
4 |
+
authoritarian
|
5 |
+
authority
|
6 |
+
authorization
|
7 |
+
brief
|
8 |
+
capitalism
|
9 |
+
communism
|
10 |
+
constitution
|
11 |
+
conservatism
|
12 |
+
court
|
13 |
+
deficit
|
14 |
+
diplomacy
|
15 |
+
direct
|
16 |
+
democracy
|
17 |
+
equality
|
18 |
+
exports
|
19 |
+
fascism
|
20 |
+
federation
|
21 |
+
government
|
22 |
+
ideology
|
23 |
+
imports
|
24 |
+
initiative
|
25 |
+
legislature
|
26 |
+
legitimacy
|
27 |
+
liberalism
|
28 |
+
liberty
|
29 |
+
majority
|
30 |
+
order
|
31 |
+
political
|
32 |
+
culture
|
33 |
+
politics
|
34 |
+
power
|
35 |
+
primary
|
36 |
+
property
|
37 |
+
ratification
|
38 |
+
recall
|
39 |
+
referendum
|
40 |
+
republic
|
41 |
+
socialism
|
42 |
+
state
|
43 |
+
subsidy
|
44 |
+
tariff
|
45 |
+
imports
|
46 |
+
tax
|
47 |
+
totalitarian
|
naacl-2021-fudge-controlled-generation/topic_data/wordlists/religion.txt
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absolute
|
2 |
+
affect
|
3 |
+
aid
|
4 |
+
angel
|
5 |
+
anthem
|
6 |
+
apostle
|
7 |
+
archangel
|
8 |
+
Archbishop
|
9 |
+
balance
|
10 |
+
ban
|
11 |
+
belief
|
12 |
+
benefit
|
13 |
+
Bible
|
14 |
+
bishop
|
15 |
+
bless
|
16 |
+
blessing
|
17 |
+
bliss
|
18 |
+
bond
|
19 |
+
bow
|
20 |
+
Buddhism
|
21 |
+
canon
|
22 |
+
Cantor
|
23 |
+
cathedral
|
24 |
+
celestial
|
25 |
+
chapel
|
26 |
+
charity
|
27 |
+
choice
|
28 |
+
Christianity
|
29 |
+
church
|
30 |
+
comfort
|
31 |
+
community
|
32 |
+
conflict
|
33 |
+
connection
|
34 |
+
conquest
|
35 |
+
conservative
|
36 |
+
control
|
37 |
+
conversion
|
38 |
+
convert
|
39 |
+
core
|
40 |
+
counsel
|
41 |
+
courage
|
42 |
+
Covenant
|
43 |
+
creative
|
44 |
+
Creator
|
45 |
+
creed
|
46 |
+
cross
|
47 |
+
Crusade
|
48 |
+
Darkness
|
49 |
+
decision
|
50 |
+
deity
|
51 |
+
destiny
|
52 |
+
Devil
|
53 |
+
disciple
|
54 |
+
discipline
|
55 |
+
discussion
|
56 |
+
divine
|
57 |
+
divinity
|
58 |
+
doctrine
|
59 |
+
duty
|
60 |
+
effect
|
61 |
+
elder
|
62 |
+
energy
|
63 |
+
essence
|
64 |
+
eternal
|
65 |
+
ethics
|
66 |
+
event
|
67 |
+
evidence
|
68 |
+
exile
|
69 |
+
Exodus
|
70 |
+
faith
|
71 |
+
family
|
72 |
+
fate
|
73 |
+
Father
|
74 |
+
favor
|
75 |
+
fundamental
|
76 |
+
gift
|
77 |
+
glory
|
78 |
+
God
|
79 |
+
gospel
|
80 |
+
grace
|
81 |
+
growth
|
82 |
+
guru
|
83 |
+
habit
|
84 |
+
hallow
|
85 |
+
halo
|
86 |
+
happiness
|
87 |
+
harmony
|
88 |
+
healing
|
89 |
+
Heaven
|
90 |
+
Hebrew
|
91 |
+
holy
|
92 |
+
honor
|
93 |
+
hope
|
94 |
+
host
|
95 |
+
humane
|
96 |
+
immortal
|
97 |
+
influence
|
98 |
+
insight
|
99 |
+
instruction
|
100 |
+
issue
|
101 |
+
Jesuit
|
102 |
+
Jesus
|
103 |
+
joy
|
104 |
+
Judaism
|
105 |
+
judgment
|
106 |
+
justice
|
107 |
+
karma
|
108 |
+
keen
|
109 |
+
Keystone
|
110 |
+
Kingdom
|
111 |
+
Latin
|
112 |
+
life
|
113 |
+
light
|
114 |
+
love
|
115 |
+
loving
|
116 |
+
marriage
|
117 |
+
meaning
|
118 |
+
mercy
|
119 |
+
Messiah
|
120 |
+
minister
|
121 |
+
miracle
|
122 |
+
mission
|
123 |
+
mortal
|
124 |
+
mosque
|
125 |
+
movement
|
126 |
+
music
|
127 |
+
mystery
|
128 |
+
nature
|
129 |
+
nun
|
130 |
+
official
|
131 |
+
oracle
|
132 |
+
order
|
133 |
+
organ
|
134 |
+
Orthodox
|
135 |
+
outlook
|
136 |
+
pacific
|
137 |
+
pagan
|
138 |
+
parish
|
139 |
+
participation
|
140 |
+
pastor
|
141 |
+
patriarch
|
142 |
+
peace
|
143 |
+
perception
|
144 |
+
personal
|
145 |
+
perspective
|
146 |
+
petition
|
147 |
+
pilgrim
|
148 |
+
politics
|
149 |
+
power
|
150 |
+
practice
|
151 |
+
prayer
|
152 |
+
prelude
|
153 |
+
presence
|
154 |
+
priest
|
155 |
+
principle
|
156 |
+
privacy
|
157 |
+
prophet
|
158 |
+
protection
|
159 |
+
purpose
|
160 |
+
query
|
161 |
+
quest
|
162 |
+
question
|
163 |
+
quiet
|
164 |
+
radiant
|
165 |
+
radical
|
166 |
+
rally
|
167 |
+
rebirth
|
168 |
+
redemption
|
169 |
+
refuge
|
170 |
+
relationship
|
171 |
+
relative
|
172 |
+
religion
|
173 |
+
religious
|
174 |
+
Revelation
|
175 |
+
ritual
|
176 |
+
role
|
177 |
+
Sacrament
|
178 |
+
sacred
|
179 |
+
sacrifice
|
180 |
+
sage
|
181 |
+
saint
|
182 |
+
salvation
|
183 |
+
sanctuary
|
184 |
+
savior
|
185 |
+
scripture
|
186 |
+
scriptures
|
187 |
+
sect
|
188 |
+
security
|
189 |
+
sense
|
190 |
+
serious
|
191 |
+
serve
|
192 |
+
service
|
193 |
+
Sharia
|
194 |
+
shepherd
|
195 |
+
shrine
|
196 |
+
silence
|
197 |
+
sin
|
198 |
+
society
|
199 |
+
soul
|
200 |
+
source
|
201 |
+
spirit
|
202 |
+
spiritual
|
203 |
+
split
|
204 |
+
statue
|
205 |
+
Sunday
|
206 |
+
support
|
207 |
+
Supreme
|
208 |
+
teaching
|
209 |
+
temple
|
210 |
+
tests
|
211 |
+
text
|
212 |
+
Torah
|
213 |
+
tradition
|
214 |
+
traditional
|
215 |
+
trust
|
216 |
+
unique
|
217 |
+
unity
|
218 |
+
unknown
|
219 |
+
value
|
220 |
+
vanity
|
221 |
+
virtue
|
222 |
+
vision
|
223 |
+
voice
|
224 |
+
voices
|
225 |
+
watch
|
226 |
+
weight
|
227 |
+
whole
|
228 |
+
wisdom
|
229 |
+
wonder
|
230 |
+
yang
|
231 |
+
yin
|
232 |
+
zeal
|
naacl-2021-fudge-controlled-generation/topic_data/wordlists/science.txt
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
astronomy
|
2 |
+
atom
|
3 |
+
biology
|
4 |
+
cell
|
5 |
+
chemical
|
6 |
+
chemistry
|
7 |
+
climate
|
8 |
+
control
|
9 |
+
data
|
10 |
+
electricity
|
11 |
+
element
|
12 |
+
energy
|
13 |
+
evolution
|
14 |
+
experiment
|
15 |
+
fact
|
16 |
+
flask
|
17 |
+
fossil
|
18 |
+
funnel
|
19 |
+
genetics
|
20 |
+
gravity
|
21 |
+
hypothesis
|
22 |
+
lab
|
23 |
+
laboratory
|
24 |
+
laws
|
25 |
+
mass
|
26 |
+
matter
|
27 |
+
measure
|
28 |
+
microscope
|
29 |
+
mineral
|
30 |
+
molecule
|
31 |
+
motion
|
32 |
+
observe
|
33 |
+
organism
|
34 |
+
particle
|
35 |
+
phase
|
36 |
+
physics
|
37 |
+
research
|
38 |
+
scale
|
39 |
+
science
|
40 |
+
scientist
|
41 |
+
telescope
|
42 |
+
temperature
|
43 |
+
theory
|
44 |
+
tissue
|
45 |
+
variable
|
46 |
+
volume
|
47 |
+
weather
|
48 |
+
weigh
|
naacl-2021-fudge-controlled-generation/topic_data/wordlists/space.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
planet
|
2 |
+
galaxy
|
3 |
+
space
|
4 |
+
universe
|
5 |
+
orbit
|
6 |
+
spacecraft
|
7 |
+
earth
|
8 |
+
moon
|
9 |
+
comet
|
10 |
+
star
|
11 |
+
astronaut
|
12 |
+
aerospace
|
13 |
+
asteroid
|
14 |
+
spaceship
|
15 |
+
starship
|
16 |
+
galactic
|
17 |
+
satellite
|
18 |
+
meteor
|
naacl-2021-fudge-controlled-generation/transcript.txt
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
(Sorry, the slide numbers got a bit misaligned as I added slides. Not an exact transcript for the video but roughly correct.)
|
2 |
+
|
3 |
+
1:
|
4 |
+
Hi! I'm Kevin from UC Berkeley, and today I'll be presenting my paper FUDGE: Controlled Text Generation with Future Discriminators, by me and my advisor Dan Klein.
|
5 |
+
|
6 |
+
2:
|
7 |
+
So first a quick overview.
|
8 |
+
|
9 |
+
3:
|
10 |
+
I'll start by explaining the problem of controlled text generation with some examples,
|
11 |
+
|
12 |
+
4:
|
13 |
+
then describe our method, FUDGE, Future Discriminators for Generation,
|
14 |
+
|
15 |
+
5:
|
16 |
+
and in doing so I'll also show experimental results and example model outputs on three diverse controlled generation tasks.
|
17 |
+
|
18 |
+
6:
|
19 |
+
So what's controlled text generation?
|
20 |
+
|
21 |
+
7:
|
22 |
+
Well let's start with our autoregressive language model that we use for text generation, without the controlled part.
|
23 |
+
|
24 |
+
8:
|
25 |
+
The language model models a distribution over next tokens x i+1 given the prefix x1 to x i. for example, you might tell it to
|
26 |
+
|
27 |
+
9:
|
28 |
+
Generate text according to a prompt like
|
29 |
+
|
30 |
+
9:
|
31 |
+
THIS, the issue focused on.
|
32 |
+
|
33 |
+
10:
|
34 |
+
and then it'll chug along and generate text,
|
35 |
+
|
36 |
+
11:
|
37 |
+
and these days language models are pretty good.
|
38 |
+
But in controlled generation, you have an additional attribute constraint
|
39 |
+
|
40 |
+
12:
|
41 |
+
like wanting the output to be about politics.
|
42 |
+
|
43 |
+
13:
|
44 |
+
Specifically, we have an attribute function a(X) which says whether or not the attribute a is true for your output X, in this case whether or not the output is on topic. There's no probabilities involved in a(X) since it operates on the completed generation output, not on partial sequences.
|
45 |
+
|
46 |
+
14:
|
47 |
+
More precisely, the task of controlled text generation is to sample from the distribution P(X given a = True), so the distribution of outputs X which satisfy a.
|
48 |
+
|
49 |
+
15:
|
50 |
+
By default the language model isn't equipped to handle this additional constraint a, so its output is not going to pass.
|
51 |
+
So we need a method for *controlled* text generation.
|
52 |
+
|
53 |
+
15:
|
54 |
+
For example, our method FUDGE.
|
55 |
+
|
56 |
+
16:
|
57 |
+
Given the same prompt with the politics topic,
|
58 |
+
|
59 |
+
16:
|
60 |
+
Here's what FUDGE says.
|
61 |
+
|
62 |
+
17:
|
63 |
+
It worked pretty well in this example. It's talking about institutions and constitutions, which seems clearly on topic.
|
64 |
+
|
65 |
+
18:
|
66 |
+
And I'll point out here that controlled generation makes sense in addition to the usual conditioning on the input that you might see in translation or summarization.
|
67 |
+
|
68 |
+
19:
|
69 |
+
Say we're translating Spanish to English. There's input conditioning on the original Spanish, but we're also imposing the additional constraint that the output be formal, which is where controlled text generation comes in.
|
70 |
+
|
71 |
+
20:
|
72 |
+
So say we have this Spanish input
|
73 |
+
|
74 |
+
20:
|
75 |
+
and let me just move it to the corner so we can still see it
|
76 |
+
|
77 |
+
20:
|
78 |
+
If you ask your off-the-shelf translation model it'll get the meaning right,
|
79 |
+
but it copies some ungrammatical parts of the original Spanish
|
80 |
+
|
81 |
+
21:
|
82 |
+
like these repeated words in bold.
|
83 |
+
|
84 |
+
22:
|
85 |
+
So at the end when we ask our formality classifier,
|
86 |
+
|
87 |
+
23:
|
88 |
+
it might not be super happy.
|
89 |
+
|
90 |
+
24:
|
91 |
+
But if you use a controlled text generation approach like FUDGE,
|
92 |
+
|
93 |
+
25:
|
94 |
+
You can get this translation which preserves the meaning, while also better matching the formal style constraint.
|
95 |
+
|
96 |
+
26:
|
97 |
+
|
98 |
+
|
99 |
+
27:
|
100 |
+
You might wonder, why don't we just do rejection sampling?
|
101 |
+
|
102 |
+
28:
|
103 |
+
Just sample a bunch of times from the translator
|
104 |
+
|
105 |
+
29:
|
106 |
+
until you get one that passes.
|
107 |
+
That might work for some simpler constraints like topics, but it's going to be totally intractable when you use constraints that are very rarely satisfied by your generator distribution.
|
108 |
+
|
109 |
+
30:
|
110 |
+
What are some more difficult attribute constraints?
|
111 |
+
|
112 |
+
31:
|
113 |
+
Consider this task, effectively, complete the poem.
|
114 |
+
|
115 |
+
32:
|
116 |
+
Let’s see what the language model says when we give it this input from Shakespeare. And even thence thou wilt be stol'n I fear
|
117 |
+
|
118 |
+
32:
|
119 |
+
and thou art a good friend of mine. The king's guard.
|
120 |
+
|
121 |
+
33:
|
122 |
+
This is terrible! It doesn't roll off the tongue, it doesn't rhyme, it doesn't even end the sentence properly at the end.
|
123 |
+
|
124 |
+
34:
|
125 |
+
Shakespeare hates it. You could generate any number of poems using your language model, and Shakespeare is gonna hate every last one.
|
126 |
+
|
127 |
+
35:
|
128 |
+
But if you ask FUDGE, you get this. And even thence thou wilt be stol'n I fear, for this shall be the end. That's pretty clear.
|
129 |
+
|
130 |
+
36:
|
131 |
+
So it's not Shakespeare, but it gets the meter, or rhythm right, it rhymes, and it ends the sentence in about the right place at the end. Not too bad.
|
132 |
+
|
133 |
+
37:
|
134 |
+
Ok. So how does controlled generation work anyway? Let me give an incredibly oversimplified summary of some ideas in this line of work to put FUDGE in context.
|
135 |
+
|
136 |
+
38:
|
137 |
+
First, you can finetune.
|
138 |
+
|
139 |
+
39:
|
140 |
+
We'll use the politics topic as an example.
|
141 |
+
|
142 |
+
39:
|
143 |
+
You can train on a bunch of text about politics. Depending on how good your data is, this can work great! or it could be rather bad. It also might be annoying to have to finetune again next time, when you want to write about science instead.
|
144 |
+
|
145 |
+
40:
|
146 |
+
Another idea is to use a classifier.
|
147 |
+
|
148 |
+
41:
|
149 |
+
We're already using a classifier to evaluate.
|
150 |
+
|
151 |
+
42:
|
152 |
+
We can use a classifier to help us generate too. There's many different ways to do this.
|
153 |
+
|
154 |
+
43:
|
155 |
+
For example, you might propagate gradients to modify the model's activations,
|
156 |
+
|
157 |
+
44:
|
158 |
+
or you could just directly modify the model's output probabilities. One advantage of the latter method is that you don't need to access the original language model's gradients at all, which is nice if you're using something like GPT3. You can also swap the generator out as better models become available, like GPT4. Our approach FUDGE falls in this category of just modifying the output logits.
|
159 |
+
|
160 |
+
45:
|
161 |
+
Ok, so what's FUDGE?
|
162 |
+
|
163 |
+
46:
|
164 |
+
FUDGE at its core learns a lightweight classifier for the attribute constraint, and then follows a Bayesian factorization to combine it with the original generator, like the pretrained language model.
|
165 |
+
|
166 |
+
47:
|
167 |
+
A key difference from prior work is that we plan for the future, not the immediate present.
|
168 |
+
|
169 |
+
48:
|
170 |
+
And finally, FUDGE can easily and flexibly compose multiple constraints.
|
171 |
+
|
172 |
+
49:
|
173 |
+
Let's start with the classifier and Bayesian factorization.
|
174 |
+
|
175 |
+
50:
|
176 |
+
Since FUDGE builds off the base language model, let's review:
|
177 |
+
|
178 |
+
51:
|
179 |
+
You feed whatever tokens you have so far
|
180 |
+
|
181 |
+
52:
|
182 |
+
into your model,
|
183 |
+
|
184 |
+
53:
|
185 |
+
which models the distribution over possible next tokens.
|
186 |
+
|
187 |
+
54:
|
188 |
+
And then you sample from this distribution to pick your continuation.
|
189 |
+
|
190 |
+
55:
|
191 |
+
Now, we completely ignored the formal style constraint.
|
192 |
+
|
193 |
+
56:
|
194 |
+
So it's gonna be unhappy.
|
195 |
+
|
196 |
+
57:
|
197 |
+
So what do you want to do instead?
|
198 |
+
|
199 |
+
58:
|
200 |
+
Well, what you really want is to use your classifier to judge continuations.
|
201 |
+
|
202 |
+
59:
|
203 |
+
and mark which ones are acceptable given your constraint. So the classifier looks at each possible next continuation Do you want, Do you prefer, Do you thus, and so on maybe up to some limit, and judges each one individually to decide which it's ok with.
|
204 |
+
|
205 |
+
60:
|
206 |
+
So putting it together, we throw out whatever the classifier didn't like,
|
207 |
+
|
208 |
+
61:
|
209 |
+
and then we select from whatever the classifier is ok with depending on the base generator's probabilities.
|
210 |
+
And this gets you "Do you prefer" instead of "Do you want"
|
211 |
+
|
212 |
+
62:
|
213 |
+
which sounds a bit more formal.
|
214 |
+
|
215 |
+
63:
|
216 |
+
But there's a subtle problem in this diagram.
|
217 |
+
The classifier is supposed to judge the finished sentence, not the prefixes,
|
218 |
+
|
219 |
+
64:
|
220 |
+
but here we've shoved it into our generation procedure where it's gonna operate on prefixes.
|
221 |
+
What we actually need is
|
222 |
+
|
223 |
+
65:
|
224 |
+
kind of a future looking crystal ball version of the classifier, which judges whether the whole sentence will eventually be formal, given the current prefix.
|
225 |
+
|
226 |
+
65:
|
227 |
+
And in practice, we implement the judge as a learned binary classifier, which runs on each possible continuation, and for each one outputs the probability that in the end the desired attribute a will be True, or in this case whether the finished sentence would be formal, given just the current prefix plus next token.
|
228 |
+
So in the red table, this 0.2 by "want" means it thinks that there's a 20% chance that the eventual sentence would be formal if we started with Do you want, whereas it assigns a much higher probability for Do you prefer and Do you thus because those are more formal.
|
229 |
+
|
230 |
+
68:
|
231 |
+
And then we sample proportionally from the probabilities in the purple table,
|
232 |
+
which are now just the elementwise product of the blue and red tables' probabilities.
|
233 |
+
This corresponds exactly to a Bayesian factorization for the probability distribution over sentences generated by the language model that possess the desired attribute, and you can check the math in the paper.
|
234 |
+
But the Bayesian motivation is not new.
|
235 |
+
|
236 |
+
70:
|
237 |
+
What's really new in FUDGE is that we explicitly distinguish the final classifier from the crystal ball future-predicting version that we use during the generation procedure, and making this distinction is critical for performance.
|
238 |
+
|
239 |
+
71:
|
240 |
+
Let's see FUDGE in action.
|
241 |
+
|
242 |
+
72:
|
243 |
+
if you recall our Spanish to English formal translation example.
|
244 |
+
|
245 |
+
73:
|
246 |
+
Let's backtrack FUDGE to this step.
|
247 |
+
|
248 |
+
74:
|
249 |
+
Again we have the repeated Spanish que que in bold, which the base model translated verbatim as that, that.
|
250 |
+
|
251 |
+
75:
|
252 |
+
But by having our classifier judge the formality of possible continuations, FUDGE is able to modify its continuation so that it doesn't repeat the words here.
|
253 |
+
|
254 |
+
76:
|
255 |
+
And the end result preserves the meaning while being also a bit more formal.
|
256 |
+
|
257 |
+
77:
|
258 |
+
And finally this all holds up in our experiments. So we have a classifier trained on a heldout dataset of formality, and it indeed judges FUDGE's outputs to be significantly more formal than those of the best prior method.
|
259 |
+
|
260 |
+
78:
|
261 |
+
At the same time, FUDGE is able to preserve the content, based on measuring BLEU against cleaned reference translations.
|
262 |
+
|
263 |
+
79:
|
264 |
+
Ok great. So next I'll elaborate more about planning for the future vs present,
|
265 |
+
|
266 |
+
80:
|
267 |
+
and I'll try to show more clearly *why* we really need this crystal ball classifier.
|
268 |
+
|
269 |
+
81:
|
270 |
+
Let's go back to our politics topic constraint.
|
271 |
+
|
272 |
+
82:
|
273 |
+
For simplicity, let's pretend just for this talk that the politics topic just means whether or not you use the word "constitution."
|
274 |
+
|
275 |
+
83:
|
276 |
+
So the constraint that we check at the end of generation is literally just grep for constitution.
|
277 |
+
|
278 |
+
84:
|
279 |
+
The crystal ball classifier has a much harder task. For a given prefix, it needs to predict whether each possible word makes "constitution" more likely to appear later.
|
280 |
+
|
281 |
+
85:
|
282 |
+
So how do we learn this?
|
283 |
+
|
284 |
+
86:
|
285 |
+
Say you have this example in your training data containing "constitution"
|
286 |
+
|
287 |
+
87:
|
288 |
+
The crystal ball classifier takes this and makes a bunch of prefix examples, labeled with the attribute function a(X)=True because we saw those prefixes led to the word "constitution" later.
|
289 |
+
|
290 |
+
88:
|
291 |
+
And similarly if you have this example without the word "constitution"
|
292 |
+
|
293 |
+
89:
|
294 |
+
It'll label those prefixes as False.
|
295 |
+
|
296 |
+
90:
|
297 |
+
Ok
|
298 |
+
|
299 |
+
91:
|
300 |
+
So let's examine what FUDGE generates.
|
301 |
+
|
302 |
+
92:
|
303 |
+
After a couple of steps, we have It has been shown whether the two
|
304 |
+
|
305 |
+
93:
|
306 |
+
What if you hypothetically use the non crystal ball classifier to guide generation?
|
307 |
+
|
308 |
+
94:
|
309 |
+
The issue focused on whether the two constitution
|
310 |
+
(pause) Maybe not. We don't really want to sacrifice fluency. But this classifier is too shortsighted. It's all or nothing, you either have to use constitution immediately or bust.
|
311 |
+
|
312 |
+
95:
|
313 |
+
Ok
|
314 |
+
|
315 |
+
96:
|
316 |
+
Good thing FUDGE is actually using the future looking classifier.
|
317 |
+
|
318 |
+
97:
|
319 |
+
So instead, FUDGE is going to generate something which is still reasonably likely under the original language model, but makes constitution more likely to be generated later on. This classifier doesn't care whether constitution is generated now or later, as long as it shows up eventually.
|
320 |
+
|
321 |
+
98:
|
322 |
+
So here it's going to write about institutions, so it's on the right topic
|
323 |
+
|
324 |
+
99:
|
325 |
+
which eventually leads it to write about the constitution.
|
326 |
+
|
327 |
+
100:
|
328 |
+
Great.
|
329 |
+
|
330 |
+
101:
|
331 |
+
And indeed in our experiments, FUDGE is great according to human evaluations too. It substantially beats the best prior method in pairwise evaluations of being on topic,
|
332 |
+
|
333 |
+
102:
|
334 |
+
while also beating it in fluency.
|
335 |
+
|
336 |
+
103:
|
337 |
+
Cool. So I've now demonstrated the importance of planning for the future through this topic control task.
|
338 |
+
And finally, i'll highlight FUDGE's compositional potential, using a third task.
|
339 |
+
|
340 |
+
104:
|
341 |
+
Ok.
|
342 |
+
|
343 |
+
105:
|
344 |
+
So remember our schematic diagram where we have the judge of formality.
|
345 |
+
|
346 |
+
106:
|
347 |
+
This works great when we have just one attribute we care about.
|
348 |
+
|
349 |
+
107:
|
350 |
+
Now, what if you have another attribute? Maybe you want it to be formal but also about math
|
351 |
+
|
352 |
+
108:
|
353 |
+
Now our old crystal ball classifier of just formality isn't good enough anymore.
|
354 |
+
Of course, you could construct a classifier which predicts both attributes simultaneously, but FUDGE lets you do something more scalable and also i think a bit more elegant.
|
355 |
+
|
356 |
+
109:
|
357 |
+
Just reuse the formality predictor, while adding a second crystal ball for the math topic.
|
358 |
+
So now your generation is guided by one classifier for each constraint,
|
359 |
+
|
360 |
+
110:
|
361 |
+
and it picks something which it thinks sounds more mathy.
|
362 |
+
|
363 |
+
111:
|
364 |
+
So let's see this in practice.
|
365 |
+
|
366 |
+
112:
|
367 |
+
Remember our poetry examples? where FUDGE's example isn't quite Shakespeare but is at least pretty well-formed.
|
368 |
+
This task actually uses three separate constraints:
|
369 |
+
|
370 |
+
113:
|
371 |
+
We want iambic meter, which means that every other syllable should be a stressed syllable when we're reading it,
|
372 |
+
|
373 |
+
114:
|
374 |
+
we want the two lines to rhyme, and since the first line is 10 syllables that means the second line should be 10 syllables too,
|
375 |
+
|
376 |
+
115:
|
377 |
+
and the second line that we generate should end the sentence afterward too.
|
378 |
+
|
379 |
+
116:
|
380 |
+
So let's backtrack to halfway through FUDGE's generation, before it's generated the last couple of words, pretty clear.
|
381 |
+
|
382 |
+
117:
|
383 |
+
FUDGE is using its crystal ball poetry classifier, which is a combination of three classifiers, one for each of the three constraints.
|
384 |
+
|
385 |
+
118:
|
386 |
+
It would be perfectly grammatical to just directly say "clear". This works for the iambic meter constraint. But this is only the 8th syllable, so you'd still have to rhyme and end a new sentence in just two more syllables.
|
387 |
+
|
388 |
+
119:
|
389 |
+
Then we're probably back to angry Shakespeare.
|
390 |
+
|
391 |
+
120:
|
392 |
+
So FUDGE first generates pretty
|
393 |
+
|
394 |
+
121:
|
395 |
+
before finishing with clear and a period,
|
396 |
+
|
397 |
+
122:
|
398 |
+
and this show how FUDGE is able to compose multiple attributes using multiple classifiers, while simultaneously planning for the future as I described previously.
|
399 |
+
|
400 |
+
123:
|
401 |
+
Finally, if we look at the experiments, FUDGE's performance holds up, with the success rate on simultaneously satisfying all three constraints being more than double that of the best prior method.
|
402 |
+
|
403 |
+
124:
|
404 |
+
So that wraps things up. The takeaways are that FUDGE is a simple, flexible method for controlled text generation.
|
405 |
+
|
406 |
+
125:
|
407 |
+
To reiterate our three main points from earlier, FUDGE learns a classifier in a Bayesian factorization to guide the generation,
|
408 |
+
it plans for the future rather than the present,
|
409 |
+
and it can easily and flexibly compose different constraints as needed while maintaining strong performance.
|
410 |
+
|
411 |
+
126:
|
412 |
+
And our code is all publicly available.
|
413 |
+
|
414 |
+
127:
|
415 |
+
Thanks for watching! And please check out our paper for the full details.
|
naacl-2021-fudge-controlled-generation/util.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import sys
|
4 |
+
from contextlib import contextmanager
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from constants import *
|
9 |
+
|
10 |
+
@contextmanager
|
11 |
+
def suppress_stdout():
|
12 |
+
with open(os.devnull, "w") as devnull:
|
13 |
+
old_stdout = sys.stdout
|
14 |
+
sys.stdout = devnull
|
15 |
+
try:
|
16 |
+
yield
|
17 |
+
finally:
|
18 |
+
sys.stdout = old_stdout
|
19 |
+
|
20 |
+
|
21 |
+
def save_checkpoint(state, save_path):
|
22 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
23 |
+
torch.save(state, save_path)
|
24 |
+
|
25 |
+
|
26 |
+
def freeze(module):
|
27 |
+
for param in module.parameters():
|
28 |
+
param.requires_grad = False
|
29 |
+
|
30 |
+
|
31 |
+
def num_params(model):
|
32 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
33 |
+
|
34 |
+
|
35 |
+
def clamp(x, limit):
|
36 |
+
return max(-limit, min(x, limit))
|
37 |
+
|
38 |
+
|
39 |
+
def pad_to_length(tensor, length, dim, value=0):
|
40 |
+
"""
|
41 |
+
Pad tensor to given length in given dim using given value (value should be numeric)
|
42 |
+
"""
|
43 |
+
assert tensor.size(dim) <= length
|
44 |
+
if tensor.size(dim) < length:
|
45 |
+
zeros_shape = list(tensor.shape)
|
46 |
+
zeros_shape[dim] = length - tensor.size(dim)
|
47 |
+
zeros_shape = tuple(zeros_shape)
|
48 |
+
return torch.cat([tensor, torch.zeros(zeros_shape).type(tensor.type()).to(tensor.device).fill_(value)], dim=dim)
|
49 |
+
else:
|
50 |
+
return tensor
|
51 |
+
|
52 |
+
|
53 |
+
def pad_mask(lengths: torch.LongTensor) -> torch.ByteTensor:
|
54 |
+
"""
|
55 |
+
Create a mask of seq x batch where seq = max(lengths), with 0 in padding locations and 1 otherwise.
|
56 |
+
"""
|
57 |
+
# lengths: bs. Ex: [2, 3, 1]
|
58 |
+
max_seqlen = torch.max(lengths)
|
59 |
+
expanded_lengths = lengths.unsqueeze(0).repeat((max_seqlen, 1)) # [[2, 3, 1], [2, 3, 1], [2, 3, 1]]
|
60 |
+
indices = torch.arange(max_seqlen).unsqueeze(1).repeat((1, lengths.size(0))).to(lengths.device) # [[0, 0, 0], [1, 1, 1], [2, 2, 2]]
|
61 |
+
|
62 |
+
return expanded_lengths > indices # pad locations are 0. #[[1, 1, 1], [1, 1, 0], [0, 1, 0]]. seqlen x bs
|
63 |
+
|
64 |
+
|
65 |
+
class ProgressMeter(object):
|
66 |
+
"""
|
67 |
+
Display meter
|
68 |
+
"""
|
69 |
+
def __init__(self, num_batches, meters, prefix=""):
|
70 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
71 |
+
self.meters = meters
|
72 |
+
self.prefix = prefix
|
73 |
+
|
74 |
+
def display(self, batch):
|
75 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
76 |
+
entries.append(time.ctime(time.time()))
|
77 |
+
entries += [str(meter) for meter in self.meters]
|
78 |
+
print('\t'.join(entries))
|
79 |
+
|
80 |
+
def _get_batch_fmtstr(self, num_batches):
|
81 |
+
num_digits = len(str(num_batches // 1))
|
82 |
+
fmt = '{:' + str(num_digits) + 'd}'
|
83 |
+
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
84 |
+
|
85 |
+
|
86 |
+
class AverageMeter(object):
|
87 |
+
"""
|
88 |
+
Display meter
|
89 |
+
Computes and stores the average and current value
|
90 |
+
"""
|
91 |
+
def __init__(self, name, fmt=':f'):
|
92 |
+
self.name = name
|
93 |
+
self.fmt = fmt
|
94 |
+
self.reset()
|
95 |
+
|
96 |
+
def reset(self):
|
97 |
+
self.val = 0
|
98 |
+
self.avg = 0
|
99 |
+
self.sum = 0
|
100 |
+
self.count = 0
|
101 |
+
|
102 |
+
def update(self, val, n=1):
|
103 |
+
self.val = val
|
104 |
+
self.sum += val * n
|
105 |
+
self.count += n
|
106 |
+
self.avg = self.sum / self.count
|
107 |
+
|
108 |
+
def __str__(self):
|
109 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
110 |
+
return fmtstr.format(**self.__dict__)
|