yhavinga commited on
Commit
95cb6e5
1 Parent(s): 6ff213f

Add scripts

Browse files
Files changed (3) hide show
  1. flax_to_pytorch.py +26 -5
  2. run_t5.sh +13 -12
  3. run_t5_mlm_flax.py +966 -0
flax_to_pytorch.py CHANGED
@@ -1,5 +1,26 @@
1
- from transformers import T5ForConditionalGeneration, TFT5ForConditionalGeneration
2
- pt_model = T5ForConditionalGeneration.from_pretrained(".", from_flax=True)
3
- pt_model.save_pretrained(".")
4
- tf_model = TFT5ForConditionalGeneration.from_pretrained(".", from_pt=True)
5
- tf_model.save_pretrained(".")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import jax.numpy as jnp
4
+ from transformers import AutoTokenizer
5
+ from transformers import FlaxT5ForConditionalGeneration
6
+ from transformers import T5ForConditionalGeneration
7
+ tokenizer = AutoTokenizer.from_pretrained(".")
8
+ model_fx = FlaxT5ForConditionalGeneration.from_pretrained(".")
9
+ model_pt = T5ForConditionalGeneration.from_pretrained(".", from_flax=True)
10
+ model_pt.save_pretrained("./")
11
+ text = "Hoe gaat het?"
12
+ e_input_ids_fx = tokenizer(text, return_tensors="np", padding=True, max_length=128, truncation=True)
13
+ d_input_ids_fx = jnp.ones((e_input_ids_fx.input_ids.shape[0], 1), dtype="i4") * model_fx.config.decoder_start_token_id
14
+ e_input_ids_pt = tokenizer(text, return_tensors="pt", padding=True, max_length=128, truncation=True)
15
+ d_input_ids_pt = np.ones((e_input_ids_pt.input_ids.shape[0], 1), dtype="i4") * model_pt.config.decoder_start_token_id
16
+ print(e_input_ids_fx)
17
+ print(d_input_ids_fx)
18
+ print()
19
+ encoder_pt = model_fx.encode(**e_input_ids_pt)
20
+ decoder_pt = model_fx.decode(d_input_ids_pt, encoder_pt)
21
+ logits_pt = decoder_pt.logits
22
+ print(logits_pt)
23
+ encoder_fx = model_fx.encode(**e_input_ids_fx)
24
+ decoder_fx = model_fx.decode(d_input_ids_fx, encoder_fx)
25
+ logits_fx = decoder_fx.logits
26
+ print(logits_fx)
run_t5.sh CHANGED
@@ -1,5 +1,17 @@
1
  #!/bin/bash
2
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  python run_t5_mlm_flax.py \
4
  --output_dir="${MODEL_PATH}" \
5
  --model_type="t5" \
@@ -22,15 +34,4 @@ python run_t5_mlm_flax.py \
22
  --weight_decay="0.01" \
23
  --warmup_steps="10000" \
24
  --validation_split_count="15000" \
25
- --push_to_hub \
26
- # --adam_beta1="0.9" \
27
- # --adam_beta2="0.98" \
28
- # --resume_from_checkpoint="${MODEL_DIR}" \ # Uncomment to resume from ckpt
29
- # --max_train_samples 100000 \
30
- # --max_eval_samples 1000 \
31
- # --adafactor \
32
- # --save_steps="80000" \
33
-
34
-
35
- # Instead of adafactor: adamw
36
-
 
1
  #!/bin/bash
2
 
3
+ export HF_PROJECT="t5-base-dutch"
4
+
5
+ # Variables for training the tokenizer and creating the config
6
+ export VOCAB_SIZE="32000"
7
+ export N_INPUT_SENTENCES="1000000" # Num of sentences to train the tokenizer
8
+ export DATASET="yhavinga/mc4_nl_cleaned" # Name of the dataset in the Huggingface Hub
9
+ export DATASET_CONFIG="full" # Config of the dataset in the Huggingface Hub
10
+ export DATASET_SPLIT="train" # Split to use for training tokenizer and model
11
+ export TEXT_FIELD="text" # Field containing the text to be used for training
12
+ export CONFIG_TYPE="t5-base" # Config that our model will use
13
+ export MODEL_PATH="${HOME}/data/${HF_PROJECT}" # Path to the model, e.g. here inside the mount
14
+
15
  python run_t5_mlm_flax.py \
16
  --output_dir="${MODEL_PATH}" \
17
  --model_type="t5" \
 
34
  --weight_decay="0.01" \
35
  --warmup_steps="10000" \
36
  --validation_split_count="15000" \
37
+ --push_to_hub
 
 
 
 
 
 
 
 
 
 
 
