Henry Scheible commited on
Commit
54e4e45
1 Parent(s): a34b545

change app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -244
app.py CHANGED
@@ -66,122 +66,6 @@ def show_anns(anns):
66
  img[:,:,i] = color_mask[i]
67
  ax.imshow(np.dstack((img, m*0.35)))
68
 
69
-
70
- # def find_contours(img, color):
71
- # low = color - 10
72
- # high = color + 10
73
-
74
- # mask = cv2.inRange(img, low, high)
75
- # contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
76
-
77
- # print(f"Total Contours: {len(contours)}")
78
- # nonempty_contours = list()
79
- # for i in range(len(contours)):
80
- # if hierarchy[0,i,3] == -1 and cv2.contourArea(contours[i]) > cv2.arcLength(contours[i], True):
81
- # nonempty_contours += [contours[i]]
82
- # print(f"Nonempty Contours: {len(nonempty_contours)}")
83
- # contour_plot = img.copy()
84
- # contour_plot = cv2.drawContours(contour_plot, nonempty_contours, -1, (0,255,0), -1)
85
-
86
- # sorted_contours = sorted(nonempty_contours, key=cv2.contourArea, reverse= True)
87
-
88
- # bounding_rects = [cv2.boundingRect(cnt) for cnt in contours]
89
-
90
- # for (i,c) in enumerate(sorted_contours):
91
- # M= cv2.moments(c)
92
- # cx= int(M['m10']/M['m00'])
93
- # cy= int(M['m01']/M['m00'])
94
- # cv2.putText(contour_plot, text= str(i), org=(cx,cy),
95
- # fontFace= cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.25, color=(255,255,255),
96
- # thickness=1, lineType=cv2.LINE_AA)
97
-
98
- # N = len(sorted_contours)
99
- # H, W, C = img.shape
100
- # boxes_array_xywh = [cv2.boundingRect(cnt) for cnt in sorted_contours]
101
- # boxes_array_corners = [[x, y, x+w, y+h] for x, y, w, h in boxes_array_xywh]
102
- # boxes = torch.tensor(boxes_array_corners)
103
-
104
- # labels = torch.ones(N)
105
- # masks = np.zeros([N, H, W])
106
- # for idx in range(len(sorted_contours)):
107
- # cnt = sorted_contours[idx]
108
- # cv2.drawContours(masks[idx,:,:], [cnt], 0, (255), -1)
109
- # masks = masks / 255.0
110
- # masks = torch.tensor(masks)
111
-
112
- # # for box in boxes:
113
- # # cv2.rectangle(contour_plot, (box[0].item(), box[1].item()), (box[2].item(), box[3].item()), (255,0,0), 2)
114
-
115
- # return contour_plot, (boxes, masks)
116
-
117
-
118
- # def get_dataset_x(blank_image, filter_size=50, filter_stride=2):
119
- # full_image_tensor = torch.tensor(blank_image).type(torch.FloatTensor).permute(2, 0, 1).unsqueeze(0)
120
- # num_windows_h = math.floor((full_image_tensor.shape[2] - filter_size) / filter_stride) + 1
121
- # num_windows_w = math.floor((full_image_tensor.shape[3] - filter_size) / filter_stride) + 1
122
- # windows = torch.nn.functional.unfold(full_image_tensor, (filter_size, filter_size), stride=filter_stride).reshape(
123
- # [1, 3, 50, 50, num_windows_h * num_windows_w]).permute([0, 4, 1, 2, 3]).squeeze()
124
-
125
- # dataset_images = [windows[idx] for idx in range(len(windows))]
126
- # dataset = list(dataset_images)
127
- # return dataset
128
-
129
-
130
- # def get_dataset(labeled_image, blank_image, color, filter_size=50, filter_stride=2, label_size=5):
131
- # contour_plot, (blue_boxes, blue_masks) = find_contours(labeled_image, color)
132
-
133
- # mask = torch.sum(blue_masks, 0)
134
-
135
- # label_dim = int((labeled_image.shape[0] - filter_size) / filter_stride + 1)
136
- # labels = torch.zeros(label_dim, label_dim)
137
- # mask_labels = torch.zeros(label_dim, label_dim, filter_size, filter_size)
138
-
139
- # for lx in range(label_dim):
140
- # for ly in range(label_dim):
141
- # mask_labels[lx, ly, :, :] = mask[
142
- # lx * filter_stride: lx * filter_stride + filter_size,
143
- # ly * filter_stride: ly * filter_stride + filter_size
144
- # ]
145
-
146
- # print(labels.shape)
147
- # for box in blue_boxes:
148
- # x = int((box[0] + box[2]) / 2)
149
- # y = int((box[1] + box[3]) / 2)
150
-
151
- # window_x = int((x - int(filter_size / 2)) / filter_stride)
152
- # window_y = int((y - int(filter_size / 2)) / filter_stride)
153
-
154
- # clamp = lambda n, minn, maxn: max(min(maxn, n), minn)
155
-
156
- # labels[
157
- # clamp(window_y - label_size, 0, labels.shape[0] - 1):clamp(window_y + label_size, 0, labels.shape[0] - 1),
158
- # clamp(window_x - label_size, 0, labels.shape[0] - 1):clamp(window_x + label_size, 0, labels.shape[0] - 1),
159
- # ] = 1
160
-
161
- # positive_labels = labels.flatten() / labels.max()
162
- # negative_labels = 1 - positive_labels
163
- # pos_mask_labels = torch.flatten(mask_labels, end_dim=1)
164
- # neg_mask_labels = 1 - pos_mask_labels
165
- # mask_labels = torch.stack([pos_mask_labels, neg_mask_labels], dim=1)
166
- # dataset_labels = torch.tensor(list(zip(positive_labels, negative_labels)))
167
- # dataset = list(zip(
168
- # get_dataset_x(blank_image, filter_size=filter_size, filter_stride=filter_stride),
169
- # dataset_labels,
170
- # mask_labels
171
- # ))
172
- # return dataset, (labels, mask_labels)
173
-
174
-
175
- # from torchvision.models.resnet import resnet50
176
- # from torchvision.models.resnet import ResNet50_Weights
177
-
178
- # print("Loading resnet...")
179
- # model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
180
- # hidden_state_size = model.fc.in_features
181
- # model.fc = torch.nn.Linear(in_features=hidden_state_size, out_features=2, bias=True)
182
- # model.to(device)
183
- # model.load_state_dict(torch.load("model_best_epoch_4_59.62.pth", map_location=torch.device(device)))
184
- # model.to(device)
185
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
186
 
