Henry Scheible commited on
Commit
29c7e10
1 Parent(s): 80b40b1

decrease number of crops

Browse files
Files changed (1) hide show
  1. app.py +2 -5
app.py CHANGED
@@ -71,8 +71,6 @@ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamP
71
  model = sam_model_registry["default"](checkpoint="./sam_vit_h_4b8939.pth")
72
  model.to(device)
73
 
74
- predictor = SamPredictor(model)
75
-
76
  mask_generator = SamAutomaticMaskGenerator(model)
77
 
78
  import gradio as gr
@@ -106,7 +104,7 @@ def count_barnacles(image_raw, split_num, progress=gr.Progress()):
106
 
107
  print(cropped_image.shape)
108
 
109
- split_num = 3
110
 
111
  x_inc = int(cropped_image.shape[0]/split_num)
112
  y_inc = int(cropped_image.shape[1]/split_num)
@@ -127,7 +125,6 @@ def count_barnacles(image_raw, split_num, progress=gr.Progress()):
127
  # plt.figure()
128
  # plt.imshow(small_image)
129
  # plt.axis('on')
130
- mask_generator.predictor.set_image(small_image)
131
  progress(0, desc=f"Generating masks for crop {r*split_num + c}/{split_num ** 2}")
132
  masks = mask_generator.generate(small_image)
133
  num_masks = len(masks)
@@ -144,7 +141,7 @@ def count_barnacles(image_raw, split_num, progress=gr.Progress()):
144
 
145
  progress(0, desc="Generating Plot")
146
  # Create a figure with a size of 10 inches by 10 inches
147
- fig = plt.figure(figsize=(40, 40))
148
 
149
  # Display the image using the imshow() function
150
  # plt.imshow(cropped_image)
 
71
  model = sam_model_registry["default"](checkpoint="./sam_vit_h_4b8939.pth")
72
  model.to(device)
73
 
 
 
74
  mask_generator = SamAutomaticMaskGenerator(model)
75
 
76
  import gradio as gr
 
104
 
105
  print(cropped_image.shape)
106
 
107
+ split_num = 2
108
 
109
  x_inc = int(cropped_image.shape[0]/split_num)
110
  y_inc = int(cropped_image.shape[1]/split_num)
 
125
  # plt.figure()
126
  # plt.imshow(small_image)
127
  # plt.axis('on')
 
128
  progress(0, desc=f"Generating masks for crop {r*split_num + c}/{split_num ** 2}")
129
  masks = mask_generator.generate(small_image)
130
  num_masks = len(masks)
 
141
 
142
  progress(0, desc="Generating Plot")
143
  # Create a figure with a size of 10 inches by 10 inches
144
+ fig = plt.figure(figsize=(10, 10))
145
 
146
  # Display the image using the imshow() function
147
  # plt.imshow(cropped_image)