yhavinga commited on
Commit
21e9f42
1 Parent(s): f2c9d90

Add scripts, model config and vocabulary

Browse files
README.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: nl
3
+ widget:
4
+ - text: "In het jaar 2030 zullen we"
5
+ - text: "Toen ik gisteren volledig in de ban was van"
6
+ - text: "Studenten en leraren van de Bogazici Universiteit in de Turkse stad Istanbul"
7
+ - text: "In Israël was een strenge lockdown"
8
+ tags:
9
+ - gpt2-medium
10
+ - gpt2
11
+ pipeline_tag: text-generation
12
+ datasets:
13
+ - yhavinga/mc4_nl_cleaned
14
+ ---
15
+ # GPT2-Medium pre-trained on cleaned Dutch mC4 🇳🇱
16
+
17
+ Dataset:
18
+
19
+ * [mC4 NL Cleaned](https://huggingface.co/datasets/yhavinga/mc4_nl_cleaned)
20
+ * dataset config: full (33B tokens)
21
+
22
+ Tokenizer:
23
+
24
+ * Tokenizer trained on mC4 with scripts from the Huggingface
25
+ Transformers [Flax examples](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling)
26
+
27
+ Training details:
28
+
29
+ * Trained for 280k steps (30 dec 2021)
30
+ * Block size: 512
31
+ * Optimizer: adam, lr 8e-4, beta1 0.9, beta2 0.98
32
+ * Warmup steps: 5000
33
+ * Weight decay: 0.01
34
+
35
+ Work in progress. Dec 2021-Jan2022
36
+
37
+ * Many thanks to the [Google TPU Research Cloud](https://sites.research.google/trc/about/) for providing access to a TPU cluster!
38
+ * Thanks to @gsarti for creating the [t5-flax-gcp
39
+ repository](https://github.com/gsarti/t5-flax-gcp).
40
+ * Also thanks to the creators of [gpt2-medium-persian](https://huggingface.co/flax-community/gpt2-medium-persian) and
41
+ [gpt2-medium-indonesian](https://huggingface.co/flax-community/gpt2-medium-persian)
42
+ for sharing their training scripts!
added_tokens.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"<|endoftext|>": 50256}
commit.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ source ~/venv/bin/activate
3
+
4
+ while true
5
+ do
6
+ echo -n "Checking at .. "
7
+ date
8
+ UPDATED=`git status | grep flax_model | grep modified`
9
+
10
+ if [ ! -z "$UPDATED" ]
11
+ then
12
+ sleep 120
13
+ FILE=$(find . -name `ls -tR runs | grep events | head -n 1` | tail -n 1)
14
+ STEP=`tensorboard --load_fast=true --inspect --event_file=$FILE | grep last_step | awk '{print $2}'`
15
+ git add runs
16
+ git add flax_model.msgpack
17
+ git commit -m "Saving weights and logs step $STEP"
18
+ fi
19
+
20
+ sleep 60
21
+ done
config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/home/yeb/data/gpt2-medium-dutch",
3
+ "activation_function": "gelu_new",
4
+ "architectures": [
5
+ "GPT2LMHeadModel"
6
+ ],
7
+ "attn_pdrop": 0.0,
8
+ "bos_token_id": 50256,
9
+ "embd_pdrop": 0.0,
10
+ "eos_token_id": 50256,
11
+ "initializer_range": 0.02,
12
+ "layer_norm_epsilon": 1e-05,
13
+ "model_type": "gpt2",
14
+ "torch_dtype": "float32",
15
+ "n_ctx": 1024,
16
+ "n_embd": 1024,
17
+ "n_head": 16,
18
+ "n_inner": null,
19
+ "n_layer": 24,
20
+ "n_positions": 1024,
21
+ "n_special": 0,
22
+ "predict_special_tokens": true,
23
+ "reorder_and_upcast_attn": false,
24
+ "resid_pdrop": 0.0,
25
+ "scale_attn_by_inverse_layer_idx": false,
26
+ "scale_attn_weights": true,
27
+ "summary_activation": null,
28
+ "summary_first_dropout": 0.1,
29
+ "summary_proj_to_labels": true,
30
+ "summary_type": "cls_index",
31
+ "summary_use_proj": true,
32
+ "task_specific_params": {
33
+ "text-generation": {
34
+ "do_sample": true,
35
+ "max_length": 50
36
+ }
37
+ },
38
+ "transformers_version": "4.13.0",
39
+ "use_cache": true,
40
+ "vocab_size": 50257
41
+ }
flax_to_pytorch.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from transformers import AutoTokenizer
6
+ from transformers import FlaxGPT2LMHeadModel
7
+ from transformers import GPT2LMHeadModel
8
+ tokenizer = AutoTokenizer.from_pretrained(".")
9
+ tokenizer.pad_token = tokenizer.eos_token
10
+ model_fx = FlaxGPT2LMHeadModel.from_pretrained(".")
11
+ # def to_f32(t):
12
+ # return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
13
+ # model_fx.params = to_f32(model_fx.params)
14
+ # model_fx.save_pretrained("./fx")
15
+ model_pt = GPT2LMHeadModel.from_pretrained(".", from_flax=True)
16
+ model_pt.save_pretrained(".")
17
+ input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
18
+ input_ids_pt = torch.tensor(input_ids)
19
+ logits_pt = model_pt(input_ids_pt).logits
20
+ print(logits_pt)
21
+ logits_fx = model_fx(input_ids).logits
22
+ print(logits_fx)
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
push.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ source ~/venv/bin/activate
3
+
4
+ while true
5
+ do
6
+ echo -n "Checking at .. "
7
+ date
8
+ BEHIND=`git rev-list origin..HEAD`
9
+
10
+ if [ ! -z "$BEHIND" ]
11
+ then
12
+ git push origin
13
+ fi
14
+
15
+ sleep 180
16
+ done
replace_token_script.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ''''This script was used to replace the final index of tokenizer.json and vocab.json
2
+ with "<|endoftext|>" token. Also reassociate the corresponding merges'''
3
+
4
+ import json
5
+
6
+ tokenizer_path = 'tokenizer.json'
7
+ model_config_path = 'config.json'
8
+ vocab_path = 'vocab.json'
9
+
10
+ with open(vocab_path, "r") as f:
11
+ vocab_data = json.load(f)
12
+
13
+ with open(tokenizer_path, "r") as f:
14
+ tokenizer_data = json.load(f)
15
+
16
+ with open(model_config_path, "r") as f:
17
+ model_config = json.load(f)
18
+
19
+ model_vocab_size = model_config['vocab_size']
20
+ tokenizer_vocab = tokenizer_data['model']['vocab']
21
+
22
+ mergeslength = len(tokenizer_data['model']['merges'])
23
+
24
+ #readjust added_tokens 'id' to model_vocab_size - 1
25
+ tokenizer_data['added_tokens'][-1]['id'] = model_vocab_size - 1
26
+
27
+ final_index = model_vocab_size - 1
28
+ eos = '<|endoftext|>'
29
+
30
+ #retrieve the key of final index
31
+ old_key_final_index_tokenizer = list(tokenizer_data['model']['vocab'].keys())[final_index]
32
+ old_key_final_index_vocab = list(vocab_data.keys())[final_index]
33
+ old_key_final_index_vocab_min2 = list(vocab_data.keys())[final_index - 1]
34
+ old_key_final_index_tokenizer_merges = tokenizer_data['model']['merges'][mergeslength - 1]
35
+
36
+ print(f"old_key_final_index_tokenizer = {old_key_final_index_tokenizer}")
37
+ print(f"old_key_final_index_vocab = {old_key_final_index_vocab}")
38
+ print(f"old_key_final_index_vocab_min2 = {old_key_final_index_vocab_min2}")
39
+ print(f"old_key_final_index_tokenizer_merges = {old_key_final_index_tokenizer_merges}")
40
+
41
+ #replace old key with new key
42
+ tokenizer_data['model']['vocab']['<|endoftext|>'] = tokenizer_data['model']['vocab'][old_key_final_index_tokenizer]
43
+ vocab_data[eos] = vocab_data[old_key_final_index_vocab]
44
+
45
+ #replace the final merges idx with vocab_data - 1
46
+ tokenizer_data['model']['merges'] = tokenizer_data['model']['merges'][: mergeslength - 1]
47
+
48
+
49
+ #delete old key
50
+ del tokenizer_data['model']['vocab'][old_key_final_index_tokenizer]
51
+ del vocab_data[old_key_final_index_vocab]
52
+
53
+ #check updated key
54
+ old_key_final_index_tokenizer = list(tokenizer_data['model']['vocab'].keys())[final_index]
55
+ old_key_final_index_vocab = list(vocab_data.keys())[final_index]
56
+ old_key_final_index_tokenizer_merges = tokenizer_data['model']['merges'][mergeslength - 2]
57
+
58
+ print(len(tokenizer_data['model']['merges']))
59
+ print()
60
+ print(f"updated old_key_final_index_tokenizer = {old_key_final_index_tokenizer}")
61
+ print(f"updated old_key_final_index_vocab = {old_key_final_index_vocab}")
62
+ print(f"updated old_key_final_index_tokenizer_merges = {old_key_final_index_tokenizer_merges}")
63
+
64
+ with open(tokenizer_path, "w")as f:
65
+ json.dump(tokenizer_data, f)
66
+
67
+ with open(vocab_path, "w")as f:
68
+ json.dump(vocab_data, f)
69
+
70
+ with open('merges.txt') as f:
71
+ lines = f.readlines()
72
+
73
+ with open("merges.txt", "w") as f:
74
+ for i in range(len(lines) - 1):
75
+ f.write(lines[i])
76
+
77
+ with open('merges.txt') as f:
78
+ newlines = f.readlines()
79
+
80
+ print(f"newlines[len(newlines) - 1] = {newlines[len(newlines) - 1]}")
run_clm_flax.py ADDED
@@ -0,0 +1,889 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
+ https://huggingface.co/models?filter=text-generation
21
+ """
22
+ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
23
+
24
+ import json
25
+ import logging
26
+ import math
27
+ import os
28
+ import sys
29
+ import time
30
+ from dataclasses import asdict, dataclass, field
31
+ from enum import Enum
32
+ from itertools import chain
33
+ from pathlib import Path
34
+ from typing import Callable, Optional
35
+ import json
36
+ import shutil
37
+
38
+ import datasets
39
+ import numpy as np
40
+ from datasets import Dataset, load_dataset
41
+ from tqdm import tqdm
42
+
43
+ import jax
44
+ import jax.numpy as jnp
45
+ import optax
46
+ import transformers
47
+ from flax import jax_utils, traverse_util
48
+ from flax.jax_utils import unreplicate
49
+ from flax.training import train_state
50
+ from flax.training.checkpoints import save_checkpoint, restore_checkpoint
51
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
52
+ from flax.serialization import to_bytes, from_bytes
53
+ from transformers import (
54
+ CONFIG_MAPPING,
55
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
56
+ AutoConfig,
57
+ AutoTokenizer,
58
+ FlaxAutoModelForCausalLM,
59
+ HfArgumentParser,
60
+ is_tensorboard_available,
61
+ set_seed,
62
+ )
63
+ from transformers.file_utils import get_full_repo_name
64
+ from transformers.testing_utils import CaptureLogger
65
+
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
70
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
71
+
72
+
73
+ @dataclass
74
+ class TrainingArguments:
75
+ output_dir: str = field(
76
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
77
+ )
78
+ overwrite_output_dir: bool = field(
79
+ default=False,
80
+ metadata={
81
+ "help": (
82
+ "Overwrite the content of the output directory. "
83
+ "Use this to continue training if output_dir points to a checkpoint directory."
84
+ )
85
+ },
86
+ )
87
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
88
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
89
+ per_device_train_batch_size: int = field(
90
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
91
+ )
92
+ per_device_eval_batch_size: int = field(
93
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
94
+ )
95
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
96
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
97
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
98
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
99
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
100
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
101
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
102
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
103
+ logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
104
+ save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
105
+ eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
106
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
107
+ push_to_hub: bool = field(
108
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
109
+ )
110
+ hub_model_id: str = field(
111
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
112
+ )
113
+ hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
114
+
115
+ def __post_init__(self):
116
+ if self.output_dir is not None:
117
+ self.output_dir = os.path.expanduser(self.output_dir)
118
+
119
+ def to_dict(self):
120
+ """
121
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
122
+ the token values by removing their value.
123
+ """
124
+ d = asdict(self)
125
+ for k, v in d.items():
126
+ if isinstance(v, Enum):
127
+ d[k] = v.value
128
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
129
+ d[k] = [x.value for x in v]
130
+ if k.endswith("_token"):
131
+ d[k] = f"<{k.upper()}>"
132
+ return d
133
+
134
+
135
+ @dataclass
136
+ class ModelArguments:
137
+ """
138
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
139
+ """
140
+
141
+ model_name_or_path: Optional[str] = field(
142
+ default=None,
143
+ metadata={
144
+ "help": "The model checkpoint for weights initialization."
145
+ "Don't set if you want to train a model from scratch."
146
+ },
147
+ )
148
+ model_type: Optional[str] = field(
149
+ default=None,
150
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
151
+ )
152
+ config_name: Optional[str] = field(
153
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
154
+ )
155
+ tokenizer_name: Optional[str] = field(
156
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
157
+ )
158
+ cache_dir: Optional[str] = field(
159
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
160
+ )
161
+ use_fast_tokenizer: bool = field(
162
+ default=True,
163
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
164
+ )
165
+ dtype: Optional[str] = field(
166
+ default="float32",
167
+ metadata={
168
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
169
+ },
170
+ )
171
+
172
+
173
+ @dataclass
174
+ class DataTrainingArguments:
175
+ """
176
+ Arguments pertaining to what data we are going to input our model for training and eval.
177
+ """
178
+
179
+ dataset_name: Optional[str] = field(
180
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
181
+ )
182
+ dataset_config_name: Optional[str] = field(
183
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
184
+ )
185
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
186
+ validation_file: Optional[str] = field(
187
+ default=None,
188
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
189
+ )
190
+ max_train_samples: Optional[int] = field(
191
+ default=None,
192
+ metadata={
193
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
194
+ "value if set."
195
+ },
196
+ )
197
+ max_eval_samples: Optional[int] = field(
198
+ default=None,
199
+ metadata={
200
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
201
+ "value if set."
202
+ },
203
+ )
204
+ overwrite_cache: bool = field(
205
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
206
+ )
207
+ validation_split_percentage: Optional[int] = field(
208
+ default=5,
209
+ metadata={
210
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
211
+ },
212
+ )
213
+ block_size: Optional[int] = field(
214
+ default=None,
215
+ metadata={
216
+ "help": "Optional input sequence length after tokenization. "
217
+ "The training dataset will be truncated in block of this size for training. "
218
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
219
+ },
220
+ )
221
+ overwrite_cache: bool = field(
222
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
223
+ )
224
+ preprocessing_num_workers: Optional[int] = field(
225
+ default=None,
226
+ metadata={"help": "The number of processes to use for the preprocessing."},
227
+ )
228
+ keep_linebreaks: bool = field(
229
+ default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
230
+ )
231
+
232
+ def __post_init__(self):
233
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
234
+ raise ValueError("Need either a dataset name or a training/validation file.")
235
+ else:
236
+ if self.train_file is not None:
237
+ extension = self.train_file.split(".")[-1]
238
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
239
+ if self.validation_file is not None:
240
+ extension = self.validation_file.split(".")[-1]
241
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
242
+
243
+
244
+ class TrainState(train_state.TrainState):
245
+ dropout_rng: jnp.ndarray
246
+
247
+ def replicate(self):
248
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
249
+
250
+
251
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
252
+ """
253
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
254
+ Shuffle batches if `shuffle` is `True`.
255
+ """
256
+ steps_per_epoch = len(dataset) // batch_size
257
+
258
+ if shuffle:
259
+ batch_idx = jax.random.permutation(rng, len(dataset))
260
+ else:
261
+ batch_idx = jnp.arange(len(dataset))
262
+
263
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
264
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
265
+
266
+ for idx in batch_idx:
267
+ batch = dataset[idx]
268
+ batch = {k: np.array(v) for k, v in batch.items()}
269
+
270
+ yield batch
271
+
272
+
273
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
274
+ summary_writer.scalar("train_time", train_time, step)
275
+
276
+ train_metrics = get_metrics(train_metrics)
277
+ for key, vals in train_metrics.items():
278
+ tag = f"train_{key}"
279
+ for i, val in enumerate(vals):
280
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
281
+
282
+
283
+ def write_eval_metric(summary_writer, eval_metrics, step):
284
+ for metric_name, value in eval_metrics.items():
285
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
286
+
287
+
288
+ def create_learning_rate_fn(
289
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
290
+ ) -> Callable[[int], jnp.array]:
291
+ """Returns a linear warmup, linear_decay learning rate function."""
292
+ steps_per_epoch = train_ds_size // train_batch_size
293
+ num_train_steps = steps_per_epoch * num_train_epochs
294
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
295
+ decay_fn = optax.linear_schedule(
296
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
297
+ )
298
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
299
+ return schedule_fn
300
+
301
+
302
+ # utils
303
+ def mb_item(x):
304
+ return x.item() if hasattr(x, "item") else x
305
+
306
+
307
+ # checkpoint functions
308
+ def save_model_checkpoint(model, save_dir, state, with_opt: bool = True, push_to_hub: bool = False):
309
+ """
310
+ If `push_to_hub` is True, will save to `save_dir`. Otherwise will save to `save_dir/ckpt-{step}`.
311
+ """
312
+ state = jax_utils.unreplicate(state)
313
+ logger.info(f"SAVING CHECKPOINT IN {save_dir}...")
314
+ if not push_to_hub:
315
+ save_dir = f"{save_dir}/ckpt-{mb_item(state.step) - 1}"
316
+ model.save_pretrained(
317
+ save_dir,
318
+ params=state.params,
319
+ push_to_hub=push_to_hub,
320
+ commit_message=f"Saving weights and logs at step {mb_item(state.step) - 1}",
321
+ )
322
+ if with_opt:
323
+ with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
324
+ f.write(to_bytes(state.opt_state))
325
+ with open(os.path.join(save_dir, "training_state.json"), "w") as f:
326
+ json.dump({"step": state.step.item()}, f)
327
+ logger.info("checkpoint saved")
328
+
329
+
330
+ # this is added to make resuming from checkpoint to work with adafactor
331
+ # to be removed when issue is fixed
332
+ # notice that adafactor state is perturbed by fake_update
333
+ def _zeros_tree_like(inp_tree):
334
+ return jax.tree_map(jnp.zeros_like, inp_tree)
335
+
336
+
337
+ def fake_update(state):
338
+ fake_updates = _zeros_tree_like(state.params)
339
+ _, new_inner_opt_state = state.tx.inner_opt.update(fake_updates, state.opt_state.inner_opt_state, state.params)
340
+ opt_state = state.opt_state
341
+ new_opt_state = optax.MultiStepsState(mini_step=opt_state.mini_step,
342
+ gradient_step=opt_state.gradient_step,
343
+ inner_opt_state=new_inner_opt_state,
344
+ acc_grads=opt_state.acc_grads)
345
+ return state.replace(opt_state=new_opt_state)
346
+
347
+
348
+ def reinstantiate_states(opt_state):
349
+ new_state = []
350
+ for state in opt_state:
351
+ if isinstance(state, list):
352
+ new_state.append(reinstantiate_states(state))
353
+ else:
354
+ cls = getattr(optax, type(state).__name__)
355
+ new_state.append(cls(**{k: getattr(state, k) for k in state._fields}))
356
+ return new_state
357
+
358
+
359
+ def restore_model_checkpoint(save_dir, state):
360
+ logger.info(f"RESTORING CHECKPOINT FROM {save_dir}...")
361
+ with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
362
+ params = from_bytes(state.params, f.read())
363
+
364
+ with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
365
+ opt_state = from_bytes(state.opt_state, f.read())
366
+
367
+ with open(os.path.join(save_dir, "training_state.json"), "r") as f:
368
+ training_state = json.load(f)
369
+ step = training_state["step"]
370
+
371
+ logger.info("checkpoint restored")
372
+ # reinstantiate inner opt state to avoid type conflict
373
+ if hasattr(opt_state, "inner_opt_state"):
374
+ print("restoring state of multisteps optimizer")
375
+ inner_opt_state = reinstantiate_states(opt_state.inner_opt_state)
376
+ ms_state_dict = {k: getattr(state.opt_state, k) for k in state.opt_state._fields}
377
+ ms_state_dict["inner_opt_state"] = inner_opt_state
378
+ opt_state = optax.MultiStepsState(**ms_state_dict)
379
+
380
+ return state.replace(step=step, params=params, opt_state=opt_state)
381
+
382
+
383
+ def rotate_checkpoints(ckpt_dir: str, save_total_limit: int):
384
+ "Removes older checkpoints so that `save_total_limit` checkpoints are kept"
385
+ # TODO: what to remove is decided using step number only, we might want to improve that
386
+ ckpts = [str(x) for x in Path(ckpt_dir).glob("ckpt-*")]
387
+ # sort checkpoints by step
388
+ ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split('-')[-1]))
389
+ ckpts_to_delete = ckpts_sorted[:-save_total_limit]
390
+ for ckpt in ckpts_to_delete:
391
+ logger.info(f"Deleting older checkpoint [{ckpt}] due to save_total_limit ({save_total_limit})")
392
+ shutil.rmtree(ckpt)
393
+
394
+
395
+ def main():
396
+ # See all possible arguments in src/transformers/training_args.py
397
+ # or by passing the --help flag to this script.
398
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
399
+
400
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
401
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
402
+ # If we pass only one argument to the script and it's the path to a json file,
403
+ # let's parse it to get our arguments.
404
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
405
+ else:
406
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
407
+
408
+ if (
409
+ os.path.exists(training_args.output_dir)
410
+ and os.listdir(training_args.output_dir)
411
+ and training_args.do_train
412
+ and not training_args.overwrite_output_dir
413
+ ):
414
+ raise ValueError(
415
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
416
+ "Use --overwrite_output_dir to overcome."
417
+ )
418
+
419
+ # Make one log on every process with the configuration for debugging.
420
+ logging.basicConfig(
421
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
422
+ datefmt="%m/%d/%Y %H:%M:%S",
423
+ level=logging.INFO,
424
+ )
425
+ # Setup logging, we only want one process per machine to log things on the screen.
426
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
427
+ if jax.process_index() == 0:
428
+ datasets.utils.logging.set_verbosity_warning()
429
+ transformers.utils.logging.set_verbosity_info()
430
+ else:
431
+ datasets.utils.logging.set_verbosity_error()
432
+ transformers.utils.logging.set_verbosity_error()
433
+
434
+ # Set the verbosity to info of the Transformers logger (on main process only):
435
+ logger.info(f"Training/evaluation parameters {training_args}")
436
+
437
+ # Set seed before initializing model.
438
+ set_seed(training_args.seed)
439
+
440
+ # # Handle the repository creation
441
+ # if training_args.push_to_hub:
442
+ # if training_args.hub_model_id is None:
443
+ # repo_name = get_full_repo_name(
444
+ # Path(training_args.output_dir).absolute().name, token=training_args.hub_token
445
+ # )
446
+ # else:
447
+ # repo_name = training_args.hub_model_id
448
+ # repo = Repository(training_args.output_dir, clone_from=repo_name)
449
+
450
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
451
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
452
+ # (the dataset will be downloaded automatically from the datasets Hub).
453
+ #
454
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
455
+ # 'text' is found. You can easily tweak this behavior (see below).
456
+ #
457
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
458
+ # download the dataset.
459
+ if data_args.dataset_name is not None:
460
+ # Downloading and loading a dataset from the hub.
461
+ dataset = load_dataset(
462
+ data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
463
+ )
464
+
465
+ if "validation" not in dataset.keys():
466
+ dataset["validation"] = load_dataset(
467
+ data_args.dataset_name,
468
+ data_args.dataset_config_name,
469
+ split=f"train[:{data_args.validation_split_percentage}%]",
470
+ cache_dir=model_args.cache_dir,
471
+ )
472
+ dataset["train"] = load_dataset(
473
+ data_args.dataset_name,
474
+ data_args.dataset_config_name,
475
+ split=f"train[{data_args.validation_split_percentage}%:]",
476
+ cache_dir=model_args.cache_dir,
477
+ )
478
+ else:
479
+ data_files = {}
480
+ dataset_args = {}
481
+ if data_args.train_file is not None:
482
+ data_files["train"] = data_args.train_file
483
+ if data_args.validation_file is not None:
484
+ data_files["validation"] = data_args.validation_file
485
+ extension = data_args.train_file.split(".")[-1]
486
+ if extension == "txt":
487
+ extension = "text"
488
+ dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
489
+ dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args)
490
+
491
+ if "validation" not in dataset.keys():
492
+ dataset["validation"] = load_dataset(
493
+ extension,
494
+ data_files=data_files,
495
+ split=f"train[:{data_args.validation_split_percentage}%]",
496
+ cache_dir=model_args.cache_dir,
497
+ **dataset_args,
498
+ )
499
+ dataset["train"] = load_dataset(
500
+ extension,
501
+ data_files=data_files,
502
+ split=f"train[{data_args.validation_split_percentage}%:]",
503
+ cache_dir=model_args.cache_dir,
504
+ **dataset_args,
505
+ )
506
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
507
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
508
+
509
+ # Load pretrained model and tokenizer
510
+
511
+ # Distributed training:
512
+ # The .from_pretrained methods guarantee that only one local process can concurrently
513
+ # download model & vocab.
514
+ if model_args.config_name:
515
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
516
+ elif model_args.model_name_or_path:
517
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
518
+ else:
519
+ config = CONFIG_MAPPING[model_args.model_type]()
520
+ logger.warning("You are instantiating a new config instance from scratch.")
521
+
522
+ if model_args.tokenizer_name:
523
+ tokenizer = AutoTokenizer.from_pretrained(
524
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
525
+ )
526
+ elif model_args.model_name_or_path:
527
+ tokenizer = AutoTokenizer.from_pretrained(
528
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
529
+ )
530
+ else:
531
+ raise ValueError(
532
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
533
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
534
+ )
535
+
536
+ if model_args.model_name_or_path:
537
+ model = FlaxAutoModelForCausalLM.from_pretrained(
538
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
539
+ )
540
+ else:
541
+ model = FlaxAutoModelForCausalLM.from_config(
542
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
543
+ )
544
+
545
+ # Preprocessing the datasets.
546
+ # First we tokenize all the texts.
547
+ if training_args.do_train:
548
+ column_names = dataset["train"].column_names
549
+ else:
550
+ column_names = dataset["validation"].column_names
551
+ text_column_name = "text" if "text" in column_names else column_names[0]
552
+
553
+ # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
554
+ tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
555
+
556
+ def tokenize_function(examples):
557
+ with CaptureLogger(tok_logger) as cl:
558
+ output = tokenizer(examples[text_column_name])
559
+ # clm input could be much much longer than block_size
560
+ if "Token indices sequence length is longer than the" in cl.out:
561
+ tok_logger.warning(
562
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
563
+ )
564
+ return output
565
+
566
+ tokenized_datasets = dataset.map(
567
+ tokenize_function,
568
+ batched=True,
569
+ num_proc=data_args.preprocessing_num_workers,
570
+ remove_columns=column_names,
571
+ load_from_cache_file=not data_args.overwrite_cache,
572
+ )
573
+
574
+ if data_args.block_size is None:
575
+ block_size = tokenizer.model_max_length
576
+ if block_size > config.max_position_embeddings:
577
+ logger.warning(
578
+ f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
579
+ "Picking 1024 instead. You can change that default value by passing --block_size xxx."
580
+ )
581
+ block_size = 1024
582
+ else:
583
+ if data_args.block_size > tokenizer.model_max_length:
584
+ logger.warning(
585
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
586
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
587
+ )
588
+ block_size = min(data_args.block_size, tokenizer.model_max_length)
589
+
590
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
591
+ def group_texts(examples):
592
+ # Concatenate all texts.
593
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
594
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
595
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
596
+ # customize this part to your needs.
597
+ if total_length >= block_size:
598
+ total_length = (total_length // block_size) * block_size
599
+ # Split by chunks of max_len.
600
+ result = {
601
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
602
+ for k, t in concatenated_examples.items()
603
+ }
604
+ result["labels"] = result["input_ids"].copy()
605
+ return result
606
+
607
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
608
+ # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
609
+ # to preprocess.
610
+ #
611
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
612
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
613
+
614
+ lm_datasets = tokenized_datasets.map(
615
+ group_texts,
616
+ batched=True,
617
+ num_proc=data_args.preprocessing_num_workers,
618
+ load_from_cache_file=not data_args.overwrite_cache,
619
+ )
620
+
621
+ if training_args.do_train:
622
+ if "train" not in tokenized_datasets:
623
+ raise ValueError("--do_train requires a train dataset")
624
+ train_dataset = lm_datasets["train"]
625
+ if data_args.max_train_samples is not None:
626
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
627
+
628
+ if training_args.do_eval:
629
+ if "validation" not in tokenized_datasets:
630
+ raise ValueError("--do_eval requires a validation dataset")
631
+ eval_dataset = lm_datasets["validation"]
632
+ if data_args.max_eval_samples is not None:
633
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
634
+
635
+ # Enable tensorboard only on the master node
636
+ has_tensorboard = is_tensorboard_available()
637
+ if has_tensorboard and jax.process_index() == 0:
638
+ try:
639
+ from flax.metrics.tensorboard import SummaryWriter
640
+
641
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir + "/runs"))
642
+ except ImportError as ie:
643
+ has_tensorboard = False
644
+ logger.warning(
645
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
646
+ )
647
+ else:
648
+ logger.warning(
649
+ "Unable to display metrics through TensorBoard because the package is not installed: "
650
+ "Please run pip install tensorboard to enable."
651
+ )
652
+
653
+ # Initialize our training
654
+ rng = jax.random.PRNGKey(training_args.seed)
655
+ rng, dropout_rng = jax.random.split(rng)
656
+
657
+ # Store some constant
658
+ num_epochs = int(training_args.num_train_epochs)
659
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
660
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
661
+ steps_per_epoch = len(train_dataset) // train_batch_size
662
+ total_train_steps = steps_per_epoch * num_epochs
663
+
664
+ # Create learning rate schedule
665
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
666
+ len(train_dataset),
667
+ train_batch_size,
668
+ training_args.num_train_epochs,
669
+ training_args.warmup_steps,
670
+ training_args.learning_rate,
671
+ )
672
+
673
+ # We use Optax's "masking" functionality to not apply weight decay
674
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
675
+ # mask boolean with the same structure as the parameters.
676
+ # The mask is True for parameters that should be decayed.
677
+ # Note that this mask is specifically adapted for FlaxGPT2.
678
+ # For other models, one should correct the layer norm parameter naming
679
+ # accordingly.
680
+ def decay_mask_fn(params):
681
+ flat_params = traverse_util.flatten_dict(params)
682
+ flat_mask = {
683
+ path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
684
+ for path in flat_params
685
+ }
686
+ return traverse_util.unflatten_dict(flat_mask)
687
+
688
+ # create adam optimizer
689
+ if training_args.adafactor:
690
+ # We use the default parameters here to initialize adafactor,
691
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
692
+ optimizer = optax.adafactor(
693
+ learning_rate=linear_decay_lr_schedule_fn,
694
+ )
695
+ else:
696
+ optimizer = optax.adamw(
697
+ learning_rate=linear_decay_lr_schedule_fn,
698
+ b1=training_args.adam_beta1,
699
+ b2=training_args.adam_beta2,
700
+ eps=training_args.adam_epsilon,
701
+ weight_decay=training_args.weight_decay,
702
+ mask=decay_mask_fn,
703
+ )
704
+
705
+ # Setup train state
706
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
707
+
708
+ # if training_args.resume_from_checkpoint:
709
+ # state = restore_model_checkpoint(training_args.resume_from_checkpoint, state)
710
+ # resume_step = mb_item(state.step)
711
+ # if training_args.adafactor:
712
+ # state = fake_update(state)
713
+ # else:
714
+ resume_step = 0
715
+
716
+ def loss_fn(logits, labels):
717
+ shift_logits = logits[..., :-1, :]
718
+ shift_labels = labels[..., 1:]
719
+ loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
720
+ return loss.mean()
721
+
722
+ # Define gradient update step fn
723
+ def train_step(state, batch):
724
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
725
+
726
+ def compute_loss(params):
727
+ labels = batch.pop("labels")
728
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
729
+ loss = loss_fn(logits, labels)
730
+ return loss
731
+
732
+ grad_fn = jax.value_and_grad(compute_loss)
733
+ loss, grad = grad_fn(state.params)
734
+ grad = jax.lax.pmean(grad, "batch")
735
+
736
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
737
+
738
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
739
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
740
+
741
+ return new_state, metrics
742
+
743
+ # Define eval fn
744
+ def eval_step(params, batch):
745
+ labels = batch.pop("labels")
746
+ logits = model(**batch, params=params, train=False)[0]
747
+ loss = loss_fn(logits, labels)
748
+
749
+ # summarize metrics
750
+ metrics = {"loss": loss}
751
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
752
+ return metrics
753
+
754
+ # Create parallel version of the train and eval step
755
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
756
+ p_eval_step = jax.pmap(eval_step, "batch")
757
+
758
+ # Replicate the train state on each device
759
+ state = state.replicate()
760
+
761
+ logger.info("***** Running training *****")
762
+ logger.info(f" Num examples = {len(train_dataset)}")
763
+ logger.info(f" Num Epochs = {num_epochs}")
764
+ logger.info(f" Num tokenized group examples {len(tokenized_datasets['train'])}")
765
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
766
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
767
+ logger.info(f" Total optimization steps = {total_train_steps}")
768
+
769
+ train_time = 0
770
+ train_metrics = []
771
+ resume_epoch = resume_step // (steps_per_epoch)
772
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... ({resume_epoch + 1}/{num_epochs})", position=0)
773
+ if resume_step != 0:
774
+ logger.info(f"Skipping to epoch {resume_epoch} step {resume_step}")
775
+ for epoch in epochs:
776
+ # ======================== Training ================================
777
+ if epoch < resume_epoch:
778
+ continue
779
+
780
+ train_start = time.time()
781
+
782
+ # Create sampling rng
783
+ rng, input_rng = jax.random.split(rng)
784
+
785
+ # Generate an epoch by shuffling sampling indices from the train dataset
786
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
787
+ steps_per_epoch = len(train_dataset) // train_batch_size
788
+ # train
789
+ for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
790
+ cur_step = epoch * (len(train_dataset) // train_batch_size) + step
791
+ # skip to the step from which we are resuming
792
+ if cur_step < resume_step:
793
+ continue
794
+
795
+ batch = next(train_loader)
796
+ batch = shard(batch)
797
+ state, train_metric = p_train_step(state, batch)
798
+ train_metrics.append(train_metric)
799
+
800
+
801
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
802
+ # Save metrics
803
+ train_metric = unreplicate(train_metric)
804
+ train_time += time.time() - train_start
805
+ if has_tensorboard and jax.process_index() == 0:
806
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
807
+
808
+ epochs.write(
809
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
810
+ )
811
+
812
+ train_metrics = []
813
+
814
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
815
+ # ======================== Evaluating ==============================
816
+ eval_metrics = []
817
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
818
+ eval_steps = len(eval_dataset) // eval_batch_size
819
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
820
+ # Model forward
821
+ batch = next(eval_loader)
822
+ batch = shard(batch)
823
+ metrics = p_eval_step(state.params, batch)
824
+ eval_metrics.append(metrics)
825
+
826
+ # normalize eval metrics
827
+ eval_metrics = get_metrics(eval_metrics)
828
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
829
+
830
+ try:
831
+ eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
832
+ except OverflowError:
833
+ eval_metrics["perplexity"] = float("inf")
834
+
835
+ # Print metrics and update progress bar
836
+ desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
837
+ epochs.write(desc)
838
+ epochs.desc = desc
839
+
840
+ # Save metrics
841
+ if has_tensorboard and jax.process_index() == 0:
842
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
843
+
844
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
845
+ # save checkpoint after each epoch and push checkpoint to the hub
846
+ if jax.process_index() == 0:
847
+ save_model_checkpoint(model, training_args.output_dir, state, with_opt=False,
848
+ push_to_hub=training_args.push_to_hub)
849
+ # params = jax.device_get(unreplicate(state.params))
850
+ # model.save_pretrained(training_args.output_dir, params=params)
851
+ # tokenizer.save_pretrained(training_args.output_dir)
852
+ # if training_args.push_to_hub:
853
+ # repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
854
+
855
+ # Eval after training
856
+ if training_args.do_eval:
857
+ eval_metrics = []
858
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
859
+ eval_steps = len(eval_dataset) // eval_batch_size
860
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
861
+ # Model forward
862
+ batch = shard(next(eval_loader))
863
+ metrics = p_eval_step(state.params, batch)
864
+ eval_metrics.append(metrics)
865
+
866
+ # normalize eval metrics
867
+ eval_metrics = get_metrics(eval_metrics)
868
+ eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)
869
+
870
+ try:
871
+ eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
872
+ except OverflowError:
873
+ eval_metrics["perplexity"] = float("inf")
874
+
875
+ if jax.process_index() == 0:
876
+ eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
877
+ path = os.path.join(training_args.output_dir, "eval_results.json")
878
+ with open(path, "w") as f:
879
+ json.dump(eval_metrics, f, indent=4, sort_keys=True)
880
+
881
+ # save model after training is over
882
+ if jax.process_index() == 0:
883
+ save_model_checkpoint(model, training_args.output_dir, state, with_opt=False,
884
+ push_to_hub=training_args.push_to_hub)
885
+
886
+
887
+
888
+ if __name__ == "__main__":
889
+ main()
run_gpt.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export HF_PROJECT="gpt2-medium-dutch-nedd-2"
4
+
5
+ # Variables for training the tokenizer and creating the config
6
+ export VOCAB_SIZE="50257"
7
+ export DATASET="${HOME}/data/nedd_wiki_news/nedd_wiki_news.py" # Name of the dataset in the Huggingface Hub
8
+ export DATASET_CONFIG="ddwn500_nl" # Config of the dataset in the Huggingface Hub
9
+ export DATASET_SPLIT="train" # Split to use for training tokenizer and model
10
+ export TEXT_FIELD="text" # Field containing the text to be used for training
11
+ export CONFIG_TYPE="gpt2-medium" # Config that our model will use
12
+ export MODEL_PATH="${HOME}/data/${HF_PROJECT}" # Path to the model, e.g. here inside the mount
13
+
14
+ python run_clm_flax.py \
15
+ --output_dir="${MODEL_PATH}" \
16
+ --model_type="gpt2" \
17
+ --config_name="${MODEL_PATH}" \
18
+ --tokenizer_name="${MODEL_PATH}" \
19
+ --preprocessing_num_workers="96" \
20
+ --do_train --do_eval \
21
+ --dataset_name="${DATASET}" \
22
+ --dataset_config_name="${DATASET_CONFIG}" \
23
+ --block_size="512" \
24
+ --per_device_train_batch_size="16" \
25
+ --per_device_eval_batch_size="16" \
26
+ --learning_rate="0.0024" --warmup_steps="5000" \
27
+ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
28
+ --overwrite_output_dir \
29
+ --num_train_epochs="1" \
30
+ --logging_steps="500" \
31
+ --save_steps="40000" \
32
+ --eval_steps="2500"
33
+
34
+ # \
35
+ # --push_to_hub
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "special_tokens_map_file": null, "name_or_path": ".", "tokenizer_class": "GPT2Tokenizer"}
vocab.json ADDED
The diff for this file is too large to render. See raw diff