boris commited on
Commit
635402d
2 Parent(s): 7aa2f4b 566d5f2

Merge pull request #20 from borisdayma/eval-interval

Browse files

Add eval_interval to evaluate and log every so often.

Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +54 -34
seq2seq/run_seq2seq_flax.py CHANGED
@@ -226,6 +226,12 @@ class DataTrainingArguments:
226
  "value if set."
227
  },
228
  )
 
 
 
 
 
 
229
  log_model: bool = field(
230
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
231
  )
@@ -740,37 +746,8 @@ def main():
740
  train_time = 0
741
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
742
  global_step = 0
743
- for epoch in epochs:
744
- # ======================== Training ================================
745
- train_start = time.time()
746
-
747
- # Create sampling rng
748
- rng, input_rng = jax.random.split(rng)
749
- train_metrics = []
750
-
751
- # Generate an epoch by shuffling sampling indices from the train dataset
752
- train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
753
- steps_per_epoch = len(train_dataset) // train_batch_size
754
- # train
755
- for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
756
- global_step +=1
757
- batch = next(train_loader)
758
- state, train_metric = p_train_step(state, batch)
759
- train_metrics.append(train_metric)
760
-
761
- if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
762
- for k, v in unreplicate(train_metric).items():
763
- wandb.log({"train/step": global_step})
764
- wandb.log({f"train/{k}": jax.device_get(v)})
765
-
766
- train_time += time.time() - train_start
767
-
768
- train_metric = unreplicate(train_metric)
769
-
770
- epochs.write(
771
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
772
- )
773
 
 
774
  # ======================== Evaluating ==============================
775
  eval_metrics = []
776
  if training_args.do_eval:
@@ -797,17 +774,60 @@ def main():
797
  eval_metrics = get_metrics(eval_metrics)
798
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
799
 
 
 
 
 
 
800
  # compute ROUGE metrics
801
  rouge_desc = ""
802
- # if data_args.predict_with_generate:
803
- # rouge_metrics = compute_metrics(eval_preds, eval_labels)
804
- # eval_metrics.update(rouge_metrics)
805
- # rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
806
 
807
  # Print metrics and update progress bar
808
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
809
  epochs.write(desc)
810
  epochs.desc = desc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
811
 
812
  # Save metrics
813
  if has_tensorboard and jax.process_index() == 0:
 
226
  "value if set."
227
  },
228
  )
229
+ eval_interval: Optional[int] = field(
230
+ default=40,
231
+ metadata={
232
+ "help": "Evaluation will be performed every eval_interval steps"
233
+ },
234
+ )
235
  log_model: bool = field(
236
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
237
  )
 
746
  train_time = 0
747
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
748
  global_step = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
749
 
750
+ def run_evaluation():
751
  # ======================== Evaluating ==============================
752
  eval_metrics = []
753
  if training_args.do_eval:
 
774
  eval_metrics = get_metrics(eval_metrics)
775
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
776
 
777
+ if jax.process_index() == 0:
778
+ for k, v in eval_metrics.items():
779
+ wandb.log({"eval/step": global_step})
780
+ wandb.log({f"eval/{k}": jax.device_get(v)})
781
+
782
  # compute ROUGE metrics
783
  rouge_desc = ""
784
+ # if data_args.predict_with_generate:
785
+ # rouge_metrics = compute_metrics(eval_preds, eval_labels)
786
+ # eval_metrics.update(rouge_metrics)
787
+ # rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
788
 
789
  # Print metrics and update progress bar
790
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
791
  epochs.write(desc)
792
  epochs.desc = desc
793
+ return eval_metrics
794
+
795
+ for epoch in epochs:
796
+ # ======================== Training ================================
797
+ train_start = time.time()
798
+
799
+ # Create sampling rng
800
+ rng, input_rng = jax.random.split(rng)
801
+ train_metrics = []
802
+
803
+ # Generate an epoch by shuffling sampling indices from the train dataset
804
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
805
+ steps_per_epoch = len(train_dataset) // train_batch_size
806
+ # train
807
+ for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
808
+ global_step +=1
809
+ batch = next(train_loader)
810
+ state, train_metric = p_train_step(state, batch)
811
+ train_metrics.append(train_metric)
812
+
813
+ if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
814
+ print("logging train loss")
815
+ for k, v in unreplicate(train_metric).items():
816
+ wandb.log({"train/step": global_step})
817
+ wandb.log({f"train/{k}": jax.device_get(v)})
818
+
819
+ if global_step % data_args.eval_interval == 0 and jax.process_index() == 0:
820
+ run_evaluation()
821
+
822
+ train_time += time.time() - train_start
823
+
824
+ train_metric = unreplicate(train_metric)
825
+
826
+ epochs.write(
827
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
828
+ )
829
+
830
+ eval_metrics = run_evaluation()
831
 
832
  # Save metrics
833
  if has_tensorboard and jax.process_index() == 0: