hieungo1410 commited on
Commit
8cb4f3b
1 Parent(s): cbb24e0
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +6 -0
  2. README.md +156 -13
  3. bin/__init__.py +1 -0
  4. bin/main.py +73 -0
  5. bin/serve.py +108 -0
  6. config/bilingual_prototype.yml +52 -0
  7. config/prototype.json +25 -0
  8. layers/__init__.py +1 -0
  9. layers/prototypes.py +148 -0
  10. models/__init__.py +7 -0
  11. models/default.py +13 -0
  12. models/transformer.py +404 -0
  13. modules/__init__.py +3 -0
  14. modules/config.py +62 -0
  15. modules/constants.py +18 -0
  16. modules/default.py +54 -0
  17. modules/inference/__init__.py +10 -0
  18. modules/inference/__pycache__/__init__.cpython-36.pyc +0 -0
  19. modules/inference/__pycache__/beam_search.cpython-36.pyc +0 -0
  20. modules/inference/__pycache__/decode_strategy.cpython-36.pyc +0 -0
  21. modules/inference/__pycache__/prototypes.cpython-36.pyc +0 -0
  22. modules/inference/__pycache__/sampling_temperature.cpython-36.pyc +0 -0
  23. modules/inference/beam_search.py +336 -0
  24. modules/inference/beam_search1.py +346 -0
  25. modules/inference/decode_strategy.py +62 -0
  26. modules/inference/greedy_search.py +121 -0
  27. modules/inference/prototypes.py +144 -0
  28. modules/inference/sampling_temperature.py +119 -0
  29. modules/loader/__init__.py +4 -0
  30. modules/loader/__pycache__/__init__.cpython-36.pyc +0 -0
  31. modules/loader/__pycache__/default_loader.cpython-36.pyc +0 -0
  32. modules/loader/__pycache__/multilingual_loader.cpython-36.pyc +0 -0
  33. modules/loader/default_loader.py +114 -0
  34. modules/loader/multilingual_loader.py +139 -0
  35. modules/optim/__init__.py +5 -0
  36. modules/optim/__pycache__/__init__.cpython-36.pyc +0 -0
  37. modules/optim/__pycache__/adabelief.cpython-36.pyc +0 -0
  38. modules/optim/__pycache__/adam.cpython-36.pyc +0 -0
  39. modules/optim/__pycache__/scheduler.cpython-36.pyc +0 -0
  40. modules/optim/adabelief.py +42 -0
  41. modules/optim/adam.py +38 -0
  42. modules/optim/scheduler.py +56 -0
  43. modules/prototypes.py +209 -0
  44. requirements.txt +10 -0
  45. third-party/multi-bleu.perl +177 -0
  46. utils/data.py +137 -0
  47. utils/decode_old.py +163 -0
  48. utils/logging.py +22 -0
  49. utils/loss.py +25 -0
  50. utils/metric.py +60 -0
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ bin/__pycache__
2
+ layers/__pycache__
3
+ models/__pycache__
4
+ modules/__pycache__
5
+ utils/__pycache__
6
+ data/raw/preprocess.ipynb
README.md CHANGED
@@ -1,13 +1,156 @@
1
- ---
2
- title: NMT LaVi
3
- emoji: 🐢
4
- colorFrom: red
5
- colorTo: pink
6
- sdk: streamlit
7
- sdk_version: 1.28.2
8
- app_file: app.py
9
- pinned: false
10
- license: unknown
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Dự án MultilingualMT-UET-KC4.0 là dự án open-source được phát triển bởi nhóm UETNLPLab.
2
+
3
+ # Setup
4
+ ## Cài đặt công cụ Multilingual-NMT
5
+
6
+ **Note**:
7
+ Lưu ý:
8
+ Phiên bản hiện tại chỉ tương thích với python>=3.6
9
+ ```bash
10
+ git clone https://github.com/KCDichDaNgu/KC4.0_MultilingualNMT.git
11
+ cd KC4.0_MultilingualNMT
12
+ pip install -r requirements.txt
13
+
14
+ # Quickstart
15
+
16
+ ```
17
+
18
+ ## Bước 1: Chuẩn bị dữ liệu
19
+
20
+ Ví dụ thực nghiệm dựa trên cặp dữ liệu Anh-Việt nguồn từ iwslt với 133k cặp câu:
21
+
22
+ ```bash
23
+ cd data/iwslt_en_vi
24
+ ```
25
+
26
+ Dữ liệu bao gồm câu nguồn (`src`) và câu đích (`tgt`) dữ liệu đã được tách từ:
27
+
28
+ * `train.en`
29
+ * `train.vi`
30
+ * `tst2012.en`
31
+ * `tst2012.vi`
32
+
33
+ | Data set | Sentences | Download |
34
+ | :---------: | :--------: | :-------------------------------------------: |
35
+ | Training | 133,317 | via GitHub or located in data/train-en-vi.tgz |
36
+ | Development | 1,553 | via GitHub or located in data/train-en-vi.tgz |
37
+ | Test | 1,268 | via GitHub or located in data/train-en-vi.tgz |
38
+
39
+
40
+ **Note**:
41
+ Lưu ý:
42
+ - Dữ liệu trước khi đưa vào huấn luyện cần phải được tokenize.
43
+ - $CONFIG là đường dẫn tới vị trí chứa file config
44
+
45
+ Tách dữ liệu dev để tính toán hội tụ trong quá trình huấn luyện, thường không lớn hơn 5k câu.
46
+
47
+ ```text
48
+ $ head -n 5 data/iwslt_en_vi/train.en
49
+ Rachel Pike : The science behind a climate headline
50
+ In 4 minutes , atmospheric chemist Rachel Pike provides a glimpse of the massive scientific effort behind the bold headlines on climate change , with her team -- one of thousands who contributed -- taking a risky flight over the rainforest in pursuit of data on a key molecule .
51
+ I 'd like to talk to you today about the scale of the scientific effort that goes into making the headlines you see in the paper .
52
+ Headlines that look like this when they have to do with climate change , and headlines that look like this when they have to do with air quality or smog .
53
+ They are both two branches of the same field of atmospheric science .
54
+ ```
55
+
56
+ ## Bước 2: Huấn luyện mô hình
57
+
58
+ Để huấn luyện một mô hình mới **hãy chỉnh sửa file YAML config**:
59
+ Cần phải sửa lại file config en_vi.yml chỉnh siêu tham số và đường dẫn tới dữ liệu huấn luyện:
60
+
61
+ ```yaml
62
+ # data location and config section
63
+ data:
64
+ train_data_location: data/iwslt_en_vi/train
65
+ eval_data_location: data/iwslt_en_vi/tst2013
66
+ src_lang: .en
67
+ trg_lang: .vi
68
+ log_file_models: 'model.log'
69
+ lowercase: false
70
+ build_vocab_kwargs: # additional arguments for build_vocab. See torchtext.vocab.Vocab for mode details
71
+ # max_size: 50000
72
+ min_freq: 5
73
+ # model parameters section
74
+ device: cuda
75
+ d_model: 512
76
+ n_layers: 6
77
+ heads: 8
78
+ # inference section
79
+ eval_batch_size: 8
80
+ decode_strategy: BeamSearch
81
+ decode_strategy_kwargs:
82
+ beam_size: 5 # beam search size
83
+ length_normalize: 0.6 # recalculate beam position by length. Currently only work in default BeamSearch
84
+ replace_unk: # tuple of layer/head attention to replace unknown words
85
+ - 0 # layer
86
+ - 0 # head
87
+ input_max_length: 200 # input longer than this value will be trimmed in inference. Note that this values are to be used during cached PE, hence, validation set with more than this much tokens will call a warning for the trimming.
88
+ max_length: 160 # only perform up to this much timestep during inference
89
+ train_max_length: 50 # training samples with this much length in src/trg will be discarded
90
+ # optimizer and learning arguments section
91
+ lr: 0.2
92
+ optimizer: AdaBelief
93
+ optimizer_params:
94
+ betas:
95
+ - 0.9 # beta1
96
+ - 0.98 # beta2
97
+ eps: !!float 1e-9
98
+ n_warmup_steps: 4000
99
+ label_smoothing: 0.1
100
+ dropout: 0.1
101
+ # training config, evaluation, save & load section
102
+ batch_size: 64
103
+ epochs: 20
104
+ printevery: 200
105
+ save_checkpoint_epochs: 1
106
+ maximum_saved_model_eval: 5
107
+ maximum_saved_model_train: 5
108
+
109
+ ```
110
+
111
+ Sau đó có thể chạy với câu lệnh:
112
+
113
+ ```bash
114
+ python -m bin.main train --model Transformer --model_dir $MODEL/en-vi.model --config $CONFIG/en_vi.yml
115
+ ```
116
+
117
+ **Note**:
118
+ Ở đây:
119
+ - $MODEL là dường dẫn tới vị trí lưu mô hình. Sau khi huấn luyện mô hình, thư mục chứa mô hình bao gồm mô hình huyến luyện, file config, file log, vocab.
120
+ - $CONFIG là đường dẫn tới vị trí chứa file config
121
+
122
+ ## Bước 3: Dịch
123
+
124
+ Mô hình dịch dựa trên thuật toán beam search và lưu bản dịch tại `$your_data_path/translate.en2vi.vi`.
125
+
126
+ ```bash
127
+ python -m bin.main infer --model Transformer --model_dir $MODEL/en-vi.model --features_file $your_data_path/tst2012.en --predictions_file $your_data_path/translate.en2vi.vi
128
+ ```
129
+
130
+ ## Bước 4: Đánh giá chất lượng dựa trên điểm BLEU
131
+
132
+ Đánh giá điểm BLEU dựa trên multi-bleu
133
+
134
+ ```bash
135
+ perl thrid-party/multi-bleu.perl $your_data_path/translate.en2vi.vi < $your_data_path/tst2012.vi
136
+ ```
137
+
138
+ | MODEL | BLEU (Beam Search) |
139
+ | :-----------------:| :----------------: |
140
+ | Transformer (Base) | 25.64 |
141
+
142
+
143
+ ## Chi tiết tham khảo tại
144
+ [nmtuet.ddns.net](http://nmtuet.ddns.net:1190/)
145
+
146
+ ## Nếu có ý kiến đóng góp, xin hãy gửi thư tới địa chỉ mail kcdichdangu@gmail.com
147
+
148
+ ## Xin trích dẫn bài báo sau:
149
+ ```bash
150
+ @inproceedings{ViNMT2022,
151
+ title = {ViNMT: Neural Machine Translation Toolkit},
152
+ author = {Nguyen Hoang Quan, Nguyen Thanh Dat, Nguyen Hoang Minh Cong, Nguyen Van Vinh, Ngo Thi Vinh, Nguyen Phuong Thai, Tran Hong Viet},
153
+ booktitle = {https://arxiv.org/abs/2112.15272},
154
+ year = {2022},
155
+ }
156
+ ```
bin/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ import bin.main as main
bin/main.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import models
2
+ import argparse, os
3
+ from shutil import copy2 as copy
4
+ from modules.config import find_all_config
5
+
6
+ OVERRIDE_RUN_MODE = {"serve": "infer", "debug": "eval"}
7
+
8
+ def check_valid_file(path):
9
+ if(os.path.isfile(path)):
10
+ return path
11
+ else:
12
+ raise argparse.ArgumentError("This path {:s} is not a valid file, check again.".format(path))
13
+
14
+ def create_torchscript_model(model, model_dir, model_name):
15
+ """Create a torchscript model using junk data. NOTE: same as tensorflow, is a limited model with no native python script."""
16
+ import torch
17
+ junk_input = torch.rand(2, 10)
18
+ junk_output = torch.rand(2, 7)
19
+ traced_model = torch.jit.trace(model, junk_input, junk_output)
20
+ save_location = os.path.join(model_dir, model_name)
21
+ traced_model.save(save_location)
22
+
23
+ if __name__ == "__main__":
24
+ parser = argparse.ArgumentParser(description="Main argument parser")
25
+ parser.add_argument("run_mode", choices=("train", "eval", "infer", "debug", "serve"), help="Main running mode of the program")
26
+ parser.add_argument("--model", type=str, choices=models.AvailableModels.keys(), help="The type of model to be ran")
27
+ parser.add_argument("--model_dir", type=str, required=True, help="Location of model")
28
+ parser.add_argument("--config", type=str, nargs="+", default=None, help="Location of the config file")
29
+ parser.add_argument("--no_keeping_config", action="store_false", help="If set, do not copy the config file to the model directory")
30
+ # arguments for inference
31
+ parser.add_argument("--features_file", type=str, help="Inference mode: Provide the location of features file")
32
+ parser.add_argument("--predictions_file", type=str, help="Inference mode: Provide Location of output file which is predicted from features file")
33
+ parser.add_argument("--src_lang", type=str, help="Inference mode: Provide language used by source file")
34
+ parser.add_argument("--trg_lang", type=str, default=None, help="Inference mode: Choose language that is translated from source file. NOTE: only specify for multilingual model")
35
+ parser.add_argument("--infer_batch_size", type=int, default=None, help="Specify the batch_size to run the model with. Default use the config value.")
36
+ parser.add_argument("--checkpoint", type=str, default=None, help="All mode: specify to load the checkpoint into model.")
37
+ parser.add_argument("--checkpoint_idx", type=int, default=0, help="All mode: specify the epoch of the checkpoint loaded. Only useful for training.")
38
+ parser.add_argument("--serve_path", type=str, default=None, help="File to save TorchScript model into.")
39
+
40
+ args = parser.parse_args()
41
+ # create directory if not exist
42
+ os.makedirs(args.model_dir, exist_ok=True)
43
+ config_path = args.config
44
+ if(config_path is None):
45
+ config_path = find_all_config(args.model_dir)
46
+ print("Config path not specified, load the configs in model directory which is {}".format(config_path))
47
+ elif(args.no_keeping_config):
48
+ # store false variable, mean true is default
49
+ print("Config specified, copying all to model dir")
50
+ for subpath in config_path:
51
+ copy(subpath, args.model_dir)
52
+
53
+ # load model. Specific run mode required converting
54
+ run_mode = OVERRIDE_RUN_MODE.get(args.run_mode, args.run_mode)
55
+ model = models.AvailableModels[args.model](config=config_path, model_dir=args.model_dir, mode=run_mode)
56
+ model.load_checkpoint(args.model_dir, checkpoint=args.checkpoint, checkpoint_idx=args.checkpoint_idx)
57
+ # run model
58
+ run_mode = args.run_mode
59
+ if(run_mode == "train"):
60
+ model.run_train(model_dir=args.model_dir, config=config_path)
61
+ elif(run_mode == "eval"):
62
+ model.run_eval(model_dir=args.model_dir, config=config_path)
63
+ elif(run_mode == "infer"):
64
+ model.run_infer(args.features_file, args.predictions_file, src_lang=args.src_lang, trg_lang=args.trg_lang, config=config_path, batch_size=args.infer_batch_size)
65
+ elif(run_mode == "debug"):
66
+ raise NotImplementedError
67
+ model.run_debug(model_dir=args.model_dir, config=config_path)
68
+ elif(run_mode == "serve"):
69
+ if(args.serve_path is None):
70
+ raise parser.ArgumentError("In serving, --serve_path cannot be empty")
71
+ model.prepare_serve(args.serve_path, model_dir=args.model_dir, config=config_path)
72
+ else:
73
+ raise ValueError("Run mode {:s} not implemented.".format(run_mode))
bin/serve.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ #import utils.save as saver
4
+ #import models
5
+ #from models.transformer import Transformer
6
+ #from modules.config import find_all_config
7
+
8
+ class TransformerHandlerClass:
9
+ def __init__(self):
10
+ self.model = None
11
+ self.device = None
12
+ self.initialized = False
13
+
14
+ def _find_checkpoint(self, model_dir, best_model_prefix="best_model", model_prefix="model", validate=True):
15
+ """Attempt to retrieve the best model checkpoint from model_dir. Failing that, the model of the latest iteration.
16
+ Args:
17
+ model_dir: location to search for checkpoint. str
18
+ Returns:
19
+ single str denoting the checkpoint path """
20
+ score_file_path = os.path.join(model_dir, saver.BEST_MODEL_FILE)
21
+ if(os.path.isfile(score_file_path)): # score exist -> best model
22
+ best_model_path = os.path.join(model_dir, saver.MODEL_FILE_FORMAT.format(best_model_prefix, 0, saver.MODEL_EXTENSION))
23
+ if(validate):
24
+ assert os.path.isfile(best_model_path), "Score file is available, but file {:s} is missing.".format(best_model_path)
25
+ return best_model_path
26
+ else: # score not exist -> latest model
27
+ last_checkpoint_idx = saver.check_model_in_path(name_prefix=model_prefix)
28
+ if(last_checkpoint_idx == 0):
29
+ raise ValueError("No checkpoint found in folder {:s} with prefix {:s}.".format(model_dir, model_prefix))
30
+ else:
31
+ return os.path.join(model_dir, saver.MODEL_FILE_FORMAT.format(model_prefix, last_checkpoint_idx, saver.MODEL_EXTENSION))
32
+
33
+
34
+ def initialize(self, ctx):
35
+ manifest = ctx.manifest
36
+ properties = ctx.system_properties
37
+
38
+ self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
39
+ self.model_dir = model_dir = properties.get("model_dir")
40
+
41
+ # extract checkpoint location, config & model name
42
+ model_serve_file = os.path.join(model_dir, saver.MODEL_SERVE_FILE)
43
+ with io.open(model_serve_file, "r") as serve_config:
44
+ model_name = serve_config.read().strip()
45
+ # model_cls = models.AvailableModels[model_name]
46
+ model_cls = Transformer # can't select due to nature of model file
47
+ checkpoint_path = manifest['model'].get('serializedFile', self._find_checkpoint(model_dir)) # attempt to use the checkpoint fed from archiver; else use the best checkpoint found
48
+ config_path = find_all_config(model_dir)
49
+
50
+ # load model with inbuilt config + vocab & without pretraining data
51
+ self.model = model = model_cls(config=config_path, model_dir=model_dir, mode="infer")
52
+ model.load_checkpoint(args.model_dir, checkpoint=checkpoint_path) # TODO find_checkpoint might do some redundant thing here since load_checkpoint had already done searching for latest
53
+
54
+ print("Model {:s} loaded successfully at location {:s}.".format(model_name, model_dir))
55
+ self.initialized = True
56
+
57
+ def handle(self, data):
58
+ """The main bulk of handling. Process a batch of data received from client.
59
+ Args:
60
+ data: the object received from client. Should contain something in [batch_size] of str
61
+ Returns:
62
+ the expected translation, [batch_size] of str
63
+ """
64
+ batch_sentences = data[0].get("data")
65
+ # assert batch_sentences is not None, "data is {}".format(data)
66
+
67
+ # make sure that sentences are detokenized before returning
68
+ translated_sentences = self.model.translate_batch(batch_sentences, output_tokens=False)
69
+
70
+ return translated_sentences
71
+
72
+ class BeamSearchHandlerClass:
73
+ def __init__(self):
74
+ self.model = None
75
+ self.inferrer = None
76
+ self.initialized = False
77
+
78
+ def initialize(self, ctx):
79
+ manifest = ctx.manifest
80
+ properties = ctx.system_properties
81
+
82
+ model_dir = properties['model_dir']
83
+ ts_modelpath = manifest['model']['serializedFile']
84
+ self.model = ts_model = torch.jit.load(os.path.join(model_dir, ts_modelpath))
85
+
86
+ from modules.inference.beam_search import BeamSearch
87
+ device = "cuda" if torch.cuda.is_available() else "cpu"
88
+ self.inferrer = BeamSearch(model, 160, device, beam_size=5)
89
+
90
+ self.initialized = True
91
+
92
+ def handle(self, data):
93
+ batch_sentences = data[0].get("data")
94
+ # assert batch_sentences is not None, "data is {}".format(data)
95
+
96
+ translated_sentences = self.inferrer.translate_batch_sentence(data, output_tokens=False)
97
+ return translated_sentences
98
+
99
+ RUNNING_MODEL = BeamSearchHandlerClass()
100
+
101
+ def handle(data, context):
102
+ if(not RUNNING_MODEL.initialized): # Lazy init
103
+ RUNNING_MODEL.initialize(context)
104
+
105
+ if(data is None):
106
+ return None
107
+
108
+ return RUNNING_MODEL.handle(data)
config/bilingual_prototype.yml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data location and config section
2
+ data:
3
+ train_data_location: data/test/train2023
4
+ eval_data_location: data/test/dev2023
5
+ src_lang: .lo
6
+ trg_lang: .vi
7
+ log_file_models: 'model.log'
8
+ lowercase: false
9
+ build_vocab_kwargs: # additional arguments for build_vocab. See torchtext.vocab.Vocab for mode details
10
+ # max_size: 50000
11
+ min_freq: 4
12
+ specials:
13
+ - <unk>
14
+ - <pad>
15
+ - <sos>
16
+ - <eos>
17
+ # data augmentation section
18
+ # model parameters section
19
+ device: cuda
20
+ d_model: 512
21
+ n_layers: 6
22
+ heads: 8
23
+ # inference section
24
+ eval_batch_size: 8
25
+ decode_strategy: BeamSearch
26
+ decode_strategy_kwargs:
27
+ beam_size: 5 # beam search size
28
+ length_normalize: 0.6 # recalculate beam position by length. Currently only work in default BeamSearch
29
+ replace_unk: # tuple of layer/head attention to replace unknown words
30
+ - 0 # layer
31
+ - 0 # head
32
+ input_max_length: 250 # input longer than this value will be trimmed in inference. Note that this values are to be used during cached PE, hence, validation set with more than this much tokens will call a warning for the trimming.
33
+ max_length: 160 # only perform up to this much timestep during inference
34
+ train_max_length: 140 # training samples with this much length in src/trg will be discarded
35
+ # optimizer and learning arguments section
36
+ lr: 0.2
37
+ optimizer: AdaBelief
38
+ optimizer_params:
39
+ betas:
40
+ - 0.9 # beta1
41
+ - 0.98 # beta2
42
+ eps: !!float 1e-9
43
+ n_warmup_steps: 4000
44
+ label_smoothing: 0.1
45
+ dropout: 0.05
46
+ # training config, evaluation, save & load section
47
+ batch_size: 32
48
+ epochs: 40
49
+ printevery: 200
50
+ save_checkpoint_epochs: 1
51
+ maximum_saved_model_eval: 5
52
+ maximum_saved_model_train: 5
config/prototype.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_src_data": "/workspace/khoai23/opennmt/data/iwslt_en_vi/train.en",
3
+ "train_trg_data": "/workspace/khoai23/opennmt/data/iwslt_en_vi/train.vi",
4
+ "valid_src_data": "/workspace/khoai23/opennmt/data/iwslt_en_vi/tst2013.en",
5
+ "valid_trg_data": "/workspace/khoai23/opennmt/data/iwslt_en_vi/tst2013.vi",
6
+ "src_lang": "en",
7
+ "trg_lang": "en",
8
+ "max_strlen": 160,
9
+ "batchsize": 1500,
10
+ "device": "cpu",
11
+ "d_model": 512,
12
+ "n_layers": 6,
13
+ "heads": 8,
14
+ "dropout": 0.1,
15
+ "lr": 0.0001,
16
+ "epochs": 30,
17
+ "printevery": 200,
18
+ "k": 5,
19
+ "n_warmup_steps": 4000,
20
+ "beta1": 0.9,
21
+ "beta2": 0.98,
22
+ "eps": 1e-09,
23
+ "label_smoothing": 0.1,
24
+ "save_checkpoint_epochs": 5
25
+ }
layers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from layers.prototypes import *
layers/prototypes.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.autograd import Variable
4
+ import torch.nn.functional as functional
5
+ import math
6
+ import logging
7
+
8
+ class PositionalEncoder(nn.Module):
9
+ def __init__(self, d_model, max_seq_length=200, dropout=0.1):
10
+ super().__init__()
11
+
12
+ self.d_model = d_model
13
+ self.dropout = nn.Dropout(dropout)
14
+ self._max_seq_length = max_seq_length
15
+
16
+ pe = torch.zeros(max_seq_length, d_model)
17
+
18
+ for pos in range(max_seq_length):
19
+ for i in range(0, d_model, 2):
20
+ pe[pos, i] = math.sin(pos/(10000**(2*i/d_model)))
21
+ pe[pos, i+1] = math.cos(pos/(10000**((2*i+1)/d_model)))
22
+ pe = pe.unsqueeze(0)
23
+ self.register_buffer('pe', pe)
24
+
25
+ @torch.jit.script
26
+ def splice_by_size(source, target):
27
+ """Custom function to splice the source by target's second dimension. Required due to torch.Size not a torchTensor. Why? hell if I know."""
28
+ length = target.size(1);
29
+ return source[:, :length]
30
+
31
+ self.splice_by_size = splice_by_size
32
+
33
+ def forward(self, x):
34
+ if(x.shape[1] > self._max_seq_length):
35
+ logging.warn("Input longer than maximum supported length for PE detected. Build a model with a larger input_max_length limit if you want to keep the input; or ignore if you want the input trimmed")
36
+ x = x[:, :self._max_seq_length]
37
+
38
+ x = x * math.sqrt(self.d_model)
39
+
40
+ spliced_pe = self.splice_by_size(self.pe, x) # self.pe[:, :x.shape[1]]
41
+ # pe = Variable(spliced_pe, requires_grad=False)
42
+ pe = spliced_pe.requires_grad_(False)
43
+
44
+ # if x.is_cuda: # remove since it is a sub nn.Module
45
+ # pe.cuda()
46
+ # assert all([xs == ys for xs, ys in zip(x.shape[1:], pe.shape[1:])]), "{} - {}".format(x.shape, pe.shape)
47
+
48
+ x = x + pe
49
+ x = self.dropout(x)
50
+
51
+ return x
52
+
53
+ class MultiHeadAttention(nn.Module):
54
+ def __init__(self, heads, d_model, dropout=0.1):
55
+ super().__init__()
56
+ assert d_model % heads == 0
57
+
58
+ self.d_model = d_model
59
+ self.d_k = d_model // heads
60
+ self.h = heads
61
+
62
+ # three casting linear layer for query/key.value
63
+ self.q_linear = nn.Linear(d_model, d_model)
64
+ self.k_linear = nn.Linear(d_model, d_model)
65
+ self.v_linear = nn.Linear(d_model, d_model)
66
+
67
+ self.dropout = nn.Dropout(dropout)
68
+ self.out = nn.Linear(d_model, d_model)
69
+
70
+ def forward(self, q, k, v, mask=None):
71
+ """
72
+ Args:
73
+ q / k / v: query/key/value, should all be [batch_size, sequence_length, d_model]. Only differ in decode attention, where q is tgt_len and k/v is src_len
74
+ mask: either [batch_size, 1, src_len] or [batch_size, tgt_len, tgt_len]. The last two dimensions must match or are broadcastable.
75
+ Returns:
76
+ the value of the attention process, [batch_size, sequence_length, d_model].
77
+ The used attention, [batch_size, q_length, k_v_length]
78
+ """
79
+ bs = q.shape[0]
80
+ q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
81
+ k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
82
+ v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
83
+
84
+ q = q.transpose(1, 2)
85
+ k = k.transpose(1, 2)
86
+ v = v.transpose(1, 2)
87
+
88
+ value, attn = self.attention(q, k, v, mask, self.dropout)
89
+ concat = value.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
90
+ output = self.out(concat)
91
+ return output, attn
92
+
93
+ def attention(self, q, k, v, mask=None, dropout=None):
94
+ """Calculate the attention and output the attention & value
95
+ Args:
96
+ q / k / v: query/key/value already transformed, should all be [batch_size, heads, sequence_length, d_k]. Only differ in decode attention, where q is tgt_len and k/v is src_len
97
+ mask: either [batch_size, 1, src_len] or [batch_size, tgt_len, tgt_len]. The last two dimensions must match or are broadcastable.
98
+ Returns:
99
+ the attentionized but raw values [batch_size, head, seq_length, d_k]
100
+ the attention calculated [batch_size, heads, sequence_length, sequence_length]
101
+ """
102
+
103
+ # d_k = q.shape[-1]
104
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
105
+
106
+ if mask is not None:
107
+ mask = mask.unsqueeze(1) # add a dimension to account for head
108
+ scores = scores.masked_fill(mask==0, -1e9)
109
+ # softmax the padding/peeking masked attention
110
+ scores = functional.softmax(scores, dim=-1)
111
+
112
+ if dropout is not None:
113
+ scores = dropout(scores)
114
+
115
+ output = torch.matmul(scores, v)
116
+ return output, scores
117
+
118
+ class Norm(nn.Module):
119
+ def __init__(self, d_model, eps = 1e-6):
120
+ super().__init__()
121
+
122
+ self.size = d_model
123
+
124
+ # create two learnable parameters to calibrate normalisation
125
+ self.alpha = nn.Parameter(torch.ones(self.size))
126
+ self.bias = nn.Parameter(torch.zeros(self.size))
127
+
128
+ self.eps = eps
129
+
130
+ def forward(self, x):
131
+ norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \
132
+ / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
133
+ return norm
134
+
135
+ class FeedForward(nn.Module):
136
+ """A two-hidden-linear feedforward layer that can activate and dropout its transition state"""
137
+ def __init__(self, d_model, d_ff=2048, internal_activation=functional.relu, dropout=0.1):
138
+ super().__init__()
139
+ self.linear_1 = nn.Linear(d_model, d_ff)
140
+ self.dropout = nn.Dropout(dropout)
141
+ self.linear_2 = nn.Linear(d_ff, d_model)
142
+
143
+ self.internal_activation = internal_activation
144
+
145
+ def forward(self, x):
146
+ x = self.dropout(self.internal_activation(self.linear_1(x)))
147
+ x = self.linear_2(x)
148
+ return x
models/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from models.default import MockModel
2
+ from models.transformer import Transformer
3
+
4
+ AvailableModels = {
5
+ "MockModel": MockModel,
6
+ "Transformer" : Transformer
7
+ }
models/default.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class MockModel:
2
+ """A model that only output string to show flow"""
3
+ def __init__(self, *args, **kwargs):
4
+ print("Mock model initialization, with args/kwargs: {} {}".format(args, kwargs))
5
+
6
+ def run_train(self, **kwargs):
7
+ print("Model in training, with args: {}".format(kwargs))
8
+
9
+ def run_eval(self, **kwargs):
10
+ print("Model in evaluation, with args: {}".format(kwargs))
11
+
12
+ def run_debug(self, **kwargs):
13
+ print("Model in debuging, with args: {}".format(kwargs))
models/transformer.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchtext.data as data
4
+ import copy, time, io
5
+ import numpy as np
6
+
7
+ from modules.prototypes import Encoder, Decoder, Config as DefaultConfig
8
+ from modules.loader import DefaultLoader, MultiLoader
9
+ from modules.config import MultiplePathConfig as Config
10
+ from modules.inference import strategies
11
+ from modules import constants as const
12
+ from modules.optim import optimizers, ScheduledOptim
13
+
14
+ import utils.save as saver
15
+ from utils.decode_old import create_masks, translate_sentence
16
+ #from utils.data import create_fields, create_dataset, read_data, read_file, write_file
17
+ from utils.loss import LabelSmoothingLoss
18
+ from utils.metric import bleu, bleu_batch_iter, bleu_single, bleu_batch
19
+ #from utils.save import load_model_from_path, check_model_in_path, save_and_clear_model, write_model_score, load_model_score, save_model_best_to_path, load_model
20
+
21
+ class Transformer(nn.Module):
22
+ """
23
+ Implementation of Transformer architecture based on the paper `Attention is all you need`.
24
+ Source: https://arxiv.org/abs/1706.03762
25
+ """
26
+ def __init__(self, mode=None, model_dir=None, config=None):
27
+ super().__init__()
28
+
29
+ # Use specific config file if provided otherwise use the default config instead
30
+ self.config = DefaultConfig() if(config is None) else Config(config)
31
+ opt = self.config
32
+ self.device = opt.get('device', const.DEFAULT_DEVICE)
33
+
34
+ if('train_data_location' in opt or 'train_data_location' in opt.get("data", {})):
35
+ # monolingual data detected
36
+ data_opt = opt if 'train_data_location' in opt else opt["data"]
37
+ self.loader = DefaultLoader(data_opt['train_data_location'], eval_path=data_opt.get('eval_data_location', None), language_tuple=(data_opt["src_lang"], data_opt["trg_lang"]), option=opt)
38
+ elif('data' in opt):
39
+ # multilingual data with multiple corpus in [data][train] namespace
40
+ self.loader = MultiLoader(opt["data"]["train"], valid=opt["data"].get("valid", None), option=opt)
41
+ # input fields
42
+ self.SRC, self.TRG = self.loader.build_field(lower=opt.get("lowercase", const.DEFAULT_LOWERCASE))
43
+ # self.SRC = data.Field(lower=opt.get("lowercase", const.DEFAULT_LOWERCASE))
44
+ # self.TRG = data.Field(lower=opt.get("lowercase", const.DEFAULT_LOWERCASE), eos_token='<eos>')
45
+
46
+ # initialize dataset and by proxy the vocabulary
47
+ if(mode == "train"):
48
+ # training flow, necessitate the DataLoader and iterations. This will attempt to load vocab file from the dir instead of rebuilding, but can build a new vocab if no data is found
49
+ self.train_iter, self.valid_iter = self.loader.create_iterator(self.fields, model_path=model_dir)
50
+ elif(mode == "eval"):
51
+ # evaluation flow, which only require valid_iter
52
+ # TODO fix accordingly
53
+ self.train_iter, self.valid_iter = self.loader.create_iterator(self.fields, model_path=model_dir)
54
+ elif(mode == "infer"):
55
+ # inference, require pickled model and vocab in the path
56
+ self.loader.build_vocab(self.fields, model_path=model_dir)
57
+ else:
58
+ raise ValueError("Unknown model's mode: {}".format(mode))
59
+
60
+
61
+ # define the model
62
+ src_vocab_size, trg_vocab_size = len(self.SRC.vocab), len(self.TRG.vocab)
63
+ d_model, N, heads, dropout = opt['d_model'], opt['n_layers'], opt['heads'], opt['dropout']
64
+ # get the maximum amount of tokens per sample in encoder. This is useful due to PositionalEncoder requiring this value
65
+ train_ignore_length = self.config.get("train_max_length", const.DEFAULT_TRAIN_MAX_LENGTH)
66
+ input_max_length = self.config.get("input_max_length", const.DEFAULT_INPUT_MAX_LENGTH)
67
+ infer_max_length = self.config.get('max_length', const.DEFAULT_MAX_LENGTH)
68
+ encoder_max_length = max(input_max_length, train_ignore_length)
69
+ decoder_max_length = max(infer_max_length, train_ignore_length)
70
+ self.encoder = Encoder(src_vocab_size, d_model, N, heads, dropout, max_seq_length=encoder_max_length)
71
+ self.decoder = Decoder(trg_vocab_size, d_model, N, heads, dropout, max_seq_length=decoder_max_length)
72
+ self.out = nn.Linear(d_model, trg_vocab_size)
73
+
74
+ # load the beamsearch obj with preset values read from config. ALWAYS require the current model, max_length, and device used as per DecodeStrategy base
75
+ decode_strategy_class = strategies[opt.get('decode_strategy', const.DEFAULT_DECODE_STRATEGY)]
76
+ decode_strategy_kwargs = opt.get('decode_strategy_kwargs', const.DEFAULT_STRATEGY_KWARGS)
77
+ self.decode_strategy = decode_strategy_class(self, infer_max_length, self.device, **decode_strategy_kwargs)
78
+
79
+ self.to(self.device)
80
+
81
+ def load_checkpoint(self, model_dir, checkpoint=None, checkpoint_idx=0):
82
+ """Attempt to load past checkpoint into the model. If a specified checkpoint is set, load it; otherwise load the latest checkpoint in model_dir.
83
+ Args:
84
+ model_dir: location of the current model. Not used if checkpoint is specified
85
+ checkpoint: location of the specific checkpoint to load
86
+ checkpoint_idx: the epoch of the checkpoint
87
+ NOTE: checkpoint_idx return -1 in the event of not found; while 0 is when checkpoint is forced
88
+ """
89
+ if(checkpoint is not None):
90
+ saver.load_model(self, checkpoint)
91
+ self._checkpoint_idx = checkpoint_idx
92
+ else:
93
+ if model_dir is not None:
94
+ # load the latest available checkpoint, overriding the checkpoint value
95
+ checkpoint_idx = saver.check_model_in_path(model_dir)
96
+ if(checkpoint_idx > 0):
97
+ print("Found model with index {:d} already saved.".format(checkpoint_idx))
98
+ saver.load_model_from_path(self, model_dir, checkpoint_idx=checkpoint_idx)
99
+ else:
100
+ print("No checkpoint found, start from beginning.")
101
+ checkpoint_idx = -1
102
+ else:
103
+ print("No model_dir available, start from beginning.")
104
+ # train the model from begin
105
+ checkpoint_idx = -1
106
+ self._checkpoint_idx = checkpoint_idx
107
+
108
+
109
+ def forward(self, src, trg, src_mask, trg_mask, output_attention=False):
110
+ """Run a full model with specified source-target batched set of data
111
+ Args:
112
+ src: the source input [batch_size, src_len]
113
+ trg: the target input (& expected output) [batch_size, trg len]
114
+ src_mask: the padding mask for src [batch_size, 1, src_len]
115
+ trg_mask: the triangle mask for trg [batch_size, trg_len, trg_len]
116
+ output_attention: if specified, output the attention as needed
117
+ Returns:
118
+ the logits (unsoftmaxed outputs), same shape as trg
119
+ """
120
+ e_outputs = self.encoder(src, src_mask)
121
+ d_output, attn = self.decoder(trg, e_outputs, src_mask, trg_mask, output_attention=True)
122
+ output = self.out(d_output)
123
+ if(output_attention):
124
+ return output, attn
125
+ else:
126
+ return output
127
+
128
+ def train_step(self, optimizer, batch, criterion):
129
+ """
130
+ Perform one training step.
131
+ """
132
+ self.train()
133
+ opt = self.config
134
+
135
+ # move data to specific device's memory
136
+ src = batch.src.transpose(0, 1).to(opt.get('device', const.DEFAULT_DEVICE))
137
+ trg = batch.trg.transpose(0, 1).to(opt.get('device', const.DEFAULT_DEVICE))
138
+
139
+ trg_input = trg[:, :-1]
140
+ src_pad = self.SRC.vocab.stoi['<pad>']
141
+ trg_pad = self.TRG.vocab.stoi['<pad>']
142
+ ys = trg[:, 1:].contiguous().view(-1)
143
+
144
+ # create mask and perform network forward
145
+ src_mask, trg_mask = create_masks(src, trg_input, src_pad, trg_pad, opt.get('device', const.DEFAULT_DEVICE))
146
+ preds = self(src, trg_input, src_mask, trg_mask)
147
+
148
+ # perform backprogation
149
+ optimizer.zero_grad()
150
+ loss = criterion(preds.view(-1, preds.size(-1)), ys)
151
+ loss.backward()
152
+ optimizer.step_and_update_lr()
153
+ loss = loss.item()
154
+
155
+ return loss
156
+
157
+ def validate(self, valid_iter, criterion, maximum_length=None):
158
+ """Compute loss in validation dataset. As we can't perform trimming the input in the valid_iter yet, using a crutch in maximum_input_length variable
159
+ Args:
160
+ valid_iter: the Iteration containing batches of data, accessed by .src and .trg
161
+ criterion: the loss function to use to evaluate
162
+ maximum_length: if fed, a tuple of max_input_len, max_output_len to trim the src/trg
163
+ Returns:
164
+ the avg loss of the criterion
165
+ """
166
+ self.eval()
167
+ opt = self.config
168
+ src_pad = self.SRC.vocab.stoi['<pad>']
169
+ trg_pad = self.TRG.vocab.stoi['<pad>']
170
+
171
+ with torch.no_grad():
172
+ total_loss = []
173
+ for batch in valid_iter:
174
+ # load model into specific device (GPU/CPU) memory
175
+ src = batch.src.transpose(0, 1).to(opt.get('device', const.DEFAULT_DEVICE))
176
+ trg = batch.trg.transpose(0, 1).to(opt.get('device', const.DEFAULT_DEVICE))
177
+ if(maximum_length is not None):
178
+ src = src[:, :maximum_length[0]]
179
+ trg = trg[:, :maximum_length[1]-1] # using partials
180
+ trg_input = trg[:, :-1]
181
+ ys = trg[:, 1:].contiguous().view(-1)
182
+
183
+ # create mask and perform network forward
184
+ src_mask, trg_mask = create_masks(src, trg_input, src_pad, trg_pad, opt.get('device', const.DEFAULT_DEVICE))
185
+ preds = self(src, trg_input, src_mask, trg_mask)
186
+
187
+ # compute loss on current batch
188
+ loss = criterion(preds.view(-1, preds.size(-1)), ys)
189
+ loss = loss.item()
190
+ total_loss.append(loss)
191
+
192
+ avg_loss = np.mean(total_loss)
193
+ return avg_loss
194
+
195
+ def translate_sentence(self, sentence, device=None, k=None, max_len=None, debug=False):
196
+ """
197
+ Receive a sentence string and output the prediction generated from the model.
198
+ NOTE: sentence input is a list of tokens instead of string due to change in loader. See the current DefaultLoader for further details
199
+ """
200
+ self.eval()
201
+ if(device is None): device = self.config.get('device', const.DEFAULT_DEVICE)
202
+ if(k is None): k = self.config.get('k', const.DEFAULT_K)
203
+ if(max_len is None): max_len = self.config.get('max_length', const.DEFAULT_MAX_LENGTH)
204
+
205
+ # Get output from decode
206
+ translated_tokens = translate_sentence(sentence, self, self.SRC, self.TRG, device, k, max_len, debug=debug, output_list_of_tokens=True)
207
+
208
+ return translated_tokens
209
+
210
+ def translate_batch_sentence(self, sentences, src_lang=None, trg_lang=None, output_tokens=False, batch_size=None):
211
+ """Translate sentences by splitting them to batches and process them simultaneously
212
+ Args:
213
+ sentences: the sentences in a list. Must NOT have been tokenized (due to SRC preprocess)
214
+ output_tokens: if set, do not detokenize the output
215
+ batch_size: if specified, use the value; else use config ones
216
+ Returns:
217
+ a matching translated sentences list in [detokenized format using loader.detokenize | list of tokens]
218
+ """
219
+ if(batch_size is None):
220
+ batch_size = self.config.get("eval_batch_size", const.DEFAULT_EVAL_BATCH_SIZE)
221
+ input_max_length = self.config.get("input_max_length", const.DEFAULT_INPUT_MAX_LENGTH)
222
+ self.eval()
223
+
224
+ translated = []
225
+ for b_idx in range(0, len(sentences), batch_size):
226
+ batch = sentences[b_idx: b_idx+batch_size]
227
+ # raise Exception(batch)
228
+ trans_batch = self.translate_batch(batch, trg_lang=trg_lang, output_tokens=output_tokens, input_max_length=input_max_length)
229
+ # raise Exception(detokenized_batch)
230
+ translated.extend(trans_batch)
231
+ # for line in trans_batch:
232
+ # print(line)
233
+ return translated
234
+
235
+ def translate_batch(self, batch_sentences, src_lang=None, trg_lang=None, output_tokens=False, input_max_length=None):
236
+ """Translate a single batch of sentences. Split to aid serving
237
+ Args:
238
+ sentences: the sentences in a list. Must NOT have been tokenized (due to SRC preprocess)
239
+ src_lang/trg_lang: the language from src->trg. Used for multilingual models only.
240
+ output_tokens: if set, do not detokenize the output
241
+ Returns:
242
+ a matching translated sentences list in [detokenized format using loader.detokenize | list of tokens]
243
+ """
244
+ if(input_max_length is None):
245
+ input_max_length = self.config.get("input_max_length", const.DEFAULT_INPUT_MAX_LENGTH)
246
+ translated_batch = self.decode_strategy.translate_batch(batch_sentences, trg_lang=trg_lang, src_size_limit=input_max_length, output_tokens=True, debug=False)
247
+ return self.loader.detokenize(translated_batch) if not output_tokens else translated_batch
248
+
249
+ def run_train(self, model_dir=None, config=None):
250
+ opt = self.config
251
+ from utils.logging import init_logger
252
+ logging = init_logger(model_dir, opt.get('log_file_models'))
253
+
254
+ trg_pad = self.TRG.vocab.stoi['<pad>']
255
+ # load model into specific device (GPU/CPU) memory
256
+ logging.info("%s * src vocab size = %s"%(self.loader._language_tuple[0] ,len(self.SRC.vocab)))
257
+ logging.info("%s * tgt vocab size = %s"%(self.loader._language_tuple[1] ,len(self.TRG.vocab)))
258
+ logging.info("Building model...")
259
+ model = self.to(opt.get('device', const.DEFAULT_DEVICE))
260
+
261
+ checkpoint_idx = self._checkpoint_idx
262
+ if(checkpoint_idx < 0):
263
+ # initialize weights
264
+ print("Zero checkpoint detected, reinitialize the model")
265
+ for p in model.parameters():
266
+ if p.dim() > 1:
267
+ nn.init.xavier_uniform_(p)
268
+ checkpoint_idx = 0
269
+
270
+ # also, load the scores of the best model
271
+ best_model_score = saver.load_model_score(model_dir)
272
+
273
+ # set up optimizer
274
+ optim_algo = opt["optimizer"]
275
+ lr = opt["lr"]
276
+ d_model = opt["d_model"]
277
+ n_warmup_steps = opt["n_warmup_steps"]
278
+ optimizer_params = opt.get("optimizer_params", dict({}))
279
+
280
+ if optim_algo not in optimizers:
281
+ raise ValueError("Unknown optimizer: {}".format(optim_algo))
282
+
283
+ optimizer = ScheduledOptim(
284
+ optimizer=optimizers.get(optim_algo)(model.parameters(), **optimizer_params),
285
+ init_lr=lr,
286
+ d_model=d_model,
287
+ n_warmup_steps=n_warmup_steps
288
+ )
289
+
290
+ # define loss function
291
+ criterion = LabelSmoothingLoss(len(self.TRG.vocab), padding_idx=trg_pad, smoothing=opt['label_smoothing'])
292
+
293
+ # valid_src_data, valid_trg_data = self.loader._eval_data
294
+ # raise Exception("Initial bleu: %.3f" % bleu_batch_iter(self, self.valid_iter, debug=True))
295
+ logging.info(self)
296
+ model_encoder_parameters = filter(lambda p: p.requires_grad, self.encoder.parameters())
297
+ model_decoder_parameters = filter(lambda p: p.requires_grad, self.decoder.parameters())
298
+ params_encode = sum([np.prod(p.size()) for p in model_encoder_parameters])
299
+ params_decode = sum([np.prod(p.size()) for p in model_decoder_parameters])
300
+
301
+ logging.info("Encoder: %s"%(params_encode))
302
+ logging.info("Decoder: %s"%(params_decode))
303
+ logging.info("* Number of parameters: %s"%(params_encode+params_decode))
304
+ logging.info("Starting training on %s"%(opt.get('device', const.DEFAULT_DEVICE)))
305
+
306
+ for epoch in range(checkpoint_idx, opt['epochs']):
307
+ total_loss = 0.0
308
+
309
+ s = time.time()
310
+ for i, batch in enumerate(self.train_iter):
311
+ loss = self.train_step(optimizer, batch, criterion)
312
+ total_loss += loss
313
+
314
+ # print training loss after every {printevery} steps
315
+ if (i + 1) % opt['printevery'] == 0:
316
+ avg_loss = total_loss / opt['printevery']
317
+ et = time.time() - s
318
+ # print('epoch: {:03d} - iter: {:05d} - train loss: {:.4f} - time elapsed/per batch: {:.4f} {:.4f}'.format(epoch, i+1, avg_loss, et, et / opt['printevery']))
319
+ logging.info('epoch: {:03d} - iter: {:05d} - train loss: {:.4f} - time elapsed/per batch: {:.4f} {:.4f}'.format(epoch, i+1, avg_loss, et, et / opt['printevery']))
320
+ total_loss = 0
321
+ s = time.time()
322
+
323
+ # bleu calculation and evaluate, save checkpoint for every {save_checkpoint_epochs} epochs
324
+ s = time.time()
325
+ valid_loss = self.validate(self.valid_iter, criterion, maximum_length=(self.encoder._max_seq_length, self.decoder._max_seq_length))
326
+ if (epoch+1) % opt['save_checkpoint_epochs'] == 0 and model_dir is not None:
327
+
328
+ # evaluate loss and bleu score on validation dataset for each epoch
329
+ # bleuscore = bleu(valid_src_data, valid_trg_data, model, opt.get('device', const.DEFAULT_DEVICE), opt['k'], opt['max_strlen'])
330
+ # bleuscore = bleu_single(self, self.loader._eval_data)
331
+ # bleuscore = bleu_batch(self, self.loader._eval_data, batch_size=opt.get('eval_batch_size', const.DEFAULT_EVAL_BATCH_SIZE))
332
+ valid_src_lang, valid_trg_lang = self.loader.language_tuple
333
+ bleuscore = bleu_batch_iter(self, self.valid_iter, src_lang=valid_src_lang, trg_lang=valid_trg_lang)
334
+
335
+ # save_model_to_path(model, model_dir, checkpoint_idx=epoch+1)
336
+ saver.save_and_clear_model(model, model_dir, checkpoint_idx=epoch+1, maximum_saved_model=opt.get('maximum_saved_model_train', const.DEFAULT_NUM_KEEP_MODEL_TRAIN))
337
+ # keep the best models per every bleu calculation
338
+ best_model_score = saver.save_model_best_to_path(model, model_dir, best_model_score, bleuscore, maximum_saved_model=opt.get('maximum_saved_model_eval', const.DEFAULT_NUM_KEEP_MODEL_TRAIN))
339
+ # print('epoch: {:03d} - iter: {:05d} - valid loss: {:.4f} - bleu score: {:.4f} - full evaluation time: {:.4f}'.format(epoch, i, valid_loss, bleuscore, time.time() - s))
340
+ logging.info('epoch: {:03d} - iter: {:05d} - valid loss: {:.4f} - bleu score: {:.4f} - full evaluation time: {:.4f}'.format(epoch, i, valid_loss, bleuscore, time.time() - s))
341
+ else:
342
+ # print('epoch: {:03d} - iter: {:05d} - valid loss: {:.4f} - validation time: {:.4f}'.format(epoch, i, valid_loss, time.time() - s))
343
+ logging.info('epoch: {:03d} - iter: {:05d} - valid loss: {:.4f} - validation time: {:.4f}'.format(epoch, i, valid_loss, time.time() - s))
344
+
345
+ def run_infer(self, features_file, predictions_file, src_lang=None, trg_lang=None, config=None, batch_size=None):
346
+ opt = self.config
347
+ # load model into specific device (GPU/CPU) memory
348
+ model = self.to(opt.get('device', const.DEFAULT_DEVICE))
349
+
350
+ # Read inference file
351
+ print("Reading features file from {}...".format(features_file))
352
+ with io.open(features_file, "r", encoding="utf-8") as read_file:
353
+ inputs = [l.strip() for l in read_file.readlines()]
354
+
355
+ print("Performing inference ...")
356
+ # Append each translated sentence line by line
357
+ # results = "\n".join([model.loader.detokenize(model.translate_sentence(sentence)) for sentence in inputs])
358
+ # Translate by batched versions
359
+ start = time.time()
360
+ results = "\n".join( self.translate_batch_sentence(inputs, src_lang=src_lang, trg_lang=trg_lang, output_tokens=False, batch_size=batch_size))
361
+ print("Inference done, cost {:.2f} secs.".format(time.time() - start))
362
+
363
+ # Write results to system file
364
+ print("Writing results to {} ...".format(predictions_file))
365
+ with io.open(predictions_file, "w", encoding="utf-8") as write_file:
366
+ write_file.write(results)
367
+
368
+ print("All done!")
369
+
370
+ def encode(self, *args, **kwargs):
371
+ return self.encoder(*args, **kwargs)
372
+
373
+ def decode(self, *args, **kwargs):
374
+ return self.decoder(*args, **kwargs)
375
+
376
+ def to_logits(self, inputs): # function to include the logits. TODO use this in inference fns as well
377
+ return self.out(inputs)
378
+
379
+ def prepare_serve(self, serve_path, model_dir=None, check_trace=True, **kwargs):
380
+ self.eval()
381
+ """Run to prepare for serving."""
382
+ saver.save_model_name(type(self).__name__, model_dir)
383
+ # return
384
+ # raise NotImplementedError("trace_module currently not supported")
385
+ # jit to convert model to ScriptModule.
386
+ # create junk arguments for necessary modules
387
+ fake_batch, fake_srclen, fake_trglen, fake_range = 3, 7, 4, 1000
388
+ sample_src, sample_trg = torch.randint(fake_range, (fake_batch, fake_srclen), dtype=torch.long), torch.randint(fake_range, (fake_batch, fake_trglen), dtype=torch.long)
389
+ sample_src_mask, sample_trg_mask = torch.rand(fake_batch, 1, fake_srclen) > 0.5, torch.rand(fake_batch, fake_trglen, fake_trglen) > 0.5
390
+ sample_src, sample_trg, sample_src_mask, sample_trg_mask = [t.to(self.device) for t in [sample_src, sample_trg, sample_src_mask, sample_trg_mask]]
391
+ sample_encoded = self.encode(sample_src, sample_src_mask)
392
+ sample_before_logits = self.decode(sample_trg, sample_encoded, sample_src_mask, sample_trg_mask)
393
+ # bundle within dictionary
394
+ needed_fn = {'forward': (sample_src, sample_trg, sample_src_mask, sample_trg_mask), "encode": (sample_src, sample_src_mask), "decode": (sample_trg, sample_encoded, sample_src_mask, sample_trg_mask), "to_logits": sample_before_logits}
395
+ # create the ScriptModule. Currently disabling deterministic check
396
+ traced_model = torch.jit.trace_module(self, needed_fn, check_trace=check_trace)
397
+ # save it down
398
+ torch.jit.save(traced_model, serve_path)
399
+ return serve_path
400
+
401
+
402
+ @property
403
+ def fields(self):
404
+ return (self.SRC, self.TRG)
modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from modules.default import *
2
+ from modules.prototypes import Decoder, Encoder
3
+ from modules.config import Config
modules/config.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml, json
2
+ import os, io
3
+
4
+ def extension_check(pth):
5
+ ext = os.path.splitext(pth)[-1]
6
+ return any( ext == valid_ext for valid_ext in [".json", ".yaml", ".yml"])
7
+
8
+ def find_all_config(directory):
9
+ return [os.path.join(directory, f) for f in os.listdir(directory) if extension_check(f)]
10
+
11
+ class Config(dict):
12
+ def __init__(self, path=None, **elements):
13
+ """Initiate a config object, where specified elements override the default config loaded"""
14
+ super(Config, self).__init__(self._try_load_path(path))
15
+ self.update(**elements)
16
+
17
+ def _load_json(self, json_path):
18
+ with io.open(json_path, "r", encoding="utf-8") as jf:
19
+ return json.load(jf)
20
+
21
+ def _load_yaml(self, yaml_path):
22
+ with io.open(yaml_path, "r", encoding="utf-8") as yf:
23
+ return yaml.safe_load(yf.read())
24
+
25
+ def _try_load_path(self, path):
26
+ assert isinstance(path, str), "Basic Config class can only support a single file path (str), but instead is {}({})".format(path, type(path))
27
+ assert os.path.isfile(path), "Config file {:s} does not exist".format(path)
28
+ extension = os.path.splitext(path)[-1]
29
+ if(extension == ".json"):
30
+ return self._load_json(path)
31
+ elif(extension == ".yml" or extension == ".yaml"):
32
+ return self._load_yaml(path)
33
+ else:
34
+ raise ValueError("Unrecognized extension ({:s}) from file {:s}".format(extension, path))
35
+
36
+ @property
37
+ def opt(self):
38
+ """Backward compatibility to original. Remove once finished."""
39
+ return self
40
+
41
+ class MultiplePathConfig(Config):
42
+ def _try_load_path(self, paths):
43
+ """Update to support multiple paths."""
44
+ if(isinstance(paths, list)):
45
+ print("Loaded path is a list of locations. Load in the order received, overriding and merging as needed.")
46
+ result = {}
47
+ for pth in paths:
48
+ self._recursive_update(result, super(MultiplePathConfig, self)._try_load_path(pth))
49
+ return result
50
+ else:
51
+ return super(MultiplePathConfig, self)._try_load_path(paths)
52
+
53
+ def _recursive_update(self, orig, new):
54
+ """Instead of overriding dicts, merge them recursively."""
55
+ # print(orig, new)
56
+ for k, v in new.items():
57
+ if(k in orig and isinstance(orig[k], dict)):
58
+ assert isinstance(v, dict), "Mismatching config with key {}: {} - {}".format(k, orig[k], v)
59
+ orig[k] = self._recursive_update(orig[k], v)
60
+ else:
61
+ orig[k] = v;
62
+ return orig
modules/constants.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DESIGNATE constants values for config
2
+ DEFAULT_DECODE_STRATEGY = "BeamSearch"
3
+ DEFAULT_STRATEGY_KWARGS = {}
4
+ DEFAULT_SEED = 101
5
+ DEFAULT_BATCH_SIZE = 64
6
+ DEFAULT_EVAL_BATCH_SIZE = 8
7
+ DEFAULT_TRAIN_TEST_SPLIT = 0.8
8
+ DEFAULT_DEVICE = "cpu"
9
+ DEFAULT_K = 5
10
+ DEFAULT_INPUT_MAX_LENGTH = 200
11
+ DEFAULT_MAX_LENGTH = 150
12
+ DEFAULT_TRAIN_MAX_LENGTH = 100
13
+ DEFAULT_LOWERCASE = True
14
+ DEFAULT_NUM_KEEP_MODEL_TRAIN = 5
15
+ DEFAULT_NUM_KEEP_MODEL_BEST = 5
16
+ DEFAULT_SOS = "<sos>"
17
+ DEFAULT_EOS = "<eos>"
18
+ DEFAULT_PAD = "<pad>"
modules/default.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class MockLoader:
2
+ def __init__(self, *args, **kwargs):
3
+ """Only print stuff"""
4
+ print("MockLoader initialized, args/kwargs {} {}".format(args, kwargs))
5
+
6
+ def tokenize(self, inputs, **kwargs):
7
+ print("MockLoader tokenize called, inputs/kwargs {} {}".format(inputs, kwargs))
8
+ return inputs
9
+
10
+ def detokenize(self, inputs, **kwargs):
11
+ print("MockLoader detokenize called, inputs/kwargs {} {}".format(inputs, kwargs))
12
+ return inputs
13
+
14
+ def reverse_lookup(self, inputs, **kwargs):
15
+ print("MockLoader reverse_lookup called, inputs/kwargs {} {}".format(inputs, kwargs))
16
+ return inputs
17
+
18
+ def lookup(self, inputs, **kwargs):
19
+ print("MockLoader lookup called, inputs/kwargs {} {}".format(inputs, kwargs))
20
+ return inputs
21
+
22
+ def embed(self, inputs, **kwargs):
23
+ print("MockLoader embed called, inputs/kwargs {} {}".format(inputs, kwargs))
24
+ return inputs
25
+
26
+ class MockEncoder:
27
+ def __init__(self, *args, **kwargs):
28
+ """Only print stuff"""
29
+ print("MockEncoder initialized, args/kwargs {} {}".format(args, kwargs))
30
+
31
+ def encode(self, inputs, **kwargs):
32
+ print("MockEncoder encode called, inputs/kwargs {} {}".format(inputs, kwargs))
33
+ return inputs
34
+
35
+ def __call__(self, inputs, num_layers=3, **kwargs):
36
+ print("MockEncoder __call__ called, inputs/num_layers/kwargs {} {} {}".format(inputs, num_layers, kwargs))
37
+ for i in range(num_layers):
38
+ inputs = encode(inputs, **kwargs)
39
+ return inputs
40
+
41
+ class MockDecoder:
42
+ def __init__(self, *args, **kwargs):
43
+ """Only print stuff"""
44
+ print("MockDecoder initialized, args/kwargs {} {}".format(args, kwargs))
45
+
46
+ def decode(self, inputs, **kwargs):
47
+ print("MockDecoder decode called, inputs/kwargs {} {}".format(inputs, kwargs))
48
+ return inputs
49
+
50
+ def __call__(self, inputs, num_layers=3, **kwargs):
51
+ print("MockDecoder __call__ called, inputs/num_layers/kwargs {} {} {}".format(inputs, num_layers, kwargs))
52
+ for i in range(num_layers):
53
+ inputs = decode(inputs, **kwargs)
54
+ return inputs
modules/inference/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.inference.decode_strategy import DecodeStrategy
2
+ from modules.inference.beam_search import BeamSearch
3
+ from modules.inference.prototypes import BeamSearch2
4
+ from modules.inference.sampling_temperature import GreedySearch
5
+
6
+ strategies = {
7
+ "BeamSearch": BeamSearch,
8
+ "BeamSearch2": BeamSearch2,
9
+ "GreedySearch": GreedySearch
10
+ }
modules/inference/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (483 Bytes). View file
 
