nikigoli commited on
Commit
1cc5944
1 Parent(s): b812b2c

Changed how device being allocated

Browse files
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -57,6 +57,9 @@ def get_args_parser():
57
  >> python main.py --output_dir ./gdino_test -c config/cfg_fsc147_vit_b_test.py --eval --datasets config/datasets_fsc147.json --pretrain_model_path ../checkpoints_and_logs/gdino_train/checkpoint_best_regular.pth --options text_encoder_type=checkpoints/bert-base-uncased --sam_tt_norm --crop
58
  """
59
  parser = argparse.ArgumentParser("Set transformer detector", add_help=False)
 
 
 
60
  parser.add_argument(
61
  "--options",
62
  nargs="+",
@@ -108,11 +111,12 @@ def get_args_parser():
108
  parser.add_argument("--amp", action="store_true", help="Train with mixed precision")
109
  return parser
110
 
111
- @spaces.GPU
112
  def get_device():
113
  if torch.cuda.is_available():
 
114
  return torch.device('cuda')
115
  else:
 
116
  return torch.device('cpu')
117
 
118
  # Get counting model.
@@ -136,8 +140,6 @@ def build_model_and_transforms(args):
136
  else:
137
  raise ValueError("Key {} can used by args only".format(k))
138
 
139
- # Start with model on CPU.
140
- args.device = "cpu"
141
  # fix the seed for reproducibility
142
  seed = 42
143
  torch.manual_seed(seed)
@@ -220,7 +222,6 @@ def get_ind_to_filter(text, word_ids, keywords):
220
 
221
  @spaces.GPU
222
  def count(image, text, prompts, state, device):
223
- model.to(device)
224
 
225
  keywords = "" # do not handle this for now
226
 
@@ -303,13 +304,11 @@ def count(image, text, prompts, state, device):
303
  out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
304
  else:
305
  out_label = "Nothing specified to detect."
306
- #model.cpu()
307
  return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]), new_submit_btn, gr.Tab(visible=True), step_3, state)
308
 
309
  @spaces.GPU
310
  def count_main(image, text, prompts, device):
311
- model.to(device)
312
-
313
  keywords = "" # do not handle this for now
314
  # Handle no prompt case.
315
  if prompts is None:
@@ -372,7 +371,7 @@ def count_main(image, text, prompts, device):
372
  out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
373
  else:
374
  out_label = "Nothing specified to detect."
375
- model.cpu()
376
  return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]))
377
 
378
  def remove_label(image):
 
57
  >> python main.py --output_dir ./gdino_test -c config/cfg_fsc147_vit_b_test.py --eval --datasets config/datasets_fsc147.json --pretrain_model_path ../checkpoints_and_logs/gdino_train/checkpoint_best_regular.pth --options text_encoder_type=checkpoints/bert-base-uncased --sam_tt_norm --crop
58
  """
59
  parser = argparse.ArgumentParser("Set transformer detector", add_help=False)
60
+ parser.add_argument(
61
+ "--device", default="cuda", help="device to use for inference"
62
+ )
63
  parser.add_argument(
64
  "--options",
65
  nargs="+",
 
111
  parser.add_argument("--amp", action="store_true", help="Train with mixed precision")
112
  return parser
113
 
 
114
  def get_device():
115
  if torch.cuda.is_available():
116
+ print("USING GPU")
117
  return torch.device('cuda')
118
  else:
119
+ print("USING CPU")
120
  return torch.device('cpu')
121
 
122
  # Get counting model.
 
140
  else:
141
  raise ValueError("Key {} can used by args only".format(k))
142
 
 
 
143
  # fix the seed for reproducibility
144
  seed = 42
145
  torch.manual_seed(seed)
 
222
 
223
  @spaces.GPU
224
  def count(image, text, prompts, state, device):
 
225
 
226
  keywords = "" # do not handle this for now
227
 
 
304
  out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
305
  else:
306
  out_label = "Nothing specified to detect."
307
+
308
  return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]), new_submit_btn, gr.Tab(visible=True), step_3, state)
309
 
310
  @spaces.GPU
311
  def count_main(image, text, prompts, device):
 
 
312
  keywords = "" # do not handle this for now
313
  # Handle no prompt case.
314
  if prompts is None:
 
371
  out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
372
  else:
373
  out_label = "Nothing specified to detect."
374
+
375
  return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]))
376
 
377
  def remove_label(image):