nikigoli commited on
Commit
31bd1d5
1 Parent(s): e480de7

Changed device to global variable in gradio (ie gr.State instance)

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -208,26 +208,26 @@ def get_ind_to_filter(text, word_ids, keywords):
208
  return inds_to_filter
209
 
210
  @spaces.GPU
211
- def count(image, text, prompts, state):
212
  print("state: " + str(state))
213
  keywords = "" # do not handle this for now
214
  # Handle no prompt case.
215
  if prompts is None:
216
  prompts = {"image": image, "points": []}
217
  input_image, _ = transform(image, {"exemplars": torch.tensor([])})
218
- input_image = input_image.unsqueeze(0).to(args.device)
219
  exemplars = get_box_inputs(prompts["points"])
220
  print(exemplars)
221
  input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
222
- input_image_exemplars = input_image_exemplars.unsqueeze(0).to(args.device)
223
- exemplars = [exemplars["exemplars"].to(args.device)]
224
 
225
  with torch.no_grad():
226
  model_output = model(
227
  nested_tensor_from_tensor_list(input_image),
228
  nested_tensor_from_tensor_list(input_image_exemplars),
229
  exemplars,
230
- [torch.tensor([0]).to(args.device) for _ in range(len(input_image))],
231
  captions=[text + " ."] * len(input_image),
232
  )
233
 
@@ -297,25 +297,25 @@ def count(image, text, prompts, state):
297
  return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]), new_submit_btn, gr.Tab(visible=True), step_3, state)
298
 
299
  @spaces.GPU
300
- def count_main(image, text, prompts):
301
  keywords = "" # do not handle this for now
302
  # Handle no prompt case.
303
  if prompts is None:
304
  prompts = {"image": image, "points": []}
305
  input_image, _ = transform(image, {"exemplars": torch.tensor([])})
306
- input_image = input_image.unsqueeze(0).to(args.device)
307
  exemplars = get_box_inputs(prompts["points"])
308
  print(exemplars)
309
  input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
310
- input_image_exemplars = input_image_exemplars.unsqueeze(0).to(args.device)
311
- exemplars = [exemplars["exemplars"].to(args.device)]
312
 
313
  with torch.no_grad():
314
  model_output = model(
315
  nested_tensor_from_tensor_list(input_image),
316
  nested_tensor_from_tensor_list(input_image_exemplars),
317
  exemplars,
318
- [torch.tensor([0]).to(args.device) for _ in range(len(input_image))],
319
  captions=[text + " ."] * len(input_image),
320
  )
321
 
@@ -396,6 +396,7 @@ As shown earlier, there are 3 ways to specify the object to count: (1) with text
396
 
397
  with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", head="""<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=1">""") as demo:
398
  state = gr.State(value=[AppSteps.JUST_TEXT])
 
399
  with gr.Tab("Tutorial"):
400
  with gr.Row():
401
  with gr.Column():
@@ -419,7 +420,7 @@ with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", h
419
  pred_count = gr.Number(label="Predicted Count", visible=False)
420
  submit_btn = gr.Button("Count", variant="primary", interactive=True)
421
 
422
- submit_btn.click(fn=remove_label, inputs=[detected_instances], outputs=[detected_instances]).then(fn=count, inputs=[input_image, input_text, exemplar_image, state], outputs=[detected_instances, pred_count, submit_btn, step_2, step_3, state])
423
  exemplar_image.change(check_submit_btn, inputs=[exemplar_image, state], outputs=[submit_btn])
424
  with gr.Tab("App", visible=True) as main_app:
425
 
@@ -445,7 +446,7 @@ with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", h
445
  submit_btn_main = gr.Button("Count", variant="primary")
446
  clear_btn_main = gr.ClearButton(variant="secondary")
447
  gr.Examples(label="Examples: click on a row to load the example. Add visual exemplars by drawing boxes on the loaded \"Visual Exemplar Image.\"", examples=examples, inputs=[input_image_main, input_text_main, exemplar_image_main])