modules/inference/__pycache__/beam_search.cpython-36.pyc ADDED
Binary file (15.3 kB). View file
 
modules/inference/__pycache__/decode_strategy.cpython-36.pyc ADDED
Binary file (3.26 kB). View file
 
modules/inference/__pycache__/prototypes.cpython-36.pyc ADDED
Binary file (4.73 kB). View file
 
modules/inference/__pycache__/sampling_temperature.cpython-36.pyc ADDED
Binary file (4.42 kB). View file
 
modules/inference/beam_search.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import math, time, operator
4
+ import torch.nn.functional as functional
5
+ import torch.nn as nn
6
+ import logging
7
+ from torch.autograd import Variable
8
+ from torch.nn.utils.rnn import pad_sequence
9
+
10
+ from modules.inference.decode_strategy import DecodeStrategy
11
+ import modules.constants as const
12
+ from utils.misc import no_peeking_mask
13
+ from utils.data import generate_language_token
14
+
15
+ class BeamSearch(DecodeStrategy):
16
+ def __init__(self, model, max_len, device, beam_size=5, use_synonym_fn=False, replace_unk=None, length_normalize=None):
17
+ """
18
+ Args:
19
+ model: the used model
20
+ max_len: the maximum timestep to be used
21
+ device: the device to perform calculation
22
+ beam_size: the size of the beam itself
23
+ use_synonym_fn: if set, use the get_synonym fn from wordnet to try replace <unk>
24
+ replace_unk: a tuple of [layer, head] designation, to replace the unknown word by chosen attention
25
+ """
26
+ super(BeamSearch, self).__init__(model, max_len, device)
27
+ self.beam_size = beam_size
28
+ self._use_synonym = use_synonym_fn
29
+ self._replace_unk = replace_unk
30
+ self._length_norm = length_normalize
31
+
32
+ def init_vars(self, src, start_token=const.DEFAULT_SOS):
33
+ """
34
+ Calculate the required matrices during translation after the model is finished
35
+ Input:
36
+ :param src: The batch of sentences
37
+
38
+ Output: Initialize the first character includes outputs, e_outputs, log_scores
39
+ """
40
+ model = self.model
41
+ batch_size = len(src)
42
+ row_b = self.beam_size * batch_size
43
+
44
+ init_tok = self.TRG.vocab.stoi[start_token]
45
+ src_mask = (src != self.SRC.vocab.stoi['<pad>']).unsqueeze(-2).to(self.device)
46
+ src = src.to(self.device)
47
+
48
+ # Encoder
49
+ # raise Exception(src.shape, src_mask.shape)
50
+ e_output = model.encode(src, src_mask)
51
+ outputs = torch.LongTensor([[init_tok] for i in range(batch_size)])
52
+ outputs = outputs.to(self.device)
53
+ trg_mask = no_peeking_mask(1, self.device)
54
+
55
+ # Decoder
56
+ out = model.to_logits(model.decode(outputs, e_output, src_mask, trg_mask))
57
+ out = functional.softmax(out, dim=-1)
58
+ probs, ix = out[:, -1].data.topk(self.beam_size)
59
+
60
+ log_scores = torch.Tensor([math.log(p) for p in probs.data.view(-1)]).view(-1, 1)
61
+
62
+ outputs = torch.zeros(row_b, self.max_len).long()
63
+ outputs = outputs.to(self.device)
64
+ outputs[:, 0] = init_tok
65
+ outputs[:, 1] = ix.view(-1)
66
+
67
+ e_outputs = torch.repeat_interleave(e_output, self.beam_size, 0)
68
+
69
+ # raise Exception(outputs[:, :2], e_outputs)
70
+
71
+ return outputs, e_outputs, log_scores
72
+
73
+ def compute_k_best(self, outputs, out, log_scores, i, debug=False):
74
+ """
75
+ Compute k words with the highest conditional probability
76
+ Args:
77
+ outputs: Array has k previous candidate output sequences. [batch_size*beam_size, max_len]
78
+ i: the current timestep to execute. Int
79
+ out: current output of the model at timestep. [batch_size*beam_size, vocab_size]
80
+ log_scores: Conditional probability of past candidates (in outputs) [batch_size * beam_size]
81
+
82
+ Returns:
83
+ new outputs has k best candidate output sequences
84
+ log_scores for each of those candidate
85
+ """
86
+ row_b = len(out); batch_size = row_b // self.beam_size
87
+ eos_id = self.TRG.vocab.stoi['<eos>']
88
+
89
+ probs, ix = out[:, -1].data.topk(self.beam_size)
90
+
91
+ probs_rep = torch.Tensor([[1] + [1e-100] * (self.beam_size-1)]*row_b).view(row_b, self.beam_size).to(self.device)
92
+ ix_rep = torch.LongTensor([[eos_id] + [-1]*(self.beam_size-1)]*row_b).view(row_b, self.beam_size).to(self.device)
93
+
94
+ check_eos = torch.repeat_interleave((outputs[:, i-1] == eos_id).view(row_b, 1), self.beam_size, 1)
95
+
96
+ probs = torch.where(check_eos, probs_rep, probs)
97
+ ix = torch.where(check_eos, ix_rep, ix)
98
+
99
+ # if(debug):
100
+ # print("kprobs before debug: ", probs, probs_rep, ix, ix_rep, log_scores)
101
+
102
+ log_probs = torch.log(probs).to(self.device) + log_scores.to(self.device) # CPU
103
+
104
+ k_probs, k_ix = log_probs.view(batch_size, -1).topk(self.beam_size)
105
+ if(debug):
106
+ print("kprobs_after_select: ", log_probs, k_probs, k_ix)
107
+
108
+ # Use cpu
109
+ k_probs, k_ix = torch.Tensor(k_probs.cpu().data.numpy()), torch.LongTensor(k_ix.cpu().data.numpy())
110
+ row = k_ix // self.beam_size + torch.LongTensor([[v*self.beam_size] for v in range(batch_size)])
111
+ col = k_ix % self.beam_size
112
+ if(debug):
113
+ print("kprobs row/col", row, col, ix[row.view(-1), col.view(-1)])
114
+ assert False
115
+
116
+ outputs[:, :i] = outputs[row.view(-1), :i]
117
+ outputs[:, i] = ix[row.view(-1), col.view(-1)]
118
+ log_scores = k_probs.view(-1, 1)
119
+
120
+ return outputs, log_scores
121
+
122
+ def replace_unknown(self, outputs, sentences, attn, selector_tuple, unknown_token="<unk>"):
123
+ """Replace the unknown words in the outputs with the highest valued attentionized words.
124
+ Args:
125
+ outputs: the output from decoding. [batch, beam] of list of str
126
+ sentences: the original wordings of the sentences. [batch_size, src_len] of str
127
+ attn: the attention received, in the form of list: [layers units of (self-attention, attention) with shapes of [batchbeam, heads, tgt_len, tgt_len] & [batchbeam, heads, tgt_len, src_len] respectively]
128
+ selector_tuple: (layer, head) used to select the attention
129
+ unknown_token: token used for checking. str
130
+ Returns:
131
+ the replaced version, in the same shape as outputs
132
+ """
133
+
134
+ # is_finished = torch.LongTensor([[self.TRG.vocab.stoi['<eos>']] for i in range(self.beam_offset)]).view(-1).to(self.device)
135
+ # unk_token = self.SRC.vocab.stoi['<unk>']
136
+ layer_used, head_used = selector_tuple
137
+ used_attention = attn[layer_used][-1][:, head_used] # it should be [batchbeam, tgt_len, src_len], as we are using the attention to source
138
+ flattened_outputs = outputs.reshape((-1, )) # flatten the outputs back to batchbeam
139
+
140
+ select_id_src = torch.argmax(used_attention, dim=-1).cpu().numpy() # [batchbeam, tgt_len] of best indices. Also convert to numpy version (remove sos not needed as it is attention of outputs)
141
+ beam_size = select_id_src.shape[0] // len(sentences) # used custom-calculated beam_size as we might not output the entirety of beams. See beam_search fn for details
142
+ # select per batchbeam. source batch id is found by dividing batchbeam id per beam; we are selecting [tgt_len] indices from [src_len] tokens; then concat at the first dimensions to retrieve [batch_beam, tgt_len] of replacement tokens
143
+ # need itemgetter / map to retrieve from list
144
+ replace_tokens = [ operator.itemgetter(*src_idx)(sentences[bidx // beam_size]) for bidx, src_idx in enumerate(select_id_src)]
145
+
146
+ # zip together with sentences; then output { the token if not unk / the replacement if is }. Note that this will trim the orig version down to repl size.
147
+ zipped = zip(flattened_outputs, replace_tokens)
148
+ replaced = np.array([[tok if tok != unknown_token else rpl for rpl, tok in zip(repl, orig)] for orig, repl in zipped], dtype=object)
149
+ # reshape back to outputs shape [batch, beam] of list
150
+ return replaced.reshape(outputs.shape)
151
+
152
+ # for i in range(1, self.max_len):
153
+ # ix = attn[0, 0, i-1, :].argmax().data
154
+ # outputs[:, i][outputs[:, i] == unk_token] = sentences[0][ix.data]
155
+ # if torch.equal(outputs[:, i], is_finished):
156
+ # break
157
+ #
158
+ # return outputs
159
+
160
+ def beam_search(self, src, src_lang=None, trg_lang=None, src_tokens=None, n_best=1, length_norm=None, replace_unk=None, debug=False):
161
+ """
162
+ Beam search select k words with the highest conditional probability
163
+ to be the first word of the k candidate output sequences.
164
+ Args:
165
+ src: The batch of sentences, already in [batch_size, tokens] of int
166
+ src_tokens: src in str version, same size as above. Used almost exclusively for replace unknown word
167
+ n_best: number of usable values per beam loaded
168
+ length_norm: if specified, normalize as per (Wu, 2016); note that if not inputted then it will still use __init__ value as default. float
169
+ replace_unk: if specified, do replace unknown word using attention of (layer, head); note that if not inputted, it will still use __init__ value as default. (int, int)
170
+ debug: if true, print some debug information during the search
171
+ Return:
172
+ An array of translated sentences, in list-of-tokens format.
173
+ Either [batch_size, n_best, tgt_len] when n_best > 1
174
+ Or [batch_size, tgt_len] when n_best == 1
175
+ """
176
+ model = self.model
177
+ start_token = const.DEFAULT_SOS if trg_lang is None else generate_language_token(trg_lang)
178
+ outputs, e_outputs, log_scores = self.init_vars(src, start_token=start_token)
179
+
180
+ eos_tok = self.TRG.vocab.stoi[const.DEFAULT_EOS]
181
+ src_mask = (src != self.SRC.vocab.stoi[const.DEFAULT_PAD]).unsqueeze(-2)
182
+ src_mask = torch.repeat_interleave(src_mask, self.beam_size, 0).to(self.device)
183
+ is_finished = torch.LongTensor([[eos_tok] for i in range(self.beam_size*len(src))]).view(-1).to(self.device)
184
+ ind = None
185
+ for i in range(2, self.max_len):
186
+ trg_mask = no_peeking_mask(i, self.device)
187
+
188
+ decoder_output, attn = model.decoder(outputs[:, :i], e_outputs, src_mask, trg_mask, output_attention=True)
189
+ out = model.out(decoder_output)
190
+ out = functional.softmax(out, dim=-1)
191
+ outputs, log_scores = self.compute_k_best(outputs, out, log_scores, i)
192
+
193
+ # Occurrences of end symbols for all input sentences.
194
+ if torch.equal(outputs[:, i], is_finished):
195
+ break
196
+
197
+
198
+ # if(self._replace_unk):
199
+ # outputs = self.replace_unknown(attn, src, outputs)
200
+
201
+ # reshape outputs and log_probs to [batch, beam] numpy array
202
+ batch_size = src.shape[0]
203
+ outputs = outputs.cpu().numpy().reshape((batch_size, self.beam_size, self.max_len))
204
+ log_scores = log_scores.cpu().numpy().reshape((batch_size, self.beam_size))
205
+
206
+ # Get the best sentences for every beam: splice by length and itos the indices, result in an array of tokens
207
+ # also remove the first token in this timestep (as it is sos)
208
+ translated_sentences = np.empty(outputs.shape[:-1], dtype=object)
209
+ trim_and_itos = lambda sent: [self.TRG.vocab.itos[i] for i in sent[1:self._length(sent, eos_tok=eos_tok)]]
210
+ for ba in range(outputs.shape[0]):
211
+ for bm in range(outputs.shape[1]):
212
+ translated_sentences[ba, bm] = trim_and_itos(outputs[ba, bm])
213
+ # raise ValueError(translated_sentences)
214
+ #translated_sentences = np.apply_along_axis(lambda sent: tuple(sent.tolist()[:self._length(sent, eos_tok=eos_tok)]), -1, outputs)
215
+ #translated_sentences = np.vectorize(lambda sent: [self.TRG.vocab.itos[i] for i in sent])(translated_sentences)
216
+ if(replace_unk is None):
217
+ replace_unk = self._replace_unk
218
+ if(replace_unk):
219
+ # replace unknown words per translated sentences. Do it before normalization (since that is independent on actual tokens)
220
+ if(src_tokens is None):
221
+ logging.warn("replace_unknown option enabled but no src_tokens supplied for the task. The method will not run.")
222
+ else:
223
+ translated_sentences = self.replace_unknown(translated_sentences, src_tokens, attn, replace_unk)
224
+
225
+ if(length_norm is None):
226
+ length_norm = self._length_norm
227
+ if(length_norm is not None):
228
+ # raise ValueError(length_norm)
229
+ # perform length normalization calculation and reorganize the sentences accordingly
230
+ lengths = np.apply_along_axis(lambda x: self._length(x, eos_tok=eos_tok), -1, outputs)
231
+ log_scores, indices = self.length_normalize(lengths, log_scores, coff=length_norm)
232
+ translated_sentences = np.array([beams[ids] for beams, ids in zip(translated_sentences, indices)])
233
+ # outputs = np.array([beams[ids] for beams, ids in zip(outputs, indices)])
234
+
235
+ # assert n_best == 1, "Currently unsupported n_best > 1. TODO write."
236
+ if(n_best == 1):
237
+ return translated_sentences[:, 0]
238
+ else:
239
+ return translated_sentences[:, :n_best]
240
+
241
+ def translate_single_sentence(self, src, **kwargs):
242
+ """Translate a single sentence. Currently unused."""
243
+ raise NotImplementedError
244
+ return self.translate_batch_sentence([src], **kwargs)
245
+
246
+ def translate_batch_sentence(self, src, src_lang=None, trg_lang=None, field_processed=False, src_size_limit=None, output_tokens=False, replace_unk=None, debug=False):
247
+ """Translate a batch of sentences together. Currently disabling the synonym func.
248
+ Args:
249
+ src: the batch of sentences to be translated. list of str
250
+ src_lang: the language translated from. Only used with multilingual models, in preprocess. str
251
+ trg_lang: the language to be translated to. Only used with multilingual models, in beam_search. str
252
+ field_processed: bool, if the sentences had been already processed (i.e part of batched validation data)
253
+ src_size_limit: if set, trim the input if it cross this value. Added due to current positional encoding support only <=200 tokens
254
+ output_tokens: the output format. False will give a batch of sentences (str), while True will give batch of tokens (list of str)
255
+ replace_unk: see beam_search for usage. (int, int) or False to suppress __init__ value
256
+ debug: enable to print external values
257
+ Return:
258
+ the result of translation, with format dictated by output_tokens
259
+ """
260
+ self.model.eval()
261
+ # create the indiced batch.
262
+ processed_batch = self.preprocess_batch(src, src_lang=src_lang, field_processed=field_processed, src_size_limit=src_size_limit, output_tokens=True, debug=debug)
263
+ sent_ids, sent_tokens = (processed_batch, None) if(field_processed) else processed_batch
264
+ assert isinstance(sent_ids, torch.Tensor), "sent_ids is instead {}".format(type(sent_ids))
265
+
266
+ batch_start = time.time()
267
+ translated_sentences = self.beam_search(sent_ids, trg_lang=trg_lang, src_tokens=sent_tokens, replace_unk=replace_unk, debug=debug)
268
+ if(debug):
269
+ print("Time performed for batch {}: {:.2f}s".format(sent_ids.shape, time.time() - batch_start))
270
+
271
+ if(not output_tokens):
272
+ translated_sentences = [' '.join(tokens) for tokens in translated_sentences]
273
+
274
+ return translated_sentences
275
+
276
+ def preprocess_batch(self, sentences, src_lang=None, field_processed=False, pad_token="<pad>", src_size_limit=None, output_tokens=False, debug=True):
277
+ """Adding
278
+ src_size_limit: int, option to limit the length of src.
279
+ src_lang: if specified (not None), append this token <{src_lang}> to the start of the batch
280
+ field_processed: bool: if the sentences had been already processed (i.e part of batched validation data)
281
+ output_tokens: if set, output a token version aside the id version, in [batch of [src_len]] str. Note that it won't work with field_processed
282
+ """
283
+ if(field_processed):
284
+ # do nothing, as it had already performed tokenizing/stoi.
285
+ # Still cap the length of the batch due to possible infraction in valid
286
+ if(src_size_limit is not None):
287
+ sentences = sentences[:, :src_size_limit]
288
+ return sentences
289
+ processed_sent = map(self.SRC.preprocess, sentences)
290
+ if(src_lang is not None):
291
+ src_token = generate_language_token(src_lang)
292
+ processed_sent = map(lambda x: [src_token] + x, processed_sent)
293
+ if(src_size_limit):
294
+ processed_sent = map(lambda x: x[:src_size_limit], processed_sent)
295
+ processed_sent = list(processed_sent)
296
+ tokenized_sent = [torch.LongTensor([self._token_to_index(t) for t in s]) for s in processed_sent] # convert to tensors, in indices format
297
+ sentences = Variable(pad_sequence(tokenized_sent, True, padding_value=self.SRC.vocab.stoi[pad_token])) # padding sentences
298
+ if(debug):
299
+ print("Input batch after process: ", sentences.shape, sentences)
300
+
301
+ if(output_tokens):
302
+ return sentences, processed_sent
303
+ else:
304
+ return sentences
305
+
306
+ def translate_batch(self, sentences, **kwargs):
307
+ return self.translate_batch_sentence(sentences, **kwargs)
308
+
309
+ def length_normalize(self, lengths, log_probs, coff=0.6):
310
+ """Normalize the probabilty score as in (Wu 2016). Use pure numpy values
311
+ Args:
312
+ lengths: the length of the hypothesis. [batch, beam] of int->float
313
+ log_probs: the unchanged log probability for the whole hypothesis. [batch, beam] of float
314
+ coff: the alpha coefficient.
315
+ Returns:
316
+ Tuple of (penalized_values, indices) to reorganize outputs."""
317
+ lengths = ((lengths + 5) / 6) ** coff
318
+ penalized_probs = log_probs / lengths
319
+ indices = np.argsort(penalized_probs, axis=-1)[::-1]
320
+ # basically take log_probs values for every batch
321
+ reorganized_probs = np.array([prb[ids] for prb, ids in zip(penalized_probs, indices)])
322
+ return reorganized_probs, indices
323
+
324
+ def _length(self, tokens, eos_tok=None):
325
+ """Retrieve the first location of eos_tok as length; else return the entire length"""
326
+ if(eos_tok is None):
327
+ eos_tok = self.TRG.vocab.stoi[const.DEFAULT_EOS]
328
+ eos, = np.nonzero(tokens==eos_tok)
329
+ return len(tokens) if len(eos) == 0 else eos[0]
330
+
331
+ def _token_to_index(self, tok):
332
+ """Override to select, depending on the self._use_synonym param"""
333
+ if(self._use_synonym):
334
+ return super(BeamSearch, self)._token_to_index(tok)
335
+ else:
336
+ return self.SRC.vocab.stoi[tok]
modules/inference/beam_search1.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import math, time, operator
4
+ import torch.nn.functional as functional
5
+ import torch.nn as nn
6
+ from torch.autograd import Variable
7
+ from torch.nn.utils.rnn import pad_sequence
8
+
9
+ from modules.inference.decode_strategy import DecodeStrategy
10
+ from utils.misc import no_peeking_mask
11
+
12
+ class BeamSearch1(DecodeStrategy):
13
+ def __init__(self, model, max_len, device, beam_size=5, use_synonym_fn=False, replace_unk=None):
14
+ """
15
+ Args:
16
+ model: the used model
17
+ max_len: the maximum timestep to be used
18
+ device: the device to perform calculation
19
+ beam_size: the size of the beam itself
20
+ use_synonym_fn: if set, use the get_synonym fn from wordnet to try replace <unk>
21
+ replace_unk: a tuple of [layer, head] designation, to replace the unknown word by chosen attention
22
+ """
23
+ super(BeamSearch1, self).__init__(model, max_len, device)
24
+ self.beam_size = beam_size
25
+ self._use_synonym = use_synonym_fn
26
+ self._replace_unk = replace_unk
27
+ # print("Init BeamSearch ----------------")
28
+
29
+ def trg_init_vars(self, src, batch_size, trg_init_token, trg_eos_token, single_src_mask):
30
+ """
31
+ Calculate the required matrices during translation after the model is finished
32
+ Input:
33
+ :param src: The batch of sentences
34
+
35
+ Output: Initialize the first character includes outputs, e_outputs, log_scores
36
+ """
37
+ # Initialize target sequence (start with '<sos>' token) [batch_size x k x max_len]
38
+ trg = torch.zeros(batch_size, self.beam_size, self.max_len, device=self.device).long()
39
+ trg[:, :, 0] = trg_init_token
40
+
41
+ # Precalc output from model's encoder
42
+ e_out = self.model.encoder(src, single_src_mask) # [batch_size x S x d_model]
43
+ # Output model prob
44
+ trg_mask = no_peeking_mask(1, device=self.device)
45
+ # [batch_size x 1]
46
+ inp_decoder = trg[:, 0, 0].view(batch_size, 1)
47
+ # [batch_size x 1 x vocab_size]
48
+ prob = self.model.out(self.model.decoder(inp_decoder, e_out, single_src_mask, trg_mask))
49
+ prob = functional.softmax(prob, dim=-1)
50
+
51
+ # [batch_size x 1 x k]
52
+ k_prob, k_index = torch.topk(prob, self.beam_size, dim=-1)
53
+ trg[:, :, 1] = k_index.view(batch_size, self.beam_size)
54
+ # Init log scores from k beams [batch_size x k x 1]
55
+ log_scores = torch.log(k_prob.view(batch_size, self.beam_size, 1))
56
+
57
+ # Repeat encoder's output k times for searching [(k * batch_size) x S x d_model]
58
+ e_outs = torch.repeat_interleave(e_out, self.beam_size, dim=0)
59
+ src_mask = torch.repeat_interleave(single_src_mask, self.beam_size, dim=0)
60
+
61
+ # Create mask for checking eos
62
+ sent_eos = torch.tensor([trg_eos_token for _ in range(self.beam_size)], device=self.device).view(1, self.beam_size)
63
+
64
+ return sent_eos, log_scores, e_outs, e_out, src_mask, trg
65
+
66
+ def compute_k_best(self, outputs, out, log_scores, i, debug=False):
67
+ """
68
+ Compute k words with the highest conditional probability
69
+ Args:
70
+ outputs: Array has k previous candidate output sequences. [batch_size*beam_size, max_len]
71
+ i: the current timestep to execute. Int
72
+ out: current output of the model at timestep. [batch_size*beam_size, vocab_size]
73
+ log_scores: Conditional probability of past candidates (in outputs) [batch_size * beam_size]
74
+
75
+ Returns:
76
+ new outputs has k best candidate output sequences
77
+ log_scores for each of those candidate
78
+ """
79
+ row_b = len(out);
80
+ batch_size = row_b // self.beam_size
81
+ eos_id = self.TRG.vocab.stoi['<eos>']
82
+
83
+ probs, ix = out[:, -1].data.topk(self.beam_size)
84
+
85
+ probs_rep = torch.Tensor([[1] + [1e-100] * (self.beam_size-1)]*row_b).view(row_b, self.beam_size).to(self.device)
86
+ ix_rep = torch.LongTensor([[eos_id] + [-1]*(self.beam_size-1)]*row_b).view(row_b, self.beam_size).to(self.device)
87
+
88
+ check_eos = torch.repeat_interleave((outputs[:, i-1] == eos_id).view(row_b, 1), self.beam_size, 1)
89
+
90
+ probs = torch.where(check_eos, probs_rep, probs)
91
+ ix = torch.where(check_eos, ix_rep, ix)
92
+
93
+ log_probs = torch.log(probs).to(self.device) + log_scores.to(self.device) # CPU
94
+
95
+ k_probs, k_ix = log_probs.view(batch_size, -1).topk(self.beam_size)
96
+ if(debug):
97
+ print("kprobs_after_select: ", log_probs, k_probs, k_ix)
98
+
99
+ # Use cpu
100
+ k_probs, k_ix = torch.Tensor(k_probs.cpu().data.numpy()), torch.LongTensor(k_ix.cpu().data.numpy())
101
+ row = k_ix // self.beam_size + torch.LongTensor([[v*self.beam_size] for v in range(batch_size)])
102
+ col = k_ix % self.beam_size
103
+ if(debug):
104
+ print("kprobs row/col", row, col, ix[row.view(-1), col.view(-1)])
105
+ assert False
106
+
107
+ outputs[:, :i] = outputs[row.view(-1), :i]
108
+ outputs[:, i] = ix[row.view(-1), col.view(-1)]
109
+ log_scores = k_probs.view(-1, 1)
110
+
111
+ return outputs, log_scores
112
+
113
+ def replace_unknown(self, outputs, sentences, attn, selector_tuple, unknown_token="<unk>"):
114
+ """Replace the unknown words in the outputs with the highest valued attentionized words.
115
+ Args:
116
+ outputs: the output from decoding. [batchbeam] of list of str, with maximum values being
117
+ sentences: the original wordings of the sentences. [batch_size, src_len] of str
118
+ attn: the attention received, in the form of list: [layers units of (self-attention, attention) with shapes of [batchbeam, heads, tgt_len, tgt_len] & [batchbeam, heads, tgt_len, src_len] respectively]
119
+ selector_tuple: (layer, head) used to select the attention
120
+ unknown_token: token used for
121
+ Returns:
122
+ the replaced version, in the same shape as outputs
123
+ """
124
+ layer_used, head_used = selector_tuple
125
+ # used_attention = attn[layer_used][-1][:, head_used] # it should be [batchbeam, tgt_len, src_len], as we are using the attention to source
126
+ inx = torch.arange(start=0,end=len(attn)-1, step=self.beam_size)
127
+ used_attention = attn[inx]
128
+ select_id_src = torch.argmax(used_attention, dim=-1).cpu().numpy() # [batchbeam, tgt_len] of best indices. Also convert to numpy version (remove sos not needed as it is attention of outputs)
129
+ # print(select_id_src, len(select_id_src))
130
+ beam_size = select_id_src.shape[0] // len(sentences) # used custom-calculated beam_size as we might not output the entirety of beams. See beam_search fn for details
131
+ # print("beam: ", beam_size)
132
+ # select per batchbeam. source batch id is found by dividing batchbeam id per beam; we are selecting [tgt_len] indices from [src_len] tokens; then concat at the first dimensions to retrieve [batch_beam, tgt_len] of replacement tokens
133
+ # need itemgetter / map to retrieve from list
134
+ # print([ operator.itemgetter(*src_idx)(sentences[bidx // beam_size]) for bidx, src_idx in enumerate(select_id_src)])
135
+ # print([print(sentences[bidx // beam_size], src_idx) for bidx, src_idx in enumerate(select_id_src)])
136
+ # replace_tokens = [ operator.itemgetter(*src_idx)(sentences[bidx // beam_size]) for bidx, src_idx in enumerate(select_id_src)]
137
+
138
+ for i in range(len(outputs)):
139
+ for j in range(len(outputs[i])):
140
+ if outputs[i][j] == unknown_token:
141
+ outputs[i][j] = sentences[i][select_id_src[i][j]]
142
+
143
+ # print(sentences[0][0], outputs[0][0])
144
+
145
+ # print(i)
146
+ # zip together with sentences; then output { the token if not unk / the replacement if is }. Note that this will trim the orig version down to repl size.
147
+ # replaced = [ [tok if tok != unknown_token else rpl for rpl, tok in zip(repl, orig)] for orig, repl in zipped ]
148
+
149
+ # return replaced
150
+ return outputs
151
+
152
+ # def beam_search(self, src, max_len, device, k=4):
153
+ def beam_search(self, src, src_tokens=None, n_best=1, debug=False):
154
+ """
155
+ Beam search for a single sentence
156
+ Args:
157
+ model : a Transformer instance
158
+ src : a batch (tokenized + numerized) sentence [batch_size x S]
159
+ Returns:
160
+ trg : a batch (tokenized + numerized) sentence [batch_size x T]
161
+ """
162
+ src = src.to(self.device)
163
+ trg_init_token = self.TRG.vocab.stoi["<sos>"]
164
+ trg_eos_token = self.TRG.vocab.stoi["<eos>"]
165
+ single_src_mask = (src != self.SRC.vocab.stoi['<pad>']).unsqueeze(1).to(self.device)
166
+ batch_size = src.size(0)
167
+
168
+ sent_eos, log_scores, e_outs, e_out, src_mask, trg = self.trg_init_vars(src, batch_size, trg_init_token, trg_eos_token, single_src_mask)
169
+
170
+ # The batch indexes
171
+ batch_index = torch.arange(batch_size)
172
+ finished_batches = torch.zeros(batch_size, device=self.device).long()
173
+
174
+ log_attn = torch.zeros([self.beam_size*batch_size, self.max_len, len(src[0])])
175
+
176
+ # Iteratively searching
177
+ for i in range(2, self.max_len):
178
+ trg_mask = no_peeking_mask(i, self.device)
179
+
180
+ # Flatten trg tensor for feeding into model [(k * batch_size) x i]
181
+ inp_decoder = trg[batch_index, :, :i].view(self.beam_size * len(batch_index), i)
182
+ # Output model prob [(k * batch_size) x i x vocab_size]
183
+ current_decode, attn = self.model.decoder(inp_decoder, e_outs, src_mask, trg_mask, output_attention=True)
184
+ # print(len(attn[0]))
185
+
186
+ prob = self.model.out(current_decode)
187
+ prob = functional.softmax(prob, dim=-1)
188
+
189
+ # Only care the last prob i-th
190
+ # [(k * batch_size) x 1 x vocab_size]
191
+ prob = prob[:, i-1, :].view(self.beam_size * len(batch_index), 1, -1)
192
+
193
+ # Truncate prob to top k [(k * batch_size) x 1 x k]
194
+ k_prob, k_index = prob.data.topk(self.beam_size, dim=-1)
195
+
196
+ # Deflatten k_prob & k_index
197
+ k_prob = k_prob.view(len(batch_index), self.beam_size, 1, self.beam_size)
198
+ k_index = k_index.view(len(batch_index), self.beam_size, 1, self.beam_size)
199
+
200
+ # Preserve eos beams
201
+ # [batch_size x k] -> view -> [batch_size x k x 1 x 1] (broadcastable)
202
+ eos_mask = (trg[batch_index, :, i-1] == trg_eos_token).view(len(batch_index), self.beam_size, 1, 1)
203
+ k_prob.masked_fill_(eos_mask, 1.0)
204
+ k_index.masked_fill_(eos_mask, trg_eos_token)
205
+
206
+ # Find the best k cases
207
+ # Compute log score at i-th timestep
208
+ # [batch_size x k x 1 x 1] + [batch_size x k x 1 x k] = [batch_size x k x 1 x k]
209
+ combine_probs = log_scores[batch_index].unsqueeze(-1) + torch.log(k_prob)
210
+ # [batch_size x k x 1]
211
+ log_scores[batch_index], positions = torch.topk(combine_probs.view(len(batch_index), self.beam_size * self.beam_size, 1), self.beam_size, dim=1)
212
+
213
+ # The rows selected from top k
214
+ rows = positions.view(len(batch_index), self.beam_size) // self.beam_size
215
+ # The indexes in vocab respected to these rows
216
+ cols = positions.view(len(batch_index), self.beam_size) % self.beam_size
217
+
218
+ batch_sim = torch.arange(len(batch_index)).view(-1, 1)
219
+ trg[batch_index, :, :] = trg[batch_index.view(-1, 1), rows, :]
220
+ trg[batch_index, :, i] = k_index[batch_sim, rows, :, cols].view(len(batch_index), self.beam_size)
221
+
222
+ # Update attn
223
+ inx = torch.repeat_interleave(finished_batches, self.beam_size, dim=0)
224
+ batch_attn = torch.nonzero(inx == 0).view(-1)
225
+ # import copy
226
+ # x = copy.deepcopy(attn[0][-1][:, 0].to("cpu"))
227
+ # log_attn[batch_attn, :i, :] = x
228
+
229
+ # if i == 7:
230
+ # print(log_attn[batch_attn, :i, :].shape, attn[0][-1][:, 0].shape)
231
+ # print(log_attn[batch_attn, :i, :])
232
+ # Update which sentences finished all its beams
233
+ mask = (trg[:, :, i] == sent_eos).all(1).view(-1).to(self.device)
234
+ finished_batches.masked_fill_(mask, value=1)
235
+ cnt = torch.sum(finished_batches).item()
236
+ if cnt == batch_size:
237
+ break
238
+
239
+ # # Continue with remaining batches (if any)
240
+ batch_index = torch.nonzero(finished_batches == 0).view(-1)
241
+ e_outs = torch.repeat_interleave(e_out[batch_index], self.beam_size, dim=0)
242
+ src_mask = torch.repeat_interleave(single_src_mask[batch_index], self.beam_size, dim=0)
243
+ # End loop
244
+
245
+ # Get the best beam
246
+ log_scores = log_scores.view(batch_size, self.beam_size)
247
+ results = []
248
+ for t, j in enumerate(torch.argmax(log_scores, dim=-1)):
249
+ sent = []
250
+ for i in range(self.max_len):
251
+ token_id = trg[t, j.item(), i].item()
252
+ if token_id == trg_init_token:
253
+ continue
254
+ if token_id == trg_eos_token:
255
+ break
256
+ sent.append(self.TRG.vocab.itos[token_id])
257
+ results.append(sent)
258
+
259
+ # if(self._replace_unk and src_tokens is not None):
260
+ # # replace unknown words per translated sentences.
261
+ # # NOTE: lacking a src_tokens does not raise any warning. Add that in when logging module is available, to support error catching
262
+ # # print("Replace unk -----------------------")
263
+ # results = self.replace_unknown(results, src_tokens, log_attn, self._replace_unk)
264
+
265
+ return results
266
+
267
+ def translate_single_sentence(self, src, **kwargs):
268
+ """Translate a single sentence. Currently unused."""
269
+ raise NotImplementedError
270
+ return self.translate_batch_sentence([src], **kwargs)
271
+
272
+ def translate_batch_sentence(self, src, field_processed=False, src_size_limit=None, output_tokens=False, debug=False):
273
+ """Translate a batch of sentences together. Currently disabling the synonym func.
274
+ Args:
275
+ src: the batch of sentences to be translated
276
+ field_processed: bool, if the sentences had been already processed (i.e part of batched validation data)
277
+ src_size_limit: if set, trim the input if it cross this value. Added due to current positional encoding support only <=200 tokens
278
+ output_tokens: the output format. False will give a batch of sentences (str), while True will give batch of tokens (list of str)
279
+ debug: enable to print external values
280
+ Return:
281
+ the result of translation, with format dictated by output_tokens
282
+ """
283
+ # start = time.time()
284
+
285
+ self.model.eval()
286
+ # create the indiced batch.
287
+ processed_batch = self.preprocess_batch(src, field_processed=field_processed, src_size_limit=src_size_limit, output_tokens=True, debug=debug)
288
+ # print("Time preprocess_batch: ", time.time()-start)
289
+
290
+ sent_ids, sent_tokens = (processed_batch, None) if(field_processed) else processed_batch
291
+ assert isinstance(sent_ids, torch.Tensor), "sent_ids is instead {}".format(type(sent_ids))
292
+
293
+ translated_sentences = self.beam_search(sent_ids, src_tokens=sent_tokens, debug=debug)
294
+
295
+ # print("Time for one batch: ", time.time()-batch_start)
296
+
297
+ # if time.time()-batch_start > 2:
298
+ # [print("len src >2 : ++++++", len(i.split())) for i in src]
299
+ # [print("len translate >2: ++++++", len(i)) for i in translated_sentences]
300
+ # else:
301
+ # [print("len src : ====", len(i.split())) for i in src]
302
+ # [print("len translate : ====", len(i)) for i in translated_sentences]
303
+ # print("=====================================")
304
+
305
+ # time.sleep(4)
306
+ if(debug):
307
+ print("Time performed for batch {}: {:.2f}s".format(sent_ids.shape))
308
+
309
+ if(not output_tokens):
310
+ translated_sentences = [' '.join(tokens) for tokens in translated_sentences]
311
+
312
+ return translated_sentences
313
+
314
+ def preprocess_batch(self, sentences, field_processed=False, pad_token="<pad>", src_size_limit=None, output_tokens=False, debug=True):
315
+ """Adding
316
+ src_size_limit: int, option to limit the length of src.
317
+ field_processed: bool: if the sentences had been already processed (i.e part of batched validation data)
318
+ output_tokens: if set, output a token version aside the id version, in [batch of [src_len]] str. Note that it won't work with field_processed
319
+ """
320
+
321
+ if(field_processed):
322
+ # do nothing, as it had already performed tokenizing/stoi
323
+ return sentences
324
+ processed_sent = map(self.SRC.preprocess, sentences)
325
+ if(src_size_limit):
326
+ processed_sent = map(lambda x: x[:src_size_limit], processed_sent)
327
+ processed_sent = list(processed_sent)
328
+ tokenized_sent = [torch.LongTensor([self._token_to_index(t) for t in s]) for s in processed_sent] # convert to tensors, in indices format
329
+ sentences = Variable(pad_sequence(tokenized_sent, True, padding_value=self.SRC.vocab.stoi[pad_token])) # padding sentences
330
+ if(debug):
331
+ print("Input batch after process: ", sentences.shape, sentences)
332
+
333
+ if(output_tokens):
334
+ return sentences, processed_sent
335
+ else:
336
+ return sentences
337
+
338
+ def translate_batch(self, sentences, **kwargs):
339
+ return self.translate_batch_sentence(sentences, **kwargs)
340
+
341
+ def _token_to_index(self, tok):
342
+ """Override to select, depending on the self._use_synonym param"""
343
+ if(self._use_synonym):
344
+ return super(BeamSearch1, self)._token_to_index(tok)
345
+ else:
346
+ return self.SRC.vocab.stoi[tok]
modules/inference/decode_strategy.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Variable
3
+ from utils.data import get_synonym
4
+ from torch.nn.utils.rnn import pad_sequence
5
+ import abc
6
+
7
+ class DecodeStrategy(object):
8
+ """
9
+ Base, abstract class for generation strategies. Contain specific call to base model that use it
10
+
11
+ """
12
+ def __init__(self, model, max_len, device):
13
+ self.model = model
14
+ self.max_len = max_len
15
+ self.device = device
16
+
17
+ @property
18
+ def SRC(self):
19
+ return self.model.SRC
20
+
21
+ @property
22
+ def TRG(self):
23
+ return self.model.TRG
24
+
25
+ @abc.abstractmethod
26
+ def translate_single(self, src_lang, trg_lang, sentences):
27
+ """Translate a single sentence. Might be useful as backcompatibility"""
28
+ raise NotImplementedError
29
+
30
+ @abc.abstractmethod
31
+ def translate_batch(self, src_lang, trg_lang, sentences):
32
+ """Translate a batch of sentences.
33
+ Args:
34
+ sentences: The sentences, formatted as [batch_size] Tensor of str
35
+ Returns:
36
+ The detokenized output, most commonly [batch_size] of str
37
+ """
38
+
39
+ raise NotImplementedError
40
+
41
+ @abc.abstractmethod
42
+ def replace_unknown(self, *args):
43
+ """Replace unknown words from batched sentences"""
44
+ raise NotImplementedError
45
+
46
+ def preprocess_batch(self, lang, sentences, pad_token="<pad>"):
47
+ """Feed a unprocessed batch into the torchtext.Field of source.
48
+ Args:
49
+ sentences: [batch_size] of str
50
+ pad_token: the pad token used to pad the sentences
51
+ Returns:
52
+ the sentences in Tensor format, padded with pad_value"""
53
+ processed_sent = list(map(self.SRC.preprocess, sentences)) # tokenizing
54
+ tokenized_sent = [Torch.LongTensor([self._token_to_index(t) for t in s]) for s in processed_sent] # convert to tensors and indices
55
+ sentences = Variable(pad_sequence(tokenized_sent, True, padding_value=self.SRC.vocab.stoi[pad_token])) # padding sentences
56
+ return sentences
57
+
58
+ def _token_to_index(self, tok):
59
+ """Implementing get_synonym as default. Override if want to use default behavior (<unk> for unknown words, independent of wordnet)"""
60
+ if self.SRC.vocab.stoi[tok] != self.SRC.vocab.stoi['<eos>']:
61
+ return self.SRC.vocab.stoi[tok]
62
+ return get_synonym(tok, self.SRC)
modules/inference/greedy_search.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##@title Beam của mình
2
+ import numpy as np
3
+ import torch
4
+ import math
5
+ import torch.nn.functional as functional
6
+ import torch.nn as nn
7
+ from torch.autograd import Variable
8
+
9
+ from modules.inference.decode_strategy import DecodeStrategy
10
+ from utils.misc import no_peeking_mask
11
+
12
+ class GreedySearch(DecodeStrategy):
13
+ def __init__(self, model, max_len, device, replace_unk=None):
14
+ """
15
+ :param beam_size
16
+ :param batch_size
17
+ :param beam_offset
18
+ """
19
+ super(GreedySearch, self).__init__(model, max_len, device)
20
+ # self.replace_unk = replace_unk
21
+ # raise NotImplementedError("Replace unk was yeeted from base class DecodeStrategy. Fix first.")
22
+
23
+ def initilize_value(self, sentences):
24
+ """
25
+ Calculate the required matrices during translation after the model is finished
26
+ Input:
27
+ :param src: Sentences
28
+
29
+ Output: Initialize the first character includes outputs, e_outputs, log_scores
30
+ """
31
+ batch_size=len(sentences)
32
+ init_tok = self.TRG.vocab.stoi['<sos>']
33
+ src_mask = (sentences != self.SRC.vocab.stoi['<pad>']).unsqueeze(-2)
34
+ eos_tok = self.TRG.vocab.stoi['<eos>']
35
+
36
+ # Encoder
37
+ e_output = self.model.encoder(sentences, src_mask)
38
+
39
+ out = torch.LongTensor([[init_tok] for i in range(batch_size)])
40
+ outputs = torch.zeros(batch_size, self.max_len).long()
41
+ outputs[:, :1] = out
42
+
43
+ outputs = outputs.to(self.device)
44
+ is_finished = torch.LongTensor([[eos_tok] for i in range(batch_size)]).view(-1).to(self.device)
45
+ return eos_tok, src_mask, is_finished, e_output, outputs
46
+
47
+ def create_trg_mask(self, i, device):
48
+ return no_peeking_mask(i, device)
49
+
50
+ def current_predict(self, outputs, e_output, src_mask, trg_mask):
51
+ model = self.model
52
+ # out, attn = model.out(model.decoder(outputs, e_output, src_mask, trg_mask))
53
+ decoder_output, attn = model.decoder(outputs, e_output, src_mask, trg_mask, output_attention=True)
54
+ # total_time_decode += time.time()-decode_time
55
+ out = model.out(decoder_output)
56
+
57
+ out = functional.softmax(out, dim=-1)
58
+ return out, attn
59
+
60
+ def greedy_search(self, sentences, sampling_temp=0.0, keep_topk=1):
61
+ batch_size=len(sentences)
62
+
63
+ eos_tok, src_mask, is_finished, e_output, outputs = self.initilize_value(sentences)
64
+
65
+ for i in range(1, self.max_len):
66
+ out, attn = self.current_predict(outputs[:, :i], e_output, src_mask, self.create_trg_mask(i, self.device))
67
+ topk_ix, topk_prob = self.sample_with_temperature(out[:, -1], sampling_temp, keep_topk)
68
+ outputs[:, i] = topk_ix.view(-1)
69
+ if torch.equal(outputs[:, i], is_finished):
70
+ break
71
+
72
+ # if self.replace_unk == True:
73
+ # outputs = self.replace_unknown(attn, sentences, outputs)
74
+
75
+ # print("\n".join([' '.join([self.TRG.vocab.itos[tok] for tok in line[1:]]) for line in outputs]))
76
+ # Write to file or Print to the console
77
+ translated_sentences = []
78
+ # Get the best sentences: idx = 0 + i*k
79
+ for i in range(0, len(outputs)):
80
+ is_eos = torch.nonzero(outputs[i]==eos_tok)
81
+ if len(is_eos) == 0:
82
+ # if there is no sequence end, remove
83
+ sent = outputs[i, 1:]
84
+ else:
85
+ length = is_eos[0]
86
+ sent = outputs[i, 1:length]
87
+ translated_sentences.append([self.TRG.vocab.itos[tok] for tok in sent])
88
+
89
+ return translated_sentences
90
+
91
+ def sample_with_temperature(self, logits, sampling_temp, keep_topk):
92
+ if sampling_temp == 0.0 or keep_topk == 1:
93
+ # For temp=0.0, take the argmax to avoid divide-by-zero errors.
94
+ # keep_topk=1 is also equivalent to argmax.
95
+ topk_scores, topk_ids = logits.topk(1, dim=-1)
96
+ if sampling_temp > 0:
97
+ topk_scores /= sampling_temp
98
+ else:
99
+ logits = torch.div(logits, sampling_temp)
100
+
101
+ if keep_topk > 0:
102
+ top_values, top_indices = torch.topk(logits, keep_topk, dim=1)
103
+ kth_best = top_values[:, -1].view([-1, 1])
104
+ kth_best = kth_best.repeat([1, logits.shape[1]]).float()
105
+
106
+ # Set all logits that are not in the top-k to -10000.
107
+ # This puts the probabilities close to 0.
108
+ ignore = torch.lt(logits, kth_best)
109
+ logits = logits.masked_fill(ignore, -10000)
110
+
111
+ dist = torch.distributions.Multinomial(
112
+ logits=logits, total_count=1)
113
+ topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True)
114
+ topk_scores = logits.gather(dim=1, index=topk_ids)
115
+ return topk_ids, topk_scores
116
+
117
+ def translate_batch(self, sentences, src_size_limit, output_tokens=True, debug=False):
118
+ # super(BeamSearch, self).__init__()
119
+ sentences = self.preprocess_batch(sentences).to(self.device)
120
+ return self.greedy_search(sentences, 0.2, 2)
121
+ # print(self.initilize_value(sentences))
modules/inference/prototypes.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, time
2
+ import torch.nn.functional as functional
3
+ from torch.autograd import Variable
4
+ from torch.nn.utils.rnn import pad_sequence
5
+
6
+ from modules.inference.beam_search import BeamSearch
7
+ from utils.data import generate_language_token
8
+ import modules.constants as const
9
+
10
+ def generate_subsequent_mask(sz, device):
11
+ return torch.triu(
12
+ torch.ones(sz, sz, dtype=torch.int, device=device)
13
+ ).transpose(0, 1).unsqueeze(0)
14
+
15
+ class BeamSearch2(BeamSearch):
16
+ """
17
+ Same with BeamSearch2 class.
18
+ Difference: remove the sentence which its beams terminated (reached <eos> token) from the time step loop.
19
+ Update to reuse functions already coded in normal BeamSearch. Note that replacing unknown words & n_best is not available.
20
+ """
21
+ def _convert_to_sent(self, sent_id, eos_token_id):
22
+ eos = torch.nonzero(sent_id == eos_token_id).view(-1)
23
+ t = eos[0] if len(eos) > 0 else len(sent_id)
24
+ return [self.TRG.vocab.itos[j] for j in sent_id[1 : t]]
25
+
26
+ @torch.no_grad()
27
+ def beam_search(self, src, src_lang=None, trg_lang=None, src_tokens=None, n_best=1, debug=False):
28
+ """
29
+ Beam search select k words with the highest conditional probability
30
+ to be the first word of the k candidate output sequences.
31
+ Args:
32
+ src: The batch of sentences, already in [batch_size, tokens] of int
33
+ src_tokens: src in str version, same size as above
34
+ n_best: number of usable values per beam loaded (Not implemented)
35
+ debug: if true, print some debug information during the search
36
+ Return:
37
+ An array of translated sentences, in list-of-tokens format. TODO convert [batch_size, n_best, tgt_len] instead of [batch_size, tgt_len]
38
+ """
39
+ # Create some local variable
40
+ src_field, trg_field = self.SRC, self.TRG
41
+ sos_token = generate_language_token(trg_lang) if trg_lang is not None else const.DEFAULT_SOS
42
+ init_token = trg_field.vocab.stoi[sos_token]
43
+ eos_token_id = trg_field.vocab.stoi[const.DEFAULT_EOS]
44
+ src = src.to(self.device)
45
+
46
+ batch_size = src.size(0)
47
+ model = self.model
48
+ k = self.beam_size
49
+ max_len = self.max_len
50
+ device = self.device
51
+
52
+ # Initialize target sequence (start with '<sos>' token) [batch_size x k x max_len]
53
+ trg = torch.zeros(batch_size, k, max_len, device=device).long()
54
+ trg[:, :, 0] = init_token
55
+
56
+ # Precalc output from model's encoder
57
+ single_src_mask = (src != src_field.vocab.stoi['<pad>']).unsqueeze(1).to(device)
58
+ e_out = model.encoder(src, single_src_mask) # [batch_size x S x d_model]
59
+
60
+ # Output model prob
61
+ trg_mask = generate_subsequent_mask(1, device=device)
62
+ # [batch_size x 1]
63
+ inp_decoder = trg[:, 0, 0].view(batch_size, 1)
64
+ # [batch_size x 1 x vocab_size]
65
+ prob = model.out(model.decoder(inp_decoder, e_out, single_src_mask, trg_mask))
66
+ prob = functional.softmax(prob, dim=-1)
67
+
68
+ # [batch_size x 1 x k]
69
+ k_prob, k_index = torch.topk(prob, k, dim=-1)
70
+ trg[:, :, 1] = k_index.view(batch_size, k)
71
+ # Init log scores from k beams [batch_size x k x 1]
72
+ log_scores = torch.log(k_prob.view(batch_size, k, 1))
73
+
74
+ # Repeat encoder's output k times for searching [(k * batch_size) x S x d_model]
75
+ e_outs = torch.repeat_interleave(e_out, k, dim=0)
76
+ src_mask = torch.repeat_interleave(single_src_mask, k, dim=0)
77
+
78
+ # Create mask for checking eos
79
+ sent_eos = torch.tensor([eos_token_id for _ in range(k)], device=device).view(1, k)
80
+
81
+ # The batch indexes
82
+ batch_index = torch.arange(batch_size)
83
+ finished_batches = torch.zeros(batch_size, device=device).long()
84
+
85
+ # Iteratively searching
86
+ for i in range(2, max_len):
87
+ trg_mask = generate_subsequent_mask(i, device)
88
+
89
+ # Flatten trg tensor for feeding into model [(k * batch_size) x i]
90
+ inp_decoder = trg[batch_index, :, :i].view(k * len(batch_index), i)
91
+ # Output model prob [(k * batch_size) x i x vocab_size]
92
+ prob = model.out(model.decoder(inp_decoder, e_outs, src_mask, trg_mask))
93
+ prob = functional.softmax(prob, dim=-1)
94
+
95
+ # Only care the last prob i-th
96
+ # [(k * batch_size) x 1 x vocab_size]
97
+ prob = prob[:, i-1, :].view(k * len(batch_index), 1, -1)
98
+
99
+ # Truncate prob to top k [(k * batch_size) x 1 x k]
100
+ k_prob, k_index = prob.data.topk(k, dim=-1)
101
+
102
+ # Deflatten k_prob & k_index
103
+ k_prob = k_prob.view(len(batch_index), k, 1, k)
104
+ k_index = k_index.view(len(batch_index), k, 1, k)
105
+
106
+ # Preserve eos beams
107
+ # [batch_size x k] -> view -> [batch_size x k x 1 x 1] (broadcastable)
108
+ eos_mask = (trg[batch_index, :, i-1] == eos_token_id).view(len(batch_index), k, 1, 1)
109
+ k_prob.masked_fill_(eos_mask, 1.0)
110
+ k_index.masked_fill_(eos_mask, eos_token_id)
111
+
112
+ # Find the best k cases
113
+ # Compute log score at i-th timestep
114
+ # [batch_size x k x 1 x 1] + [batch_size x k x 1 x k] = [batch_size x k x 1 x k]
115
+ combine_probs = log_scores[batch_index].unsqueeze(-1) + torch.log(k_prob)
116
+ # [batch_size x k x 1]
117
+ log_scores[batch_index], positions = torch.topk(combine_probs.view(len(batch_index), k * k, 1), k, dim=1)
118
+
119
+ # The rows selected from top k
120
+ rows = positions.view(len(batch_index), k) // k
121
+ # The indexes in vocab respected to these rows
122
+ cols = positions.view(len(batch_index), k) % k
123
+
124
+ batch_sim = torch.arange(len(batch_index)).view(-1, 1)
125
+ trg[batch_index, :, :] = trg[batch_index.view(-1, 1), rows, :]
126
+ trg[batch_index, :, i] = k_index[batch_sim, rows, :, cols].view(len(batch_index), k)
127
+
128
+ # Update which sentences finished all its beams
129
+ mask = (trg[:, :, i] == sent_eos).all(1).view(-1).to(device)
130
+ finished_batches.masked_fill_(mask, value=1)
131
+ cnt = torch.sum(finished_batches).item()
132
+ if cnt == batch_size:
133
+ break
134
+
135
+ # Continue with remaining batches (if any)
136
+ batch_index = torch.nonzero(finished_batches == 0).view(-1)
137
+ e_outs = torch.repeat_interleave(e_out[batch_index], k, dim=0)
138
+ src_mask = torch.repeat_interleave(single_src_mask[batch_index], k, dim=0)
139
+ # End loop
140
+
141
+ # Get the best beam
142
+ log_scores = log_scores.view(batch_size, k)
143
+ results = [self._convert_to_sent(trg[t, j.item(), :], eos_token_id) for t, j in enumerate(torch.argmax(log_scores, dim=-1))]
144
+ return results
modules/inference/sampling_temperature.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##@title Beam của mình
2
+ import numpy as np
3
+ import torch
4
+ import math
5
+ import torch.nn.functional as functional
6
+ import torch.nn as nn
7
+ from torch.autograd import Variable
8
+
9
+ from modules.inference.decode_strategy import DecodeStrategy
10
+ from utils.misc import no_peeking_mask
11
+
12
+ class GreedySearch(DecodeStrategy):
13
+ def __init__(self, model, max_len, device, replace_unk=True):
14
+ """
15
+ :param beam_size
16
+ :param batch_size
17
+ :param beam_offset
18
+ """
19
+ super(GreedySearch, self).__init__(model, max_len, device)
20
+ self.batch_size = batch_size
21
+ self.replace_unk = replace_unk
22
+ raise NotImplementedError("Replace unk was yeeted from base class DecodeStrategy. Fix first.")
23
+
24
+ def initilize_value(self, sentences):
25
+ """
26
+ Calculate the required matrices during translation after the model is finished
27
+ Input:
28
+ :param src: Sentences
29
+
30
+ Output: Initialize the first character includes outputs, e_outputs, log_scores
31
+ """
32
+
33
+ init_tok = self.TRG.vocab.stoi['<sos>']
34
+ src_mask = (sentences != self.SRC.vocab.stoi['<pad>']).unsqueeze(-2)
35
+ eos_tok = self.TRG.vocab.stoi['<eos>']
36
+
37
+ # Encoder
38
+ e_output = self.model.encoder(sentences, src_mask)
39
+
40
+ out = torch.LongTensor([[init_tok] for i in range(self.batch_size)])
41
+ outputs = torch.zeros(self.batch_size, self.max_len).long()
42
+ outputs[:, :1] = out
43
+
44
+ outputs = outputs.to(self.device)
45
+ is_finished = torch.LongTensor([[eos_tok] for i in range(self.batch_size)]).view(-1).to(self.device)
46
+ return eos_tok, src_mask, is_finished, e_output, outputs
47
+
48
+ def create_trg_mask(self, i, device):
49
+ return no_peeking_mask(i, device)
50
+
51
+ def current_predict(self, outputs, e_output, src_mask, trg_mask):
52
+ out, attn = self.model.out(self.model.decoder(outputs,
53
+ e_output, src_mask, trg_mask))
54
+ out = functional.softmax(out, dim=-1)
55
+ return out, attn
56
+
57
+ def greedy_search(self, sentences, sampling_temp=0.0, keep_topk=1):
58
+ if len(sentences) < self.batch_size:
59
+ self.batch_size = len(sentences)
60
+
61
+ eos_tok, src_mask, is_finished, e_output, outputs = self.initilize_value(sentences)
62
+
63
+ for i in range(1, self.max_len):
64
+ out, attn = self.current_predict(outputs[:, :i], e_output, src_mask, self.create_trg_mask(i, self.device))
65
+ topk_ix, topk_prob = self.sample_with_temperature(out[:, -1], sampling_temp, keep_topk)
66
+ outputs[:, i] = topk_ix.view(-1)
67
+ if torch.equal(outputs[:, i], is_finished):
68
+ break
69
+
70
+ if self.replace_unk == True:
71
+ outputs = self.replace_unknown(attn, sentences, outputs)
72
+
73
+ # print("\n".join([' '.join([self.TRG.vocab.itos[tok] for tok in line[1:]]) for line in outputs]))
74
+ # Write to file or Print to the console
75
+ translated_sentences = []
76
+ # Get the best sentences: idx = 0 + i*k
77
+ for i in range(0, len(outputs)):
78
+ is_eos = torch.nonzero(outputs[i]==eos_tok)
79
+ if len(is_eos) == 0:
80
+ # if there is no sequence end, remove
81
+ sent = outputs[i, 1:]
82
+ else:
83
+ length = is_eos[0]
84
+ sent = outputs[i, 1:length]
85
+ translated_sentences.append([self.TRG.vocab.itos[tok] for tok in sent])
86
+
87
+ return translated_sentences
88
+
89
+ def sample_with_temperature(self, logits, sampling_temp, keep_topk):
90
+ if sampling_temp == 0.0 or keep_topk == 1:
91
+ # For temp=0.0, take the argmax to avoid divide-by-zero errors.
92
+ # keep_topk=1 is also equivalent to argmax.
93
+ topk_scores, topk_ids = logits.topk(1, dim=-1)
94
+ if sampling_temp > 0:
95
+ topk_scores /= sampling_temp
96
+ else:
97
+ logits = torch.div(logits, sampling_temp)
98
+
99
+ if keep_topk > 0:
100
+ top_values, top_indices = torch.topk(logits, keep_topk, dim=1)
101
+ kth_best = top_values[:, -1].view([-1, 1])
102
+ kth_best = kth_best.repeat([1, logits.shape[1]]).float()
103
+
104
+ # Set all logits that are not in the top-k to -10000.
105
+ # This puts the probabilities close to 0.
106
+ ignore = torch.lt(logits, kth_best)
107
+ logits = logits.masked_fill(ignore, -10000)
108
+
109
+ dist = torch.distributions.Multinomial(
110
+ logits=logits, total_count=1)
111
+ topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True)
112
+ topk_scores = logits.gather(dim=1, index=topk_ids)
113
+ return topk_ids, topk_scores
114
+
115
+ def translate_batch(self, sentences):
116
+ # super(BeamSearch, self).__init__()
117
+ sentences = self.preprocess_batch(sentences).to(self.device)
118
+ return self.greedy_search(sentences, 0.2, 2)
119
+ # print(self.initilize_value(sentences))
modules/loader/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .default_loader import DefaultLoader
2
+ from .multilingual_loader import MultiLoader
3
+
4
+ loaders = {"monoloader": DefaultLoader, "multiloader": MultiLoader}
modules/loader/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (313 Bytes). View file
 
modules/loader/__pycache__/default_loader.cpython-36.pyc ADDED
Binary file (6.24 kB). View file
 
modules/loader/__pycache__/multilingual_loader.cpython-36.pyc ADDED
Binary file (5.97 kB). View file
 
modules/loader/default_loader.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io, os
2
+ import dill as pickle
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ from torchtext.data import BucketIterator, Dataset, Example, Field
6
+ from torchtext.datasets import TranslationDataset, Multi30k, IWSLT, WMT14
7
+ from collections import Counter
8
+
9
+ import modules.constants as const
10
+ from utils.save import load_vocab_from_path
11
+ import laonlp
12
+
13
+ class DefaultLoader:
14
+ def __init__(self, train_path_or_name, language_tuple=None, valid_path=None, eval_path=None, option=None):
15
+ """Load training/eval data file pairing, process and create data iterator for training """
16
+ self._language_tuple = language_tuple
17
+ self._train_path = train_path_or_name
18
+ self._eval_path = eval_path
19
+ self._option = option
20
+
21
+ @property
22
+ def language_tuple(self):
23
+ """DefaultLoader will use the default lang option @bleu_batch_iter <sos>, hence, None"""
24
+ return None, None
25
+
26
+ def tokenize(self, sentence):
27
+ return sentence.strip().split()
28
+
29
+ def detokenize(self, list_of_tokens):
30
+ """Differentiate between [batch, len] and [len]; joining tokens back to strings"""
31
+ if( len(list_of_tokens) == 0 or isinstance(list_of_tokens[0], str)):
32
+ # [len], single sentence version
33
+ return " ".join(list_of_tokens)
34
+ else:
35
+ # [batch, len], batch sentence version
36
+ return [" ".join(tokens) for tokens in list_of_tokens]
37
+
38
+ def _train_path_is_name(self):
39
+ return os.path.isfile(self._train_path + self._language_tuple[0]) and os.path.isfile(self._train_path + self._language_tuple[1])
40
+
41
+ def create_length_constraint(self, token_limit):
42
+ """Filter an iterator if it pass a token limit"""
43
+ return lambda x: len(x.src) <= token_limit and len(x.trg) <= token_limit
44
+
45
+ def build_field(self, **kwargs):
46
+ """Build fields that will handle the conversion from token->idx and vice versa. To be overriden by MultiLoader."""
47
+ return Field(lower=False, tokenize=laonlp.tokenize.word_tokenize), Field(lower=False, tokenize=self.tokenize, init_token=const.DEFAULT_SOS, eos_token=const.DEFAULT_EOS, is_target=True)
48
+
49
+ def build_vocab(self, fields, model_path=None, data=None, **kwargs):
50
+ """Build the vocabulary object for torchtext Field. There are three flows:
51
+ - if the model path is present, it will first try to load the pickled/dilled vocab object from path. This is accessed on continued training & standalone inference
52
+ - if that failed and data is available, try to build the vocab from that data. This is accessed on first time training
53
+ - if data is not available, search for set of two vocab files and read them into the fields. This is accessed on first time training
54
+ TODO: expand on the vocab file option (loading pretrained vectors as well)
55
+ """
56
+ src_field, trg_field = fields
57
+ if(model_path is None or not load_vocab_from_path(model_path, self._language_tuple, fields)):
58
+ # the condition will try to load vocab pickled to model path.
59
+ if(data is not None):
60
+ print("Building vocab from received data.")
61
+ # build the vocab using formatted data.
62
+ src_field.build_vocab(data, **kwargs)
63
+ trg_field.build_vocab(data, **kwargs)
64
+ else:
65
+ print("Building vocab from preloaded text file.")
66
+ # load the vocab values from external location (a formatted text file). Initialize values as random
67
+ external_vocab_location = self._option.get("external_vocab_location", None)
68
+ src_ext, trg_ext = self._language_tuple
69
+ # read the files and create a mock Counter object; which then is passed to vocab's class
70
+ # see Field.build_vocab for the options used with vocab_cls
71
+ vocab_src = external_vocab_location + src_ext
72
+ with io.open(vocab_src, "r", encoding="utf-8") as svf:
73
+ mock_counter = Counter({w.strip():1 for w in svf.readlines()})
74
+ special_tokens = [src_field.unk_token, src_field.pad_token, src_field.init_token, src_field.eos_token]
75
+ src_field.vocab = src_field.vocab_cls(mock_counter, specials=special_tokens, min_freq=5, **kwargs)
76
+ vocab_trg = external_vocab_location + trg_ext
77
+ with io.open(vocab_trg, "r", encoding="utf-8") as tvf:
78
+ mock_counter = Counter({w.strip():1 for w in tvf.readlines()})
79
+ special_tokens = [trg_field.unk_token, trg_field.pad_token, trg_field.init_token, trg_field.eos_token]
80
+ trg_field.vocab = trg_field.vocab_cls(mock_counter, specials=special_tokens, min_freq=5, **kwargs)
81
+ else:
82
+ print("Load vocab from path successful.")
83
+
84
+ def create_iterator(self, fields, model_path=None):
85
+ """Create the iterator needed to load batches of data and bind them to existing fields
86
+ NOTE: unlike the previous loader, this one inputs list of tokens instead of a string, which necessitate redefinining of translate_sentence pipe"""
87
+ if(not self._train_path_is_name()):
88
+ # load the default torchtext dataset by name
89
+ # TODO load additional arguments in the config
90
+ dataset_cls = next( (s for s in [Multi30k, IWSLT, WMT14] if s.__name__ == self._train_path), None )
91
+ if(dataset_cls is None):
92
+ raise ValueError("The specified train path {:s}(+{:s}/{:s}) does neither point to a valid files path nor is a name of torchtext dataset class.".format(self._train_path, *self._language_tuple))
93
+ src_suffix, trg_suffix = ext = self._language_tuple
94
+ # print(ext, fields)
95
+ self._train_data, self._valid_data, self._eval_data = dataset_cls.splits(exts=ext, fields=fields) #, split_ratio=self._option.get("train_test_split", const.DEFAULT_TRAIN_TEST_SPLIT)
96
+ else:
97
+ # create dataset from path. Also add all necessary constraints (e.g lengths trimming/excluding)
98
+ src_suffix, trg_suffix = ext = self._language_tuple
99
+ filter_fn = self.create_length_constraint(self._option.get("train_max_length", const.DEFAULT_TRAIN_MAX_LENGTH))
100
+ self._train_data = TranslationDataset(self._train_path, ext, fields, filter_pred=filter_fn)
101
+ self._valid_data = self._eval_data = TranslationDataset(self._eval_path, ext, fields)
102
+ # first_sample = self._train_data[0]; raise Exception("{} {}".format(first_sample.src, first_sample.trg))
103
+ # whatever created, we now have the two set of data ready. add the necessary constraints/filtering/etc.
104
+ train_data = self._train_data
105
+ eval_data = self._eval_data
106
+ # now we can execute build_vocab. This function will try to load vocab from model_path, and if fail, build the vocab from train_data
107
+ build_vocab_kwargs = self._option.get("build_vocab_kwargs", {})
108
+ self.build_vocab(fields, data=train_data, model_path=model_path, **build_vocab_kwargs)
109
+ # raise Exception("{}".format(len(src_field.vocab)))
110
+ # crafting iterators
111
+ train_iter = BucketIterator(train_data, batch_size=self._option.get("batch_size", const.DEFAULT_BATCH_SIZE), device=self._option.get("device", const.DEFAULT_DEVICE) )
112
+ eval_iter = BucketIterator(eval_data, batch_size=self._option.get("eval_batch_size", const.DEFAULT_EVAL_BATCH_SIZE), device=self._option.get("device", const.DEFAULT_DEVICE), train=False )
113
+ return train_iter, eval_iter
114
+
modules/loader/multilingual_loader.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io, os
2
+ import dill as pickle
3
+ from collections import Counter
4
+ import torch
5
+ from torchtext.data import BucketIterator, Dataset, Example, Field, interleave_keys
6
+ import modules.constants as const
7
+ from utils.save import load_vocab_from_path
8
+ from utils.data import generate_language_token
9
+ from modules.loader.default_loader import DefaultLoader
10
+ import laonlp
11
+
12
+ class MultiDataset(Dataset):
13
+ """
14
+ Ensemble one or more corpuses from different languages.
15
+ The corpuses use global source vocab and target vocab.
16
+
17
+ Constructor Args:
18
+ data_info: list of datasets info <See `train` argument in MultiLoader class>
19
+ fields: A tuple containing src field and trg field.
20
+ """
21
+ @staticmethod
22
+ def sort_key(ex):
23
+ return interleave_keys(len(ex.src), len(ex.trg))
24
+
25
+ def __init__(self, data_info, fields, **kwargs):
26
+ self.languages = set()
27
+
28
+ if not isinstance(fields[0], (tuple, list)):
29
+ fields = [('src', fields[0]), ('trg', fields[1])]
30
+
31
+ examples = []
32
+ for corpus, info in data_info:
33
+ print("Loading corpus {} ...".format(corpus))
34
+
35
+ src_lang = info["src_lang"]
36
+ trg_lang = info["trg_lang"]
37
+ src_path = os.path.expanduser('.'.join([info["path"], src_lang]))
38
+ trg_path = os.path.expanduser('.'.join([info["path"], trg_lang]))
39
+ self.languages.add(src_lang)
40
+ self.languages.add(trg_lang)
41
+
42
+ with io.open(src_path, mode='r', encoding='utf-8') as src_file, \
43
+ io.open(trg_path, mode='r', encoding='utf-8') as trg_file:
44
+ for src_line, trg_line in zip(src_file, trg_file):
45
+ src_line, trg_line = src_line.strip(), trg_line.strip()
46
+ if src_line != '' and trg_line != '':
47
+ # Append language-specific prefix token
48
+ src_line = ' '.join([generate_language_token(src_lang), src_line])
49
+ trg_line = ' '.join([generate_language_token(trg_lang), trg_line])
50
+ examples.append(Example.fromlist([src_line, trg_line], fields))
51
+ print("Done!")
52
+
53
+ super(MultiDataset, self).__init__(examples, fields, **kwargs)
54
+
55
+
56
+ class MultiLoader(DefaultLoader):
57
+ def __init__(self, train, valid=None, option=None):
58
+ """
59
+ Load multiple training/eval parallel data files, process and create data iterator
60
+ Constructor Args:
61
+ train: a dictionary contains training data information
62
+ valid (optional): a dictionary contains validation data information
63
+ option (optional): a dictionary contains configurable parameters
64
+
65
+ For example:
66
+ train = {
67
+ "corpus_1": {
68
+ "path": path/to/training/data,
69
+ "src_lang": src,
70
+ "trg_lang": trg
71
+ },
72
+ "corpus_2": {
73
+ ...
74
+ }
75
+ }
76
+ """
77
+ self._train_info = train
78
+ self._valid_info = valid
79
+ self._language_tuple = ('.src', '.trg')
80
+ self._option = option
81
+
82
+ @property
83
+ def tokenize(self, sentence):
84
+ return sentence.strip().split()
85
+
86
+
87
+ def language_tuple(self):
88
+ """Currently output valid data's tuple for bleu_valid_iter, which would use <{trg_lang}> during inference. Since <{src_lang}> had already been added to the valid data, return None instead."""
89
+ return None, self._valid_info["trg_lang"]
90
+
91
+ def _is_path(self, path, lang):
92
+ """Check whether the path is a system path or a corpus name"""
93
+ return os.path.isfile(path + '.' + lang)
94
+
95
+ def build_field(self, **kwargs):
96
+ return Field(lower=False, tokenize=laonlp.tokenize.word_tokenize), Field(lower=False, tokenize=self.tokenize, init_token=const.DEFAULT_SOS, eos_token=const.DEFAULT_EOS)
97
+
98
+ def build_vocab(self, fields, model_path=None, data=None, **kwargs):
99
+ """Build the vocabulary object for torchtext Field. There are three flows:
100
+ - if the model path is present, it will first try to load the pickled/dilled vocab object from path. This is accessed on continued training & standalone inference
101
+ - if that failed and data is available, try to build the vocab from that data. This is accessed on first time training
102
+ - if data is not available, search for set of two vocab files and read them into the fields. This is accessed on first time training
103
+ TODO: expand on the vocab file option (loading pretrained vectors as well)
104
+ """
105
+ src_field, trg_field = fields
106
+ if model_path is None or not load_vocab_from_path(model_path, self._language_tuple, fields):
107
+ # the condition will try to load vocab pickled to model path.
108
+ if data is not None:
109
+ print("Building vocab from received data.")
110
+ # build the vocab using formatted data.
111
+ src_field.build_vocab(data, **kwargs)
112
+ trg_field.build_vocab(data, **kwargs)
113
+ else:
114
+ # Not implemented mixing preloaded datasets and external datasets
115
+ raise ValueError("MultiLoader currently do not support preloaded text vocab")
116
+ else:
117
+ print("Load vocab from path successful.")
118
+
119
+ def create_iterator(self, fields, model_path=None):
120
+ """Create the iterator needed to load batches of data and bind them to existing fields"""
121
+ # create dataset from path. Also add all necessary constraints (e.g lengths trimming/excluding)
122
+ filter_fn = self.create_length_constraint(self._option.get("train_max_length", const.DEFAULT_TRAIN_MAX_LENGTH))
123
+ self._train_data = MultiDataset(data_info=self._train_info.items(), fields=fields, filter_pred=filter_fn)
124
+
125
+ # now we can execute build_vocab. This function will try to load vocab from model_path, and if fail, build the vocab from train_data
126
+ build_vocab_kwargs = self._option.get("build_vocab_kwargs", {})
127
+ build_vocab_kwargs["specials"] = build_vocab_kwargs.pop("specials", []) + list(self._train_data.languages)
128
+ self.build_vocab(fields, data=self._train_data, model_path=model_path, **build_vocab_kwargs)
129
+
130
+ # Create train iterator
131
+ train_iter = BucketIterator(self._train_data, batch_size=self._option.get("batch_size", const.DEFAULT_BATCH_SIZE), device=self._option.get("device", const.DEFAULT_DEVICE))
132
+
133
+ if self._valid_info is not None:
134
+ self._valid_data = MultiDataset(data_info=[("valid", self._valid_info)], fields=fields)
135
+ valid_iter = BucketIterator(self._valid_data, batch_size=self._option.get("eval_batch_size", const.DEFAULT_EVAL_BATCH_SIZE), device=self._option.get("device", const.DEFAULT_DEVICE), train=False)
136
+ else:
137
+ valid_iter = None
138
+
139
+ return train_iter, valid_iter
modules/optim/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from modules.optim.adam import AdamOptim
2
+ from modules.optim.adabelief import AdaBeliefOptim
3
+ from modules.optim.scheduler import ScheduledOptim
4
+
5
+ optimizers = {"Adam": AdamOptim, "AdaBelief": AdaBeliefOptim}
modules/optim/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (376 Bytes). View file
 
modules/optim/__pycache__/adabelief.cpython-36.pyc ADDED
Binary file (1.48 kB). View file
 
modules/optim/__pycache__/adam.cpython-36.pyc ADDED
Binary file (1.43 kB). View file
 
modules/optim/__pycache__/scheduler.cpython-36.pyc ADDED
Binary file (2.09 kB). View file
 
modules/optim/adabelief.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+
3
+ class AdaBeliefOptim(torch.optim.Optimizer):
4
+
5
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-9, **kwargs):
6
+ defaults = dict(lr=lr, betas=betas, eps=eps)
7
+ super().__init__(params, defaults)
8
+
9
+ @torch.no_grad()
10
+ def step(self, closure=None):
11
+ for group in self.param_groups:
12
+ for p in group['params']:
13
+ if p.grad is None:
14
+ # No backward
15
+ continue
16
+ grad = p.grad
17
+ state = self.state[p]
18
+
19
+ if len(state) == 0:
20
+ # Initial step
21
+ state['step'] = 0
22
+ state['exp_avg'] = torch.zeros_like(p)
23
+ state['exp_avg_sq'] = torch.zeros_like(p)
24
+
25
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
26
+ beta1, beta2 = group['betas']
27
+
28
+ state['step'] += 1
29
+ bias_correction1 = 1 - beta1 ** state['step']
30
+ bias_correction2 = 1 - beta2 ** state['step']
31
+
32
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
33
+
34
+ # This is the diff between Adam and Adabelief
35
+ centered_grad = grad.sub(exp_avg)
36
+ exp_avg_sq.mul_(beta2).addcmul_(centered_grad, centered_grad, value=1-beta2)
37
+ # !
38
+
39
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
40
+ step_size = group['lr'] / bias_correction1
41
+
42
+ p.addcdiv_(exp_avg, denom, value=-step_size)
modules/optim/adam.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+
3
+ class AdamOptim(torch.optim.Optimizer):
4
+
5
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-9, **kwargs):
6
+ defaults = dict(lr=lr, betas=betas, eps=eps)
7
+ super().__init__(params, defaults)
8
+
9
+ @torch.no_grad()
10
+ def step(self, closure=None):
11
+ for group in self.param_groups:
12
+ for p in group['params']:
13
+ if p.grad is None:
14
+ # No backward
15
+ continue
16
+ grad = p.grad
17
+ state = self.state[p]
18
+
19
+ if len(state) == 0:
20
+ # Initial step
21
+ state['step'] = 0
22
+ state['exp_avg'] = torch.zeros_like(p)
23
+ state['exp_avg_sq'] = torch.zeros_like(p)
24
+
25
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
26
+ beta1, beta2 = group['betas']
27
+
28
+ state['step'] += 1
29
+ bias_correction1 = 1 - beta1 ** state['step']
30
+ bias_correction2 = 1 - beta2 ** state['step']
31
+
32
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
33
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
34
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
35
+
36
+ step_size = group['lr'] / bias_correction1
37
+
38
+ p.addcdiv_(exp_avg, denom, value=-step_size)
modules/optim/scheduler.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class ScheduledOptim():
4
+ '''A simple wrapper class for learning rate scheduling'''
5
+
6
+ def __init__(self, optimizer, init_lr, d_model, n_warmup_steps):
7
+ self._optimizer = optimizer
8
+ self.init_lr = init_lr
9
+ self.d_model = d_model
10
+ self.n_warmup_steps = n_warmup_steps
11
+ self.n_steps = 0
12
+
13
+
14
+ def step_and_update_lr(self):
15
+ "Step with the inner optimizer"
16
+ self._update_learning_rate()
17
+ self._optimizer.step()
18
+
19
+
20
+ def zero_grad(self):
21
+ "Zero out the gradients with the inner optimizer"
22
+ self._optimizer.zero_grad()
23
+
24
+
25
+ def _get_lr_scale(self):
26
+ d_model = self.d_model
27
+ n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps
28
+ return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5))
29
+
30
+ def state_dict(self):
31
+ optimizer_state_dict = {
32
+ 'init_lr':self.init_lr,
33
+ 'd_model':self.d_model,
34
+ 'n_warmup_steps':self.n_warmup_steps,
35
+ 'n_steps':self.n_steps,
36
+ '_optimizer':self._optimizer.state_dict(),
37
+ }
38
+
39
+ return optimizer_state_dict
40
+
41
+ def load_state_dict(self, state_dict):
42
+ self.init_lr = state_dict['init_lr']
43
+ self.d_model = state_dict['d_model']
44
+ self.n_warmup_steps = state_dict['n_warmup_steps']
45
+ self.n_steps = state_dict['n_steps']
46
+
47
+ self._optimizer.load_state_dict(state_dict['_optimizer'])
48
+
49
+ def _update_learning_rate(self):
50
+ ''' Learning rate scheduling per step '''
51
+
52
+ self.n_steps += 1
53
+ lr = self.init_lr * self._get_lr_scale()
54
+
55
+ for param_group in self._optimizer.param_groups:
56
+ param_group['lr'] = lr
modules/prototypes.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torchtext import data
3
+ import copy
4
+ import layers as layers
5
+
6
+ class Embedder(nn.Module):
7
+ def __init__(self, vocab_size, d_model):
8
+ super().__init__()
9
+ self.vocab_size = vocab_size
10
+ self.d_model = d_model
11
+
12
+ self.embed = nn.Embedding(vocab_size, d_model)
13
+
14
+ def forward(self, x):
15
+ return self.embed(x)
16
+
17
+ class EncoderLayer(nn.Module):
18
+ def __init__(self, d_model, heads, dropout=0.1):
19
+ """An layer of the encoder. Contain a self-attention accepting padding mask
20
+ Args:
21
+ d_model: the inner dimension size of the layer
22
+ heads: number of heads used in the attention
23
+ dropout: applied dropout value during training
24
+ """
25
+ super().__init__()
26
+ self.norm_1 = layers.Norm(d_model)
27
+ self.norm_2 = layers.Norm(d_model)
28
+ self.attn = layers.MultiHeadAttention(heads, d_model, dropout=dropout)
29
+ self.ff = layers.FeedForward(d_model, dropout=dropout)
30
+ self.dropout_1 = nn.Dropout(dropout)
31
+ self.dropout_2 = nn.Dropout(dropout)
32
+
33
+ def forward(self, x, src_mask):
34
+ """Run the encoding layer
35
+ Args:
36
+ x: the input (either embedding values or previous layer output), should be in shape [batch_size, src_len, d_model]
37
+ src_mask: the padding mask, should be [batch_size, 1, src_len]
38
+ Return:
39
+ an output that have the same shape as input, [batch_size, src_len, d_model]
40
+ the attention used [batch_size, src_len, src_len]
41
+ """
42
+ x2 = self.norm_1(x)
43
+ # Self attention only
44
+ x_sa, sa = self.attn(x2, x2, x2, src_mask)
45
+ x = x + self.dropout_1(x_sa)
46
+ x2 = self.norm_2(x)
47
+ x = x + self.dropout_2(self.ff(x2))
48
+ return x, sa
49
+
50
+ class DecoderLayer(nn.Module):
51
+ def __init__(self, d_model, heads, dropout=0.1):
52
+ """An layer of the decoder. Contain a self-attention that accept no-peeking mask and a normal attention tha t accept padding mask
53
+ Args:
54
+ d_model: the inner dimension size of the layer
55
+ heads: number of heads used in the attention
56
+ dropout: applied dropout value during training
57
+ """
58
+ super().__init__()
59
+ self.norm_1 = layers.Norm(d_model)
60
+ self.norm_2 = layers.Norm(d_model)
61
+ self.norm_3 = layers.Norm(d_model)
62
+
63
+ self.dropout_1 = nn.Dropout(dropout)
64
+ self.dropout_2 = nn.Dropout(dropout)
65
+ self.dropout_3 = nn.Dropout(dropout)
66
+
67
+ self.attn_1 = layers.MultiHeadAttention(heads, d_model, dropout=dropout)
68
+ self.attn_2 = layers.MultiHeadAttention(heads, d_model, dropout=dropout)
69
+ self.ff = layers.FeedForward(d_model, dropout=dropout)
70
+
71
+ def forward(self, x, memory, src_mask, trg_mask):
72
+ """Run the decoding layer
73
+ Args:
74
+ x: the input (either embedding values or previous layer output), should be in shape [batch_size, tgt_len, d_model]
75
+ memory: the outputs of the encoding section, used for normal attention. [batch_size, src_len, d_model]
76
+ src_mask: the padding mask for the memory, [batch_size, 1, src_len]
77
+ tgt_mask: the no-peeking mask for the decoder, [batch_size, tgt_len, tgt_len]
78
+ Return:
79
+ an output that have the same shape as input, [batch_size, tgt_len, d_model]
80
+ the self-attention and normal attention received [batch_size, head, tgt_len, tgt_len] & [batch_size, head, tgt_len, src_len]
81
+ """
82
+ x2 = self.norm_1(x)
83
+ # Self-attention
84
+ x_sa, sa = self.attn_1(x2, x2, x2, trg_mask)
85
+ x = x + self.dropout_1(x_sa)
86
+ x2 = self.norm_2(x)
87
+ # Normal multi-head attention
88
+ x_na, na = self.attn_2(x2, memory, memory, src_mask)
89
+ x = x + self.dropout_2(x_na)
90
+ x2 = self.norm_3(x)
91
+ x = x + self.dropout_3(self.ff(x2))
92
+ return x, (sa, na)
93
+
94
+ def get_clones(module, N, keep_module=True):
95
+ if(keep_module and N >= 1):
96
+ # create N-1 copies in addition to the original
97
+ return nn.ModuleList([module] + [copy.deepcopy(module) for i in range(N-1)])
98
+ else:
99
+ # create N new copy
100
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
101
+
102
+ class Encoder(nn.Module):
103
+ """A wrapper that embed, positional encode, and self-attention encode the inputs.
104
+ Args:
105
+ vocab_size: the size of the vocab. Used for embedding
106
+ d_model: the inner dim of the module
107
+ N: number of layers used
108
+ heads: number of heads used in the attention
109
+ dropout: applied dropout value during training
110
+ max_seq_length: the maximum length value used for this encoder. Needed for PositionalEncoder, due to caching
111
+ """
112
+ def __init__(self, vocab_size, d_model, N, heads, dropout, max_seq_length=200):
113
+ super().__init__()
114
+ self.N = N
115
+ self.embed = nn.Embedding(vocab_size, d_model)
116
+ self.pe = layers.PositionalEncoder(d_model, dropout=dropout, max_seq_length=max_seq_length)
117
+ self.layers = get_clones(EncoderLayer(d_model, heads, dropout), N)
118
+ self.norm = layers.Norm(d_model)
119
+
120
+ self._max_seq_length = max_seq_length
121
+
122
+ def forward(self, src, src_mask, output_attention=False, seq_length_check=False):
123
+ """Accepts a batch of indexed tokens, return the encoded values.
124
+ Args:
125
+ src: int Tensor of [batch_size, src_len]
126
+ src_mask: the padding mask, [batch_size, 1, src_len]
127
+ output_attention: if set, output a list containing used attention
128
+ seq_length_check: if set, automatically trim the input if it goes past the expected sequence length.
129
+ Returns:
130
+ the encoded values [batch_size, src_len, d_model]
131
+ if available, list of N (self-attention) calculated. They are in form of [batch_size, heads, src_len, src_len]
132
+ """
133
+ if(seq_length_check and src.shape[1] > self._max_seq_length):
134
+ src = src[:, :self._max_seq_length]
135
+ src_mask = src_mask[:, :, :self._max_seq_length]
136
+ x = self.embed(src)
137
+ x = self.pe(x)
138
+ attentions = [None] * self.N
139
+ for i in range(self.N):
140
+ x, attn = self.layers[i](x, src_mask)
141
+ attentions[i] = attn
142
+ x = self.norm(x)
143
+ return x if(not output_attention) else (x, attentions)
144
+
145
+ class Decoder(nn.Module):
146
+ """A wrapper that receive the encoder outputs, run through the decoder process for a determined input
147
+ Args:
148
+ vocab_size: the size of the vocab. Used for embedding
149
+ d_model: the inner dim of the module
150
+ N: number of layers used
151
+ heads: number of heads used in the attention
152
+ dropout: applied dropout value during training
153
+ max_seq_length: the maximum length value used for this encoder. Needed for PositionalEncoder, due to caching
154
+ """
155
+ def __init__(self, vocab_size, d_model, N, heads, dropout, max_seq_length=200):
156
+ super().__init__()
157
+ self.N = N
158
+ self.embed = nn.Embedding(vocab_size, d_model)
159
+ self.pe = layers.PositionalEncoder(d_model, dropout=dropout, max_seq_length=max_seq_length)
160
+ self.layers = get_clones(DecoderLayer(d_model, heads, dropout), N)
161
+ self.norm = layers.Norm(d_model)
162
+
163
+ self._max_seq_length = max_seq_length
164
+
165
+ def forward(self, trg, memory, src_mask, trg_mask, output_attention=False):
166
+ """Accepts a batch of indexed tokens and the encoding outputs, return the decoded values.
167
+ Args:
168
+ trg: input Tensor of [batch_size, trg_len]
169
+ memory: output of Encoder [batch_size, src_len, d_model]
170
+ src_mask: the padding mask, [batch_size, 1, src_len]
171
+ trg_mask: the no-peeking mask, [batch_size, tgt_len, tgt_len]
172
+ output_attention: if set, output a list containing used attention
173
+ Returns:
174
+ the decoded values [batch_size, tgt_len, d_model]
175
+ if available, list of N (self-attention, attention) calculated. They are in form of [batch_size, heads, tgt_len, tgt/src_len]
176
+ """
177
+ x = self.embed(trg)
178
+ x = self.pe(x)
179
+
180
+ attentions = [None] * self.N
181
+ for i in range(self.N):
182
+ x, attn = self.layers[i](x, memory, src_mask, trg_mask)
183
+ attentions[i] = attn
184
+ x = self.norm(x)
185
+ return x if(not output_attention) else (x, attentions)
186
+
187
+
188
+ class Config:
189
+ """Deprecated"""
190
+ def __init__(self):
191
+ self.opt = {
192
+ 'train_src_data':'/workspace/khoai23/opennmt/data/iwslt_en_vi/train.en',
193
+ 'train_trg_data':'/workspace/khoai23/opennmt/data/iwslt_en_vi/train.vi',
194
+ 'valid_src_data':'/workspace/khoai23/opennmt/data/iwslt_en_vi/tst2013.en',
195
+ 'valid_trg_data':'/workspace/khoai23/opennmt/data/iwslt_en_vi/tst2013.vi',
196
+ 'src_lang':'en', # useless atm
197
+ 'trg_lang':'en',#'vi_spacy_model', # useless atm
198
+ 'max_strlen':160,
199
+ 'batchsize':1500,
200
+ 'device':'cuda',
201
+ 'd_model': 512,
202
+ 'n_layers': 6,
203
+ 'heads': 8,
204
+ 'dropout': 0.1,
205
+ 'lr':0.0001,
206
+ 'epochs':30,
207
+ 'printevery': 200,
208
+ 'k':5,
209
+ }
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ nltk
3
+ torchtext==0.6.0
4
+ pyvi
5
+ spacy
6
+ PyYAML
7
+ dill
8
+ pandas
9
+ laonlp
10
+ perl
third-party/multi-bleu.perl ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env perl
2
+ #
3
+ # This file is part of moses. Its use is licensed under the GNU Lesser General
4
+ # Public License version 2.1 or, at your option, any later version.
5
+
6
+ # $Id$
7
+ use warnings;
8
+ use strict;
9
+
10
+ my $lowercase = 0;
11
+ if ($ARGV[0] eq "-lc") {
12
+ $lowercase = 1;
13
+ shift;
14
+ }
15
+
16
+ my $stem = $ARGV[0];
17
+ if (!defined $stem) {
18
+ print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n";
19
+ print STDERR "Reads the references from reference or reference0, reference1, ...\n";
20
+ exit(1);
21
+ }
22
+
23
+ $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0";
24
+
25
+ my @REF;
26
+ my $ref=0;
27
+ while(-e "$stem$ref") {
28
+ &add_to_ref("$stem$ref",\@REF);
29
+ $ref++;
30
+ }
31
+ &add_to_ref($stem,\@REF) if -e $stem;
32
+ die("ERROR: could not find reference file $stem") unless scalar @REF;
33
+
34
+ # add additional references explicitly specified on the command line
35
+ shift;
36
+ foreach my $stem (@ARGV) {
37
+ &add_to_ref($stem,\@REF) if -e $stem;
38
+ }
39
+
40
+
41
+
42
+ sub add_to_ref {
43
+ my ($file,$REF) = @_;
44
+ my $s=0;
45
+ if ($file =~ /.gz$/) {
46
+ open(REF,"gzip -dc $file|") or die "Can't read $file";
47
+ } else {
48
+ open(REF,$file) or die "Can't read $file";
49
+ }
50
+ while(<REF>) {
51
+ chomp;
52
+ push @{$$REF[$s++]}, $_;
53
+ }
54
+ close(REF);
55
+ }
56
+
57
+ my(@CORRECT,@TOTAL,$length_translation,$length_reference);
58
+ my $s=0;
59
+ while(<STDIN>) {
60
+ chomp;
61
+ $_ = lc if $lowercase;
62
+ my @WORD = split;
63
+ my %REF_NGRAM = ();
64
+ my $length_translation_this_sentence = scalar(@WORD);
65
+ my ($closest_diff,$closest_length) = (9999,9999);
66
+ foreach my $reference (@{$REF[$s]}) {
67
+ # print "$s $_ <=> $reference\n";
68
+ $reference = lc($reference) if $lowercase;
69
+ my @WORD = split(' ',$reference);
70
+ my $length = scalar(@WORD);
71
+ my $diff = abs($length_translation_this_sentence-$length);
72
+ if ($diff < $closest_diff) {
73
+ $closest_diff = $diff;
74
+ $closest_length = $length;
75
+ # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n";
76
+ } elsif ($diff == $closest_diff) {
77
+ $closest_length = $length if $length < $closest_length;
78
+ # from two references with the same closeness to me
79
+ # take the *shorter* into account, not the "first" one.
80
+ }
81
+ for(my $n=1;$n<=4;$n++) {
82
+ my %REF_NGRAM_N = ();
83
+ for(my $start=0;$start<=$#WORD-($n-1);$start++) {
84
+ my $ngram = "$n";
85
+ for(my $w=0;$w<$n;$w++) {
86
+ $ngram .= " ".$WORD[$start+$w];
87
+ }
88
+ $REF_NGRAM_N{$ngram}++;
89
+ }
90
+ foreach my $ngram (keys %REF_NGRAM_N) {
91
+ if (!defined($REF_NGRAM{$ngram}) ||
92
+ $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) {
93
+ $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram};
94
+ # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}<BR>\n";
95
+ }
96
+ }
97
+ }
98
+ }
99
+ $length_translation += $length_translation_this_sentence;
100
+ $length_reference += $closest_length;
101
+ for(my $n=1;$n<=4;$n++) {
102
+ my %T_NGRAM = ();
103
+ for(my $start=0;$start<=$#WORD-($n-1);$start++) {
104
+ my $ngram = "$n";
105
+ for(my $w=0;$w<$n;$w++) {
106
+ $ngram .= " ".$WORD[$start+$w];
107
+ }
108
+ $T_NGRAM{$ngram}++;
109
+ }
110
+ foreach my $ngram (keys %T_NGRAM) {
111
+ $ngram =~ /^(\d+) /;
112
+ my $n = $1;
113
+ # my $corr = 0;
114
+ # print "$i e $ngram $T_NGRAM{$ngram}<BR>\n";
115
+ $TOTAL[$n] += $T_NGRAM{$ngram};
116
+ if (defined($REF_NGRAM{$ngram})) {
117
+ if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) {
118
+ $CORRECT[$n] += $T_NGRAM{$ngram};
119
+ # $corr = $T_NGRAM{$ngram};
120
+ # print "$i e correct1 $T_NGRAM{$ngram}<BR>\n";
121
+ }
122
+ else {
123
+ $CORRECT[$n] += $REF_NGRAM{$ngram};
124
+ # $corr = $REF_NGRAM{$ngram};
125
+ # print "$i e correct2 $REF_NGRAM{$ngram}<BR>\n";
126
+ }
127
+ }
128
+ # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram};
129
+ # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n"
130
+ }
131
+ }
132
+ $s++;
133
+ }
134
+ my $brevity_penalty = 1;
135
+ my $bleu = 0;
136
+
137
+ my @bleu=();
138
+
139
+ for(my $n=1;$n<=4;$n++) {
140
+ if (defined ($TOTAL[$n])){
141
+ $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0;
142
+ # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n";
143
+ }else{
144
+ $bleu[$n]=0;
145
+ }
146
+ }
147
+
148
+ if ($length_reference==0){
149
+ printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n";
150
+ exit(1);
151
+ }
152
+
153
+ if ($length_translation<$length_reference) {
154
+ $brevity_penalty = exp(1-$length_reference/$length_translation);
155
+ }
156
+ $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) +
157
+ my_log( $bleu[2] ) +
158
+ my_log( $bleu[3] ) +
159
+ my_log( $bleu[4] ) ) / 4) ;
160
+ printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n",
161
+ 100*$bleu,
162
+ 100*$bleu[1],
163
+ 100*$bleu[2],
164
+ 100*$bleu[3],
165
+ 100*$bleu[4],
166
+ $brevity_penalty,
167
+ $length_translation / $length_reference,
168
+ $length_translation,
169
+ $length_reference;
170
+
171
+
172
+ print STDERR "It is in-advisable to publish scores from multi-bleu.perl. The scores depend on your tokenizer, which is unlikely to be reproducible from your paper or consistent across research groups. Instead you should detokenize then use mteval-v14.pl, which has a standard tokenization. Scores from multi-bleu.perl can still be used for internal purposes when you have a consistent tokenizer.\n";
173
+
174
+ sub my_log {
175
+ return -9999999999 unless $_[0];
176
+ return log($_[0]);
177
+ }
utils/data.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re, os
2
+ import nltk
3
+ from nltk.corpus import wordnet
4
+ import dill as pickle
5
+ import pandas as pd
6
+ from torchtext import data
7
+ from laonlp import tokenize
8
+
9
+ def multiple_replace(dict, text):
10
+ # Create a regular expression from the dictionary keys
11
+ regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
12
+
13
+ # For each match, look-up corresponding value in dictionary
14
+ return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
15
+
16
+ # get_synonym replace word with any synonym found among src
17
+ def get_synonym(word, SRC):
18
+ syns = wordnet.synsets(word)
19
+ for s in syns:
20
+ for l in s.lemmas():
21
+ if SRC.vocab.stoi[l.name()] != 0:
22
+ return SRC.vocab.stoi[l.name()]
23
+
24
+ return 0
25
+
26
+ class Tokenizer:
27
+ def __init__(self, lang=None):
28
+ if(lang is not None):
29
+ self.nlp = spacy.load(lang)
30
+ self.tokenizer_fn = self.nlp.tokenizer
31
+ else:
32
+ self.tokenizer_fn = lambda l: l.strip().split()
33
+
34
+ # def tokenize(self, sentence):
35
+ # sentence = re.sub(
36
+ # r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’\!;]", " ", str(sentence))
37
+ # sentence = re.sub(r"[ ]+", " ", sentence)
38
+ # sentence = re.sub(r"\!+", "!", sentence)
39
+ # sentence = re.sub(r"\,+", ",", sentence)
40
+ # sentence = re.sub(r"\?+", "?", sentence)
41
+ # sentence = sentence.lower()
42
+ # return [tok.text for tok in self.tokenizer_fn(sentence) if tok.text != " "]
43
+
44
+
45
+ def read_data(src_file, trg_file):
46
+ src_data = open(src_file).read().strip().split('\n')
47
+
48
+ trg_data = open(trg_file).read().strip().split('\n')
49
+
50
+ return src_data, trg_data
51
+
52
+
53
+ def read_file(file_dir):
54
+ f = open(file_dir, 'r')
55
+ data = f.read().strip().split('\n')
56
+ return data
57
+
58
+ def write_file(file_dir, content):
59
+ f = open(file_dir, "w")
60
+ f.write(content)
61
+ f.close()
62
+
63
+ def create_fields(src_lang, trg_lang):
64
+
65
+ #print("loading spacy tokenizers...")
66
+ #
67
+ # t_src = tokenize(src_lang)
68
+ # t_trg = tokenize(trg_lang)
69
+ # t_src_tokenizer = t_trg_tokenizer = lambda x: x.strip().split()
70
+ target_tokenizer = lambda x: x.strip().split()
71
+
72
+ TRG = data.Field(lower=True, tokenize=target_tokenizer, init_token='<sos>', eos_token='<eos>')
73
+ SRC = data.Field(lower=True, tokenize=tokenize.word_tokenize)
74
+
75
+ return SRC, TRG
76
+
77
+ def create_dataset(src_data, trg_data, max_strlen, batchsize, device, SRC, TRG, istrain=True):
78
+
79
+ print("creating dataset and iterator... ")
80
+
81
+ raw_data = {'src' : [line for line in src_data], 'trg': [line for line in trg_data]}
82
+ df = pd.DataFrame(raw_data, columns=["src", "trg"])
83
+
84
+ mask = (df['src'].str.count(' ') < max_strlen) & (df['trg'].str.count(' ') < max_strlen)
85
+ df = df.loc[mask]
86
+
87
+ df.to_csv("translate_transformer_temp.csv", index=False)
88
+
89
+ data_fields = [('src', SRC), ('trg', TRG)]
90
+ train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields)
91
+
92
+ train_iter = MyIterator(train, batch_size=batchsize, device=device,
93
+ repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
94
+ batch_size_fn=batch_size_fn, train=istrain, shuffle=True)
95
+
96
+ os.remove('translate_transformer_temp.csv')
97
+
98
+ if istrain:
99
+ SRC.build_vocab(train)
100
+ TRG.build_vocab(train)
101
+
102
+ return train_iter
103
+
104
+ class MyIterator(data.Iterator):
105
+ def create_batches(self):
106
+ if self.train:
107
+ def pool(d, random_shuffler):
108
+ for p in data.batch(d, self.batch_size * 100):
109
+ p_batch = data.batch(
110
+ sorted(p, key=self.sort_key),
111
+ self.batch_size, self.batch_size_fn)
112
+ for b in random_shuffler(list(p_batch)):
113
+ yield b
114
+ self.batches = pool(self.data(), self.random_shuffler)
115
+
116
+ else:
117
+ self.batches = []
118
+ for b in data.batch(self.data(), self.batch_size,
119
+ self.batch_size_fn):
120
+ self.batches.append(sorted(b, key=self.sort_key))
121
+
122
+ global max_src_in_batch, max_tgt_in_batch
123
+
124
+ def batch_size_fn(new, count, sofar):
125
+ "Keep augmenting batch and calculate total number of tokens + padding."
126
+ global max_src_in_batch, max_tgt_in_batch
127
+ if count == 1:
128
+ max_src_in_batch = 0
129
+ max_tgt_in_batch = 0
130
+ max_src_in_batch = max(max_src_in_batch, len(new.src))
131
+ max_tgt_in_batch = max(max_tgt_in_batch, len(new.trg) + 2)
132
+ src_elements = count * max_src_in_batch
133
+ tgt_elements = count * max_tgt_in_batch
134
+ return max(src_elements, tgt_elements)
135
+
136
+ def generate_language_token(lang: str):
137
+ return '<{}>'.format(lang.strip())
utils/decode_old.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ import torch
4
+ from torch.autograd import Variable
5
+ import torch.nn.functional as functional
6
+
7
+ from utils.data import multiple_replace, get_synonym
8
+
9
+ def no_peeking_mask(size, device):
10
+ """
11
+ Tạo mask được sử dụng trong decoder để lúc dự đoán trong quá trình huấn luyện
12
+ mô hình không nhìn thấy được các từ ở tương lai
13
+ """
14
+ np_mask = np.triu(np.ones((1, size, size)),
15
+ k=1).astype('uint8')
16
+ np_mask = Variable(torch.from_numpy(np_mask) == 0)
17
+ np_mask = np_mask.to(device)
18
+
19
+ return np_mask
20
+
21
+ def create_masks(src, trg, src_pad, trg_pad, device):
22
+ """ Tạo mask cho encoder,
23
+ để mô hình không bỏ qua thông tin của các kí tự PAD do chúng ta thêm vào
24
+ """
25
+ src_mask = (src != src_pad).unsqueeze(-2)
26
+
27
+ if trg is not None:
28
+ trg_mask = (trg != trg_pad).unsqueeze(-2)
29
+ size = trg.size(1) # get seq_len for matrix
30
+ np_mask = no_peeking_mask(size, device)
31
+ if trg.is_cuda:
32
+ np_mask.cuda()
33
+ trg_mask = trg_mask & np_mask
34
+
35
+ else:
36
+ trg_mask = None
37
+ return src_mask, trg_mask
38
+
39
+ def init_vars(src, model, SRC, TRG, device, k, max_len):
40
+ """ Tính toán các ma trận cần thiết trong quá trình translation sau khi mô hình học xong
41
+ """
42
+ init_tok = TRG.vocab.stoi['<sos>']
43
+ src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
44
+
45
+ # tính sẵn output của encoder
46
+ e_output = model.encoder(src, src_mask)
47
+
48
+ outputs = torch.LongTensor([[init_tok]])
49
+
50
+ outputs = outputs.to(device)
51
+
52
+ trg_mask = no_peeking_mask(1, device)
53
+ # dự đoán kí tự đầu tiên
54
+ out = model.out(model.decoder(outputs,
55
+ e_output, src_mask, trg_mask))
56
+ out = functional.softmax(out, dim=-1)
57
+
58
+ probs, ix = out[:, -1].data.topk(k)
59
+ log_scores = torch.Tensor([math.log(prob) for prob in probs.data[0]]).unsqueeze(0)
60
+
61
+ outputs = torch.zeros(k, max_len).long()
62
+ outputs = outputs.to(device)
63
+ outputs[:, 0] = init_tok
64
+ outputs[:, 1] = ix[0]
65
+
66
+ e_outputs = torch.zeros(k, e_output.size(-2),e_output.size(-1))
67
+
68
+ e_outputs = e_outputs.to(device)
69
+ e_outputs[:, :] = e_output[0]
70
+
71
+ return outputs, e_outputs, log_scores
72
+
73
+ def k_best_outputs(outputs, out, log_scores, i, k):
74
+ # debug print
75
+
76
+ probs, ix = out[:, -1].data.topk(k)
77
+ log_probs = torch.Tensor([math.log(p) for p in probs.data.view(-1)]).view(k, -1) + log_scores.transpose(0,1)
78
+ k_probs, k_ix = log_probs.view(-1).topk(k)
79
+
80
+ row = k_ix // k
81
+ col = k_ix % k
82
+
83
+ outputs[:, :i] = outputs[row, :i]
84
+ outputs[:, i] = ix[row, col]
85
+
86
+ log_scores = k_probs.unsqueeze(0)
87
+
88
+ return outputs, log_scores
89
+
90
+ def beam_search(src, model, SRC, TRG, device, k, max_len, debug=False, output_list_of_tokens=False):
91
+
92
+ outputs, e_outputs, log_scores = init_vars(src, model, SRC, TRG, device, k, max_len)
93
+ eos_tok = TRG.vocab.stoi['<eos>']
94
+ src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
95
+ ind = None
96
+ for i in range(2, max_len):
97
+ if(debug):
98
+ print("Current iteration to maxlen: {:d}".format(i))
99
+
100
+ trg_mask = no_peeking_mask(i, device)
101
+
102
+ out = model.out(model.decoder(outputs[:,:i], e_outputs, src_mask, trg_mask))
103
+
104
+ out = functional.softmax(out, dim=-1)
105
+
106
+ outputs, log_scores = k_best_outputs(outputs, out, log_scores, i, k)
107
+
108
+ ones = (outputs==eos_tok).nonzero() # Occurrences of end symbols for all input sentences.
109
+ sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).to(device)
110
+ for vec in ones:
111
+ i = vec[0]
112
+ if sentence_lengths[i]==0: # First end symbol has not been found yet
113
+ sentence_lengths[i] = vec[1] # Position of first end symbol
114
+
115
+ num_finished_sentences = len([s for s in sentence_lengths if s > 0])
116
+
117
+ if num_finished_sentences == k:
118
+ alpha = 0.7
119
+ div = 1/(sentence_lengths.type_as(log_scores)**alpha)
120
+ _, ind = torch.max(log_scores * div, 1)
121
+ ind = ind.data[0]
122
+ break
123
+
124
+ # additional change to output list of tokens instead of string
125
+ join_fn = (lambda x: x) if(output_list_of_tokens) else (lambda x: " ".join(x))
126
+
127
+ if ind is None:
128
+ length = (outputs[0]==eos_tok).nonzero()[0] if len((outputs[0]==eos_tok).nonzero()) > 0 else -1
129
+ return join_fn([TRG.vocab.itos[tok] for tok in outputs[0, 1:length]])
130
+ else:
131
+ length = (outputs[ind]==eos_tok).nonzero()[0]
132
+ return join_fn([TRG.vocab.itos[tok] for tok in outputs[ind, 1:length]])
133
+
134
+ def translate_sentence(raw_sentence, model, SRC, TRG, device, k, max_len, debug=False, output_list_of_tokens=False):
135
+ """Dịch một câu sử dụng beamsearch
136
+ """
137
+ model.eval()
138
+ indexed = []
139
+ if(isinstance(raw_sentence, str)):
140
+ # single sentence, require preprocessing
141
+ sentence = SRC.preprocess(raw_sentence)
142
+ else:
143
+ # already tokenized (taken from iterators, etc.)
144
+ sentence = raw_sentence
145
+
146
+ for tok in sentence:
147
+ if SRC.vocab.stoi[tok] != SRC.vocab.stoi['<eos>']:
148
+ indexed.append(SRC.vocab.stoi[tok])
149
+ else:
150
+ indexed.append(get_synonym(tok, SRC))
151
+
152
+ output = Variable(torch.LongTensor([indexed]))
153
+
154
+ output = output.to(device)
155
+
156
+ output = beam_search(output, model, SRC, TRG, device, k, max_len, output_list_of_tokens=output_list_of_tokens)
157
+
158
+ if(debug):
159
+ print("{} -> {}".format(raw_sentence, output))
160
+
161
+ return output
162
+
163
+ # return multiple_replace({' ?' : '?',' !':'!',' .':'.','\' ':'\'',' ,':','}, sentence)
utils/logging.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import os
4
+ import logging
5
+
6
+ def init_logger(model_dir, log_file=None, rotate=False):
7
+
8
+ logging.basicConfig(level=logging.DEBUG,
9
+ format='[%(asctime)s %(levelname)s] %(message)s',
10
+ datefmt='%a, %d %b %Y %H:%M:%S',
11
+ filename=os.path.join(model_dir, log_file),
12
+ filemode='w')
13
+ console = logging.StreamHandler()
14
+ console.setLevel(logging.INFO)
15
+ # set a format which is simpler for console use
16
+ formatter = logging.Formatter('[%(asctime)s %(levelname)s] %(message)s', '%a, %d %b %Y %H:%M:%S')
17
+ # tell the handler to use this format
18
+ console.setFormatter(formatter)
19
+ # add the handler to the root logger
20
+ logging.getLogger('').addHandler(console)
21
+
22
+ return logging
utils/loss.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class LabelSmoothingLoss(nn.Module):
5
+ def __init__(self, classes, padding_idx, smoothing=0.0, dim=-1):
6
+ super(LabelSmoothingLoss, self).__init__()
7
+ self.confidence = 1.0 - smoothing
8
+ self.smoothing = smoothing
9
+ self.cls = classes
10
+ self.dim = dim
11
+ self.padding_idx = padding_idx
12
+
13
+ def forward(self, pred, target):
14
+ pred = pred.log_softmax(dim=self.dim)
15
+ with torch.no_grad():
16
+ # true_dist = pred.data.clone()
17
+ true_dist = torch.zeros_like(pred)
18
+ true_dist.fill_(self.smoothing / (self.cls - 2))
19
+ true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
20
+ true_dist[:, self.padding_idx] = 0
21
+ mask = torch.nonzero(target.data == self.padding_idx) #, as_tuple=False is redundant and causing error
22
+ if mask.dim() > 0:
23
+ true_dist.index_fill_(0, mask.squeeze(), 0.0)
24
+
25
+ return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
utils/metric.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchtext.data.metrics import bleu_score
2
+
3
+ def bleu(valid_src_data, valid_trg_data, model, device, k, max_strlen):
4
+ pred_sents = []
5
+ for sentence in valid_src_data:
6
+ pred_trg = model.translate_sentence(sentence, device, k, max_strlen)
7
+ pred_sents.append(pred_trg)
8
+
9
+ pred_sents = [self.TRG.preprocess(sent) for sent in pred_sents]
10
+ trg_sents = [[sent.split()] for sent in valid_trg_data]
11
+
12
+ return bleu_score(pred_sents, trg_sents)
13
+
14
+ def bleu_single(model, valid_dataset, debug=False):
15
+ """Perform single sentence translation, then calculate bleu score. Update when batch beam search is online"""
16
+ # need to join the sentence back per sample (the iterator is the one that had already been split to tokens)
17
+ # THIS METRIC USE 2D vs 3D! AAAAAAHHHHHHH!!!!
18
+ translate_pair = ( ([pair.trg], model.translate_sentence(pair.src, debug=debug)) for pair in valid_dataset)
19
+ # raise Exception(next(translate_pair))
20
+ labels, predictions = [list(l) for l in zip(*translate_pair)] # zip( *((l, p.split()) for l, p in translate_pair) )
21
+ return bleu_score(predictions, labels)
22
+
23
+ def bleu_batch(model, valid_dataset, batch_size, debug=False):
24
+ """Perform batch sentence translation in the same vein."""
25
+ predictions = model.translate_batch_sentence([s.src for s in valid_dataset], output_tokens=True, batch_size=batch_size)
26
+ labels = [[s.trg] for s in valid_dataset]
27
+ return bleu_score(predictions, labels)
28
+
29
+
30
+ def _revert_trg(sent, eos): # revert batching process on sentence
31
+ try:
32
+ endloc = sent.index(eos)
33
+ return sent[1:endloc]
34
+ except ValueError:
35
+ return sent[1:]
36
+
37
+ def bleu_batch_iter(model, valid_iter, src_lang=None, trg_lang=None, eos_token="<eos>", debug=False):
38
+ """Perform batched translations; other metrics are the same. Note that the inputs/outputs had been preprocessed, but have [length, batch_size] shape as per BucketIterator"""
39
+ # raise NotImplementedError("Error during calculation, currently unusable.")
40
+ # raise Exception([[model.SRC.vocab.itos[t] for t in batch] for batch in next(iter(valid_iter)).src.transpose(0, 1)])
41
+
42
+ translated_batched_pair = (
43
+ (
44
+ pair.trg.transpose(0, 1), # transpose due to timestep-first batches
45
+ model.decode_strategy.translate_batch_sentence(
46
+ pair.src.transpose(0, 1),
47
+ src_lang=src_lang,
48
+ trg_lang=trg_lang,
49
+ output_tokens=True,
50
+ field_processed=True,
51
+ replace_unk=False, # do not replace in this version
52
+ debug=debug
53
+ )
54
+ )
55
+ for pair in valid_iter
56
+ )
57
+ flattened_pair = ( ([model.TRG.vocab.itos[i] for i in trg], pred) for batch_trg, batch_pred in translated_batched_pair for trg, pred in zip(batch_trg, batch_pred) )
58
+ flat_labels, predictions = [list(l) for l in zip(*flattened_pair)]
59
+ labels = [[_revert_trg(l, eos_token)] for l in flat_labels] # remove <sos> and <eos> also updim the trg for 3D requirements.
60
+ return bleu_score(predictions, labels)