187
  model = sam_model_registry["default"](checkpoint="./sam_vit_h_4b8939.pth")
@@ -215,72 +99,14 @@ def check_circularity(segmentation):
215
  def count_barnacles(image_raw, split_num, progress=gr.Progress()):
216
  progress(0, desc="Finding bounding wire")
217
 
218
- # crop image
219
- # h, w = raw_input_img.shape[:2]
220
- # imghsv = cv2.cvtColor(raw_input_img, cv2.COLOR_RGB2HSV)
221
- # hsvblur = cv2.GaussianBlur(imghsv, (9, 9), 0)
222
-
223
- # lower = np.array([70, 20, 20])
224
- # upper = np.array([130, 255, 255])
225
-
226
- # color_mask = cv2.inRange(hsvblur, lower, upper)
227
-
228
- # invert = cv2.bitwise_not(color_mask)
229
-
230
- # contours, _ = cv2.findContours(invert, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
231
-
232
- # max_contour = contours[0]
233
- # largest_area = 0
234
- # for index, contour in enumerate(contours):
235
- # area = cv2.contourArea(contour)
236
- # if area > largest_area:
237
- # if cv2.pointPolygonTest(contour, (w / 2, h / 2), False) == 1:
238
- # largest_area = area
239
- # max_contour = contour
240
-
241
- # x, y, w, h = cv2.boundingRect(max_contour)
242
-
243
-
244
- # image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB)
245
- # image = Image.fromarray(image_raw)
246
- # image = image[:,:,::-1]
247
- # image = image_raw
248
- # print(image.shape)
249
- # print(type(image))
250
- # print(image.dtype)
251
- # print(image)
252
  corners = wireframe_extractor(image_raw)
