Ashoka74 commited on
Commit
c15f1a2
1 Parent(s): 4193881

Update app_merged.py

Browse files
Files changed (1) hide show
  1. app_merged.py +28 -113
app_merged.py CHANGED
@@ -841,8 +841,6 @@ def use_orientation(selected_image:gr.SelectData):
841
  def process_image(input_image, input_text):
842
  """Main processing function for the Gradio interface"""
843
 
844
-
845
-
846
  if isinstance(input_image, Image.Image):
847
  input_image = np.array(input_image)
848
 
@@ -857,7 +855,6 @@ def process_image(input_image, input_text):
857
  HEIGHT = 768
858
  WIDTH = 768
859
 
860
-
861
  # Initialize DDS client
862
  config = Config(API_TOKEN)
863
  client = Client(config)
@@ -867,8 +864,6 @@ def process_image(input_image, input_text):
867
  class_name_to_id = {name: id for id, name in enumerate(classes)}
868
  class_id_to_name = {id: name for name, id in class_name_to_id.items()}
869
 
870
-
871
-
872
  # Save input image to temp file and get URL
873
  with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
874
  cv2.imwrite(tmpfile.name, input_image)
@@ -884,11 +879,11 @@ def process_image(input_image, input_text):
884
 
885
  if len(input_text) == 0:
886
  task = DinoxTask(
887
- image_url=image_url,
888
- prompts=[TextPrompt(text="<prompt_free>")],
889
- # targets=[DetectionTarget.BBox, DetectionTarget.Mask]
890
  )
891
-
892
  client.run_task(task)
893
  predictions = task.result.objects
894
  classes = [pred.category for pred in predictions]
@@ -931,38 +926,24 @@ def process_image(input_image, input_text):
931
  if len(detections) > 0:
932
  # Get first mask
933
  first_mask = detections.mask[0]
934
-
935
  # Get original RGB image
936
  img = input_image.copy()
937
-
938
  H, W, C = img.shape
939
-
940
- # Create RGBA image
941
  alpha = np.zeros((H, W, 1), dtype=np.uint8)
942
-
943
- alpha[first_mask] = 255
944
-
945
- # rgba = np.dstack((img, alpha)).astype(np.uint8)
946
-
947
- # Crop to mask bounds to minimize image size
948
- # y_indices, x_indices = np.where(first_mask)
949
- # y_min, y_max = y_indices.min(), y_indices.max()
950
- # x_min, x_max = x_indices.min(), x_indices.max()
951
-
952
- # Crop the RGBA image
953
- # cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
954
-
955
- # Set extracted foreground for mask mover
956
- # mask_mover.set_extracted_fg(cropped_rgba)
957
 
958
- # alpha = img[..., 3] > 0
959
- H, W = alpha.shape
960
  # get the bounding box of alpha
961
  y, x = np.where(alpha > 0)
962
  y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
963
  x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
964
-
965
- image_center = img[y0:y1, x0:x1]
966
  # resize the longer side to H * 0.9
967
  H, W, _ = image_center.shape
968
  if H > W:
@@ -972,7 +953,7 @@ def process_image(input_image, input_text):
972
  H = int(H * (WIDTH * 0.9) / W)
973
  W = int(WIDTH * 0.9)
974
 
975
- image_center = np.array(Image.fromarray(image_center).resize((W, H)))
976
  # pad to H, W
977
  start_h = (HEIGHT - H) // 2
978
  start_w = (WIDTH - W) // 2
@@ -982,10 +963,9 @@ def process_image(input_image, input_text):
982
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
983
  image = (image * 255).clip(0, 255).astype(np.uint8)
984
  image = Image.fromarray(image)
985
-
986
- return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
987
 
988
-
 
989
  else:
990
  # Run DINO-X detection
991
  task = DinoxTask(
@@ -998,18 +978,6 @@ def process_image(input_image, input_text):
998
  result = task.result
999
  objects = result.objects
1000
 
1001
-
1002
-
1003
- # for obj in objects:
1004
- # input_boxes.append(obj.bbox)
1005
- # confidences.append(obj.score)
1006
- # cls_name = obj.category.lower().strip()
1007
- # class_names.append(cls_name)
1008
- # class_ids.append(class_name_to_id[cls_name])
1009
-
1010
- # input_boxes = np.array(input_boxes)
1011
- # class_ids = np.array(class_ids)
1012
-
1013
  predictions = task.result.objects
1014
  classes = [x.strip().lower() for x in input_text.split('.') if x]
1015
  class_name_to_id = {name: id for id, name in enumerate(classes)}
@@ -1037,46 +1005,12 @@ def process_image(input_image, input_text):
1037
  for class_name, confidence
1038
  in zip(class_names, confidences)
1039
  ]
