boris commited on
Commit
d368fb6
1 Parent(s): d5d442a

feat: add bucket reference to artifact

Browse files
Files changed (1) hide show
  1. tools/train/train.py +9 -4
tools/train/train.py CHANGED
@@ -135,11 +135,12 @@ class ModelArguments:
135
  artifact = wandb.run.use_artifact(state_artifact)
136
  else:
137
  artifact = wandb.Api().artifact(state_artifact)
138
- artifact_dir = artifact.download(tmp_dir)
139
  if artifact.metadata.get("bucket_path"):
 
140
  self.restore_state = artifact.metadata["bucket_path"]
141
  else:
142
- self.restore_state = Path(artifact_dir) / "opt_state.msgpack"
 
143
 
144
  if self.restore_state.startswith("gs://"):
145
  bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
@@ -1130,7 +1131,9 @@ def main():
1130
  type="DalleBart_model",
1131
  metadata=metadata,
1132
  )
1133
- if not use_bucket:
 
 
1134
  for filename in [
1135
  "config.json",
1136
  "flax_model.msgpack",
@@ -1153,7 +1156,9 @@ def main():
1153
  type="DalleBart_state",
1154
  metadata=metadata,
1155
  )
1156
- if not use_bucket:
 
 
1157
  artifact_state.add_file(
1158
  f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
1159
  )
 
135
  artifact = wandb.run.use_artifact(state_artifact)
136
  else:
137
  artifact = wandb.Api().artifact(state_artifact)
 
138
  if artifact.metadata.get("bucket_path"):
139
+ # we will read directly file contents
140
  self.restore_state = artifact.metadata["bucket_path"]
141
  else:
142
+ artifact_dir = artifact.download(tmp_dir)
143
+ self.restore_state = str(Path(artifact_dir) / "opt_state.msgpack")
144
 
145
  if self.restore_state.startswith("gs://"):
146
  bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
 
1131
  type="DalleBart_model",
1132
  metadata=metadata,
1133
  )
1134
+ if use_bucket:
1135
+ artifact.add_reference(metadata["bucket_path"])
1136
+ else:
1137
  for filename in [
1138
  "config.json",
1139
  "flax_model.msgpack",
 
1156
  type="DalleBart_state",
1157
  metadata=metadata,
1158
  )
1159
+ if use_bucket:
1160
+ artifact_state.add_reference(metadata["bucket_path"])
1161
+ else:
1162
  artifact_state.add_file(
1163
  f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
1164
  )