Spaces:
Running
Running
Merge pull request #22 from borisdayma/feat-axis
Browse files- seq2seq/run_seq2seq_flax.py +25 -57
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -57,7 +57,6 @@ from transformers import (
|
|
57 |
FlaxBartForConditionalGeneration,
|
58 |
HfArgumentParser,
|
59 |
TrainingArguments,
|
60 |
-
is_tensorboard_available,
|
61 |
)
|
62 |
from transformers.models.bart.modeling_flax_bart import *
|
63 |
from transformers.file_utils import is_offline_mode
|
@@ -229,12 +228,6 @@ class DataTrainingArguments:
|
|
229 |
"value if set."
|
230 |
},
|
231 |
)
|
232 |
-
eval_interval: Optional[int] = field(
|
233 |
-
default=400,
|
234 |
-
metadata={
|
235 |
-
"help": "Evaluation will be performed every eval_interval steps"
|
236 |
-
},
|
237 |
-
)
|
238 |
log_model: bool = field(
|
239 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
240 |
)
|
@@ -327,19 +320,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
327 |
yield batch
|
328 |
|
329 |
|
330 |
-
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
331 |
-
summary_writer.scalar("train_time", train_time, step)
|
332 |
-
|
333 |
-
train_metrics = get_metrics(train_metrics)
|
334 |
-
for key, vals in train_metrics.items():
|
335 |
-
tag = f"train_epoch/{key}"
|
336 |
-
for i, val in enumerate(vals):
|
337 |
-
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
338 |
-
|
339 |
-
for metric_name, value in eval_metrics.items():
|
340 |
-
summary_writer.scalar(f"eval/{metric_name}", value, step)
|
341 |
-
|
342 |
-
|
343 |
def create_learning_rate_fn(
|
344 |
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
|
345 |
) -> Callable[[int], jnp.array]:
|
@@ -356,6 +336,14 @@ def create_learning_rate_fn(
|
|
356 |
return schedule_fn
|
357 |
|
358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
def main():
|
360 |
# See all possible arguments in src/transformers/training_args.py
|
361 |
# or by passing the --help flag to this script.
|
@@ -369,6 +357,9 @@ def main():
|
|
369 |
else:
|
370 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
371 |
|
|
|
|
|
|
|
372 |
if (
|
373 |
os.path.exists(training_args.output_dir)
|
374 |
and os.listdir(training_args.output_dir)
|
@@ -382,13 +373,16 @@ def main():
|
|
382 |
|
383 |
# Set up wandb run
|
384 |
wandb.init(
|
385 |
-
sync_tensorboard=True,
|
386 |
entity='wandb',
|
387 |
project='hf-flax-dalle-mini',
|
388 |
job_type='Seq2SeqVQGAN',
|
389 |
config=parser.parse_args()
|
390 |
)
|
391 |
|
|
|
|
|
|
|
|
|
392 |
# Make one log on every process with the configuration for debugging.
|
393 |
pylogging.basicConfig(
|
394 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
@@ -583,24 +577,6 @@ def main():
|
|
583 |
result = {k: round(v, 4) for k, v in result.items()}
|
584 |
return result
|
585 |
|
586 |
-
# Enable tensorboard only on the master node
|
587 |
-
has_tensorboard = is_tensorboard_available()
|
588 |
-
if has_tensorboard and jax.process_index() == 0:
|
589 |
-
try:
|
590 |
-
from flax.metrics.tensorboard import SummaryWriter
|
591 |
-
|
592 |
-
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
593 |
-
except ImportError as ie:
|
594 |
-
has_tensorboard = False
|
595 |
-
logger.warning(
|
596 |
-
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
597 |
-
)
|
598 |
-
else:
|
599 |
-
logger.warning(
|
600 |
-
"Unable to display metrics through TensorBoard because the package is not installed: "
|
601 |
-
"Please run pip install tensorboard to enable."
|
602 |
-
)
|
603 |
-
|
604 |
# Initialize our training
|
605 |
rng = jax.random.PRNGKey(training_args.seed)
|
606 |
rng, dropout_rng = jax.random.split(rng)
|
@@ -780,10 +756,8 @@ def main():
|
|
780 |
eval_metrics = get_metrics(eval_metrics)
|
781 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
782 |
|
783 |
-
|
784 |
-
|
785 |
-
wandb.log({"eval/step": global_step})
|
786 |
-
wandb.log({f"eval/{k}": jax.device_get(v)})
|
787 |
|
788 |
# compute ROUGE metrics
|
789 |
rouge_desc = ""
|
@@ -796,6 +770,7 @@ def main():
|
|
796 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
797 |
epochs.write(desc)
|
798 |
epochs.desc = desc
|
|
|
799 |
return eval_metrics
|
800 |
|
801 |
for epoch in epochs:
|
@@ -804,7 +779,6 @@ def main():
|
|
804 |
|
805 |
# Create sampling rng
|
806 |
rng, input_rng = jax.random.split(rng)
|
807 |
-
train_metrics = []
|
808 |
|
809 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
810 |
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
@@ -814,32 +788,26 @@ def main():
|
|
814 |
global_step +=1
|
815 |
batch = next(train_loader)
|
816 |
state, train_metric = p_train_step(state, batch)
|
817 |
-
train_metrics.append(train_metric)
|
818 |
|
819 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
820 |
-
|
821 |
-
|
822 |
-
wandb.log({"train/step": global_step})
|
823 |
-
wandb.log({f"train/{k}": jax.device_get(v)})
|
824 |
|
825 |
-
if global_step %
|
826 |
run_evaluation()
|
|
|
|
|
|
|
827 |
|
828 |
train_time += time.time() - train_start
|
829 |
-
|
830 |
train_metric = unreplicate(train_metric)
|
831 |
-
|
832 |
epochs.write(
|
833 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
834 |
)
|
835 |
|
|
|
836 |
eval_metrics = run_evaluation()
|
837 |
|
838 |
-
# Save metrics
|
839 |
-
if has_tensorboard and jax.process_index() == 0:
|
840 |
-
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
841 |
-
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
842 |
-
|
843 |
# save checkpoint after each epoch and push checkpoint to the hub
|
844 |
if jax.process_index() == 0:
|
845 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
|
|
57 |
FlaxBartForConditionalGeneration,
|
58 |
HfArgumentParser,
|
59 |
TrainingArguments,
|
|
|
60 |
)
|
61 |
from transformers.models.bart.modeling_flax_bart import *
|
62 |
from transformers.file_utils import is_offline_mode
|
|
|
228 |
"value if set."
|
229 |
},
|
230 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
log_model: bool = field(
|
232 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
233 |
)
|
|
|
320 |
yield batch
|
321 |
|
322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
def create_learning_rate_fn(
|
324 |
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
|
325 |
) -> Callable[[int], jnp.array]:
|
|
|
336 |
return schedule_fn
|
337 |
|
338 |
|
339 |
+
def wandb_log(metrics, step=None, prefix=None):
|
340 |
+
if jax.process_index() == 0:
|
341 |
+
log_metrics = {f'{prefix}/{k}' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
|
342 |
+
if step is not None:
|
343 |
+
log_metrics = {**log_metrics, 'train/step': step}
|
344 |
+
wandb.log(log_metrics)
|
345 |
+
|
346 |
+
|
347 |
def main():
|
348 |
# See all possible arguments in src/transformers/training_args.py
|
349 |
# or by passing the --help flag to this script.
|
|
|
357 |
else:
|
358 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
359 |
|
360 |
+
logger.warning(f"eval_steps has been manually hardcoded") # TODO: remove it later, convenient for now
|
361 |
+
training_args.eval_steps = 400
|
362 |
+
|
363 |
if (
|
364 |
os.path.exists(training_args.output_dir)
|
365 |
and os.listdir(training_args.output_dir)
|
|
|
373 |
|
374 |
# Set up wandb run
|
375 |
wandb.init(
|
|
|
376 |
entity='wandb',
|
377 |
project='hf-flax-dalle-mini',
|
378 |
job_type='Seq2SeqVQGAN',
|
379 |
config=parser.parse_args()
|
380 |
)
|
381 |
|
382 |
+
# set default x-axis as 'train/step'
|
383 |
+
wandb.define_metric('train/step')
|
384 |
+
wandb.define_metric('*', step_metric='train/step')
|
385 |
+
|
386 |
# Make one log on every process with the configuration for debugging.
|
387 |
pylogging.basicConfig(
|
388 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
|
577 |
result = {k: round(v, 4) for k, v in result.items()}
|
578 |
return result
|
579 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
580 |
# Initialize our training
|
581 |
rng = jax.random.PRNGKey(training_args.seed)
|
582 |
rng, dropout_rng = jax.random.split(rng)
|
|
|
756 |
eval_metrics = get_metrics(eval_metrics)
|
757 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
758 |
|
759 |
+
# log metrics
|
760 |
+
wandb_log(eval_metrics, step=global_step, prefix='eval')
|
|
|
|
|
761 |
|
762 |
# compute ROUGE metrics
|
763 |
rouge_desc = ""
|
|
|
770 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
771 |
epochs.write(desc)
|
772 |
epochs.desc = desc
|
773 |
+
|
774 |
return eval_metrics
|
775 |
|
776 |
for epoch in epochs:
|
|
|
779 |
|
780 |
# Create sampling rng
|
781 |
rng, input_rng = jax.random.split(rng)
|
|
|
782 |
|
783 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
784 |
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
|
|
788 |
global_step +=1
|
789 |
batch = next(train_loader)
|
790 |
state, train_metric = p_train_step(state, batch)
|
|
|
791 |
|
792 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
793 |
+
# log metrics
|
794 |
+
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
|
|
|
|
795 |
|
796 |
+
if global_step % training_args.eval_steps == 0:
|
797 |
run_evaluation()
|
798 |
+
|
799 |
+
# log final train metrics
|
800 |
+
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
801 |
|
802 |
train_time += time.time() - train_start
|
|
|
803 |
train_metric = unreplicate(train_metric)
|
|
|
804 |
epochs.write(
|
805 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
806 |
)
|
807 |
|
808 |
+
# Final evaluation
|
809 |
eval_metrics = run_evaluation()
|
810 |
|
|
|
|
|
|
|
|
|
|
|
811 |
# save checkpoint after each epoch and push checkpoint to the hub
|
812 |
if jax.process_index() == 0:
|
813 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|