Spaces:
Running
Running
fix: typos
Browse files
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -643,7 +643,7 @@ def main():
|
|
643 |
state = TrainState.create(
|
644 |
apply_fn=model.__call__,
|
645 |
params=model.params,
|
646 |
-
tx=
|
647 |
dropout_rng=dropout_rng,
|
648 |
grad_accum=jax.tree_map(jnp.zeros_like, model.params),
|
649 |
optimizer_step=0,
|
@@ -755,7 +755,7 @@ def main():
|
|
755 |
|
756 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
757 |
for k, v in unreplicate(train_metric).items():
|
758 |
-
wandb.log(f
|
759 |
|
760 |
train_time += time.time() - train_start
|
761 |
|
|
|
643 |
state = TrainState.create(
|
644 |
apply_fn=model.__call__,
|
645 |
params=model.params,
|
646 |
+
tx=optimizer,
|
647 |
dropout_rng=dropout_rng,
|
648 |
grad_accum=jax.tree_map(jnp.zeros_like, model.params),
|
649 |
optimizer_step=0,
|
|
|
755 |
|
756 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
757 |
for k, v in unreplicate(train_metric).items():
|
758 |
+
wandb.log({f"train/{k}": jax.device_get(v)}, step=global_step)
|
759 |
|
760 |
train_time += time.time() - train_start
|
761 |
|