Spaces:
No application file
No application file
hieungo1410
commited on
Commit
•
8cb4f3b
1
Parent(s):
cbb24e0
'add'
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +6 -0
- README.md +156 -13
- bin/__init__.py +1 -0
- bin/main.py +73 -0
- bin/serve.py +108 -0
- config/bilingual_prototype.yml +52 -0
- config/prototype.json +25 -0
- layers/__init__.py +1 -0
- layers/prototypes.py +148 -0
- models/__init__.py +7 -0
- models/default.py +13 -0
- models/transformer.py +404 -0
- modules/__init__.py +3 -0
- modules/config.py +62 -0
- modules/constants.py +18 -0
- modules/default.py +54 -0
- modules/inference/__init__.py +10 -0
- modules/inference/__pycache__/__init__.cpython-36.pyc +0 -0
- modules/inference/__pycache__/beam_search.cpython-36.pyc +0 -0
- modules/inference/__pycache__/decode_strategy.cpython-36.pyc +0 -0
- modules/inference/__pycache__/prototypes.cpython-36.pyc +0 -0
- modules/inference/__pycache__/sampling_temperature.cpython-36.pyc +0 -0
- modules/inference/beam_search.py +336 -0
- modules/inference/beam_search1.py +346 -0
- modules/inference/decode_strategy.py +62 -0
- modules/inference/greedy_search.py +121 -0
- modules/inference/prototypes.py +144 -0
- modules/inference/sampling_temperature.py +119 -0
- modules/loader/__init__.py +4 -0
- modules/loader/__pycache__/__init__.cpython-36.pyc +0 -0
- modules/loader/__pycache__/default_loader.cpython-36.pyc +0 -0
- modules/loader/__pycache__/multilingual_loader.cpython-36.pyc +0 -0
- modules/loader/default_loader.py +114 -0
- modules/loader/multilingual_loader.py +139 -0
- modules/optim/__init__.py +5 -0
- modules/optim/__pycache__/__init__.cpython-36.pyc +0 -0
- modules/optim/__pycache__/adabelief.cpython-36.pyc +0 -0
- modules/optim/__pycache__/adam.cpython-36.pyc +0 -0
- modules/optim/__pycache__/scheduler.cpython-36.pyc +0 -0
- modules/optim/adabelief.py +42 -0
- modules/optim/adam.py +38 -0
- modules/optim/scheduler.py +56 -0
- modules/prototypes.py +209 -0
- requirements.txt +10 -0
- third-party/multi-bleu.perl +177 -0
- utils/data.py +137 -0
- utils/decode_old.py +163 -0
- utils/logging.py +22 -0
- utils/loss.py +25 -0
- utils/metric.py +60 -0
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
bin/__pycache__
|
2 |
+
layers/__pycache__
|
3 |
+
models/__pycache__
|
4 |
+
modules/__pycache__
|
5 |
+
utils/__pycache__
|
6 |
+
data/raw/preprocess.ipynb
|
README.md
CHANGED
@@ -1,13 +1,156 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Dự án MultilingualMT-UET-KC4.0 là dự án open-source được phát triển bởi nhóm UETNLPLab.
|
2 |
+
|
3 |
+
# Setup
|
4 |
+
## Cài đặt công cụ Multilingual-NMT
|
5 |
+
|
6 |
+
**Note**:
|
7 |
+
Lưu ý:
|
8 |
+
Phiên bản hiện tại chỉ tương thích với python>=3.6
|
9 |
+
```bash
|
10 |
+
git clone https://github.com/KCDichDaNgu/KC4.0_MultilingualNMT.git
|
11 |
+
cd KC4.0_MultilingualNMT
|
12 |
+
pip install -r requirements.txt
|
13 |
+
|
14 |
+
# Quickstart
|
15 |
+
|
16 |
+
```
|
17 |
+
|
18 |
+
## Bước 1: Chuẩn bị dữ liệu
|
19 |
+
|
20 |
+
Ví dụ thực nghiệm dựa trên cặp dữ liệu Anh-Việt nguồn từ iwslt với 133k cặp câu:
|
21 |
+
|
22 |
+
```bash
|
23 |
+
cd data/iwslt_en_vi
|
24 |
+
```
|
25 |
+
|
26 |
+
Dữ liệu bao gồm câu nguồn (`src`) và câu đích (`tgt`) dữ liệu đã được tách từ:
|
27 |
+
|
28 |
+
* `train.en`
|
29 |
+
* `train.vi`
|
30 |
+
* `tst2012.en`
|
31 |
+
* `tst2012.vi`
|
32 |
+
|
33 |
+
| Data set | Sentences | Download |
|
34 |
+
| :---------: | :--------: | :-------------------------------------------: |
|
35 |
+
| Training | 133,317 | via GitHub or located in data/train-en-vi.tgz |
|
36 |
+
| Development | 1,553 | via GitHub or located in data/train-en-vi.tgz |
|
37 |
+
| Test | 1,268 | via GitHub or located in data/train-en-vi.tgz |
|
38 |
+
|
39 |
+
|
40 |
+
**Note**:
|
41 |
+
Lưu ý:
|
42 |
+
- Dữ liệu trước khi đưa vào huấn luyện cần phải được tokenize.
|
43 |
+
- $CONFIG là đường dẫn tới vị trí chứa file config
|
44 |
+
|
45 |
+
Tách dữ liệu dev để tính toán hội tụ trong quá trình huấn luyện, thường không lớn hơn 5k câu.
|
46 |
+
|
47 |
+
```text
|
48 |
+
$ head -n 5 data/iwslt_en_vi/train.en
|
49 |
+
Rachel Pike : The science behind a climate headline
|
50 |
+
In 4 minutes , atmospheric chemist Rachel Pike provides a glimpse of the massive scientific effort behind the bold headlines on climate change , with her team -- one of thousands who contributed -- taking a risky flight over the rainforest in pursuit of data on a key molecule .
|
51 |
+
I 'd like to talk to you today about the scale of the scientific effort that goes into making the headlines you see in the paper .
|
52 |
+
Headlines that look like this when they have to do with climate change , and headlines that look like this when they have to do with air quality or smog .
|
53 |
+
They are both two branches of the same field of atmospheric science .
|
54 |
+
```
|
55 |
+
|
56 |
+
## Bước 2: Huấn luyện mô hình
|
57 |
+
|
58 |
+
Để huấn luyện một mô hình mới **hãy chỉnh sửa file YAML config**:
|
59 |
+
Cần phải sửa lại file config en_vi.yml chỉnh siêu tham số và đường dẫn tới dữ liệu huấn luyện:
|
60 |
+
|
61 |
+
```yaml
|
62 |
+
# data location and config section
|
63 |
+
data:
|
64 |
+
train_data_location: data/iwslt_en_vi/train
|
65 |
+
eval_data_location: data/iwslt_en_vi/tst2013
|
66 |
+
src_lang: .en
|
67 |
+
trg_lang: .vi
|
68 |
+
log_file_models: 'model.log'
|
69 |
+
lowercase: false
|
70 |
+
build_vocab_kwargs: # additional arguments for build_vocab. See torchtext.vocab.Vocab for mode details
|
71 |
+
# max_size: 50000
|
72 |
+
min_freq: 5
|
73 |
+
# model parameters section
|
74 |
+
device: cuda
|
75 |
+
d_model: 512
|
76 |
+
n_layers: 6
|
77 |
+
heads: 8
|
78 |
+
# inference section
|
79 |
+
eval_batch_size: 8
|
80 |
+
decode_strategy: BeamSearch
|
81 |
+
decode_strategy_kwargs:
|
82 |
+
beam_size: 5 # beam search size
|
83 |
+
length_normalize: 0.6 # recalculate beam position by length. Currently only work in default BeamSearch
|
84 |
+
replace_unk: # tuple of layer/head attention to replace unknown words
|
85 |
+
- 0 # layer
|
86 |
+
- 0 # head
|
87 |
+
input_max_length: 200 # input longer than this value will be trimmed in inference. Note that this values are to be used during cached PE, hence, validation set with more than this much tokens will call a warning for the trimming.
|
88 |
+
max_length: 160 # only perform up to this much timestep during inference
|
89 |
+
train_max_length: 50 # training samples with this much length in src/trg will be discarded
|
90 |
+
# optimizer and learning arguments section
|
91 |
+
lr: 0.2
|
92 |
+
optimizer: AdaBelief
|
93 |
+
optimizer_params:
|
94 |
+
betas:
|
95 |
+
- 0.9 # beta1
|
96 |
+
- 0.98 # beta2
|
97 |
+
eps: !!float 1e-9
|
98 |
+
n_warmup_steps: 4000
|
99 |
+
label_smoothing: 0.1
|
100 |
+
dropout: 0.1
|
101 |
+
# training config, evaluation, save & load section
|
102 |
+
batch_size: 64
|
103 |
+
epochs: 20
|
104 |
+
printevery: 200
|
105 |
+
save_checkpoint_epochs: 1
|
106 |
+
maximum_saved_model_eval: 5
|
107 |
+
maximum_saved_model_train: 5
|
108 |
+
|
109 |
+
```
|
110 |
+
|
111 |
+
Sau đó có thể chạy với câu lệnh:
|
112 |
+
|
113 |
+
```bash
|
114 |
+
python -m bin.main train --model Transformer --model_dir $MODEL/en-vi.model --config $CONFIG/en_vi.yml
|
115 |
+
```
|
116 |
+
|
117 |
+
**Note**:
|
118 |
+
Ở đây:
|
119 |
+
- $MODEL là dường dẫn tới vị trí lưu mô hình. Sau khi huấn luyện mô hình, thư mục chứa mô hình bao gồm mô hình huyến luyện, file config, file log, vocab.
|
120 |
+
- $CONFIG là đường dẫn tới vị trí chứa file config
|
121 |
+
|
122 |
+
## Bước 3: Dịch
|
123 |
+
|
124 |
+
Mô hình dịch dựa trên thuật toán beam search và lưu bản dịch tại `$your_data_path/translate.en2vi.vi`.
|
125 |
+
|
126 |
+
```bash
|
127 |
+
python -m bin.main infer --model Transformer --model_dir $MODEL/en-vi.model --features_file $your_data_path/tst2012.en --predictions_file $your_data_path/translate.en2vi.vi
|
128 |
+
```
|
129 |
+
|
130 |
+
## Bước 4: Đánh giá chất lượng dựa trên điểm BLEU
|
131 |
+
|
132 |
+
Đánh giá điểm BLEU dựa trên multi-bleu
|
133 |
+
|
134 |
+
```bash
|
135 |
+
perl thrid-party/multi-bleu.perl $your_data_path/translate.en2vi.vi < $your_data_path/tst2012.vi
|
136 |
+
```
|
137 |
+
|
138 |
+
| MODEL | BLEU (Beam Search) |
|
139 |
+
| :-----------------:| :----------------: |
|
140 |
+
| Transformer (Base) | 25.64 |
|
141 |
+
|
142 |
+
|
143 |
+
## Chi tiết tham khảo tại
|
144 |
+
[nmtuet.ddns.net](http://nmtuet.ddns.net:1190/)
|
145 |
+
|
146 |
+
## Nếu có ý kiến đóng góp, xin hãy gửi thư tới địa chỉ mail kcdichdangu@gmail.com
|
147 |
+
|
148 |
+
## Xin trích dẫn bài báo sau:
|
149 |
+
```bash
|
150 |
+
@inproceedings{ViNMT2022,
|
151 |
+
title = {ViNMT: Neural Machine Translation Toolkit},
|
152 |
+
author = {Nguyen Hoang Quan, Nguyen Thanh Dat, Nguyen Hoang Minh Cong, Nguyen Van Vinh, Ngo Thi Vinh, Nguyen Phuong Thai, Tran Hong Viet},
|
153 |
+
booktitle = {https://arxiv.org/abs/2112.15272},
|
154 |
+
year = {2022},
|
155 |
+
}
|
156 |
+
```
|
bin/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
import bin.main as main
|
bin/main.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import models
|
2 |
+
import argparse, os
|
3 |
+
from shutil import copy2 as copy
|
4 |
+
from modules.config import find_all_config
|
5 |
+
|
6 |
+
OVERRIDE_RUN_MODE = {"serve": "infer", "debug": "eval"}
|
7 |
+
|
8 |
+
def check_valid_file(path):
|
9 |
+
if(os.path.isfile(path)):
|
10 |
+
return path
|
11 |
+
else:
|
12 |
+
raise argparse.ArgumentError("This path {:s} is not a valid file, check again.".format(path))
|
13 |
+
|
14 |
+
def create_torchscript_model(model, model_dir, model_name):
|
15 |
+
"""Create a torchscript model using junk data. NOTE: same as tensorflow, is a limited model with no native python script."""
|
16 |
+
import torch
|
17 |
+
junk_input = torch.rand(2, 10)
|
18 |
+
junk_output = torch.rand(2, 7)
|
19 |
+
traced_model = torch.jit.trace(model, junk_input, junk_output)
|
20 |
+
save_location = os.path.join(model_dir, model_name)
|
21 |
+
traced_model.save(save_location)
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
parser = argparse.ArgumentParser(description="Main argument parser")
|
25 |
+
parser.add_argument("run_mode", choices=("train", "eval", "infer", "debug", "serve"), help="Main running mode of the program")
|
26 |
+
parser.add_argument("--model", type=str, choices=models.AvailableModels.keys(), help="The type of model to be ran")
|
27 |
+
parser.add_argument("--model_dir", type=str, required=True, help="Location of model")
|
28 |
+
parser.add_argument("--config", type=str, nargs="+", default=None, help="Location of the config file")
|
29 |
+
parser.add_argument("--no_keeping_config", action="store_false", help="If set, do not copy the config file to the model directory")
|
30 |
+
# arguments for inference
|
31 |
+
parser.add_argument("--features_file", type=str, help="Inference mode: Provide the location of features file")
|
32 |
+
parser.add_argument("--predictions_file", type=str, help="Inference mode: Provide Location of output file which is predicted from features file")
|
33 |
+
parser.add_argument("--src_lang", type=str, help="Inference mode: Provide language used by source file")
|
34 |
+
parser.add_argument("--trg_lang", type=str, default=None, help="Inference mode: Choose language that is translated from source file. NOTE: only specify for multilingual model")
|
35 |
+
parser.add_argument("--infer_batch_size", type=int, default=None, help="Specify the batch_size to run the model with. Default use the config value.")
|
36 |
+
parser.add_argument("--checkpoint", type=str, default=None, help="All mode: specify to load the checkpoint into model.")
|
37 |
+
parser.add_argument("--checkpoint_idx", type=int, default=0, help="All mode: specify the epoch of the checkpoint loaded. Only useful for training.")
|
38 |
+
parser.add_argument("--serve_path", type=str, default=None, help="File to save TorchScript model into.")
|
39 |
+
|
40 |
+
args = parser.parse_args()
|
41 |
+
# create directory if not exist
|
42 |
+
os.makedirs(args.model_dir, exist_ok=True)
|
43 |
+
config_path = args.config
|
44 |
+
if(config_path is None):
|
45 |
+
config_path = find_all_config(args.model_dir)
|
46 |
+
print("Config path not specified, load the configs in model directory which is {}".format(config_path))
|
47 |
+
elif(args.no_keeping_config):
|
48 |
+
# store false variable, mean true is default
|
49 |
+
print("Config specified, copying all to model dir")
|
50 |
+
for subpath in config_path:
|
51 |
+
copy(subpath, args.model_dir)
|
52 |
+
|
53 |
+
# load model. Specific run mode required converting
|
54 |
+
run_mode = OVERRIDE_RUN_MODE.get(args.run_mode, args.run_mode)
|
55 |
+
model = models.AvailableModels[args.model](config=config_path, model_dir=args.model_dir, mode=run_mode)
|
56 |
+
model.load_checkpoint(args.model_dir, checkpoint=args.checkpoint, checkpoint_idx=args.checkpoint_idx)
|
57 |
+
# run model
|
58 |
+
run_mode = args.run_mode
|
59 |
+
if(run_mode == "train"):
|
60 |
+
model.run_train(model_dir=args.model_dir, config=config_path)
|
61 |
+
elif(run_mode == "eval"):
|
62 |
+
model.run_eval(model_dir=args.model_dir, config=config_path)
|
63 |
+
elif(run_mode == "infer"):
|
64 |
+
model.run_infer(args.features_file, args.predictions_file, src_lang=args.src_lang, trg_lang=args.trg_lang, config=config_path, batch_size=args.infer_batch_size)
|
65 |
+
elif(run_mode == "debug"):
|
66 |
+
raise NotImplementedError
|
67 |
+
model.run_debug(model_dir=args.model_dir, config=config_path)
|
68 |
+
elif(run_mode == "serve"):
|
69 |
+
if(args.serve_path is None):
|
70 |
+
raise parser.ArgumentError("In serving, --serve_path cannot be empty")
|
71 |
+
model.prepare_serve(args.serve_path, model_dir=args.model_dir, config=config_path)
|
72 |
+
else:
|
73 |
+
raise ValueError("Run mode {:s} not implemented.".format(run_mode))
|
bin/serve.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
#import utils.save as saver
|
4 |
+
#import models
|
5 |
+
#from models.transformer import Transformer
|
6 |
+
#from modules.config import find_all_config
|
7 |
+
|
8 |
+
class TransformerHandlerClass:
|
9 |
+
def __init__(self):
|
10 |
+
self.model = None
|
11 |
+
self.device = None
|
12 |
+
self.initialized = False
|
13 |
+
|
14 |
+
def _find_checkpoint(self, model_dir, best_model_prefix="best_model", model_prefix="model", validate=True):
|
15 |
+
"""Attempt to retrieve the best model checkpoint from model_dir. Failing that, the model of the latest iteration.
|
16 |
+
Args:
|
17 |
+
model_dir: location to search for checkpoint. str
|
18 |
+
Returns:
|
19 |
+
single str denoting the checkpoint path """
|
20 |
+
score_file_path = os.path.join(model_dir, saver.BEST_MODEL_FILE)
|
21 |
+
if(os.path.isfile(score_file_path)): # score exist -> best model
|
22 |
+
best_model_path = os.path.join(model_dir, saver.MODEL_FILE_FORMAT.format(best_model_prefix, 0, saver.MODEL_EXTENSION))
|
23 |
+
if(validate):
|
24 |
+
assert os.path.isfile(best_model_path), "Score file is available, but file {:s} is missing.".format(best_model_path)
|
25 |
+
return best_model_path
|
26 |
+
else: # score not exist -> latest model
|
27 |
+
last_checkpoint_idx = saver.check_model_in_path(name_prefix=model_prefix)
|
28 |
+
if(last_checkpoint_idx == 0):
|
29 |
+
raise ValueError("No checkpoint found in folder {:s} with prefix {:s}.".format(model_dir, model_prefix))
|
30 |
+
else:
|
31 |
+
return os.path.join(model_dir, saver.MODEL_FILE_FORMAT.format(model_prefix, last_checkpoint_idx, saver.MODEL_EXTENSION))
|
32 |
+
|
33 |
+
|
34 |
+
def initialize(self, ctx):
|
35 |
+
manifest = ctx.manifest
|
36 |
+
properties = ctx.system_properties
|
37 |
+
|
38 |
+
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
|
39 |
+
self.model_dir = model_dir = properties.get("model_dir")
|
40 |
+
|
41 |
+
# extract checkpoint location, config & model name
|
42 |
+
model_serve_file = os.path.join(model_dir, saver.MODEL_SERVE_FILE)
|
43 |
+
with io.open(model_serve_file, "r") as serve_config:
|
44 |
+
model_name = serve_config.read().strip()
|
45 |
+
# model_cls = models.AvailableModels[model_name]
|
46 |
+
model_cls = Transformer # can't select due to nature of model file
|
47 |
+
checkpoint_path = manifest['model'].get('serializedFile', self._find_checkpoint(model_dir)) # attempt to use the checkpoint fed from archiver; else use the best checkpoint found
|
48 |
+
config_path = find_all_config(model_dir)
|
49 |
+
|
50 |
+
# load model with inbuilt config + vocab & without pretraining data
|
51 |
+
self.model = model = model_cls(config=config_path, model_dir=model_dir, mode="infer")
|
52 |
+
model.load_checkpoint(args.model_dir, checkpoint=checkpoint_path) # TODO find_checkpoint might do some redundant thing here since load_checkpoint had already done searching for latest
|
53 |
+
|
54 |
+
print("Model {:s} loaded successfully at location {:s}.".format(model_name, model_dir))
|
55 |
+
self.initialized = True
|
56 |
+
|
57 |
+
def handle(self, data):
|
58 |
+
"""The main bulk of handling. Process a batch of data received from client.
|
59 |
+
Args:
|
60 |
+
data: the object received from client. Should contain something in [batch_size] of str
|
61 |
+
Returns:
|
62 |
+
the expected translation, [batch_size] of str
|
63 |
+
"""
|
64 |
+
batch_sentences = data[0].get("data")
|
65 |
+
# assert batch_sentences is not None, "data is {}".format(data)
|
66 |
+
|
67 |
+
# make sure that sentences are detokenized before returning
|
68 |
+
translated_sentences = self.model.translate_batch(batch_sentences, output_tokens=False)
|
69 |
+
|
70 |
+
return translated_sentences
|
71 |
+
|
72 |
+
class BeamSearchHandlerClass:
|
73 |
+
def __init__(self):
|
74 |
+
self.model = None
|
75 |
+
self.inferrer = None
|
76 |
+
self.initialized = False
|
77 |
+
|
78 |
+
def initialize(self, ctx):
|
79 |
+
manifest = ctx.manifest
|
80 |
+
properties = ctx.system_properties
|
81 |
+
|
82 |
+
model_dir = properties['model_dir']
|
83 |
+
ts_modelpath = manifest['model']['serializedFile']
|
84 |
+
self.model = ts_model = torch.jit.load(os.path.join(model_dir, ts_modelpath))
|
85 |
+
|
86 |
+
from modules.inference.beam_search import BeamSearch
|
87 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
88 |
+
self.inferrer = BeamSearch(model, 160, device, beam_size=5)
|
89 |
+
|
90 |
+
self.initialized = True
|
91 |
+
|
92 |
+
def handle(self, data):
|
93 |
+
batch_sentences = data[0].get("data")
|
94 |
+
# assert batch_sentences is not None, "data is {}".format(data)
|
95 |
+
|
96 |
+
translated_sentences = self.inferrer.translate_batch_sentence(data, output_tokens=False)
|
97 |
+
return translated_sentences
|
98 |
+
|
99 |
+
RUNNING_MODEL = BeamSearchHandlerClass()
|
100 |
+
|
101 |
+
def handle(data, context):
|
102 |
+
if(not RUNNING_MODEL.initialized): # Lazy init
|
103 |
+
RUNNING_MODEL.initialize(context)
|
104 |
+
|
105 |
+
if(data is None):
|
106 |
+
return None
|
107 |
+
|
108 |
+
return RUNNING_MODEL.handle(data)
|
config/bilingual_prototype.yml
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# data location and config section
|
2 |
+
data:
|
3 |
+
train_data_location: data/test/train2023
|
4 |
+
eval_data_location: data/test/dev2023
|
5 |
+
src_lang: .lo
|
6 |
+
trg_lang: .vi
|
7 |
+
log_file_models: 'model.log'
|
8 |
+
lowercase: false
|
9 |
+
build_vocab_kwargs: # additional arguments for build_vocab. See torchtext.vocab.Vocab for mode details
|
10 |
+
# max_size: 50000
|
11 |
+
min_freq: 4
|
12 |
+
specials:
|
13 |
+
- <unk>
|
14 |
+
- <pad>
|
15 |
+
- <sos>
|
16 |
+
- <eos>
|
17 |
+
# data augmentation section
|
18 |
+
# model parameters section
|
19 |
+
device: cuda
|
20 |
+
d_model: 512
|
21 |
+
n_layers: 6
|
22 |
+
heads: 8
|
23 |
+
# inference section
|
24 |
+
eval_batch_size: 8
|
25 |
+
decode_strategy: BeamSearch
|
26 |
+
decode_strategy_kwargs:
|
27 |
+
beam_size: 5 # beam search size
|
28 |
+
length_normalize: 0.6 # recalculate beam position by length. Currently only work in default BeamSearch
|
29 |
+
replace_unk: # tuple of layer/head attention to replace unknown words
|
30 |
+
- 0 # layer
|
31 |
+
- 0 # head
|
32 |
+
input_max_length: 250 # input longer than this value will be trimmed in inference. Note that this values are to be used during cached PE, hence, validation set with more than this much tokens will call a warning for the trimming.
|
33 |
+
max_length: 160 # only perform up to this much timestep during inference
|
34 |
+
train_max_length: 140 # training samples with this much length in src/trg will be discarded
|
35 |
+
# optimizer and learning arguments section
|
36 |
+
lr: 0.2
|
37 |
+
optimizer: AdaBelief
|
38 |
+
optimizer_params:
|
39 |
+
betas:
|
40 |
+
- 0.9 # beta1
|
41 |
+
- 0.98 # beta2
|
42 |
+
eps: !!float 1e-9
|
43 |
+
n_warmup_steps: 4000
|
44 |
+
label_smoothing: 0.1
|
45 |
+
dropout: 0.05
|
46 |
+
# training config, evaluation, save & load section
|
47 |
+
batch_size: 32
|
48 |
+
epochs: 40
|
49 |
+
printevery: 200
|
50 |
+
save_checkpoint_epochs: 1
|
51 |
+
maximum_saved_model_eval: 5
|
52 |
+
maximum_saved_model_train: 5
|
config/prototype.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train_src_data": "/workspace/khoai23/opennmt/data/iwslt_en_vi/train.en",
|
3 |
+
"train_trg_data": "/workspace/khoai23/opennmt/data/iwslt_en_vi/train.vi",
|
4 |
+
"valid_src_data": "/workspace/khoai23/opennmt/data/iwslt_en_vi/tst2013.en",
|
5 |
+
"valid_trg_data": "/workspace/khoai23/opennmt/data/iwslt_en_vi/tst2013.vi",
|
6 |
+
"src_lang": "en",
|
7 |
+
"trg_lang": "en",
|
8 |
+
"max_strlen": 160,
|
9 |
+
"batchsize": 1500,
|
10 |
+
"device": "cpu",
|
11 |
+
"d_model": 512,
|
12 |
+
"n_layers": 6,
|
13 |
+
"heads": 8,
|
14 |
+
"dropout": 0.1,
|
15 |
+
"lr": 0.0001,
|
16 |
+
"epochs": 30,
|
17 |
+
"printevery": 200,
|
18 |
+
"k": 5,
|
19 |
+
"n_warmup_steps": 4000,
|
20 |
+
"beta1": 0.9,
|
21 |
+
"beta2": 0.98,
|
22 |
+
"eps": 1e-09,
|
23 |
+
"label_smoothing": 0.1,
|
24 |
+
"save_checkpoint_epochs": 5
|
25 |
+
}
|
layers/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from layers.prototypes import *
|
layers/prototypes.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.autograd import Variable
|
4 |
+
import torch.nn.functional as functional
|
5 |
+
import math
|
6 |
+
import logging
|
7 |
+
|
8 |
+
class PositionalEncoder(nn.Module):
|
9 |
+
def __init__(self, d_model, max_seq_length=200, dropout=0.1):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
self.d_model = d_model
|
13 |
+
self.dropout = nn.Dropout(dropout)
|
14 |
+
self._max_seq_length = max_seq_length
|
15 |
+
|
16 |
+
pe = torch.zeros(max_seq_length, d_model)
|
17 |
+
|
18 |
+
for pos in range(max_seq_length):
|
19 |
+
for i in range(0, d_model, 2):
|
20 |
+
pe[pos, i] = math.sin(pos/(10000**(2*i/d_model)))
|
21 |
+
pe[pos, i+1] = math.cos(pos/(10000**((2*i+1)/d_model)))
|
22 |
+
pe = pe.unsqueeze(0)
|
23 |
+
self.register_buffer('pe', pe)
|
24 |
+
|
25 |
+
@torch.jit.script
|
26 |
+
def splice_by_size(source, target):
|
27 |
+
"""Custom function to splice the source by target's second dimension. Required due to torch.Size not a torchTensor. Why? hell if I know."""
|
28 |
+
length = target.size(1);
|
29 |
+
return source[:, :length]
|
30 |
+
|
31 |
+
self.splice_by_size = splice_by_size
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
if(x.shape[1] > self._max_seq_length):
|
35 |
+
logging.warn("Input longer than maximum supported length for PE detected. Build a model with a larger input_max_length limit if you want to keep the input; or ignore if you want the input trimmed")
|
36 |
+
x = x[:, :self._max_seq_length]
|
37 |
+
|
38 |
+
x = x * math.sqrt(self.d_model)
|
39 |
+
|
40 |
+
spliced_pe = self.splice_by_size(self.pe, x) # self.pe[:, :x.shape[1]]
|
41 |
+
# pe = Variable(spliced_pe, requires_grad=False)
|
42 |
+
pe = spliced_pe.requires_grad_(False)
|
43 |
+
|
44 |
+
# if x.is_cuda: # remove since it is a sub nn.Module
|
45 |
+
# pe.cuda()
|
46 |
+
# assert all([xs == ys for xs, ys in zip(x.shape[1:], pe.shape[1:])]), "{} - {}".format(x.shape, pe.shape)
|
47 |
+
|
48 |
+
x = x + pe
|
49 |
+
x = self.dropout(x)
|
50 |
+
|
51 |
+
return x
|
52 |
+
|
53 |
+
class MultiHeadAttention(nn.Module):
|
54 |
+
def __init__(self, heads, d_model, dropout=0.1):
|
55 |
+
super().__init__()
|
56 |
+
assert d_model % heads == 0
|
57 |
+
|
58 |
+
self.d_model = d_model
|
59 |
+
self.d_k = d_model // heads
|
60 |
+
self.h = heads
|
61 |
+
|
62 |
+
# three casting linear layer for query/key.value
|
63 |
+
self.q_linear = nn.Linear(d_model, d_model)
|
64 |
+
self.k_linear = nn.Linear(d_model, d_model)
|
65 |
+
self.v_linear = nn.Linear(d_model, d_model)
|
66 |
+
|
67 |
+
self.dropout = nn.Dropout(dropout)
|
68 |
+
self.out = nn.Linear(d_model, d_model)
|
69 |
+
|
70 |
+
def forward(self, q, k, v, mask=None):
|
71 |
+
"""
|
72 |
+
Args:
|
73 |
+
q / k / v: query/key/value, should all be [batch_size, sequence_length, d_model]. Only differ in decode attention, where q is tgt_len and k/v is src_len
|
74 |
+
mask: either [batch_size, 1, src_len] or [batch_size, tgt_len, tgt_len]. The last two dimensions must match or are broadcastable.
|
75 |
+
Returns:
|
76 |
+
the value of the attention process, [batch_size, sequence_length, d_model].
|
77 |
+
The used attention, [batch_size, q_length, k_v_length]
|
78 |
+
"""
|
79 |
+
bs = q.shape[0]
|
80 |
+
q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
|
81 |
+
k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
|
82 |
+
v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
|
83 |
+
|
84 |
+
q = q.transpose(1, 2)
|
85 |
+
k = k.transpose(1, 2)
|
86 |
+
v = v.transpose(1, 2)
|
87 |
+
|
88 |
+
value, attn = self.attention(q, k, v, mask, self.dropout)
|
89 |
+
concat = value.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
|
90 |
+
output = self.out(concat)
|
91 |
+
return output, attn
|
92 |
+
|
93 |
+
def attention(self, q, k, v, mask=None, dropout=None):
|
94 |
+
"""Calculate the attention and output the attention & value
|
95 |
+
Args:
|
96 |
+
q / k / v: query/key/value already transformed, should all be [batch_size, heads, sequence_length, d_k]. Only differ in decode attention, where q is tgt_len and k/v is src_len
|
97 |
+
mask: either [batch_size, 1, src_len] or [batch_size, tgt_len, tgt_len]. The last two dimensions must match or are broadcastable.
|
98 |
+
Returns:
|
99 |
+
the attentionized but raw values [batch_size, head, seq_length, d_k]
|
100 |
+
the attention calculated [batch_size, heads, sequence_length, sequence_length]
|
101 |
+
"""
|
102 |
+
|
103 |
+
# d_k = q.shape[-1]
|
104 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
105 |
+
|
106 |
+
if mask is not None:
|
107 |
+
mask = mask.unsqueeze(1) # add a dimension to account for head
|
108 |
+
scores = scores.masked_fill(mask==0, -1e9)
|
109 |
+
# softmax the padding/peeking masked attention
|
110 |
+
scores = functional.softmax(scores, dim=-1)
|
111 |
+
|
112 |
+
if dropout is not None:
|
113 |
+
scores = dropout(scores)
|
114 |
+
|
115 |
+
output = torch.matmul(scores, v)
|
116 |
+
return output, scores
|
117 |
+
|
118 |
+
class Norm(nn.Module):
|
119 |
+
def __init__(self, d_model, eps = 1e-6):
|
120 |
+
super().__init__()
|
121 |
+
|
122 |
+
self.size = d_model
|
123 |
+
|
124 |
+
# create two learnable parameters to calibrate normalisation
|
125 |
+
self.alpha = nn.Parameter(torch.ones(self.size))
|
126 |
+
self.bias = nn.Parameter(torch.zeros(self.size))
|
127 |
+
|
128 |
+
self.eps = eps
|
129 |
+
|
130 |
+
def forward(self, x):
|
131 |
+
norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \
|
132 |
+
/ (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
|
133 |
+
return norm
|
134 |
+
|
135 |
+
class FeedForward(nn.Module):
|
136 |
+
"""A two-hidden-linear feedforward layer that can activate and dropout its transition state"""
|
137 |
+
def __init__(self, d_model, d_ff=2048, internal_activation=functional.relu, dropout=0.1):
|
138 |
+
super().__init__()
|
139 |
+
self.linear_1 = nn.Linear(d_model, d_ff)
|
140 |
+
self.dropout = nn.Dropout(dropout)
|
141 |
+
self.linear_2 = nn.Linear(d_ff, d_model)
|
142 |
+
|
143 |
+
self.internal_activation = internal_activation
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
x = self.dropout(self.internal_activation(self.linear_1(x)))
|
147 |
+
x = self.linear_2(x)
|
148 |
+
return x
|
models/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.default import MockModel
|
2 |
+
from models.transformer import Transformer
|
3 |
+
|
4 |
+
AvailableModels = {
|
5 |
+
"MockModel": MockModel,
|
6 |
+
"Transformer" : Transformer
|
7 |
+
}
|
models/default.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class MockModel:
|
2 |
+
"""A model that only output string to show flow"""
|
3 |
+
def __init__(self, *args, **kwargs):
|
4 |
+
print("Mock model initialization, with args/kwargs: {} {}".format(args, kwargs))
|
5 |
+
|
6 |
+
def run_train(self, **kwargs):
|
7 |
+
print("Model in training, with args: {}".format(kwargs))
|
8 |
+
|
9 |
+
def run_eval(self, **kwargs):
|
10 |
+
print("Model in evaluation, with args: {}".format(kwargs))
|
11 |
+
|
12 |
+
def run_debug(self, **kwargs):
|
13 |
+
print("Model in debuging, with args: {}".format(kwargs))
|
models/transformer.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchtext.data as data
|
4 |
+
import copy, time, io
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from modules.prototypes import Encoder, Decoder, Config as DefaultConfig
|
8 |
+
from modules.loader import DefaultLoader, MultiLoader
|
9 |
+
from modules.config import MultiplePathConfig as Config
|
10 |
+
from modules.inference import strategies
|
11 |
+
from modules import constants as const
|
12 |
+
from modules.optim import optimizers, ScheduledOptim
|
13 |
+
|
14 |
+
import utils.save as saver
|
15 |
+
from utils.decode_old import create_masks, translate_sentence
|
16 |
+
#from utils.data import create_fields, create_dataset, read_data, read_file, write_file
|
17 |
+
from utils.loss import LabelSmoothingLoss
|
18 |
+
from utils.metric import bleu, bleu_batch_iter, bleu_single, bleu_batch
|
19 |
+
#from utils.save import load_model_from_path, check_model_in_path, save_and_clear_model, write_model_score, load_model_score, save_model_best_to_path, load_model
|
20 |
+
|
21 |
+
class Transformer(nn.Module):
|
22 |
+
"""
|
23 |
+
Implementation of Transformer architecture based on the paper `Attention is all you need`.
|
24 |
+
Source: https://arxiv.org/abs/1706.03762
|
25 |
+
"""
|
26 |
+
def __init__(self, mode=None, model_dir=None, config=None):
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
# Use specific config file if provided otherwise use the default config instead
|
30 |
+
self.config = DefaultConfig() if(config is None) else Config(config)
|
31 |
+
opt = self.config
|
32 |
+
self.device = opt.get('device', const.DEFAULT_DEVICE)
|
33 |
+
|
34 |
+
if('train_data_location' in opt or 'train_data_location' in opt.get("data", {})):
|
35 |
+
# monolingual data detected
|
36 |
+
data_opt = opt if 'train_data_location' in opt else opt["data"]
|
37 |
+
self.loader = DefaultLoader(data_opt['train_data_location'], eval_path=data_opt.get('eval_data_location', None), language_tuple=(data_opt["src_lang"], data_opt["trg_lang"]), option=opt)
|
38 |
+
elif('data' in opt):
|
39 |
+
# multilingual data with multiple corpus in [data][train] namespace
|
40 |
+
self.loader = MultiLoader(opt["data"]["train"], valid=opt["data"].get("valid", None), option=opt)
|
41 |
+
# input fields
|
42 |
+
self.SRC, self.TRG = self.loader.build_field(lower=opt.get("lowercase", const.DEFAULT_LOWERCASE))
|
43 |
+
# self.SRC = data.Field(lower=opt.get("lowercase", const.DEFAULT_LOWERCASE))
|
44 |
+
# self.TRG = data.Field(lower=opt.get("lowercase", const.DEFAULT_LOWERCASE), eos_token='<eos>')
|
45 |
+
|
46 |
+
# initialize dataset and by proxy the vocabulary
|
47 |
+
if(mode == "train"):
|
48 |
+
# training flow, necessitate the DataLoader and iterations. This will attempt to load vocab file from the dir instead of rebuilding, but can build a new vocab if no data is found
|
49 |
+
self.train_iter, self.valid_iter = self.loader.create_iterator(self.fields, model_path=model_dir)
|
50 |
+
elif(mode == "eval"):
|
51 |
+
# evaluation flow, which only require valid_iter
|
52 |
+
# TODO fix accordingly
|
53 |
+
self.train_iter, self.valid_iter = self.loader.create_iterator(self.fields, model_path=model_dir)
|
54 |
+
elif(mode == "infer"):
|
55 |
+
# inference, require pickled model and vocab in the path
|
56 |
+
self.loader.build_vocab(self.fields, model_path=model_dir)
|
57 |
+
else:
|
58 |
+
raise ValueError("Unknown model's mode: {}".format(mode))
|
59 |
+
|
60 |
+
|
61 |
+
# define the model
|
62 |
+
src_vocab_size, trg_vocab_size = len(self.SRC.vocab), len(self.TRG.vocab)
|
63 |
+
d_model, N, heads, dropout = opt['d_model'], opt['n_layers'], opt['heads'], opt['dropout']
|
64 |
+
# get the maximum amount of tokens per sample in encoder. This is useful due to PositionalEncoder requiring this value
|
65 |
+
train_ignore_length = self.config.get("train_max_length", const.DEFAULT_TRAIN_MAX_LENGTH)
|
66 |
+
input_max_length = self.config.get("input_max_length", const.DEFAULT_INPUT_MAX_LENGTH)
|
67 |
+
infer_max_length = self.config.get('max_length', const.DEFAULT_MAX_LENGTH)
|
68 |
+
encoder_max_length = max(input_max_length, train_ignore_length)
|
69 |
+
decoder_max_length = max(infer_max_length, train_ignore_length)
|
70 |
+
self.encoder = Encoder(src_vocab_size, d_model, N, heads, dropout, max_seq_length=encoder_max_length)
|
71 |
+
self.decoder = Decoder(trg_vocab_size, d_model, N, heads, dropout, max_seq_length=decoder_max_length)
|
72 |
+
self.out = nn.Linear(d_model, trg_vocab_size)
|
73 |
+
|
74 |
+
# load the beamsearch obj with preset values read from config. ALWAYS require the current model, max_length, and device used as per DecodeStrategy base
|
75 |
+
decode_strategy_class = strategies[opt.get('decode_strategy', const.DEFAULT_DECODE_STRATEGY)]
|
76 |
+
decode_strategy_kwargs = opt.get('decode_strategy_kwargs', const.DEFAULT_STRATEGY_KWARGS)
|
77 |
+
self.decode_strategy = decode_strategy_class(self, infer_max_length, self.device, **decode_strategy_kwargs)
|
78 |
+
|
79 |
+
self.to(self.device)
|
80 |
+
|
81 |
+
def load_checkpoint(self, model_dir, checkpoint=None, checkpoint_idx=0):
|
82 |
+
"""Attempt to load past checkpoint into the model. If a specified checkpoint is set, load it; otherwise load the latest checkpoint in model_dir.
|
83 |
+
Args:
|
84 |
+
model_dir: location of the current model. Not used if checkpoint is specified
|
85 |
+
checkpoint: location of the specific checkpoint to load
|
86 |
+
checkpoint_idx: the epoch of the checkpoint
|
87 |
+
NOTE: checkpoint_idx return -1 in the event of not found; while 0 is when checkpoint is forced
|
88 |
+
"""
|
89 |
+
if(checkpoint is not None):
|
90 |
+
saver.load_model(self, checkpoint)
|
91 |
+
self._checkpoint_idx = checkpoint_idx
|
92 |
+
else:
|
93 |
+
if model_dir is not None:
|
94 |
+
# load the latest available checkpoint, overriding the checkpoint value
|
95 |
+
checkpoint_idx = saver.check_model_in_path(model_dir)
|
96 |
+
if(checkpoint_idx > 0):
|
97 |
+
print("Found model with index {:d} already saved.".format(checkpoint_idx))
|
98 |
+
saver.load_model_from_path(self, model_dir, checkpoint_idx=checkpoint_idx)
|
99 |
+
else:
|
100 |
+
print("No checkpoint found, start from beginning.")
|
101 |
+
checkpoint_idx = -1
|
102 |
+
else:
|
103 |
+
print("No model_dir available, start from beginning.")
|
104 |
+
# train the model from begin
|
105 |
+
checkpoint_idx = -1
|
106 |
+
self._checkpoint_idx = checkpoint_idx
|
107 |
+
|
108 |
+
|
109 |
+
def forward(self, src, trg, src_mask, trg_mask, output_attention=False):
|
110 |
+
"""Run a full model with specified source-target batched set of data
|
111 |
+
Args:
|
112 |
+
src: the source input [batch_size, src_len]
|
113 |
+
trg: the target input (& expected output) [batch_size, trg len]
|
114 |
+
src_mask: the padding mask for src [batch_size, 1, src_len]
|
115 |
+
trg_mask: the triangle mask for trg [batch_size, trg_len, trg_len]
|
116 |
+
output_attention: if specified, output the attention as needed
|
117 |
+
Returns:
|
118 |
+
the logits (unsoftmaxed outputs), same shape as trg
|
119 |
+
"""
|
120 |
+
e_outputs = self.encoder(src, src_mask)
|
121 |
+
d_output, attn = self.decoder(trg, e_outputs, src_mask, trg_mask, output_attention=True)
|
122 |
+
output = self.out(d_output)
|
123 |
+
if(output_attention):
|
124 |
+
return output, attn
|
125 |
+
else:
|
126 |
+
return output
|
127 |
+
|
128 |
+
def train_step(self, optimizer, batch, criterion):
|
129 |
+
"""
|
130 |
+
Perform one training step.
|
131 |
+
"""
|
132 |
+
self.train()
|
133 |
+
opt = self.config
|
134 |
+
|
135 |
+
# move data to specific device's memory
|
136 |
+
src = batch.src.transpose(0, 1).to(opt.get('device', const.DEFAULT_DEVICE))
|
137 |
+
trg = batch.trg.transpose(0, 1).to(opt.get('device', const.DEFAULT_DEVICE))
|
138 |
+
|
139 |
+
trg_input = trg[:, :-1]
|
140 |
+
src_pad = self.SRC.vocab.stoi['<pad>']
|
141 |
+
trg_pad = self.TRG.vocab.stoi['<pad>']
|
142 |
+
ys = trg[:, 1:].contiguous().view(-1)
|
143 |
+
|
144 |
+
# create mask and perform network forward
|
145 |
+
src_mask, trg_mask = create_masks(src, trg_input, src_pad, trg_pad, opt.get('device', const.DEFAULT_DEVICE))
|
146 |
+
preds = self(src, trg_input, src_mask, trg_mask)
|
147 |
+
|
148 |
+
# perform backprogation
|
149 |
+
optimizer.zero_grad()
|
150 |
+
loss = criterion(preds.view(-1, preds.size(-1)), ys)
|
151 |
+
loss.backward()
|
152 |
+
optimizer.step_and_update_lr()
|
153 |
+
loss = loss.item()
|
154 |
+
|
155 |
+
return loss
|
156 |
+
|
157 |
+
def validate(self, valid_iter, criterion, maximum_length=None):
|
158 |
+
"""Compute loss in validation dataset. As we can't perform trimming the input in the valid_iter yet, using a crutch in maximum_input_length variable
|
159 |
+
Args:
|
160 |
+
valid_iter: the Iteration containing batches of data, accessed by .src and .trg
|
161 |
+
criterion: the loss function to use to evaluate
|
162 |
+
maximum_length: if fed, a tuple of max_input_len, max_output_len to trim the src/trg
|
163 |
+
Returns:
|
164 |
+
the avg loss of the criterion
|
165 |
+
"""
|
166 |
+
self.eval()
|
167 |
+
opt = self.config
|
168 |
+
src_pad = self.SRC.vocab.stoi['<pad>']
|
169 |
+
trg_pad = self.TRG.vocab.stoi['<pad>']
|
170 |
+
|
171 |
+
with torch.no_grad():
|
172 |
+
total_loss = []
|
173 |
+
for batch in valid_iter:
|
174 |
+
# load model into specific device (GPU/CPU) memory
|
175 |
+
src = batch.src.transpose(0, 1).to(opt.get('device', const.DEFAULT_DEVICE))
|
176 |
+
trg = batch.trg.transpose(0, 1).to(opt.get('device', const.DEFAULT_DEVICE))
|
177 |
+
if(maximum_length is not None):
|
178 |
+
src = src[:, :maximum_length[0]]
|
179 |
+
trg = trg[:, :maximum_length[1]-1] # using partials
|
180 |
+
trg_input = trg[:, :-1]
|
181 |
+
ys = trg[:, 1:].contiguous().view(-1)
|
182 |
+
|
183 |
+
# create mask and perform network forward
|
184 |
+
src_mask, trg_mask = create_masks(src, trg_input, src_pad, trg_pad, opt.get('device', const.DEFAULT_DEVICE))
|
185 |
+
preds = self(src, trg_input, src_mask, trg_mask)
|
186 |
+
|
187 |
+
# compute loss on current batch
|
188 |
+
loss = criterion(preds.view(-1, preds.size(-1)), ys)
|
189 |
+
loss = loss.item()
|
190 |
+
total_loss.append(loss)
|
191 |
+
|
192 |
+
avg_loss = np.mean(total_loss)
|
193 |
+
return avg_loss
|
194 |
+
|
195 |
+
def translate_sentence(self, sentence, device=None, k=None, max_len=None, debug=False):
|
196 |
+
"""
|
197 |
+
Receive a sentence string and output the prediction generated from the model.
|
198 |
+
NOTE: sentence input is a list of tokens instead of string due to change in loader. See the current DefaultLoader for further details
|
199 |
+
"""
|
200 |
+
self.eval()
|
201 |
+
if(device is None): device = self.config.get('device', const.DEFAULT_DEVICE)
|
202 |
+
if(k is None): k = self.config.get('k', const.DEFAULT_K)
|
203 |
+
if(max_len is None): max_len = self.config.get('max_length', const.DEFAULT_MAX_LENGTH)
|
204 |
+
|
205 |
+
# Get output from decode
|
206 |
+
translated_tokens = translate_sentence(sentence, self, self.SRC, self.TRG, device, k, max_len, debug=debug, output_list_of_tokens=True)
|
207 |
+
|
208 |
+
return translated_tokens
|
209 |
+
|
210 |
+
def translate_batch_sentence(self, sentences, src_lang=None, trg_lang=None, output_tokens=False, batch_size=None):
|
211 |
+
"""Translate sentences by splitting them to batches and process them simultaneously
|
212 |
+
Args:
|
213 |
+
sentences: the sentences in a list. Must NOT have been tokenized (due to SRC preprocess)
|
214 |
+
output_tokens: if set, do not detokenize the output
|
215 |
+
batch_size: if specified, use the value; else use config ones
|
216 |
+
Returns:
|
217 |
+
a matching translated sentences list in [detokenized format using loader.detokenize | list of tokens]
|
218 |
+
"""
|
219 |
+
if(batch_size is None):
|
220 |
+
batch_size = self.config.get("eval_batch_size", const.DEFAULT_EVAL_BATCH_SIZE)
|
221 |
+
input_max_length = self.config.get("input_max_length", const.DEFAULT_INPUT_MAX_LENGTH)
|
222 |
+
self.eval()
|
223 |
+
|
224 |
+
translated = []
|
225 |
+
for b_idx in range(0, len(sentences), batch_size):
|
226 |
+
batch = sentences[b_idx: b_idx+batch_size]
|
227 |
+
# raise Exception(batch)
|
228 |
+
trans_batch = self.translate_batch(batch, trg_lang=trg_lang, output_tokens=output_tokens, input_max_length=input_max_length)
|
229 |
+
# raise Exception(detokenized_batch)
|
230 |
+
translated.extend(trans_batch)
|
231 |
+
# for line in trans_batch:
|
232 |
+
# print(line)
|
233 |
+
return translated
|
234 |
+
|
235 |
+
def translate_batch(self, batch_sentences, src_lang=None, trg_lang=None, output_tokens=False, input_max_length=None):
|
236 |
+
"""Translate a single batch of sentences. Split to aid serving
|
237 |
+
Args:
|
238 |
+
sentences: the sentences in a list. Must NOT have been tokenized (due to SRC preprocess)
|
239 |
+
src_lang/trg_lang: the language from src->trg. Used for multilingual models only.
|
240 |
+
output_tokens: if set, do not detokenize the output
|
241 |
+
Returns:
|
242 |
+
a matching translated sentences list in [detokenized format using loader.detokenize | list of tokens]
|
243 |
+
"""
|
244 |
+
if(input_max_length is None):
|
245 |
+
input_max_length = self.config.get("input_max_length", const.DEFAULT_INPUT_MAX_LENGTH)
|
246 |
+
translated_batch = self.decode_strategy.translate_batch(batch_sentences, trg_lang=trg_lang, src_size_limit=input_max_length, output_tokens=True, debug=False)
|
247 |
+
return self.loader.detokenize(translated_batch) if not output_tokens else translated_batch
|
248 |
+
|
249 |
+
def run_train(self, model_dir=None, config=None):
|
250 |
+
opt = self.config
|
251 |
+
from utils.logging import init_logger
|
252 |
+
logging = init_logger(model_dir, opt.get('log_file_models'))
|
253 |
+
|
254 |
+
trg_pad = self.TRG.vocab.stoi['<pad>']
|
255 |
+
# load model into specific device (GPU/CPU) memory
|
256 |
+
logging.info("%s * src vocab size = %s"%(self.loader._language_tuple[0] ,len(self.SRC.vocab)))
|
257 |
+
logging.info("%s * tgt vocab size = %s"%(self.loader._language_tuple[1] ,len(self.TRG.vocab)))
|
258 |
+
logging.info("Building model...")
|
259 |
+
model = self.to(opt.get('device', const.DEFAULT_DEVICE))
|
260 |
+
|
261 |
+
checkpoint_idx = self._checkpoint_idx
|
262 |
+
if(checkpoint_idx < 0):
|
263 |
+
# initialize weights
|
264 |
+
print("Zero checkpoint detected, reinitialize the model")
|
265 |
+
for p in model.parameters():
|
266 |
+
if p.dim() > 1:
|
267 |
+
nn.init.xavier_uniform_(p)
|
268 |
+
checkpoint_idx = 0
|
269 |
+
|
270 |
+
# also, load the scores of the best model
|
271 |
+
best_model_score = saver.load_model_score(model_dir)
|
272 |
+
|
273 |
+
# set up optimizer
|
274 |
+
optim_algo = opt["optimizer"]
|
275 |
+
lr = opt["lr"]
|
276 |
+
d_model = opt["d_model"]
|
277 |
+
n_warmup_steps = opt["n_warmup_steps"]
|
278 |
+
optimizer_params = opt.get("optimizer_params", dict({}))
|
279 |
+
|
280 |
+
if optim_algo not in optimizers:
|
281 |
+
raise ValueError("Unknown optimizer: {}".format(optim_algo))
|
282 |
+
|
283 |
+
optimizer = ScheduledOptim(
|
284 |
+
optimizer=optimizers.get(optim_algo)(model.parameters(), **optimizer_params),
|
285 |
+
init_lr=lr,
|
286 |
+
d_model=d_model,
|
287 |
+
n_warmup_steps=n_warmup_steps
|
288 |
+
)
|
289 |
+
|
290 |
+
# define loss function
|
291 |
+
criterion = LabelSmoothingLoss(len(self.TRG.vocab), padding_idx=trg_pad, smoothing=opt['label_smoothing'])
|
292 |
+
|
293 |
+
# valid_src_data, valid_trg_data = self.loader._eval_data
|
294 |
+
# raise Exception("Initial bleu: %.3f" % bleu_batch_iter(self, self.valid_iter, debug=True))
|
295 |
+
logging.info(self)
|
296 |
+
model_encoder_parameters = filter(lambda p: p.requires_grad, self.encoder.parameters())
|
297 |
+
model_decoder_parameters = filter(lambda p: p.requires_grad, self.decoder.parameters())
|
298 |
+
params_encode = sum([np.prod(p.size()) for p in model_encoder_parameters])
|
299 |
+
params_decode = sum([np.prod(p.size()) for p in model_decoder_parameters])
|
300 |
+
|
301 |
+
logging.info("Encoder: %s"%(params_encode))
|
302 |
+
logging.info("Decoder: %s"%(params_decode))
|
303 |
+
logging.info("* Number of parameters: %s"%(params_encode+params_decode))
|
304 |
+
logging.info("Starting training on %s"%(opt.get('device', const.DEFAULT_DEVICE)))
|
305 |
+
|
306 |
+
for epoch in range(checkpoint_idx, opt['epochs']):
|
307 |
+
total_loss = 0.0
|
308 |
+
|
309 |
+
s = time.time()
|
310 |
+
for i, batch in enumerate(self.train_iter):
|
311 |
+
loss = self.train_step(optimizer, batch, criterion)
|
312 |
+
total_loss += loss
|
313 |
+
|
314 |
+
# print training loss after every {printevery} steps
|
315 |
+
if (i + 1) % opt['printevery'] == 0:
|
316 |
+
avg_loss = total_loss / opt['printevery']
|
317 |
+
et = time.time() - s
|
318 |
+
# print('epoch: {:03d} - iter: {:05d} - train loss: {:.4f} - time elapsed/per batch: {:.4f} {:.4f}'.format(epoch, i+1, avg_loss, et, et / opt['printevery']))
|
319 |
+
logging.info('epoch: {:03d} - iter: {:05d} - train loss: {:.4f} - time elapsed/per batch: {:.4f} {:.4f}'.format(epoch, i+1, avg_loss, et, et / opt['printevery']))
|
320 |
+
total_loss = 0
|
321 |
+
s = time.time()
|
322 |
+
|
323 |
+
# bleu calculation and evaluate, save checkpoint for every {save_checkpoint_epochs} epochs
|
324 |
+
s = time.time()
|
325 |
+
valid_loss = self.validate(self.valid_iter, criterion, maximum_length=(self.encoder._max_seq_length, self.decoder._max_seq_length))
|
326 |
+
if (epoch+1) % opt['save_checkpoint_epochs'] == 0 and model_dir is not None:
|
327 |
+
|
328 |
+
# evaluate loss and bleu score on validation dataset for each epoch
|
329 |
+
# bleuscore = bleu(valid_src_data, valid_trg_data, model, opt.get('device', const.DEFAULT_DEVICE), opt['k'], opt['max_strlen'])
|
330 |
+
# bleuscore = bleu_single(self, self.loader._eval_data)
|
331 |
+
# bleuscore = bleu_batch(self, self.loader._eval_data, batch_size=opt.get('eval_batch_size', const.DEFAULT_EVAL_BATCH_SIZE))
|
332 |
+
valid_src_lang, valid_trg_lang = self.loader.language_tuple
|
333 |
+
bleuscore = bleu_batch_iter(self, self.valid_iter, src_lang=valid_src_lang, trg_lang=valid_trg_lang)
|
334 |
+
|
335 |
+
# save_model_to_path(model, model_dir, checkpoint_idx=epoch+1)
|
336 |
+
saver.save_and_clear_model(model, model_dir, checkpoint_idx=epoch+1, maximum_saved_model=opt.get('maximum_saved_model_train', const.DEFAULT_NUM_KEEP_MODEL_TRAIN))
|
337 |
+
# keep the best models per every bleu calculation
|
338 |
+
best_model_score = saver.save_model_best_to_path(model, model_dir, best_model_score, bleuscore, maximum_saved_model=opt.get('maximum_saved_model_eval', const.DEFAULT_NUM_KEEP_MODEL_TRAIN))
|
339 |
+
# print('epoch: {:03d} - iter: {:05d} - valid loss: {:.4f} - bleu score: {:.4f} - full evaluation time: {:.4f}'.format(epoch, i, valid_loss, bleuscore, time.time() - s))
|
340 |
+
logging.info('epoch: {:03d} - iter: {:05d} - valid loss: {:.4f} - bleu score: {:.4f} - full evaluation time: {:.4f}'.format(epoch, i, valid_loss, bleuscore, time.time() - s))
|
341 |
+
else:
|
342 |
+
# print('epoch: {:03d} - iter: {:05d} - valid loss: {:.4f} - validation time: {:.4f}'.format(epoch, i, valid_loss, time.time() - s))
|
343 |
+
logging.info('epoch: {:03d} - iter: {:05d} - valid loss: {:.4f} - validation time: {:.4f}'.format(epoch, i, valid_loss, time.time() - s))
|
344 |
+
|
345 |
+
def run_infer(self, features_file, predictions_file, src_lang=None, trg_lang=None, config=None, batch_size=None):
|
346 |
+
opt = self.config
|
347 |
+
# load model into specific device (GPU/CPU) memory
|
348 |
+
model = self.to(opt.get('device', const.DEFAULT_DEVICE))
|
349 |
+
|
350 |
+
# Read inference file
|
351 |
+
print("Reading features file from {}...".format(features_file))
|
352 |
+
with io.open(features_file, "r", encoding="utf-8") as read_file:
|
353 |
+
inputs = [l.strip() for l in read_file.readlines()]
|
354 |
+
|
355 |
+
print("Performing inference ...")
|
356 |
+
# Append each translated sentence line by line
|
357 |
+
# results = "\n".join([model.loader.detokenize(model.translate_sentence(sentence)) for sentence in inputs])
|
358 |
+
# Translate by batched versions
|
359 |
+
start = time.time()
|
360 |
+
results = "\n".join( self.translate_batch_sentence(inputs, src_lang=src_lang, trg_lang=trg_lang, output_tokens=False, batch_size=batch_size))
|
361 |
+
print("Inference done, cost {:.2f} secs.".format(time.time() - start))
|
362 |
+
|
363 |
+
# Write results to system file
|
364 |
+
print("Writing results to {} ...".format(predictions_file))
|
365 |
+
with io.open(predictions_file, "w", encoding="utf-8") as write_file:
|
366 |
+
write_file.write(results)
|
367 |
+
|
368 |
+
print("All done!")
|
369 |
+
|
370 |
+
def encode(self, *args, **kwargs):
|
371 |
+
return self.encoder(*args, **kwargs)
|
372 |
+
|
373 |
+
def decode(self, *args, **kwargs):
|
374 |
+
return self.decoder(*args, **kwargs)
|
375 |
+
|
376 |
+
def to_logits(self, inputs): # function to include the logits. TODO use this in inference fns as well
|
377 |
+
return self.out(inputs)
|
378 |
+
|
379 |
+
def prepare_serve(self, serve_path, model_dir=None, check_trace=True, **kwargs):
|
380 |
+
self.eval()
|
381 |
+
"""Run to prepare for serving."""
|
382 |
+
saver.save_model_name(type(self).__name__, model_dir)
|
383 |
+
# return
|
384 |
+
# raise NotImplementedError("trace_module currently not supported")
|
385 |
+
# jit to convert model to ScriptModule.
|
386 |
+
# create junk arguments for necessary modules
|
387 |
+
fake_batch, fake_srclen, fake_trglen, fake_range = 3, 7, 4, 1000
|
388 |
+
sample_src, sample_trg = torch.randint(fake_range, (fake_batch, fake_srclen), dtype=torch.long), torch.randint(fake_range, (fake_batch, fake_trglen), dtype=torch.long)
|
389 |
+
sample_src_mask, sample_trg_mask = torch.rand(fake_batch, 1, fake_srclen) > 0.5, torch.rand(fake_batch, fake_trglen, fake_trglen) > 0.5
|
390 |
+
sample_src, sample_trg, sample_src_mask, sample_trg_mask = [t.to(self.device) for t in [sample_src, sample_trg, sample_src_mask, sample_trg_mask]]
|
391 |
+
sample_encoded = self.encode(sample_src, sample_src_mask)
|
392 |
+
sample_before_logits = self.decode(sample_trg, sample_encoded, sample_src_mask, sample_trg_mask)
|
393 |
+
# bundle within dictionary
|
394 |
+
needed_fn = {'forward': (sample_src, sample_trg, sample_src_mask, sample_trg_mask), "encode": (sample_src, sample_src_mask), "decode": (sample_trg, sample_encoded, sample_src_mask, sample_trg_mask), "to_logits": sample_before_logits}
|
395 |
+
# create the ScriptModule. Currently disabling deterministic check
|
396 |
+
traced_model = torch.jit.trace_module(self, needed_fn, check_trace=check_trace)
|
397 |
+
# save it down
|
398 |
+
torch.jit.save(traced_model, serve_path)
|
399 |
+
return serve_path
|
400 |
+
|
401 |
+
|
402 |
+
@property
|
403 |
+
def fields(self):
|
404 |
+
return (self.SRC, self.TRG)
|
modules/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from modules.default import *
|
2 |
+
from modules.prototypes import Decoder, Encoder
|
3 |
+
from modules.config import Config
|
modules/config.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml, json
|
2 |
+
import os, io
|
3 |
+
|
4 |
+
def extension_check(pth):
|
5 |
+
ext = os.path.splitext(pth)[-1]
|
6 |
+
return any( ext == valid_ext for valid_ext in [".json", ".yaml", ".yml"])
|
7 |
+
|
8 |
+
def find_all_config(directory):
|
9 |
+
return [os.path.join(directory, f) for f in os.listdir(directory) if extension_check(f)]
|
10 |
+
|
11 |
+
class Config(dict):
|
12 |
+
def __init__(self, path=None, **elements):
|
13 |
+
"""Initiate a config object, where specified elements override the default config loaded"""
|
14 |
+
super(Config, self).__init__(self._try_load_path(path))
|
15 |
+
self.update(**elements)
|
16 |
+
|
17 |
+
def _load_json(self, json_path):
|
18 |
+
with io.open(json_path, "r", encoding="utf-8") as jf:
|
19 |
+
return json.load(jf)
|
20 |
+
|
21 |
+
def _load_yaml(self, yaml_path):
|
22 |
+
with io.open(yaml_path, "r", encoding="utf-8") as yf:
|
23 |
+
return yaml.safe_load(yf.read())
|
24 |
+
|
25 |
+
def _try_load_path(self, path):
|
26 |
+
assert isinstance(path, str), "Basic Config class can only support a single file path (str), but instead is {}({})".format(path, type(path))
|
27 |
+
assert os.path.isfile(path), "Config file {:s} does not exist".format(path)
|
28 |
+
extension = os.path.splitext(path)[-1]
|
29 |
+
if(extension == ".json"):
|
30 |
+
return self._load_json(path)
|
31 |
+
elif(extension == ".yml" or extension == ".yaml"):
|
32 |
+
return self._load_yaml(path)
|
33 |
+
else:
|
34 |
+
raise ValueError("Unrecognized extension ({:s}) from file {:s}".format(extension, path))
|
35 |
+
|
36 |
+
@property
|
37 |
+
def opt(self):
|
38 |
+
"""Backward compatibility to original. Remove once finished."""
|
39 |
+
return self
|
40 |
+
|
41 |
+
class MultiplePathConfig(Config):
|
42 |
+
def _try_load_path(self, paths):
|
43 |
+
"""Update to support multiple paths."""
|
44 |
+
if(isinstance(paths, list)):
|
45 |
+
print("Loaded path is a list of locations. Load in the order received, overriding and merging as needed.")
|
46 |
+
result = {}
|
47 |
+
for pth in paths:
|
48 |
+
self._recursive_update(result, super(MultiplePathConfig, self)._try_load_path(pth))
|
49 |
+
return result
|
50 |
+
else:
|
51 |
+
return super(MultiplePathConfig, self)._try_load_path(paths)
|
52 |
+
|
53 |
+
def _recursive_update(self, orig, new):
|
54 |
+
"""Instead of overriding dicts, merge them recursively."""
|
55 |
+
# print(orig, new)
|
56 |
+
for k, v in new.items():
|
57 |
+
if(k in orig and isinstance(orig[k], dict)):
|
58 |
+
assert isinstance(v, dict), "Mismatching config with key {}: {} - {}".format(k, orig[k], v)
|
59 |
+
orig[k] = self._recursive_update(orig[k], v)
|
60 |
+
else:
|
61 |
+
orig[k] = v;
|
62 |
+
return orig
|
modules/constants.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DESIGNATE constants values for config
|
2 |
+
DEFAULT_DECODE_STRATEGY = "BeamSearch"
|
3 |
+
DEFAULT_STRATEGY_KWARGS = {}
|
4 |
+
DEFAULT_SEED = 101
|
5 |
+
DEFAULT_BATCH_SIZE = 64
|
6 |
+
DEFAULT_EVAL_BATCH_SIZE = 8
|
7 |
+
DEFAULT_TRAIN_TEST_SPLIT = 0.8
|
8 |
+
DEFAULT_DEVICE = "cpu"
|
9 |
+
DEFAULT_K = 5
|
10 |
+
DEFAULT_INPUT_MAX_LENGTH = 200
|
11 |
+
DEFAULT_MAX_LENGTH = 150
|
12 |
+
DEFAULT_TRAIN_MAX_LENGTH = 100
|
13 |
+
DEFAULT_LOWERCASE = True
|
14 |
+
DEFAULT_NUM_KEEP_MODEL_TRAIN = 5
|
15 |
+
DEFAULT_NUM_KEEP_MODEL_BEST = 5
|
16 |
+
DEFAULT_SOS = "<sos>"
|
17 |
+
DEFAULT_EOS = "<eos>"
|
18 |
+
DEFAULT_PAD = "<pad>"
|
modules/default.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class MockLoader:
|
2 |
+
def __init__(self, *args, **kwargs):
|
3 |
+
"""Only print stuff"""
|
4 |
+
print("MockLoader initialized, args/kwargs {} {}".format(args, kwargs))
|
5 |
+
|
6 |
+
def tokenize(self, inputs, **kwargs):
|
7 |
+
print("MockLoader tokenize called, inputs/kwargs {} {}".format(inputs, kwargs))
|
8 |
+
return inputs
|
9 |
+
|
10 |
+
def detokenize(self, inputs, **kwargs):
|
11 |
+
print("MockLoader detokenize called, inputs/kwargs {} {}".format(inputs, kwargs))
|
12 |
+
return inputs
|
13 |
+
|
14 |
+
def reverse_lookup(self, inputs, **kwargs):
|
15 |
+
print("MockLoader reverse_lookup called, inputs/kwargs {} {}".format(inputs, kwargs))
|
16 |
+
return inputs
|
17 |
+
|
18 |
+
def lookup(self, inputs, **kwargs):
|
19 |
+
print("MockLoader lookup called, inputs/kwargs {} {}".format(inputs, kwargs))
|
20 |
+
return inputs
|
21 |
+
|
22 |
+
def embed(self, inputs, **kwargs):
|
23 |
+
print("MockLoader embed called, inputs/kwargs {} {}".format(inputs, kwargs))
|
24 |
+
return inputs
|
25 |
+
|
26 |
+
class MockEncoder:
|
27 |
+
def __init__(self, *args, **kwargs):
|
28 |
+
"""Only print stuff"""
|
29 |
+
print("MockEncoder initialized, args/kwargs {} {}".format(args, kwargs))
|
30 |
+
|
31 |
+
def encode(self, inputs, **kwargs):
|
32 |
+
print("MockEncoder encode called, inputs/kwargs {} {}".format(inputs, kwargs))
|
33 |
+
return inputs
|
34 |
+
|
35 |
+
def __call__(self, inputs, num_layers=3, **kwargs):
|
36 |
+
print("MockEncoder __call__ called, inputs/num_layers/kwargs {} {} {}".format(inputs, num_layers, kwargs))
|
37 |
+
for i in range(num_layers):
|
38 |
+
inputs = encode(inputs, **kwargs)
|
39 |
+
return inputs
|
40 |
+
|
41 |
+
class MockDecoder:
|
42 |
+
def __init__(self, *args, **kwargs):
|
43 |
+
"""Only print stuff"""
|
44 |
+
print("MockDecoder initialized, args/kwargs {} {}".format(args, kwargs))
|
45 |
+
|
46 |
+
def decode(self, inputs, **kwargs):
|
47 |
+
print("MockDecoder decode called, inputs/kwargs {} {}".format(inputs, kwargs))
|
48 |
+
return inputs
|
49 |
+
|
50 |
+
def __call__(self, inputs, num_layers=3, **kwargs):
|
51 |
+
print("MockDecoder __call__ called, inputs/num_layers/kwargs {} {} {}".format(inputs, num_layers, kwargs))
|
52 |
+
for i in range(num_layers):
|
53 |
+
inputs = decode(inputs, **kwargs)
|
54 |
+
return inputs
|
modules/inference/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.inference.decode_strategy import DecodeStrategy
|
2 |
+
from modules.inference.beam_search import BeamSearch
|
3 |
+
from modules.inference.prototypes import BeamSearch2
|
4 |
+
from modules.inference.sampling_temperature import GreedySearch
|
5 |
+
|
6 |
+
strategies = {
|
7 |
+
"BeamSearch": BeamSearch,
|
8 |
+
"BeamSearch2": BeamSearch2,
|
9 |
+
"GreedySearch": GreedySearch
|
10 |
+
}
|
modules/inference/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (483 Bytes). View file
|
|
modules/inference/__pycache__/beam_search.cpython-36.pyc
ADDED
Binary file (15.3 kB). View file
|
|
modules/inference/__pycache__/decode_strategy.cpython-36.pyc
ADDED
Binary file (3.26 kB). View file
|
|
modules/inference/__pycache__/prototypes.cpython-36.pyc
ADDED
Binary file (4.73 kB). View file
|
|
modules/inference/__pycache__/sampling_temperature.cpython-36.pyc
ADDED
Binary file (4.42 kB). View file
|
|
modules/inference/beam_search.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import math, time, operator
|
4 |
+
import torch.nn.functional as functional
|
5 |
+
import torch.nn as nn
|
6 |
+
import logging
|
7 |
+
from torch.autograd import Variable
|
8 |
+
from torch.nn.utils.rnn import pad_sequence
|
9 |
+
|
10 |
+
from modules.inference.decode_strategy import DecodeStrategy
|
11 |
+
import modules.constants as const
|
12 |
+
from utils.misc import no_peeking_mask
|
13 |
+
from utils.data import generate_language_token
|
14 |
+
|
15 |
+
class BeamSearch(DecodeStrategy):
|
16 |
+
def __init__(self, model, max_len, device, beam_size=5, use_synonym_fn=False, replace_unk=None, length_normalize=None):
|
17 |
+
"""
|
18 |
+
Args:
|
19 |
+
model: the used model
|
20 |
+
max_len: the maximum timestep to be used
|
21 |
+
device: the device to perform calculation
|
22 |
+
beam_size: the size of the beam itself
|
23 |
+
use_synonym_fn: if set, use the get_synonym fn from wordnet to try replace <unk>
|
24 |
+
replace_unk: a tuple of [layer, head] designation, to replace the unknown word by chosen attention
|
25 |
+
"""
|
26 |
+
super(BeamSearch, self).__init__(model, max_len, device)
|
27 |
+
self.beam_size = beam_size
|
28 |
+
self._use_synonym = use_synonym_fn
|
29 |
+
self._replace_unk = replace_unk
|
30 |
+
self._length_norm = length_normalize
|
31 |
+
|
32 |
+
def init_vars(self, src, start_token=const.DEFAULT_SOS):
|
33 |
+
"""
|
34 |
+
Calculate the required matrices during translation after the model is finished
|
35 |
+
Input:
|
36 |
+
:param src: The batch of sentences
|
37 |
+
|
38 |
+
Output: Initialize the first character includes outputs, e_outputs, log_scores
|
39 |
+
"""
|
40 |
+
model = self.model
|
41 |
+
batch_size = len(src)
|
42 |
+
row_b = self.beam_size * batch_size
|
43 |
+
|
44 |
+
init_tok = self.TRG.vocab.stoi[start_token]
|
45 |
+
src_mask = (src != self.SRC.vocab.stoi['<pad>']).unsqueeze(-2).to(self.device)
|
46 |
+
src = src.to(self.device)
|
47 |
+
|
48 |
+
# Encoder
|
49 |
+
# raise Exception(src.shape, src_mask.shape)
|
50 |
+
e_output = model.encode(src, src_mask)
|
51 |
+
outputs = torch.LongTensor([[init_tok] for i in range(batch_size)])
|
52 |
+
outputs = outputs.to(self.device)
|
53 |
+
trg_mask = no_peeking_mask(1, self.device)
|
54 |
+
|
55 |
+
# Decoder
|
56 |
+
out = model.to_logits(model.decode(outputs, e_output, src_mask, trg_mask))
|
57 |
+
out = functional.softmax(out, dim=-1)
|
58 |
+
probs, ix = out[:, -1].data.topk(self.beam_size)
|
59 |
+
|
60 |
+
log_scores = torch.Tensor([math.log(p) for p in probs.data.view(-1)]).view(-1, 1)
|
61 |
+
|
62 |
+
outputs = torch.zeros(row_b, self.max_len).long()
|
63 |
+
outputs = outputs.to(self.device)
|
64 |
+
outputs[:, 0] = init_tok
|
65 |
+
outputs[:, 1] = ix.view(-1)
|
66 |
+
|
67 |
+
e_outputs = torch.repeat_interleave(e_output, self.beam_size, 0)
|
68 |
+
|
69 |
+
# raise Exception(outputs[:, :2], e_outputs)
|
70 |
+
|
71 |
+
return outputs, e_outputs, log_scores
|
72 |
+
|
73 |
+
def compute_k_best(self, outputs, out, log_scores, i, debug=False):
|
74 |
+
"""
|
75 |
+
Compute k words with the highest conditional probability
|
76 |
+
Args:
|
77 |
+
outputs: Array has k previous candidate output sequences. [batch_size*beam_size, max_len]
|
78 |
+
i: the current timestep to execute. Int
|
79 |
+
out: current output of the model at timestep. [batch_size*beam_size, vocab_size]
|
80 |
+
log_scores: Conditional probability of past candidates (in outputs) [batch_size * beam_size]
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
new outputs has k best candidate output sequences
|
84 |
+
log_scores for each of those candidate
|
85 |
+
"""
|
86 |
+
row_b = len(out); batch_size = row_b // self.beam_size
|
87 |
+
eos_id = self.TRG.vocab.stoi['<eos>']
|
88 |
+
|
89 |
+
probs, ix = out[:, -1].data.topk(self.beam_size)
|
90 |
+
|
91 |
+
probs_rep = torch.Tensor([[1] + [1e-100] * (self.beam_size-1)]*row_b).view(row_b, self.beam_size).to(self.device)
|
92 |
+
ix_rep = torch.LongTensor([[eos_id] + [-1]*(self.beam_size-1)]*row_b).view(row_b, self.beam_size).to(self.device)
|
93 |
+
|
94 |
+
check_eos = torch.repeat_interleave((outputs[:, i-1] == eos_id).view(row_b, 1), self.beam_size, 1)
|
95 |
+
|
96 |
+
probs = torch.where(check_eos, probs_rep, probs)
|
97 |
+
ix = torch.where(check_eos, ix_rep, ix)
|
98 |
+
|
99 |
+
# if(debug):
|
100 |
+
# print("kprobs before debug: ", probs, probs_rep, ix, ix_rep, log_scores)
|
101 |
+
|
102 |
+
log_probs = torch.log(probs).to(self.device) + log_scores.to(self.device) # CPU
|
103 |
+
|
104 |
+
k_probs, k_ix = log_probs.view(batch_size, -1).topk(self.beam_size)
|
105 |
+
if(debug):
|
106 |
+
print("kprobs_after_select: ", log_probs, k_probs, k_ix)
|
107 |
+
|
108 |
+
# Use cpu
|
109 |
+
k_probs, k_ix = torch.Tensor(k_probs.cpu().data.numpy()), torch.LongTensor(k_ix.cpu().data.numpy())
|
110 |
+
row = k_ix // self.beam_size + torch.LongTensor([[v*self.beam_size] for v in range(batch_size)])
|
111 |
+
col = k_ix % self.beam_size
|
112 |
+
if(debug):
|
113 |
+
print("kprobs row/col", row, col, ix[row.view(-1), col.view(-1)])
|
114 |
+
assert False
|
115 |
+
|
116 |
+
outputs[:, :i] = outputs[row.view(-1), :i]
|
117 |
+
outputs[:, i] = ix[row.view(-1), col.view(-1)]
|
118 |
+
log_scores = k_probs.view(-1, 1)
|
119 |
+
|
120 |
+
return outputs, log_scores
|
121 |
+
|
122 |
+
def replace_unknown(self, outputs, sentences, attn, selector_tuple, unknown_token="<unk>"):
|
123 |
+
"""Replace the unknown words in the outputs with the highest valued attentionized words.
|
124 |
+
Args:
|
125 |
+
outputs: the output from decoding. [batch, beam] of list of str
|
126 |
+
sentences: the original wordings of the sentences. [batch_size, src_len] of str
|
127 |
+
attn: the attention received, in the form of list: [layers units of (self-attention, attention) with shapes of [batchbeam, heads, tgt_len, tgt_len] & [batchbeam, heads, tgt_len, src_len] respectively]
|
128 |
+
selector_tuple: (layer, head) used to select the attention
|
129 |
+
unknown_token: token used for checking. str
|
130 |
+
Returns:
|
131 |
+
the replaced version, in the same shape as outputs
|
132 |
+
"""
|
133 |
+
|
134 |
+
# is_finished = torch.LongTensor([[self.TRG.vocab.stoi['<eos>']] for i in range(self.beam_offset)]).view(-1).to(self.device)
|
135 |
+
# unk_token = self.SRC.vocab.stoi['<unk>']
|
136 |
+
layer_used, head_used = selector_tuple
|
137 |
+
used_attention = attn[layer_used][-1][:, head_used] # it should be [batchbeam, tgt_len, src_len], as we are using the attention to source
|
138 |
+
flattened_outputs = outputs.reshape((-1, )) # flatten the outputs back to batchbeam
|
139 |
+
|
140 |
+
select_id_src = torch.argmax(used_attention, dim=-1).cpu().numpy() # [batchbeam, tgt_len] of best indices. Also convert to numpy version (remove sos not needed as it is attention of outputs)
|
141 |
+
beam_size = select_id_src.shape[0] // len(sentences) # used custom-calculated beam_size as we might not output the entirety of beams. See beam_search fn for details
|
142 |
+
# select per batchbeam. source batch id is found by dividing batchbeam id per beam; we are selecting [tgt_len] indices from [src_len] tokens; then concat at the first dimensions to retrieve [batch_beam, tgt_len] of replacement tokens
|
143 |
+
# need itemgetter / map to retrieve from list
|
144 |
+
replace_tokens = [ operator.itemgetter(*src_idx)(sentences[bidx // beam_size]) for bidx, src_idx in enumerate(select_id_src)]
|
145 |
+
|
146 |
+
# zip together with sentences; then output { the token if not unk / the replacement if is }. Note that this will trim the orig version down to repl size.
|
147 |
+
zipped = zip(flattened_outputs, replace_tokens)
|
148 |
+
replaced = np.array([[tok if tok != unknown_token else rpl for rpl, tok in zip(repl, orig)] for orig, repl in zipped], dtype=object)
|
149 |
+
# reshape back to outputs shape [batch, beam] of list
|
150 |
+
return replaced.reshape(outputs.shape)
|
151 |
+
|
152 |
+
# for i in range(1, self.max_len):
|
153 |
+
# ix = attn[0, 0, i-1, :].argmax().data
|
154 |
+
# outputs[:, i][outputs[:, i] == unk_token] = sentences[0][ix.data]
|
155 |
+
# if torch.equal(outputs[:, i], is_finished):
|
156 |
+
# break
|
157 |
+
#
|
158 |
+
# return outputs
|
159 |
+
|
160 |
+
def beam_search(self, src, src_lang=None, trg_lang=None, src_tokens=None, n_best=1, length_norm=None, replace_unk=None, debug=False):
|
161 |
+
"""
|
162 |
+
Beam search select k words with the highest conditional probability
|
163 |
+
to be the first word of the k candidate output sequences.
|
164 |
+
Args:
|
165 |
+
src: The batch of sentences, already in [batch_size, tokens] of int
|
166 |
+
src_tokens: src in str version, same size as above. Used almost exclusively for replace unknown word
|
167 |
+
n_best: number of usable values per beam loaded
|
168 |
+
length_norm: if specified, normalize as per (Wu, 2016); note that if not inputted then it will still use __init__ value as default. float
|
169 |
+
replace_unk: if specified, do replace unknown word using attention of (layer, head); note that if not inputted, it will still use __init__ value as default. (int, int)
|
170 |
+
debug: if true, print some debug information during the search
|
171 |
+
Return:
|
172 |
+
An array of translated sentences, in list-of-tokens format.
|
173 |
+
Either [batch_size, n_best, tgt_len] when n_best > 1
|
174 |
+
Or [batch_size, tgt_len] when n_best == 1
|
175 |
+
"""
|
176 |
+
model = self.model
|
177 |
+
start_token = const.DEFAULT_SOS if trg_lang is None else generate_language_token(trg_lang)
|
178 |
+
outputs, e_outputs, log_scores = self.init_vars(src, start_token=start_token)
|
179 |
+
|
180 |
+
eos_tok = self.TRG.vocab.stoi[const.DEFAULT_EOS]
|
181 |
+
src_mask = (src != self.SRC.vocab.stoi[const.DEFAULT_PAD]).unsqueeze(-2)
|
182 |
+
src_mask = torch.repeat_interleave(src_mask, self.beam_size, 0).to(self.device)
|
183 |
+
is_finished = torch.LongTensor([[eos_tok] for i in range(self.beam_size*len(src))]).view(-1).to(self.device)
|
184 |
+
ind = None
|
185 |
+
for i in range(2, self.max_len):
|
186 |
+
trg_mask = no_peeking_mask(i, self.device)
|
187 |
+
|
188 |
+
decoder_output, attn = model.decoder(outputs[:, :i], e_outputs, src_mask, trg_mask, output_attention=True)
|
189 |
+
out = model.out(decoder_output)
|
190 |
+
out = functional.softmax(out, dim=-1)
|
191 |
+
outputs, log_scores = self.compute_k_best(outputs, out, log_scores, i)
|
192 |
+
|
193 |
+
# Occurrences of end symbols for all input sentences.
|
194 |
+
if torch.equal(outputs[:, i], is_finished):
|
195 |
+
break
|
196 |
+
|
197 |
+
|
198 |
+
# if(self._replace_unk):
|
199 |
+
# outputs = self.replace_unknown(attn, src, outputs)
|
200 |
+
|
201 |
+
# reshape outputs and log_probs to [batch, beam] numpy array
|
202 |
+
batch_size = src.shape[0]
|
203 |
+
outputs = outputs.cpu().numpy().reshape((batch_size, self.beam_size, self.max_len))
|
204 |
+
log_scores = log_scores.cpu().numpy().reshape((batch_size, self.beam_size))
|
205 |
+
|
206 |
+
# Get the best sentences for every beam: splice by length and itos the indices, result in an array of tokens
|
207 |
+
# also remove the first token in this timestep (as it is sos)
|
208 |
+
translated_sentences = np.empty(outputs.shape[:-1], dtype=object)
|
209 |
+
trim_and_itos = lambda sent: [self.TRG.vocab.itos[i] for i in sent[1:self._length(sent, eos_tok=eos_tok)]]
|
210 |
+
for ba in range(outputs.shape[0]):
|
211 |
+
for bm in range(outputs.shape[1]):
|
212 |
+
translated_sentences[ba, bm] = trim_and_itos(outputs[ba, bm])
|
213 |
+
# raise ValueError(translated_sentences)
|
214 |
+
#translated_sentences = np.apply_along_axis(lambda sent: tuple(sent.tolist()[:self._length(sent, eos_tok=eos_tok)]), -1, outputs)
|
215 |
+
#translated_sentences = np.vectorize(lambda sent: [self.TRG.vocab.itos[i] for i in sent])(translated_sentences)
|
216 |
+
if(replace_unk is None):
|
217 |
+
replace_unk = self._replace_unk
|
218 |
+
if(replace_unk):
|
219 |
+
# replace unknown words per translated sentences. Do it before normalization (since that is independent on actual tokens)
|
220 |
+
if(src_tokens is None):
|
221 |
+
logging.warn("replace_unknown option enabled but no src_tokens supplied for the task. The method will not run.")
|
222 |
+
else:
|
223 |
+
translated_sentences = self.replace_unknown(translated_sentences, src_tokens, attn, replace_unk)
|
224 |
+
|
225 |
+
if(length_norm is None):
|
226 |
+
length_norm = self._length_norm
|
227 |
+
if(length_norm is not None):
|
228 |
+
# raise ValueError(length_norm)
|
229 |
+
# perform length normalization calculation and reorganize the sentences accordingly
|
230 |
+
lengths = np.apply_along_axis(lambda x: self._length(x, eos_tok=eos_tok), -1, outputs)
|
231 |
+
log_scores, indices = self.length_normalize(lengths, log_scores, coff=length_norm)
|
232 |
+
translated_sentences = np.array([beams[ids] for beams, ids in zip(translated_sentences, indices)])
|
233 |
+
# outputs = np.array([beams[ids] for beams, ids in zip(outputs, indices)])
|
234 |
+
|
235 |
+
# assert n_best == 1, "Currently unsupported n_best > 1. TODO write."
|
236 |
+
if(n_best == 1):
|
237 |
+
return translated_sentences[:, 0]
|
238 |
+
else:
|
239 |
+
return translated_sentences[:, :n_best]
|
240 |
+
|
241 |
+
def translate_single_sentence(self, src, **kwargs):
|
242 |
+
"""Translate a single sentence. Currently unused."""
|
243 |
+
raise NotImplementedError
|
244 |
+
return self.translate_batch_sentence([src], **kwargs)
|
245 |
+
|
246 |
+
def translate_batch_sentence(self, src, src_lang=None, trg_lang=None, field_processed=False, src_size_limit=None, output_tokens=False, replace_unk=None, debug=False):
|
247 |
+
"""Translate a batch of sentences together. Currently disabling the synonym func.
|
248 |
+
Args:
|
249 |
+
src: the batch of sentences to be translated. list of str
|
250 |
+
src_lang: the language translated from. Only used with multilingual models, in preprocess. str
|
251 |
+
trg_lang: the language to be translated to. Only used with multilingual models, in beam_search. str
|
252 |
+
field_processed: bool, if the sentences had been already processed (i.e part of batched validation data)
|
253 |
+
src_size_limit: if set, trim the input if it cross this value. Added due to current positional encoding support only <=200 tokens
|
254 |
+
output_tokens: the output format. False will give a batch of sentences (str), while True will give batch of tokens (list of str)
|
255 |
+
replace_unk: see beam_search for usage. (int, int) or False to suppress __init__ value
|
256 |
+
debug: enable to print external values
|
257 |
+
Return:
|
258 |
+
the result of translation, with format dictated by output_tokens
|
259 |
+
"""
|
260 |
+
self.model.eval()
|
261 |
+
# create the indiced batch.
|
262 |
+
processed_batch = self.preprocess_batch(src, src_lang=src_lang, field_processed=field_processed, src_size_limit=src_size_limit, output_tokens=True, debug=debug)
|
263 |
+
sent_ids, sent_tokens = (processed_batch, None) if(field_processed) else processed_batch
|
264 |
+
assert isinstance(sent_ids, torch.Tensor), "sent_ids is instead {}".format(type(sent_ids))
|
265 |
+
|
266 |
+
batch_start = time.time()
|
267 |
+
translated_sentences = self.beam_search(sent_ids, trg_lang=trg_lang, src_tokens=sent_tokens, replace_unk=replace_unk, debug=debug)
|
268 |
+
if(debug):
|
269 |
+
print("Time performed for batch {}: {:.2f}s".format(sent_ids.shape, time.time() - batch_start))
|
270 |
+
|
271 |
+
if(not output_tokens):
|
272 |
+
translated_sentences = [' '.join(tokens) for tokens in translated_sentences]
|
273 |
+
|
274 |
+
return translated_sentences
|
275 |
+
|
276 |
+
def preprocess_batch(self, sentences, src_lang=None, field_processed=False, pad_token="<pad>", src_size_limit=None, output_tokens=False, debug=True):
|
277 |
+
"""Adding
|
278 |
+
src_size_limit: int, option to limit the length of src.
|
279 |
+
src_lang: if specified (not None), append this token <{src_lang}> to the start of the batch
|
280 |
+
field_processed: bool: if the sentences had been already processed (i.e part of batched validation data)
|
281 |
+
output_tokens: if set, output a token version aside the id version, in [batch of [src_len]] str. Note that it won't work with field_processed
|
282 |
+
"""
|
283 |
+
if(field_processed):
|
284 |
+
# do nothing, as it had already performed tokenizing/stoi.
|
285 |
+
# Still cap the length of the batch due to possible infraction in valid
|
286 |
+
if(src_size_limit is not None):
|
287 |
+
sentences = sentences[:, :src_size_limit]
|
288 |
+
return sentences
|
289 |
+
processed_sent = map(self.SRC.preprocess, sentences)
|
290 |
+
if(src_lang is not None):
|
291 |
+
src_token = generate_language_token(src_lang)
|
292 |
+
processed_sent = map(lambda x: [src_token] + x, processed_sent)
|
293 |
+
if(src_size_limit):
|
294 |
+
processed_sent = map(lambda x: x[:src_size_limit], processed_sent)
|
295 |
+
processed_sent = list(processed_sent)
|
296 |
+
tokenized_sent = [torch.LongTensor([self._token_to_index(t) for t in s]) for s in processed_sent] # convert to tensors, in indices format
|
297 |
+
sentences = Variable(pad_sequence(tokenized_sent, True, padding_value=self.SRC.vocab.stoi[pad_token])) # padding sentences
|
298 |
+
if(debug):
|
299 |
+
print("Input batch after process: ", sentences.shape, sentences)
|
300 |
+
|
301 |
+
if(output_tokens):
|
302 |
+
return sentences, processed_sent
|
303 |
+
else:
|
304 |
+
return sentences
|
305 |
+
|
306 |
+
def translate_batch(self, sentences, **kwargs):
|
307 |
+
return self.translate_batch_sentence(sentences, **kwargs)
|
308 |
+
|
309 |
+
def length_normalize(self, lengths, log_probs, coff=0.6):
|
310 |
+
"""Normalize the probabilty score as in (Wu 2016). Use pure numpy values
|
311 |
+
Args:
|
312 |
+
lengths: the length of the hypothesis. [batch, beam] of int->float
|
313 |
+
log_probs: the unchanged log probability for the whole hypothesis. [batch, beam] of float
|
314 |
+
coff: the alpha coefficient.
|
315 |
+
Returns:
|
316 |
+
Tuple of (penalized_values, indices) to reorganize outputs."""
|
317 |
+
lengths = ((lengths + 5) / 6) ** coff
|
318 |
+
penalized_probs = log_probs / lengths
|
319 |
+
indices = np.argsort(penalized_probs, axis=-1)[::-1]
|
320 |
+
# basically take log_probs values for every batch
|
321 |
+
reorganized_probs = np.array([prb[ids] for prb, ids in zip(penalized_probs, indices)])
|
322 |
+
return reorganized_probs, indices
|
323 |
+
|
324 |
+
def _length(self, tokens, eos_tok=None):
|
325 |
+
"""Retrieve the first location of eos_tok as length; else return the entire length"""
|
326 |
+
if(eos_tok is None):
|
327 |
+
eos_tok = self.TRG.vocab.stoi[const.DEFAULT_EOS]
|
328 |
+
eos, = np.nonzero(tokens==eos_tok)
|
329 |
+
return len(tokens) if len(eos) == 0 else eos[0]
|
330 |
+
|
331 |
+
def _token_to_index(self, tok):
|
332 |
+
"""Override to select, depending on the self._use_synonym param"""
|
333 |
+
if(self._use_synonym):
|
334 |
+
return super(BeamSearch, self)._token_to_index(tok)
|
335 |
+
else:
|
336 |
+
return self.SRC.vocab.stoi[tok]
|
modules/inference/beam_search1.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import math, time, operator
|
4 |
+
import torch.nn.functional as functional
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.autograd import Variable
|
7 |
+
from torch.nn.utils.rnn import pad_sequence
|
8 |
+
|
9 |
+
from modules.inference.decode_strategy import DecodeStrategy
|
10 |
+
from utils.misc import no_peeking_mask
|
11 |
+
|
12 |
+
class BeamSearch1(DecodeStrategy):
|
13 |
+
def __init__(self, model, max_len, device, beam_size=5, use_synonym_fn=False, replace_unk=None):
|
14 |
+
"""
|
15 |
+
Args:
|
16 |
+
model: the used model
|
17 |
+
max_len: the maximum timestep to be used
|
18 |
+
device: the device to perform calculation
|
19 |
+
beam_size: the size of the beam itself
|
20 |
+
use_synonym_fn: if set, use the get_synonym fn from wordnet to try replace <unk>
|
21 |
+
replace_unk: a tuple of [layer, head] designation, to replace the unknown word by chosen attention
|
22 |
+
"""
|
23 |
+
super(BeamSearch1, self).__init__(model, max_len, device)
|
24 |
+
self.beam_size = beam_size
|
25 |
+
self._use_synonym = use_synonym_fn
|
26 |
+
self._replace_unk = replace_unk
|
27 |
+
# print("Init BeamSearch ----------------")
|
28 |
+
|
29 |
+
def trg_init_vars(self, src, batch_size, trg_init_token, trg_eos_token, single_src_mask):
|
30 |
+
"""
|
31 |
+
Calculate the required matrices during translation after the model is finished
|
32 |
+
Input:
|
33 |
+
:param src: The batch of sentences
|
34 |
+
|
35 |
+
Output: Initialize the first character includes outputs, e_outputs, log_scores
|
36 |
+
"""
|
37 |
+
# Initialize target sequence (start with '<sos>' token) [batch_size x k x max_len]
|
38 |
+
trg = torch.zeros(batch_size, self.beam_size, self.max_len, device=self.device).long()
|
39 |
+
trg[:, :, 0] = trg_init_token
|
40 |
+
|
41 |
+
# Precalc output from model's encoder
|
42 |
+
e_out = self.model.encoder(src, single_src_mask) # [batch_size x S x d_model]
|
43 |
+
# Output model prob
|
44 |
+
trg_mask = no_peeking_mask(1, device=self.device)
|
45 |
+
# [batch_size x 1]
|
46 |
+
inp_decoder = trg[:, 0, 0].view(batch_size, 1)
|
47 |
+
# [batch_size x 1 x vocab_size]
|
48 |
+
prob = self.model.out(self.model.decoder(inp_decoder, e_out, single_src_mask, trg_mask))
|
49 |
+
prob = functional.softmax(prob, dim=-1)
|
50 |
+
|
51 |
+
# [batch_size x 1 x k]
|
52 |
+
k_prob, k_index = torch.topk(prob, self.beam_size, dim=-1)
|
53 |
+
trg[:, :, 1] = k_index.view(batch_size, self.beam_size)
|
54 |
+
# Init log scores from k beams [batch_size x k x 1]
|
55 |
+
log_scores = torch.log(k_prob.view(batch_size, self.beam_size, 1))
|
56 |
+
|
57 |
+
# Repeat encoder's output k times for searching [(k * batch_size) x S x d_model]
|
58 |
+
e_outs = torch.repeat_interleave(e_out, self.beam_size, dim=0)
|
59 |
+
src_mask = torch.repeat_interleave(single_src_mask, self.beam_size, dim=0)
|
60 |
+
|
61 |
+
# Create mask for checking eos
|
62 |
+
sent_eos = torch.tensor([trg_eos_token for _ in range(self.beam_size)], device=self.device).view(1, self.beam_size)
|
63 |
+
|
64 |
+
return sent_eos, log_scores, e_outs, e_out, src_mask, trg
|
65 |
+
|
66 |
+
def compute_k_best(self, outputs, out, log_scores, i, debug=False):
|
67 |
+
"""
|
68 |
+
Compute k words with the highest conditional probability
|
69 |
+
Args:
|
70 |
+
outputs: Array has k previous candidate output sequences. [batch_size*beam_size, max_len]
|
71 |
+
i: the current timestep to execute. Int
|
72 |
+
out: current output of the model at timestep. [batch_size*beam_size, vocab_size]
|
73 |
+
log_scores: Conditional probability of past candidates (in outputs) [batch_size * beam_size]
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
new outputs has k best candidate output sequences
|
77 |
+
log_scores for each of those candidate
|
78 |
+
"""
|
79 |
+
row_b = len(out);
|
80 |
+
batch_size = row_b // self.beam_size
|
81 |
+
eos_id = self.TRG.vocab.stoi['<eos>']
|
82 |
+
|
83 |
+
probs, ix = out[:, -1].data.topk(self.beam_size)
|
84 |
+
|
85 |
+
probs_rep = torch.Tensor([[1] + [1e-100] * (self.beam_size-1)]*row_b).view(row_b, self.beam_size).to(self.device)
|
86 |
+
ix_rep = torch.LongTensor([[eos_id] + [-1]*(self.beam_size-1)]*row_b).view(row_b, self.beam_size).to(self.device)
|
87 |
+
|
88 |
+
check_eos = torch.repeat_interleave((outputs[:, i-1] == eos_id).view(row_b, 1), self.beam_size, 1)
|
89 |
+
|
90 |
+
probs = torch.where(check_eos, probs_rep, probs)
|
91 |
+
ix = torch.where(check_eos, ix_rep, ix)
|
92 |
+
|
93 |
+
log_probs = torch.log(probs).to(self.device) + log_scores.to(self.device) # CPU
|
94 |
+
|
95 |
+
k_probs, k_ix = log_probs.view(batch_size, -1).topk(self.beam_size)
|
96 |
+
if(debug):
|
97 |
+
print("kprobs_after_select: ", log_probs, k_probs, k_ix)
|
98 |
+
|
99 |
+
# Use cpu
|
100 |
+
k_probs, k_ix = torch.Tensor(k_probs.cpu().data.numpy()), torch.LongTensor(k_ix.cpu().data.numpy())
|
101 |
+
row = k_ix // self.beam_size + torch.LongTensor([[v*self.beam_size] for v in range(batch_size)])
|
102 |
+
col = k_ix % self.beam_size
|
103 |
+
if(debug):
|
104 |
+
print("kprobs row/col", row, col, ix[row.view(-1), col.view(-1)])
|
105 |
+
assert False
|
106 |
+
|
107 |
+
outputs[:, :i] = outputs[row.view(-1), :i]
|
108 |
+
outputs[:, i] = ix[row.view(-1), col.view(-1)]
|
109 |
+
log_scores = k_probs.view(-1, 1)
|
110 |
+
|
111 |
+
return outputs, log_scores
|
112 |
+
|
113 |
+
def replace_unknown(self, outputs, sentences, attn, selector_tuple, unknown_token="<unk>"):
|
114 |
+
"""Replace the unknown words in the outputs with the highest valued attentionized words.
|
115 |
+
Args:
|
116 |
+
outputs: the output from decoding. [batchbeam] of list of str, with maximum values being
|
117 |
+
sentences: the original wordings of the sentences. [batch_size, src_len] of str
|
118 |
+
attn: the attention received, in the form of list: [layers units of (self-attention, attention) with shapes of [batchbeam, heads, tgt_len, tgt_len] & [batchbeam, heads, tgt_len, src_len] respectively]
|
119 |
+
selector_tuple: (layer, head) used to select the attention
|
120 |
+
unknown_token: token used for
|
121 |
+
Returns:
|
122 |
+
the replaced version, in the same shape as outputs
|
123 |
+
"""
|
124 |
+
layer_used, head_used = selector_tuple
|
125 |
+
# used_attention = attn[layer_used][-1][:, head_used] # it should be [batchbeam, tgt_len, src_len], as we are using the attention to source
|
126 |
+
inx = torch.arange(start=0,end=len(attn)-1, step=self.beam_size)
|
127 |
+
used_attention = attn[inx]
|
128 |
+
select_id_src = torch.argmax(used_attention, dim=-1).cpu().numpy() # [batchbeam, tgt_len] of best indices. Also convert to numpy version (remove sos not needed as it is attention of outputs)
|
129 |
+
# print(select_id_src, len(select_id_src))
|
130 |
+
beam_size = select_id_src.shape[0] // len(sentences) # used custom-calculated beam_size as we might not output the entirety of beams. See beam_search fn for details
|
131 |
+
# print("beam: ", beam_size)
|
132 |
+
# select per batchbeam. source batch id is found by dividing batchbeam id per beam; we are selecting [tgt_len] indices from [src_len] tokens; then concat at the first dimensions to retrieve [batch_beam, tgt_len] of replacement tokens
|
133 |
+
# need itemgetter / map to retrieve from list
|
134 |
+
# print([ operator.itemgetter(*src_idx)(sentences[bidx // beam_size]) for bidx, src_idx in enumerate(select_id_src)])
|
135 |
+
# print([print(sentences[bidx // beam_size], src_idx) for bidx, src_idx in enumerate(select_id_src)])
|
136 |
+
# replace_tokens = [ operator.itemgetter(*src_idx)(sentences[bidx // beam_size]) for bidx, src_idx in enumerate(select_id_src)]
|
137 |
+
|
138 |
+
for i in range(len(outputs)):
|
139 |
+
for j in range(len(outputs[i])):
|
140 |
+
if outputs[i][j] == unknown_token:
|
141 |
+
outputs[i][j] = sentences[i][select_id_src[i][j]]
|
142 |
+
|
143 |
+
# print(sentences[0][0], outputs[0][0])
|
144 |
+
|
145 |
+
# print(i)
|
146 |
+
# zip together with sentences; then output { the token if not unk / the replacement if is }. Note that this will trim the orig version down to repl size.
|
147 |
+
# replaced = [ [tok if tok != unknown_token else rpl for rpl, tok in zip(repl, orig)] for orig, repl in zipped ]
|
148 |
+
|
149 |
+
# return replaced
|
150 |
+
return outputs
|
151 |
+
|
152 |
+
# def beam_search(self, src, max_len, device, k=4):
|
153 |
+
def beam_search(self, src, src_tokens=None, n_best=1, debug=False):
|
154 |
+
"""
|
155 |
+
Beam search for a single sentence
|
156 |
+
Args:
|
157 |
+
model : a Transformer instance
|
158 |
+
src : a batch (tokenized + numerized) sentence [batch_size x S]
|
159 |
+
Returns:
|
160 |
+
trg : a batch (tokenized + numerized) sentence [batch_size x T]
|
161 |
+
"""
|
162 |
+
src = src.to(self.device)
|
163 |
+
trg_init_token = self.TRG.vocab.stoi["<sos>"]
|
164 |
+
trg_eos_token = self.TRG.vocab.stoi["<eos>"]
|
165 |
+
single_src_mask = (src != self.SRC.vocab.stoi['<pad>']).unsqueeze(1).to(self.device)
|
166 |
+
batch_size = src.size(0)
|
167 |
+
|
168 |
+
sent_eos, log_scores, e_outs, e_out, src_mask, trg = self.trg_init_vars(src, batch_size, trg_init_token, trg_eos_token, single_src_mask)
|
169 |
+
|
170 |
+
# The batch indexes
|
171 |
+
batch_index = torch.arange(batch_size)
|
172 |
+
finished_batches = torch.zeros(batch_size, device=self.device).long()
|
173 |
+
|
174 |
+
log_attn = torch.zeros([self.beam_size*batch_size, self.max_len, len(src[0])])
|
175 |
+
|
176 |
+
# Iteratively searching
|
177 |
+
for i in range(2, self.max_len):
|
178 |
+
trg_mask = no_peeking_mask(i, self.device)
|
179 |
+
|
180 |
+
# Flatten trg tensor for feeding into model [(k * batch_size) x i]
|
181 |
+
inp_decoder = trg[batch_index, :, :i].view(self.beam_size * len(batch_index), i)
|
182 |
+
# Output model prob [(k * batch_size) x i x vocab_size]
|
183 |
+
current_decode, attn = self.model.decoder(inp_decoder, e_outs, src_mask, trg_mask, output_attention=True)
|
184 |
+
# print(len(attn[0]))
|
185 |
+
|
186 |
+
prob = self.model.out(current_decode)
|
187 |
+
prob = functional.softmax(prob, dim=-1)
|
188 |
+
|
189 |
+
# Only care the last prob i-th
|
190 |
+
# [(k * batch_size) x 1 x vocab_size]
|
191 |
+
prob = prob[:, i-1, :].view(self.beam_size * len(batch_index), 1, -1)
|
192 |
+
|
193 |
+
# Truncate prob to top k [(k * batch_size) x 1 x k]
|
194 |
+
k_prob, k_index = prob.data.topk(self.beam_size, dim=-1)
|
195 |
+
|
196 |
+
# Deflatten k_prob & k_index
|
197 |
+
k_prob = k_prob.view(len(batch_index), self.beam_size, 1, self.beam_size)
|
198 |
+
k_index = k_index.view(len(batch_index), self.beam_size, 1, self.beam_size)
|
199 |
+
|
200 |
+
# Preserve eos beams
|
201 |
+
# [batch_size x k] -> view -> [batch_size x k x 1 x 1] (broadcastable)
|
202 |
+
eos_mask = (trg[batch_index, :, i-1] == trg_eos_token).view(len(batch_index), self.beam_size, 1, 1)
|
203 |
+
k_prob.masked_fill_(eos_mask, 1.0)
|
204 |
+
k_index.masked_fill_(eos_mask, trg_eos_token)
|
205 |
+
|
206 |
+
# Find the best k cases
|
207 |
+
# Compute log score at i-th timestep
|
208 |
+
# [batch_size x k x 1 x 1] + [batch_size x k x 1 x k] = [batch_size x k x 1 x k]
|
209 |
+
combine_probs = log_scores[batch_index].unsqueeze(-1) + torch.log(k_prob)
|
210 |
+
# [batch_size x k x 1]
|
211 |
+
log_scores[batch_index], positions = torch.topk(combine_probs.view(len(batch_index), self.beam_size * self.beam_size, 1), self.beam_size, dim=1)
|
212 |
+
|
213 |
+
# The rows selected from top k
|
214 |
+
rows = positions.view(len(batch_index), self.beam_size) // self.beam_size
|
215 |
+
# The indexes in vocab respected to these rows
|
216 |
+
cols = positions.view(len(batch_index), self.beam_size) % self.beam_size
|
217 |
+
|
218 |
+
batch_sim = torch.arange(len(batch_index)).view(-1, 1)
|
219 |
+
trg[batch_index, :, :] = trg[batch_index.view(-1, 1), rows, :]
|
220 |
+
trg[batch_index, :, i] = k_index[batch_sim, rows, :, cols].view(len(batch_index), self.beam_size)
|
221 |
+
|
222 |
+
# Update attn
|
223 |
+
inx = torch.repeat_interleave(finished_batches, self.beam_size, dim=0)
|
224 |
+
batch_attn = torch.nonzero(inx == 0).view(-1)
|
225 |
+
# import copy
|
226 |
+
# x = copy.deepcopy(attn[0][-1][:, 0].to("cpu"))
|
227 |
+
# log_attn[batch_attn, :i, :] = x
|
228 |
+
|
229 |
+
# if i == 7:
|
230 |
+
# print(log_attn[batch_attn, :i, :].shape, attn[0][-1][:, 0].shape)
|
231 |
+
# print(log_attn[batch_attn, :i, :])
|
232 |
+
# Update which sentences finished all its beams
|
233 |
+
mask = (trg[:, :, i] == sent_eos).all(1).view(-1).to(self.device)
|
234 |
+
finished_batches.masked_fill_(mask, value=1)
|
235 |
+
cnt = torch.sum(finished_batches).item()
|
236 |
+
if cnt == batch_size:
|
237 |
+
break
|
238 |
+
|
239 |
+
# # Continue with remaining batches (if any)
|
240 |
+
batch_index = torch.nonzero(finished_batches == 0).view(-1)
|
241 |
+
e_outs = torch.repeat_interleave(e_out[batch_index], self.beam_size, dim=0)
|
242 |
+
src_mask = torch.repeat_interleave(single_src_mask[batch_index], self.beam_size, dim=0)
|
243 |
+
# End loop
|
244 |
+
|
245 |
+
# Get the best beam
|
246 |
+
log_scores = log_scores.view(batch_size, self.beam_size)
|
247 |
+
results = []
|
248 |
+
for t, j in enumerate(torch.argmax(log_scores, dim=-1)):
|
249 |
+
sent = []
|
250 |
+
for i in range(self.max_len):
|
251 |
+
token_id = trg[t, j.item(), i].item()
|
252 |
+
if token_id == trg_init_token:
|
253 |
+
continue
|
254 |
+
if token_id == trg_eos_token:
|
255 |
+
break
|
256 |
+
sent.append(self.TRG.vocab.itos[token_id])
|
257 |
+
results.append(sent)
|
258 |
+
|
259 |
+
# if(self._replace_unk and src_tokens is not None):
|
260 |
+
# # replace unknown words per translated sentences.
|
261 |
+
# # NOTE: lacking a src_tokens does not raise any warning. Add that in when logging module is available, to support error catching
|
262 |
+
# # print("Replace unk -----------------------")
|
263 |
+
# results = self.replace_unknown(results, src_tokens, log_attn, self._replace_unk)
|
264 |
+
|
265 |
+
return results
|
266 |
+
|
267 |
+
def translate_single_sentence(self, src, **kwargs):
|
268 |
+
"""Translate a single sentence. Currently unused."""
|
269 |
+
raise NotImplementedError
|
270 |
+
return self.translate_batch_sentence([src], **kwargs)
|
271 |
+
|
272 |
+
def translate_batch_sentence(self, src, field_processed=False, src_size_limit=None, output_tokens=False, debug=False):
|
273 |
+
"""Translate a batch of sentences together. Currently disabling the synonym func.
|
274 |
+
Args:
|
275 |
+
src: the batch of sentences to be translated
|
276 |
+
field_processed: bool, if the sentences had been already processed (i.e part of batched validation data)
|
277 |
+
src_size_limit: if set, trim the input if it cross this value. Added due to current positional encoding support only <=200 tokens
|
278 |
+
output_tokens: the output format. False will give a batch of sentences (str), while True will give batch of tokens (list of str)
|
279 |
+
debug: enable to print external values
|
280 |
+
Return:
|
281 |
+
the result of translation, with format dictated by output_tokens
|
282 |
+
"""
|
283 |
+
# start = time.time()
|
284 |
+
|
285 |
+
self.model.eval()
|
286 |
+
# create the indiced batch.
|
287 |
+
processed_batch = self.preprocess_batch(src, field_processed=field_processed, src_size_limit=src_size_limit, output_tokens=True, debug=debug)
|
288 |
+
# print("Time preprocess_batch: ", time.time()-start)
|
289 |
+
|
290 |
+
sent_ids, sent_tokens = (processed_batch, None) if(field_processed) else processed_batch
|
291 |
+
assert isinstance(sent_ids, torch.Tensor), "sent_ids is instead {}".format(type(sent_ids))
|
292 |
+
|
293 |
+
translated_sentences = self.beam_search(sent_ids, src_tokens=sent_tokens, debug=debug)
|
294 |
+
|
295 |
+
# print("Time for one batch: ", time.time()-batch_start)
|
296 |
+
|
297 |
+
# if time.time()-batch_start > 2:
|
298 |
+
# [print("len src >2 : ++++++", len(i.split())) for i in src]
|
299 |
+
# [print("len translate >2: ++++++", len(i)) for i in translated_sentences]
|
300 |
+
# else:
|
301 |
+
# [print("len src : ====", len(i.split())) for i in src]
|
302 |
+
# [print("len translate : ====", len(i)) for i in translated_sentences]
|
303 |
+
# print("=====================================")
|
304 |
+
|
305 |
+
# time.sleep(4)
|
306 |
+
if(debug):
|
307 |
+
print("Time performed for batch {}: {:.2f}s".format(sent_ids.shape))
|
308 |
+
|
309 |
+
if(not output_tokens):
|
310 |
+
translated_sentences = [' '.join(tokens) for tokens in translated_sentences]
|
311 |
+
|
312 |
+
return translated_sentences
|
313 |
+
|
314 |
+
def preprocess_batch(self, sentences, field_processed=False, pad_token="<pad>", src_size_limit=None, output_tokens=False, debug=True):
|
315 |
+
"""Adding
|
316 |
+
src_size_limit: int, option to limit the length of src.
|
317 |
+
field_processed: bool: if the sentences had been already processed (i.e part of batched validation data)
|
318 |
+
output_tokens: if set, output a token version aside the id version, in [batch of [src_len]] str. Note that it won't work with field_processed
|
319 |
+
"""
|
320 |
+
|
321 |
+
if(field_processed):
|
322 |
+
# do nothing, as it had already performed tokenizing/stoi
|
323 |
+
return sentences
|
324 |
+
processed_sent = map(self.SRC.preprocess, sentences)
|
325 |
+
if(src_size_limit):
|
326 |
+
processed_sent = map(lambda x: x[:src_size_limit], processed_sent)
|
327 |
+
processed_sent = list(processed_sent)
|
328 |
+
tokenized_sent = [torch.LongTensor([self._token_to_index(t) for t in s]) for s in processed_sent] # convert to tensors, in indices format
|
329 |
+
sentences = Variable(pad_sequence(tokenized_sent, True, padding_value=self.SRC.vocab.stoi[pad_token])) # padding sentences
|
330 |
+
if(debug):
|
331 |
+
print("Input batch after process: ", sentences.shape, sentences)
|
332 |
+
|
333 |
+
if(output_tokens):
|
334 |
+
return sentences, processed_sent
|
335 |
+
else:
|
336 |
+
return sentences
|
337 |
+
|
338 |
+
def translate_batch(self, sentences, **kwargs):
|
339 |
+
return self.translate_batch_sentence(sentences, **kwargs)
|
340 |
+
|
341 |
+
def _token_to_index(self, tok):
|
342 |
+
"""Override to select, depending on the self._use_synonym param"""
|
343 |
+
if(self._use_synonym):
|
344 |
+
return super(BeamSearch1, self)._token_to_index(tok)
|
345 |
+
else:
|
346 |
+
return self.SRC.vocab.stoi[tok]
|
modules/inference/decode_strategy.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.autograd import Variable
|
3 |
+
from utils.data import get_synonym
|
4 |
+
from torch.nn.utils.rnn import pad_sequence
|
5 |
+
import abc
|
6 |
+
|
7 |
+
class DecodeStrategy(object):
|
8 |
+
"""
|
9 |
+
Base, abstract class for generation strategies. Contain specific call to base model that use it
|
10 |
+
|
11 |
+
"""
|
12 |
+
def __init__(self, model, max_len, device):
|
13 |
+
self.model = model
|
14 |
+
self.max_len = max_len
|
15 |
+
self.device = device
|
16 |
+
|
17 |
+
@property
|
18 |
+
def SRC(self):
|
19 |
+
return self.model.SRC
|
20 |
+
|
21 |
+
@property
|
22 |
+
def TRG(self):
|
23 |
+
return self.model.TRG
|
24 |
+
|
25 |
+
@abc.abstractmethod
|
26 |
+
def translate_single(self, src_lang, trg_lang, sentences):
|
27 |
+
"""Translate a single sentence. Might be useful as backcompatibility"""
|
28 |
+
raise NotImplementedError
|
29 |
+
|
30 |
+
@abc.abstractmethod
|
31 |
+
def translate_batch(self, src_lang, trg_lang, sentences):
|
32 |
+
"""Translate a batch of sentences.
|
33 |
+
Args:
|
34 |
+
sentences: The sentences, formatted as [batch_size] Tensor of str
|
35 |
+
Returns:
|
36 |
+
The detokenized output, most commonly [batch_size] of str
|
37 |
+
"""
|
38 |
+
|
39 |
+
raise NotImplementedError
|
40 |
+
|
41 |
+
@abc.abstractmethod
|
42 |
+
def replace_unknown(self, *args):
|
43 |
+
"""Replace unknown words from batched sentences"""
|
44 |
+
raise NotImplementedError
|
45 |
+
|
46 |
+
def preprocess_batch(self, lang, sentences, pad_token="<pad>"):
|
47 |
+
"""Feed a unprocessed batch into the torchtext.Field of source.
|
48 |
+
Args:
|
49 |
+
sentences: [batch_size] of str
|
50 |
+
pad_token: the pad token used to pad the sentences
|
51 |
+
Returns:
|
52 |
+
the sentences in Tensor format, padded with pad_value"""
|
53 |
+
processed_sent = list(map(self.SRC.preprocess, sentences)) # tokenizing
|
54 |
+
tokenized_sent = [Torch.LongTensor([self._token_to_index(t) for t in s]) for s in processed_sent] # convert to tensors and indices
|
55 |
+
sentences = Variable(pad_sequence(tokenized_sent, True, padding_value=self.SRC.vocab.stoi[pad_token])) # padding sentences
|
56 |
+
return sentences
|
57 |
+
|
58 |
+
def _token_to_index(self, tok):
|
59 |
+
"""Implementing get_synonym as default. Override if want to use default behavior (<unk> for unknown words, independent of wordnet)"""
|
60 |
+
if self.SRC.vocab.stoi[tok] != self.SRC.vocab.stoi['<eos>']:
|
61 |
+
return self.SRC.vocab.stoi[tok]
|
62 |
+
return get_synonym(tok, self.SRC)
|
modules/inference/greedy_search.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##@title Beam của mình
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import math
|
5 |
+
import torch.nn.functional as functional
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.autograd import Variable
|
8 |
+
|
9 |
+
from modules.inference.decode_strategy import DecodeStrategy
|
10 |
+
from utils.misc import no_peeking_mask
|
11 |
+
|
12 |
+
class GreedySearch(DecodeStrategy):
|
13 |
+
def __init__(self, model, max_len, device, replace_unk=None):
|
14 |
+
"""
|
15 |
+
:param beam_size
|
16 |
+
:param batch_size
|
17 |
+
:param beam_offset
|
18 |
+
"""
|
19 |
+
super(GreedySearch, self).__init__(model, max_len, device)
|
20 |
+
# self.replace_unk = replace_unk
|
21 |
+
# raise NotImplementedError("Replace unk was yeeted from base class DecodeStrategy. Fix first.")
|
22 |
+
|
23 |
+
def initilize_value(self, sentences):
|
24 |
+
"""
|
25 |
+
Calculate the required matrices during translation after the model is finished
|
26 |
+
Input:
|
27 |
+
:param src: Sentences
|
28 |
+
|
29 |
+
Output: Initialize the first character includes outputs, e_outputs, log_scores
|
30 |
+
"""
|
31 |
+
batch_size=len(sentences)
|
32 |
+
init_tok = self.TRG.vocab.stoi['<sos>']
|
33 |
+
src_mask = (sentences != self.SRC.vocab.stoi['<pad>']).unsqueeze(-2)
|
34 |
+
eos_tok = self.TRG.vocab.stoi['<eos>']
|
35 |
+
|
36 |
+
# Encoder
|
37 |
+
e_output = self.model.encoder(sentences, src_mask)
|
38 |
+
|
39 |
+
out = torch.LongTensor([[init_tok] for i in range(batch_size)])
|
40 |
+
outputs = torch.zeros(batch_size, self.max_len).long()
|
41 |
+
outputs[:, :1] = out
|
42 |
+
|
43 |
+
outputs = outputs.to(self.device)
|
44 |
+
is_finished = torch.LongTensor([[eos_tok] for i in range(batch_size)]).view(-1).to(self.device)
|
45 |
+
return eos_tok, src_mask, is_finished, e_output, outputs
|
46 |
+
|
47 |
+
def create_trg_mask(self, i, device):
|
48 |
+
return no_peeking_mask(i, device)
|
49 |
+
|
50 |
+
def current_predict(self, outputs, e_output, src_mask, trg_mask):
|
51 |
+
model = self.model
|
52 |
+
# out, attn = model.out(model.decoder(outputs, e_output, src_mask, trg_mask))
|
53 |
+
decoder_output, attn = model.decoder(outputs, e_output, src_mask, trg_mask, output_attention=True)
|
54 |
+
# total_time_decode += time.time()-decode_time
|
55 |
+
out = model.out(decoder_output)
|
56 |
+
|
57 |
+
out = functional.softmax(out, dim=-1)
|
58 |
+
return out, attn
|
59 |
+
|
60 |
+
def greedy_search(self, sentences, sampling_temp=0.0, keep_topk=1):
|
61 |
+
batch_size=len(sentences)
|
62 |
+
|
63 |
+
eos_tok, src_mask, is_finished, e_output, outputs = self.initilize_value(sentences)
|
64 |
+
|
65 |
+
for i in range(1, self.max_len):
|
66 |
+
out, attn = self.current_predict(outputs[:, :i], e_output, src_mask, self.create_trg_mask(i, self.device))
|
67 |
+
topk_ix, topk_prob = self.sample_with_temperature(out[:, -1], sampling_temp, keep_topk)
|
68 |
+
outputs[:, i] = topk_ix.view(-1)
|
69 |
+
if torch.equal(outputs[:, i], is_finished):
|
70 |
+
break
|
71 |
+
|
72 |
+
# if self.replace_unk == True:
|
73 |
+
# outputs = self.replace_unknown(attn, sentences, outputs)
|
74 |
+
|
75 |
+
# print("\n".join([' '.join([self.TRG.vocab.itos[tok] for tok in line[1:]]) for line in outputs]))
|
76 |
+
# Write to file or Print to the console
|
77 |
+
translated_sentences = []
|
78 |
+
# Get the best sentences: idx = 0 + i*k
|
79 |
+
for i in range(0, len(outputs)):
|
80 |
+
is_eos = torch.nonzero(outputs[i]==eos_tok)
|
81 |
+
if len(is_eos) == 0:
|
82 |
+
# if there is no sequence end, remove
|
83 |
+
sent = outputs[i, 1:]
|
84 |
+
else:
|
85 |
+
length = is_eos[0]
|
86 |
+
sent = outputs[i, 1:length]
|
87 |
+
translated_sentences.append([self.TRG.vocab.itos[tok] for tok in sent])
|
88 |
+
|
89 |
+
return translated_sentences
|
90 |
+
|
91 |
+
def sample_with_temperature(self, logits, sampling_temp, keep_topk):
|
92 |
+
if sampling_temp == 0.0 or keep_topk == 1:
|
93 |
+
# For temp=0.0, take the argmax to avoid divide-by-zero errors.
|
94 |
+
# keep_topk=1 is also equivalent to argmax.
|
95 |
+
topk_scores, topk_ids = logits.topk(1, dim=-1)
|
96 |
+
if sampling_temp > 0:
|
97 |
+
topk_scores /= sampling_temp
|
98 |
+
else:
|
99 |
+
logits = torch.div(logits, sampling_temp)
|
100 |
+
|
101 |
+
if keep_topk > 0:
|
102 |
+
top_values, top_indices = torch.topk(logits, keep_topk, dim=1)
|
103 |
+
kth_best = top_values[:, -1].view([-1, 1])
|
104 |
+
kth_best = kth_best.repeat([1, logits.shape[1]]).float()
|
105 |
+
|
106 |
+
# Set all logits that are not in the top-k to -10000.
|
107 |
+
# This puts the probabilities close to 0.
|
108 |
+
ignore = torch.lt(logits, kth_best)
|
109 |
+
logits = logits.masked_fill(ignore, -10000)
|
110 |
+
|
111 |
+
dist = torch.distributions.Multinomial(
|
112 |
+
logits=logits, total_count=1)
|
113 |
+
topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True)
|
114 |
+
topk_scores = logits.gather(dim=1, index=topk_ids)
|
115 |
+
return topk_ids, topk_scores
|
116 |
+
|
117 |
+
def translate_batch(self, sentences, src_size_limit, output_tokens=True, debug=False):
|
118 |
+
# super(BeamSearch, self).__init__()
|
119 |
+
sentences = self.preprocess_batch(sentences).to(self.device)
|
120 |
+
return self.greedy_search(sentences, 0.2, 2)
|
121 |
+
# print(self.initilize_value(sentences))
|
modules/inference/prototypes.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, time
|
2 |
+
import torch.nn.functional as functional
|
3 |
+
from torch.autograd import Variable
|
4 |
+
from torch.nn.utils.rnn import pad_sequence
|
5 |
+
|
6 |
+
from modules.inference.beam_search import BeamSearch
|
7 |
+
from utils.data import generate_language_token
|
8 |
+
import modules.constants as const
|
9 |
+
|
10 |
+
def generate_subsequent_mask(sz, device):
|
11 |
+
return torch.triu(
|
12 |
+
torch.ones(sz, sz, dtype=torch.int, device=device)
|
13 |
+
).transpose(0, 1).unsqueeze(0)
|
14 |
+
|
15 |
+
class BeamSearch2(BeamSearch):
|
16 |
+
"""
|
17 |
+
Same with BeamSearch2 class.
|
18 |
+
Difference: remove the sentence which its beams terminated (reached <eos> token) from the time step loop.
|
19 |
+
Update to reuse functions already coded in normal BeamSearch. Note that replacing unknown words & n_best is not available.
|
20 |
+
"""
|
21 |
+
def _convert_to_sent(self, sent_id, eos_token_id):
|
22 |
+
eos = torch.nonzero(sent_id == eos_token_id).view(-1)
|
23 |
+
t = eos[0] if len(eos) > 0 else len(sent_id)
|
24 |
+
return [self.TRG.vocab.itos[j] for j in sent_id[1 : t]]
|
25 |
+
|
26 |
+
@torch.no_grad()
|
27 |
+
def beam_search(self, src, src_lang=None, trg_lang=None, src_tokens=None, n_best=1, debug=False):
|
28 |
+
"""
|
29 |
+
Beam search select k words with the highest conditional probability
|
30 |
+
to be the first word of the k candidate output sequences.
|
31 |
+
Args:
|
32 |
+
src: The batch of sentences, already in [batch_size, tokens] of int
|
33 |
+
src_tokens: src in str version, same size as above
|
34 |
+
n_best: number of usable values per beam loaded (Not implemented)
|
35 |
+
debug: if true, print some debug information during the search
|
36 |
+
Return:
|
37 |
+
An array of translated sentences, in list-of-tokens format. TODO convert [batch_size, n_best, tgt_len] instead of [batch_size, tgt_len]
|
38 |
+
"""
|
39 |
+
# Create some local variable
|
40 |
+
src_field, trg_field = self.SRC, self.TRG
|
41 |
+
sos_token = generate_language_token(trg_lang) if trg_lang is not None else const.DEFAULT_SOS
|
42 |
+
init_token = trg_field.vocab.stoi[sos_token]
|
43 |
+
eos_token_id = trg_field.vocab.stoi[const.DEFAULT_EOS]
|
44 |
+
src = src.to(self.device)
|
45 |
+
|
46 |
+
batch_size = src.size(0)
|
47 |
+
model = self.model
|
48 |
+
k = self.beam_size
|
49 |
+
max_len = self.max_len
|
50 |
+
device = self.device
|
51 |
+
|
52 |
+
# Initialize target sequence (start with '<sos>' token) [batch_size x k x max_len]
|
53 |
+
trg = torch.zeros(batch_size, k, max_len, device=device).long()
|
54 |
+
trg[:, :, 0] = init_token
|
55 |
+
|
56 |
+
# Precalc output from model's encoder
|
57 |
+
single_src_mask = (src != src_field.vocab.stoi['<pad>']).unsqueeze(1).to(device)
|
58 |
+
e_out = model.encoder(src, single_src_mask) # [batch_size x S x d_model]
|
59 |
+
|
60 |
+
# Output model prob
|
61 |
+
trg_mask = generate_subsequent_mask(1, device=device)
|
62 |
+
# [batch_size x 1]
|
63 |
+
inp_decoder = trg[:, 0, 0].view(batch_size, 1)
|
64 |
+
# [batch_size x 1 x vocab_size]
|
65 |
+
prob = model.out(model.decoder(inp_decoder, e_out, single_src_mask, trg_mask))
|
66 |
+
prob = functional.softmax(prob, dim=-1)
|
67 |
+
|
68 |
+
# [batch_size x 1 x k]
|
69 |
+
k_prob, k_index = torch.topk(prob, k, dim=-1)
|
70 |
+
trg[:, :, 1] = k_index.view(batch_size, k)
|
71 |
+
# Init log scores from k beams [batch_size x k x 1]
|
72 |
+
log_scores = torch.log(k_prob.view(batch_size, k, 1))
|
73 |
+
|
74 |
+
# Repeat encoder's output k times for searching [(k * batch_size) x S x d_model]
|
75 |
+
e_outs = torch.repeat_interleave(e_out, k, dim=0)
|
76 |
+
src_mask = torch.repeat_interleave(single_src_mask, k, dim=0)
|
77 |
+
|
78 |
+
# Create mask for checking eos
|
79 |
+
sent_eos = torch.tensor([eos_token_id for _ in range(k)], device=device).view(1, k)
|
80 |
+
|
81 |
+
# The batch indexes
|
82 |
+
batch_index = torch.arange(batch_size)
|
83 |
+
finished_batches = torch.zeros(batch_size, device=device).long()
|
84 |
+
|
85 |
+
# Iteratively searching
|
86 |
+
for i in range(2, max_len):
|
87 |
+
trg_mask = generate_subsequent_mask(i, device)
|
88 |
+
|
89 |
+
# Flatten trg tensor for feeding into model [(k * batch_size) x i]
|
90 |
+
inp_decoder = trg[batch_index, :, :i].view(k * len(batch_index), i)
|
91 |
+
# Output model prob [(k * batch_size) x i x vocab_size]
|
92 |
+
prob = model.out(model.decoder(inp_decoder, e_outs, src_mask, trg_mask))
|
93 |
+
prob = functional.softmax(prob, dim=-1)
|
94 |
+
|
95 |
+
# Only care the last prob i-th
|
96 |
+
# [(k * batch_size) x 1 x vocab_size]
|
97 |
+
prob = prob[:, i-1, :].view(k * len(batch_index), 1, -1)
|
98 |
+
|
99 |
+
# Truncate prob to top k [(k * batch_size) x 1 x k]
|
100 |
+
k_prob, k_index = prob.data.topk(k, dim=-1)
|
101 |
+
|
102 |
+
# Deflatten k_prob & k_index
|
103 |
+
k_prob = k_prob.view(len(batch_index), k, 1, k)
|
104 |
+
k_index = k_index.view(len(batch_index), k, 1, k)
|
105 |
+
|
106 |
+
# Preserve eos beams
|
107 |
+
# [batch_size x k] -> view -> [batch_size x k x 1 x 1] (broadcastable)
|
108 |
+
eos_mask = (trg[batch_index, :, i-1] == eos_token_id).view(len(batch_index), k, 1, 1)
|
109 |
+
k_prob.masked_fill_(eos_mask, 1.0)
|
110 |
+
k_index.masked_fill_(eos_mask, eos_token_id)
|
111 |
+
|
112 |
+
# Find the best k cases
|
113 |
+
# Compute log score at i-th timestep
|
114 |
+
# [batch_size x k x 1 x 1] + [batch_size x k x 1 x k] = [batch_size x k x 1 x k]
|
115 |
+
combine_probs = log_scores[batch_index].unsqueeze(-1) + torch.log(k_prob)
|
116 |
+
# [batch_size x k x 1]
|
117 |
+
log_scores[batch_index], positions = torch.topk(combine_probs.view(len(batch_index), k * k, 1), k, dim=1)
|
118 |
+
|
119 |
+
# The rows selected from top k
|
120 |
+
rows = positions.view(len(batch_index), k) // k
|
121 |
+
# The indexes in vocab respected to these rows
|
122 |
+
cols = positions.view(len(batch_index), k) % k
|
123 |
+
|
124 |
+
batch_sim = torch.arange(len(batch_index)).view(-1, 1)
|
125 |
+
trg[batch_index, :, :] = trg[batch_index.view(-1, 1), rows, :]
|
126 |
+
trg[batch_index, :, i] = k_index[batch_sim, rows, :, cols].view(len(batch_index), k)
|
127 |
+
|
128 |
+
# Update which sentences finished all its beams
|
129 |
+
mask = (trg[:, :, i] == sent_eos).all(1).view(-1).to(device)
|
130 |
+
finished_batches.masked_fill_(mask, value=1)
|
131 |
+
cnt = torch.sum(finished_batches).item()
|
132 |
+
if cnt == batch_size:
|
133 |
+
break
|
134 |
+
|
135 |
+
# Continue with remaining batches (if any)
|
136 |
+
batch_index = torch.nonzero(finished_batches == 0).view(-1)
|
137 |
+
e_outs = torch.repeat_interleave(e_out[batch_index], k, dim=0)
|
138 |
+
src_mask = torch.repeat_interleave(single_src_mask[batch_index], k, dim=0)
|
139 |
+
# End loop
|
140 |
+
|
141 |
+
# Get the best beam
|
142 |
+
log_scores = log_scores.view(batch_size, k)
|
143 |
+
results = [self._convert_to_sent(trg[t, j.item(), :], eos_token_id) for t, j in enumerate(torch.argmax(log_scores, dim=-1))]
|
144 |
+
return results
|
modules/inference/sampling_temperature.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##@title Beam của mình
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import math
|
5 |
+
import torch.nn.functional as functional
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.autograd import Variable
|
8 |
+
|
9 |
+
from modules.inference.decode_strategy import DecodeStrategy
|
10 |
+
from utils.misc import no_peeking_mask
|
11 |
+
|
12 |
+
class GreedySearch(DecodeStrategy):
|
13 |
+
def __init__(self, model, max_len, device, replace_unk=True):
|
14 |
+
"""
|
15 |
+
:param beam_size
|
16 |
+
:param batch_size
|
17 |
+
:param beam_offset
|
18 |
+
"""
|
19 |
+
super(GreedySearch, self).__init__(model, max_len, device)
|
20 |
+
self.batch_size = batch_size
|
21 |
+
self.replace_unk = replace_unk
|
22 |
+
raise NotImplementedError("Replace unk was yeeted from base class DecodeStrategy. Fix first.")
|
23 |
+
|
24 |
+
def initilize_value(self, sentences):
|
25 |
+
"""
|
26 |
+
Calculate the required matrices during translation after the model is finished
|
27 |
+
Input:
|
28 |
+
:param src: Sentences
|
29 |
+
|
30 |
+
Output: Initialize the first character includes outputs, e_outputs, log_scores
|
31 |
+
"""
|
32 |
+
|
33 |
+
init_tok = self.TRG.vocab.stoi['<sos>']
|
34 |
+
src_mask = (sentences != self.SRC.vocab.stoi['<pad>']).unsqueeze(-2)
|
35 |
+
eos_tok = self.TRG.vocab.stoi['<eos>']
|
36 |
+
|
37 |
+
# Encoder
|
38 |
+
e_output = self.model.encoder(sentences, src_mask)
|
39 |
+
|
40 |
+
out = torch.LongTensor([[init_tok] for i in range(self.batch_size)])
|
41 |
+
outputs = torch.zeros(self.batch_size, self.max_len).long()
|
42 |
+
outputs[:, :1] = out
|
43 |
+
|
44 |
+
outputs = outputs.to(self.device)
|
45 |
+
is_finished = torch.LongTensor([[eos_tok] for i in range(self.batch_size)]).view(-1).to(self.device)
|
46 |
+
return eos_tok, src_mask, is_finished, e_output, outputs
|
47 |
+
|
48 |
+
def create_trg_mask(self, i, device):
|
49 |
+
return no_peeking_mask(i, device)
|
50 |
+
|
51 |
+
def current_predict(self, outputs, e_output, src_mask, trg_mask):
|
52 |
+
out, attn = self.model.out(self.model.decoder(outputs,
|
53 |
+
e_output, src_mask, trg_mask))
|
54 |
+
out = functional.softmax(out, dim=-1)
|
55 |
+
return out, attn
|
56 |
+
|
57 |
+
def greedy_search(self, sentences, sampling_temp=0.0, keep_topk=1):
|
58 |
+
if len(sentences) < self.batch_size:
|
59 |
+
self.batch_size = len(sentences)
|
60 |
+
|
61 |
+
eos_tok, src_mask, is_finished, e_output, outputs = self.initilize_value(sentences)
|
62 |
+
|
63 |
+
for i in range(1, self.max_len):
|
64 |
+
out, attn = self.current_predict(outputs[:, :i], e_output, src_mask, self.create_trg_mask(i, self.device))
|
65 |
+
topk_ix, topk_prob = self.sample_with_temperature(out[:, -1], sampling_temp, keep_topk)
|
66 |
+
outputs[:, i] = topk_ix.view(-1)
|
67 |
+
if torch.equal(outputs[:, i], is_finished):
|
68 |
+
break
|
69 |
+
|
70 |
+
if self.replace_unk == True:
|
71 |
+
outputs = self.replace_unknown(attn, sentences, outputs)
|
72 |
+
|
73 |
+
# print("\n".join([' '.join([self.TRG.vocab.itos[tok] for tok in line[1:]]) for line in outputs]))
|
74 |
+
# Write to file or Print to the console
|
75 |
+
translated_sentences = []
|
76 |
+
# Get the best sentences: idx = 0 + i*k
|
77 |
+
for i in range(0, len(outputs)):
|
78 |
+
is_eos = torch.nonzero(outputs[i]==eos_tok)
|
79 |
+
if len(is_eos) == 0:
|
80 |
+
# if there is no sequence end, remove
|
81 |
+
sent = outputs[i, 1:]
|
82 |
+
else:
|
83 |
+
length = is_eos[0]
|
84 |
+
sent = outputs[i, 1:length]
|
85 |
+
translated_sentences.append([self.TRG.vocab.itos[tok] for tok in sent])
|
86 |
+
|
87 |
+
return translated_sentences
|
88 |
+
|
89 |
+
def sample_with_temperature(self, logits, sampling_temp, keep_topk):
|
90 |
+
if sampling_temp == 0.0 or keep_topk == 1:
|
91 |
+
# For temp=0.0, take the argmax to avoid divide-by-zero errors.
|
92 |
+
# keep_topk=1 is also equivalent to argmax.
|
93 |
+
topk_scores, topk_ids = logits.topk(1, dim=-1)
|
94 |
+
if sampling_temp > 0:
|
95 |
+
topk_scores /= sampling_temp
|
96 |
+
else:
|
97 |
+
logits = torch.div(logits, sampling_temp)
|
98 |
+
|
99 |
+
if keep_topk > 0:
|
100 |
+
top_values, top_indices = torch.topk(logits, keep_topk, dim=1)
|
101 |
+
kth_best = top_values[:, -1].view([-1, 1])
|
102 |
+
kth_best = kth_best.repeat([1, logits.shape[1]]).float()
|
103 |
+
|
104 |
+
# Set all logits that are not in the top-k to -10000.
|
105 |
+
# This puts the probabilities close to 0.
|
106 |
+
ignore = torch.lt(logits, kth_best)
|
107 |
+
logits = logits.masked_fill(ignore, -10000)
|
108 |
+
|
109 |
+
dist = torch.distributions.Multinomial(
|
110 |
+
logits=logits, total_count=1)
|
111 |
+
topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True)
|
112 |
+
topk_scores = logits.gather(dim=1, index=topk_ids)
|
113 |
+
return topk_ids, topk_scores
|
114 |
+
|
115 |
+
def translate_batch(self, sentences):
|
116 |
+
# super(BeamSearch, self).__init__()
|
117 |
+
sentences = self.preprocess_batch(sentences).to(self.device)
|
118 |
+
return self.greedy_search(sentences, 0.2, 2)
|
119 |
+
# print(self.initilize_value(sentences))
|
modules/loader/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .default_loader import DefaultLoader
|
2 |
+
from .multilingual_loader import MultiLoader
|
3 |
+
|
4 |
+
loaders = {"monoloader": DefaultLoader, "multiloader": MultiLoader}
|
modules/loader/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (313 Bytes). View file
|
|
modules/loader/__pycache__/default_loader.cpython-36.pyc
ADDED
Binary file (6.24 kB). View file
|
|
modules/loader/__pycache__/multilingual_loader.cpython-36.pyc
ADDED
Binary file (5.97 kB). View file
|
|
modules/loader/default_loader.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io, os
|
2 |
+
import dill as pickle
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from torchtext.data import BucketIterator, Dataset, Example, Field
|
6 |
+
from torchtext.datasets import TranslationDataset, Multi30k, IWSLT, WMT14
|
7 |
+
from collections import Counter
|
8 |
+
|
9 |
+
import modules.constants as const
|
10 |
+
from utils.save import load_vocab_from_path
|
11 |
+
import laonlp
|
12 |
+
|
13 |
+
class DefaultLoader:
|
14 |
+
def __init__(self, train_path_or_name, language_tuple=None, valid_path=None, eval_path=None, option=None):
|
15 |
+
"""Load training/eval data file pairing, process and create data iterator for training """
|
16 |
+
self._language_tuple = language_tuple
|
17 |
+
self._train_path = train_path_or_name
|
18 |
+
self._eval_path = eval_path
|
19 |
+
self._option = option
|
20 |
+
|
21 |
+
@property
|
22 |
+
def language_tuple(self):
|
23 |
+
"""DefaultLoader will use the default lang option @bleu_batch_iter <sos>, hence, None"""
|
24 |
+
return None, None
|
25 |
+
|
26 |
+
def tokenize(self, sentence):
|
27 |
+
return sentence.strip().split()
|
28 |
+
|
29 |
+
def detokenize(self, list_of_tokens):
|
30 |
+
"""Differentiate between [batch, len] and [len]; joining tokens back to strings"""
|
31 |
+
if( len(list_of_tokens) == 0 or isinstance(list_of_tokens[0], str)):
|
32 |
+
# [len], single sentence version
|
33 |
+
return " ".join(list_of_tokens)
|
34 |
+
else:
|
35 |
+
# [batch, len], batch sentence version
|
36 |
+
return [" ".join(tokens) for tokens in list_of_tokens]
|
37 |
+
|
38 |
+
def _train_path_is_name(self):
|
39 |
+
return os.path.isfile(self._train_path + self._language_tuple[0]) and os.path.isfile(self._train_path + self._language_tuple[1])
|
40 |
+
|
41 |
+
def create_length_constraint(self, token_limit):
|
42 |
+
"""Filter an iterator if it pass a token limit"""
|
43 |
+
return lambda x: len(x.src) <= token_limit and len(x.trg) <= token_limit
|
44 |
+
|
45 |
+
def build_field(self, **kwargs):
|
46 |
+
"""Build fields that will handle the conversion from token->idx and vice versa. To be overriden by MultiLoader."""
|
47 |
+
return Field(lower=False, tokenize=laonlp.tokenize.word_tokenize), Field(lower=False, tokenize=self.tokenize, init_token=const.DEFAULT_SOS, eos_token=const.DEFAULT_EOS, is_target=True)
|
48 |
+
|
49 |
+
def build_vocab(self, fields, model_path=None, data=None, **kwargs):
|
50 |
+
"""Build the vocabulary object for torchtext Field. There are three flows:
|
51 |
+
- if the model path is present, it will first try to load the pickled/dilled vocab object from path. This is accessed on continued training & standalone inference
|
52 |
+
- if that failed and data is available, try to build the vocab from that data. This is accessed on first time training
|
53 |
+
- if data is not available, search for set of two vocab files and read them into the fields. This is accessed on first time training
|
54 |
+
TODO: expand on the vocab file option (loading pretrained vectors as well)
|
55 |
+
"""
|
56 |
+
src_field, trg_field = fields
|
57 |
+
if(model_path is None or not load_vocab_from_path(model_path, self._language_tuple, fields)):
|
58 |
+
# the condition will try to load vocab pickled to model path.
|
59 |
+
if(data is not None):
|
60 |
+
print("Building vocab from received data.")
|
61 |
+
# build the vocab using formatted data.
|
62 |
+
src_field.build_vocab(data, **kwargs)
|
63 |
+
trg_field.build_vocab(data, **kwargs)
|
64 |
+
else:
|
65 |
+
print("Building vocab from preloaded text file.")
|
66 |
+
# load the vocab values from external location (a formatted text file). Initialize values as random
|
67 |
+
external_vocab_location = self._option.get("external_vocab_location", None)
|
68 |
+
src_ext, trg_ext = self._language_tuple
|
69 |
+
# read the files and create a mock Counter object; which then is passed to vocab's class
|
70 |
+
# see Field.build_vocab for the options used with vocab_cls
|
71 |
+
vocab_src = external_vocab_location + src_ext
|
72 |
+
with io.open(vocab_src, "r", encoding="utf-8") as svf:
|
73 |
+
mock_counter = Counter({w.strip():1 for w in svf.readlines()})
|
74 |
+
special_tokens = [src_field.unk_token, src_field.pad_token, src_field.init_token, src_field.eos_token]
|
75 |
+
src_field.vocab = src_field.vocab_cls(mock_counter, specials=special_tokens, min_freq=5, **kwargs)
|
76 |
+
vocab_trg = external_vocab_location + trg_ext
|
77 |
+
with io.open(vocab_trg, "r", encoding="utf-8") as tvf:
|
78 |
+
mock_counter = Counter({w.strip():1 for w in tvf.readlines()})
|
79 |
+
special_tokens = [trg_field.unk_token, trg_field.pad_token, trg_field.init_token, trg_field.eos_token]
|
80 |
+
trg_field.vocab = trg_field.vocab_cls(mock_counter, specials=special_tokens, min_freq=5, **kwargs)
|
81 |
+
else:
|
82 |
+
print("Load vocab from path successful.")
|
83 |
+
|
84 |
+
def create_iterator(self, fields, model_path=None):
|
85 |
+
"""Create the iterator needed to load batches of data and bind them to existing fields
|
86 |
+
NOTE: unlike the previous loader, this one inputs list of tokens instead of a string, which necessitate redefinining of translate_sentence pipe"""
|
87 |
+
if(not self._train_path_is_name()):
|
88 |
+
# load the default torchtext dataset by name
|
89 |
+
# TODO load additional arguments in the config
|
90 |
+
dataset_cls = next( (s for s in [Multi30k, IWSLT, WMT14] if s.__name__ == self._train_path), None )
|
91 |
+
if(dataset_cls is None):
|
92 |
+
raise ValueError("The specified train path {:s}(+{:s}/{:s}) does neither point to a valid files path nor is a name of torchtext dataset class.".format(self._train_path, *self._language_tuple))
|
93 |
+
src_suffix, trg_suffix = ext = self._language_tuple
|
94 |
+
# print(ext, fields)
|
95 |
+
self._train_data, self._valid_data, self._eval_data = dataset_cls.splits(exts=ext, fields=fields) #, split_ratio=self._option.get("train_test_split", const.DEFAULT_TRAIN_TEST_SPLIT)
|
96 |
+
else:
|
97 |
+
# create dataset from path. Also add all necessary constraints (e.g lengths trimming/excluding)
|
98 |
+
src_suffix, trg_suffix = ext = self._language_tuple
|
99 |
+
filter_fn = self.create_length_constraint(self._option.get("train_max_length", const.DEFAULT_TRAIN_MAX_LENGTH))
|
100 |
+
self._train_data = TranslationDataset(self._train_path, ext, fields, filter_pred=filter_fn)
|
101 |
+
self._valid_data = self._eval_data = TranslationDataset(self._eval_path, ext, fields)
|
102 |
+
# first_sample = self._train_data[0]; raise Exception("{} {}".format(first_sample.src, first_sample.trg))
|
103 |
+
# whatever created, we now have the two set of data ready. add the necessary constraints/filtering/etc.
|
104 |
+
train_data = self._train_data
|
105 |
+
eval_data = self._eval_data
|
106 |
+
# now we can execute build_vocab. This function will try to load vocab from model_path, and if fail, build the vocab from train_data
|
107 |
+
build_vocab_kwargs = self._option.get("build_vocab_kwargs", {})
|
108 |
+
self.build_vocab(fields, data=train_data, model_path=model_path, **build_vocab_kwargs)
|
109 |
+
# raise Exception("{}".format(len(src_field.vocab)))
|
110 |
+
# crafting iterators
|
111 |
+
train_iter = BucketIterator(train_data, batch_size=self._option.get("batch_size", const.DEFAULT_BATCH_SIZE), device=self._option.get("device", const.DEFAULT_DEVICE) )
|
112 |
+
eval_iter = BucketIterator(eval_data, batch_size=self._option.get("eval_batch_size", const.DEFAULT_EVAL_BATCH_SIZE), device=self._option.get("device", const.DEFAULT_DEVICE), train=False )
|
113 |
+
return train_iter, eval_iter
|
114 |
+
|
modules/loader/multilingual_loader.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io, os
|
2 |
+
import dill as pickle
|
3 |
+
from collections import Counter
|
4 |
+
import torch
|
5 |
+
from torchtext.data import BucketIterator, Dataset, Example, Field, interleave_keys
|
6 |
+
import modules.constants as const
|
7 |
+
from utils.save import load_vocab_from_path
|
8 |
+
from utils.data import generate_language_token
|
9 |
+
from modules.loader.default_loader import DefaultLoader
|
10 |
+
import laonlp
|
11 |
+
|
12 |
+
class MultiDataset(Dataset):
|
13 |
+
"""
|
14 |
+
Ensemble one or more corpuses from different languages.
|
15 |
+
The corpuses use global source vocab and target vocab.
|
16 |
+
|
17 |
+
Constructor Args:
|
18 |
+
data_info: list of datasets info <See `train` argument in MultiLoader class>
|
19 |
+
fields: A tuple containing src field and trg field.
|
20 |
+
"""
|
21 |
+
@staticmethod
|
22 |
+
def sort_key(ex):
|
23 |
+
return interleave_keys(len(ex.src), len(ex.trg))
|
24 |
+
|
25 |
+
def __init__(self, data_info, fields, **kwargs):
|
26 |
+
self.languages = set()
|
27 |
+
|
28 |
+
if not isinstance(fields[0], (tuple, list)):
|
29 |
+
fields = [('src', fields[0]), ('trg', fields[1])]
|
30 |
+
|
31 |
+
examples = []
|
32 |
+
for corpus, info in data_info:
|
33 |
+
print("Loading corpus {} ...".format(corpus))
|
34 |
+
|
35 |
+
src_lang = info["src_lang"]
|
36 |
+
trg_lang = info["trg_lang"]
|
37 |
+
src_path = os.path.expanduser('.'.join([info["path"], src_lang]))
|
38 |
+
trg_path = os.path.expanduser('.'.join([info["path"], trg_lang]))
|
39 |
+
self.languages.add(src_lang)
|
40 |
+
self.languages.add(trg_lang)
|
41 |
+
|
42 |
+
with io.open(src_path, mode='r', encoding='utf-8') as src_file, \
|
43 |
+
io.open(trg_path, mode='r', encoding='utf-8') as trg_file:
|
44 |
+
for src_line, trg_line in zip(src_file, trg_file):
|
45 |
+
src_line, trg_line = src_line.strip(), trg_line.strip()
|
46 |
+
if src_line != '' and trg_line != '':
|
47 |
+
# Append language-specific prefix token
|
48 |
+
src_line = ' '.join([generate_language_token(src_lang), src_line])
|
49 |
+
trg_line = ' '.join([generate_language_token(trg_lang), trg_line])
|
50 |
+
examples.append(Example.fromlist([src_line, trg_line], fields))
|
51 |
+
print("Done!")
|
52 |
+
|
53 |
+
super(MultiDataset, self).__init__(examples, fields, **kwargs)
|
54 |
+
|
55 |
+
|
56 |
+
class MultiLoader(DefaultLoader):
|
57 |
+
def __init__(self, train, valid=None, option=None):
|
58 |
+
"""
|
59 |
+
Load multiple training/eval parallel data files, process and create data iterator
|
60 |
+
Constructor Args:
|
61 |
+
train: a dictionary contains training data information
|
62 |
+
valid (optional): a dictionary contains validation data information
|
63 |
+
option (optional): a dictionary contains configurable parameters
|
64 |
+
|
65 |
+
For example:
|
66 |
+
train = {
|
67 |
+
"corpus_1": {
|
68 |
+
"path": path/to/training/data,
|
69 |
+
"src_lang": src,
|
70 |
+
"trg_lang": trg
|
71 |
+
},
|
72 |
+
"corpus_2": {
|
73 |
+
...
|
74 |
+
}
|
75 |
+
}
|
76 |
+
"""
|
77 |
+
self._train_info = train
|
78 |
+
self._valid_info = valid
|
79 |
+
self._language_tuple = ('.src', '.trg')
|
80 |
+
self._option = option
|
81 |
+
|
82 |
+
@property
|
83 |
+
def tokenize(self, sentence):
|
84 |
+
return sentence.strip().split()
|
85 |
+
|
86 |
+
|
87 |
+
def language_tuple(self):
|
88 |
+
"""Currently output valid data's tuple for bleu_valid_iter, which would use <{trg_lang}> during inference. Since <{src_lang}> had already been added to the valid data, return None instead."""
|
89 |
+
return None, self._valid_info["trg_lang"]
|
90 |
+
|
91 |
+
def _is_path(self, path, lang):
|
92 |
+
"""Check whether the path is a system path or a corpus name"""
|
93 |
+
return os.path.isfile(path + '.' + lang)
|
94 |
+
|
95 |
+
def build_field(self, **kwargs):
|
96 |
+
return Field(lower=False, tokenize=laonlp.tokenize.word_tokenize), Field(lower=False, tokenize=self.tokenize, init_token=const.DEFAULT_SOS, eos_token=const.DEFAULT_EOS)
|
97 |
+
|
98 |
+
def build_vocab(self, fields, model_path=None, data=None, **kwargs):
|
99 |
+
"""Build the vocabulary object for torchtext Field. There are three flows:
|
100 |
+
- if the model path is present, it will first try to load the pickled/dilled vocab object from path. This is accessed on continued training & standalone inference
|
101 |
+
- if that failed and data is available, try to build the vocab from that data. This is accessed on first time training
|
102 |
+
- if data is not available, search for set of two vocab files and read them into the fields. This is accessed on first time training
|
103 |
+
TODO: expand on the vocab file option (loading pretrained vectors as well)
|
104 |
+
"""
|
105 |
+
src_field, trg_field = fields
|
106 |
+
if model_path is None or not load_vocab_from_path(model_path, self._language_tuple, fields):
|
107 |
+
# the condition will try to load vocab pickled to model path.
|
108 |
+
if data is not None:
|
109 |
+
print("Building vocab from received data.")
|
110 |
+
# build the vocab using formatted data.
|
111 |
+
src_field.build_vocab(data, **kwargs)
|
112 |
+
trg_field.build_vocab(data, **kwargs)
|
113 |
+
else:
|
114 |
+
# Not implemented mixing preloaded datasets and external datasets
|
115 |
+
raise ValueError("MultiLoader currently do not support preloaded text vocab")
|
116 |
+
else:
|
117 |
+
print("Load vocab from path successful.")
|
118 |
+
|
119 |
+
def create_iterator(self, fields, model_path=None):
|
120 |
+
"""Create the iterator needed to load batches of data and bind them to existing fields"""
|
121 |
+
# create dataset from path. Also add all necessary constraints (e.g lengths trimming/excluding)
|
122 |
+
filter_fn = self.create_length_constraint(self._option.get("train_max_length", const.DEFAULT_TRAIN_MAX_LENGTH))
|
123 |
+
self._train_data = MultiDataset(data_info=self._train_info.items(), fields=fields, filter_pred=filter_fn)
|
124 |
+
|
125 |
+
# now we can execute build_vocab. This function will try to load vocab from model_path, and if fail, build the vocab from train_data
|
126 |
+
build_vocab_kwargs = self._option.get("build_vocab_kwargs", {})
|
127 |
+
build_vocab_kwargs["specials"] = build_vocab_kwargs.pop("specials", []) + list(self._train_data.languages)
|
128 |
+
self.build_vocab(fields, data=self._train_data, model_path=model_path, **build_vocab_kwargs)
|
129 |
+
|
130 |
+
# Create train iterator
|
131 |
+
train_iter = BucketIterator(self._train_data, batch_size=self._option.get("batch_size", const.DEFAULT_BATCH_SIZE), device=self._option.get("device", const.DEFAULT_DEVICE))
|
132 |
+
|
133 |
+
if self._valid_info is not None:
|
134 |
+
self._valid_data = MultiDataset(data_info=[("valid", self._valid_info)], fields=fields)
|
135 |
+
valid_iter = BucketIterator(self._valid_data, batch_size=self._option.get("eval_batch_size", const.DEFAULT_EVAL_BATCH_SIZE), device=self._option.get("device", const.DEFAULT_DEVICE), train=False)
|
136 |
+
else:
|
137 |
+
valid_iter = None
|
138 |
+
|
139 |
+
return train_iter, valid_iter
|
modules/optim/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.optim.adam import AdamOptim
|
2 |
+
from modules.optim.adabelief import AdaBeliefOptim
|
3 |
+
from modules.optim.scheduler import ScheduledOptim
|
4 |
+
|
5 |
+
optimizers = {"Adam": AdamOptim, "AdaBelief": AdaBeliefOptim}
|
modules/optim/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (376 Bytes). View file
|
|
modules/optim/__pycache__/adabelief.cpython-36.pyc
ADDED
Binary file (1.48 kB). View file
|
|
modules/optim/__pycache__/adam.cpython-36.pyc
ADDED
Binary file (1.43 kB). View file
|
|
modules/optim/__pycache__/scheduler.cpython-36.pyc
ADDED
Binary file (2.09 kB). View file
|
|
modules/optim/adabelief.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, math
|
2 |
+
|
3 |
+
class AdaBeliefOptim(torch.optim.Optimizer):
|
4 |
+
|
5 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-9, **kwargs):
|
6 |
+
defaults = dict(lr=lr, betas=betas, eps=eps)
|
7 |
+
super().__init__(params, defaults)
|
8 |
+
|
9 |
+
@torch.no_grad()
|
10 |
+
def step(self, closure=None):
|
11 |
+
for group in self.param_groups:
|
12 |
+
for p in group['params']:
|
13 |
+
if p.grad is None:
|
14 |
+
# No backward
|
15 |
+
continue
|
16 |
+
grad = p.grad
|
17 |
+
state = self.state[p]
|
18 |
+
|
19 |
+
if len(state) == 0:
|
20 |
+
# Initial step
|
21 |
+
state['step'] = 0
|
22 |
+
state['exp_avg'] = torch.zeros_like(p)
|
23 |
+
state['exp_avg_sq'] = torch.zeros_like(p)
|
24 |
+
|
25 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
26 |
+
beta1, beta2 = group['betas']
|
27 |
+
|
28 |
+
state['step'] += 1
|
29 |
+
bias_correction1 = 1 - beta1 ** state['step']
|
30 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
31 |
+
|
32 |
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
33 |
+
|
34 |
+
# This is the diff between Adam and Adabelief
|
35 |
+
centered_grad = grad.sub(exp_avg)
|
36 |
+
exp_avg_sq.mul_(beta2).addcmul_(centered_grad, centered_grad, value=1-beta2)
|
37 |
+
# !
|
38 |
+
|
39 |
+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
40 |
+
step_size = group['lr'] / bias_correction1
|
41 |
+
|
42 |
+
p.addcdiv_(exp_avg, denom, value=-step_size)
|
modules/optim/adam.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, math
|
2 |
+
|
3 |
+
class AdamOptim(torch.optim.Optimizer):
|
4 |
+
|
5 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-9, **kwargs):
|
6 |
+
defaults = dict(lr=lr, betas=betas, eps=eps)
|
7 |
+
super().__init__(params, defaults)
|
8 |
+
|
9 |
+
@torch.no_grad()
|
10 |
+
def step(self, closure=None):
|
11 |
+
for group in self.param_groups:
|
12 |
+
for p in group['params']:
|
13 |
+
if p.grad is None:
|
14 |
+
# No backward
|
15 |
+
continue
|
16 |
+
grad = p.grad
|
17 |
+
state = self.state[p]
|
18 |
+
|
19 |
+
if len(state) == 0:
|
20 |
+
# Initial step
|
21 |
+
state['step'] = 0
|
22 |
+
state['exp_avg'] = torch.zeros_like(p)
|
23 |
+
state['exp_avg_sq'] = torch.zeros_like(p)
|
24 |
+
|
25 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
26 |
+
beta1, beta2 = group['betas']
|
27 |
+
|
28 |
+
state['step'] += 1
|
29 |
+
bias_correction1 = 1 - beta1 ** state['step']
|
30 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
31 |
+
|
32 |
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
33 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
34 |
+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
35 |
+
|
36 |
+
step_size = group['lr'] / bias_correction1
|
37 |
+
|
38 |
+
p.addcdiv_(exp_avg, denom, value=-step_size)
|
modules/optim/scheduler.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class ScheduledOptim():
|
4 |
+
'''A simple wrapper class for learning rate scheduling'''
|
5 |
+
|
6 |
+
def __init__(self, optimizer, init_lr, d_model, n_warmup_steps):
|
7 |
+
self._optimizer = optimizer
|
8 |
+
self.init_lr = init_lr
|
9 |
+
self.d_model = d_model
|
10 |
+
self.n_warmup_steps = n_warmup_steps
|
11 |
+
self.n_steps = 0
|
12 |
+
|
13 |
+
|
14 |
+
def step_and_update_lr(self):
|
15 |
+
"Step with the inner optimizer"
|
16 |
+
self._update_learning_rate()
|
17 |
+
self._optimizer.step()
|
18 |
+
|
19 |
+
|
20 |
+
def zero_grad(self):
|
21 |
+
"Zero out the gradients with the inner optimizer"
|
22 |
+
self._optimizer.zero_grad()
|
23 |
+
|
24 |
+
|
25 |
+
def _get_lr_scale(self):
|
26 |
+
d_model = self.d_model
|
27 |
+
n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps
|
28 |
+
return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5))
|
29 |
+
|
30 |
+
def state_dict(self):
|
31 |
+
optimizer_state_dict = {
|
32 |
+
'init_lr':self.init_lr,
|
33 |
+
'd_model':self.d_model,
|
34 |
+
'n_warmup_steps':self.n_warmup_steps,
|
35 |
+
'n_steps':self.n_steps,
|
36 |
+
'_optimizer':self._optimizer.state_dict(),
|
37 |
+
}
|
38 |
+
|
39 |
+
return optimizer_state_dict
|
40 |
+
|
41 |
+
def load_state_dict(self, state_dict):
|
42 |
+
self.init_lr = state_dict['init_lr']
|
43 |
+
self.d_model = state_dict['d_model']
|
44 |
+
self.n_warmup_steps = state_dict['n_warmup_steps']
|
45 |
+
self.n_steps = state_dict['n_steps']
|
46 |
+
|
47 |
+
self._optimizer.load_state_dict(state_dict['_optimizer'])
|
48 |
+
|
49 |
+
def _update_learning_rate(self):
|
50 |
+
''' Learning rate scheduling per step '''
|
51 |
+
|
52 |
+
self.n_steps += 1
|
53 |
+
lr = self.init_lr * self._get_lr_scale()
|
54 |
+
|
55 |
+
for param_group in self._optimizer.param_groups:
|
56 |
+
param_group['lr'] = lr
|
modules/prototypes.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from torchtext import data
|
3 |
+
import copy
|
4 |
+
import layers as layers
|
5 |
+
|
6 |
+
class Embedder(nn.Module):
|
7 |
+
def __init__(self, vocab_size, d_model):
|
8 |
+
super().__init__()
|
9 |
+
self.vocab_size = vocab_size
|
10 |
+
self.d_model = d_model
|
11 |
+
|
12 |
+
self.embed = nn.Embedding(vocab_size, d_model)
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
return self.embed(x)
|
16 |
+
|
17 |
+
class EncoderLayer(nn.Module):
|
18 |
+
def __init__(self, d_model, heads, dropout=0.1):
|
19 |
+
"""An layer of the encoder. Contain a self-attention accepting padding mask
|
20 |
+
Args:
|
21 |
+
d_model: the inner dimension size of the layer
|
22 |
+
heads: number of heads used in the attention
|
23 |
+
dropout: applied dropout value during training
|
24 |
+
"""
|
25 |
+
super().__init__()
|
26 |
+
self.norm_1 = layers.Norm(d_model)
|
27 |
+
self.norm_2 = layers.Norm(d_model)
|
28 |
+
self.attn = layers.MultiHeadAttention(heads, d_model, dropout=dropout)
|
29 |
+
self.ff = layers.FeedForward(d_model, dropout=dropout)
|
30 |
+
self.dropout_1 = nn.Dropout(dropout)
|
31 |
+
self.dropout_2 = nn.Dropout(dropout)
|
32 |
+
|
33 |
+
def forward(self, x, src_mask):
|
34 |
+
"""Run the encoding layer
|
35 |
+
Args:
|
36 |
+
x: the input (either embedding values or previous layer output), should be in shape [batch_size, src_len, d_model]
|
37 |
+
src_mask: the padding mask, should be [batch_size, 1, src_len]
|
38 |
+
Return:
|
39 |
+
an output that have the same shape as input, [batch_size, src_len, d_model]
|
40 |
+
the attention used [batch_size, src_len, src_len]
|
41 |
+
"""
|
42 |
+
x2 = self.norm_1(x)
|
43 |
+
# Self attention only
|
44 |
+
x_sa, sa = self.attn(x2, x2, x2, src_mask)
|
45 |
+
x = x + self.dropout_1(x_sa)
|
46 |
+
x2 = self.norm_2(x)
|
47 |
+
x = x + self.dropout_2(self.ff(x2))
|
48 |
+
return x, sa
|
49 |
+
|
50 |
+
class DecoderLayer(nn.Module):
|
51 |
+
def __init__(self, d_model, heads, dropout=0.1):
|
52 |
+
"""An layer of the decoder. Contain a self-attention that accept no-peeking mask and a normal attention tha t accept padding mask
|
53 |
+
Args:
|
54 |
+
d_model: the inner dimension size of the layer
|
55 |
+
heads: number of heads used in the attention
|
56 |
+
dropout: applied dropout value during training
|
57 |
+
"""
|
58 |
+
super().__init__()
|
59 |
+
self.norm_1 = layers.Norm(d_model)
|
60 |
+
self.norm_2 = layers.Norm(d_model)
|
61 |
+
self.norm_3 = layers.Norm(d_model)
|
62 |
+
|
63 |
+
self.dropout_1 = nn.Dropout(dropout)
|
64 |
+
self.dropout_2 = nn.Dropout(dropout)
|
65 |
+
self.dropout_3 = nn.Dropout(dropout)
|
66 |
+
|
67 |
+
self.attn_1 = layers.MultiHeadAttention(heads, d_model, dropout=dropout)
|
68 |
+
self.attn_2 = layers.MultiHeadAttention(heads, d_model, dropout=dropout)
|
69 |
+
self.ff = layers.FeedForward(d_model, dropout=dropout)
|
70 |
+
|
71 |
+
def forward(self, x, memory, src_mask, trg_mask):
|
72 |
+
"""Run the decoding layer
|
73 |
+
Args:
|
74 |
+
x: the input (either embedding values or previous layer output), should be in shape [batch_size, tgt_len, d_model]
|
75 |
+
memory: the outputs of the encoding section, used for normal attention. [batch_size, src_len, d_model]
|
76 |
+
src_mask: the padding mask for the memory, [batch_size, 1, src_len]
|
77 |
+
tgt_mask: the no-peeking mask for the decoder, [batch_size, tgt_len, tgt_len]
|
78 |
+
Return:
|
79 |
+
an output that have the same shape as input, [batch_size, tgt_len, d_model]
|
80 |
+
the self-attention and normal attention received [batch_size, head, tgt_len, tgt_len] & [batch_size, head, tgt_len, src_len]
|
81 |
+
"""
|
82 |
+
x2 = self.norm_1(x)
|
83 |
+
# Self-attention
|
84 |
+
x_sa, sa = self.attn_1(x2, x2, x2, trg_mask)
|
85 |
+
x = x + self.dropout_1(x_sa)
|
86 |
+
x2 = self.norm_2(x)
|
87 |
+
# Normal multi-head attention
|
88 |
+
x_na, na = self.attn_2(x2, memory, memory, src_mask)
|
89 |
+
x = x + self.dropout_2(x_na)
|
90 |
+
x2 = self.norm_3(x)
|
91 |
+
x = x + self.dropout_3(self.ff(x2))
|
92 |
+
return x, (sa, na)
|
93 |
+
|
94 |
+
def get_clones(module, N, keep_module=True):
|
95 |
+
if(keep_module and N >= 1):
|
96 |
+
# create N-1 copies in addition to the original
|
97 |
+
return nn.ModuleList([module] + [copy.deepcopy(module) for i in range(N-1)])
|
98 |
+
else:
|
99 |
+
# create N new copy
|
100 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
101 |
+
|
102 |
+
class Encoder(nn.Module):
|
103 |
+
"""A wrapper that embed, positional encode, and self-attention encode the inputs.
|
104 |
+
Args:
|
105 |
+
vocab_size: the size of the vocab. Used for embedding
|
106 |
+
d_model: the inner dim of the module
|
107 |
+
N: number of layers used
|
108 |
+
heads: number of heads used in the attention
|
109 |
+
dropout: applied dropout value during training
|
110 |
+
max_seq_length: the maximum length value used for this encoder. Needed for PositionalEncoder, due to caching
|
111 |
+
"""
|
112 |
+
def __init__(self, vocab_size, d_model, N, heads, dropout, max_seq_length=200):
|
113 |
+
super().__init__()
|
114 |
+
self.N = N
|
115 |
+
self.embed = nn.Embedding(vocab_size, d_model)
|
116 |
+
self.pe = layers.PositionalEncoder(d_model, dropout=dropout, max_seq_length=max_seq_length)
|
117 |
+
self.layers = get_clones(EncoderLayer(d_model, heads, dropout), N)
|
118 |
+
self.norm = layers.Norm(d_model)
|
119 |
+
|
120 |
+
self._max_seq_length = max_seq_length
|
121 |
+
|
122 |
+
def forward(self, src, src_mask, output_attention=False, seq_length_check=False):
|
123 |
+
"""Accepts a batch of indexed tokens, return the encoded values.
|
124 |
+
Args:
|
125 |
+
src: int Tensor of [batch_size, src_len]
|
126 |
+
src_mask: the padding mask, [batch_size, 1, src_len]
|
127 |
+
output_attention: if set, output a list containing used attention
|
128 |
+
seq_length_check: if set, automatically trim the input if it goes past the expected sequence length.
|
129 |
+
Returns:
|
130 |
+
the encoded values [batch_size, src_len, d_model]
|
131 |
+
if available, list of N (self-attention) calculated. They are in form of [batch_size, heads, src_len, src_len]
|
132 |
+
"""
|
133 |
+
if(seq_length_check and src.shape[1] > self._max_seq_length):
|
134 |
+
src = src[:, :self._max_seq_length]
|
135 |
+
src_mask = src_mask[:, :, :self._max_seq_length]
|
136 |
+
x = self.embed(src)
|
137 |
+
x = self.pe(x)
|
138 |
+
attentions = [None] * self.N
|
139 |
+
for i in range(self.N):
|
140 |
+
x, attn = self.layers[i](x, src_mask)
|
141 |
+
attentions[i] = attn
|
142 |
+
x = self.norm(x)
|
143 |
+
return x if(not output_attention) else (x, attentions)
|
144 |
+
|
145 |
+
class Decoder(nn.Module):
|
146 |
+
"""A wrapper that receive the encoder outputs, run through the decoder process for a determined input
|
147 |
+
Args:
|
148 |
+
vocab_size: the size of the vocab. Used for embedding
|
149 |
+
d_model: the inner dim of the module
|
150 |
+
N: number of layers used
|
151 |
+
heads: number of heads used in the attention
|
152 |
+
dropout: applied dropout value during training
|
153 |
+
max_seq_length: the maximum length value used for this encoder. Needed for PositionalEncoder, due to caching
|
154 |
+
"""
|
155 |
+
def __init__(self, vocab_size, d_model, N, heads, dropout, max_seq_length=200):
|
156 |
+
super().__init__()
|
157 |
+
self.N = N
|
158 |
+
self.embed = nn.Embedding(vocab_size, d_model)
|
159 |
+
self.pe = layers.PositionalEncoder(d_model, dropout=dropout, max_seq_length=max_seq_length)
|
160 |
+
self.layers = get_clones(DecoderLayer(d_model, heads, dropout), N)
|
161 |
+
self.norm = layers.Norm(d_model)
|
162 |
+
|
163 |
+
self._max_seq_length = max_seq_length
|
164 |
+
|
165 |
+
def forward(self, trg, memory, src_mask, trg_mask, output_attention=False):
|
166 |
+
"""Accepts a batch of indexed tokens and the encoding outputs, return the decoded values.
|
167 |
+
Args:
|
168 |
+
trg: input Tensor of [batch_size, trg_len]
|
169 |
+
memory: output of Encoder [batch_size, src_len, d_model]
|
170 |
+
src_mask: the padding mask, [batch_size, 1, src_len]
|
171 |
+
trg_mask: the no-peeking mask, [batch_size, tgt_len, tgt_len]
|
172 |
+
output_attention: if set, output a list containing used attention
|
173 |
+
Returns:
|
174 |
+
the decoded values [batch_size, tgt_len, d_model]
|
175 |
+
if available, list of N (self-attention, attention) calculated. They are in form of [batch_size, heads, tgt_len, tgt/src_len]
|
176 |
+
"""
|
177 |
+
x = self.embed(trg)
|
178 |
+
x = self.pe(x)
|
179 |
+
|
180 |
+
attentions = [None] * self.N
|
181 |
+
for i in range(self.N):
|
182 |
+
x, attn = self.layers[i](x, memory, src_mask, trg_mask)
|
183 |
+
attentions[i] = attn
|
184 |
+
x = self.norm(x)
|
185 |
+
return x if(not output_attention) else (x, attentions)
|
186 |
+
|
187 |
+
|
188 |
+
class Config:
|
189 |
+
"""Deprecated"""
|
190 |
+
def __init__(self):
|
191 |
+
self.opt = {
|
192 |
+
'train_src_data':'/workspace/khoai23/opennmt/data/iwslt_en_vi/train.en',
|
193 |
+
'train_trg_data':'/workspace/khoai23/opennmt/data/iwslt_en_vi/train.vi',
|
194 |
+
'valid_src_data':'/workspace/khoai23/opennmt/data/iwslt_en_vi/tst2013.en',
|
195 |
+
'valid_trg_data':'/workspace/khoai23/opennmt/data/iwslt_en_vi/tst2013.vi',
|
196 |
+
'src_lang':'en', # useless atm
|
197 |
+
'trg_lang':'en',#'vi_spacy_model', # useless atm
|
198 |
+
'max_strlen':160,
|
199 |
+
'batchsize':1500,
|
200 |
+
'device':'cuda',
|
201 |
+
'd_model': 512,
|
202 |
+
'n_layers': 6,
|
203 |
+
'heads': 8,
|
204 |
+
'dropout': 0.1,
|
205 |
+
'lr':0.0001,
|
206 |
+
'epochs':30,
|
207 |
+
'printevery': 200,
|
208 |
+
'k':5,
|
209 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
nltk
|
3 |
+
torchtext==0.6.0
|
4 |
+
pyvi
|
5 |
+
spacy
|
6 |
+
PyYAML
|
7 |
+
dill
|
8 |
+
pandas
|
9 |
+
laonlp
|
10 |
+
perl
|
third-party/multi-bleu.perl
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env perl
|
2 |
+
#
|
3 |
+
# This file is part of moses. Its use is licensed under the GNU Lesser General
|
4 |
+
# Public License version 2.1 or, at your option, any later version.
|
5 |
+
|
6 |
+
# $Id$
|
7 |
+
use warnings;
|
8 |
+
use strict;
|
9 |
+
|
10 |
+
my $lowercase = 0;
|
11 |
+
if ($ARGV[0] eq "-lc") {
|
12 |
+
$lowercase = 1;
|
13 |
+
shift;
|
14 |
+
}
|
15 |
+
|
16 |
+
my $stem = $ARGV[0];
|
17 |
+
if (!defined $stem) {
|
18 |
+
print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n";
|
19 |
+
print STDERR "Reads the references from reference or reference0, reference1, ...\n";
|
20 |
+
exit(1);
|
21 |
+
}
|
22 |
+
|
23 |
+
$stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0";
|
24 |
+
|
25 |
+
my @REF;
|
26 |
+
my $ref=0;
|
27 |
+
while(-e "$stem$ref") {
|
28 |
+
&add_to_ref("$stem$ref",\@REF);
|
29 |
+
$ref++;
|
30 |
+
}
|
31 |
+
&add_to_ref($stem,\@REF) if -e $stem;
|
32 |
+
die("ERROR: could not find reference file $stem") unless scalar @REF;
|
33 |
+
|
34 |
+
# add additional references explicitly specified on the command line
|
35 |
+
shift;
|
36 |
+
foreach my $stem (@ARGV) {
|
37 |
+
&add_to_ref($stem,\@REF) if -e $stem;
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
sub add_to_ref {
|
43 |
+
my ($file,$REF) = @_;
|
44 |
+
my $s=0;
|
45 |
+
if ($file =~ /.gz$/) {
|
46 |
+
open(REF,"gzip -dc $file|") or die "Can't read $file";
|
47 |
+
} else {
|
48 |
+
open(REF,$file) or die "Can't read $file";
|
49 |
+
}
|
50 |
+
while(<REF>) {
|
51 |
+
chomp;
|
52 |
+
push @{$$REF[$s++]}, $_;
|
53 |
+
}
|
54 |
+
close(REF);
|
55 |
+
}
|
56 |
+
|
57 |
+
my(@CORRECT,@TOTAL,$length_translation,$length_reference);
|
58 |
+
my $s=0;
|
59 |
+
while(<STDIN>) {
|
60 |
+
chomp;
|
61 |
+
$_ = lc if $lowercase;
|
62 |
+
my @WORD = split;
|
63 |
+
my %REF_NGRAM = ();
|
64 |
+
my $length_translation_this_sentence = scalar(@WORD);
|
65 |
+
my ($closest_diff,$closest_length) = (9999,9999);
|
66 |
+
foreach my $reference (@{$REF[$s]}) {
|
67 |
+
# print "$s $_ <=> $reference\n";
|
68 |
+
$reference = lc($reference) if $lowercase;
|
69 |
+
my @WORD = split(' ',$reference);
|
70 |
+
my $length = scalar(@WORD);
|
71 |
+
my $diff = abs($length_translation_this_sentence-$length);
|
72 |
+
if ($diff < $closest_diff) {
|
73 |
+
$closest_diff = $diff;
|
74 |
+
$closest_length = $length;
|
75 |
+
# print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n";
|
76 |
+
} elsif ($diff == $closest_diff) {
|
77 |
+
$closest_length = $length if $length < $closest_length;
|
78 |
+
# from two references with the same closeness to me
|
79 |
+
# take the *shorter* into account, not the "first" one.
|
80 |
+
}
|
81 |
+
for(my $n=1;$n<=4;$n++) {
|
82 |
+
my %REF_NGRAM_N = ();
|
83 |
+
for(my $start=0;$start<=$#WORD-($n-1);$start++) {
|
84 |
+
my $ngram = "$n";
|
85 |
+
for(my $w=0;$w<$n;$w++) {
|
86 |
+
$ngram .= " ".$WORD[$start+$w];
|
87 |
+
}
|
88 |
+
$REF_NGRAM_N{$ngram}++;
|
89 |
+
}
|
90 |
+
foreach my $ngram (keys %REF_NGRAM_N) {
|
91 |
+
if (!defined($REF_NGRAM{$ngram}) ||
|
92 |
+
$REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) {
|
93 |
+
$REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram};
|
94 |
+
# print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}<BR>\n";
|
95 |
+
}
|
96 |
+
}
|
97 |
+
}
|
98 |
+
}
|
99 |
+
$length_translation += $length_translation_this_sentence;
|
100 |
+
$length_reference += $closest_length;
|
101 |
+
for(my $n=1;$n<=4;$n++) {
|
102 |
+
my %T_NGRAM = ();
|
103 |
+
for(my $start=0;$start<=$#WORD-($n-1);$start++) {
|
104 |
+
my $ngram = "$n";
|
105 |
+
for(my $w=0;$w<$n;$w++) {
|
106 |
+
$ngram .= " ".$WORD[$start+$w];
|
107 |
+
}
|
108 |
+
$T_NGRAM{$ngram}++;
|
109 |
+
}
|
110 |
+
foreach my $ngram (keys %T_NGRAM) {
|
111 |
+
$ngram =~ /^(\d+) /;
|
112 |
+
my $n = $1;
|
113 |
+
# my $corr = 0;
|
114 |
+
# print "$i e $ngram $T_NGRAM{$ngram}<BR>\n";
|
115 |
+
$TOTAL[$n] += $T_NGRAM{$ngram};
|
116 |
+
if (defined($REF_NGRAM{$ngram})) {
|
117 |
+
if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) {
|
118 |
+
$CORRECT[$n] += $T_NGRAM{$ngram};
|
119 |
+
# $corr = $T_NGRAM{$ngram};
|
120 |
+
# print "$i e correct1 $T_NGRAM{$ngram}<BR>\n";
|
121 |
+
}
|
122 |
+
else {
|
123 |
+
$CORRECT[$n] += $REF_NGRAM{$ngram};
|
124 |
+
# $corr = $REF_NGRAM{$ngram};
|
125 |
+
# print "$i e correct2 $REF_NGRAM{$ngram}<BR>\n";
|
126 |
+
}
|
127 |
+
}
|
128 |
+
# $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram};
|
129 |
+
# print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n"
|
130 |
+
}
|
131 |
+
}
|
132 |
+
$s++;
|
133 |
+
}
|
134 |
+
my $brevity_penalty = 1;
|
135 |
+
my $bleu = 0;
|
136 |
+
|
137 |
+
my @bleu=();
|
138 |
+
|
139 |
+
for(my $n=1;$n<=4;$n++) {
|
140 |
+
if (defined ($TOTAL[$n])){
|
141 |
+
$bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0;
|
142 |
+
# print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n";
|
143 |
+
}else{
|
144 |
+
$bleu[$n]=0;
|
145 |
+
}
|
146 |
+
}
|
147 |
+
|
148 |
+
if ($length_reference==0){
|
149 |
+
printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n";
|
150 |
+
exit(1);
|
151 |
+
}
|
152 |
+
|
153 |
+
if ($length_translation<$length_reference) {
|
154 |
+
$brevity_penalty = exp(1-$length_reference/$length_translation);
|
155 |
+
}
|
156 |
+
$bleu = $brevity_penalty * exp((my_log( $bleu[1] ) +
|
157 |
+
my_log( $bleu[2] ) +
|
158 |
+
my_log( $bleu[3] ) +
|
159 |
+
my_log( $bleu[4] ) ) / 4) ;
|
160 |
+
printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n",
|
161 |
+
100*$bleu,
|
162 |
+
100*$bleu[1],
|
163 |
+
100*$bleu[2],
|
164 |
+
100*$bleu[3],
|
165 |
+
100*$bleu[4],
|
166 |
+
$brevity_penalty,
|
167 |
+
$length_translation / $length_reference,
|
168 |
+
$length_translation,
|
169 |
+
$length_reference;
|
170 |
+
|
171 |
+
|
172 |
+
print STDERR "It is in-advisable to publish scores from multi-bleu.perl. The scores depend on your tokenizer, which is unlikely to be reproducible from your paper or consistent across research groups. Instead you should detokenize then use mteval-v14.pl, which has a standard tokenization. Scores from multi-bleu.perl can still be used for internal purposes when you have a consistent tokenizer.\n";
|
173 |
+
|
174 |
+
sub my_log {
|
175 |
+
return -9999999999 unless $_[0];
|
176 |
+
return log($_[0]);
|
177 |
+
}
|
utils/data.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re, os
|
2 |
+
import nltk
|
3 |
+
from nltk.corpus import wordnet
|
4 |
+
import dill as pickle
|
5 |
+
import pandas as pd
|
6 |
+
from torchtext import data
|
7 |
+
from laonlp import tokenize
|
8 |
+
|
9 |
+
def multiple_replace(dict, text):
|
10 |
+
# Create a regular expression from the dictionary keys
|
11 |
+
regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
|
12 |
+
|
13 |
+
# For each match, look-up corresponding value in dictionary
|
14 |
+
return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
|
15 |
+
|
16 |
+
# get_synonym replace word with any synonym found among src
|
17 |
+
def get_synonym(word, SRC):
|
18 |
+
syns = wordnet.synsets(word)
|
19 |
+
for s in syns:
|
20 |
+
for l in s.lemmas():
|
21 |
+
if SRC.vocab.stoi[l.name()] != 0:
|
22 |
+
return SRC.vocab.stoi[l.name()]
|
23 |
+
|
24 |
+
return 0
|
25 |
+
|
26 |
+
class Tokenizer:
|
27 |
+
def __init__(self, lang=None):
|
28 |
+
if(lang is not None):
|
29 |
+
self.nlp = spacy.load(lang)
|
30 |
+
self.tokenizer_fn = self.nlp.tokenizer
|
31 |
+
else:
|
32 |
+
self.tokenizer_fn = lambda l: l.strip().split()
|
33 |
+
|
34 |
+
# def tokenize(self, sentence):
|
35 |
+
# sentence = re.sub(
|
36 |
+
# r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’\!;]", " ", str(sentence))
|
37 |
+
# sentence = re.sub(r"[ ]+", " ", sentence)
|
38 |
+
# sentence = re.sub(r"\!+", "!", sentence)
|
39 |
+
# sentence = re.sub(r"\,+", ",", sentence)
|
40 |
+
# sentence = re.sub(r"\?+", "?", sentence)
|
41 |
+
# sentence = sentence.lower()
|
42 |
+
# return [tok.text for tok in self.tokenizer_fn(sentence) if tok.text != " "]
|
43 |
+
|
44 |
+
|
45 |
+
def read_data(src_file, trg_file):
|
46 |
+
src_data = open(src_file).read().strip().split('\n')
|
47 |
+
|
48 |
+
trg_data = open(trg_file).read().strip().split('\n')
|
49 |
+
|
50 |
+
return src_data, trg_data
|
51 |
+
|
52 |
+
|
53 |
+
def read_file(file_dir):
|
54 |
+
f = open(file_dir, 'r')
|
55 |
+
data = f.read().strip().split('\n')
|
56 |
+
return data
|
57 |
+
|
58 |
+
def write_file(file_dir, content):
|
59 |
+
f = open(file_dir, "w")
|
60 |
+
f.write(content)
|
61 |
+
f.close()
|
62 |
+
|
63 |
+
def create_fields(src_lang, trg_lang):
|
64 |
+
|
65 |
+
#print("loading spacy tokenizers...")
|
66 |
+
#
|
67 |
+
# t_src = tokenize(src_lang)
|
68 |
+
# t_trg = tokenize(trg_lang)
|
69 |
+
# t_src_tokenizer = t_trg_tokenizer = lambda x: x.strip().split()
|
70 |
+
target_tokenizer = lambda x: x.strip().split()
|
71 |
+
|
72 |
+
TRG = data.Field(lower=True, tokenize=target_tokenizer, init_token='<sos>', eos_token='<eos>')
|
73 |
+
SRC = data.Field(lower=True, tokenize=tokenize.word_tokenize)
|
74 |
+
|
75 |
+
return SRC, TRG
|
76 |
+
|
77 |
+
def create_dataset(src_data, trg_data, max_strlen, batchsize, device, SRC, TRG, istrain=True):
|
78 |
+
|
79 |
+
print("creating dataset and iterator... ")
|
80 |
+
|
81 |
+
raw_data = {'src' : [line for line in src_data], 'trg': [line for line in trg_data]}
|
82 |
+
df = pd.DataFrame(raw_data, columns=["src", "trg"])
|
83 |
+
|
84 |
+
mask = (df['src'].str.count(' ') < max_strlen) & (df['trg'].str.count(' ') < max_strlen)
|
85 |
+
df = df.loc[mask]
|
86 |
+
|
87 |
+
df.to_csv("translate_transformer_temp.csv", index=False)
|
88 |
+
|
89 |
+
data_fields = [('src', SRC), ('trg', TRG)]
|
90 |
+
train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields)
|
91 |
+
|
92 |
+
train_iter = MyIterator(train, batch_size=batchsize, device=device,
|
93 |
+
repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
|
94 |
+
batch_size_fn=batch_size_fn, train=istrain, shuffle=True)
|
95 |
+
|
96 |
+
os.remove('translate_transformer_temp.csv')
|
97 |
+
|
98 |
+
if istrain:
|
99 |
+
SRC.build_vocab(train)
|
100 |
+
TRG.build_vocab(train)
|
101 |
+
|
102 |
+
return train_iter
|
103 |
+
|
104 |
+
class MyIterator(data.Iterator):
|
105 |
+
def create_batches(self):
|
106 |
+
if self.train:
|
107 |
+
def pool(d, random_shuffler):
|
108 |
+
for p in data.batch(d, self.batch_size * 100):
|
109 |
+
p_batch = data.batch(
|
110 |
+
sorted(p, key=self.sort_key),
|
111 |
+
self.batch_size, self.batch_size_fn)
|
112 |
+
for b in random_shuffler(list(p_batch)):
|
113 |
+
yield b
|
114 |
+
self.batches = pool(self.data(), self.random_shuffler)
|
115 |
+
|
116 |
+
else:
|
117 |
+
self.batches = []
|
118 |
+
for b in data.batch(self.data(), self.batch_size,
|
119 |
+
self.batch_size_fn):
|
120 |
+
self.batches.append(sorted(b, key=self.sort_key))
|
121 |
+
|
122 |
+
global max_src_in_batch, max_tgt_in_batch
|
123 |
+
|
124 |
+
def batch_size_fn(new, count, sofar):
|
125 |
+
"Keep augmenting batch and calculate total number of tokens + padding."
|
126 |
+
global max_src_in_batch, max_tgt_in_batch
|
127 |
+
if count == 1:
|
128 |
+
max_src_in_batch = 0
|
129 |
+
max_tgt_in_batch = 0
|
130 |
+
max_src_in_batch = max(max_src_in_batch, len(new.src))
|
131 |
+
max_tgt_in_batch = max(max_tgt_in_batch, len(new.trg) + 2)
|
132 |
+
src_elements = count * max_src_in_batch
|
133 |
+
tgt_elements = count * max_tgt_in_batch
|
134 |
+
return max(src_elements, tgt_elements)
|
135 |
+
|
136 |
+
def generate_language_token(lang: str):
|
137 |
+
return '<{}>'.format(lang.strip())
|
utils/decode_old.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
from torch.autograd import Variable
|
5 |
+
import torch.nn.functional as functional
|
6 |
+
|
7 |
+
from utils.data import multiple_replace, get_synonym
|
8 |
+
|
9 |
+
def no_peeking_mask(size, device):
|
10 |
+
"""
|
11 |
+
Tạo mask được sử dụng trong decoder để lúc dự đoán trong quá trình huấn luyện
|
12 |
+
mô hình không nhìn thấy được các từ ở tương lai
|
13 |
+
"""
|
14 |
+
np_mask = np.triu(np.ones((1, size, size)),
|
15 |
+
k=1).astype('uint8')
|
16 |
+
np_mask = Variable(torch.from_numpy(np_mask) == 0)
|
17 |
+
np_mask = np_mask.to(device)
|
18 |
+
|
19 |
+
return np_mask
|
20 |
+
|
21 |
+
def create_masks(src, trg, src_pad, trg_pad, device):
|
22 |
+
""" Tạo mask cho encoder,
|
23 |
+
để mô hình không bỏ qua thông tin của các kí tự PAD do chúng ta thêm vào
|
24 |
+
"""
|
25 |
+
src_mask = (src != src_pad).unsqueeze(-2)
|
26 |
+
|
27 |
+
if trg is not None:
|
28 |
+
trg_mask = (trg != trg_pad).unsqueeze(-2)
|
29 |
+
size = trg.size(1) # get seq_len for matrix
|
30 |
+
np_mask = no_peeking_mask(size, device)
|
31 |
+
if trg.is_cuda:
|
32 |
+
np_mask.cuda()
|
33 |
+
trg_mask = trg_mask & np_mask
|
34 |
+
|
35 |
+
else:
|
36 |
+
trg_mask = None
|
37 |
+
return src_mask, trg_mask
|
38 |
+
|
39 |
+
def init_vars(src, model, SRC, TRG, device, k, max_len):
|
40 |
+
""" Tính toán các ma trận cần thiết trong quá trình translation sau khi mô hình học xong
|
41 |
+
"""
|
42 |
+
init_tok = TRG.vocab.stoi['<sos>']
|
43 |
+
src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
|
44 |
+
|
45 |
+
# tính sẵn output của encoder
|
46 |
+
e_output = model.encoder(src, src_mask)
|
47 |
+
|
48 |
+
outputs = torch.LongTensor([[init_tok]])
|
49 |
+
|
50 |
+
outputs = outputs.to(device)
|
51 |
+
|
52 |
+
trg_mask = no_peeking_mask(1, device)
|
53 |
+
# dự đoán kí tự đầu tiên
|
54 |
+
out = model.out(model.decoder(outputs,
|
55 |
+
e_output, src_mask, trg_mask))
|
56 |
+
out = functional.softmax(out, dim=-1)
|
57 |
+
|
58 |
+
probs, ix = out[:, -1].data.topk(k)
|
59 |
+
log_scores = torch.Tensor([math.log(prob) for prob in probs.data[0]]).unsqueeze(0)
|
60 |
+
|
61 |
+
outputs = torch.zeros(k, max_len).long()
|
62 |
+
outputs = outputs.to(device)
|
63 |
+
outputs[:, 0] = init_tok
|
64 |
+
outputs[:, 1] = ix[0]
|
65 |
+
|
66 |
+
e_outputs = torch.zeros(k, e_output.size(-2),e_output.size(-1))
|
67 |
+
|
68 |
+
e_outputs = e_outputs.to(device)
|
69 |
+
e_outputs[:, :] = e_output[0]
|
70 |
+
|
71 |
+
return outputs, e_outputs, log_scores
|
72 |
+
|
73 |
+
def k_best_outputs(outputs, out, log_scores, i, k):
|
74 |
+
# debug print
|
75 |
+
|
76 |
+
probs, ix = out[:, -1].data.topk(k)
|
77 |
+
log_probs = torch.Tensor([math.log(p) for p in probs.data.view(-1)]).view(k, -1) + log_scores.transpose(0,1)
|
78 |
+
k_probs, k_ix = log_probs.view(-1).topk(k)
|
79 |
+
|
80 |
+
row = k_ix // k
|
81 |
+
col = k_ix % k
|
82 |
+
|
83 |
+
outputs[:, :i] = outputs[row, :i]
|
84 |
+
outputs[:, i] = ix[row, col]
|
85 |
+
|
86 |
+
log_scores = k_probs.unsqueeze(0)
|
87 |
+
|
88 |
+
return outputs, log_scores
|
89 |
+
|
90 |
+
def beam_search(src, model, SRC, TRG, device, k, max_len, debug=False, output_list_of_tokens=False):
|
91 |
+
|
92 |
+
outputs, e_outputs, log_scores = init_vars(src, model, SRC, TRG, device, k, max_len)
|
93 |
+
eos_tok = TRG.vocab.stoi['<eos>']
|
94 |
+
src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
|
95 |
+
ind = None
|
96 |
+
for i in range(2, max_len):
|
97 |
+
if(debug):
|
98 |
+
print("Current iteration to maxlen: {:d}".format(i))
|
99 |
+
|
100 |
+
trg_mask = no_peeking_mask(i, device)
|
101 |
+
|
102 |
+
out = model.out(model.decoder(outputs[:,:i], e_outputs, src_mask, trg_mask))
|
103 |
+
|
104 |
+
out = functional.softmax(out, dim=-1)
|
105 |
+
|
106 |
+
outputs, log_scores = k_best_outputs(outputs, out, log_scores, i, k)
|
107 |
+
|
108 |
+
ones = (outputs==eos_tok).nonzero() # Occurrences of end symbols for all input sentences.
|
109 |
+
sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).to(device)
|
110 |
+
for vec in ones:
|
111 |
+
i = vec[0]
|
112 |
+
if sentence_lengths[i]==0: # First end symbol has not been found yet
|
113 |
+
sentence_lengths[i] = vec[1] # Position of first end symbol
|
114 |
+
|
115 |
+
num_finished_sentences = len([s for s in sentence_lengths if s > 0])
|
116 |
+
|
117 |
+
if num_finished_sentences == k:
|
118 |
+
alpha = 0.7
|
119 |
+
div = 1/(sentence_lengths.type_as(log_scores)**alpha)
|
120 |
+
_, ind = torch.max(log_scores * div, 1)
|
121 |
+
ind = ind.data[0]
|
122 |
+
break
|
123 |
+
|
124 |
+
# additional change to output list of tokens instead of string
|
125 |
+
join_fn = (lambda x: x) if(output_list_of_tokens) else (lambda x: " ".join(x))
|
126 |
+
|
127 |
+
if ind is None:
|
128 |
+
length = (outputs[0]==eos_tok).nonzero()[0] if len((outputs[0]==eos_tok).nonzero()) > 0 else -1
|
129 |
+
return join_fn([TRG.vocab.itos[tok] for tok in outputs[0, 1:length]])
|
130 |
+
else:
|
131 |
+
length = (outputs[ind]==eos_tok).nonzero()[0]
|
132 |
+
return join_fn([TRG.vocab.itos[tok] for tok in outputs[ind, 1:length]])
|
133 |
+
|
134 |
+
def translate_sentence(raw_sentence, model, SRC, TRG, device, k, max_len, debug=False, output_list_of_tokens=False):
|
135 |
+
"""Dịch một câu sử dụng beamsearch
|
136 |
+
"""
|
137 |
+
model.eval()
|
138 |
+
indexed = []
|
139 |
+
if(isinstance(raw_sentence, str)):
|
140 |
+
# single sentence, require preprocessing
|
141 |
+
sentence = SRC.preprocess(raw_sentence)
|
142 |
+
else:
|
143 |
+
# already tokenized (taken from iterators, etc.)
|
144 |
+
sentence = raw_sentence
|
145 |
+
|
146 |
+
for tok in sentence:
|
147 |
+
if SRC.vocab.stoi[tok] != SRC.vocab.stoi['<eos>']:
|
148 |
+
indexed.append(SRC.vocab.stoi[tok])
|
149 |
+
else:
|
150 |
+
indexed.append(get_synonym(tok, SRC))
|
151 |
+
|
152 |
+
output = Variable(torch.LongTensor([indexed]))
|
153 |
+
|
154 |
+
output = output.to(device)
|
155 |
+
|
156 |
+
output = beam_search(output, model, SRC, TRG, device, k, max_len, output_list_of_tokens=output_list_of_tokens)
|
157 |
+
|
158 |
+
if(debug):
|
159 |
+
print("{} -> {}".format(raw_sentence, output))
|
160 |
+
|
161 |
+
return output
|
162 |
+
|
163 |
+
# return multiple_replace({' ?' : '?',' !':'!',' .':'.','\' ':'\'',' ,':','}, sentence)
|
utils/logging.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
|
3 |
+
import os
|
4 |
+
import logging
|
5 |
+
|
6 |
+
def init_logger(model_dir, log_file=None, rotate=False):
|
7 |
+
|
8 |
+
logging.basicConfig(level=logging.DEBUG,
|
9 |
+
format='[%(asctime)s %(levelname)s] %(message)s',
|
10 |
+
datefmt='%a, %d %b %Y %H:%M:%S',
|
11 |
+
filename=os.path.join(model_dir, log_file),
|
12 |
+
filemode='w')
|
13 |
+
console = logging.StreamHandler()
|
14 |
+
console.setLevel(logging.INFO)
|
15 |
+
# set a format which is simpler for console use
|
16 |
+
formatter = logging.Formatter('[%(asctime)s %(levelname)s] %(message)s', '%a, %d %b %Y %H:%M:%S')
|
17 |
+
# tell the handler to use this format
|
18 |
+
console.setFormatter(formatter)
|
19 |
+
# add the handler to the root logger
|
20 |
+
logging.getLogger('').addHandler(console)
|
21 |
+
|
22 |
+
return logging
|
utils/loss.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class LabelSmoothingLoss(nn.Module):
|
5 |
+
def __init__(self, classes, padding_idx, smoothing=0.0, dim=-1):
|
6 |
+
super(LabelSmoothingLoss, self).__init__()
|
7 |
+
self.confidence = 1.0 - smoothing
|
8 |
+
self.smoothing = smoothing
|
9 |
+
self.cls = classes
|
10 |
+
self.dim = dim
|
11 |
+
self.padding_idx = padding_idx
|
12 |
+
|
13 |
+
def forward(self, pred, target):
|
14 |
+
pred = pred.log_softmax(dim=self.dim)
|
15 |
+
with torch.no_grad():
|
16 |
+
# true_dist = pred.data.clone()
|
17 |
+
true_dist = torch.zeros_like(pred)
|
18 |
+
true_dist.fill_(self.smoothing / (self.cls - 2))
|
19 |
+
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
|
20 |
+
true_dist[:, self.padding_idx] = 0
|
21 |
+
mask = torch.nonzero(target.data == self.padding_idx) #, as_tuple=False is redundant and causing error
|
22 |
+
if mask.dim() > 0:
|
23 |
+
true_dist.index_fill_(0, mask.squeeze(), 0.0)
|
24 |
+
|
25 |
+
return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
|
utils/metric.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchtext.data.metrics import bleu_score
|
2 |
+
|
3 |
+
def bleu(valid_src_data, valid_trg_data, model, device, k, max_strlen):
|
4 |
+
pred_sents = []
|
5 |
+
for sentence in valid_src_data:
|
6 |
+
pred_trg = model.translate_sentence(sentence, device, k, max_strlen)
|
7 |
+
pred_sents.append(pred_trg)
|
8 |
+
|
9 |
+
pred_sents = [self.TRG.preprocess(sent) for sent in pred_sents]
|
10 |
+
trg_sents = [[sent.split()] for sent in valid_trg_data]
|
11 |
+
|
12 |
+
return bleu_score(pred_sents, trg_sents)
|
13 |
+
|
14 |
+
def bleu_single(model, valid_dataset, debug=False):
|
15 |
+
"""Perform single sentence translation, then calculate bleu score. Update when batch beam search is online"""
|
16 |
+
# need to join the sentence back per sample (the iterator is the one that had already been split to tokens)
|
17 |
+
# THIS METRIC USE 2D vs 3D! AAAAAAHHHHHHH!!!!
|
18 |
+
translate_pair = ( ([pair.trg], model.translate_sentence(pair.src, debug=debug)) for pair in valid_dataset)
|
19 |
+
# raise Exception(next(translate_pair))
|
20 |
+
labels, predictions = [list(l) for l in zip(*translate_pair)] # zip( *((l, p.split()) for l, p in translate_pair) )
|
21 |
+
return bleu_score(predictions, labels)
|
22 |
+
|
23 |
+
def bleu_batch(model, valid_dataset, batch_size, debug=False):
|
24 |
+
"""Perform batch sentence translation in the same vein."""
|
25 |
+
predictions = model.translate_batch_sentence([s.src for s in valid_dataset], output_tokens=True, batch_size=batch_size)
|
26 |
+
labels = [[s.trg] for s in valid_dataset]
|
27 |
+
return bleu_score(predictions, labels)
|
28 |
+
|
29 |
+
|
30 |
+
def _revert_trg(sent, eos): # revert batching process on sentence
|
31 |
+
try:
|
32 |
+
endloc = sent.index(eos)
|
33 |
+
return sent[1:endloc]
|
34 |
+
except ValueError:
|
35 |
+
return sent[1:]
|
36 |
+
|
37 |
+
def bleu_batch_iter(model, valid_iter, src_lang=None, trg_lang=None, eos_token="<eos>", debug=False):
|
38 |
+
"""Perform batched translations; other metrics are the same. Note that the inputs/outputs had been preprocessed, but have [length, batch_size] shape as per BucketIterator"""
|
39 |
+
# raise NotImplementedError("Error during calculation, currently unusable.")
|
40 |
+
# raise Exception([[model.SRC.vocab.itos[t] for t in batch] for batch in next(iter(valid_iter)).src.transpose(0, 1)])
|
41 |
+
|
42 |
+
translated_batched_pair = (
|
43 |
+
(
|
44 |
+
pair.trg.transpose(0, 1), # transpose due to timestep-first batches
|
45 |
+
model.decode_strategy.translate_batch_sentence(
|
46 |
+
pair.src.transpose(0, 1),
|
47 |
+
src_lang=src_lang,
|
48 |
+
trg_lang=trg_lang,
|
49 |
+
output_tokens=True,
|
50 |
+
field_processed=True,
|
51 |
+
replace_unk=False, # do not replace in this version
|
52 |
+
debug=debug
|
53 |
+
)
|
54 |
+
)
|
55 |
+
for pair in valid_iter
|
56 |
+
)
|
57 |
+
flattened_pair = ( ([model.TRG.vocab.itos[i] for i in trg], pred) for batch_trg, batch_pred in translated_batched_pair for trg, pred in zip(batch_trg, batch_pred) )
|
58 |
+
flat_labels, predictions = [list(l) for l in zip(*flattened_pair)]
|
59 |
+
labels = [[_revert_trg(l, eos_token)] for l in flat_labels] # remove <sos> and <eos> also updim the trg for 3D requirements.
|
60 |
+
return bleu_score(predictions, labels)
|