Spaces:
Running
Running
dennistrujillo
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -16,16 +16,12 @@ def load_bounding_boxes(csv_file):
|
|
16 |
df = pd.read_csv(csv_file)
|
17 |
return df
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
ds = pydicom.dcmread(filepath)
|
26 |
-
img = ds.pixel_array
|
27 |
-
images.append(img)
|
28 |
-
return np.array(images)
|
29 |
|
30 |
# MedSAM inference function
|
31 |
def medsam_inference(medsam_model, img, box, H, W, target_size):
|
@@ -64,42 +60,34 @@ def visualize(images, masks, box):
|
|
64 |
return buf
|
65 |
|
66 |
# Main function for Gradio app
|
67 |
-
def process_images(csv_file,
|
68 |
bounding_boxes = load_bounding_boxes(csv_file)
|
69 |
-
|
70 |
|
71 |
# Initialize MedSAM model
|
72 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
73 |
-
medsam_model = sam_model_registry['
|
74 |
medsam_model = medsam_model.to(device)
|
75 |
medsam_model.eval()
|
76 |
|
77 |
masks = []
|
|
|
78 |
for index, row in bounding_boxes.iterrows():
|
79 |
-
if index >= len(dicom_images):
|
80 |
-
continue # Skip if the index exceeds the number of images
|
81 |
-
|
82 |
-
image = dicom_images[index]
|
83 |
-
H, W = image.shape
|
84 |
box = [row['x_min'], row['y_min'], row['x_max'], row['y_max']]
|
85 |
-
|
86 |
-
mask = medsam_inference(medsam_model, image, box, H, W, target_size)
|
87 |
masks.append(mask)
|
|
|
88 |
|
89 |
-
visualizations = visualize(
|
90 |
-
|
91 |
-
return visualizations, np.array(masks)
|
92 |
|
93 |
# Set up Gradio interface
|
94 |
iface = gr.Interface(
|
95 |
fn=process_images,
|
96 |
inputs=[
|
97 |
gr.File(label="CSV File"),
|
98 |
-
gr.File(label="
|
99 |
-
outputs=
|
100 |
-
gr.Image(type="pil"),
|
101 |
-
gr.File(type="numpy")
|
102 |
-
]
|
103 |
)
|
104 |
|
105 |
iface.launch()
|
|
|
16 |
df = pd.read_csv(csv_file)
|
17 |
return df
|
18 |
|
19 |
+
def load_dicom_image(filename):
|
20 |
+
if filename.endswith(".dcm"):
|
21 |
+
ds = pydicom.dcmread(filename)
|
22 |
+
img = ds.pixel_array
|
23 |
+
H, W = img.shape
|
24 |
+
return np.array(img), H, W
|
|
|
|
|
|
|
|
|
25 |
|
26 |
# MedSAM inference function
|
27 |
def medsam_inference(medsam_model, img, box, H, W, target_size):
|
|
|
60 |
return buf
|
61 |
|
62 |
# Main function for Gradio app
|
63 |
+
def process_images(csv_file, dicom_file):
|
64 |
bounding_boxes = load_bounding_boxes(csv_file)
|
65 |
+
image, H, W = load_dicom_image(dicom_file)
|
66 |
|
67 |
# Initialize MedSAM model
|
68 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
69 |
+
medsam_model = sam_model_registry['vit_b'](checkpoint="medsam_vit_b.pth") # Ensure the correct path
|
70 |
medsam_model = medsam_model.to(device)
|
71 |
medsam_model.eval()
|
72 |
|
73 |
masks = []
|
74 |
+
boxes = []
|
75 |
for index, row in bounding_boxes.iterrows():
|
|
|
|
|
|
|
|
|
|
|
76 |
box = [row['x_min'], row['y_min'], row['x_max'], row['y_max']]
|
77 |
+
mask = medsam_inference(medsam_model, image, box, H, W, H) # Assuming target size is the same as the image height
|
|
|
78 |
masks.append(mask)
|
79 |
+
boxes.append(box)
|
80 |
|
81 |
+
visualizations = visualize([image] * len(masks), masks, boxes)
|
82 |
+
return visualizations.getvalue()
|
|
|
83 |
|
84 |
# Set up Gradio interface
|
85 |
iface = gr.Interface(
|
86 |
fn=process_images,
|
87 |
inputs=[
|
88 |
gr.File(label="CSV File"),
|
89 |
+
gr.File(label="DICOM File")],
|
90 |
+
outputs="plot"
|
|
|
|
|
|
|
91 |
)
|
92 |
|
93 |
iface.launch()
|