Spaces:
Running
Running
fix(train): handle seed_dataset
Browse files- src/dalle_mini/data.py +3 -8
- tools/train/train.py +2 -2
src/dalle_mini/data.py
CHANGED
@@ -161,7 +161,7 @@ class Dataset:
|
|
161 |
):
|
162 |
"""
|
163 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
164 |
-
Shuffle batches if
|
165 |
"""
|
166 |
steps_per_epoch = len(dataset) // batch_size
|
167 |
|
@@ -184,17 +184,13 @@ class Dataset:
|
|
184 |
def _dataloader_datasets_streaming(
|
185 |
dataset: Dataset, batch_size: int, epoch: int
|
186 |
):
|
187 |
-
# epoch is only use for multi-host
|
188 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
189 |
batch = {k: [] for k in keys}
|
190 |
first_loop = True
|
191 |
while self.multi_hosts or first_loop:
|
192 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
193 |
# at the same time and we don't know how much data is on each host
|
194 |
-
|
195 |
-
# multi-host setting, we reshuffle shards
|
196 |
-
epoch += 1
|
197 |
-
dataset.set_epoch(epoch)
|
198 |
for item in dataset:
|
199 |
for k, v in item.items():
|
200 |
batch[k].append(v)
|
@@ -203,6 +199,7 @@ class Dataset:
|
|
203 |
batch = shard(batch)
|
204 |
yield batch
|
205 |
batch = {k: [] for k in keys}
|
|
|
206 |
first_loop = False
|
207 |
|
208 |
if split == "train":
|
@@ -213,8 +210,6 @@ class Dataset:
|
|
213 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
214 |
|
215 |
if self.streaming:
|
216 |
-
if split == "train":
|
217 |
-
ds.set_epoch(epoch)
|
218 |
return _dataloader_datasets_streaming(ds, batch_size, epoch)
|
219 |
else:
|
220 |
if split == "train":
|
|
|
161 |
):
|
162 |
"""
|
163 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
164 |
+
Shuffle batches if rng is set.
|
165 |
"""
|
166 |
steps_per_epoch = len(dataset) // batch_size
|
167 |
|
|
|
184 |
def _dataloader_datasets_streaming(
|
185 |
dataset: Dataset, batch_size: int, epoch: int
|
186 |
):
|
|
|
187 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
188 |
batch = {k: [] for k in keys}
|
189 |
first_loop = True
|
190 |
while self.multi_hosts or first_loop:
|
191 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
192 |
# at the same time and we don't know how much data is on each host
|
193 |
+
dataset.set_epoch(epoch) # reshuffle data at each epoch
|
|
|
|
|
|
|
194 |
for item in dataset:
|
195 |
for k, v in item.items():
|
196 |
batch[k].append(v)
|
|
|
199 |
batch = shard(batch)
|
200 |
yield batch
|
201 |
batch = {k: [] for k in keys}
|
202 |
+
epoch += 1
|
203 |
first_loop = False
|
204 |
|
205 |
if split == "train":
|
|
|
210 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
211 |
|
212 |
if self.streaming:
|
|
|
|
|
213 |
return _dataloader_datasets_streaming(ds, batch_size, epoch)
|
214 |
else:
|
215 |
if split == "train":
|
tools/train/train.py
CHANGED
@@ -241,7 +241,7 @@ class TrainingArguments:
|
|
241 |
)
|
242 |
optim_quantized: bool = field(
|
243 |
default=False,
|
244 |
-
|
245 |
"help": "Whether to quantize optimizer (only supported with distributed_shampoo)."
|
246 |
},
|
247 |
)
|
@@ -845,7 +845,7 @@ def main():
|
|
845 |
metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
|
846 |
|
847 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
848 |
-
train_loader = dataset.dataloader("train", train_batch_size)
|
849 |
# train
|
850 |
for batch in tqdm(
|
851 |
train_loader,
|
|
|
241 |
)
|
242 |
optim_quantized: bool = field(
|
243 |
default=False,
|
244 |
+
metadata={
|
245 |
"help": "Whether to quantize optimizer (only supported with distributed_shampoo)."
|
246 |
},
|
247 |
)
|
|
|
845 |
metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
|
846 |
|
847 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
848 |
+
train_loader = dataset.dataloader("train", train_batch_size, epoch)
|
849 |
# train
|
850 |
for batch in tqdm(
|
851 |
train_loader,
|