Spaces:
Running
Running
feat: handle model parallel
Browse files- src/dalle_mini/data.py +6 -1
- src/dalle_mini/model/configuration.py +19 -18
- src/dalle_mini/model/modeling.py +4 -4
- tools/train/train.py +28 -13
src/dalle_mini/data.py
CHANGED
@@ -85,7 +85,12 @@ class Dataset:
|
|
85 |
else self.eval_dataset.select(range(self.max_eval_samples))
|
86 |
)
|
87 |
|
88 |
-
def preprocess(self, tokenizer,
|
|
|
|
|
|
|
|
|
|
|
89 |
if self.streaming:
|
90 |
# we need to shuffle early in streaming mode
|
91 |
if hasattr(self, "train_dataset"):
|
|
|
85 |
else self.eval_dataset.select(range(self.max_eval_samples))
|
86 |
)
|
87 |
|
88 |
+
def preprocess(self, tokenizer, config):
|
89 |
+
# get required config variables
|
90 |
+
decoder_start_token_id = config.decoder_start_token_id
|
91 |
+
normalize_text = config.normalize_text
|
92 |
+
max_length = config.max_text_length
|
93 |
+
|
94 |
if self.streaming:
|
95 |
# we need to shuffle early in streaming mode
|
96 |
if hasattr(self, "train_dataset"):
|
src/dalle_mini/model/configuration.py
CHANGED
@@ -59,6 +59,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
59 |
is_encoder_decoder=True,
|
60 |
forced_eos_token_id=None,
|
61 |
tie_word_embeddings=False, # different modalities and sizes
|
|
|
62 |
**kwargs,
|
63 |
):
|
64 |
self.normalize_text = normalize_text
|
@@ -87,28 +88,28 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
87 |
scale_embedding # scale factor will be sqrt(d_model) if True
|
88 |
)
|
89 |
|
90 |
-
#
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
kwargs.pop(k, None)
|
100 |
|
101 |
super().__init__(
|
102 |
-
|
103 |
-
+ 1, # needed to avoid errors during generation (converted to jnp.array)
|
104 |
-
bos_token_id=image_vocab_size + 1, # set to unreachable values
|
105 |
-
eos_token_id=image_vocab_size + 1,
|
106 |
is_encoder_decoder=is_encoder_decoder,
|
107 |
-
decoder_start_token_id=image_vocab_size, # BOS appended to vocab
|
108 |
-
forced_eos_token_id=forced_eos_token_id,
|
109 |
tie_word_embeddings=tie_word_embeddings,
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
**kwargs,
|
113 |
)
|
114 |
|
|
|
59 |
is_encoder_decoder=True,
|
60 |
forced_eos_token_id=None,
|
61 |
tie_word_embeddings=False, # different modalities and sizes
|
62 |
+
do_sample=True,
|
63 |
**kwargs,
|
64 |
):
|
65 |
self.normalize_text = normalize_text
|
|
|
88 |
scale_embedding # scale factor will be sqrt(d_model) if True
|
89 |
)
|
90 |
|
91 |
+
# special token id's are appended to vocab if not provided
|
92 |
+
decoder_start_token_id = kwargs.pop("decoder_start_token_id", image_vocab_size)
|
93 |
+
bos_token_id = kwargs.pop("bos_token_id", image_vocab_size)
|
94 |
+
pad_token_id = kwargs.pop("pad_token_id", image_vocab_size)
|
95 |
+
eos_token_id = kwargs.pop("eos_token_id", image_vocab_size)
|
96 |
+
|
97 |
+
# we generate to image_length + 1 (for bos) by default
|
98 |
+
min_length = kwargs.pop("min_length", image_length + 1)
|
99 |
+
max_length = kwargs.pop("max_length", image_length + 1)
|
|
|
100 |
|
101 |
super().__init__(
|
102 |
+
# args required in parent class
|
|
|
|
|
|
|
103 |
is_encoder_decoder=is_encoder_decoder,
|
|
|
|
|
104 |
tie_word_embeddings=tie_word_embeddings,
|
105 |
+
forced_eos_token_id=forced_eos_token_id,
|
106 |
+
decoder_start_token_id=decoder_start_token_id,
|
107 |
+
bos_token_id=bos_token_id,
|
108 |
+
pad_token_id=pad_token_id,
|
109 |
+
eos_token_id=eos_token_id,
|
110 |
+
min_length=min_length,
|
111 |
+
max_length=max_length,
|
112 |
+
do_sample=do_sample,
|
113 |
**kwargs,
|
114 |
)
|
115 |
|
src/dalle_mini/model/modeling.py
CHANGED
@@ -54,7 +54,7 @@ logger = logging.get_logger(__name__)
|
|
54 |
class FlaxBartAttention(FlaxBartAttention):
|
55 |
"""
|
56 |
Edits:
|
57 |
-
- causal mask is used only in decoder and considers image_length
|
58 |
"""
|
59 |
|
60 |
def setup(self) -> None:
|
@@ -81,7 +81,7 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
81 |
if self.causal:
|
82 |
# used only in decoder
|
83 |
self.causal_mask = make_causal_mask(
|
84 |
-
jnp.ones((1, self.config.image_length
|
85 |
)
|
86 |
|
87 |
|
@@ -240,7 +240,7 @@ class FlaxBartDecoder(FlaxBartDecoder):
|
|
240 |
"""
|
241 |
Edits:
|
242 |
- offset set to 0 (no padding token)
|
243 |
-
- use image_length
|
244 |
- use custom FlaxBartDecoderLayerCollection
|
245 |
- embed_tokens cannot be None (issue at compile time)
|
246 |
"""
|
@@ -258,7 +258,7 @@ class FlaxBartDecoder(FlaxBartDecoder):
|
|
258 |
# and adjust num_embeddings appropriately. Other models don't have this hack
|
259 |
self.offset = 0
|
260 |
self.embed_positions = nn.Embed(
|
261 |
-
self.config.image_length +
|
262 |
embed_dim,
|
263 |
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
264 |
)
|
|
|
54 |
class FlaxBartAttention(FlaxBartAttention):
|
55 |
"""
|
56 |
Edits:
|
57 |
+
- causal mask is used only in decoder and considers image_length
|
58 |
"""
|
59 |
|
60 |
def setup(self) -> None:
|
|
|
81 |
if self.causal:
|
82 |
# used only in decoder
|
83 |
self.causal_mask = make_causal_mask(
|
84 |
+
jnp.ones((1, self.config.image_length), dtype="bool"), dtype="bool"
|
85 |
)
|
86 |
|
87 |
|
|
|
240 |
"""
|
241 |
Edits:
|
242 |
- offset set to 0 (no padding token)
|
243 |
+
- use image_length instead of max_position_embeddings
|
244 |
- use custom FlaxBartDecoderLayerCollection
|
245 |
- embed_tokens cannot be None (issue at compile time)
|
246 |
"""
|
|
|
258 |
# and adjust num_embeddings appropriately. Other models don't have this hack
|
259 |
self.offset = 0
|
260 |
self.embed_positions = nn.Embed(
|
261 |
+
self.config.image_length + self.offset, # image length for BOS
|
262 |
embed_dim,
|
263 |
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
264 |
)
|
tools/train/train.py
CHANGED
@@ -99,7 +99,7 @@ class ModelArguments:
|
|
99 |
|
100 |
def __post_init__(self):
|
101 |
if self.restore_state:
|
102 |
-
assert (
|
103 |
"/model-" in self.model_name_or_path
|
104 |
), "Restoring state only available with W&B artifact reference"
|
105 |
self.state_artifact = self.model_name_or_path.replace(
|
@@ -222,12 +222,13 @@ class TrainingArguments:
|
|
222 |
)
|
223 |
|
224 |
per_device_train_batch_size: int = field(
|
225 |
-
default=8,
|
|
|
226 |
)
|
227 |
per_device_eval_batch_size: Optional[int] = field(
|
228 |
default=None,
|
229 |
metadata={
|
230 |
-
"help": "Batch size per
|
231 |
},
|
232 |
)
|
233 |
|
@@ -523,12 +524,7 @@ def main():
|
|
523 |
# Preprocessing the datasets.
|
524 |
# We need to normalize and tokenize inputs and targets.
|
525 |
|
526 |
-
dataset.preprocess(
|
527 |
-
tokenizer=tokenizer,
|
528 |
-
decoder_start_token_id=model.config.decoder_start_token_id,
|
529 |
-
normalize_text=model.config.normalize_text,
|
530 |
-
max_length=model.config.max_text_length,
|
531 |
-
)
|
532 |
|
533 |
# Initialize our training
|
534 |
rng = jax.random.PRNGKey(training_args.seed_model)
|
@@ -874,9 +870,17 @@ def main():
|
|
874 |
|
875 |
# Define eval fn
|
876 |
def eval_step(state, batch):
|
877 |
-
|
878 |
-
|
879 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
880 |
return loss
|
881 |
|
882 |
# Create parallel version of the train and eval step
|
@@ -946,7 +950,18 @@ def main():
|
|
946 |
leave=False,
|
947 |
total=eval_steps,
|
948 |
):
|
949 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
950 |
batch = freeze(batch)
|
951 |
# accumulate losses async
|
952 |
eval_loss.append(p_eval_step(state, batch))
|
|
|
99 |
|
100 |
def __post_init__(self):
|
101 |
if self.restore_state:
|
102 |
+
assert self.model_name_or_path is not None and (
|
103 |
"/model-" in self.model_name_or_path
|
104 |
), "Restoring state only available with W&B artifact reference"
|
105 |
self.state_artifact = self.model_name_or_path.replace(
|
|
|
222 |
)
|
223 |
|
224 |
per_device_train_batch_size: int = field(
|
225 |
+
default=8,
|
226 |
+
metadata={"help": "Batch size per data parallel device for training."},
|
227 |
)
|
228 |
per_device_eval_batch_size: Optional[int] = field(
|
229 |
default=None,
|
230 |
metadata={
|
231 |
+
"help": "Batch size per data parallel device for evaluation. Same as training batch size if not set."
|
232 |
},
|
233 |
)
|
234 |
|
|
|
524 |
# Preprocessing the datasets.
|
525 |
# We need to normalize and tokenize inputs and targets.
|
526 |
|
527 |
+
dataset.preprocess(tokenizer=tokenizer, config=model.config)
|
|
|
|
|
|
|
|
|
|
|
528 |
|
529 |
# Initialize our training
|
530 |
rng = jax.random.PRNGKey(training_args.seed_model)
|
|
|
870 |
|
871 |
# Define eval fn
|
872 |
def eval_step(state, batch):
|
873 |
+
def compute_eval_loss(batch):
|
874 |
+
batch, labels = batch.pop("labels")
|
875 |
+
logits = state.apply_fn(**batch, params=state.params, train=False)[0]
|
876 |
+
return loss_fn(logits, labels)
|
877 |
+
|
878 |
+
# calculate loss independently per dp_device
|
879 |
+
loss = jax.vmap(compute_eval_loss, in_axes=(0,), out_axes=0)(batch)
|
880 |
+
# ensure they are sharded over dp devices
|
881 |
+
loss = with_sharding_constraint(loss, PartitionSpec("batch"))
|
882 |
+
# average across all devices
|
883 |
+
loss = jnp.mean(loss)
|
884 |
return loss
|
885 |
|
886 |
# Create parallel version of the train and eval step
|
|
|
950 |
leave=False,
|
951 |
total=eval_steps,
|
952 |
):
|
953 |
+
# reshape data into (dp_devices, batch_per_dp, ...)
|
954 |
+
batch = jax.tree_map(
|
955 |
+
lambda x: x.reshape(
|
956 |
+
(
|
957 |
+
training_args.dp_devices,
|
958 |
+
training_args.per_device_eval_batch_size,
|
959 |
+
)
|
960 |
+
+ x.shape[1:]
|
961 |
+
),
|
962 |
+
batch,
|
963 |
+
)
|
964 |
+
# freeze batch to pass safely to jax transforms
|
965 |
batch = freeze(batch)
|
966 |
# accumulate losses async
|
967 |
eval_loss.append(p_eval_step(state, batch))
|