253
  print(corners) # (0, 0, 1254, 1152)
254
 
255
  cropped_image = image_raw[corners[1]:corners[3]+corners[1], corners[0]:corners[2]+corners[0], :]
256
 
257
  print(cropped_image.shape)
258
- # cropped_image = cropped_image[100:400, 100:400]
259
- # print(cropped_image)
260
 
261
-
262
- # progress(0, desc="Generating Masks by point in window")
263
-
264
- # # get center point of windows
265
- # predictor.set_image(image)
266
- # mask_counter = 0
267
- # masks = []
268
-
269
- # for x in range(1,20, 2):
270
- # for y in range(1,20, 2):
271
- # point = np.array([[x*25, y*25]])
272
- # input_label = np.array([1])
273
- # mask, score, logit = predictor.predict(
274
- # point_coords=point,
275
- # point_labels=input_label,
276
- # multimask_output=False,
277
- # )
278
- # if score[0] > 0.8:
279
- # mask_counter += 1
280
- # masks.append(mask)
281
-
282
- # return mask_counter
283
- split_num = 2
284
 
285
  x_inc = int(cropped_image.shape[0]/split_num)
286
  y_inc = int(cropped_image.shape[1]/split_num)
@@ -301,23 +127,17 @@ def count_barnacles(image_raw, split_num, progress=gr.Progress()):
301
  # plt.figure()
302
  # plt.imshow(small_image)
303
  # plt.axis('on')
304
-
 
 
305
  masks = mask_generator.generate(small_image)
306
-
307
 
308
- for mask in masks:
 
309
  circular = check_circularity(mask['segmentation'])
310
  if circular and mask['area']>500 and mask['area'] < 10000:
311
  mask_counter += 1
312
- # if cropped_image.shape != image_raw.shape:
313
- # add_to_row = [False] * corners[0]
314
- # temp = [False]*(corners[2]+corners[0])
315
- # temp = [temp]*corners[1]
316
- # new_seg = np.array(temp)
317
- # for row in mask['segmentation']:
318
- # row = np.append(add_to_row, row)
319
- # new_seg = np.vstack([new_seg, row])
320
- # mask['segmentation'] = new_seg
321
  good_masks.append(mask)
322
  box = mask['bbox']
323
  centers.append((box[0] + box[2]/2 + corners[0] + startx, box[1] + box[3]/2 + corners[1] + starty))
@@ -358,63 +178,6 @@ def count_barnacles(image_raw, split_num, progress=gr.Progress()):
358
  # return annotated, mask_counter, centers
359
  return fig, mask_counter, centers
360
 
