Spaces:
Running
Running
feat: update defaults
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -253,7 +253,7 @@ class DataTrainingArguments:
|
|
253 |
metadata={"help": "Overwrite the cached training and evaluation sets"},
|
254 |
)
|
255 |
save_model_steps: Optional[int] = field(
|
256 |
-
default=
|
257 |
metadata={
|
258 |
"help": "For logging the model more frequently. Used only when `log_model` is set."
|
259 |
},
|
@@ -290,9 +290,9 @@ class DataTrainingArguments:
|
|
290 |
|
291 |
|
292 |
class TrainState(train_state.TrainState):
|
293 |
-
dropout_rng: jnp.ndarray
|
294 |
-
grad_accum: jnp.ndarray
|
295 |
-
optimizer_step: int
|
296 |
|
297 |
def replicate(self):
|
298 |
return jax_utils.replicate(self).replace(
|
|
|
253 |
metadata={"help": "Overwrite the cached training and evaluation sets"},
|
254 |
)
|
255 |
save_model_steps: Optional[int] = field(
|
256 |
+
default=5000, # about once every 1.5h in our experiments
|
257 |
metadata={
|
258 |
"help": "For logging the model more frequently. Used only when `log_model` is set."
|
259 |
},
|
|
|
290 |
|
291 |
|
292 |
class TrainState(train_state.TrainState):
|
293 |
+
dropout_rng: jnp.ndarray = None
|
294 |
+
grad_accum: jnp.ndarray = None
|
295 |
+
optimizer_step: int = None
|
296 |
|
297 |
def replicate(self):
|
298 |
return jax_utils.replicate(self).replace(
|