dennistrujillo commited on
Commit
9840e47
·
verified ·
1 Parent(s): 30feade

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -28
app.py CHANGED
@@ -16,16 +16,12 @@ def load_bounding_boxes(csv_file):
16
  df = pd.read_csv(csv_file)
17
  return df
18
 
19
- # Function to load DICOM images
20
- def load_dicom_images(folder_path):
21
- images = []
22
- for filename in sorted(os.listdir(folder_path)):
23
- if filename.endswith(".dcm"):
24
- filepath = os.path.join(folder_path, filename)
25
- ds = pydicom.dcmread(filepath)
26
- img = ds.pixel_array
27
- images.append(img)
28
- return np.array(images)
29
 
30
  # MedSAM inference function
31
  def medsam_inference(medsam_model, img, box, H, W, target_size):
@@ -64,42 +60,34 @@ def visualize(images, masks, box):
64
  return buf
65
 
66
  # Main function for Gradio app
67
- def process_images(csv_file, dicom_folder):
68
  bounding_boxes = load_bounding_boxes(csv_file)
69
- dicom_images = load_dicom_images(dicom_folder)
70
 
71
  # Initialize MedSAM model
72
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
73
- medsam_model = sam_model_registry['your_model_version'](checkpoint='path_to_your_checkpoint')
74
  medsam_model = medsam_model.to(device)
75
  medsam_model.eval()
76
 
77
  masks = []
 
78
  for index, row in bounding_boxes.iterrows():
79
- if index >= len(dicom_images):
80
- continue # Skip if the index exceeds the number of images
81
-
82
- image = dicom_images[index]
83
- H, W = image.shape
84
  box = [row['x_min'], row['y_min'], row['x_max'], row['y_max']]
85
-
86
- mask = medsam_inference(medsam_model, image, box, H, W, target_size)
87
  masks.append(mask)
 
88
 
89
- visualizations = visualize(dicom_images, masks, box)
90
-
91
- return visualizations, np.array(masks)
92
 
93
  # Set up Gradio interface
94
  iface = gr.Interface(
95
  fn=process_images,
96
  inputs=[
97
  gr.File(label="CSV File"),
98
- gr.File(label="Zipped DICOM stack or nrrd file")],
99
- outputs=[
100
- gr.Image(type="pil"),
101
- gr.File(type="numpy")
102
- ]
103
  )
104
 
105
  iface.launch()
 
16
  df = pd.read_csv(csv_file)
17
  return df
18
 
19
+ def load_dicom_image(filename):
20
+ if filename.endswith(".dcm"):
21
+ ds = pydicom.dcmread(filename)
22
+ img = ds.pixel_array
23
+ H, W = img.shape
24
+ return np.array(img), H, W
 
 
 
 
25
 
26
  # MedSAM inference function
27
  def medsam_inference(medsam_model, img, box, H, W, target_size):
 
60
  return buf
61
 
62
  # Main function for Gradio app
63
+ def process_images(csv_file, dicom_file):
64
  bounding_boxes = load_bounding_boxes(csv_file)
65
+ image, H, W = load_dicom_image(dicom_file)
66
 
67
  # Initialize MedSAM model
68
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
69
+ medsam_model = sam_model_registry['vit_b'](checkpoint="medsam_vit_b.pth") # Ensure the correct path
70
  medsam_model = medsam_model.to(device)
71
  medsam_model.eval()
72
 
73
  masks = []
74
+ boxes = []
75
  for index, row in bounding_boxes.iterrows():
 
 
 
 
 
76
  box = [row['x_min'], row['y_min'], row['x_max'], row['y_max']]
77
+ mask = medsam_inference(medsam_model, image, box, H, W, H) # Assuming target size is the same as the image height
 
78
  masks.append(mask)
79
+ boxes.append(box)
80
 
81
+ visualizations = visualize([image] * len(masks), masks, boxes)
82
+ return visualizations.getvalue()
 
83
 
84
  # Set up Gradio interface
85
  iface = gr.Interface(
86
  fn=process_images,
87
  inputs=[
88
  gr.File(label="CSV File"),
89
+ gr.File(label="DICOM File")],
90
+ outputs="plot"
 
 
 
91
  )
92
 
93
  iface.launch()