Spaces:
Running
Running
feat: custom gradient accumulation
Browse files- src/dalle_mini/data.py +33 -6
- tools/train/train.py +43 -18
src/dalle_mini/data.py
CHANGED
@@ -153,16 +153,21 @@ class Dataset:
|
|
153 |
),
|
154 |
)
|
155 |
|
156 |
-
def dataloader(
|
|
|
|
|
|
|
|
|
157 |
def _dataloader_datasets_non_streaming(
|
158 |
dataset: Dataset,
|
159 |
-
|
160 |
rng: jax.random.PRNGKey = None,
|
161 |
):
|
162 |
"""
|
163 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
164 |
Shuffle batches if rng is set.
|
165 |
"""
|
|
|
166 |
steps_per_epoch = len(dataset) // batch_size
|
167 |
|
168 |
if rng is not None:
|
@@ -182,7 +187,11 @@ class Dataset:
|
|
182 |
yield batch
|
183 |
|
184 |
def _dataloader_datasets_streaming(
|
185 |
-
dataset: Dataset,
|
|
|
|
|
|
|
|
|
186 |
):
|
187 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
188 |
batch = {k: [] for k in keys}
|
@@ -199,8 +208,22 @@ class Dataset:
|
|
199 |
for item in dataset:
|
200 |
for k, v in item.items():
|
201 |
batch[k].append(v)
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
batch = shard(batch)
|
205 |
yield batch
|
206 |
batch = {k: [] for k in keys}
|
@@ -214,11 +237,15 @@ class Dataset:
|
|
214 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
215 |
|
216 |
if self.streaming:
|
217 |
-
return _dataloader_datasets_streaming(
|
|
|
|
|
218 |
else:
|
219 |
if split == "train":
|
220 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
221 |
-
return _dataloader_datasets_non_streaming(
|
|
|
|
|
222 |
|
223 |
@property
|
224 |
def length(self):
|
|
|
153 |
),
|
154 |
)
|
155 |
|
156 |
+
def dataloader(
|
157 |
+
self, split, per_device_batch_size, gradient_accumulation_steps=None, epoch=None
|
158 |
+
):
|
159 |
+
num_devices = jax.local_device_count()
|
160 |
+
|
161 |
def _dataloader_datasets_non_streaming(
|
162 |
dataset: Dataset,
|
163 |
+
per_device_batch_size: int,
|
164 |
rng: jax.random.PRNGKey = None,
|
165 |
):
|
166 |
"""
|
167 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
168 |
Shuffle batches if rng is set.
|
169 |
"""
|
170 |
+
batch_size = per_device_batch_size * num_devices
|
171 |
steps_per_epoch = len(dataset) // batch_size
|
172 |
|
173 |
if rng is not None:
|
|
|
187 |
yield batch
|
188 |
|
189 |
def _dataloader_datasets_streaming(
|
190 |
+
dataset: Dataset,
|
191 |
+
split: str,
|
192 |
+
per_device_batch_size: int,
|
193 |
+
gradient_accumulation_steps: int,
|
194 |
+
epoch: int,
|
195 |
):
|
196 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
197 |
batch = {k: [] for k in keys}
|
|
|
208 |
for item in dataset:
|
209 |
for k, v in item.items():
|
210 |
batch[k].append(v)
|
211 |
+
# batch = 5, devices = 8, accumulation = 2 / batch_size = 5 x 8
|
212 |
+
# (40, 3, 3) -> shard 8 x (5, 3, 3)
|
213 |
+
# (16, 5, 3, 3) -> shard 8 x (2, 5, 3, 3)
|
214 |
+
if len(batch[keys[0]]) == per_device_batch_size * num_devices * (
|
215 |
+
gradient_accumulation_steps
|
216 |
+
if gradient_accumulation_steps is not None
|
217 |
+
else 1
|
218 |
+
):
|
219 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
220 |
+
if gradient_accumulation_steps is not None:
|
221 |
+
batch = jax.tree_map(
|
222 |
+
lambda x: x.reshape(
|
223 |
+
(-1, per_device_batch_size) + x.shape[1:]
|
224 |
+
),
|
225 |
+
batch,
|
226 |
+
)
|
227 |
batch = shard(batch)
|
228 |
yield batch
|
229 |
batch = {k: [] for k in keys}
|
|
|
237 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
238 |
|
239 |
if self.streaming:
|
240 |
+
return _dataloader_datasets_streaming(
|
241 |
+
ds, split, per_device_batch_size, gradient_accumulation_steps, epoch
|
242 |
+
)
|
243 |
else:
|
244 |
if split == "train":
|
245 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
246 |
+
return _dataloader_datasets_non_streaming(
|
247 |
+
ds, per_device_batch_size, input_rng
|
248 |
+
)
|
249 |
|
250 |
@property
|
251 |
def length(self):
|
tools/train/train.py
CHANGED
@@ -277,8 +277,8 @@ class TrainingArguments:
|
|
277 |
},
|
278 |
)
|
279 |
|
280 |
-
num_train_epochs:
|
281 |
-
default=3
|
282 |
)
|
283 |
warmup_steps: int = field(
|
284 |
default=0, metadata={"help": "Linear warmup over warmup_steps."}
|
@@ -515,10 +515,10 @@ def main():
|
|
515 |
rng, dropout_rng = jax.random.split(rng)
|
516 |
|
517 |
# Store some constant
|
518 |
-
num_epochs =
|
519 |
# batch size per node
|
520 |
train_batch_size = (
|
521 |
-
|
522 |
)
|
523 |
batch_size_per_update = (
|
524 |
train_batch_size
|
@@ -526,7 +526,7 @@ def main():
|
|
526 |
* jax.process_count()
|
527 |
)
|
528 |
eval_batch_size = (
|
529 |
-
|
530 |
)
|
531 |
len_train_dataset, len_eval_dataset = dataset.length
|
532 |
steps_per_epoch = (
|
@@ -645,12 +645,6 @@ def main():
|
|
645 |
clipping_threshold=training_args.max_grad_norm,
|
646 |
)
|
647 |
|
648 |
-
# add gradient accumulation
|
649 |
-
if training_args.gradient_accumulation_steps > 1:
|
650 |
-
optimizer = optax.MultiSteps(
|
651 |
-
optimizer, training_args.gradient_accumulation_steps
|
652 |
-
)
|
653 |
-
|
654 |
# Setup train state
|
655 |
state = TrainState.create(
|
656 |
apply_fn=model.__call__,
|
@@ -673,16 +667,42 @@ def main():
|
|
673 |
def train_step(state, batch, delta_time):
|
674 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
675 |
|
676 |
-
def compute_loss(params,
|
677 |
-
labels =
|
678 |
logits = state.apply_fn(
|
679 |
-
**
|
680 |
)[0]
|
681 |
-
|
682 |
-
return loss
|
683 |
|
684 |
grad_fn = jax.value_and_grad(compute_loss)
|
685 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
686 |
grads = jax.lax.pmean(grads, "batch")
|
687 |
state = state.apply_gradients(
|
688 |
grads=grads,
|
@@ -871,7 +891,12 @@ def main():
|
|
871 |
metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
|
872 |
|
873 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
874 |
-
train_loader = dataset.dataloader(
|
|
|
|
|
|
|
|
|
|
|
875 |
# train
|
876 |
for batch in tqdm(
|
877 |
train_loader,
|
|
|
277 |
},
|
278 |
)
|
279 |
|
280 |
+
num_train_epochs: int = field(
|
281 |
+
default=3, metadata={"help": "Total number of training epochs to perform."}
|
282 |
)
|
283 |
warmup_steps: int = field(
|
284 |
default=0, metadata={"help": "Linear warmup over warmup_steps."}
|
|
|
515 |
rng, dropout_rng = jax.random.split(rng)
|
516 |
|
517 |
# Store some constant
|
518 |
+
num_epochs = training_args.num_train_epochs
|
519 |
# batch size per node
|
520 |
train_batch_size = (
|
521 |
+
training_args.per_device_train_batch_size * jax.local_device_count()
|
522 |
)
|
523 |
batch_size_per_update = (
|
524 |
train_batch_size
|
|
|
526 |
* jax.process_count()
|
527 |
)
|
528 |
eval_batch_size = (
|
529 |
+
training_args.per_device_eval_batch_size * jax.local_device_count()
|
530 |
)
|
531 |
len_train_dataset, len_eval_dataset = dataset.length
|
532 |
steps_per_epoch = (
|
|
|
645 |
clipping_threshold=training_args.max_grad_norm,
|
646 |
)
|
647 |
|
|
|
|
|
|
|
|
|
|
|
|
|
648 |
# Setup train state
|
649 |
state = TrainState.create(
|
650 |
apply_fn=model.__call__,
|
|
|
667 |
def train_step(state, batch, delta_time):
|
668 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
669 |
|
670 |
+
def compute_loss(params, minibatch):
|
671 |
+
labels = minibatch.pop("labels")
|
672 |
logits = state.apply_fn(
|
673 |
+
**minibatch, params=params, dropout_rng=dropout_rng, train=True
|
674 |
)[0]
|
675 |
+
return loss_fn(logits, labels)
|
|
|
676 |
|
677 |
grad_fn = jax.value_and_grad(compute_loss)
|
678 |
+
|
679 |
+
if training_args.gradient_accumulation_steps == 1:
|
680 |
+
minibatch = jax.tree_map(lambda x: x[0], batch)
|
681 |
+
loss, grads = grad_fn(state.params, minibatch)
|
682 |
+
else:
|
683 |
+
|
684 |
+
def _cumul_loss_grads(i, cumul_loss_grads):
|
685 |
+
minibatch = jax.tree_map(lambda x: x[i], batch)
|
686 |
+
return jax.tree_map(
|
687 |
+
lambda x, y: x + y,
|
688 |
+
cumul_loss_grads,
|
689 |
+
grad_fn(state.params, minibatch),
|
690 |
+
)
|
691 |
+
|
692 |
+
init_loss_grads = (
|
693 |
+
0.0,
|
694 |
+
jax.tree_map(jnp.zeros_like, state.params),
|
695 |
+
)
|
696 |
+
loss, grads = jax.tree_map(
|
697 |
+
lambda x: x / training_args.gradient_accumulation_steps,
|
698 |
+
jax.lax.fori_loop(
|
699 |
+
0,
|
700 |
+
training_args.gradient_accumulation_steps,
|
701 |
+
_cumul_loss_grads,
|
702 |
+
init_loss_grads,
|
703 |
+
),
|
704 |
+
)
|
705 |
+
|
706 |
grads = jax.lax.pmean(grads, "batch")
|
707 |
state = state.apply_gradients(
|
708 |
grads=grads,
|
|
|
891 |
metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
|
892 |
|
893 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
894 |
+
train_loader = dataset.dataloader(
|
895 |
+
"train",
|
896 |
+
training_args.per_device_train_batch_size,
|
897 |
+
training_args.gradient_accumulation_steps,
|
898 |
+
epoch,
|
899 |
+
)
|
900 |
# train
|
901 |
for batch in tqdm(
|
902 |
train_loader,
|