dennistrujillo commited on
Commit
f0b3d8c
·
verified ·
1 Parent(s): 06cbd50

restored inference functionality

Browse files
Files changed (1) hide show
  1. app.py +77 -5
app.py CHANGED
@@ -29,15 +29,63 @@ def load_image(file_path):
29
  H, W = img.shape[:2]
30
  return img, H, W
31
 
32
- # The rest of the code remains the same...
 
 
 
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Main function for Gradio app
35
  def process_images(img_dict):
36
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
37
 
38
  # Load and preprocess image
 
39
  img = img_dict['image']
40
- print(image.type())
41
  points = img_dict['points'][0] # Accessing the first (and possibly only) set of points
42
  if len(points) >= 6:
43
  x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
@@ -48,12 +96,36 @@ def process_images(img_dict):
48
  image = np.repeat(image[:, :, None], 3, axis=-1)
49
  H, W, _ = image.shape
50
 
51
- # The rest of the function remains the same...
 
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  # Set up Gradio interface
54
  iface = gr.Interface(
55
  fn=process_images,
56
- inputs=[gr.File(label="image or nrrd")],
 
 
57
  outputs=[
58
  gr.Image(type="pil", label="Processed Image")
59
  ],
@@ -62,4 +134,4 @@ iface = gr.Interface(
62
  )
63
 
64
  # Launch the interface
65
- iface.launch()
 
29
  H, W = img.shape[:2]
30
  return img, H, W
31
 
32
+ @torch.no_grad()
33
+ def medsam_inference(medsam_model, img_embed, box_1024, H, W):
34
+ box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
35
+ if len(box_torch.shape) == 2:
36
+ box_torch = box_torch[:, None, :] # (B, 1, 4)
37
 
38
+ box_torch=box_torch.reshape(1,4)
39
+ sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
40
+ points=None,
41
+ boxes=box_torch,
42
+ masks=None,
43
+ )
44
+
45
+ low_res_logits, _ = medsam_model.mask_decoder(
46
+ image_embeddings=img_embed, # (B, 256, 64, 64)
47
+ image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
48
+ sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
49
+ dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
50
+ multimask_output=False,
51
+ )
52
+
53
+ low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
54
+
55
+ low_res_pred = F.interpolate(
56
+ low_res_pred,
57
+ size=(H, W),
58
+ mode="bilinear",
59
+ align_corners=False,
60
+ ) # (1, 1, gt.shape)
61
+ low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256)
62
+ medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
63
+ return medsam_seg
64
+
65
+ # Function for visualizing images with masks
66
+ def visualize(image, mask, box):
67
+ fig, ax = plt.subplots(1, 2, figsize=(10, 5))
68
+ ax[0].imshow(image, cmap='gray')
69
+ ax[0].add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], edgecolor="red", facecolor="none"))
70
+ ax[1].imshow(image, cmap='gray')
71
+ ax[1].imshow(mask, alpha=0.5, cmap="jet")
72
+ plt.tight_layout()
73
+
74
+ # Convert matplotlib figure to a PIL Image
75
+ buf = io.BytesIO()
76
+ fig.savefig(buf, format='png')
77
+ plt.close(fig) # Close the figure to release memory
78
+ buf.seek(0)
79
+ pil_img = Image.open(buf)
80
+
81
+ return pil_img
82
  # Main function for Gradio app
83
  def process_images(img_dict):
84
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
85
 
86
  # Load and preprocess image
87
+ print(img_dict)
88
  img = img_dict['image']
 
89
  points = img_dict['points'][0] # Accessing the first (and possibly only) set of points
90
  if len(points) >= 6:
91
  x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
 
96
  image = np.repeat(image[:, :, None], 3, axis=-1)
97
  H, W, _ = image.shape
98
 
99
+ image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
100
+ image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
101
+ image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
102
 
103
+ # Initialize the MedSAM model and set the device
104
+ model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint
105
+ medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
106
+ medsam_model = medsam_model.to(device)
107
+ medsam_model.eval()
108
+
109
+ # Generate image embedding
110
+ with torch.no_grad():
111
+ img_embed = medsam_model.image_encoder(image_tensor)
112
+
113
+ # Calculate resized box coordinates
114
+ scale_factors = np.array([1024 / W, 1024 / H, 1024 / W, 1024 / H])
115
+ box_1024 = np.array([x_min, y_min, x_max, y_max]) * scale_factors
116
+
117
+ # Perform inference
118
+ mask = medsam_inference(medsam_model, img_embed, box_1024, H, W)
119
+
120
+ # Visualization
121
+ visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
122
+ return visualization
123
  # Set up Gradio interface
124
  iface = gr.Interface(
125
  fn=process_images,
126
+ inputs=[
127
+ ImagePrompter(label="Image")
128
+ ],
129
  outputs=[
130
  gr.Image(type="pil", label="Processed Image")
131
  ],
 
134
  )
135
 
136
  # Launch the interface
137
+ iface.launch()