franchesoni commited on
Commit
e1b51e5
1 Parent(s): 2df2c09
Files changed (6) hide show
  1. .gitignore +5 -0
  2. app.py +197 -0
  3. busam.py +137 -0
  4. losses.py +211 -0
  5. network.py +267 -0
  6. utils.py +219 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *.sh
2
+ *.pth
3
+ *.pkl
4
+ __pycache__/
5
+ flagged/
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import numpy as np
4
+ import gradio as gr
5
+ from pathlib import Path
6
+
7
+ from busam import Busam
8
+
9
+ resize_to = 512
10
+ checkpoint = "weights.pth"
11
+ device = "cpu"
12
+ print("Loading model...")
13
+ busam = Busam(checkpoint=checkpoint, device=device, side=resize_to)
14
+ minmaxnorm = lambda x: (x - x.min()) / (x.max() - x.min())
15
+
16
+ def edge_inference(img, algorithm, th_low=None, th_high=None):
17
+ algorithm = algorithm.lower()
18
+ print("Loading image...")
19
+ img = np.array(img[:, :, :3])
20
+ print("Getting features...")
21
+ pred, size = busam.process_image(img, do_activate=True)
22
+ print("Computing sobel...")
23
+ if algorithm == "sobel":
24
+ edge = busam.sobel_from_pred(pred, size)
25
+ elif algorithm == "canny":
26
+ th_low, th_high = th_low or 5000, th_high or 10000
27
+ edge = busam.canny_from_pred(pred, size, th_low=th_low, th_high=th_high)
28
+ else:
29
+ raise ValueError("algorithm should be sobel or canny")
30
+ edge = edge.cpu().numpy() if isinstance(edge, torch.Tensor) else edge
31
+
32
+ print("Done")
33
+ return Image.fromarray(
34
+ (minmaxnorm(edge) * 255).astype(np.uint8)
35
+ ).resize(size[::-1])
36
+
37
+ def dimred_inference(
38
+ img,
39
+ algorithm,
40
+ resample_pct,
41
+ ):
42
+ algorithm = algorithm.lower()
43
+ img = np.array(img[:, :, :3])
44
+ print("Getting features...")
45
+ pred, size = busam.process_image(img, do_activate=True)
46
+ # pred is 1, F, S, S
47
+ assert pred.shape[1] >= 3, "should have at least 3 channels"
48
+ if algorithm == 'pca':
49
+ from sklearn.decomposition import PCA
50
+ reducer = PCA(n_components=3)
51
+ elif algorithm == 'tsne':
52
+ from sklearn.manifold import TSNE
53
+ reducer = TSNE(n_components=3)
54
+ elif algorithm == 'umap':
55
+ from umap import UMAP
56
+ reducer = UMAP(n_components=3)
57
+ else:
58
+ raise ValueError('algorithm should be pca, tsne or umap')
59
+ np_y_hat = pred.detach().cpu().permute(1, 0, 2, 3).numpy() # F, B, H, W
60
+ np_y_hat = np_y_hat.reshape(np_y_hat.shape[0], -1) # F, BHW
61
+ np_y_hat = np_y_hat.T # BHW, F
62
+ resample_pct = 10**resample_pct
63
+ resample_size = int(resample_pct * np_y_hat.shape[0])
64
+ sampled_pixels = np_y_hat[:: np_y_hat.shape[0] // resample_size]
65
+ print("dim reduction fit..." + " " * 30, end="\r")
66
+ reducer = reducer.fit(sampled_pixels)
67
+ print("dim reduction transform..." + " " * 30, end="\r")
68
+ reducer.transform(np_y_hat[:10]) # to numba compile the function
69
+ np_y_hat = reducer.transform(np_y_hat) # BHW, 3
70
+ print()
71
+ print('Done. Saving...')
72
+ # revert back to original shape
73
+ colors = np_y_hat.reshape(pred.shape[2], pred.shape[3], 3)
74
+ return Image.fromarray((minmaxnorm(colors) * 255).astype(np.uint8)).resize(
75
+ size[::-1]
76
+ )
77
+
78
+ def segmentation_inference(img, algorithm, scale):
79
+ algorithm = algorithm.lower()
80
+ img = np.array(img[:, :, :3])
81
+ print("Getting features...")
82
+ pred, size = busam.process_image(img, do_activate=True)
83
+ print("Computing segmentation...")
84
+ if algorithm == "kmeans":
85
+ from sklearn.cluster import KMeans
86
+ n_clusters = int(100 / 100**scale)
87
+ kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(
88
+ pred.view(pred.shape[1], -1).T
89
+ )
90
+ labels = kmeans.labels_
91
+ labels = labels.reshape(pred.shape[2], pred.shape[3])
92
+ elif algorithm == "felzenszwalb":
93
+ from skimage.segmentation import felzenszwalb
94
+ labels = felzenszwalb(
95
+ (minmaxnorm(pred[0].cpu().numpy()) * 255).astype(np.uint8).transpose(1, 2, 0),
96
+ scale=10**(8*scale-3),
97
+ sigma=0,
98
+ min_size=50,
99
+ )
100
+ elif algorithm == "slic":
101
+ from skimage.segmentation import slic
102
+ labels = slic(
103
+ (minmaxnorm(pred[0].cpu().numpy()) * 255).astype(np.uint8).transpose(1, 2, 0),
104
+ n_segments = int(100 / 100**scale),
105
+ compactness=0.00001,
106
+ sigma=1,
107
+ )
108
+ else:
109
+ raise ValueError("algorithm should be kmeans, felzenszwalb or slic")
110
+ print("Done")
111
+ # the labels have values that are usually close to each other in the image and in magnitude, which complicates visualization
112
+ # shuffle the labels to make them more visually distinct
113
+ out = labels.copy()
114
+ out[labels % 4 == 0] = labels[labels % 4 == 0] * 1 / 4
115
+ out[labels % 4 == 1] = labels[labels % 4 == 1] * 4 // 4 + 1
116
+ out[labels % 4 == 2] = labels[labels % 4 == 2] * 2 // 4 + 2
117
+ out[labels % 4 == 3] = labels[labels % 4 == 3] * 3 // 4 + 3
118
+ return Image.fromarray(
119
+ (minmaxnorm(out) * 255).astype(np.uint8)
120
+ ).resize(size[::-1])
121
+
122
+ def one_click_segmentation(img, row, col, threshold):
123
+ row, col = int(row), int(col)
124
+ img = np.array(img[:, :, :3])
125
+ click_map = np.zeros(img.shape[:2], dtype=bool)
126
+ click_map[max(0, row-5):min(img.shape[0], row+5), col] = True
127
+ click_map[row, max(0, col-5):min(img.shape[1], col+5)] = True
128
+ print("Getting features...")
129
+ pred, size = busam.process_image(img, do_activate=True)
130
+ print("Getting mask...")
131
+ mask = busam.get_mask((pred, size), (row, col))
132
+ print("Done")
133
+ print('shapes=', img.shape, mask.shape, click_map.shape)
134
+ return (img, [(mask, 'Prediction'), (click_map, 'Click')])
135
+
136
+ with gr.Blocks() as demo:
137
+ with gr.Tab('Edge detection'):
138
+ algorithm = "canny"
139
+ with gr.Row():
140
+ def enable_sliders(algorithm):
141
+ algorithm = algorithm.lower()
142
+ return gr.Slider(visible=algorithm == "canny"), gr.Slider(visible=algorithm == "canny")
143
+
144
+ with gr.Column():
145
+ image_input = gr.Image(label="Input Image")
146
+ run_button = gr.Button("Run")
147
+ algorithm = gr.Radio(["Sobel", "Canny"], label="Algorithm", value="Sobel")
148
+ # add sliders for th_low, th_high
149
+ th_low_slider = gr.Slider(0, 32768, 10000, label="Canny's low threshold", visible=False)
150
+ th_high_slider = gr.Slider(0, 32768, 20000, label="Canny's high threshold", visible=False)
151
+ algorithm.change(enable_sliders, inputs=[algorithm], outputs=[th_low_slider, th_high_slider])
152
+ with gr.Column():
153
+ output_image = gr.Image(label="Output Image")
154
+ run_button.click(edge_inference, inputs=[image_input, algorithm, th_low_slider, th_high_slider], outputs=output_image)
155
+ gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input)
156
+
157
+ with gr.Tab('Reduction to 3D'):
158
+ with gr.Row():
159
+ with gr.Column():
160
+ image_input = gr.Image(label="Input Image")
161
+ algorithm = gr.Radio(["PCA", "TSNE", "UMAP"], label="Algorithm")
162
+ run_button = gr.Button("Run")
163
+ gr.Markdown("⚠️ UMAP is slow, TSNE is ultra-slow, use resample x<-3 ⚠️")
164
+ resample_pct = gr.Slider(-5, 0, -3, label="Resample (10^x)*100%")
165
+ with gr.Column():
166
+ output_image = gr.Image(label="Output Image")
167
+ run_button.click(dimred_inference, inputs=[image_input, algorithm, resample_pct], outputs=output_image)
168
+ gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input)
169
+
170
+ with gr.Tab('Classical Segmentation'):
171
+ with gr.Row():
172
+ with gr.Column():
173
+ image_input = gr.Image(label="Input Image")
174
+ algorithm = gr.Radio(['KMeans', 'Felzenszwalb', 'SLIC'], label="Algorithm", value="SLIC")
175
+ scale = gr.Slider(0.1, 1.0, 0.5, label="Scale")
176
+ run_button = gr.Button("Run")
177
+ with gr.Column():
178
+ output_image = gr.Image(label="Output Image")
179
+ run_button.click(segmentation_inference, inputs=[image_input, algorithm, scale], outputs=output_image)
180
+ gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input)
181
+
182
+ with gr.Tab('One-click segmentation'):
183
+ with gr.Row():
184
+ with gr.Column():
185
+ image_input = gr.Image(label="Input Image")
186
+ threshold = gr.Slider(0, 1, 0.5, label="Threshold")
187
+ with gr.Row():
188
+ row = gr.Textbox(10, label="Click's row")
189
+ col = gr.Textbox(10, label="Click's column")
190
+ run_button = gr.Button("Run")
191
+ with gr.Column():
192
+ output_image = gr.AnnotatedImage(label="Output")
193
+ run_button.click(one_click_segmentation, inputs=[image_input, row, col, threshold], outputs=output_image)
194
+ gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input)
195
+
196
+
197
+ demo.launch(share=False)
busam.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ from cv2 import resize
5
+ import cv2
6
+ from pathlib import Path
7
+
8
+ from network import EfficientViT_l1_r224
9
+ from losses import IISLoss, activate
10
+ from utils import minmaxnorm, load_from_ckpt
11
+
12
+
13
+ class Busam:
14
+ def __init__(self, checkpoint, device, side=224):
15
+ out_channels = 16
16
+ use_norm_params = False
17
+ net = EfficientViT_l1_r224(
18
+ out_channels=out_channels, use_norm_params=use_norm_params, pretrained=False
19
+ )
20
+ net = load_from_ckpt(net, checkpoint)
21
+ net = net.to(device)
22
+ net.eval()
23
+ self.net = net
24
+ self.device = device
25
+ self.side = side
26
+
27
+ def prepare_img(self, img):
28
+ """
29
+ assume H, W, 3 image
30
+ """
31
+ assert len(img.shape) == 3, "should be H, W, 3 but is " + str(img.shape)
32
+ assert img.shape[2] == 3, "should be H, W, 3 but is " + str(img.shape)
33
+ assert img.min() >= 0, "min should be more than 0 but is " + str(img.min())
34
+ assert img.max() <= 255, "max should be less than 255 but is " + str(img.max())
35
+ assert img.dtype == np.uint8, "dtype should be np.uint8 but is " + str(
36
+ img.dtype
37
+ )
38
+ nimg = resize(img, (self.side, self.side))
39
+ tensorimg = (
40
+ (torch.from_numpy(nimg / 255).permute(2, 0, 1) - 0.5)
41
+ .float()[None]
42
+ .to(self.device)
43
+ )
44
+ return tensorimg
45
+
46
+ def process_image(self, img, do_activate=False):
47
+ with torch.no_grad():
48
+ x = self.prepare_img(img)
49
+ pred = self.net(x)
50
+ H, W = img.shape[:2]
51
+ if do_activate:
52
+ B, F, pH, pW = pred.shape
53
+ features, _, _, _ = activate(
54
+ pred.view(F, pH * pW), None, "symlog", False, False, False
55
+ )
56
+ pred = features.view(B, F, pH, pW)
57
+ return pred, (H, W)
58
+
59
+ def get_mask(self, aux, click):
60
+ """assume click is (row, col)"""
61
+ pred = aux[0][0] # remove batch dim
62
+ oH, oW = aux[1]
63
+ F, H, W = pred.shape
64
+ features = pred.view(F, H * W)
65
+ rclick = click[0] * H // oH, click[1] * W // oW
66
+ sindex = rclick[0] * W + rclick[1]
67
+ mask = IISLoss.get_mask_from_query(features, sindex)
68
+ mask = mask.reshape(H, W)
69
+ mask = (
70
+ resize((mask.cpu().numpy() * 255).astype(np.uint8), (oW, oH)) > 100
71
+ ).astype(bool)
72
+ return mask
73
+
74
+ def get_gradients(self, pred, size):
75
+ F, H, W = pred[0].shape
76
+ sobel_x = (
77
+ torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().to(pred.device)
78
+ )
79
+ sobel_y = sobel_x.T
80
+ sobel_x = sobel_x.repeat(F, 1, 1, 1)
81
+ sobel_y = sobel_y.repeat(F, 1, 1, 1)
82
+ edge_x = torch.nn.functional.conv2d(pred, sobel_x, padding=1, groups=F).view(
83
+ F, H, W
84
+ ) # 1, F, H, W
85
+ edge_y = torch.nn.functional.conv2d(pred, sobel_y, padding=1, groups=F).view(
86
+ F, H, W
87
+ )
88
+ edge_x = torch.norm(edge_x, dim=0, p=2) # will take sqrt
89
+ edge_y = torch.norm(edge_y, dim=0, p=2) # H, W
90
+ return edge_x, edge_y
91
+
92
+ def sobel_from_pred(self, pred, size):
93
+ edge_x, edge_y = self.get_gradients(pred, size)
94
+ edge = torch.sqrt(edge_x**2 + edge_y**2)
95
+ return edge
96
+
97
+ def canny_from_pred(self, pred, size, th_low=10000, th_high=20000):
98
+ th_low = th_low or th_high
99
+ th_high = th_high or th_low
100
+
101
+ edge_x, edge_y = self.get_gradients(pred, size)
102
+ amin = min(edge_x.min(), edge_y.min())
103
+ amax = max(edge_x.max(), edge_y.max())
104
+ edge_x, edge_y = (edge_x - amin) / (amax - amin), (edge_y - amin) / (
105
+ amax - amin
106
+ )
107
+ canny = cv2.Canny(cast_to_int16(edge_x), cast_to_int16(edge_y), th_low, th_high)
108
+ return canny
109
+
110
+
111
+ def cast_to_int16(x):
112
+ if isinstance(x, torch.Tensor):
113
+ x = x.cpu().numpy()
114
+ return (x * 32767).astype(np.int16)
115
+
116
+
117
+ # from segment_anything import sam_model_registry, SamPredictor
118
+ # class SAM:
119
+ # sam_checkpoint = "sam_vit_b_01ec64.pth"
120
+ # model_type = "vit_b"
121
+
122
+ # def __init__(self, device):
123
+ # sam = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint)
124
+ # sam.to(device=device)
125
+ # self.predictor = SamPredictor(sam)
126
+
127
+ # def process_image(self, img):
128
+ # self.predictor.set_image(img)
129
+ # return None
130
+
131
+ # def get_mask(self, aux, click):
132
+ # input_point = np.array([[click[1], click[0]]])
133
+ # input_label = np.array([1])
134
+ # masks, scores, logits = self.predictor.predict(
135
+ # point_coords=input_point, point_labels=input_label, multimask_output=False
136
+ # )
137
+ # return masks[0]
losses.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print("Importing standard...")
2
+ from abc import ABC, abstractmethod
3
+
4
+ print("Importing external...")
5
+ import torch
6
+ from torch.nn.functional import binary_cross_entropy
7
+
8
+ # from matplotlib import pyplot as plt
9
+
10
+ print("Importing internal...")
11
+ from utils import preprocess_masks_features, get_row_col, symlog, calculate_iou
12
+
13
+
14
+ ######### BINARY LOSSES ###############
15
+ def my_lovasz_hinge(logits, gt, downsample=False):
16
+ if downsample:
17
+ offset = int(torch.randint(downsample - 1, (1,)))
18
+ logits, gt = logits[:, offset::downsample], gt[:, offset::downsample]
19
+ # B, HW
20
+ gt = 1.0 * gt # go float
21
+ areas = gt.sum(dim=1, keepdims=True) # B, 1
22
+ # per_image = True, ignore = None
23
+ signs = 2 * gt - 1
24
+ errors = 1 - logits * signs
25
+ errors_sorted, perm = torch.sort(errors, dim=1, descending=True)
26
+ gt_sorted = torch.gather(gt, 1, perm) # B, HW
27
+ # lovasz grad
28
+ intersection = areas - gt_sorted.cumsum(dim=1) # B, HW
29
+ union = areas + (1 - gt_sorted).cumsum(dim=1) # B, HW
30
+ jaccard = 1 - intersection / union # B, HW
31
+ jaccard[:, 1:] = jaccard[:, 1:] - jaccard[:, :-1]
32
+ loss = (torch.relu(errors_sorted) * jaccard).sum(dim=1) # B,
33
+ return torch.nanmean(loss)
34
+
35
+
36
+ def focal_loss(scores, targets, alpha=0.25, gamma=2):
37
+ p = scores
38
+ ce_loss = binary_cross_entropy(p, targets, reduction="none")
39
+ p_t = p * targets + (1 - p) * (1 - targets)
40
+ loss = ce_loss * ((1 - p_t) ** gamma)
41
+
42
+ if alpha >= 0:
43
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
44
+ loss = alpha_t * loss
45
+
46
+ return loss
47
+
48
+
49
+ # also binary_cross_entropy and lovasz
50
+
51
+
52
+ ########## SUBFUNCTIONS ######################3
53
+ def get_distances(features, refs, sigma, norm_p, square_distances, H, W):
54
+ # features: B, 1, F, HW
55
+ # refs: B, M, F, 1
56
+ # sigma: B, M, 1, 1
57
+ B, M = refs.shape[0], refs.shape[1]
58
+ distances = torch.norm(
59
+ features - refs, dim=2, p=norm_p, keepdim=True
60
+ ) # B, M, 1, H*W
61
+ distances = distances**2 if square_distances else distances
62
+ distances = (distances / (2 * sigma**2)).reshape(B, M, H * W)
63
+ return distances
64
+
65
+
66
+ def activate(features, masks, activation, use_sigma, offset_pos, ret_prediction):
67
+ # sigmoid is very similar to exp
68
+ # prepare features
69
+ assert activation in ["sigmoid", "symlog"]
70
+ if masks is None: # when inferencing
71
+ B, M = 1, 1
72
+ F, N = sorted(features.shape)
73
+ H, W = [int(N ** (0.5))] * 2
74
+ features = features.reshape(1, 1, -1, H * W)
75
+ else:
76
+ masks, features, M, B, H, W, F = preprocess_masks_features(masks, features)
77
+ # features: B, 1, F, H*W
78
+ # masks: B, M, 1, H*W
79
+ if use_sigma:
80
+ sigma = torch.nn.functional.softplus(features)[:, :, -1:] # B, 1, 1, H*W
81
+ features = features[:, :, :-1]
82
+ F = features.shape[2]
83
+ else:
84
+ sigma = 1
85
+ features = symlog(features) if activation == "symlog" else torch.sigmoid(features)
86
+ if offset_pos:
87
+ assert F >= 2
88
+ row, col = get_row_col(H, W, features.device)
89
+ row = row.reshape(1, 1, 1, H, 1).expand(B, 1, 1, H, W).reshape(B, 1, 1, H * W)
90
+ col = col.reshape(1, 1, 1, 1, W).expand(B, 1, 1, H, W).reshape(B, 1, 1, H * W)
91
+ positional_features = torch.cat([row, col], dim=2) # B, 1, 2, H*W
92
+ features[:, :, :2] = features[:, :, :2] + positional_features
93
+ prediction = features.reshape(B, 1, -1, H, W) if ret_prediction else None
94
+ if masks is None:
95
+ features = features.reshape(-1, H * W)
96
+ sigma = sigma.reshape(-1, H * W) if use_sigma else 1
97
+ return features, sigma, H, W
98
+ return features, masks, sigma, prediction, B, M, F, H, W
99
+
100
+
101
+ class AbstractLoss(ABC):
102
+ @staticmethod
103
+ @abstractmethod
104
+ def loss(features, masks, ret_prediction=False, **kwargs):
105
+ pass
106
+
107
+ @staticmethod
108
+ @abstractmethod
109
+ def get_mask_from_query(features, sindex, **kwargs):
110
+ pass
111
+
112
+
113
+ class IISLoss(AbstractLoss):
114
+ @staticmethod
115
+ def loss(features, masks, ret_prediction=False, K=3, logger=None):
116
+ features, masks, sigma, prediction, B, M, F, H, W = activate(
117
+ features, masks, "symlog", False, False, ret_prediction
118
+ )
119
+ rindices = torch.randperm(H * W, device=masks.device)
120
+ # the following should work if all masks have more than K pixels
121
+ sindices = torch.stack(
122
+ [
123
+ torch.stack([rindices[masks[b, m, 0, rindices]][:K] for m in range(M)])
124
+ for b in range(B)
125
+ ]
126
+ ) # B, M, K
127
+ feats_at_sindices = torch.gather(
128
+ features.permute(0, 3, 1, 2).expand(B, H * W, K, F),
129
+ dim=1,
130
+ index=sindices.reshape(B, M, K, 1).expand(B, M, K, F),
131
+ ) # B, M, K, F
132
+ feats_at_sindices = feats_at_sindices.reshape(B, M, K, F, 1) # B, M, K, F, 1
133
+ dists = get_distances(
134
+ features, feats_at_sindices.reshape(B, M * K, F, 1), sigma, 2, True, H, W
135
+ )
136
+ score = torch.exp(-dists) # B, M*K, H*W [0, 1]
137
+ targets = (
138
+ masks.expand(B, M, K, H * W).reshape(B, M * K, H * W).float()
139
+ ) # B, M, K, H*W
140
+ floss = focal_loss(score, targets).mean()
141
+ lloss = my_lovasz_hinge(
142
+ score.view(B * M * K, H * W) * 2 - 1,
143
+ targets.view(B * M * K, H * W),
144
+ )
145
+ loss = floss + lloss
146
+ return loss, prediction
147
+
148
+ @staticmethod
149
+ def get_mask_from_query(features, sindex):
150
+ features, _, H, W = activate(features, None, "symlog", False, False, False)
151
+ F = features.shape[0]
152
+ query_feat = features[:, sindex]
153
+ dists = get_distances(
154
+ features.reshape(1, 1, F, H * W),
155
+ query_feat.reshape(1, 1, F, 1),
156
+ 1,
157
+ 2,
158
+ True,
159
+ H,
160
+ W,
161
+ )
162
+ score = torch.exp(-dists) # 1, H*W
163
+ pred = score > 0.5
164
+ return pred
165
+
166
+
167
+ def iis_iou(features, masks, get_mask_from_query, K=20):
168
+ masks, features, M, B, H, W, F = preprocess_masks_features(masks, features)
169
+ # features: B, 1, F, H*W
170
+ # masks: B, M, 1, H*W
171
+ rindices = torch.randperm(H * W).to(masks.device)
172
+ sindices = torch.stack(
173
+ [
174
+ torch.stack([rindices[masks[b, m, 0, rindices]][:K] for m in range(M)])
175
+ for b in range(B)
176
+ ]
177
+ ) # B, M, K
178
+ cum_iou, n_samples = 0, 0
179
+ for b in range(B):
180
+ for m in range(M):
181
+ for k in range(K):
182
+ sindex = sindices[b, m, k]
183
+ pred = get_mask_from_query(features[b, 0], sindex)
184
+ iou = calculate_iou(pred, masks[b, m, 0, :])
185
+ cum_iou += iou
186
+ n_samples += 1
187
+
188
+ return cum_iou / n_samples
189
+
190
+
191
+ losses_names = [
192
+ "iis",
193
+ ]
194
+ #
195
+
196
+
197
+ def get_loss_class(loss_name):
198
+ if loss_name == "iis":
199
+ return IISLoss
200
+ else:
201
+ raise NotImplementedError
202
+
203
+
204
+ def get_get_mask_from_query(loss_name):
205
+ loss_class = get_loss_class(loss_name)
206
+ return loss_class.get_mask_from_query
207
+
208
+
209
+ def get_loss(loss_name):
210
+ loss_class = get_loss_class(loss_name)
211
+ return loss_class.loss
network.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print("Importing external...")
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+
6
+ from timm.models.efficientvit_mit import (
7
+ ConvNormAct,
8
+ FusedMBConv,
9
+ MBConv,
10
+ ResidualBlock,
11
+ efficientvit_l1,
12
+ )
13
+ from timm.layers import GELUTanh
14
+
15
+
16
+ def val2list(x: list or tuple or any, repeat_time=1):
17
+ if isinstance(x, (list, tuple)):
18
+ return list(x)
19
+ return [x for _ in range(repeat_time)]
20
+
21
+
22
+ def resize(
23
+ x: torch.Tensor,
24
+ size: any or None = None,
25
+ scale_factor: list[float] or None = None,
26
+ mode: str = "bicubic",
27
+ align_corners: bool or None = False,
28
+ ) -> torch.Tensor:
29
+ if mode in {"bilinear", "bicubic"}:
30
+ return F.interpolate(
31
+ x,
32
+ size=size,
33
+ scale_factor=scale_factor,
34
+ mode=mode,
35
+ align_corners=align_corners,
36
+ )
37
+ elif mode in {"nearest", "area"}:
38
+ return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode)
39
+ else:
40
+ raise NotImplementedError(f"resize(mode={mode}) not implemented.")
41
+
42
+
43
+ class UpSampleLayer(nn.Module):
44
+ def __init__(
45
+ self,
46
+ mode="bicubic",
47
+ size: int or tuple[int, int] or list[int] or None = None,
48
+ factor=2,
49
+ align_corners=False,
50
+ ):
51
+ super(UpSampleLayer, self).__init__()
52
+ self.mode = mode
53
+ self.size = val2list(size, 2) if size is not None else None
54
+ self.factor = None if self.size is not None else factor
55
+ self.align_corners = align_corners
56
+
57
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
58
+ if (
59
+ self.size is not None and tuple(x.shape[-2:]) == self.size
60
+ ) or self.factor == 1:
61
+ return x
62
+ return resize(x, self.size, self.factor, self.mode, self.align_corners)
63
+
64
+
65
+ class DAGBlock(nn.Module):
66
+ def __init__(
67
+ self,
68
+ inputs: dict[str, nn.Module],
69
+ merge: str,
70
+ post_input: nn.Module or None,
71
+ middle: nn.Module,
72
+ outputs: dict[str, nn.Module],
73
+ ):
74
+ super(DAGBlock, self).__init__()
75
+
76
+ self.input_keys = list(inputs.keys())
77
+ self.input_ops = nn.ModuleList(list(inputs.values()))
78
+ self.merge = merge
79
+ self.post_input = post_input
80
+
81
+ self.middle = middle
82
+
83
+ self.output_keys = list(outputs.keys())
84
+ self.output_ops = nn.ModuleList(list(outputs.values()))
85
+
86
+ def forward(self, feature_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
87
+ feat = [
88
+ op(feature_dict[key]) for key, op in zip(self.input_keys, self.input_ops)
89
+ ]
90
+ if self.merge == "add":
91
+ feat = list_sum(feat)
92
+ elif self.merge == "cat":
93
+ feat = torch.concat(feat, dim=1)
94
+ else:
95
+ raise NotImplementedError
96
+ if self.post_input is not None:
97
+ feat = self.post_input(feat)
98
+ feat = self.middle(feat)
99
+ for key, op in zip(self.output_keys, self.output_ops):
100
+ feature_dict[key] = op(feat)
101
+ return feature_dict
102
+
103
+
104
+ def list_sum(x: list) -> any:
105
+ return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
106
+
107
+
108
+ class SegHead(nn.Module):
109
+ def __init__(
110
+ self,
111
+ fid_list: list[str],
112
+ in_channel_list: list[int],
113
+ stride_list: list[int],
114
+ head_stride: int,
115
+ head_width: int,
116
+ head_depth: int,
117
+ expand_ratio: float,
118
+ middle_op: str,
119
+ final_expand: float or None,
120
+ n_classes: int,
121
+ dropout=0,
122
+ norm="bn2d",
123
+ act_func="hswish",
124
+ ):
125
+ super(SegHead, self).__init__()
126
+ # exceptions to adapt effvit to timm
127
+ if act_func == "gelu":
128
+ act_func = GELUTanh
129
+ else:
130
+ raise ValueError(f"act_func {act_func} not supported")
131
+ if norm == "bn2d":
132
+ norm_layer = nn.BatchNorm2d
133
+ else:
134
+ raise ValueError(f"norm {norm} not supported")
135
+
136
+ inputs = {}
137
+ for fid, in_channel, stride in zip(fid_list, in_channel_list, stride_list):
138
+ factor = stride // head_stride
139
+ if factor == 1:
140
+ inputs[fid] = ConvNormAct(
141
+ in_channel, head_width, 1, norm_layer=norm_layer, act_layer=act_func
142
+ )
143
+ else:
144
+ inputs[fid] = nn.Sequential(
145
+ ConvNormAct(
146
+ in_channel,
147
+ head_width,
148
+ 1,
149
+ norm_layer=norm_layer,
150
+ act_layer=act_func,
151
+ ),
152
+ UpSampleLayer(factor=factor),
153
+ )
154
+ self.in_keys = inputs.keys()
155
+ self.in_ops = nn.ModuleList(inputs.values())
156
+
157
+ middle = []
158
+ for _ in range(head_depth):
159
+ if middle_op == "mbconv":
160
+ block = MBConv(
161
+ head_width,
162
+ head_width,
163
+ expand_ratio=expand_ratio,
164
+ norm_layer=norm_layer,
165
+ act_layer=(act_func, act_func, None),
166
+ )
167
+ elif middle_op == "fmbconv":
168
+ block = FusedMBConv(
169
+ head_width,
170
+ head_width,
171
+ expand_ratio=expand_ratio,
172
+ norm_layer=norm_layer,
173
+ act_layer=(act_func, None),
174
+ )
175
+ else:
176
+ raise NotImplementedError
177
+ middle.append(ResidualBlock(block, nn.Identity()))
178
+ self.middle = nn.Sequential(*middle)
179
+
180
+ self.out_layer = nn.Sequential(
181
+ *[
182
+ None
183
+ if final_expand is None
184
+ else ConvNormAct(
185
+ head_width,
186
+ head_width * final_expand,
187
+ 1,
188
+ norm_layer=norm_layer,
189
+ act_layer=act_func,
190
+ ),
191
+ ConvNormAct(
192
+ head_width * (final_expand or 1),
193
+ n_classes,
194
+ 1,
195
+ bias=True,
196
+ dropout=dropout,
197
+ norm_layer=None,
198
+ act_layer=None,
199
+ ),
200
+ ]
201
+ )
202
+
203
+ def forward(self, feature_map_list):
204
+ t_feat_maps = [
205
+ self.in_ops[ind](feature_map_list[ind])
206
+ for ind in range(len(feature_map_list))
207
+ ]
208
+ t_feat_map = list_sum(t_feat_maps)
209
+ t_feat_map = self.middle(t_feat_map)
210
+ out = self.out_layer(t_feat_map)
211
+ return out
212
+
213
+
214
+ class EfficientViT_l1_r224(nn.Module):
215
+ def __init__(
216
+ self,
217
+ out_channels,
218
+ out_ds_factor=1,
219
+ decoder_size="small",
220
+ pretrained=False,
221
+ use_norm_params=False,
222
+ ):
223
+ if decoder_size == "small":
224
+ head_width = 32
225
+ head_depth = 1
226
+ middle_op = "mbconv"
227
+ elif decoder_size == "medium":
228
+ head_width = 64
229
+ head_depth = 3
230
+ middle_op = "mbconv"
231
+ elif decoder_size == "large":
232
+ head_width = 256
233
+ head_depth = 3
234
+ middle_op = "fmbconv"
235
+
236
+ super(EfficientViT_l1_r224, self).__init__()
237
+ self.bbone = efficientvit_l1(
238
+ num_classes=0, features_only=True, pretrained=pretrained
239
+ )
240
+ self.head = SegHead(
241
+ fid_list=["stage4", "stage3", "stage2"],
242
+ in_channel_list=[512, 256, 128],
243
+ stride_list=[32, 16, 8],
244
+ head_stride=out_ds_factor,
245
+ head_width=head_width,
246
+ head_depth=head_depth,
247
+ expand_ratio=4,
248
+ middle_op=middle_op,
249
+ final_expand=8,
250
+ n_classes=out_channels,
251
+ act_func="gelu",
252
+ )
253
+ # [optional] deactivate normalization
254
+ if not use_norm_params:
255
+ for module in self.modules():
256
+ if (
257
+ isinstance(module, nn.LayerNorm)
258
+ or isinstance(module, nn.BatchNorm2d)
259
+ or isinstance(module, nn.BatchNorm1d)
260
+ ):
261
+ module.weight.requires_grad_(False)
262
+ module.bias.requires_grad_(False)
263
+
264
+ def forward(self, x):
265
+ feat = self.bbone(x)
266
+ out = self.head([feat[3], feat[2], feat[1]])
267
+ return out
utils.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print("Importing standard...")
2
+ import subprocess
3
+ import shutil
4
+ from pathlib import Path
5
+
6
+ print("Importing external...")
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image
10
+
11
+ REDUCTION = "pca"
12
+ if REDUCTION == "umap":
13
+ from umap import UMAP
14
+ elif REDUCTION == "tsne":
15
+ from sklearn.manifold import TSNE
16
+ elif REDUCTION == "pca":
17
+ from sklearn.decomposition import PCA
18
+
19
+
20
+ def symlog(x):
21
+ return torch.sign(x) * torch.log(torch.abs(x) + 1)
22
+
23
+
24
+ def preprocess_masks_features(masks, features):
25
+ # Get shapes right
26
+ B, M, H, W = masks.shape
27
+ Bf, F, Hf, Wf = features.shape
28
+ masks = masks.reshape(B, M, 1, H * W)
29
+ # # the following assertions should work, remove due to speed
30
+ # assert H == Hf and W == Wf and B == Bf
31
+ # assert masks.dtype == torch.bool
32
+ # assert (mask_areas > 0).all(), "you shouldn't have empty masks"
33
+
34
+ # Reduce M if there are empty masks
35
+ mask_areas = masks.sum(dim=3) # B, M, 1
36
+ features = features.reshape(B, 1, F, H * W)
37
+ # output shapes
38
+ # features: B, 1, F, H*W
39
+ # masks: B, M, 1, H*W
40
+
41
+ return masks, features, M, B, H, W, F
42
+
43
+
44
+ def get_row_col(H, W, device):
45
+ # get position of pixels in [0, 1]
46
+ row = torch.linspace(0, 1, H, device=device)
47
+ col = torch.linspace(0, 1, W, device=device)
48
+ return row, col
49
+
50
+
51
+ def get_current_git_commit():
52
+ try:
53
+ # Run the git command to get the current commit hash
54
+ commit_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()
55
+ # Decode from bytes to a string
56
+ return commit_hash.decode("utf-8")
57
+ except subprocess.CalledProcessError:
58
+ # Handle the case where the command fails (e.g., not a Git repository)
59
+ print("An error occurred while trying to retrieve the git commit hash.")
60
+ return None
61
+
62
+
63
+ def clean_dir(dirname):
64
+ """Removes all directories in dirname that don't have a done.txt file"""
65
+ dstdir = Path(dirname)
66
+ dstdir.mkdir(exist_ok=True, parents=True)
67
+ for f in dstdir.iterdir():
68
+ # if the directory doesn't have a done.txt file remove it
69
+ if f.is_dir() and not (f / "done.txt").exists():
70
+ shutil.rmtree(f)
71
+
72
+
73
+ def save_tensor_as_image(tensor, dstfile, global_step):
74
+ dstfile = Path(dstfile)
75
+ dstfile = (dstfile.parent / (dstfile.stem + "_" + str(global_step))).with_suffix(
76
+ ".jpg"
77
+ )
78
+ save(tensor, str(dstfile))
79
+
80
+
81
+ def minmaxnorm(x):
82
+ return (x - x.min()) / (x.max() - x.min())
83
+
84
+
85
+ def save(tensor, name, channel_offset=0):
86
+ tensor = to_img(tensor, channel_offset=channel_offset)
87
+ Image.fromarray(tensor).save(name)
88
+
89
+
90
+ def to_img(tensor, channel_offset=0):
91
+ tensor = minmaxnorm(tensor)
92
+ tensor = (tensor * 255).to(torch.uint8)
93
+ C, H, W = tensor.shape
94
+ if tensor.shape[0] == 1:
95
+ tensor = tensor[0]
96
+ elif tensor.shape[0] == 2:
97
+ tensor = torch.stack([tensor[0], torch.zeros_like(tensor[0]), tensor[1]], dim=0)
98
+ tensor = tensor.permute(1, 2, 0)
99
+ elif tensor.shape[0] >= 3:
100
+ tensor = tensor[channel_offset : channel_offset + 3]
101
+ tensor = tensor.permute(1, 2, 0)
102
+ tensor = tensor.cpu().numpy()
103
+ return tensor
104
+
105
+
106
+ def log_input_output(
107
+ name,
108
+ x,
109
+ y_hat,
110
+ global_step,
111
+ img_dstdir,
112
+ out_dstdir,
113
+ reduce_dim=True,
114
+ reduction=REDUCTION,
115
+ resample_size=20000,
116
+ ):
117
+ y_hat = y_hat.reshape(
118
+ y_hat.shape[0], y_hat.shape[2], y_hat.shape[3], y_hat.shape[4]
119
+ )
120
+ if reduce_dim and y_hat.shape[1] >= 3:
121
+ reducer = (
122
+ UMAP(n_components=3)
123
+ if (reduction == "umap")
124
+ else (
125
+ TSNE(n_components=3)
126
+ if reduction == "tsne"
127
+ else PCA(n_components=3)
128
+ if reduction == "pca"
129
+ else None
130
+ )
131
+ )
132
+ np_y_hat = y_hat.detach().cpu().permute(1, 0, 2, 3).numpy() # F, 1, B, H, W
133
+ np_y_hat = np_y_hat.reshape(np_y_hat.shape[0], -1) # F, BHW
134
+ np_y_hat = np_y_hat.T # BHW, F
135
+ sampled_pixels = np_y_hat[:: np_y_hat.shape[0] // resample_size]
136
+ print("dim reduction fit..." + " " * 30, end="\r")
137
+ reducer = reducer.fit(sampled_pixels)
138
+ print("dim reduction transform..." + " " * 30, end="\r")
139
+ reducer.transform(np_y_hat[:10]) # to numba compile the function
140
+ np_y_hat = reducer.transform(np_y_hat) # BHW, 3
141
+ # revert back to original shape
142
+ y_hat2 = (
143
+ torch.from_numpy(
144
+ np_y_hat.T.reshape(3, y_hat.shape[0], y_hat.shape[2], y_hat.shape[3])
145
+ )
146
+ .to(y_hat.device)
147
+ .permute(1, 0, 2, 3)
148
+ )
149
+ print("done" + " " * 30, end="\r")
150
+ else:
151
+ y_hat2 = y_hat
152
+
153
+ for i in range(min(len(x), 8)):
154
+ save_tensor_as_image(
155
+ x[i],
156
+ img_dstdir / f"input_{name}_{str(i).zfill(2)}",
157
+ global_step=global_step,
158
+ )
159
+ for c in range(y_hat.shape[1]):
160
+ save_tensor_as_image(
161
+ y_hat[i, c : c + 1],
162
+ out_dstdir / f"pred_channel_{name}_{str(i).zfill(2)}_{c}",
163
+ global_step=global_step,
164
+ )
165
+ # log color image
166
+
167
+ assert len(y_hat2.shape) == 4, "should be B, F, H, W"
168
+ if reduce_dim:
169
+ save_tensor_as_image(
170
+ y_hat2[i][:3],
171
+ out_dstdir / f"pred_reduced_{name}_{str(i).zfill(2)}",
172
+ global_step=global_step,
173
+ )
174
+ save_tensor_as_image(
175
+ y_hat[i][:3],
176
+ out_dstdir / f"pred_colorchs_{name}_{str(i).zfill(2)}",
177
+ global_step=global_step,
178
+ )
179
+
180
+
181
+ def check_for_nan(loss, model, batch):
182
+ try:
183
+ assert torch.isnan(loss) == False
184
+ except Exception as e:
185
+ # print things useful to debug
186
+ # does the batch contain nan?
187
+ print("img batch contains nan?", torch.isnan(batch[0]).any())
188
+ print("mask batch contains nan?", torch.isnan(batch[1]).any())
189
+ # does the model weights contain nan?
190
+ for name, param in model.named_parameters():
191
+ if torch.isnan(param).any():
192
+ print(name, "contains nan")
193
+ # does the output contain nan?
194
+ print("output contains nan?", torch.isnan(model(batch[0])).any())
195
+ # now raise the error
196
+ raise e
197
+
198
+
199
+ def calculate_iou(pred, label):
200
+ intersection = ((label == 1) & (pred == 1)).sum()
201
+ union = ((label == 1) | (pred == 1)).sum()
202
+ if not union:
203
+ return 0
204
+ else:
205
+ iou = intersection.item() / union.item()
206
+ return iou
207
+
208
+
209
+ def load_from_ckpt(net, ckpt_path, strict=True):
210
+ """Load network weights"""
211
+ if ckpt_path and Path(ckpt_path).exists():
212
+ ckpt = torch.load(ckpt_path, map_location="cpu")
213
+ if "MODEL_STATE" in ckpt:
214
+ ckpt = ckpt["MODEL_STATE"]
215
+ elif "state_dict" in ckpt:
216
+ ckpt = ckpt["state_dict"]
217
+ net.load_state_dict(ckpt, strict=strict)
218
+ print("Loaded checkpoint from", ckpt_path)
219
+ return net