Spaces:
Running
on
Zero
Running
on
Zero
Update hf_merge.py
Browse files- hf_merge.py +34 -14
hf_merge.py
CHANGED
@@ -121,19 +121,33 @@ class ModelMerger:
|
|
121 |
return True # for safe
|
122 |
|
123 |
|
124 |
-
def upload_model(
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
def get_max_vocab_size(repo_list):
|
138 |
"""
|
139 |
Get the maximum vocabulary size from a list of repositories.
|
@@ -211,7 +225,13 @@ def main():
|
|
211 |
if not args.token or not args.repo:
|
212 |
logging.error("Error: HuggingFace token and repo name are required for uploading.")
|
213 |
else:
|
214 |
-
model_merger.upload_model(output_dir, args.
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
if __name__ == "__main__":
|
217 |
main()
|
|
|
121 |
return True # for safe
|
122 |
|
123 |
|
124 |
+
def upload_model(self, output_dir, commit_message):
|
125 |
+
"""
|
126 |
+
Upload the merged model to HuggingFace.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
output_dir (str): The path to the output directory containing the merged model.
|
130 |
+
commit_message (str): The commit message for the upload.
|
131 |
+
"""
|
132 |
+
try:
|
133 |
+
# Create a new repository if it doesn't exist
|
134 |
+
if not self.api.repo_exists(repo_id=self.repo_id, token=self.token):
|
135 |
+
self.api.create_repo(repo_id=self.repo_id, token=self.token, private=True)
|
136 |
+
|
137 |
+
# Upload the folder to the repository
|
138 |
+
self.api.upload_folder(
|
139 |
+
repo_id=self.repo_id,
|
140 |
+
folder_path=output_dir,
|
141 |
+
commit_message=commit_message,
|
142 |
+
token=self.token
|
143 |
+
)
|
144 |
+
url = f"https://huggingface.co/{self.repo_id}"
|
145 |
+
logging.info(f"Model uploaded successfully to {url}")
|
146 |
+
return url
|
147 |
+
except Exception as e:
|
148 |
+
logging.error(f"Error: Failed to upload to {self.repo_id}.")
|
149 |
+
logging.error(e)
|
150 |
+
return ""
|
151 |
def get_max_vocab_size(repo_list):
|
152 |
"""
|
153 |
Get the maximum vocabulary size from a list of repositories.
|
|
|
225 |
if not args.token or not args.repo:
|
226 |
logging.error("Error: HuggingFace token and repo name are required for uploading.")
|
227 |
else:
|
228 |
+
url = model_merger.upload_model(output_dir, args.commit_message)
|
229 |
+
if url:
|
230 |
+
logging.info(f"Model uploaded successfully to {url}")
|
231 |
+
|
232 |
+
if __name__ == "__main__":
|
233 |
+
main()
|
234 |
+
|
235 |
|
236 |
if __name__ == "__main__":
|
237 |
main()
|