Spaces:
Sleeping
Sleeping
franchesoni
commited on
Commit
•
e1b51e5
1
Parent(s):
2df2c09
v0
Browse files- .gitignore +5 -0
- app.py +197 -0
- busam.py +137 -0
- losses.py +211 -0
- network.py +267 -0
- utils.py +219 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.sh
|
2 |
+
*.pth
|
3 |
+
*.pkl
|
4 |
+
__pycache__/
|
5 |
+
flagged/
|
app.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import gradio as gr
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
from busam import Busam
|
8 |
+
|
9 |
+
resize_to = 512
|
10 |
+
checkpoint = "weights.pth"
|
11 |
+
device = "cpu"
|
12 |
+
print("Loading model...")
|
13 |
+
busam = Busam(checkpoint=checkpoint, device=device, side=resize_to)
|
14 |
+
minmaxnorm = lambda x: (x - x.min()) / (x.max() - x.min())
|
15 |
+
|
16 |
+
def edge_inference(img, algorithm, th_low=None, th_high=None):
|
17 |
+
algorithm = algorithm.lower()
|
18 |
+
print("Loading image...")
|
19 |
+
img = np.array(img[:, :, :3])
|
20 |
+
print("Getting features...")
|
21 |
+
pred, size = busam.process_image(img, do_activate=True)
|
22 |
+
print("Computing sobel...")
|
23 |
+
if algorithm == "sobel":
|
24 |
+
edge = busam.sobel_from_pred(pred, size)
|
25 |
+
elif algorithm == "canny":
|
26 |
+
th_low, th_high = th_low or 5000, th_high or 10000
|
27 |
+
edge = busam.canny_from_pred(pred, size, th_low=th_low, th_high=th_high)
|
28 |
+
else:
|
29 |
+
raise ValueError("algorithm should be sobel or canny")
|
30 |
+
edge = edge.cpu().numpy() if isinstance(edge, torch.Tensor) else edge
|
31 |
+
|
32 |
+
print("Done")
|
33 |
+
return Image.fromarray(
|
34 |
+
(minmaxnorm(edge) * 255).astype(np.uint8)
|
35 |
+
).resize(size[::-1])
|
36 |
+
|
37 |
+
def dimred_inference(
|
38 |
+
img,
|
39 |
+
algorithm,
|
40 |
+
resample_pct,
|
41 |
+
):
|
42 |
+
algorithm = algorithm.lower()
|
43 |
+
img = np.array(img[:, :, :3])
|
44 |
+
print("Getting features...")
|
45 |
+
pred, size = busam.process_image(img, do_activate=True)
|
46 |
+
# pred is 1, F, S, S
|
47 |
+
assert pred.shape[1] >= 3, "should have at least 3 channels"
|
48 |
+
if algorithm == 'pca':
|
49 |
+
from sklearn.decomposition import PCA
|
50 |
+
reducer = PCA(n_components=3)
|
51 |
+
elif algorithm == 'tsne':
|
52 |
+
from sklearn.manifold import TSNE
|
53 |
+
reducer = TSNE(n_components=3)
|
54 |
+
elif algorithm == 'umap':
|
55 |
+
from umap import UMAP
|
56 |
+
reducer = UMAP(n_components=3)
|
57 |
+
else:
|
58 |
+
raise ValueError('algorithm should be pca, tsne or umap')
|
59 |
+
np_y_hat = pred.detach().cpu().permute(1, 0, 2, 3).numpy() # F, B, H, W
|
60 |
+
np_y_hat = np_y_hat.reshape(np_y_hat.shape[0], -1) # F, BHW
|
61 |
+
np_y_hat = np_y_hat.T # BHW, F
|
62 |
+
resample_pct = 10**resample_pct
|
63 |
+
resample_size = int(resample_pct * np_y_hat.shape[0])
|
64 |
+
sampled_pixels = np_y_hat[:: np_y_hat.shape[0] // resample_size]
|
65 |
+
print("dim reduction fit..." + " " * 30, end="\r")
|
66 |
+
reducer = reducer.fit(sampled_pixels)
|
67 |
+
print("dim reduction transform..." + " " * 30, end="\r")
|
68 |
+
reducer.transform(np_y_hat[:10]) # to numba compile the function
|
69 |
+
np_y_hat = reducer.transform(np_y_hat) # BHW, 3
|
70 |
+
print()
|
71 |
+
print('Done. Saving...')
|
72 |
+
# revert back to original shape
|
73 |
+
colors = np_y_hat.reshape(pred.shape[2], pred.shape[3], 3)
|
74 |
+
return Image.fromarray((minmaxnorm(colors) * 255).astype(np.uint8)).resize(
|
75 |
+
size[::-1]
|
76 |
+
)
|
77 |
+
|
78 |
+
def segmentation_inference(img, algorithm, scale):
|
79 |
+
algorithm = algorithm.lower()
|
80 |
+
img = np.array(img[:, :, :3])
|
81 |
+
print("Getting features...")
|
82 |
+
pred, size = busam.process_image(img, do_activate=True)
|
83 |
+
print("Computing segmentation...")
|
84 |
+
if algorithm == "kmeans":
|
85 |
+
from sklearn.cluster import KMeans
|
86 |
+
n_clusters = int(100 / 100**scale)
|
87 |
+
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(
|
88 |
+
pred.view(pred.shape[1], -1).T
|
89 |
+
)
|
90 |
+
labels = kmeans.labels_
|
91 |
+
labels = labels.reshape(pred.shape[2], pred.shape[3])
|
92 |
+
elif algorithm == "felzenszwalb":
|
93 |
+
from skimage.segmentation import felzenszwalb
|
94 |
+
labels = felzenszwalb(
|
95 |
+
(minmaxnorm(pred[0].cpu().numpy()) * 255).astype(np.uint8).transpose(1, 2, 0),
|
96 |
+
scale=10**(8*scale-3),
|
97 |
+
sigma=0,
|
98 |
+
min_size=50,
|
99 |
+
)
|
100 |
+
elif algorithm == "slic":
|
101 |
+
from skimage.segmentation import slic
|
102 |
+
labels = slic(
|
103 |
+
(minmaxnorm(pred[0].cpu().numpy()) * 255).astype(np.uint8).transpose(1, 2, 0),
|
104 |
+
n_segments = int(100 / 100**scale),
|
105 |
+
compactness=0.00001,
|
106 |
+
sigma=1,
|
107 |
+
)
|
108 |
+
else:
|
109 |
+
raise ValueError("algorithm should be kmeans, felzenszwalb or slic")
|
110 |
+
print("Done")
|
111 |
+
# the labels have values that are usually close to each other in the image and in magnitude, which complicates visualization
|
112 |
+
# shuffle the labels to make them more visually distinct
|
113 |
+
out = labels.copy()
|
114 |
+
out[labels % 4 == 0] = labels[labels % 4 == 0] * 1 / 4
|
115 |
+
out[labels % 4 == 1] = labels[labels % 4 == 1] * 4 // 4 + 1
|
116 |
+
out[labels % 4 == 2] = labels[labels % 4 == 2] * 2 // 4 + 2
|
117 |
+
out[labels % 4 == 3] = labels[labels % 4 == 3] * 3 // 4 + 3
|
118 |
+
return Image.fromarray(
|
119 |
+
(minmaxnorm(out) * 255).astype(np.uint8)
|
120 |
+
).resize(size[::-1])
|
121 |
+
|
122 |
+
def one_click_segmentation(img, row, col, threshold):
|
123 |
+
row, col = int(row), int(col)
|
124 |
+
img = np.array(img[:, :, :3])
|
125 |
+
click_map = np.zeros(img.shape[:2], dtype=bool)
|
126 |
+
click_map[max(0, row-5):min(img.shape[0], row+5), col] = True
|
127 |
+
click_map[row, max(0, col-5):min(img.shape[1], col+5)] = True
|
128 |
+
print("Getting features...")
|
129 |
+
pred, size = busam.process_image(img, do_activate=True)
|
130 |
+
print("Getting mask...")
|
131 |
+
mask = busam.get_mask((pred, size), (row, col))
|
132 |
+
print("Done")
|
133 |
+
print('shapes=', img.shape, mask.shape, click_map.shape)
|
134 |
+
return (img, [(mask, 'Prediction'), (click_map, 'Click')])
|
135 |
+
|
136 |
+
with gr.Blocks() as demo:
|
137 |
+
with gr.Tab('Edge detection'):
|
138 |
+
algorithm = "canny"
|
139 |
+
with gr.Row():
|
140 |
+
def enable_sliders(algorithm):
|
141 |
+
algorithm = algorithm.lower()
|
142 |
+
return gr.Slider(visible=algorithm == "canny"), gr.Slider(visible=algorithm == "canny")
|
143 |
+
|
144 |
+
with gr.Column():
|
145 |
+
image_input = gr.Image(label="Input Image")
|
146 |
+
run_button = gr.Button("Run")
|
147 |
+
algorithm = gr.Radio(["Sobel", "Canny"], label="Algorithm", value="Sobel")
|
148 |
+
# add sliders for th_low, th_high
|
149 |
+
th_low_slider = gr.Slider(0, 32768, 10000, label="Canny's low threshold", visible=False)
|
150 |
+
th_high_slider = gr.Slider(0, 32768, 20000, label="Canny's high threshold", visible=False)
|
151 |
+
algorithm.change(enable_sliders, inputs=[algorithm], outputs=[th_low_slider, th_high_slider])
|
152 |
+
with gr.Column():
|
153 |
+
output_image = gr.Image(label="Output Image")
|
154 |
+
run_button.click(edge_inference, inputs=[image_input, algorithm, th_low_slider, th_high_slider], outputs=output_image)
|
155 |
+
gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input)
|
156 |
+
|
157 |
+
with gr.Tab('Reduction to 3D'):
|
158 |
+
with gr.Row():
|
159 |
+
with gr.Column():
|
160 |
+
image_input = gr.Image(label="Input Image")
|
161 |
+
algorithm = gr.Radio(["PCA", "TSNE", "UMAP"], label="Algorithm")
|
162 |
+
run_button = gr.Button("Run")
|
163 |
+
gr.Markdown("⚠️ UMAP is slow, TSNE is ultra-slow, use resample x<-3 ⚠️")
|
164 |
+
resample_pct = gr.Slider(-5, 0, -3, label="Resample (10^x)*100%")
|
165 |
+
with gr.Column():
|
166 |
+
output_image = gr.Image(label="Output Image")
|
167 |
+
run_button.click(dimred_inference, inputs=[image_input, algorithm, resample_pct], outputs=output_image)
|
168 |
+
gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input)
|
169 |
+
|
170 |
+
with gr.Tab('Classical Segmentation'):
|
171 |
+
with gr.Row():
|
172 |
+
with gr.Column():
|
173 |
+
image_input = gr.Image(label="Input Image")
|
174 |
+
algorithm = gr.Radio(['KMeans', 'Felzenszwalb', 'SLIC'], label="Algorithm", value="SLIC")
|
175 |
+
scale = gr.Slider(0.1, 1.0, 0.5, label="Scale")
|
176 |
+
run_button = gr.Button("Run")
|
177 |
+
with gr.Column():
|
178 |
+
output_image = gr.Image(label="Output Image")
|
179 |
+
run_button.click(segmentation_inference, inputs=[image_input, algorithm, scale], outputs=output_image)
|
180 |
+
gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input)
|
181 |
+
|
182 |
+
with gr.Tab('One-click segmentation'):
|
183 |
+
with gr.Row():
|
184 |
+
with gr.Column():
|
185 |
+
image_input = gr.Image(label="Input Image")
|
186 |
+
threshold = gr.Slider(0, 1, 0.5, label="Threshold")
|
187 |
+
with gr.Row():
|
188 |
+
row = gr.Textbox(10, label="Click's row")
|
189 |
+
col = gr.Textbox(10, label="Click's column")
|
190 |
+
run_button = gr.Button("Run")
|
191 |
+
with gr.Column():
|
192 |
+
output_image = gr.AnnotatedImage(label="Output")
|
193 |
+
run_button.click(one_click_segmentation, inputs=[image_input, row, col, threshold], outputs=output_image)
|
194 |
+
gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input)
|
195 |
+
|
196 |
+
|
197 |
+
demo.launch(share=False)
|
busam.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import numpy as np
|
4 |
+
from cv2 import resize
|
5 |
+
import cv2
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
from network import EfficientViT_l1_r224
|
9 |
+
from losses import IISLoss, activate
|
10 |
+
from utils import minmaxnorm, load_from_ckpt
|
11 |
+
|
12 |
+
|
13 |
+
class Busam:
|
14 |
+
def __init__(self, checkpoint, device, side=224):
|
15 |
+
out_channels = 16
|
16 |
+
use_norm_params = False
|
17 |
+
net = EfficientViT_l1_r224(
|
18 |
+
out_channels=out_channels, use_norm_params=use_norm_params, pretrained=False
|
19 |
+
)
|
20 |
+
net = load_from_ckpt(net, checkpoint)
|
21 |
+
net = net.to(device)
|
22 |
+
net.eval()
|
23 |
+
self.net = net
|
24 |
+
self.device = device
|
25 |
+
self.side = side
|
26 |
+
|
27 |
+
def prepare_img(self, img):
|
28 |
+
"""
|
29 |
+
assume H, W, 3 image
|
30 |
+
"""
|
31 |
+
assert len(img.shape) == 3, "should be H, W, 3 but is " + str(img.shape)
|
32 |
+
assert img.shape[2] == 3, "should be H, W, 3 but is " + str(img.shape)
|
33 |
+
assert img.min() >= 0, "min should be more than 0 but is " + str(img.min())
|
34 |
+
assert img.max() <= 255, "max should be less than 255 but is " + str(img.max())
|
35 |
+
assert img.dtype == np.uint8, "dtype should be np.uint8 but is " + str(
|
36 |
+
img.dtype
|
37 |
+
)
|
38 |
+
nimg = resize(img, (self.side, self.side))
|
39 |
+
tensorimg = (
|
40 |
+
(torch.from_numpy(nimg / 255).permute(2, 0, 1) - 0.5)
|
41 |
+
.float()[None]
|
42 |
+
.to(self.device)
|
43 |
+
)
|
44 |
+
return tensorimg
|
45 |
+
|
46 |
+
def process_image(self, img, do_activate=False):
|
47 |
+
with torch.no_grad():
|
48 |
+
x = self.prepare_img(img)
|
49 |
+
pred = self.net(x)
|
50 |
+
H, W = img.shape[:2]
|
51 |
+
if do_activate:
|
52 |
+
B, F, pH, pW = pred.shape
|
53 |
+
features, _, _, _ = activate(
|
54 |
+
pred.view(F, pH * pW), None, "symlog", False, False, False
|
55 |
+
)
|
56 |
+
pred = features.view(B, F, pH, pW)
|
57 |
+
return pred, (H, W)
|
58 |
+
|
59 |
+
def get_mask(self, aux, click):
|
60 |
+
"""assume click is (row, col)"""
|
61 |
+
pred = aux[0][0] # remove batch dim
|
62 |
+
oH, oW = aux[1]
|
63 |
+
F, H, W = pred.shape
|
64 |
+
features = pred.view(F, H * W)
|
65 |
+
rclick = click[0] * H // oH, click[1] * W // oW
|
66 |
+
sindex = rclick[0] * W + rclick[1]
|
67 |
+
mask = IISLoss.get_mask_from_query(features, sindex)
|
68 |
+
mask = mask.reshape(H, W)
|
69 |
+
mask = (
|
70 |
+
resize((mask.cpu().numpy() * 255).astype(np.uint8), (oW, oH)) > 100
|
71 |
+
).astype(bool)
|
72 |
+
return mask
|
73 |
+
|
74 |
+
def get_gradients(self, pred, size):
|
75 |
+
F, H, W = pred[0].shape
|
76 |
+
sobel_x = (
|
77 |
+
torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().to(pred.device)
|
78 |
+
)
|
79 |
+
sobel_y = sobel_x.T
|
80 |
+
sobel_x = sobel_x.repeat(F, 1, 1, 1)
|
81 |
+
sobel_y = sobel_y.repeat(F, 1, 1, 1)
|
82 |
+
edge_x = torch.nn.functional.conv2d(pred, sobel_x, padding=1, groups=F).view(
|
83 |
+
F, H, W
|
84 |
+
) # 1, F, H, W
|
85 |
+
edge_y = torch.nn.functional.conv2d(pred, sobel_y, padding=1, groups=F).view(
|
86 |
+
F, H, W
|
87 |
+
)
|
88 |
+
edge_x = torch.norm(edge_x, dim=0, p=2) # will take sqrt
|
89 |
+
edge_y = torch.norm(edge_y, dim=0, p=2) # H, W
|
90 |
+
return edge_x, edge_y
|
91 |
+
|
92 |
+
def sobel_from_pred(self, pred, size):
|
93 |
+
edge_x, edge_y = self.get_gradients(pred, size)
|
94 |
+
edge = torch.sqrt(edge_x**2 + edge_y**2)
|
95 |
+
return edge
|
96 |
+
|
97 |
+
def canny_from_pred(self, pred, size, th_low=10000, th_high=20000):
|
98 |
+
th_low = th_low or th_high
|
99 |
+
th_high = th_high or th_low
|
100 |
+
|
101 |
+
edge_x, edge_y = self.get_gradients(pred, size)
|
102 |
+
amin = min(edge_x.min(), edge_y.min())
|
103 |
+
amax = max(edge_x.max(), edge_y.max())
|
104 |
+
edge_x, edge_y = (edge_x - amin) / (amax - amin), (edge_y - amin) / (
|
105 |
+
amax - amin
|
106 |
+
)
|
107 |
+
canny = cv2.Canny(cast_to_int16(edge_x), cast_to_int16(edge_y), th_low, th_high)
|
108 |
+
return canny
|
109 |
+
|
110 |
+
|
111 |
+
def cast_to_int16(x):
|
112 |
+
if isinstance(x, torch.Tensor):
|
113 |
+
x = x.cpu().numpy()
|
114 |
+
return (x * 32767).astype(np.int16)
|
115 |
+
|
116 |
+
|
117 |
+
# from segment_anything import sam_model_registry, SamPredictor
|
118 |
+
# class SAM:
|
119 |
+
# sam_checkpoint = "sam_vit_b_01ec64.pth"
|
120 |
+
# model_type = "vit_b"
|
121 |
+
|
122 |
+
# def __init__(self, device):
|
123 |
+
# sam = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint)
|
124 |
+
# sam.to(device=device)
|
125 |
+
# self.predictor = SamPredictor(sam)
|
126 |
+
|
127 |
+
# def process_image(self, img):
|
128 |
+
# self.predictor.set_image(img)
|
129 |
+
# return None
|
130 |
+
|
131 |
+
# def get_mask(self, aux, click):
|
132 |
+
# input_point = np.array([[click[1], click[0]]])
|
133 |
+
# input_label = np.array([1])
|
134 |
+
# masks, scores, logits = self.predictor.predict(
|
135 |
+
# point_coords=input_point, point_labels=input_label, multimask_output=False
|
136 |
+
# )
|
137 |
+
# return masks[0]
|
losses.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
print("Importing standard...")
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
|
4 |
+
print("Importing external...")
|
5 |
+
import torch
|
6 |
+
from torch.nn.functional import binary_cross_entropy
|
7 |
+
|
8 |
+
# from matplotlib import pyplot as plt
|
9 |
+
|
10 |
+
print("Importing internal...")
|
11 |
+
from utils import preprocess_masks_features, get_row_col, symlog, calculate_iou
|
12 |
+
|
13 |
+
|
14 |
+
######### BINARY LOSSES ###############
|
15 |
+
def my_lovasz_hinge(logits, gt, downsample=False):
|
16 |
+
if downsample:
|
17 |
+
offset = int(torch.randint(downsample - 1, (1,)))
|
18 |
+
logits, gt = logits[:, offset::downsample], gt[:, offset::downsample]
|
19 |
+
# B, HW
|
20 |
+
gt = 1.0 * gt # go float
|
21 |
+
areas = gt.sum(dim=1, keepdims=True) # B, 1
|
22 |
+
# per_image = True, ignore = None
|
23 |
+
signs = 2 * gt - 1
|
24 |
+
errors = 1 - logits * signs
|
25 |
+
errors_sorted, perm = torch.sort(errors, dim=1, descending=True)
|
26 |
+
gt_sorted = torch.gather(gt, 1, perm) # B, HW
|
27 |
+
# lovasz grad
|
28 |
+
intersection = areas - gt_sorted.cumsum(dim=1) # B, HW
|
29 |
+
union = areas + (1 - gt_sorted).cumsum(dim=1) # B, HW
|
30 |
+
jaccard = 1 - intersection / union # B, HW
|
31 |
+
jaccard[:, 1:] = jaccard[:, 1:] - jaccard[:, :-1]
|
32 |
+
loss = (torch.relu(errors_sorted) * jaccard).sum(dim=1) # B,
|
33 |
+
return torch.nanmean(loss)
|
34 |
+
|
35 |
+
|
36 |
+
def focal_loss(scores, targets, alpha=0.25, gamma=2):
|
37 |
+
p = scores
|
38 |
+
ce_loss = binary_cross_entropy(p, targets, reduction="none")
|
39 |
+
p_t = p * targets + (1 - p) * (1 - targets)
|
40 |
+
loss = ce_loss * ((1 - p_t) ** gamma)
|
41 |
+
|
42 |
+
if alpha >= 0:
|
43 |
+
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
44 |
+
loss = alpha_t * loss
|
45 |
+
|
46 |
+
return loss
|
47 |
+
|
48 |
+
|
49 |
+
# also binary_cross_entropy and lovasz
|
50 |
+
|
51 |
+
|
52 |
+
########## SUBFUNCTIONS ######################3
|
53 |
+
def get_distances(features, refs, sigma, norm_p, square_distances, H, W):
|
54 |
+
# features: B, 1, F, HW
|
55 |
+
# refs: B, M, F, 1
|
56 |
+
# sigma: B, M, 1, 1
|
57 |
+
B, M = refs.shape[0], refs.shape[1]
|
58 |
+
distances = torch.norm(
|
59 |
+
features - refs, dim=2, p=norm_p, keepdim=True
|
60 |
+
) # B, M, 1, H*W
|
61 |
+
distances = distances**2 if square_distances else distances
|
62 |
+
distances = (distances / (2 * sigma**2)).reshape(B, M, H * W)
|
63 |
+
return distances
|
64 |
+
|
65 |
+
|
66 |
+
def activate(features, masks, activation, use_sigma, offset_pos, ret_prediction):
|
67 |
+
# sigmoid is very similar to exp
|
68 |
+
# prepare features
|
69 |
+
assert activation in ["sigmoid", "symlog"]
|
70 |
+
if masks is None: # when inferencing
|
71 |
+
B, M = 1, 1
|
72 |
+
F, N = sorted(features.shape)
|
73 |
+
H, W = [int(N ** (0.5))] * 2
|
74 |
+
features = features.reshape(1, 1, -1, H * W)
|
75 |
+
else:
|
76 |
+
masks, features, M, B, H, W, F = preprocess_masks_features(masks, features)
|
77 |
+
# features: B, 1, F, H*W
|
78 |
+
# masks: B, M, 1, H*W
|
79 |
+
if use_sigma:
|
80 |
+
sigma = torch.nn.functional.softplus(features)[:, :, -1:] # B, 1, 1, H*W
|
81 |
+
features = features[:, :, :-1]
|
82 |
+
F = features.shape[2]
|
83 |
+
else:
|
84 |
+
sigma = 1
|
85 |
+
features = symlog(features) if activation == "symlog" else torch.sigmoid(features)
|
86 |
+
if offset_pos:
|
87 |
+
assert F >= 2
|
88 |
+
row, col = get_row_col(H, W, features.device)
|
89 |
+
row = row.reshape(1, 1, 1, H, 1).expand(B, 1, 1, H, W).reshape(B, 1, 1, H * W)
|
90 |
+
col = col.reshape(1, 1, 1, 1, W).expand(B, 1, 1, H, W).reshape(B, 1, 1, H * W)
|
91 |
+
positional_features = torch.cat([row, col], dim=2) # B, 1, 2, H*W
|
92 |
+
features[:, :, :2] = features[:, :, :2] + positional_features
|
93 |
+
prediction = features.reshape(B, 1, -1, H, W) if ret_prediction else None
|
94 |
+
if masks is None:
|
95 |
+
features = features.reshape(-1, H * W)
|
96 |
+
sigma = sigma.reshape(-1, H * W) if use_sigma else 1
|
97 |
+
return features, sigma, H, W
|
98 |
+
return features, masks, sigma, prediction, B, M, F, H, W
|
99 |
+
|
100 |
+
|
101 |
+
class AbstractLoss(ABC):
|
102 |
+
@staticmethod
|
103 |
+
@abstractmethod
|
104 |
+
def loss(features, masks, ret_prediction=False, **kwargs):
|
105 |
+
pass
|
106 |
+
|
107 |
+
@staticmethod
|
108 |
+
@abstractmethod
|
109 |
+
def get_mask_from_query(features, sindex, **kwargs):
|
110 |
+
pass
|
111 |
+
|
112 |
+
|
113 |
+
class IISLoss(AbstractLoss):
|
114 |
+
@staticmethod
|
115 |
+
def loss(features, masks, ret_prediction=False, K=3, logger=None):
|
116 |
+
features, masks, sigma, prediction, B, M, F, H, W = activate(
|
117 |
+
features, masks, "symlog", False, False, ret_prediction
|
118 |
+
)
|
119 |
+
rindices = torch.randperm(H * W, device=masks.device)
|
120 |
+
# the following should work if all masks have more than K pixels
|
121 |
+
sindices = torch.stack(
|
122 |
+
[
|
123 |
+
torch.stack([rindices[masks[b, m, 0, rindices]][:K] for m in range(M)])
|
124 |
+
for b in range(B)
|
125 |
+
]
|
126 |
+
) # B, M, K
|
127 |
+
feats_at_sindices = torch.gather(
|
128 |
+
features.permute(0, 3, 1, 2).expand(B, H * W, K, F),
|
129 |
+
dim=1,
|
130 |
+
index=sindices.reshape(B, M, K, 1).expand(B, M, K, F),
|
131 |
+
) # B, M, K, F
|
132 |
+
feats_at_sindices = feats_at_sindices.reshape(B, M, K, F, 1) # B, M, K, F, 1
|
133 |
+
dists = get_distances(
|
134 |
+
features, feats_at_sindices.reshape(B, M * K, F, 1), sigma, 2, True, H, W
|
135 |
+
)
|
136 |
+
score = torch.exp(-dists) # B, M*K, H*W [0, 1]
|
137 |
+
targets = (
|
138 |
+
masks.expand(B, M, K, H * W).reshape(B, M * K, H * W).float()
|
139 |
+
) # B, M, K, H*W
|
140 |
+
floss = focal_loss(score, targets).mean()
|
141 |
+
lloss = my_lovasz_hinge(
|
142 |
+
score.view(B * M * K, H * W) * 2 - 1,
|
143 |
+
targets.view(B * M * K, H * W),
|
144 |
+
)
|
145 |
+
loss = floss + lloss
|
146 |
+
return loss, prediction
|
147 |
+
|
148 |
+
@staticmethod
|
149 |
+
def get_mask_from_query(features, sindex):
|
150 |
+
features, _, H, W = activate(features, None, "symlog", False, False, False)
|
151 |
+
F = features.shape[0]
|
152 |
+
query_feat = features[:, sindex]
|
153 |
+
dists = get_distances(
|
154 |
+
features.reshape(1, 1, F, H * W),
|
155 |
+
query_feat.reshape(1, 1, F, 1),
|
156 |
+
1,
|
157 |
+
2,
|
158 |
+
True,
|
159 |
+
H,
|
160 |
+
W,
|
161 |
+
)
|
162 |
+
score = torch.exp(-dists) # 1, H*W
|
163 |
+
pred = score > 0.5
|
164 |
+
return pred
|
165 |
+
|
166 |
+
|
167 |
+
def iis_iou(features, masks, get_mask_from_query, K=20):
|
168 |
+
masks, features, M, B, H, W, F = preprocess_masks_features(masks, features)
|
169 |
+
# features: B, 1, F, H*W
|
170 |
+
# masks: B, M, 1, H*W
|
171 |
+
rindices = torch.randperm(H * W).to(masks.device)
|
172 |
+
sindices = torch.stack(
|
173 |
+
[
|
174 |
+
torch.stack([rindices[masks[b, m, 0, rindices]][:K] for m in range(M)])
|
175 |
+
for b in range(B)
|
176 |
+
]
|
177 |
+
) # B, M, K
|
178 |
+
cum_iou, n_samples = 0, 0
|
179 |
+
for b in range(B):
|
180 |
+
for m in range(M):
|
181 |
+
for k in range(K):
|
182 |
+
sindex = sindices[b, m, k]
|
183 |
+
pred = get_mask_from_query(features[b, 0], sindex)
|
184 |
+
iou = calculate_iou(pred, masks[b, m, 0, :])
|
185 |
+
cum_iou += iou
|
186 |
+
n_samples += 1
|
187 |
+
|
188 |
+
return cum_iou / n_samples
|
189 |
+
|
190 |
+
|
191 |
+
losses_names = [
|
192 |
+
"iis",
|
193 |
+
]
|
194 |
+
#
|
195 |
+
|
196 |
+
|
197 |
+
def get_loss_class(loss_name):
|
198 |
+
if loss_name == "iis":
|
199 |
+
return IISLoss
|
200 |
+
else:
|
201 |
+
raise NotImplementedError
|
202 |
+
|
203 |
+
|
204 |
+
def get_get_mask_from_query(loss_name):
|
205 |
+
loss_class = get_loss_class(loss_name)
|
206 |
+
return loss_class.get_mask_from_query
|
207 |
+
|
208 |
+
|
209 |
+
def get_loss(loss_name):
|
210 |
+
loss_class = get_loss_class(loss_name)
|
211 |
+
return loss_class.loss
|
network.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
print("Importing external...")
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from timm.models.efficientvit_mit import (
|
7 |
+
ConvNormAct,
|
8 |
+
FusedMBConv,
|
9 |
+
MBConv,
|
10 |
+
ResidualBlock,
|
11 |
+
efficientvit_l1,
|
12 |
+
)
|
13 |
+
from timm.layers import GELUTanh
|
14 |
+
|
15 |
+
|
16 |
+
def val2list(x: list or tuple or any, repeat_time=1):
|
17 |
+
if isinstance(x, (list, tuple)):
|
18 |
+
return list(x)
|
19 |
+
return [x for _ in range(repeat_time)]
|
20 |
+
|
21 |
+
|
22 |
+
def resize(
|
23 |
+
x: torch.Tensor,
|
24 |
+
size: any or None = None,
|
25 |
+
scale_factor: list[float] or None = None,
|
26 |
+
mode: str = "bicubic",
|
27 |
+
align_corners: bool or None = False,
|
28 |
+
) -> torch.Tensor:
|
29 |
+
if mode in {"bilinear", "bicubic"}:
|
30 |
+
return F.interpolate(
|
31 |
+
x,
|
32 |
+
size=size,
|
33 |
+
scale_factor=scale_factor,
|
34 |
+
mode=mode,
|
35 |
+
align_corners=align_corners,
|
36 |
+
)
|
37 |
+
elif mode in {"nearest", "area"}:
|
38 |
+
return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode)
|
39 |
+
else:
|
40 |
+
raise NotImplementedError(f"resize(mode={mode}) not implemented.")
|
41 |
+
|
42 |
+
|
43 |
+
class UpSampleLayer(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
mode="bicubic",
|
47 |
+
size: int or tuple[int, int] or list[int] or None = None,
|
48 |
+
factor=2,
|
49 |
+
align_corners=False,
|
50 |
+
):
|
51 |
+
super(UpSampleLayer, self).__init__()
|
52 |
+
self.mode = mode
|
53 |
+
self.size = val2list(size, 2) if size is not None else None
|
54 |
+
self.factor = None if self.size is not None else factor
|
55 |
+
self.align_corners = align_corners
|
56 |
+
|
57 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
58 |
+
if (
|
59 |
+
self.size is not None and tuple(x.shape[-2:]) == self.size
|
60 |
+
) or self.factor == 1:
|
61 |
+
return x
|
62 |
+
return resize(x, self.size, self.factor, self.mode, self.align_corners)
|
63 |
+
|
64 |
+
|
65 |
+
class DAGBlock(nn.Module):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
inputs: dict[str, nn.Module],
|
69 |
+
merge: str,
|
70 |
+
post_input: nn.Module or None,
|
71 |
+
middle: nn.Module,
|
72 |
+
outputs: dict[str, nn.Module],
|
73 |
+
):
|
74 |
+
super(DAGBlock, self).__init__()
|
75 |
+
|
76 |
+
self.input_keys = list(inputs.keys())
|
77 |
+
self.input_ops = nn.ModuleList(list(inputs.values()))
|
78 |
+
self.merge = merge
|
79 |
+
self.post_input = post_input
|
80 |
+
|
81 |
+
self.middle = middle
|
82 |
+
|
83 |
+
self.output_keys = list(outputs.keys())
|
84 |
+
self.output_ops = nn.ModuleList(list(outputs.values()))
|
85 |
+
|
86 |
+
def forward(self, feature_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
87 |
+
feat = [
|
88 |
+
op(feature_dict[key]) for key, op in zip(self.input_keys, self.input_ops)
|
89 |
+
]
|
90 |
+
if self.merge == "add":
|
91 |
+
feat = list_sum(feat)
|
92 |
+
elif self.merge == "cat":
|
93 |
+
feat = torch.concat(feat, dim=1)
|
94 |
+
else:
|
95 |
+
raise NotImplementedError
|
96 |
+
if self.post_input is not None:
|
97 |
+
feat = self.post_input(feat)
|
98 |
+
feat = self.middle(feat)
|
99 |
+
for key, op in zip(self.output_keys, self.output_ops):
|
100 |
+
feature_dict[key] = op(feat)
|
101 |
+
return feature_dict
|
102 |
+
|
103 |
+
|
104 |
+
def list_sum(x: list) -> any:
|
105 |
+
return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
|
106 |
+
|
107 |
+
|
108 |
+
class SegHead(nn.Module):
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
fid_list: list[str],
|
112 |
+
in_channel_list: list[int],
|
113 |
+
stride_list: list[int],
|
114 |
+
head_stride: int,
|
115 |
+
head_width: int,
|
116 |
+
head_depth: int,
|
117 |
+
expand_ratio: float,
|
118 |
+
middle_op: str,
|
119 |
+
final_expand: float or None,
|
120 |
+
n_classes: int,
|
121 |
+
dropout=0,
|
122 |
+
norm="bn2d",
|
123 |
+
act_func="hswish",
|
124 |
+
):
|
125 |
+
super(SegHead, self).__init__()
|
126 |
+
# exceptions to adapt effvit to timm
|
127 |
+
if act_func == "gelu":
|
128 |
+
act_func = GELUTanh
|
129 |
+
else:
|
130 |
+
raise ValueError(f"act_func {act_func} not supported")
|
131 |
+
if norm == "bn2d":
|
132 |
+
norm_layer = nn.BatchNorm2d
|
133 |
+
else:
|
134 |
+
raise ValueError(f"norm {norm} not supported")
|
135 |
+
|
136 |
+
inputs = {}
|
137 |
+
for fid, in_channel, stride in zip(fid_list, in_channel_list, stride_list):
|
138 |
+
factor = stride // head_stride
|
139 |
+
if factor == 1:
|
140 |
+
inputs[fid] = ConvNormAct(
|
141 |
+
in_channel, head_width, 1, norm_layer=norm_layer, act_layer=act_func
|
142 |
+
)
|
143 |
+
else:
|
144 |
+
inputs[fid] = nn.Sequential(
|
145 |
+
ConvNormAct(
|
146 |
+
in_channel,
|
147 |
+
head_width,
|
148 |
+
1,
|
149 |
+
norm_layer=norm_layer,
|
150 |
+
act_layer=act_func,
|
151 |
+
),
|
152 |
+
UpSampleLayer(factor=factor),
|
153 |
+
)
|
154 |
+
self.in_keys = inputs.keys()
|
155 |
+
self.in_ops = nn.ModuleList(inputs.values())
|
156 |
+
|
157 |
+
middle = []
|
158 |
+
for _ in range(head_depth):
|
159 |
+
if middle_op == "mbconv":
|
160 |
+
block = MBConv(
|
161 |
+
head_width,
|
162 |
+
head_width,
|
163 |
+
expand_ratio=expand_ratio,
|
164 |
+
norm_layer=norm_layer,
|
165 |
+
act_layer=(act_func, act_func, None),
|
166 |
+
)
|
167 |
+
elif middle_op == "fmbconv":
|
168 |
+
block = FusedMBConv(
|
169 |
+
head_width,
|
170 |
+
head_width,
|
171 |
+
expand_ratio=expand_ratio,
|
172 |
+
norm_layer=norm_layer,
|
173 |
+
act_layer=(act_func, None),
|
174 |
+
)
|
175 |
+
else:
|
176 |
+
raise NotImplementedError
|
177 |
+
middle.append(ResidualBlock(block, nn.Identity()))
|
178 |
+
self.middle = nn.Sequential(*middle)
|
179 |
+
|
180 |
+
self.out_layer = nn.Sequential(
|
181 |
+
*[
|
182 |
+
None
|
183 |
+
if final_expand is None
|
184 |
+
else ConvNormAct(
|
185 |
+
head_width,
|
186 |
+
head_width * final_expand,
|
187 |
+
1,
|
188 |
+
norm_layer=norm_layer,
|
189 |
+
act_layer=act_func,
|
190 |
+
),
|
191 |
+
ConvNormAct(
|
192 |
+
head_width * (final_expand or 1),
|
193 |
+
n_classes,
|
194 |
+
1,
|
195 |
+
bias=True,
|
196 |
+
dropout=dropout,
|
197 |
+
norm_layer=None,
|
198 |
+
act_layer=None,
|
199 |
+
),
|
200 |
+
]
|
201 |
+
)
|
202 |
+
|
203 |
+
def forward(self, feature_map_list):
|
204 |
+
t_feat_maps = [
|
205 |
+
self.in_ops[ind](feature_map_list[ind])
|
206 |
+
for ind in range(len(feature_map_list))
|
207 |
+
]
|
208 |
+
t_feat_map = list_sum(t_feat_maps)
|
209 |
+
t_feat_map = self.middle(t_feat_map)
|
210 |
+
out = self.out_layer(t_feat_map)
|
211 |
+
return out
|
212 |
+
|
213 |
+
|
214 |
+
class EfficientViT_l1_r224(nn.Module):
|
215 |
+
def __init__(
|
216 |
+
self,
|
217 |
+
out_channels,
|
218 |
+
out_ds_factor=1,
|
219 |
+
decoder_size="small",
|
220 |
+
pretrained=False,
|
221 |
+
use_norm_params=False,
|
222 |
+
):
|
223 |
+
if decoder_size == "small":
|
224 |
+
head_width = 32
|
225 |
+
head_depth = 1
|
226 |
+
middle_op = "mbconv"
|
227 |
+
elif decoder_size == "medium":
|
228 |
+
head_width = 64
|
229 |
+
head_depth = 3
|
230 |
+
middle_op = "mbconv"
|
231 |
+
elif decoder_size == "large":
|
232 |
+
head_width = 256
|
233 |
+
head_depth = 3
|
234 |
+
middle_op = "fmbconv"
|
235 |
+
|
236 |
+
super(EfficientViT_l1_r224, self).__init__()
|
237 |
+
self.bbone = efficientvit_l1(
|
238 |
+
num_classes=0, features_only=True, pretrained=pretrained
|
239 |
+
)
|
240 |
+
self.head = SegHead(
|
241 |
+
fid_list=["stage4", "stage3", "stage2"],
|
242 |
+
in_channel_list=[512, 256, 128],
|
243 |
+
stride_list=[32, 16, 8],
|
244 |
+
head_stride=out_ds_factor,
|
245 |
+
head_width=head_width,
|
246 |
+
head_depth=head_depth,
|
247 |
+
expand_ratio=4,
|
248 |
+
middle_op=middle_op,
|
249 |
+
final_expand=8,
|
250 |
+
n_classes=out_channels,
|
251 |
+
act_func="gelu",
|
252 |
+
)
|
253 |
+
# [optional] deactivate normalization
|
254 |
+
if not use_norm_params:
|
255 |
+
for module in self.modules():
|
256 |
+
if (
|
257 |
+
isinstance(module, nn.LayerNorm)
|
258 |
+
or isinstance(module, nn.BatchNorm2d)
|
259 |
+
or isinstance(module, nn.BatchNorm1d)
|
260 |
+
):
|
261 |
+
module.weight.requires_grad_(False)
|
262 |
+
module.bias.requires_grad_(False)
|
263 |
+
|
264 |
+
def forward(self, x):
|
265 |
+
feat = self.bbone(x)
|
266 |
+
out = self.head([feat[3], feat[2], feat[1]])
|
267 |
+
return out
|
utils.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
print("Importing standard...")
|
2 |
+
import subprocess
|
3 |
+
import shutil
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
print("Importing external...")
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
REDUCTION = "pca"
|
12 |
+
if REDUCTION == "umap":
|
13 |
+
from umap import UMAP
|
14 |
+
elif REDUCTION == "tsne":
|
15 |
+
from sklearn.manifold import TSNE
|
16 |
+
elif REDUCTION == "pca":
|
17 |
+
from sklearn.decomposition import PCA
|
18 |
+
|
19 |
+
|
20 |
+
def symlog(x):
|
21 |
+
return torch.sign(x) * torch.log(torch.abs(x) + 1)
|
22 |
+
|
23 |
+
|
24 |
+
def preprocess_masks_features(masks, features):
|
25 |
+
# Get shapes right
|
26 |
+
B, M, H, W = masks.shape
|
27 |
+
Bf, F, Hf, Wf = features.shape
|
28 |
+
masks = masks.reshape(B, M, 1, H * W)
|
29 |
+
# # the following assertions should work, remove due to speed
|
30 |
+
# assert H == Hf and W == Wf and B == Bf
|
31 |
+
# assert masks.dtype == torch.bool
|
32 |
+
# assert (mask_areas > 0).all(), "you shouldn't have empty masks"
|
33 |
+
|
34 |
+
# Reduce M if there are empty masks
|
35 |
+
mask_areas = masks.sum(dim=3) # B, M, 1
|
36 |
+
features = features.reshape(B, 1, F, H * W)
|
37 |
+
# output shapes
|
38 |
+
# features: B, 1, F, H*W
|
39 |
+
# masks: B, M, 1, H*W
|
40 |
+
|
41 |
+
return masks, features, M, B, H, W, F
|
42 |
+
|
43 |
+
|
44 |
+
def get_row_col(H, W, device):
|
45 |
+
# get position of pixels in [0, 1]
|
46 |
+
row = torch.linspace(0, 1, H, device=device)
|
47 |
+
col = torch.linspace(0, 1, W, device=device)
|
48 |
+
return row, col
|
49 |
+
|
50 |
+
|
51 |
+
def get_current_git_commit():
|
52 |
+
try:
|
53 |
+
# Run the git command to get the current commit hash
|
54 |
+
commit_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()
|
55 |
+
# Decode from bytes to a string
|
56 |
+
return commit_hash.decode("utf-8")
|
57 |
+
except subprocess.CalledProcessError:
|
58 |
+
# Handle the case where the command fails (e.g., not a Git repository)
|
59 |
+
print("An error occurred while trying to retrieve the git commit hash.")
|
60 |
+
return None
|
61 |
+
|
62 |
+
|
63 |
+
def clean_dir(dirname):
|
64 |
+
"""Removes all directories in dirname that don't have a done.txt file"""
|
65 |
+
dstdir = Path(dirname)
|
66 |
+
dstdir.mkdir(exist_ok=True, parents=True)
|
67 |
+
for f in dstdir.iterdir():
|
68 |
+
# if the directory doesn't have a done.txt file remove it
|
69 |
+
if f.is_dir() and not (f / "done.txt").exists():
|
70 |
+
shutil.rmtree(f)
|
71 |
+
|
72 |
+
|
73 |
+
def save_tensor_as_image(tensor, dstfile, global_step):
|
74 |
+
dstfile = Path(dstfile)
|
75 |
+
dstfile = (dstfile.parent / (dstfile.stem + "_" + str(global_step))).with_suffix(
|
76 |
+
".jpg"
|
77 |
+
)
|
78 |
+
save(tensor, str(dstfile))
|
79 |
+
|
80 |
+
|
81 |
+
def minmaxnorm(x):
|
82 |
+
return (x - x.min()) / (x.max() - x.min())
|
83 |
+
|
84 |
+
|
85 |
+
def save(tensor, name, channel_offset=0):
|
86 |
+
tensor = to_img(tensor, channel_offset=channel_offset)
|
87 |
+
Image.fromarray(tensor).save(name)
|
88 |
+
|
89 |
+
|
90 |
+
def to_img(tensor, channel_offset=0):
|
91 |
+
tensor = minmaxnorm(tensor)
|
92 |
+
tensor = (tensor * 255).to(torch.uint8)
|
93 |
+
C, H, W = tensor.shape
|
94 |
+
if tensor.shape[0] == 1:
|
95 |
+
tensor = tensor[0]
|
96 |
+
elif tensor.shape[0] == 2:
|
97 |
+
tensor = torch.stack([tensor[0], torch.zeros_like(tensor[0]), tensor[1]], dim=0)
|
98 |
+
tensor = tensor.permute(1, 2, 0)
|
99 |
+
elif tensor.shape[0] >= 3:
|
100 |
+
tensor = tensor[channel_offset : channel_offset + 3]
|
101 |
+
tensor = tensor.permute(1, 2, 0)
|
102 |
+
tensor = tensor.cpu().numpy()
|
103 |
+
return tensor
|
104 |
+
|
105 |
+
|
106 |
+
def log_input_output(
|
107 |
+
name,
|
108 |
+
x,
|
109 |
+
y_hat,
|
110 |
+
global_step,
|
111 |
+
img_dstdir,
|
112 |
+
out_dstdir,
|
113 |
+
reduce_dim=True,
|
114 |
+
reduction=REDUCTION,
|
115 |
+
resample_size=20000,
|
116 |
+
):
|
117 |
+
y_hat = y_hat.reshape(
|
118 |
+
y_hat.shape[0], y_hat.shape[2], y_hat.shape[3], y_hat.shape[4]
|
119 |
+
)
|
120 |
+
if reduce_dim and y_hat.shape[1] >= 3:
|
121 |
+
reducer = (
|
122 |
+
UMAP(n_components=3)
|
123 |
+
if (reduction == "umap")
|
124 |
+
else (
|
125 |
+
TSNE(n_components=3)
|
126 |
+
if reduction == "tsne"
|
127 |
+
else PCA(n_components=3)
|
128 |
+
if reduction == "pca"
|
129 |
+
else None
|
130 |
+
)
|
131 |
+
)
|
132 |
+
np_y_hat = y_hat.detach().cpu().permute(1, 0, 2, 3).numpy() # F, 1, B, H, W
|
133 |
+
np_y_hat = np_y_hat.reshape(np_y_hat.shape[0], -1) # F, BHW
|
134 |
+
np_y_hat = np_y_hat.T # BHW, F
|
135 |
+
sampled_pixels = np_y_hat[:: np_y_hat.shape[0] // resample_size]
|
136 |
+
print("dim reduction fit..." + " " * 30, end="\r")
|
137 |
+
reducer = reducer.fit(sampled_pixels)
|
138 |
+
print("dim reduction transform..." + " " * 30, end="\r")
|
139 |
+
reducer.transform(np_y_hat[:10]) # to numba compile the function
|
140 |
+
np_y_hat = reducer.transform(np_y_hat) # BHW, 3
|
141 |
+
# revert back to original shape
|
142 |
+
y_hat2 = (
|
143 |
+
torch.from_numpy(
|
144 |
+
np_y_hat.T.reshape(3, y_hat.shape[0], y_hat.shape[2], y_hat.shape[3])
|
145 |
+
)
|
146 |
+
.to(y_hat.device)
|
147 |
+
.permute(1, 0, 2, 3)
|
148 |
+
)
|
149 |
+
print("done" + " " * 30, end="\r")
|
150 |
+
else:
|
151 |
+
y_hat2 = y_hat
|
152 |
+
|
153 |
+
for i in range(min(len(x), 8)):
|
154 |
+
save_tensor_as_image(
|
155 |
+
x[i],
|
156 |
+
img_dstdir / f"input_{name}_{str(i).zfill(2)}",
|
157 |
+
global_step=global_step,
|
158 |
+
)
|
159 |
+
for c in range(y_hat.shape[1]):
|
160 |
+
save_tensor_as_image(
|
161 |
+
y_hat[i, c : c + 1],
|
162 |
+
out_dstdir / f"pred_channel_{name}_{str(i).zfill(2)}_{c}",
|
163 |
+
global_step=global_step,
|
164 |
+
)
|
165 |
+
# log color image
|
166 |
+
|
167 |
+
assert len(y_hat2.shape) == 4, "should be B, F, H, W"
|
168 |
+
if reduce_dim:
|
169 |
+
save_tensor_as_image(
|
170 |
+
y_hat2[i][:3],
|
171 |
+
out_dstdir / f"pred_reduced_{name}_{str(i).zfill(2)}",
|
172 |
+
global_step=global_step,
|
173 |
+
)
|
174 |
+
save_tensor_as_image(
|
175 |
+
y_hat[i][:3],
|
176 |
+
out_dstdir / f"pred_colorchs_{name}_{str(i).zfill(2)}",
|
177 |
+
global_step=global_step,
|
178 |
+
)
|
179 |
+
|
180 |
+
|
181 |
+
def check_for_nan(loss, model, batch):
|
182 |
+
try:
|
183 |
+
assert torch.isnan(loss) == False
|
184 |
+
except Exception as e:
|
185 |
+
# print things useful to debug
|
186 |
+
# does the batch contain nan?
|
187 |
+
print("img batch contains nan?", torch.isnan(batch[0]).any())
|
188 |
+
print("mask batch contains nan?", torch.isnan(batch[1]).any())
|
189 |
+
# does the model weights contain nan?
|
190 |
+
for name, param in model.named_parameters():
|
191 |
+
if torch.isnan(param).any():
|
192 |
+
print(name, "contains nan")
|
193 |
+
# does the output contain nan?
|
194 |
+
print("output contains nan?", torch.isnan(model(batch[0])).any())
|
195 |
+
# now raise the error
|
196 |
+
raise e
|
197 |
+
|
198 |
+
|
199 |
+
def calculate_iou(pred, label):
|
200 |
+
intersection = ((label == 1) & (pred == 1)).sum()
|
201 |
+
union = ((label == 1) | (pred == 1)).sum()
|
202 |
+
if not union:
|
203 |
+
return 0
|
204 |
+
else:
|
205 |
+
iou = intersection.item() / union.item()
|
206 |
+
return iou
|
207 |
+
|
208 |
+
|
209 |
+
def load_from_ckpt(net, ckpt_path, strict=True):
|
210 |
+
"""Load network weights"""
|
211 |
+
if ckpt_path and Path(ckpt_path).exists():
|
212 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
213 |
+
if "MODEL_STATE" in ckpt:
|
214 |
+
ckpt = ckpt["MODEL_STATE"]
|
215 |
+
elif "state_dict" in ckpt:
|
216 |
+
ckpt = ckpt["state_dict"]
|
217 |
+
net.load_state_dict(ckpt, strict=strict)
|
218 |
+
print("Loaded checkpoint from", ckpt_path)
|
219 |
+
return net
|