dennistrujillo commited on
Commit
55223b8
·
verified ·
1 Parent(s): a6be0a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -16,12 +16,14 @@ def load_bounding_boxes(csv_file):
16
  df = pd.read_csv(csv_file)
17
  return df
18
 
19
- def load_dicom_image(filename):
20
- if filename.endswith(".dcm"):
21
- ds = pydicom.dcmread(filename)
22
  img = ds.pixel_array
23
- H, W = img.shape
24
- return np.array(img), H, W
 
 
25
 
26
  # MedSAM inference function
27
  def medsam_inference(medsam_model, img, box, H, W, target_size):
@@ -59,12 +61,12 @@ def visualize(image, mask, box):
59
  return buf
60
 
61
  # Main function for Gradio app
62
- def process_images(dicom_file, x_min, y_min, x_max, y_max):
63
- image, H, W = load_dicom_image(dicom_file)
64
 
65
  # Initialize MedSAM model
66
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
67
- medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH) # Ensure the correct path
68
  medsam_model = medsam_model.to(device)
69
  medsam_model.eval()
70
 
@@ -78,7 +80,7 @@ def process_images(dicom_file, x_min, y_min, x_max, y_max):
78
  iface = gr.Interface(
79
  fn=process_images,
80
  inputs=[
81
- gr.File(label="DICOM File"),
82
  gr.Number(label="X min"),
83
  gr.Number(label="Y min"),
84
  gr.Number(label="X max"),
 
16
  df = pd.read_csv(csv_file)
17
  return df
18
 
19
+ def load_image(file_path):
20
+ if file_path.endswith(".dcm"):
21
+ ds = pydicom.dcmread(file_path)
22
  img = ds.pixel_array
23
+ else:
24
+ img = np.array(Image.open(file_path).convert('L')) # Convert to grayscale
25
+ H, W = img.shape
26
+ return img, H, W
27
 
28
  # MedSAM inference function
29
  def medsam_inference(medsam_model, img, box, H, W, target_size):
 
61
  return buf
62
 
63
  # Main function for Gradio app
64
+ def process_images(file, x_min, y_min, x_max, y_max):
65
+ image, H, W = load_image(file)
66
 
67
  # Initialize MedSAM model
68
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
69
+ medsam_model = sam_model_registry['vit_b'](checkpoint="medsam_vit_b.pth") # Ensure the correct path
70
  medsam_model = medsam_model.to(device)
71
  medsam_model.eval()
72
 
 
80
  iface = gr.Interface(
81
  fn=process_images,
82
  inputs=[
83
+ gr.File(label="MRI Slice (DICOM, PNG, etc.)"),
84
  gr.Number(label="X min"),
85
  gr.Number(label="Y min"),
86
  gr.Number(label="X max"),