1040
-
1041
- # Initialize SAM2
1042
- # torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
1043
- # if torch.cuda.get_device_properties(0).major >= 8:
1044
- # torch.backends.cuda.matmul.allow_tf32 = True
1045
- # torch.backends.cudnn.allow_tf32 = True
1046
-
1047
- # sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
1048
- # sam2_predictor = SAM2ImagePredictor(sam2_model)
1049
- # sam2_predictor.set_image(input_image)
1050
-
1051
- # sam2_predictor = run_sam_inference(SAM_IMAGE_MODEL, input_image, detections)
1052
-
1053
-
1054
- # Get masks from SAM2
1055
- # masks, scores, logits = sam2_predictor.predict(
1056
- # point_coords=None,
1057
- # point_labels=None,
1058
- # box=input_boxes,
1059
- # multimask_output=False,
1060
- # )
1061
-
1062
- if masks.ndim == 4:
1063
- masks = masks.squeeze(1)
1064
-
1065
- # Create visualization
1066
- # labels = [f"{class_name} {confidence:.2f}"
1067
- # for class_name, confidence in zip(class_names, confidences)]
1068
-
1069
- # detections = sv.Detections(
1070
- # xyxy=input_boxes,
1071
- # mask=masks.astype(bool),
1072
- # class_id=class_ids
1073
- # )
1074
 
1075
  detections = sv.Detections(
1076
- xyxy = boxes,
1077
- mask = masks.astype(bool),
1078
- class_id = class_ids,
1079
- )
1080
 
1081
  box_annotator = sv.BoxAnnotator()
1082
  label_annotator = sv.LabelAnnotator()
@@ -1096,36 +1030,18 @@ def process_image(input_image, input_text):
1096
  img = input_image.copy()
1097
  H, W, C = img.shape
1098
 
1099
- first_mask = detections.mask[0]
1100
-
1101
-
1102
-
1103
- # Create RGBA image
1104
  alpha = np.zeros((H, W, 1), dtype=np.uint8)
1105
-
1106
- alpha[first_mask] = 255
1107
-
1108
- # rgba = np.dstack((img, alpha)).astype(np.uint8)
1109
-
1110
- # Crop to mask bounds to minimize image size
1111
- # y_indices, x_indices = np.where(first_mask)
1112
- # y_min, y_max = y_indices.min(), y_indices.max()
1113
- # x_min, x_max = x_indices.min(), x_indices.max()
1114
-
1115
- # Crop the RGBA image
1116
- # cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
1117
-
1118
- # Set extracted foreground for mask mover
1119
- # mask_mover.set_extracted_fg(cropped_rgba)
1120
-
1121
- # alpha = img[..., 3] > 0
1122
- H, W = alpha.shape
1123
  # get the bounding box of alpha
1124
  y, x = np.where(alpha > 0)
1125
  y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
1126
  x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
1127
-
1128
- image_center = img[y0:y1, x0:x1]
1129
  # resize the longer side to H * 0.9
1130
  H, W, _ = image_center.shape
1131
  if H > W:
@@ -1135,7 +1051,7 @@ def process_image(input_image, input_text):
1135
  H = int(H * (WIDTH * 0.9) / W)
1136
  W = int(WIDTH * 0.9)
1137
 
1138
- image_center = np.array(Image.fromarray(image_center).resize((W, H)))
1139
  # pad to H, W
1140
  start_h = (HEIGHT - H) // 2
1141
  start_w = (WIDTH - W) // 2
@@ -1148,7 +1064,6 @@ def process_image(input_image, input_text):
1148
 
1149
  return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
1150
  return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
1151
-
1152
 
1153
 
1154
 
 
841
  def process_image(input_image, input_text):
842
  """Main processing function for the Gradio interface"""
843
 
 
 
844
  if isinstance(input_image, Image.Image):
845
  input_image = np.array(input_image)
846
 
 
855
  HEIGHT = 768
856
  WIDTH = 768
857
 
 
858
  # Initialize DDS client
