Spaces:
Running
Running
feat: log more metrics
Browse files- tools/train/train.py +41 -20
tools/train/train.py
CHANGED
@@ -331,14 +331,37 @@ def create_learning_rate_fn(
|
|
331 |
return schedule_fn
|
332 |
|
333 |
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
}
|
339 |
-
|
340 |
-
|
341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
|
343 |
|
344 |
def main():
|
@@ -628,9 +651,10 @@ def main():
|
|
628 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
629 |
)
|
630 |
|
|
|
631 |
if jax.process_index() == 0:
|
632 |
# set default x-axis as 'train/step'
|
633 |
-
|
634 |
wandb.define_metric("*", step_metric="train/step")
|
635 |
|
636 |
# add interesting config parameters
|
@@ -672,7 +696,9 @@ def main():
|
|
672 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
673 |
|
674 |
# log metrics
|
675 |
-
|
|
|
|
|
676 |
|
677 |
# Print metrics and update progress bar
|
678 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
@@ -772,7 +798,7 @@ def main():
|
|
772 |
for epoch in epochs:
|
773 |
state.replace(epoch=jax_utils.replicate(epoch))
|
774 |
# ======================== Training ================================
|
775 |
-
|
776 |
|
777 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
778 |
train_loader = dataset.dataloader("train", train_batch_size)
|
@@ -797,17 +823,12 @@ def main():
|
|
797 |
step = unreplicate(state.step)
|
798 |
|
799 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
800 |
-
|
801 |
-
|
802 |
-
# log state parameters
|
803 |
-
state_dict = {
|
804 |
-
k.split("_")[-1]: unreplicate(getattr(state, k))
|
805 |
-
for k in ["epoch", "train_time", "train_samples"]
|
806 |
-
}
|
807 |
-
wandb_log({**metrics, **state_dict}, step=step, prefix="train")
|
808 |
|
809 |
eval_metrics = None
|
810 |
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
|
|
811 |
eval_metrics = run_evaluation()
|
812 |
|
813 |
if step % training_args.save_steps == 0:
|
@@ -815,8 +836,8 @@ def main():
|
|
815 |
|
816 |
# log final train metrics
|
817 |
if train_metrics is not None:
|
818 |
-
|
819 |
-
|
820 |
|
821 |
epochs.write(
|
822 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
|
|
|
331 |
return schedule_fn
|
332 |
|
333 |
|
334 |
+
class MetricsLogger:
|
335 |
+
def __init__(self, state):
|
336 |
+
self.step = state.step
|
337 |
+
self.time = time.perf_counter()
|
338 |
+
|
339 |
+
def get_all_train_metrics(self, train_metrics, state):
|
340 |
+
"""Make a dict of training metrics to be logged"""
|
341 |
+
metrics = unreplicate(train_metrics)
|
342 |
+
# get state parameters
|
343 |
+
state_dict = {
|
344 |
+
k.split("_")[-1]: unreplicate(getattr(state, k))
|
345 |
+
for k in ["epoch", "train_time", "train_samples"]
|
346 |
}
|
347 |
+
# timing metrics
|
348 |
+
new_step = int(unreplicate(state.step))
|
349 |
+
new_time = time.perf_counter()
|
350 |
+
time_per_step = (new_time - self.time) / (new_step - self.step)
|
351 |
+
self.step = new_step
|
352 |
+
self.time = new_time
|
353 |
+
return {**metrics, **state_dict, "time_per_step": time_per_step}
|
354 |
+
|
355 |
+
@staticmethod
|
356 |
+
def log(metrics, step=None, prefix=None):
|
357 |
+
if jax.process_index() == 0:
|
358 |
+
log_metrics = {
|
359 |
+
f"{prefix}/{k}" if prefix is not None else k: v
|
360 |
+
for k, v in metrics.items()
|
361 |
+
}
|
362 |
+
if step is not None:
|
363 |
+
log_metrics["train/step"] = step
|
364 |
+
wandb.log(log_metrics)
|
365 |
|
366 |
|
367 |
def main():
|
|
|
651 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
652 |
)
|
653 |
|
654 |
+
metrics_logger = MetricsLogger(state)
|
655 |
if jax.process_index() == 0:
|
656 |
# set default x-axis as 'train/step'
|
657 |
+
metrics_logger.log({}, step=state.step)
|
658 |
wandb.define_metric("*", step_metric="train/step")
|
659 |
|
660 |
# add interesting config parameters
|
|
|
696 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
697 |
|
698 |
# log metrics
|
699 |
+
metrics_logger.log(
|
700 |
+
eval_metrics, step=unreplicate(state.step), prefix="eval"
|
701 |
+
)
|
702 |
|
703 |
# Print metrics and update progress bar
|
704 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
|
798 |
for epoch in epochs:
|
799 |
state.replace(epoch=jax_utils.replicate(epoch))
|
800 |
# ======================== Training ================================
|
801 |
+
metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
|
802 |
|
803 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
804 |
train_loader = dataset.dataloader("train", train_batch_size)
|
|
|
823 |
step = unreplicate(state.step)
|
824 |
|
825 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
826 |
+
all_metrics = metrics_logger.get_all_train_metrics(train_metrics, state)
|
827 |
+
metrics_logger.log(all_metrics, step=step, prefix="train")
|
|
|
|
|
|
|
|
|
|
|
|
|
828 |
|
829 |
eval_metrics = None
|
830 |
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
831 |
+
return
|
832 |
eval_metrics = run_evaluation()
|
833 |
|
834 |
if step % training_args.save_steps == 0:
|
|
|
836 |
|
837 |
# log final train metrics
|
838 |
if train_metrics is not None:
|
839 |
+
all_metrics = metrics_logger.get_all_train_metrics(train_metrics, state)
|
840 |
+
metrics_logger.log(all_metrics, step=step, prefix="train")
|
841 |
|
842 |
epochs.write(
|
843 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
|