448
- submit_btn_main.click(fn=remove_label, inputs=[detected_instances_main], outputs=[detected_instances_main]).then(fn=count_main, inputs=[input_image_main, input_text_main, exemplar_image_main], outputs=[detected_instances_main, pred_count_main])
449
  clear_btn_main.add([input_image_main, input_text_main, exemplar_image_main, detected_instances_main, pred_count_main])
450
 
451
 
 
208
  return inds_to_filter
209
 
210
  @spaces.GPU
211
+ def count(image, text, prompts, state, device):
212
  print("state: " + str(state))
213
  keywords = "" # do not handle this for now
214
  # Handle no prompt case.
215
  if prompts is None:
216
  prompts = {"image": image, "points": []}
217
  input_image, _ = transform(image, {"exemplars": torch.tensor([])})
218
+ input_image = input_image.unsqueeze(0).to(device)
219
  exemplars = get_box_inputs(prompts["points"])
220
  print(exemplars)
221
  input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
222
+ input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
223
+ exemplars = [exemplars["exemplars"].to(device)]
224
 
225
  with torch.no_grad():
226
  model_output = model(
227
  nested_tensor_from_tensor_list(input_image),
228
  nested_tensor_from_tensor_list(input_image_exemplars),
229
  exemplars,
230
+ [torch.tensor([0]).to(device) for _ in range(len(input_image))],
231
  captions=[text + " ."] * len(input_image),
232
  )
233
 
 
297
  return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]), new_submit_btn, gr.Tab(visible=True), step_3, state)
298
 
299
  @spaces.GPU
300
+ def count_main(image, text, prompts, device):
301
  keywords = "" # do not handle this for now
302
  # Handle no prompt case.
303
  if prompts is None:
304
  prompts = {"image": image, "points": []}
305
  input_image, _ = transform(image, {"exemplars": torch.tensor([])})
306
+ input_image = input_image.unsqueeze(0).to(device)
307
  exemplars = get_box_inputs(prompts["points"])
308
  print(exemplars)
309
  input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
310
+ input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
311
+ exemplars = [exemplars["exemplars"].to(device)]
312
 
313
  with torch.no_grad():
314
  model_output = model(
315
  nested_tensor_from_tensor_list(input_image),
316
  nested_tensor_from_tensor_list(input_image_exemplars),
317
  exemplars,
318
+ [torch.tensor([0]).to(device) for _ in range(len(input_image))],
319
  captions=[text + " ."] * len(input_image),
320
  )
321
 
 
396
 
397
  with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", head="""<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=1">""") as demo:
398
  state = gr.State(value=[AppSteps.JUST_TEXT])
399
+ device = gr.State(args.device)
400
  with gr.Tab("Tutorial"):
401
  with gr.Row():
402
  with gr.Column():
 
420
  pred_count = gr.Number(label="Predicted Count", visible=False)
421
  submit_btn = gr.Button("Count", variant="primary", interactive=True)
422
 
423
+ submit_btn.click(fn=remove_label, inputs=[detected_instances], outputs=[detected_instances]).then(fn=count, inputs=[input_image, input_text, exemplar_image, state, device], outputs=[detected_instances, pred_count, submit_btn, step_2, step_3, state])
424
  exemplar_image.change(check_submit_btn, inputs=[exemplar_image, state], outputs=[submit_btn])
425
  with gr.Tab("App", visible=True) as main_app:
426
 
 
446
  submit_btn_main = gr.Button("Count", variant="primary")
447
  clear_btn_main = gr.ClearButton(variant="secondary")
448
  gr.Examples(label="Examples: click on a row to load the example. Add visual exemplars by drawing boxes on the loaded \"Visual Exemplar Image.\"", examples=examples, inputs=[input_image_main, input_text_main, exemplar_image_main])
449
+ submit_btn_main.click(fn=remove_label, inputs=[detected_instances_main], outputs=[detected_instances_main]).then(fn=count_main, inputs=[input_image_main, input_text_main, exemplar_image_main, device], outputs=[detected_instances_main, pred_count_main])
450
  clear_btn_main.add([input_image_main, input_text_main, exemplar_image_main, detected_instances_main, pred_count_main])
451
 
452