859
  config = Config(API_TOKEN)
860
  client = Client(config)
 
864
  class_name_to_id = {name: id for id, name in enumerate(classes)}
865
  class_id_to_name = {id: name for name, id in class_name_to_id.items()}
866
 
 
 
867
  # Save input image to temp file and get URL
868
  with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
869
  cv2.imwrite(tmpfile.name, input_image)
 
879
 
880
  if len(input_text) == 0:
881
  task = DinoxTask(
882
+ image_url=image_url,
883
+ prompts=[TextPrompt(text="<prompt_free>")],
884
+ # targets=[DetectionTarget.BBox, DetectionTarget.Mask]
885
  )
886
+
887
  client.run_task(task)
888
  predictions = task.result.objects
889
  classes = [pred.category for pred in predictions]
 
926
  if len(detections) > 0:
927
  # Get first mask
928
  first_mask = detections.mask[0]
929
+
930
  # Get original RGB image
931
  img = input_image.copy()
 
932
  H, W, C = img.shape
933
+
934
+ # Create RGBA image with default 255 alpha
935
  alpha = np.zeros((H, W, 1), dtype=np.uint8)
936
+ alpha[~first_mask] = 0 # 128 # for semi-transparency background
937
+ alpha[first_mask] = 255 # Make the foreground opaque
938
+ alpha = alpha.squeeze(-1) # Remove singleton dimension to become 2D
939
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
940
 
 
 
941
  # get the bounding box of alpha
942
  y, x = np.where(alpha > 0)
943
  y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
944
  x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
945
+
946
+ image_center = rgba[y0:y1, x0:x1]
947
  # resize the longer side to H * 0.9
948
  H, W, _ = image_center.shape
949
  if H > W:
 
953
  H = int(H * (WIDTH * 0.9) / W)
954
  W = int(WIDTH * 0.9)
955
 
956
+ image_center = np.array(Image.fromarray(image_center).resize((W, H), Image.LANCZOS))
957
  # pad to H, W
958
  start_h = (HEIGHT - H) // 2
959
  start_w = (WIDTH - W) // 2
 
963
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
964
  image = (image * 255).clip(0, 255).astype(np.uint8)
965
  image = Image.fromarray(image)
 
 
966
 
967
+ return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
968
+ return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
969
  else:
970
  # Run DINO-X detection
971
  task = DinoxTask(
 
978
  result = task.result
979
  objects = result.objects
980
 
 
 
 
 
 
 
 
 
 
 
 
 
981
  predictions = task.result.objects
982
  classes = [x.strip().lower() for x in input_text.split('.') if x]
983
  class_name_to_id = {name: id for id, name in enumerate(classes)}
 
1005
  for class_name, confidence
1006
  in zip(class_names, confidences)
1007
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1008
 
1009
  detections = sv.Detections(
1010
+ xyxy=boxes,
1011
+ mask=masks.astype(bool),
1012
+ class_id=class_ids,
1013
+ )
1014
 
1015
  box_annotator = sv.BoxAnnotator()
1016
  label_annotator = sv.LabelAnnotator()
 
1030
  img = input_image.copy()
1031
  H, W, C = img.shape
1032
 
1033
+ # Create RGBA image with default 255 alpha
 
 
 
 
1034
  alpha = np.zeros((H, W, 1), dtype=np.uint8)
1035
+ alpha[~first_mask] = 0 # 128 for semi-transparency background
1036
+ alpha[first_mask] = 255 # Make the foreground opaque
1037
+ alpha = alpha.squeeze(-1) # Remove singleton dimension to become 2D
1038
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1039
  # get the bounding box of alpha
1040
  y, x = np.where(alpha > 0)
1041
  y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
1042
  x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
1043
+
1044
+ image_center = rgba[y0:y1, x0:x1]
1045
  # resize the longer side to H * 0.9
1046
  H, W, _ = image_center.shape
1047
  if H > W:
 
1051
  H = int(H * (WIDTH * 0.9) / W)
1052
  W = int(WIDTH * 0.9)
1053
 
1054
+ image_center = np.array(Image.fromarray(image_center).resize((W, H), Image.LANCZOS))
1055
  # pad to H, W
1056
  start_h = (HEIGHT - H) // 2
1057
  start_w = (WIDTH - W) // 2
 
1064
 
1065
  return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
1066
  return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
 
1067
 
1068
 
1069