Spaces:
Running
Running
feat: split shards by host
Browse files- dalle_mini/data.py +30 -11
dalle_mini/data.py
CHANGED
@@ -4,9 +4,9 @@ from functools import partial
|
|
4 |
import jax
|
5 |
import jax.numpy as jnp
|
6 |
import numpy as np
|
|
|
7 |
from datasets import Dataset, load_dataset
|
8 |
from flax.training.common_utils import shard
|
9 |
-
from braceexpand import braceexpand
|
10 |
|
11 |
from .text import TextNormalizer
|
12 |
|
@@ -30,8 +30,10 @@ class Dataset:
|
|
30 |
train_dataset: Dataset = field(init=False)
|
31 |
eval_dataset: Dataset = field(init=False)
|
32 |
rng_dataset: jnp.ndarray = field(init=False)
|
|
|
33 |
|
34 |
def __post_init__(self):
|
|
|
35 |
# define data_files
|
36 |
if self.train_file is not None or self.validation_file is not None:
|
37 |
# accept braceexpand notation
|
@@ -39,6 +41,11 @@ class Dataset:
|
|
39 |
f = getattr(self, k)
|
40 |
if isinstance(f, str):
|
41 |
setattr(self, k, list(braceexpand(f)))
|
|
|
|
|
|
|
|
|
|
|
42 |
data_files = {
|
43 |
"train": self.train_file,
|
44 |
"validation": self.validation_file,
|
@@ -169,17 +176,29 @@ class Dataset:
|
|
169 |
batch = shard(batch)
|
170 |
yield batch
|
171 |
|
172 |
-
def _dataloader_datasets_streaming(
|
|
|
|
|
|
|
173 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
174 |
batch = {k: [] for k in keys}
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
if split == "train":
|
185 |
ds = self.train_dataset
|
@@ -191,7 +210,7 @@ class Dataset:
|
|
191 |
if self.streaming:
|
192 |
if split == "train":
|
193 |
ds.set_epoch(epoch)
|
194 |
-
return _dataloader_datasets_streaming(ds, batch_size)
|
195 |
else:
|
196 |
if split == "train":
|
197 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
|
|
4 |
import jax
|
5 |
import jax.numpy as jnp
|
6 |
import numpy as np
|
7 |
+
from braceexpand import braceexpand
|
8 |
from datasets import Dataset, load_dataset
|
9 |
from flax.training.common_utils import shard
|
|
|
10 |
|
11 |
from .text import TextNormalizer
|
12 |
|
|
|
30 |
train_dataset: Dataset = field(init=False)
|
31 |
eval_dataset: Dataset = field(init=False)
|
32 |
rng_dataset: jnp.ndarray = field(init=False)
|
33 |
+
multi_hosts: bool = field(init=False)
|
34 |
|
35 |
def __post_init__(self):
|
36 |
+
self.multi_hosts = jax.process_count > 1
|
37 |
# define data_files
|
38 |
if self.train_file is not None or self.validation_file is not None:
|
39 |
# accept braceexpand notation
|
|
|
41 |
f = getattr(self, k)
|
42 |
if isinstance(f, str):
|
43 |
setattr(self, k, list(braceexpand(f)))
|
44 |
+
# for list of files, split training data shards by host
|
45 |
+
if isinstance(self.train_file, list) and self.multi_hosts:
|
46 |
+
self.train_file = self.train_file[
|
47 |
+
jax.process_index() :: jax.process_count()
|
48 |
+
]
|
49 |
data_files = {
|
50 |
"train": self.train_file,
|
51 |
"validation": self.validation_file,
|
|
|
176 |
batch = shard(batch)
|
177 |
yield batch
|
178 |
|
179 |
+
def _dataloader_datasets_streaming(
|
180 |
+
dataset: Dataset, batch_size: int, epoch: int
|
181 |
+
):
|
182 |
+
# epoch is only use for multi-host
|
183 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
184 |
batch = {k: [] for k in keys}
|
185 |
+
first_loop = True
|
186 |
+
while self.multi_hosts or first_loop:
|
187 |
+
# in multi-host, we run forever (no epoch) as hosts need to stop
|
188 |
+
# at same the time and we don't know how much data is on each host
|
189 |
+
if not first_loop:
|
190 |
+
# multi-host setting, we reshuffle shards
|
191 |
+
epoch += 1
|
192 |
+
dataset.set_epoch(epoch)
|
193 |
+
for item in dataset:
|
194 |
+
for k, v in item.items():
|
195 |
+
batch[k].append(v)
|
196 |
+
if len(batch[keys[0]]) == batch_size:
|
197 |
+
batch = {k: jnp.array(v) for k, v in batch.items()}
|
198 |
+
batch = shard(batch)
|
199 |
+
yield batch
|
200 |
+
batch = {k: [] for k in keys}
|
201 |
+
first_loop = False
|
202 |
|
203 |
if split == "train":
|
204 |
ds = self.train_dataset
|
|
|
210 |
if self.streaming:
|
211 |
if split == "train":
|
212 |
ds.set_epoch(epoch)
|
213 |
+
return _dataloader_datasets_streaming(ds, batch_size, epoch)
|
214 |
else:
|
215 |
if split == "train":
|
216 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|