mrcuddle commited on
Commit
980a843
1 Parent(s): df39ff5

Update hf_merge.py

Browse files
Files changed (1) hide show
  1. hf_merge.py +34 -14
hf_merge.py CHANGED
@@ -121,19 +121,33 @@ class ModelMerger:
121
  return True # for safe
122
 
123
 
124
- def upload_model(new_repo_id, diffusers_folder, is_private, token):
125
- from huggingface_hub import HfApi
126
- api = HfApi(token=token)
127
- try:
128
- api.create_repo(repo_id=new_repo_id, token=token)
129
- api.upload_folder(repo_id=new_repo_id, folder_path=output_dir, token=token)
130
- url = f"https://huggingface.co/{repo_name}"
131
- except Exception as e:
132
- print(f"Error: Failed to upload to {repo_name}.")
133
- print(e)
134
- return ""
135
- return url
136
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  def get_max_vocab_size(repo_list):
138
  """
139
  Get the maximum vocabulary size from a list of repositories.
@@ -211,7 +225,13 @@ def main():
211
  if not args.token or not args.repo:
212
  logging.error("Error: HuggingFace token and repo name are required for uploading.")
213
  else:
214
- model_merger.upload_model(output_dir, args.repo, args.commit_message)
 
 
 
 
 
 
215
 
216
  if __name__ == "__main__":
217
  main()
 
121
  return True # for safe
122
 
123
 
124
+ def upload_model(self, output_dir, commit_message):
125
+ """
126
+ Upload the merged model to HuggingFace.
127
+
128
+ Args:
129
+ output_dir (str): The path to the output directory containing the merged model.
130
+ commit_message (str): The commit message for the upload.
131
+ """
132
+ try:
133
+ # Create a new repository if it doesn't exist
134
+ if not self.api.repo_exists(repo_id=self.repo_id, token=self.token):
135
+ self.api.create_repo(repo_id=self.repo_id, token=self.token, private=True)
136
+
137
+ # Upload the folder to the repository
138
+ self.api.upload_folder(
139
+ repo_id=self.repo_id,
140
+ folder_path=output_dir,
141
+ commit_message=commit_message,
142
+ token=self.token
143
+ )
144
+ url = f"https://huggingface.co/{self.repo_id}"
145
+ logging.info(f"Model uploaded successfully to {url}")
146
+ return url
147
+ except Exception as e:
148
+ logging.error(f"Error: Failed to upload to {self.repo_id}.")
149
+ logging.error(e)
150
+ return ""
151
  def get_max_vocab_size(repo_list):
152
  """
153
  Get the maximum vocabulary size from a list of repositories.
 
225
  if not args.token or not args.repo:
226
  logging.error("Error: HuggingFace token and repo name are required for uploading.")
227
  else:
228
+ url = model_merger.upload_model(output_dir, args.commit_message)
229
+ if url:
230
+ logging.info(f"Model uploaded successfully to {url}")
231
+
232
+ if __name__ == "__main__":
233
+ main()
234
+
235
 
236
  if __name__ == "__main__":
237
  main()