361
-
362
- # return len(masks)
363
-
364
- # progress(0, desc="Resizing Image")
365
- # cropped_img = raw_input_img[x:x+w, y:y+h]
366
- # cropped_image_tensor = torch.transpose(torch.tensor(cropped_img).to(device), 0, 2)
367
- # resize = Resize((1500, 1500))
368
- # input_img = cropped_image_tensor
369
- # blank_img_copy = torch.transpose(input_img, 0, 2).to("cpu").detach().numpy().copy()
370
-
371
- # progress(0, desc="Generating Windows")
372
- # test_dataset = get_dataset_x(input_img)
373
- # test_dataloader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
374
- # model.eval()
375
- # predicted_labels_list = []
376
- # for data in progress.tqdm(test_dataloader):
377
- # with torch.no_grad():
378
- # data = data.to(device)
379
- # predicted_labels_list += [model(data)]
380
- # predicted_labels = torch.cat(predicted_labels_list)
381
- # x = int(math.sqrt(predicted_labels.shape[0]))
382
- # predicted_labels = predicted_labels.reshape([x, x, 2]).detach()
383
- # label_img = predicted_labels[:, :, :1].cpu().numpy()
384
- # label_img -= label_img.min()
385
- # label_img /= label_img.max()
386
- # label_img = (label_img * 255).astype(np.uint8)
387
- # mask = np.array(label_img > 180, np.uint8)
388
- # contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)\
389
-
390
- # gt_contours = find_contours(labeled_input_img[x:x+w, y:y+h], cropped_img, np.array([59, 76, 160]))
391
-
392
-
393
-
394
- # def extract_contour_center(cnt):
395
- # M = cv2.moments(cnt)
396
- # cx = int(M['m10'] / M['m00'])
397
- # cy = int(M['m01'] / M['m00'])
398
- # return cx, cy
399
-
400
- # filter_width = 50
401
- # filter_stride = 2
402
-
403
- # def rev_window_transform(point):
404
- # wx, wy = point
405
- # x = int(filter_width / 2) + wx * filter_stride
406
- # y = int(filter_width / 2) + wy * filter_stride
407
- # return x, y
408
-
409
- # nonempty_contours = filter(lambda cnt: cv2.contourArea(cnt) != 0, contours)
410
- # windows = map(extract_contour_center, nonempty_contours)
411
- # points = list(map(rev_window_transform, windows))
412
- # for x, y in points:
413
- # blank_img_copy = cv2.circle(blank_img_copy, (x, y), radius=4, color=(255, 0, 0), thickness=-1)
414
- # print(f"pointlist: {len(points)}")
415
- # return blank_img_copy, len(points)
416
-
417
-
418
  demo = gr.Interface(count_barnacles,
419
  inputs=[
420
  gr.Image(type="numpy", label="Input Image"),
 
66
  img[:,:,i] = color_mask[i]
67
  ax.imshow(np.dstack((img, m*0.35)))
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
70
 
71
  model = sam_model_registry["default"](checkpoint="./sam_vit_h_4b8939.pth")
 
99
  def count_barnacles(image_raw, split_num, progress=gr.Progress()):
100
  progress(0, desc="Finding bounding wire")
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  corners = wireframe_extractor(image_raw)
103
  print(corners) # (0, 0, 1254, 1152)
104
 
105
  cropped_image = image_raw[corners[1]:corners[3]+corners[1], corners[0]:corners[2]+corners[0], :]
106
 
107
  print(cropped_image.shape)
 
 
108
 
109
+ split_num = 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  x_inc = int(cropped_image.shape[0]/split_num)
112
  y_inc = int(cropped_image.shape[1]/split_num)
 
127
  # plt.figure()
128
  # plt.imshow(small_image)
129
  # plt.axis('on')
130
+ progress(0, desc=f"Encoding crop {r*split_num + c}/{split_num ** 2}")
131
+ mask_generator.predictor.set_image(small_image)
132
+ progress(0, desc=f"Generating masks for crop {r*split_num + c}/{split_num ** 2}")
133
  masks = mask_generator.generate(small_image)
134
+ num_masks = len(masks)
135
 
136
+ for idx, mask in enumerate(masks):
137
+ progress(float(idx)/float(num_masks), desc=f"Processing masks for crop {r*split_num + c}/{split_num ** 2}")
138
  circular = check_circularity(mask['segmentation'])
139
  if circular and mask['area']>500 and mask['area'] < 10000:
140
  mask_counter += 1
 
 
 
 
 
 
 
 
 
141
  good_masks.append(mask)
142
  box = mask['bbox']
143
  centers.append((box[0] + box[2]/2 + corners[0] + startx, box[1] + box[3]/2 + corners[1] + starty))
 
178
  # return annotated, mask_counter, centers
179
  return fig, mask_counter, centers
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  demo = gr.Interface(count_barnacles,
182
  inputs=[
183
  gr.Image(type="numpy", label="Input Image"),