skytnt commited on
Commit
98bc719
·
1 Parent(s): 6532db1

add retry to avoid runtime error

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -199,6 +199,18 @@ def load_javascript(dir="javascript"):
199
  gr.routes.templates.TemplateResponse = template_response
200
 
201
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
203
  40: "Blush", 48: "Orchestra"}
204
  patch2number = {v: k for k, v in MIDI.Number2patch.items()}
@@ -210,7 +222,7 @@ if __name__ == "__main__":
210
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
211
  parser.add_argument("--max-gen", type=int, default=1024, help="max")
212
  opt = parser.parse_args()
213
- soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
214
  models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
215
  "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
216
  "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
@@ -219,8 +231,8 @@ if __name__ == "__main__":
219
  tokenizer = MIDITokenizer()
220
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
221
  for name, (repo_id, path) in models_info.items():
222
- model_base_path = hf_hub_download(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
223
- model_token_path = hf_hub_download(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
224
  model_base = rt.InferenceSession(model_base_path, providers=providers)
225
  model_token = rt.InferenceSession(model_token_path, providers=providers)
226
  models[name] = [model_base, model_token]
 
199
  gr.routes.templates.TemplateResponse = template_response
200
 
201
 
202
+ def hf_hub_download_retry(repo_id, filename):
203
+ retry = 0
204
+ err = None
205
+ while retry < 30:
206
+ try:
207
+ return hf_hub_download(repo_id=repo_id, filename=filename)
208
+ except Exception as e:
209
+ err = e
210
+ retry += 1
211
+ if err:
212
+ raise err
213
+
214
  number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
215
  40: "Blush", 48: "Orchestra"}
216
  patch2number = {v: k for k, v in MIDI.Number2patch.items()}
 
222
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
223
  parser.add_argument("--max-gen", type=int, default=1024, help="max")
224
  opt = parser.parse_args()
225
+ soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
226
  models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
227
  "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
228
  "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
 
231
  tokenizer = MIDITokenizer()
232
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
233
  for name, (repo_id, path) in models_info.items():
234
+ model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
235
+ model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
236
  model_base = rt.InferenceSession(model_base_path, providers=providers)
237
  model_token = rt.InferenceSession(model_token_path, providers=providers)
238
  models[name] = [model_base, model_token]