boris commited on
Commit
2d07559
1 Parent(s): df01fa8

feat: custom gradient accumulation

Browse files
Files changed (2) hide show
  1. src/dalle_mini/data.py +33 -6
  2. tools/train/train.py +43 -18
src/dalle_mini/data.py CHANGED
@@ -153,16 +153,21 @@ class Dataset:
153
  ),
154
  )
155
 
156
- def dataloader(self, split, batch_size, epoch=None):
 
 
 
 
157
  def _dataloader_datasets_non_streaming(
158
  dataset: Dataset,
159
- batch_size: int,
160
  rng: jax.random.PRNGKey = None,
161
  ):
162
  """
163
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
164
  Shuffle batches if rng is set.
165
  """
 
166
  steps_per_epoch = len(dataset) // batch_size
167
 
168
  if rng is not None:
@@ -182,7 +187,11 @@ class Dataset:
182
  yield batch
183
 
184
  def _dataloader_datasets_streaming(
185
- dataset: Dataset, split: str, batch_size: int, epoch: int
 
 
 
 
186
  ):
187
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
188
  batch = {k: [] for k in keys}
@@ -199,8 +208,22 @@ class Dataset:
199
  for item in dataset:
200
  for k, v in item.items():
201
  batch[k].append(v)
202
- if len(batch[keys[0]]) == batch_size:
 
 
 
 
 
 
 
203
  batch = {k: jnp.array(v) for k, v in batch.items()}
 
 
 
 
 
 
 
204
  batch = shard(batch)
205
  yield batch
206
  batch = {k: [] for k in keys}
@@ -214,11 +237,15 @@ class Dataset:
214
  raise ValueError(f'split must be "train" or "eval", got {split}')
215
 
216
  if self.streaming:
217
- return _dataloader_datasets_streaming(ds, split, batch_size, epoch)
 
 
218
  else:
219
  if split == "train":
220
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
221
- return _dataloader_datasets_non_streaming(ds, batch_size, input_rng)
 
 
222
 
223
  @property
224
  def length(self):
 
153
  ),
154
  )
155
 
156
+ def dataloader(
157
+ self, split, per_device_batch_size, gradient_accumulation_steps=None, epoch=None
158
+ ):
159
+ num_devices = jax.local_device_count()
160
+
161
  def _dataloader_datasets_non_streaming(
162
  dataset: Dataset,
163
+ per_device_batch_size: int,
164
  rng: jax.random.PRNGKey = None,
165
  ):
166
  """
167
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
168
  Shuffle batches if rng is set.
169
  """
170
+ batch_size = per_device_batch_size * num_devices
171
  steps_per_epoch = len(dataset) // batch_size
172
 
173
  if rng is not None:
 
187
  yield batch
188
 
189
  def _dataloader_datasets_streaming(
190
+ dataset: Dataset,
191
+ split: str,
192
+ per_device_batch_size: int,
193
+ gradient_accumulation_steps: int,
194
+ epoch: int,
195
  ):
196
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
197
  batch = {k: [] for k in keys}
 
208
  for item in dataset:
209
  for k, v in item.items():
210
  batch[k].append(v)
211
+ # batch = 5, devices = 8, accumulation = 2 / batch_size = 5 x 8
212
+ # (40, 3, 3) -> shard 8 x (5, 3, 3)
213
+ # (16, 5, 3, 3) -> shard 8 x (2, 5, 3, 3)
214
+ if len(batch[keys[0]]) == per_device_batch_size * num_devices * (
215
+ gradient_accumulation_steps
216
+ if gradient_accumulation_steps is not None
217
+ else 1
218
+ ):
219
  batch = {k: jnp.array(v) for k, v in batch.items()}
220
+ if gradient_accumulation_steps is not None:
221
+ batch = jax.tree_map(
222
+ lambda x: x.reshape(
223
+ (-1, per_device_batch_size) + x.shape[1:]
224
+ ),
225
+ batch,
226
+ )
227
  batch = shard(batch)
228
  yield batch
229
  batch = {k: [] for k in keys}
 
237
  raise ValueError(f'split must be "train" or "eval", got {split}')
238
 
239
  if self.streaming:
240
+ return _dataloader_datasets_streaming(
241
+ ds, split, per_device_batch_size, gradient_accumulation_steps, epoch
242
+ )
243
  else:
244
  if split == "train":
245
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
246
+ return _dataloader_datasets_non_streaming(
247
+ ds, per_device_batch_size, input_rng
248
+ )
249
 
250
  @property
251
  def length(self):
tools/train/train.py CHANGED
@@ -277,8 +277,8 @@ class TrainingArguments:
277
  },
278
  )
279
 
280
- num_train_epochs: float = field(
281
- default=3.0, metadata={"help": "Total number of training epochs to perform."}
282
  )
