multimodalart HF staff commited on
Commit
3cfa62d
1 Parent(s): 83eedef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -19
app.py CHANGED
@@ -229,11 +229,10 @@ h3{margin-top: 0}
229
  #component-1{text-align:center}
230
  .main_ui_logged_out{opacity: 0.3; pointer-events: none}
231
  .tabitem{border: 0px}
 
232
  """
233
 
234
- def swap_visibilty(profile: Union[gr.OAuthProfile, None], oauth_token: Union[gr.OAuthToken, None]):
235
- user = whoami(oauth_token.token)
236
- print(user)
237
  if is_spaces:
238
  if profile is None:
239
  return gr.update(elem_classes=["main_ui_logged_out"])
@@ -242,14 +241,23 @@ def swap_visibilty(profile: Union[gr.OAuthProfile, None], oauth_token: Union[gr.
242
  else:
243
  return gr.update(elem_classes=["main_ui_logged_in"])
244
 
245
- def update_pricing(steps):
246
- seconds_per_iteration = 7.54
247
- total_seconds = (steps * seconds_per_iteration) + 240
248
- cost_per_second = 0.80/60/60
249
- cost = round(cost_per_second * total_seconds, 2)
250
- cost_preview = f'''To train this LoRA, a paid L4 GPU will be hooked under the hood during training and then removed once finished.
251
- ## Estimated to cost <b>< US$ {str(cost)}</b> for {round(int(total_seconds)/60, 2)} minutes with your current train settings <small>({int(steps)} iterations at {seconds_per_iteration}s/it)</small>'''
252
- return gr.update(visible=True), cost_preview
 
 
 
 
 
 
 
 
 
253
 
254
  with gr.Blocks(theme=theme, css=css) as demo:
255
  gr.Markdown(
@@ -330,13 +338,14 @@ with gr.Blocks(theme=theme, css=css) as demo:
330
  sample_1 = gr.Textbox(label="Test prompt 1")
331
  sample_2 = gr.Textbox(label="Test prompt 2")
332
  sample_3 = gr.Textbox(label="Test prompt 3")
333
- with gr.Group(visible=False) as cost_preview:
334
- cost_preview_info = gr.Markdown()
 
335
  output_components.append(sample)
336
  output_components.append(sample_1)
337
  output_components.append(sample_2)
338
  output_components.append(sample_3)
339
- start = gr.Button("Start training")
340
  progress_area = gr.Markdown("")
341
 
342
  with gr.Tab("Train on your device" if is_spaces else "Instructions"):
@@ -387,13 +396,13 @@ with gr.Blocks(theme=theme, css=css) as demo:
387
  ).then(
388
  update_pricing,
389
  inputs=[steps],
390
- outputs=[cost_preview, cost_preview_info]
391
  )
392
-
393
- steps.change(
394
- update_pricing,
395
  inputs=[steps],
396
- outputs=[cost_preview, cost_preview_info]
397
  )
398
 
399
  start.click(fn=create_dataset, inputs=[images] + caption_list, outputs=dataset_folder).then(
 
229
  #component-1{text-align:center}
230
  .main_ui_logged_out{opacity: 0.3; pointer-events: none}
231
  .tabitem{border: 0px}
232
+ #cost_preview_info{padding: .5em}
233
  """
234
 
235
+ def swap_visibilty(profile: Union[gr.OAuthProfile, None]):
 
 
236
  if is_spaces:
237
  if profile is None:
238
  return gr.update(elem_classes=["main_ui_logged_out"])
 
241
  else:
242
  return gr.update(elem_classes=["main_ui_logged_in"])
243
 
244
+ def update_pricing(steps, oauth_token: Union[gr.OAuthToken, None]):
245
+ if(oauth_token and is_spaces):
246
+ user = whoami(oauth_token.token)
247
+ seconds_per_iteration = 7.54
248
+ total_seconds = (steps * seconds_per_iteration) + 240
249
+ cost_per_second = 0.80/60/60
250
+ cost = round(cost_per_second * total_seconds, 2)
251
+ cost_preview = f'''To train this LoRA, a paid L4 GPU will be hooked under the hood during training and then removed once finished.
252
+ ### Estimated to cost <b>< US$ {str(cost)}</b> for {round(int(total_seconds)/60, 2)} minutes with your current train settings <small>({int(steps)} iterations at {seconds_per_iteration}s/it)</small>'''
253
+ if(user["canPay"]):
254
+ return gr.update(visible=True), cost_preview, gr.update(visible=False), gr.update(visible=True)
255
+ else:
256
+ pay_disclaimer = f'''<b>## ⚠️ {user.name}, your account doesn't have a payment method. Set one up <a href='https://huggingface.co/settings/billing/payment' target='_blank'>here</a> and come back here to train your LoRA<br>'''
257
+ return gr.update(visible=True),
258
+ return gr.update(visible=True), pay_disclaimer+cost_preview, gr.update(visible=True), gr.update(visible=False)
259
+ else:
260
+ return gr.update(visible=False), "", gr.update(visible=False), gr.update(visible=True)
261
 
262
  with gr.Blocks(theme=theme, css=css) as demo:
263
  gr.Markdown(
 
338
  sample_1 = gr.Textbox(label="Test prompt 1")
339
  sample_2 = gr.Textbox(label="Test prompt 2")
340
  sample_3 = gr.Textbox(label="Test prompt 3")
341
+ with gr.Column(visible=False) as cost_preview:
342
+ cost_preview_info = gr.Markdown(elem_id="cost_preview_info")
343
+ payment_update = gr.Button("I have set up a payment method", visible=False)
344
  output_components.append(sample)
345
  output_components.append(sample_1)
346
  output_components.append(sample_2)
347
  output_components.append(sample_3)
348
+ start = gr.Button("Start training", visible=False)
349
  progress_area = gr.Markdown("")
350
 
351
  with gr.Tab("Train on your device" if is_spaces else "Instructions"):
 
396
  ).then(
397
  update_pricing,
398
  inputs=[steps],
399
+ outputs=[cost_preview, cost_preview_info, payment_update, start]
400
  )
401
+ gr.on(
402
+ triggers=[steps.change, payment_update.click],
403
+ fn=update_pricing,
404
  inputs=[steps],
405
+ outputs=[cost_preview, cost_preview_info, payment_update, start]
406
  )
407
 
408
  start.click(fn=create_dataset, inputs=[images] + caption_list, outputs=dataset_folder).then(