franchesoni's picture
add watershed
9754079
from PIL import Image
import torch
import numpy as np
import gradio as gr
from pathlib import Path
from busam import Busam
resize_to = 512
checkpoint = "weights.pth"
device = "cpu"
print("Loading model...")
busam = Busam(checkpoint=checkpoint, device=device, side=resize_to)
minmaxnorm = lambda x: (x - x.min()) / (x.max() - x.min())
def edge_inference(img, algorithm, th_low=None, th_high=None):
algorithm = algorithm.lower()
print("Loading image...")
img = np.array(img[:, :, :3])
print("Getting features...")
pred, size = busam.process_image(img, do_activate=True)
print("Computing sobel...")
if algorithm == "sobel":
edge = busam.sobel_from_pred(pred, size)
elif algorithm == "canny":
th_low, th_high = th_low or 5000, th_high or 10000
edge = busam.canny_from_pred(pred, size, th_low=th_low, th_high=th_high)
else:
raise ValueError("algorithm should be sobel or canny")
edge = edge.cpu().numpy() if isinstance(edge, torch.Tensor) else edge
print("Done")
return Image.fromarray(
(minmaxnorm(edge) * 255).astype(np.uint8)
).resize(size[::-1])
def dimred_inference(
img,
algorithm,
resample_pct,
):
algorithm = algorithm.lower()
img = np.array(img[:, :, :3])
print("Getting features...")
pred, size = busam.process_image(img, do_activate=True)
# pred is 1, F, S, S
assert pred.shape[1] >= 3, "should have at least 3 channels"
if algorithm == 'pca':
from sklearn.decomposition import PCA
reducer = PCA(n_components=3)
elif algorithm == 'tsne':
from sklearn.manifold import TSNE
reducer = TSNE(n_components=3)
elif algorithm == 'umap':
from umap import UMAP
reducer = UMAP(n_components=3)
else:
raise ValueError('algorithm should be pca, tsne or umap')
np_y_hat = pred.detach().cpu().permute(1, 0, 2, 3).numpy() # F, B, H, W
np_y_hat = np_y_hat.reshape(np_y_hat.shape[0], -1) # F, BHW
np_y_hat = np_y_hat.T # BHW, F
resample_pct = 10**resample_pct
resample_size = int(resample_pct * np_y_hat.shape[0])
sampled_pixels = np_y_hat[:: np_y_hat.shape[0] // resample_size]
print("dim reduction fit..." + " " * 30, end="\r")
reducer = reducer.fit(sampled_pixels)
print("dim reduction transform..." + " " * 30, end="\r")
reducer.transform(np_y_hat[:10]) # to numba compile the function
np_y_hat = reducer.transform(np_y_hat) # BHW, 3
print()
print('Done. Saving...')
# revert back to original shape
colors = np_y_hat.reshape(pred.shape[2], pred.shape[3], 3)
return Image.fromarray((minmaxnorm(colors) * 255).astype(np.uint8)).resize(
size[::-1]
)
def segmentation_inference(img, algorithm, scale):
algorithm = algorithm.lower()
img = np.array(img[:, :, :3])
print("Getting features...")
pred, size = busam.process_image(img, do_activate=True)
print("Computing segmentation...")
if algorithm == "kmeans":
from sklearn.cluster import KMeans
n_clusters = int(100 / 100**scale)
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(
pred.view(pred.shape[1], -1).T
)
labels = kmeans.labels_
labels = labels.reshape(pred.shape[2], pred.shape[3])
elif algorithm == "felzenszwalb":
from skimage.segmentation import felzenszwalb
labels = felzenszwalb(
(minmaxnorm(pred[0].cpu().numpy()) * 255).astype(np.uint8).transpose(1, 2, 0),
scale=10**(8*scale-3),
sigma=0,
min_size=50,
)
elif algorithm == "slic":
from skimage.segmentation import slic
labels = slic(
(minmaxnorm(pred[0].cpu().numpy()) * 255).astype(np.uint8).transpose(1, 2, 0),
n_segments = int(100 / 100**scale),
compactness=0.00001,
sigma=1,
)
elif algorithm == 'watershed':
from skimage.segmentation import watershed
from skimage.feature import peak_local_max
from scipy import ndimage as ndi
sobel = busam.sobel_from_pred(pred, size)
sobel = sobel.cpu().numpy() if isinstance(sobel, torch.Tensor) else sobel
# contrast stretch sobel with 5% largest
sobel = np.clip(sobel / np.percentile(sobel, 95), 0, 1)
distance = ndi.distance_transform_edt(sobel < 1) # distance to the borders
coords = peak_local_max(distance, min_distance=int(1+100*scale), labels=sobel<1)
mask = np.zeros(sobel.shape, dtype=bool)
mask[tuple(coords.T)] = True
markers, _ = ndi.label(mask)
labels = watershed(sobel, markers)
else:
raise ValueError("algorithm should be kmeans, felzenszwalb or slic")
print("Done")
# the labels have values that are usually close to each other in the image and in magnitude, which complicates visualization
# shuffle the labels to make them more visually distinct
out = labels.copy()
out[labels % 4 == 0] = labels[labels % 4 == 0] * 1 / 4
out[labels % 4 == 1] = labels[labels % 4 == 1] * 4 // 4 + 1
out[labels % 4 == 2] = labels[labels % 4 == 2] * 2 // 4 + 2
out[labels % 4 == 3] = labels[labels % 4 == 3] * 3 // 4 + 3
return Image.fromarray(
(minmaxnorm(out) * 255).astype(np.uint8)
).resize(size[::-1])
def one_click_segmentation(img, row, col, threshold):
row, col = int(row), int(col)
img = np.array(img[:, :, :3])
click_map = np.zeros(img.shape[:2], dtype=bool)
side = min(img.shape[:2]) // 100
click_map[max(0, row-side):min(img.shape[0], row+side), max(0, col-side//5):min(img.shape[0], col+side//5)] = True
click_map[max(0, row-side//5):min(img.shape[0], row+side//5), max(0, col-side):min(img.shape[0], col+side)] = True
print("Getting features...")
pred, size = busam.process_image(img, do_activate=True)
print("Getting mask...")
mask = busam.get_mask((pred, size), (row, col))
print("Done")
print('shapes=', img.shape, mask.shape, click_map.shape)
return (img, [(mask, 'Prediction'), (click_map, 'Click')])
with gr.Blocks() as demo:
with gr.Tab('Edge detection'):
algorithm = "canny"
with gr.Row():
def enable_sliders(algorithm):
algorithm = algorithm.lower()
return gr.Slider(visible=algorithm == "canny"), gr.Slider(visible=algorithm == "canny")
with gr.Column():
image_input = gr.Image(label="Input Image")
run_button = gr.Button("Run")
algorithm = gr.Radio(["Sobel", "Canny"], label="Algorithm", value="Sobel")
# add sliders for th_low, th_high
th_low_slider = gr.Slider(0, 32768, 10000, label="Canny's low threshold", visible=False)
th_high_slider = gr.Slider(0, 32768, 20000, label="Canny's high threshold", visible=False)
algorithm.change(enable_sliders, inputs=[algorithm], outputs=[th_low_slider, th_high_slider])
with gr.Column():
output_image = gr.Image(label="Output Image")
run_button.click(edge_inference, inputs=[image_input, algorithm, th_low_slider, th_high_slider], outputs=output_image)
gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input)
with gr.Tab('Reduction to 3D'):
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Input Image")
algorithm = gr.Radio(["PCA", "TSNE", "UMAP"], label="Algorithm", value="PCA")
run_button = gr.Button("Run")
gr.Markdown("⚠️ UMAP is slow, TSNE is ULTRA-slow. They won't run on time. ⚠️")
resample_pct = gr.Slider(-5, 0, -3, label="Resample (10^x)*100%")
with gr.Column():
output_image = gr.Image(label="Output Image")
run_button.click(dimred_inference, inputs=[image_input, algorithm, resample_pct], outputs=output_image)
gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input)
with gr.Tab('Classical Segmentation'):
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Input Image")
algorithm = gr.Radio(['KMeans', 'Felzenszwalb', 'SLIC', 'Watershed'], label="Algorithm", value="SLIC")
scale = gr.Slider(0.1, 1.0, 0.5, label="Scale")
run_button = gr.Button("Run")
with gr.Column():
output_image = gr.Image(label="Output Image")
run_button.click(segmentation_inference, inputs=[image_input, algorithm, scale], outputs=output_image)
gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input)
with gr.Tab('One-click segmentation'):
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Input Image")
threshold = gr.Slider(0, 1, 0.5, label="Threshold")
with gr.Row():
row = gr.Textbox(10, label="Click's row")
col = gr.Textbox(10, label="Click's column")
run_button = gr.Button("Run")
with gr.Column():
output_image = gr.AnnotatedImage(label="Output")
run_button.click(one_click_segmentation, inputs=[image_input, row, col, threshold], outputs=output_image)
gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input)
demo.launch(share=False)