283
  warmup_steps: int = field(
284
  default=0, metadata={"help": "Linear warmup over warmup_steps."}
@@ -515,10 +515,10 @@ def main():
515
  rng, dropout_rng = jax.random.split(rng)
516
 
517
  # Store some constant
518
- num_epochs = int(training_args.num_train_epochs)
519
  # batch size per node
520
  train_batch_size = (
521
- int(training_args.per_device_train_batch_size) * jax.local_device_count()
522
  )
523
  batch_size_per_update = (
524
  train_batch_size
@@ -526,7 +526,7 @@ def main():
526
  * jax.process_count()
527
  )
528
  eval_batch_size = (
529
- int(training_args.per_device_eval_batch_size) * jax.local_device_count()
530
  )
531
  len_train_dataset, len_eval_dataset = dataset.length
532
  steps_per_epoch = (
@@ -645,12 +645,6 @@ def main():
645
  clipping_threshold=training_args.max_grad_norm,
646
  )
647
 
648
- # add gradient accumulation
649
- if training_args.gradient_accumulation_steps > 1:
650
- optimizer = optax.MultiSteps(
651
- optimizer, training_args.gradient_accumulation_steps
652
- )
653
-
654
  # Setup train state
655
  state = TrainState.create(
656
  apply_fn=model.__call__,
@@ -673,16 +667,42 @@ def main():
673
  def train_step(state, batch, delta_time):
674
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
675
 
676
- def compute_loss(params, batch):
677
- labels = batch.pop("labels")
678
  logits = state.apply_fn(
679
- **batch, params=params, dropout_rng=dropout_rng, train=True
680
  )[0]
681
- loss = loss_fn(logits, labels)
682
- return loss
683
 
684
  grad_fn = jax.value_and_grad(compute_loss)
685
- loss, grads = grad_fn(state.params, batch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
686
  grads = jax.lax.pmean(grads, "batch")
687
  state = state.apply_gradients(
688
  grads=grads,
@@ -871,7 +891,12 @@ def main():
871
  metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
872
 
873
  # Generate an epoch by shuffling sampling indices from the train dataset
874
- train_loader = dataset.dataloader("train", train_batch_size, epoch)
 
 
 
 
 
875
  # train
876
  for batch in tqdm(
877
  train_loader,
 
277
  },
278
  )
279
 
280
+ num_train_epochs: int = field(
281
+ default=3, metadata={"help": "Total number of training epochs to perform."}
282
  )
283
  warmup_steps: int = field(
284
  default=0, metadata={"help": "Linear warmup over warmup_steps."}
 
515
  rng, dropout_rng = jax.random.split(rng)
516
 
517
  # Store some constant
518
+ num_epochs = training_args.num_train_epochs
519
  # batch size per node
520
  train_batch_size = (
521
+ training_args.per_device_train_batch_size * jax.local_device_count()
522
  )
523
  batch_size_per_update = (
524
  train_batch_size
 
526
  * jax.process_count()
527
  )
528
  eval_batch_size = (
529
+ training_args.per_device_eval_batch_size * jax.local_device_count()
530
  )
531
  len_train_dataset, len_eval_dataset = dataset.length
532
  steps_per_epoch = (
 
645
  clipping_threshold=training_args.max_grad_norm,
646
  )
647
 
 
 
 
 
 
 
648
  # Setup train state
649
  state = TrainState.create(
650
  apply_fn=model.__call__,
 
667
  def train_step(state, batch, delta_time):
668
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
669
 
670
+ def compute_loss(params, minibatch):
671
+ labels = minibatch.pop("labels")
672
  logits = state.apply_fn(
673
+ **minibatch, params=params, dropout_rng=dropout_rng, train=True
674
  )[0]
675
+ return loss_fn(logits, labels)
 
676
 
677
  grad_fn = jax.value_and_grad(compute_loss)
678
+
679
+ if training_args.gradient_accumulation_steps == 1:
680
+ minibatch = jax.tree_map(lambda x: x[0], batch)
681
+ loss, grads = grad_fn(state.params, minibatch)
682
+ else:
683
+
684
+ def _cumul_loss_grads(i, cumul_loss_grads):
685
+ minibatch = jax.tree_map(lambda x: x[i], batch)
686
+ return jax.tree_map(
687
+ lambda x, y: x + y,
688
+ cumul_loss_grads,
689
+ grad_fn(state.params, minibatch),
690
+ )
691
+
692
+ init_loss_grads = (
693
+ 0.0,
694
+ jax.tree_map(jnp.zeros_like, state.params),
695
+ )
696
+ loss, grads = jax.tree_map(
697
+ lambda x: x / training_args.gradient_accumulation_steps,
698
+ jax.lax.fori_loop(
699
+ 0,
700
+ training_args.gradient_accumulation_steps,
701
+ _cumul_loss_grads,
702
+ init_loss_grads,
703
+ ),
704
+ )
705
+
706
  grads = jax.lax.pmean(grads, "batch")
707
  state = state.apply_gradients(
708
  grads=grads,
 
891
  metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
892
 
893
  # Generate an epoch by shuffling sampling indices from the train dataset
894
+ train_loader = dataset.dataloader(
895
+ "train",
896
+ training_args.per_device_train_batch_size,
897
+ training_args.gradient_accumulation_steps,
898
+ epoch,
899
+ )
900
  # train
901
  for batch in tqdm(
902
  train_loader,