nreimers commited on
Commit
69dbbb4
1 Parent(s): c497e50
README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # DistilBERT with word2vec token embeddings
2
+
3
+ This model has a word2vec token embedding matrix with 256k entries. The word2vec was trained on 100GB data from C4, MSMARCO, News, Wikipedia, S2ORC, for 3 epochs.
4
+
5
+ Then the model was trained on this dataset with MLM for 250k steps (batch size 64). The token embeddings were NOT updated.
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "train-w2v-model/c4_msmarco_news_s2orc_wiki/distilbert-256k/",
3
+ "activation": "gelu",
4
+ "architectures": [
5
+ "DistilBertForMaskedLM"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "hidden_dim": 3072,
11
+ "initializer_range": 0.02,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "distilbert",
14
+ "n_heads": 12,
15
+ "n_layers": 6,
16
+ "pad_token_id": 0,
17
+ "qa_dropout": 0.1,
18
+ "seq_classif_dropout": 0.2,
19
+ "sinusoidal_pos_embds": false,
20
+ "tie_weights_": true,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.17.0",
23
+ "vocab_size": 256000
24
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d86d806578dfb9255ebc056205c99ac0622768fe42427eb3c9b457ef0631444
3
+ size 961553391
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_max_length": 512, "unk_token": "[UNK]", "cls_token": "[CLS]", "sep_token": "[SEP]", "pad_token": "[PAD]", "mask_token": "[MASK]", "model_input_names": ["input_ids", "attention_mask"], "special_tokens_map_file": "c4_msmarco_news_s2orc_wiki/tokenizer-256k/special_tokens_map.json", "name_or_path": "train-w2v-model/c4_msmarco_news_s2orc_wiki/distilbert-256k/", "tokenizer_class": "PreTrainedTokenizerFast"}
train_script.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ import logging
4
+ import math
5
+ import os
6
+ from datetime import datetime
7
+ import datasets
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+ from tqdm.auto import tqdm
11
+ import sys
12
+ import transformers
13
+ from accelerate import Accelerator, DistributedType
14
+ from shutil import copyfile
15
+ import wandb
16
+ import numpy as np
17
+
18
+ from transformers import (
19
+ MODEL_MAPPING,
20
+ AutoModelForMaskedLM,
21
+ AutoTokenizer,
22
+ DataCollatorForLanguageModeling,
23
+ SchedulerType,
24
+ get_scheduler
25
+ )
26
+ from transformers.utils.versions import require_version
27
+
28
+
29
+
30
+ class TrainDataset(torch.utils.data.IterableDataset):
31
+ def __init__(self, filepath, tokenizer, max_length, batch_size, train_samples):
32
+ self.tokenizer = tokenizer
33
+ self.fIn = open(filepath)
34
+ self.max_length = max_length
35
+ self.batch_size = batch_size
36
+ self.train_samples = train_samples
37
+
38
+ def __iter__(self):
39
+ batch = []
40
+ for sent in self.fIn:
41
+ batch.append(sent.strip()[0:1000])
42
+
43
+ if len(batch) >= self.batch_size:
44
+ #Use multi process tokenization
45
+ encoded = self.tokenizer(batch, add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True, padding=True)
46
+ #print(len(encoded['input_ids'][0]))
47
+ for idx in range(len(batch)):
48
+ single_sample = {key: encoded[key][idx] for key in encoded}
49
+ yield single_sample
50
+
51
+ batch = []
52
+
53
+ def __len__(self):
54
+ return self.train_samples
55
+
56
+
57
+
58
+
59
+
60
+ ## Dev dataset
61
+ class DevDataset(torch.utils.data.Dataset):
62
+ def __init__(self, filepath, tokenizer, max_length):
63
+ self.tokenizer = tokenizer
64
+ self.max_length = max_length
65
+ with open(filepath) as fIn:
66
+ sentences = [sent.strip() for sent in fIn]
67
+
68
+ self.num_sentences = len(sentences)
69
+ self.tokenized = self.tokenizer(sentences, add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True)
70
+
71
+ def __getitem__(self, idx):
72
+ return {key: self.tokenized[key][idx] for key in self.tokenized}
73
+
74
+ def __len__(self):
75
+ return self.num_sentences
76
+
77
+
78
+
79
+ logger = logging.getLogger(__name__)
80
+ require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
81
+ MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
82
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
83
+
84
+
85
+ def parse_args():
86
+ parser = argparse.ArgumentParser(description="Finetune a transformers model on a Masked Language Modeling task")
87
+ parser.add_argument(
88
+ "--dataset_config_name",
89
+ type=str,
90
+ default=None,
91
+ help="The configuration name of the dataset to use (via the datasets library).",
92
+ )
93
+ parser.add_argument(
94
+ "--train_file", type=str, default=None, help="A text file data (1 text per line).."
95
+ )
96
+ parser.add_argument(
97
+ "--dev_file", type=str, default=None, help="A text file data (1 text per line)."
98
+ )
99
+ parser.add_argument(
100
+ "--model_name",
101
+ default="nicoladecao/msmarco-word2vec256000-distilbert-base-uncased",
102
+ type=str,
103
+ help="Path to pretrained model or model identifier from huggingface.co/models."
104
+ )
105
+ parser.add_argument(
106
+ "--per_device_batch_size",
107
+ type=int,
108
+ default=16,
109
+ help="Batch size (per device) for the training dataloader.",
110
+ )
111
+ parser.add_argument(
112
+ "--learning_rate",
113
+ type=float,
114
+ default=5e-5,
115
+ help="Initial learning rate (after the potential warmup period) to use.",
116
+ )
117
+ parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.")
118
+ parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.")
119
+ parser.add_argument(
120
+ "--max_train_steps",
121
+ type=int,
122
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
123
+ )
124
+ parser.add_argument(
125
+ "--gradient_accumulation_steps",
126
+ type=int,
127
+ default=1,
128
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
129
+ )
130
+ parser.add_argument(
131
+ "--lr_scheduler_type",
132
+ type=SchedulerType,
133
+ default="linear",
134
+ help="The scheduler type to use.",
135
+ choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
136
+ )
137
+ parser.add_argument(
138
+ "--num_warmup_steps", type=int, default=1000, help="Number of steps for the warmup in the lr scheduler."
139
+ )
140
+ parser.add_argument(
141
+ "--model_type",
142
+ type=str,
143
+ default=None,
144
+ help="Model type to use if training from scratch.",
145
+ choices=MODEL_TYPES,
146
+ )
147
+ parser.add_argument(
148
+ "--max_seq_length",
149
+ type=int,
150
+ default=256,
151
+ help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated.",
152
+ )
153
+ parser.add_argument(
154
+ "--line_by_line",
155
+ type=bool,
156
+ default=True,
157
+ help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.",
158
+ )
159
+ parser.add_argument(
160
+ "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
161
+ )
162
+ parser.add_argument(
163
+ "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss"
164
+ )
165
+ parser.add_argument("--mixed_precision", default="fp16")
166
+ parser.add_argument("--train_samples", required=True, type=int)
167
+ parser.add_argument("--eval_steps", default=10000, type=int)
168
+ parser.add_argument("--max_grad_norm", default=1.0, type=float)
169
+ parser.add_argument("--project", default="bert-word2vec")
170
+ parser.add_argument("--freeze_emb_layer", default=False, action='store_true')
171
+ parser.add_argument("--log_interval", default=1000, type=int)
172
+ parser.add_argument("--ckp_steps", default=50000, type=int)
173
+
174
+ args = parser.parse_args()
175
+
176
+
177
+ return args
178
+
179
+
180
+ def main():
181
+ args = parse_args()
182
+
183
+ # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
184
+ accelerator = Accelerator(mixed_precision=args.mixed_precision)
185
+ # Make one log on every process with the configuration for debugging.
186
+ logging.basicConfig(
187
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
188
+ datefmt="%m/%d/%Y %H:%M:%S",
189
+ level=logging.INFO,
190
+ )
191
+ logger.info(accelerator.state)
192
+
193
+ # Setup logging, we only want one process per machine to log things on the screen.
194
+ # accelerator.is_local_main_process is only True for one process per machine.
195
+ logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
196
+ if accelerator.is_local_main_process:
197
+ datasets.utils.logging.set_verbosity_warning()
198
+ transformers.utils.logging.set_verbosity_info()
199
+ else:
200
+ datasets.utils.logging.set_verbosity_error()
201
+ transformers.utils.logging.set_verbosity_error()
202
+
203
+
204
+ accelerator.wait_for_everyone()
205
+
206
+
207
+ #Load model
208
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
209
+ model = AutoModelForMaskedLM.from_pretrained(args.model_name)
210
+
211
+ #Freeze emb layer
212
+ if args.freeze_emb_layer:
213
+ model.distilbert.embeddings.word_embeddings.requires_grad_(False)
214
+
215
+ # Logging & Co on main process
216
+ if accelerator.is_main_process:
217
+ exp_name = f'{args.model_name.replace("/", "-")}-{"freeze_emb" if args.freeze_emb_layer else "update_emb"}-{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
218
+ output_dir = os.path.join("output-mlm", exp_name)
219
+ wandb.init(project=args.project, name=exp_name, config=args)
220
+
221
+ os.makedirs(output_dir, exist_ok=False)
222
+
223
+ #Save tokenizer
224
+ tokenizer.save_pretrained(output_dir)
225
+
226
+ #Save train script
227
+ train_script_path = os.path.join(output_dir, 'train_script.py')
228
+ copyfile(__file__, train_script_path)
229
+ with open(train_script_path, 'a') as fOut:
230
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
231
+
232
+
233
+ total_batch_size = args.per_device_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
234
+
235
+ train_dataset = TrainDataset(args.train_file, tokenizer, args.max_seq_length, batch_size=total_batch_size, train_samples=args.train_samples)
236
+ eval_dataset = DevDataset(args.dev_file, tokenizer, args.max_seq_length)
237
+
238
+
239
+ # Data collator
240
+ # This one will take care of randomly masking the tokens.
241
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=args.mlm_probability)
242
+
243
+ # DataLoaders creation:
244
+ train_dataloader = DataLoader(train_dataset, collate_fn=data_collator, batch_size=args.per_device_batch_size)
245
+ eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_batch_size)
246
+
247
+ # Optimizer
248
+ # Split weights in two groups, one with weight decay and the other not.
249
+ no_decay = ["bias", "LayerNorm.weight"]
250
+ optimizer_grouped_parameters = [
251
+ {
252
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
253
+ "weight_decay": args.weight_decay,
254
+ },
255
+ {
256
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
257
+ "weight_decay": 0.0,
258
+ },
259
+ ]
260
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
261
+
262
+ # Prepare everything with our `accelerator`.
263
+ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader)
264
+
265
+ # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
266
+ if accelerator.distributed_type == DistributedType.TPU:
267
+ model.tie_weights()
268
+
269
+ # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
270
+ # shorter in multiprocess)
271
+
272
+ # Scheduler and math around the number of training steps.
273
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
274
+ if args.max_train_steps is None:
275
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
276
+ else:
277
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
278
+
279
+ lr_scheduler = get_scheduler(
280
+ name=args.lr_scheduler_type,
281
+ optimizer=optimizer,
282
+ num_warmup_steps=args.num_warmup_steps,
283
+ num_training_steps=args.max_train_steps,
284
+ )
285
+
286
+
287
+ # Train!
288
+ logger.info("***** Running training *****")
289
+ logger.info(f" Num examples = {args.train_samples}")
290
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
291
+ logger.info(f" Instantaneous batch size per device = {args.per_device_batch_size}")
292
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
293
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
294
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
295
+ # Only show the progress bar once on each machine.
296
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, smoothing=0.05)
297
+ completed_steps = 0
298
+ train_loss_values = []
299
+
300
+ best_eval_loss = 999999
301
+ if accelerator.is_main_process:
302
+ best_ckp_dir = os.path.join(output_dir, "best")
303
+ tokenizer.save_pretrained(best_ckp_dir)
304
+
305
+ for epoch in range(args.num_train_epochs):
306
+ logger.info(f"Start epoch {epoch}")
307
+ model.train()
308
+ for step, batch in enumerate(train_dataloader):
309
+ outputs = model(**batch)
310
+ loss = outputs.loss
311
+ loss = loss / args.gradient_accumulation_steps
312
+
313
+ if accelerator.is_main_process:
314
+ train_loss_values.append(loss.cpu().item())
315
+
316
+ accelerator.backward(loss)
317
+ accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
318
+ if step % args.gradient_accumulation_steps == 0:
319
+ optimizer.step()
320
+ lr_scheduler.step()
321
+ optimizer.zero_grad()
322
+ progress_bar.update(1)
323
+ completed_steps += 1
324
+
325
+ ### Do logging
326
+ if accelerator.is_main_process:
327
+ if completed_steps % args.log_interval == 0:
328
+ wandb.log({"train/loss": np.mean(train_loss_values)}, step=completed_steps)
329
+ train_loss_values = []
330
+
331
+
332
+ if completed_steps % args.eval_steps == 0:
333
+ model.eval()
334
+ losses = []
335
+ for step, batch in enumerate(eval_dataloader):
336
+ with torch.no_grad():
337
+ outputs = model(**batch)
338
+
339
+ loss = outputs.loss
340
+ losses.append(accelerator.gather(loss.repeat(args.per_device_batch_size)))
341
+
342
+ losses = torch.cat(losses)
343
+ losses = losses[: len(eval_dataset)]
344
+ try:
345
+ eval_loss = torch.mean(losses)
346
+ except OverflowError:
347
+ eval_loss = float("inf")
348
+
349
+ logger.info(f"step {completed_steps}: perplexity: {eval_loss}")
350
+ if accelerator.is_main_process:
351
+ wandb.log({"eval/loss": eval_loss}, step=completed_steps)
352
+
353
+ model.train()
354
+
355
+ #Save model
356
+ accelerator.wait_for_everyone()
357
+ if accelerator.is_main_process:
358
+ unwrapped_model = accelerator.unwrap_model(model)
359
+ unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
360
+ with open(os.path.join(output_dir, "train_steps.log"), 'a') as fOut:
361
+ fOut.write(f"{completed_steps}: {eval_loss}\n")
362
+
363
+ #Save best model
364
+ if eval_loss < best_eval_loss:
365
+ best_eval_loss = eval_loss
366
+ unwrapped_model.save_pretrained(best_ckp_dir, save_function=accelerator.save)
367
+ with open(os.path.join(best_ckp_dir, "train_steps.log"), 'a') as fOut:
368
+ fOut.write(f"{completed_steps}: {eval_loss}\n")
369
+
370
+ if accelerator.is_main_process and completed_steps % args.ckp_steps == 0:
371
+ ckp_dir = os.path.join(output_dir, f"ckp-{int(completed_steps/1000)}k")
372
+ unwrapped_model = accelerator.unwrap_model(model)
373
+ unwrapped_model.save_pretrained(ckp_dir, save_function=accelerator.save)
374
+ tokenizer.save_pretrained(ckp_dir)
375
+ with open(os.path.join(ckp_dir, "train_steps.log"), 'a') as fOut:
376
+ fOut.write(f"{completed_steps}: {eval_loss}\n")
377
+
378
+
379
+ if completed_steps >= args.max_train_steps:
380
+ break
381
+
382
+ if args.output_dir is not None:
383
+ accelerator.wait_for_everyone()
384
+ if accelerator.is_main_process:
385
+ unwrapped_model = accelerator.unwrap_model(model)
386
+ unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
387
+ with open(os.path.join(output_dir, "train_steps.log"), 'a') as fOut:
388
+ fOut.write(f"{completed_steps}\n")
389
+
390
+
391
+
392
+
393
+ if __name__ == "__main__":
394
+ main()
395
+
396
+
397
+ # Script was called via:
398
+ #python train_mlm-iterable.py --train_file data/c4_msmarco_news_s2orc_wiki_train.txt --dev_file data/c4_msmarco_news_s2orc_wiki_dev.txt --train_samples 100000000 --model_name train-w2v-model/c4_msmarco_news_s2orc_wiki/distilbert-256k/ --freeze_emb_layer
train_steps.log ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 10000: 3.6185991764068604
2
+ 20000: 3.181567430496216
3
+ 30000: 3.019852638244629
4
+ 40000: 2.8929433822631836
5
+ 50000: 2.865853786468506
6
+ 60000: 2.8218629360198975
7
+ 70000: 2.7376461029052734
8
+ 90000: 2.698227882385254
9
+ 100000: 2.6650893688201904
10
+ 120000: 2.6339340209960938
11
+ 130000: 2.593796730041504
12
+ 160000: 2.570080280303955
13
+ 180000: 2.5539512634277344
14
+ 190000: 2.5419578552246094
15
+ 210000: 2.4972760677337646
16
+ 260000: 2.4895386695861816
17
+ 270000: 2.481090545654297
18
+ 290000: 2.4765520095825195
19
+ 300000: 2.463596820831299
20
+ 320000: 2.4584429264068604
21
+ 350000: 2.450732469558716
22
+ 360000: 2.443289279937744
23
+ 370000: 2.4305179119110107
24
+ 410000: 2.4060347080230713
25
+ 470000: 2.376832962036133
26
+ 510000: 2.3685810565948486
27
+ 550000: 2.3647472858428955
28
+ 600000: 2.3556222915649414
29
+ 670000: 2.3360767364501953
30
+ 690000: 2.327178955078125
31
+ 730000: 2.3191168308258057
32
+ 740000: 2.3143470287323
33
+ 830000: 2.3057608604431152
34
+ 840000: 2.2876601219177246
35
+ 980000: 2.253411293029785
36
+ 1080000: 2.241132974624634
37
+ 1230000: 2.234037160873413
38
+ 1320000: 2.2321970462799072
39
+ 1370000: 2.2040650844573975