Spaces:
Running
Running
fix: OOM with checkpoints
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -262,15 +262,15 @@ class TrainState(train_state.TrainState):
|
|
262 |
def restore_state(self, artifact_dir):
|
263 |
# restore optimizer state
|
264 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
265 |
-
|
266 |
|
267 |
# restore steps
|
268 |
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
269 |
training_state = json.load(f)
|
270 |
-
|
271 |
|
272 |
# replace state
|
273 |
-
return self.replace(step=
|
274 |
|
275 |
|
276 |
class CustomFlaxBartModule(FlaxBartModule):
|
@@ -802,6 +802,7 @@ def main():
|
|
802 |
|
803 |
# Replicate the train state on each device
|
804 |
state = state.replicate()
|
|
|
805 |
|
806 |
logger.info("***** Running training *****")
|
807 |
logger.info(f" Num examples = {len_train_dataset}")
|
|
|
262 |
def restore_state(self, artifact_dir):
|
263 |
# restore optimizer state
|
264 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
265 |
+
new_opt_state = from_bytes(self.opt_state, f.read())
|
266 |
|
267 |
# restore steps
|
268 |
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
269 |
training_state = json.load(f)
|
270 |
+
new_step = training_state["step"]
|
271 |
|
272 |
# replace state
|
273 |
+
return self.replace(step=new_step, opt_state=new_opt_state)
|
274 |
|
275 |
|
276 |
class CustomFlaxBartModule(FlaxBartModule):
|
|
|
802 |
|
803 |
# Replicate the train state on each device
|
804 |
state = state.replicate()
|
805 |
+
del model._params
|
806 |
|
807 |
logger.info("***** Running training *****")
|
808 |
logger.info(f" Num examples = {len_train_dataset}")
|