run_t5_mlm_flax.py ADDED
@@ -0,0 +1,966 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Pretraining the library models for T5-like span-masked language modeling on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be pretrained by this script:
20
+ https://huggingface.co/models?filter=t5
21
+
22
+ Adapted from the original version to support gradient accumulation and restarting.
23
+ """
24
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
25
+ import logging
26
+ import os
27
+ import sys
28
+ import time
29
+ import json
30
+ from dataclasses import dataclass, field
31
+ from itertools import chain
32
+ from pathlib import Path
33
+ from typing import Dict, List, Optional
34
+
35
+ import numpy as np
36
+ from datasets import load_dataset
37
+ from tqdm import tqdm
38
+
39
+ import flax
40
+ import jax
41
+ import jax.numpy as jnp
42
+ import optax
43
+ from flax import jax_utils, traverse_util
44
+ from flax.serialization import to_bytes, from_bytes
45
+ from flax.training import train_state
46
+ from flax.training.common_utils import get_metrics, onehot, shard
47
+ # from huggingface_hub import Repository
48
+ from transformers import (
49
+ CONFIG_MAPPING,
50
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
51
+ AutoTokenizer,
52
+ BatchEncoding,
53
+ FlaxT5ForConditionalGeneration,
54
+ HfArgumentParser,
55
+ PreTrainedTokenizerBase,
56
+ T5Config,
57
+ TrainingArguments,
58
+ is_tensorboard_available,
59
+ set_seed,
60
+ )
61
+ # from transformers.file_utils import get_full_repo_name
62
+ from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
63
+
64
+ logger = logging.getLogger(__name__)
65
+
66
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
67
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
68
+
69
+ @dataclass
70
+ class ModelArguments:
71
+ """
72
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
73
+ """
74
+
75
+ model_name_or_path: Optional[str] = field(
76
+ default=None,
77
+ metadata={
78
+ "help": "The model checkpoint for weights initialization."
79
+ "Don't set if you want to train a model from scratch."
80
+ },
81
+ )
82
+ model_type: Optional[str] = field(
83
+ default=None,
84
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
85
+ )
86
+ config_name: Optional[str] = field(
87
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
88
+ )
89
+ tokenizer_name: Optional[str] = field(
90
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
91
+ )
92
+ cache_dir: Optional[str] = field(
93
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
94
+ )
95
+ use_fast_tokenizer: bool = field(
96
+ default=True,
97
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
98
+ )
99
+ dtype: Optional[str] = field(
100
+ default="float32",
101
+ metadata={
102
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
103
+ },
104
+ )
105
+ auth_token: Optional[str] = field(
106
+ default=None,
107
+ metadata={
108
+ "help": "Auth token for private repositories on the Huggingface Hub"
109
+ }
110
+ )
111
+
112
+
113
+ @dataclass
114
+ class DataTrainingArguments:
115
+ """
116
+ Arguments pertaining to what data we are going to input our model for training and eval.
117
+ """
118
+
119
+ dataset_name: Optional[str] = field(
120
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
121
+ )
122
+ dataset_config_name: Optional[str] = field(
123
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
124
+ )
125
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
126
+ validation_file: Optional[str] = field(
127
+ default=None,
128
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
129
+ )
130
+ train_ref_file: Optional[str] = field(
131
+ default=None,
132
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
133
+ )
134
+ validation_ref_file: Optional[str] = field(
135
+ default=None,
136
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
137
+ )
138
+ overwrite_cache: bool = field(
139
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
140
+ )
141
+ validation_split_count: Optional[int] = field(
142
+ default=10000,
143
+ metadata={
144
+ "help": "The count of the train set used as validation set in case there's no validation split"
145
+ },
146
+ )
147
+ max_seq_length: Optional[int] = field(
148
+ default=None,
149
+ metadata={
150
+ "help": "The maximum total input sequence length after tokenization and masking. Sequences longer than this will be truncated. Default to the max input length of the model."
151
+ },
152
+ )
153
+ preprocessing_num_workers: Optional[int] = field(
154
+ default=None,
155
+ metadata={"help": "The number of processes to use for the preprocessing."},
156
+ )
157
+ mlm_probability: float = field(
158
+ default=0.15, metadata={"help": "Ratio of tokens to mask for span masked language modeling loss"}
159
+ )
160
+ mean_noise_span_length: float = field(
161
+ default=3.0,
162
+ metadata={"help": "Mean span length of masked tokens"},
163
+ )
164
+ max_train_samples: Optional[int] = field(
165
+ default=None,
166
+ metadata={
167
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
168
+ "value if set."
169
+ },
170
+ )
171
+ max_eval_samples: Optional[int] = field(
172
+ default=None,
173
+ metadata={
174
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
175
+ "value if set."
176
+ },
177
+ )
178
+
179
+ def __post_init__(self):
180
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
181
+ raise ValueError("Need either a dataset name or a training/validation file.")
182
+ else:
183
+ if self.train_file is not None:
184
+ extension = self.train_file.split(".")[-1]
185
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
186
+ if self.validation_file is not None:
187
+ extension = self.validation_file.split(".")[-1]
188
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
189
+
190
+
191
+ def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
192
+ """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .
193
+
194
+ Training parameters to avoid padding with random_spans_noise_mask.
195
+ When training a model with random_spans_noise_mask, we would like to set the other
196
+ training hyperparmeters in a way that avoids padding.
197
+ This function helps us compute these hyperparameters.
198
+ We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,
199
+ and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.
200
+ This function tells us the required number of tokens in the raw example (for split_tokens())
201
+ as well as the length of the encoded targets. Note that this function assumes
202
+ the inputs and targets will have EOS appended and includes that in the reported length.
203
+
204
+ Args:
205
+ inputs_length: an integer - desired length of the tokenized inputs sequence
206
+ noise_density: a float
207
+ mean_noise_span_length: a float
208
+ Returns:
209
+ tokens_length: length of original text in tokens
210
+ targets_length: an integer - length in tokens of encoded targets sequence
211
+ """
212
+
213
+ def _tokens_length_to_inputs_length_targets_length(tokens_length):
214
+ num_noise_tokens = int(round(tokens_length * noise_density))
215
+ num_nonnoise_tokens = tokens_length - num_noise_tokens
216
+ num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
217
+ # inputs contain all nonnoise tokens, sentinels for all noise spans
218
+ # and one EOS token.
219
+ _input_length = num_nonnoise_tokens + num_noise_spans + 1
220
+ _output_length = num_noise_tokens + num_noise_spans + 1
221
+ return _input_length, _output_length
222
+
223
+ tokens_length = inputs_length
224
+
225
+ while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
226
+ tokens_length += 1
227
+
228
+ inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)
229
+
230
+ # minor hack to get the targets length to be equal to inputs length
231
+ # which is more likely to have been set to a nice round number.
232
+ if noise_density == 0.5 and targets_length > inputs_length:
233
+ tokens_length -= 1
234
+ targets_length -= 1
235
+ return tokens_length, targets_length
236
+
237
+
238
+ @flax.struct.dataclass
239
+ class FlaxDataCollatorForT5MLM:
240
+ """
241
+ Data collator used for T5 span-masked language modeling.
242
+ It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length.
243
+ For more information on how T5 span-masked language modeling works, one can take a look
244
+ at the `official paper <https://arxiv.org/pdf/1910.10683.pdf>`__
245
+ or the `official code for preprocessing <https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py>`__ .
246
+
247
+ Args:
248
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
249
+ The tokenizer used for encoding the data.
250
+ noise_density (:obj:`float`):
251
+ The probability with which to (randomly) mask tokens in the input.
252
+ mean_noise_span_length (:obj:`float`):
253
+ The average span length of the masked tokens.
254
+ input_length (:obj:`int`):
255
+ The expected input length after masking.
256
+ target_length (:obj:`int`):
257
+ The expected target length after masking.
258
+ pad_token_id: (:obj:`int`):
259
+ The pad token id of the model
260
+ decoder_start_token_id: (:obj:`int):
261
+ The decoder start token id of the model
262
+ """
263
+
264
+ tokenizer: PreTrainedTokenizerBase
265
+ noise_density: float
266
+ mean_noise_span_length: float
267
+ input_length: int
268
+ target_length: int
269
+ pad_token_id: int
270
+ decoder_start_token_id: int
271
+
272
+ def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
273
+
274
+ # convert list to dict and tensorize input
275
+ batch = BatchEncoding(
276
+ {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
277
+ )
278
+
279
+ input_ids = batch["input_ids"]
280
+ batch_size, expandend_input_length = input_ids.shape
281
+
282
+ mask_indices = np.asarray([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)])
283
+ labels_mask = ~mask_indices
284
+
285
+ input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8))
286
+ labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8))
287
+
288
+ batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel)
289
+ batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel)
290
+
291
+ if batch["input_ids"].shape[-1] != self.input_length:
292
+ raise ValueError(
293
+ f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but should be {self.target_length}."
294
+ )
295
+
296
+ if batch["labels"].shape[-1] != self.target_length:
297
+ raise ValueError(
298
+ f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be {self.target_length}."
299
+ )
300
+
301
+ # to check that tokens are correctly proprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here...
302
+ batch["decoder_input_ids"] = shift_tokens_right(
303
+ batch["labels"], self.pad_token_id, self.decoder_start_token_id
304
+ )
305
+
306
+ return batch
307
+
308
+ def create_sentinel_ids(self, mask_indices):
309
+ """
310
+ Sentinel ids creation given the indices that should be masked.
311
+ The start indices of each mask are replaced by the sentinel ids in increasing
312
+ order. Consecutive mask indices to be deleted are replaced with `-1`.
313
+ """
314
+ start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
315
+ start_indices[:, 0] = mask_indices[:, 0]
316
+
317
+ sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
318
+ sentinel_ids = np.where(sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0)
319
+ sentinel_ids -= mask_indices - start_indices
320
+
321
+ return sentinel_ids
322
+
323
+ def filter_input_ids(self, input_ids, sentinel_ids):
324
+ """
325
+ Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
326
+ This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
327
+ """
328
+ batch_size = input_ids.shape[0]
329
+
330
+ input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
331
+ input_ids = input_ids_full[input_ids_full > 0].reshape((batch_size, -1))
332
+ input_ids = np.concatenate(
333
+ [input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1
334
+ )
335
+ return input_ids
336
+
337
+ def random_spans_noise_mask(self, length):
338
+
339
+ """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
340
+
341
+ Noise mask consisting of random spans of noise tokens.
342
+ The number of noise tokens and the number of noise spans and non-noise spans
343
+ are determined deterministically as follows:
344
+ num_noise_tokens = round(length * noise_density)
345
+ num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
346
+ Spans alternate between non-noise and noise, beginning with non-noise.
347
+ Subject to the above restrictions, all masks are equally likely.
348
+
349
+ Args:
350
+ length: an int32 scalar (length of the incoming token sequence)
351
+ noise_density: a float - approximate density of output mask
352
+ mean_noise_span_length: a number
353
+
354
+ Returns:
355
+ a boolean tensor with shape [length]
356
+ """
357
+
358
+ orig_length = length
359
+
360
+ num_noise_tokens = int(np.round(length * self.noise_density))
361
+ # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
362
+ num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
363
+ num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
364
+
365
+ # avoid degeneracy by ensuring positive number of noise spans
366
+ num_noise_spans = max(num_noise_spans, 1)
367
+ num_nonnoise_tokens = length - num_noise_tokens
368
+
369
+ # pick the lengths of the noise spans and the non-noise spans
370
+ def _random_segmentation(num_items, num_segments):
371
+ """Partition a sequence of items randomly into non-empty segments.
372
+ Args:
373
+ num_items: an integer scalar > 0
374
+ num_segments: an integer scalar in [1, num_items]
375
+ Returns:
376
+ a Tensor with shape [num_segments] containing positive integers that add
377
+ up to num_items
378
+ """
379
+ mask_indices = np.arange(num_items - 1) < (num_segments - 1)
380
+ np.random.shuffle(mask_indices)
381
+ first_in_segment = np.pad(mask_indices, [[1, 0]])
382
+ segment_id = np.cumsum(first_in_segment)
383
+ # count length of sub segments assuming that list is sorted
384
+ _, segment_length = np.unique(segment_id, return_counts=True)
385
+ return segment_length
386
+
387
+ noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
388
+ nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)
389
+
390
+ interleaved_span_lengths = np.reshape(
391
+ np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]
392
+ )
393
+ span_starts = np.cumsum(interleaved_span_lengths)[:-1]
394
+ span_start_indicator = np.zeros((length,), dtype=np.int8)
395
+ span_start_indicator[span_starts] = True
396
+ span_num = np.cumsum(span_start_indicator)
397
+ is_noise = np.equal(span_num % 2, 1)
398
+
399
+ return is_noise[:orig_length]
400
+
401
+
402
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
403
+ num_samples = len(samples_idx)
404
+ samples_to_remove = num_samples % batch_size
405
+
406
+ if samples_to_remove != 0:
407
+ samples_idx = samples_idx[:-samples_to_remove]
408
+ sections_split = num_samples // batch_size
409
+ batch_idx = np.split(samples_idx, sections_split)
410
+ return batch_idx
411
+
412
+
413
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
414
+ summary_writer.scalar("train_time", train_time, step)
415
+
416
+ train_metrics = get_metrics(train_metrics)
417
+ for key, vals in train_metrics.items():
418
+ tag = f"train_{key}"
419
+ for i, val in enumerate(vals):
420
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
421
+
422
+
423
+ def write_eval_metric(summary_writer, eval_metrics, step):
424
+ for metric_name, value in eval_metrics.items():
425
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
426
+
427
+
428
+ def mb_item(x):
429
+ return x.item() if hasattr(x, "item") else x
430
+
431
+
432
+ def save_checkpoint(model, save_dir, state, cur_step: int, with_opt: bool = True, push_to_hub: bool = False):
433
+ state = jax_utils.unreplicate(state)
434
+ if with_opt:
435
+ logger.info(f'Saving optimizer and training state in {save_dir}...')
436
+ with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
437
+ f.write(to_bytes(state.opt_state))
438
+ with open(os.path.join(save_dir, "training_state.json"), "w") as f:
439
+ json.dump({"step": state.step.item()}, f)
440
+ logger.info(f'Saving model in {save_dir} {"and pushing it to HF Hub" if push_to_hub else ""}')
441
+ model.save_pretrained(
442
+ save_dir,
443
+ params=state.params,
444
+ push_to_hub=push_to_hub,
445
+ commit_message=f"Saving weights and logs of step {cur_step}",
446
+ )
447
+
448
+ def restore_checkpoint(load_dir, state):
449
+ logger.info(f"Restoring checkpoint from {load_dir}")
450
+ with open(os.path.join(load_dir, "flax_model.msgpack"), "rb") as f:
451
+ params = from_bytes(state.params, f.read())
452
+ with open(os.path.join(load_dir, "opt_state.msgpack"), "rb") as f:
453
+ opt_state = from_bytes(state.opt_state, f.read())
454
+ with open(os.path.join(load_dir, "training_state.json"), "r") as f:
455
+ training_state = json.load(f)
456
+ step = training_state["step"]
457
+ logger.info(f"Checkpoint restored at step {step}")
458
+ return state.replace(step=step, params=params, opt_state=opt_state), step
459
+
460
+
461
+ if __name__ == "__main__":
462
+ # See all possible arguments in src/transformers/training_args.py
463
+ # or by passing the --help flag to this script.
464
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
465
+
466
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
467
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
468
+ # If we pass only one argument to the script and it's the path to a json file,
469
+ # let's parse it to get our arguments.
470
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
471
+ else:
472
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
473
+
474
+ if (
475
+ os.path.exists(training_args.output_dir)
476
+ and os.listdir(training_args.output_dir)
477
+ and training_args.do_train
478
+ and not training_args.overwrite_output_dir
479
+ ):
480
+ raise ValueError(
481
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
482
+ "Use --overwrite_output_dir to overcome."
483
+ )
484
+
485
+ # Setup logging
486
+ logging.basicConfig(
487
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
488
+ level="NOTSET",
489
+ datefmt="[%X]",
490
+ )
491
+
492
+ # Log on each process the small summary:
493
+ logger = logging.getLogger(__name__)
494
+
495
+ # Set the verbosity to info of the Transformers logger (on main process only):
496
+ logger.info(f"Training/evaluation parameters {training_args}")
497
+
498
+ # Set seed before initializing model.
499
+ set_seed(training_args.seed)
500
+
501
+ # Handle the repository creation
502
+ # if training_args.push_to_hub:
503
+ # if training_args.hub_model_id is None:
504
+ # repo_name = get_full_repo_name(
505
+ # Path(training_args.output_dir).absolute().name, token=training_args.hub_token
506
+ # )
507
+ # else:
508
+ # repo_name = training_args.hub_model_id
509
+ # repo = Repository(training_args.output_dir, clone_from=repo_name)
510
+
511
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
512
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
513
+ # (the dataset will be downloaded automatically from the datasets Hub).
514
+ #
515
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
516
+ # 'text' is found. You can easily tweak this behavior (see below).
517
+ if data_args.dataset_name is not None:
518
+ # Downloading and loading a dataset from the hub.
519
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
520
+
521
+ if "validation" not in datasets.keys():
522
+ datasets["validation"] = load_dataset(
523
+ data_args.dataset_name,
524
+ data_args.dataset_config_name,
525
+ split=f"train[:{data_args.validation_split_count}]",
526
+ cache_dir=model_args.cache_dir,
527
+ )
528
+ datasets["train"] = load_dataset(
529
+ data_args.dataset_name,
530
+ data_args.dataset_config_name,
531
+ split=f"train[{data_args.validation_split_count}:]",
532
+ cache_dir=model_args.cache_dir,
533
+ )
534
+ else:
535
+ datasets["validation"] = load_dataset(
536
+ data_args.dataset_name,
537
+ data_args.dataset_config_name,
538
+ split=f"validation[:{data_args.validation_split_count}]",
539
+ cache_dir=model_args.cache_dir,
540
+ )
541
+ datasets["train"] = load_dataset(
542
+ data_args.dataset_name,
543
+ data_args.dataset_config_name,
544
+ split="train",
545
+ cache_dir=model_args.cache_dir,
546
+ )
547
+ else:
548
+ data_files = {}
549
+ if data_args.train_file is not None:
550
+ data_files["train"] = data_args.train_file
551
+ if data_args.validation_file is not None:
552
+ data_files["validation"] = data_args.validation_file
553
+ extension = data_args.train_file.split(".")[-1]
554
+ if extension == "txt":
555
+ extension = "text"
556
+ datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
557
+
558
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
559
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
560
+
561
+ # Load pretrained model and tokenizer
562
+
563
+ if model_args.tokenizer_name:
564
+ tokenizer = AutoTokenizer.from_pretrained(
565
+ model_args.tokenizer_name,
566
+ cache_dir=model_args.cache_dir,
567
+ use_fast=model_args.use_fast_tokenizer,
568
+ use_auth_token=model_args.auth_token
569
+ )
570
+ elif model_args.model_name_or_path:
571
+ tokenizer = AutoTokenizer.from_pretrained(
572
+ model_args.model_name_or_path,
573
+ cache_dir=model_args.cache_dir,
574
+ use_fast=model_args.use_fast_tokenizer,
575
+ use_auth_token=model_args.auth_token
576
+ )
577
+ else:
578
+ raise ValueError(
579
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
580
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
581
+ )
582
+
583
+ if model_args.config_name:
584
+ config = T5Config.from_pretrained(
585
+ model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
586
+ )
587
+ elif model_args.model_name_or_path:
588
+ config = T5Config.from_pretrained(
589
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
590
+ )
591
+ else:
592
+ config = CONFIG_MAPPING[model_args.model_type]()
593
+ logger.warning("You are instantiating a new config instance from scratch.")
594
+
595
+ # Preprocessing the datasets.
596
+ # First we tokenize all the texts.
597
+ if training_args.do_train:
598
+ column_names = datasets["train"].column_names
599
+ else:
600
+ column_names = datasets["validation"].column_names
601
+ text_column_name = "text" if "text" in column_names else column_names[0]
602
+
603
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
604
+
605
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
606
+ # Since we make sure that all sequences are of the same length, no attention_mask is needed.
607
+ def tokenize_function(examples):
608
+ return tokenizer(examples[text_column_name], return_attention_mask=False)
609
+
610
+ logger.info(f"Start tokenization, remove_column_names = {column_names}")
611
+ tokenized_datasets = datasets.map(
612
+ tokenize_function,
613
+ batched=True,
614
+ num_proc=data_args.preprocessing_num_workers,
615
+ remove_columns=column_names,
616
+ load_from_cache_file=not data_args.overwrite_cache,
617
+ )
618
+
619
+ # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
620
+ # To ensure that the input length is `max_seq_length`, we need to increase the maximum length
621
+ # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
622
+ expanded_inputs_length, targets_length = compute_input_and_target_lengths(
623
+ inputs_length=max_seq_length,
624
+ noise_density=data_args.mlm_probability,
625
+ mean_noise_span_length=data_args.mean_noise_span_length,
626
+ )
627
+ logger.info(f"Max seq length: {max_seq_length}, expanded_inputs_length: {expanded_inputs_length}, targets_length: {targets_length}")
628
+
629
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
630
+ def group_texts(examples):
631
+ # Concatenate all texts.
632
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
633
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
634
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
635
+ # customize this part to your needs.
636
+ if total_length >= expanded_inputs_length:
637
+ total_length = (total_length // expanded_inputs_length) * expanded_inputs_length
638
+ # Split by chunks of max_len.
639
+ result = {
640
+ k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]
641
+ for k, t in concatenated_examples.items()
642
+ }
643
+ return result
644
+
645
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
646
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
647
+ # might be slower to preprocess.
648
+ #
649
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
650
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
651
+ logger.info(f"Start group_texts")
652
+ tokenized_datasets = tokenized_datasets.map(
653
+ group_texts,
654
+ batched=True,
655
+ batch_size=200,
656
+ num_proc=data_args.preprocessing_num_workers,
657
+ load_from_cache_file=not data_args.overwrite_cache,
658
+ )
659
+
660
+ # Enable tensorboard only on the master node
661
+ has_tensorboard = is_tensorboard_available()
662
+ if has_tensorboard and jax.process_index() == 0:
663
+ try:
664
+ from flax.metrics.tensorboard import SummaryWriter
665
+
666
+ summary_writer = SummaryWriter(log_dir=Path(training_args.logging_dir))
667
+ except ImportError as ie:
668
+ has_tensorboard = False
669
+ logger.warning(
670
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
671
+ )
672
+ else:
673
+ logger.warning(
674
+ "Unable to display metrics through TensorBoard because the package is not installed: "
675
+ "Please run pip install tensorboard to enable."
676
+ )
677
+
678
+ # Initialize our training
679
+ rng = jax.random.PRNGKey(training_args.seed)
680
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
681
+
682
+ if model_args.model_name_or_path:
683
+ model = FlaxT5ForConditionalGeneration.from_pretrained(
684
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
685
+ )
686
+ else:
687
+ config.vocab_size = len(tokenizer)
688
+ model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
689
+
690
+ # Data collator
691
+ # This one will take care of randomly masking the tokens.
692
+ data_collator = FlaxDataCollatorForT5MLM(
693
+ tokenizer=tokenizer,
694
+ noise_density=data_args.mlm_probability,
695
+ mean_noise_span_length=data_args.mean_noise_span_length,
696
+ input_length=max_seq_length,
697
+ target_length=targets_length,
698
+ pad_token_id=model.config.pad_token_id,
699
+ decoder_start_token_id=model.config.decoder_start_token_id,
700
+ )
701
+
702
+ # Store some constant
703
+ num_epochs = int(training_args.num_train_epochs)
704
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
705
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
706
+
707
+ steps_per_epoch = len(tokenized_datasets['train']) // train_batch_size
708
+ num_train_steps = steps_per_epoch * num_epochs
709
+
710
+ # Create learning rate schedule
711
+ if training_args.warmup_steps:
712
+ warmup_steps = training_args.warmup_steps
713
+ elif training_args.warmup_ratio:
714
+ # See https://arxiv.org/pdf/2104.07705.pdf for rationale of choosing the peak at % of training steps
715
+ warmup_steps = int(training_args.warmup_ratio * num_train_steps)
716
+ logging.info(f"Warmup steps set to {100*training_args.warmup_ratio}% = {warmup_steps} of total train steps {num_train_steps}")
717
+ else:
718
+ raise Exception("Need either --warmup_steps or --warmup_ratio")
719
+ warmup_fn = optax.linear_schedule(
720
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=warmup_steps
721
+ )
722
+ decay_fn = optax.linear_schedule(
723
+ init_value=training_args.learning_rate,
724
+ end_value=0,
725
+ transition_steps=num_train_steps - warmup_steps,
726
+ )
727
+ linear_decay_lr_schedule_fn = optax.join_schedules(
728
+ schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps]
729
+ )
730
+
731
+ # We use Optax's "masking" functionality to not apply weight decay
732
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
733
+ # mask boolean with the same structure as the parameters.
734
+ # The mask is True for parameters that should be decayed.
735
+ def decay_mask_fn(params):
736
+ flat_params = traverse_util.flatten_dict(params)
737
+ flat_mask = {
738
+ path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
739
+ for path in flat_params
740
+ }
741
+ return traverse_util.unflatten_dict(flat_mask)
742
+
743
+ # create adam optimizer
744
+ if training_args.adafactor:
745
+ # We use the default parameters here to initialize adafactor,
746
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
747
+ optimizer = optax.adafactor(
748
+ learning_rate=linear_decay_lr_schedule_fn,
749
+ )
750
+ else:
751
+ optimizer = optax.adamw(
752
+ learning_rate=linear_decay_lr_schedule_fn,
753
+ b1=training_args.adam_beta1,
754
+ b2=training_args.adam_beta2,
755
+ weight_decay=training_args.weight_decay,
756
+ mask=decay_mask_fn,
757
+ )
758
+
759
+ if training_args.gradient_accumulation_steps > 1:
760
+ optimizer = optax.MultiSteps(optimizer, training_args.gradient_accumulation_steps)
761
+ grad_accum_steps = training_args.gradient_accumulation_steps
762
+
763
+ # Setup train state
764
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
765
+
766
+ if training_args.resume_from_checkpoint:
767
+ state, resume_step = restore_checkpoint(training_args.resume_from_checkpoint, state)
768
+ else:
769
+ resume_step = 0
770
+
771
+ # Define gradient update step fn
772
+ def train_step(state, batch, dropout_rng):
773
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
774
+
775
+ def loss_fn(params):
776
+ labels = batch.pop("labels")
777
+
778
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
779
+
780
+ # compute loss
781
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
782
+
783
+ return loss
784
+
785
+ grad_fn = jax.value_and_grad(loss_fn)
786
+ loss, grad = grad_fn(state.params)
787
+ grad = jax.lax.pmean(grad, "batch")
788
+ new_state = state.apply_gradients(grads=grad)
789
+
790
+ metrics = jax.lax.pmean(
791
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step // grad_accum_steps)},
792
+ axis_name="batch"
793
+ )
794
+
795
+ return new_state, metrics, new_dropout_rng
796
+
797
+ # Create parallel version of the train step
798
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
799
+
800
+ # Define eval fn
801
+ def eval_step(params, batch):
802
+ labels = batch.pop("labels")
803
+
804
+ logits = model(**batch, params=params, train=False)[0]
805
+
806
+ # compute loss
807
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
808
+
809
+ # compute accuracy
810
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels)
811
+
812
+ # summarize metrics
813
+ metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()}
814
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
815
+
816
+ return metrics
817
+
818
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
819
+
820
+ logger.info("Replicate the train state on each device")
821
+
822
+ # import pydevd_pycharm
823
+ #
824
+ # pydevd_pycharm.settrace('localhost', port=12345, stdoutToServer=True, stderrToServer=True)
825
+
826
+ # Replicate the train state on each device
827
+ state = jax_utils.replicate(state)
828
+
829
+ logger.info("***** Running training *****")
830
+ logger.info(f" Num examples = {len(datasets['train'])}")
831
+ logger.info(f" Num tokenized group examples {len(tokenized_datasets['train'])}")
832
+ logger.info(f" Num Epochs = {num_epochs}")
833
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
834
+ logger.info(f" Total train batch size (w. parallel, distributed and grad_accum) = {train_batch_size}")
835
+ logger.info(f" Total optimization steps = {num_train_steps}")
836
+
837
+ train_time = 0
838
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
839
+ for epoch in epochs:
840
+ # ======================== Training ================================
841
+ train_start = time.time()
842
+ train_metrics = []
843
+
844
+ # Create sampling rng
845
+ rng, input_rng = jax.random.split(rng)
846
+
847
+ # Generate an epoch by shuffling sampling indices from the train dataset
848
+ num_train_samples = len(tokenized_datasets["train"])
849
+ # train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
850
+ # train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
851
+
852
+ ## IF THE DATASET IS TOO LONG, WE ONLY PROCEED SEQUENTIALLY WITHOUT SHUFFLING
853
+ samples_to_remove = num_train_samples % (train_batch_size // grad_accum_steps)
854
+ samples_idx = np.arange(num_train_samples)
855
+ if samples_to_remove != 0:
856
+ samples_idx = samples_idx[:-samples_to_remove]
857
+ steps = num_train_samples // (train_batch_size // grad_accum_steps)
858
+
859
+ # Gather the indexes for creating the batch and do a training step
860
+ # for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
861
+ # samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
862
+ for step in tqdm(range(steps), desc="Training...", position=1):
863
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
864
+ # skip to the step from which we are resuming
865
+ if cur_step < resume_step:
866
+ continue
867
+
868
+ batch_idx = [x for x in range(step * train_batch_size, (step + 1) * train_batch_size)]
869
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
870
+ try:
871
+ model_inputs = data_collator(samples)
872
+ except ValueError as e:
873
+ logger.warning(str(e))
874
+ logger.info(f"Continuing with the next batch")
875
+ continue
876
+
877
+ # Model forward
878
+ model_inputs = shard(model_inputs.data)
879
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
880
+ train_metrics.append(train_metric)
881
+
882
+ if cur_step % training_args.logging_steps * grad_accum_steps == 0 and cur_step > 0:
883
+ # Save metrics
884
+ train_metric = jax_utils.unreplicate(train_metric)
885
+ train_time += time.time() - train_start
886
+ if has_tensorboard and jax.process_index() == 0:
887
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
888
+
889
+ epochs.write(
890
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
891
+ )
892
+
893
+ train_metrics = []
894
+
895
+ if cur_step % training_args.eval_steps * grad_accum_steps == 0 and cur_step > 0:
896
+ # ======================== Evaluating ==============================
897
+ num_eval_samples = len(tokenized_datasets["validation"])
898
+ eval_samples_idx = jnp.arange(num_eval_samples)
899
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
900
+
901
+ eval_metrics = []
902
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
903
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
904
+ model_inputs = data_collator(samples)
905
+
906
+ # Model forward
907
+ model_inputs = shard(model_inputs.data)
908
+ metrics = p_eval_step(state.params, model_inputs)
909
+ eval_metrics.append(metrics)
910
+
911
+ # get eval metrics
912
+ eval_metrics = get_metrics(eval_metrics)
913
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
914
+
915
+ # Update progress bar
916
+ epochs.write(f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})")
917
+
918
+ # Save metrics
919
+ if has_tensorboard and jax.process_index() == 0:
920
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
921
+
922
+ if cur_step % training_args.save_steps * grad_accum_steps == 0 and cur_step > 0:
923
+ # save checkpoint after each epoch and push checkpoint to the hub
924
+ if jax.process_index() == 0:
925
+ # params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
926
+ # model.save_pretrained(training_args.output_dir, params=params)
927
+ # tokenizer.save_pretrained(training_args.output_dir)
928
+ # if training_args.push_to_hub:
929
+ # repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
930
+ save_checkpoint(model, training_args.output_dir, state, cur_step, with_opt=False, push_to_hub=True)
931
+
932
+ # Eval after training
933
+ if training_args.do_eval:
934
+ num_eval_samples = len(tokenized_datasets["validation"])
935
+ eval_samples_idx = jnp.arange(num_eval_samples)
936
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
937
+
938
+ eval_metrics = []
939
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
940
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
941
+ model_inputs = data_collator(samples)
942
+
943
+ # Model forward
944
+ model_inputs = shard(model_inputs.data)
945
+ metrics = p_eval_step(state.params, model_inputs)
946
+ eval_metrics.append(metrics)
947
+
948
+ # get eval metrics
949
+ eval_metrics = get_metrics(eval_metrics)
950
+ eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics)
951
+
952
+ if jax.process_index() == 0:
953
+ eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
954
+ path = os.path.join(training_args.output_dir, "eval_results.json")
955
+ with open(path, "w") as f:
956
+ json.dump(eval_metrics, f, indent=4, sort_keys=True)
957
+
958
+ # Save model at end
959
+ if jax.process_index() == 0:
960
+ # params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
961
+ # model.save_pretrained(training_args.output_dir, params=params)
962
+ # tokenizer.save_pretrained(training_args.output_dir)
963
+ # if training_args.push_to_hub:
964
+ # repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
965
+ #
966
+ save_checkpoint(model, training_args.output_dir, state, cur_step, with_opt=False, push_to_hub=True)