Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
•
566d5f2
1
Parent(s):
835ea55
Add eval_interval to evaluate and log every so often.
Browse files- seq2seq/run_seq2seq_flax.py +54 -34
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -225,6 +225,12 @@ class DataTrainingArguments:
|
|
225 |
"value if set."
|
226 |
},
|
227 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
log_model: bool = field(
|
229 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
230 |
)
|
@@ -738,37 +744,8 @@ def main():
|
|
738 |
train_time = 0
|
739 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
740 |
global_step = 0
|
741 |
-
for epoch in epochs:
|
742 |
-
# ======================== Training ================================
|
743 |
-
train_start = time.time()
|
744 |
-
|
745 |
-
# Create sampling rng
|
746 |
-
rng, input_rng = jax.random.split(rng)
|
747 |
-
train_metrics = []
|
748 |
-
|
749 |
-
# Generate an epoch by shuffling sampling indices from the train dataset
|
750 |
-
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
751 |
-
steps_per_epoch = len(train_dataset) // train_batch_size
|
752 |
-
# train
|
753 |
-
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
754 |
-
global_step +=1
|
755 |
-
batch = next(train_loader)
|
756 |
-
state, train_metric = p_train_step(state, batch)
|
757 |
-
train_metrics.append(train_metric)
|
758 |
-
|
759 |
-
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
760 |
-
for k, v in unreplicate(train_metric).items():
|
761 |
-
wandb.log({"train/step": global_step})
|
762 |
-
wandb.log({f"train/{k}": jax.device_get(v)})
|
763 |
-
|
764 |
-
train_time += time.time() - train_start
|
765 |
-
|
766 |
-
train_metric = unreplicate(train_metric)
|
767 |
-
|
768 |
-
epochs.write(
|
769 |
-
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
770 |
-
)
|
771 |
|
|
|
772 |
# ======================== Evaluating ==============================
|
773 |
eval_metrics = []
|
774 |
if training_args.do_eval:
|
@@ -795,17 +772,60 @@ def main():
|
|
795 |
eval_metrics = get_metrics(eval_metrics)
|
796 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
797 |
|
|
|
|
|
|
|
|
|
|
|
798 |
# compute ROUGE metrics
|
799 |
rouge_desc = ""
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
|
805 |
# Print metrics and update progress bar
|
806 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
807 |
epochs.write(desc)
|
808 |
epochs.desc = desc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
809 |
|
810 |
# Save metrics
|
811 |
if has_tensorboard and jax.process_index() == 0:
|
|
|
225 |
"value if set."
|
226 |
},
|
227 |
)
|
228 |
+
eval_interval: Optional[int] = field(
|
229 |
+
default=40,
|
230 |
+
metadata={
|
231 |
+
"help": "Evaluation will be performed every eval_interval steps"
|
232 |
+
},
|
233 |
+
)
|
234 |
log_model: bool = field(
|
235 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
236 |
)
|
|
|
744 |
train_time = 0
|
745 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
746 |
global_step = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
747 |
|
748 |
+
def run_evaluation():
|
749 |
# ======================== Evaluating ==============================
|
750 |
eval_metrics = []
|
751 |
if training_args.do_eval:
|
|
|
772 |
eval_metrics = get_metrics(eval_metrics)
|
773 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
774 |
|
775 |
+
if jax.process_index() == 0:
|
776 |
+
for k, v in eval_metrics.items():
|
777 |
+
wandb.log({"eval/step": global_step})
|
778 |
+
wandb.log({f"eval/{k}": jax.device_get(v)})
|
779 |
+
|
780 |
# compute ROUGE metrics
|
781 |
rouge_desc = ""
|
782 |
+
# if data_args.predict_with_generate:
|
783 |
+
# rouge_metrics = compute_metrics(eval_preds, eval_labels)
|
784 |
+
# eval_metrics.update(rouge_metrics)
|
785 |
+
# rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
|
786 |
|
787 |
# Print metrics and update progress bar
|
788 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
789 |
epochs.write(desc)
|
790 |
epochs.desc = desc
|
791 |
+
return eval_metrics
|
792 |
+
|
793 |
+
for epoch in epochs:
|
794 |
+
# ======================== Training ================================
|
795 |
+
train_start = time.time()
|
796 |
+
|
797 |
+
# Create sampling rng
|
798 |
+
rng, input_rng = jax.random.split(rng)
|
799 |
+
train_metrics = []
|
800 |
+
|
801 |
+
# Generate an epoch by shuffling sampling indices from the train dataset
|
802 |
+
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
803 |
+
steps_per_epoch = len(train_dataset) // train_batch_size
|
804 |
+
# train
|
805 |
+
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
806 |
+
global_step +=1
|
807 |
+
batch = next(train_loader)
|
808 |
+
state, train_metric = p_train_step(state, batch)
|
809 |
+
train_metrics.append(train_metric)
|
810 |
+
|
811 |
+
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
812 |
+
print("logging train loss")
|
813 |
+
for k, v in unreplicate(train_metric).items():
|
814 |
+
wandb.log({"train/step": global_step})
|
815 |
+
wandb.log({f"train/{k}": jax.device_get(v)})
|
816 |
+
|
817 |
+
if global_step % data_args.eval_interval == 0 and jax.process_index() == 0:
|
818 |
+
run_evaluation()
|
819 |
+
|
820 |
+
train_time += time.time() - train_start
|
821 |
+
|
822 |
+
train_metric = unreplicate(train_metric)
|
823 |
+
|
824 |
+
epochs.write(
|
825 |
+
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
826 |
+
)
|
827 |
+
|
828 |
+
eval_metrics = run_evaluation()
|
829 |
|
830 |
# Save metrics
|
831 |
if has_tensorboard and jax.process_index() == 0:
|