Spaces:
Running
Running
feat: add bucket reference to artifact
Browse files- tools/train/train.py +9 -4
tools/train/train.py
CHANGED
@@ -135,11 +135,12 @@ class ModelArguments:
|
|
135 |
artifact = wandb.run.use_artifact(state_artifact)
|
136 |
else:
|
137 |
artifact = wandb.Api().artifact(state_artifact)
|
138 |
-
artifact_dir = artifact.download(tmp_dir)
|
139 |
if artifact.metadata.get("bucket_path"):
|
|
|
140 |
self.restore_state = artifact.metadata["bucket_path"]
|
141 |
else:
|
142 |
-
|
|
|
143 |
|
144 |
if self.restore_state.startswith("gs://"):
|
145 |
bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
|
@@ -1130,7 +1131,9 @@ def main():
|
|
1130 |
type="DalleBart_model",
|
1131 |
metadata=metadata,
|
1132 |
)
|
1133 |
-
if
|
|
|
|
|
1134 |
for filename in [
|
1135 |
"config.json",
|
1136 |
"flax_model.msgpack",
|
@@ -1153,7 +1156,9 @@ def main():
|
|
1153 |
type="DalleBart_state",
|
1154 |
metadata=metadata,
|
1155 |
)
|
1156 |
-
if
|
|
|
|
|
1157 |
artifact_state.add_file(
|
1158 |
f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
|
1159 |
)
|
|
|
135 |
artifact = wandb.run.use_artifact(state_artifact)
|
136 |
else:
|
137 |
artifact = wandb.Api().artifact(state_artifact)
|
|
|
138 |
if artifact.metadata.get("bucket_path"):
|
139 |
+
# we will read directly file contents
|
140 |
self.restore_state = artifact.metadata["bucket_path"]
|
141 |
else:
|
142 |
+
artifact_dir = artifact.download(tmp_dir)
|
143 |
+
self.restore_state = str(Path(artifact_dir) / "opt_state.msgpack")
|
144 |
|
145 |
if self.restore_state.startswith("gs://"):
|
146 |
bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
|
|
|
1131 |
type="DalleBart_model",
|
1132 |
metadata=metadata,
|
1133 |
)
|
1134 |
+
if use_bucket:
|
1135 |
+
artifact.add_reference(metadata["bucket_path"])
|
1136 |
+
else:
|
1137 |
for filename in [
|
1138 |
"config.json",
|
1139 |
"flax_model.msgpack",
|
|
|
1156 |
type="DalleBart_state",
|
1157 |
metadata=metadata,
|
1158 |
)
|
1159 |
+
if use_bucket:
|
1160 |
+
artifact_state.add_reference(metadata["bucket_path"])
|
1161 |
+
else:
|
1162 |
artifact_state.add_file(
|
1163 |
f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
|
1164 |
)
|