DmitrMakeev
commited on
Commit
•
98c5805
1
Parent(s):
2e9004e
Upload 9 files
Browse files- models/__init__.py +0 -0
- models/anchor_gen.py +107 -0
- models/basic.py +504 -0
- models/clusterkit.py +291 -0
- models/loss.py +222 -0
- models/model.py +196 -0
- models/network.py +352 -0
- models/position_encoding.py +86 -0
- models/transformer2d.py +229 -0
models/__init__.py
ADDED
File without changes
|
models/anchor_gen.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.autograd import Function
|
5 |
+
from models import basic, clusterkit
|
6 |
+
import pdb
|
7 |
+
|
8 |
+
class AnchorAnalysis:
|
9 |
+
def __init__(self, mode, colorLabeler):
|
10 |
+
## anchor generating mode: 1.random; 2.clustering
|
11 |
+
self.mode = mode
|
12 |
+
self.colorLabeler = colorLabeler
|
13 |
+
|
14 |
+
def _detect_correlation(self, data_tensors, color_probs, hint_masks, thres=0.1):
|
15 |
+
N,C,H,W = data_tensors.shape
|
16 |
+
## (N,C,HW)
|
17 |
+
data_vecs = data_tensors.flatten(2)
|
18 |
+
prob_vecs = color_probs.flatten(2)
|
19 |
+
mask_vecs = hint_masks.flatten(2)
|
20 |
+
#anchor_data = torch.masked_select(data_vecs, mask_vecs.bool()).view(N,C,-1)
|
21 |
+
#anchor_prob = torch.masked_select(prob_vecs, mask_vecs.bool()).view(N,313,-1)
|
22 |
+
#_,_,K = anchor_data.shape
|
23 |
+
anchor_mask = torch.matmul(mask_vecs.permute(0,2,1), mask_vecs)
|
24 |
+
cosine_sim = True
|
25 |
+
## non-similarity matrix
|
26 |
+
if cosine_sim:
|
27 |
+
norm_data = F.normalize(data_vecs, p=2, dim=1)
|
28 |
+
## (N,HW,HW) = (N,HW,C) X (N,C,HW)
|
29 |
+
corr_matrix = torch.matmul(norm_data.permute(0,2,1), norm_data)
|
30 |
+
## remapping: [-1.0,1.0] to [0.0,1.0], and convert into dis-similarity
|
31 |
+
dist_matrix = 1.0 - 0.5*(corr_matrix + 1.0)
|
32 |
+
else:
|
33 |
+
## (N,HW,HW) = (N,HW,C) X (N,C,HW)
|
34 |
+
XtX = torch.matmul(data_vecs.permute(0,2,1), data_vecs)
|
35 |
+
diag_vec = torch.diagonal(XtX, dim1=-2, dim2=-1)
|
36 |
+
A = diag_vec.unsqueeze(1).repeat(1,H*W,1)
|
37 |
+
At = diag_vec.unsqueeze(2).repeat(1,1,H*W)
|
38 |
+
dist_matrix = A - 2*XtX + At
|
39 |
+
#dist_matrix = dist_matrix + 1e7*torch.eye(K).to(data_tensors.device).repeat(N,1,1)
|
40 |
+
## for debug use
|
41 |
+
K = 8
|
42 |
+
anchor_adj_matrix = torch.masked_select(dist_matrix, anchor_mask.bool()).view(N,K,K)
|
43 |
+
## dectect connected nodes
|
44 |
+
adj_matrix = torch.where((dist_matrix < thres) & (anchor_mask > 0), torch.ones_like(dist_matrix), torch.zeros_like(dist_matrix))
|
45 |
+
adj_matrix = torch.matmul(adj_matrix, adj_matrix)
|
46 |
+
adj_matrix = adj_matrix / (1e-7+adj_matrix)
|
47 |
+
## merge nodes
|
48 |
+
## (N,K,C) = (N,K,K) X (N,K,C)
|
49 |
+
anchor_prob = torch.matmul(adj_matrix, prob_vecs.permute(0,2,1)) / torch.sum(adj_matrix, dim=2, keepdim=True)
|
50 |
+
updated_prob_vecs = anchor_prob.permute(0,2,1) * mask_vecs + (1-mask_vecs) * prob_vecs
|
51 |
+
color_probs = updated_prob_vecs.view(N,313,H,W)
|
52 |
+
return color_probs, anchor_adj_matrix
|
53 |
+
|
54 |
+
def _sample_anchor_colors(self, pred_prob, hint_mask, T=0):
|
55 |
+
N,C,H,W = pred_prob.shape
|
56 |
+
topk = 10
|
57 |
+
assert T < topk
|
58 |
+
sorted_probs, batch_indexs = torch.sort(pred_prob, dim=1, descending=True)
|
59 |
+
## (N,topk,H,W,1)
|
60 |
+
topk_probs = torch.softmax(sorted_probs[:,:topk,:,:], dim=1).unsqueeze(4)
|
61 |
+
topk_indexs = batch_indexs[:,:topk,:,:]
|
62 |
+
topk_ABs = torch.stack([self.colorLabeler.q_to_ab.index_select(0, q_i.flatten()).reshape(topk,H,W,2)
|
63 |
+
for q_i in topk_indexs])
|
64 |
+
## (N,topk,H,W,2)
|
65 |
+
topk_ABs = topk_ABs / 110.0
|
66 |
+
## choose the most distinctive 3 colors for each anchor
|
67 |
+
if T == 0:
|
68 |
+
sampled_ABs = topk_ABs[:,0,:,:,:]
|
69 |
+
elif T == 1:
|
70 |
+
sampled_AB0 = topk_ABs[:,[0],:,:,:]
|
71 |
+
internal_diff = torch.norm(topk_ABs-sampled_AB0, p=2, dim=4, keepdim=True)
|
72 |
+
_, batch_indexs = torch.sort(internal_diff, dim=1, descending=True)
|
73 |
+
## (N,1,H,W,2)
|
74 |
+
selected_index = batch_indexs[:,[0],:,:,:].expand([-1,-1,-1,-1,2])
|
75 |
+
sampled_ABs = torch.gather(topk_ABs, 1, selected_index)
|
76 |
+
sampled_ABs = sampled_ABs.squeeze(1)
|
77 |
+
else:
|
78 |
+
sampled_AB0 = topk_ABs[:,[0],:,:,:]
|
79 |
+
internal_diff = torch.norm(topk_ABs-sampled_AB0, p=2, dim=4, keepdim=True)
|
80 |
+
_, batch_indexs = torch.sort(internal_diff, dim=1, descending=True)
|
81 |
+
selected_index = batch_indexs[:,[0],:,:,:].expand([-1,-1,-1,-1,2])
|
82 |
+
sampled_AB1 = torch.gather(topk_ABs, 1, selected_index)
|
83 |
+
internal_diff2 = torch.norm(topk_ABs-sampled_AB1, p=2, dim=4, keepdim=True)
|
84 |
+
_, batch_indexs = torch.sort(internal_diff+internal_diff2, dim=1, descending=True)
|
85 |
+
## (N,1,H,W,2)
|
86 |
+
selected_index = batch_indexs[:,[T-2],:,:,:].expand([-1,-1,-1,-1,2])
|
87 |
+
sampled_ABs = torch.gather(topk_ABs, 1, selected_index)
|
88 |
+
sampled_ABs = sampled_ABs.squeeze(1)
|
89 |
+
|
90 |
+
return sampled_ABs.permute(0,3,1,2)
|
91 |
+
|
92 |
+
def __call__(self, data_tensors, n_anchors, spixel_sizes, use_sklearn_kmeans=False):
|
93 |
+
N,C,H,W = data_tensors.shape
|
94 |
+
if self.mode == 'clustering':
|
95 |
+
## clusters map: (N,K,H,W)
|
96 |
+
cluster_mask = clusterkit.batch_kmeans_pytorch(data_tensors, n_anchors, 'euclidean', use_sklearn_kmeans)
|
97 |
+
#noises = torch.rand(N,1,H,W).to(cluster_mask.device)
|
98 |
+
perturb_factors = spixel_sizes
|
99 |
+
cluster_prob = cluster_mask + perturb_factors * 0.01
|
100 |
+
hint_mask_layers = F.one_hot(torch.argmax(cluster_prob.flatten(2), dim=-1), num_classes=H*W).float()
|
101 |
+
hint_mask = torch.sum(hint_mask_layers, dim=1, keepdim=True).view(N,1,H,W)
|
102 |
+
else:
|
103 |
+
#print('----------hello, random!')
|
104 |
+
cluster_mask = torch.zeros(N,n_anchors,H,W).to(data_tensors.device)
|
105 |
+
binary_mask = basic.get_random_mask(N, H, W, minNum=n_anchors, maxNum=n_anchors)
|
106 |
+
hint_mask = torch.from_numpy(binary_mask).to(data_tensors.device)
|
107 |
+
return hint_mask, cluster_mask
|
models/basic.py
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.nn.utils.spectral_norm as spectral_norm
|
6 |
+
from torch.autograd import Function
|
7 |
+
from utils import util, cielab
|
8 |
+
import cv2, math, random
|
9 |
+
|
10 |
+
def tensor2array(tensors):
|
11 |
+
arrays = tensors.detach().to("cpu").numpy()
|
12 |
+
return np.transpose(arrays, (0, 2, 3, 1))
|
13 |
+
|
14 |
+
|
15 |
+
def rgb2gray(color_batch):
|
16 |
+
#! gray = 0.299*R+0.587*G+0.114*B
|
17 |
+
gray_batch = color_batch[:, 0, ...] * 0.299 + color_batch[:, 1, ...] * 0.587 + color_batch[:, 2, ...] * 0.114
|
18 |
+
gray_batch = gray_batch.unsqueeze_(1)
|
19 |
+
return gray_batch
|
20 |
+
|
21 |
+
|
22 |
+
def getParamsAmount(model):
|
23 |
+
params = list(model.parameters())
|
24 |
+
count = 0
|
25 |
+
for var in params:
|
26 |
+
l = 1
|
27 |
+
for j in var.size():
|
28 |
+
l *= j
|
29 |
+
count += l
|
30 |
+
return count
|
31 |
+
|
32 |
+
|
33 |
+
def checkAverageGradient(model):
|
34 |
+
meanGrad, cnt = 0.0, 0
|
35 |
+
for name, parms in model.named_parameters():
|
36 |
+
if parms.requires_grad:
|
37 |
+
meanGrad += torch.mean(torch.abs(parms.grad))
|
38 |
+
cnt += 1
|
39 |
+
return meanGrad.item() / cnt
|
40 |
+
|
41 |
+
|
42 |
+
def get_random_mask(N, H, W, minNum, maxNum):
|
43 |
+
binary_maps = np.zeros((N, H*W), np.float32)
|
44 |
+
for i in range(N):
|
45 |
+
locs = random.sample(range(0, H*W), random.randint(minNum,maxNum))
|
46 |
+
binary_maps[i, locs] = 1
|
47 |
+
return binary_maps.reshape(N,1,H,W)
|
48 |
+
|
49 |
+
|
50 |
+
def io_user_control(hint_mask, spix_colors, output=True):
|
51 |
+
cache_dir = '/apdcephfs/private_richardxia'
|
52 |
+
if output:
|
53 |
+
print('--- data saving')
|
54 |
+
mask_imgs = tensor2array(hint_mask) * 2.0 - 1.0
|
55 |
+
util.save_images_from_batch(mask_imgs, cache_dir, ['mask.png'], -1)
|
56 |
+
fake_gray = torch.zeros_like(spix_colors[:,[0],:,:])
|
57 |
+
spix_labs = torch.cat((fake_gray,spix_colors), dim=1)
|
58 |
+
spix_imgs = tensor2array(spix_labs)
|
59 |
+
util.save_normLabs_from_batch(spix_imgs, cache_dir, ['color.png'], -1)
|
60 |
+
return hint_mask, spix_colors
|
61 |
+
else:
|
62 |
+
print('--- data loading')
|
63 |
+
mask_img = cv2.imread(cache_dir+'/mask.png', cv2.IMREAD_GRAYSCALE)
|
64 |
+
mask_img = np.expand_dims(mask_img, axis=2) / 255.
|
65 |
+
hint_mask = torch.from_numpy(mask_img.transpose((2, 0, 1)))
|
66 |
+
hint_mask = hint_mask.unsqueeze(0).cuda()
|
67 |
+
bgr_img = cv2.imread(cache_dir+'/color.png', cv2.IMREAD_COLOR)
|
68 |
+
rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
|
69 |
+
rgb_img = np.array(rgb_img / 255., np.float32)
|
70 |
+
lab_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2LAB)
|
71 |
+
lab_img = torch.from_numpy(lab_img.transpose((2, 0, 1)))
|
72 |
+
ab_chans = lab_img[1:3,:,:] / 110.
|
73 |
+
spix_colors = ab_chans.unsqueeze(0).cuda()
|
74 |
+
return hint_mask.float(), spix_colors.float()
|
75 |
+
|
76 |
+
|
77 |
+
class Quantize(Function):
|
78 |
+
@staticmethod
|
79 |
+
def forward(ctx, x):
|
80 |
+
ctx.save_for_backward(x)
|
81 |
+
y = x.round()
|
82 |
+
return y
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def backward(ctx, grad_output):
|
86 |
+
"""
|
87 |
+
In the backward pass we receive a Tensor containing the gradient of the loss
|
88 |
+
with respect to the output, and we need to compute the gradient of the loss
|
89 |
+
with respect to the input.
|
90 |
+
"""
|
91 |
+
inputX = ctx.saved_tensors
|
92 |
+
return grad_output
|
93 |
+
|
94 |
+
|
95 |
+
def mark_color_hints(input_grays, target_ABs, gate_maps, kernel_size=3, base_ABs=None):
|
96 |
+
## to highlight the seeds with 1-pixel margin
|
97 |
+
binary_map = torch.where(gate_maps>0.7, torch.ones_like(gate_maps), torch.zeros_like(gate_maps))
|
98 |
+
center_mask = dilate_seeds(binary_map, kernel_size=kernel_size)
|
99 |
+
margin_mask = dilate_seeds(binary_map, kernel_size=kernel_size+2) - center_mask
|
100 |
+
## drop colors
|
101 |
+
dilated_seeds = dilate_seeds(gate_maps, kernel_size=kernel_size+2)
|
102 |
+
marked_grays = torch.where(margin_mask > 1e-5, torch.ones_like(gate_maps), input_grays)
|
103 |
+
if base_ABs is None:
|
104 |
+
marked_ABs = torch.where(center_mask < 1e-5, torch.zeros_like(target_ABs), target_ABs)
|
105 |
+
else:
|
106 |
+
marked_ABs = torch.where(margin_mask > 1e-5, torch.zeros_like(base_ABs), base_ABs)
|
107 |
+
marked_ABs = torch.where(center_mask > 1e-5, target_ABs, marked_ABs)
|
108 |
+
return torch.cat((marked_grays,marked_ABs), dim=1)
|
109 |
+
|
110 |
+
def dilate_seeds(gate_maps, kernel_size=3):
|
111 |
+
N,C,H,W = gate_maps.shape
|
112 |
+
input_unf = F.unfold(gate_maps, kernel_size, padding=kernel_size//2)
|
113 |
+
#! Notice: differentiable? just like max pooling?
|
114 |
+
dilated_seeds, _ = torch.max(input_unf, dim=1, keepdim=True)
|
115 |
+
output = F.fold(dilated_seeds, output_size=(H,W), kernel_size=1)
|
116 |
+
#print('-------', input_unf.shape)
|
117 |
+
return output
|
118 |
+
|
119 |
+
|
120 |
+
class RebalanceLoss(Function):
|
121 |
+
@staticmethod
|
122 |
+
def forward(ctx, data_input, weights):
|
123 |
+
ctx.save_for_backward(weights)
|
124 |
+
return data_input.clone()
|
125 |
+
|
126 |
+
@staticmethod
|
127 |
+
def backward(ctx, grad_output):
|
128 |
+
weights, = ctx.saved_tensors
|
129 |
+
# reweigh gradient pixelwise so that rare colors get a chance to
|
130 |
+
# contribute
|
131 |
+
grad_input = grad_output * weights
|
132 |
+
# second return value is None since we are not interested in the
|
133 |
+
# gradient with respect to the weights
|
134 |
+
return grad_input, None
|
135 |
+
|
136 |
+
|
137 |
+
class GetClassWeights:
|
138 |
+
def __init__(self, cielab, lambda_=0.5, device='cuda'):
|
139 |
+
prior = torch.from_numpy(cielab.gamut.prior).cuda()
|
140 |
+
uniform = torch.zeros_like(prior)
|
141 |
+
uniform[prior > 0] = 1 / (prior > 0).sum().type_as(uniform)
|
142 |
+
self.weights = 1 / ((1 - lambda_) * prior + lambda_ * uniform)
|
143 |
+
self.weights /= torch.sum(prior * self.weights)
|
144 |
+
|
145 |
+
def __call__(self, ab_actual):
|
146 |
+
return self.weights[ab_actual.argmax(dim=1, keepdim=True)]
|
147 |
+
|
148 |
+
|
149 |
+
class ColorLabel:
|
150 |
+
def __init__(self, lambda_=0.5, device='cuda'):
|
151 |
+
self.cielab = cielab.CIELAB()
|
152 |
+
self.q_to_ab = torch.from_numpy(self.cielab.q_to_ab).to(device)
|
153 |
+
prior = torch.from_numpy(self.cielab.gamut.prior).to(device)
|
154 |
+
uniform = torch.zeros_like(prior)
|
155 |
+
uniform[prior>0] = 1 / (prior>0).sum().type_as(uniform)
|
156 |
+
self.weights = 1 / ((1-lambda_) * prior + lambda_ * uniform)
|
157 |
+
self.weights /= torch.sum(prior * self.weights)
|
158 |
+
|
159 |
+
def visualize_label(self, step=3):
|
160 |
+
height, width = 200, 313*step
|
161 |
+
label_lab = np.ones((height,width,3), np.float32)
|
162 |
+
for x in range(313):
|
163 |
+
ab = self.cielab.q_to_ab[x,:]
|
164 |
+
label_lab[:,step*x:step*(x+1),1:] = ab / 110.
|
165 |
+
label_lab[:,:,0] = np.zeros((height,width), np.float32)
|
166 |
+
return label_lab
|
167 |
+
|
168 |
+
@staticmethod
|
169 |
+
def _gauss_eval(x, mu, sigma):
|
170 |
+
norm = 1 / (2 * math.pi * sigma)
|
171 |
+
return norm * torch.exp(-torch.sum((x - mu)**2, dim=0) / (2 * sigma**2))
|
172 |
+
|
173 |
+
def get_classweights(self, batch_gt_indx):
|
174 |
+
#return self.weights[batch_gt_q.argmax(dim=1, keepdim=True)]
|
175 |
+
return self.weights[batch_gt_indx]
|
176 |
+
|
177 |
+
def encode_ab2ind(self, batch_ab, neighbours=5, sigma=5.0):
|
178 |
+
batch_ab = batch_ab * 110.
|
179 |
+
n, _, h, w = batch_ab.shape
|
180 |
+
m = n * h * w
|
181 |
+
# find nearest neighbours
|
182 |
+
ab_ = batch_ab.permute(1, 0, 2, 3).reshape(2, -1) # (2, n*h*w)
|
183 |
+
cdist = torch.cdist(self.q_to_ab, ab_.t())
|
184 |
+
nns = cdist.argsort(dim=0)[:neighbours, :]
|
185 |
+
# gaussian weighting
|
186 |
+
nn_gauss = batch_ab.new_zeros(neighbours, m)
|
187 |
+
for i in range(neighbours):
|
188 |
+
nn_gauss[i, :] = self._gauss_eval(self.q_to_ab[nns[i, :], :].t(), ab_, sigma)
|
189 |
+
nn_gauss /= nn_gauss.sum(dim=0, keepdim=True)
|
190 |
+
# expand
|
191 |
+
bins = self.cielab.gamut.EXPECTED_SIZE
|
192 |
+
q = batch_ab.new_zeros(bins, m)
|
193 |
+
q[nns, torch.arange(m).repeat(neighbours, 1)] = nn_gauss
|
194 |
+
return q.reshape(bins, n, h, w).permute(1, 0, 2, 3)
|
195 |
+
|
196 |
+
def decode_ind2ab(self, batch_q, T=0.38):
|
197 |
+
_, _, h, w = batch_q.shape
|
198 |
+
batch_q = F.softmax(batch_q, dim=1)
|
199 |
+
if T%1 == 0:
|
200 |
+
# take the T-st probable index
|
201 |
+
sorted_probs, batch_indexs = torch.sort(batch_q, dim=1, descending=True)
|
202 |
+
#print('checking [index]', batch_indexs[:,0:5,5,5])
|
203 |
+
#print('checking [probs]', sorted_probs[:,0:5,5,5])
|
204 |
+
batch_indexs = batch_indexs[:,T:T+1,:,:]
|
205 |
+
#batch_indexs = torch.where(sorted_probs[:,T:T+1,:,:] > 0.25, batch_indexs[:,T:T+1,:,:], batch_indexs[:,0:1,:,:])
|
206 |
+
ab = torch.stack([
|
207 |
+
self.q_to_ab.index_select(0, q_i.flatten()).reshape(h,w,2).permute(2,0,1)
|
208 |
+
for q_i in batch_indexs])
|
209 |
+
else:
|
210 |
+
batch_q = torch.exp(batch_q / T)
|
211 |
+
batch_q /= batch_q.sum(dim=1, keepdim=True)
|
212 |
+
a = torch.tensordot(batch_q, self.q_to_ab[:,0], dims=((1,), (0,)))
|
213 |
+
a = a.unsqueeze(dim=1)
|
214 |
+
b = torch.tensordot(batch_q, self.q_to_ab[:,1], dims=((1,), (0,)))
|
215 |
+
b = b.unsqueeze(dim=1)
|
216 |
+
ab = torch.cat((a, b), dim=1)
|
217 |
+
ab = ab / 110.
|
218 |
+
return ab.type(batch_q.dtype)
|
219 |
+
|
220 |
+
|
221 |
+
def init_spixel_grid(img_height, img_width, spixel_size=16):
|
222 |
+
# get spixel id for the final assignment
|
223 |
+
n_spixl_h = int(np.floor(img_height/spixel_size))
|
224 |
+
n_spixl_w = int(np.floor(img_width/spixel_size))
|
225 |
+
spixel_height = int(img_height / (1. * n_spixl_h))
|
226 |
+
spixel_width = int(img_width / (1. * n_spixl_w))
|
227 |
+
spix_values = np.int32(np.arange(0, n_spixl_w * n_spixl_h).reshape((n_spixl_h, n_spixl_w)))
|
228 |
+
|
229 |
+
def shift9pos(input, h_shift_unit=1, w_shift_unit=1):
|
230 |
+
# input should be padding as (c, 1+ height+1, 1+width+1)
|
231 |
+
input_pd = np.pad(input, ((h_shift_unit, h_shift_unit), (w_shift_unit, w_shift_unit)), mode='edge')
|
232 |
+
input_pd = np.expand_dims(input_pd, axis=0)
|
233 |
+
# assign to ...
|
234 |
+
top = input_pd[:, :-2 * h_shift_unit, w_shift_unit:-w_shift_unit]
|
235 |
+
bottom = input_pd[:, 2 * h_shift_unit:, w_shift_unit:-w_shift_unit]
|
236 |
+
left = input_pd[:, h_shift_unit:-h_shift_unit, :-2 * w_shift_unit]
|
237 |
+
right = input_pd[:, h_shift_unit:-h_shift_unit, 2 * w_shift_unit:]
|
238 |
+
center = input_pd[:,h_shift_unit:-h_shift_unit,w_shift_unit:-w_shift_unit]
|
239 |
+
bottom_right = input_pd[:, 2 * h_shift_unit:, 2 * w_shift_unit:]
|
240 |
+
bottom_left = input_pd[:, 2 * h_shift_unit:, :-2 * w_shift_unit]
|
241 |
+
top_right = input_pd[:, :-2 * h_shift_unit, 2 * w_shift_unit:]
|
242 |
+
top_left = input_pd[:, :-2 * h_shift_unit, :-2 * w_shift_unit]
|
243 |
+
shift_tensor = np.concatenate([ top_left, top, top_right,
|
244 |
+
left, center, right,
|
245 |
+
bottom_left, bottom, bottom_right], axis=0)
|
246 |
+
return shift_tensor
|
247 |
+
|
248 |
+
spix_idx_tensor_ = shift9pos(spix_values)
|
249 |
+
spix_idx_tensor = np.repeat(
|
250 |
+
np.repeat(spix_idx_tensor_, spixel_height, axis=1), spixel_width, axis=2)
|
251 |
+
spixel_id_tensor = torch.from_numpy(spix_idx_tensor).type(torch.float)
|
252 |
+
|
253 |
+
#! pixel coord feature maps
|
254 |
+
all_h_coords = np.arange(0, img_height, 1)
|
255 |
+
all_w_coords = np.arange(0, img_width, 1)
|
256 |
+
curr_pxl_coord = np.array(np.meshgrid(all_h_coords, all_w_coords, indexing='ij'))
|
257 |
+
coord_feat_tensor = np.concatenate([curr_pxl_coord[1:2, :, :], curr_pxl_coord[:1, :, :]])
|
258 |
+
coord_feat_tensor = torch.from_numpy(coord_feat_tensor).type(torch.float)
|
259 |
+
|
260 |
+
return spixel_id_tensor, coord_feat_tensor
|
261 |
+
|
262 |
+
|
263 |
+
def split_spixels(assign_map, spixel_ids):
|
264 |
+
N,C,H,W = assign_map.shape
|
265 |
+
spixel_id_map = spixel_ids.expand(N,-1,-1,-1)
|
266 |
+
assig_max,_ = torch.max(assign_map, dim=1, keepdim=True)
|
267 |
+
assignment_ = torch.where(assign_map == assig_max, torch.ones(assign_map.shape).cuda(),torch.zeros(assign_map.shape).cuda())
|
268 |
+
## winner take all
|
269 |
+
new_spixl_map_ = spixel_id_map * assignment_
|
270 |
+
new_spixl_map = torch.sum(new_spixl_map_,dim=1,keepdim=True).type(torch.int)
|
271 |
+
return new_spixl_map
|
272 |
+
|
273 |
+
|
274 |
+
def poolfeat(input, prob, sp_h=2, sp_w=2, need_entry_prob=False):
|
275 |
+
def feat_prob_sum(feat_sum, prob_sum, shift_feat):
|
276 |
+
feat_sum += shift_feat[:, :-1, :, :]
|
277 |
+
prob_sum += shift_feat[:, -1:, :, :]
|
278 |
+
return feat_sum, prob_sum
|
279 |
+
|
280 |
+
b, _, h, w = input.shape
|
281 |
+
h_shift_unit = 1
|
282 |
+
w_shift_unit = 1
|
283 |
+
p2d = (w_shift_unit, w_shift_unit, h_shift_unit, h_shift_unit)
|
284 |
+
feat_ = torch.cat([input, torch.ones([b, 1, h, w], device=input.device)], dim=1) # b* (n+1) *h*w
|
285 |
+
prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 0, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
|
286 |
+
send_to_top_left = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, 2 * h_shift_unit:, 2 * w_shift_unit:]
|
287 |
+
feat_sum = send_to_top_left[:, :-1, :, :].clone()
|
288 |
+
prob_sum = send_to_top_left[:, -1:, :, :].clone()
|
289 |
+
|
290 |
+
prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 1, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
|
291 |
+
top = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, 2 * h_shift_unit:, w_shift_unit:-w_shift_unit]
|
292 |
+
feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, top)
|
293 |
+
|
294 |
+
prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 2, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
|
295 |
+
top_right = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, 2 * h_shift_unit:, :-2 * w_shift_unit]
|
296 |
+
feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, top_right)
|
297 |
+
|
298 |
+
prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 3, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
|
299 |
+
left = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, h_shift_unit:-h_shift_unit, 2 * w_shift_unit:]
|
300 |
+
feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, left)
|
301 |
+
|
302 |
+
prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 4, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
|
303 |
+
center = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, h_shift_unit:-h_shift_unit, w_shift_unit:-w_shift_unit]
|
304 |
+
feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, center)
|
305 |
+
|
306 |
+
prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 5, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
|
307 |
+
right = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, h_shift_unit:-h_shift_unit, :-2 * w_shift_unit]
|
308 |
+
feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, right)
|
309 |
+
|
310 |
+
prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 6, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
|
311 |
+
bottom_left = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, :-2 * h_shift_unit, 2 * w_shift_unit:]
|
312 |
+
feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, bottom_left)
|
313 |
+
|
314 |
+
prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 7, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
|
315 |
+
bottom = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, :-2 * h_shift_unit, w_shift_unit:-w_shift_unit]
|
316 |
+
feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, bottom)
|
317 |
+
|
318 |
+
prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 8, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
|
319 |
+
bottom_right = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, :-2 * h_shift_unit, :-2 * w_shift_unit]
|
320 |
+
feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, bottom_right)
|
321 |
+
pooled_feat = feat_sum / (prob_sum + 1e-8)
|
322 |
+
if need_entry_prob:
|
323 |
+
return pooled_feat, prob_sum
|
324 |
+
return pooled_feat
|
325 |
+
|
326 |
+
|
327 |
+
def get_spixel_size(affinity_map, sp_h=2, sp_w=2, elem_thres=25):
|
328 |
+
N,C,H,W = affinity_map.shape
|
329 |
+
device = affinity_map.device
|
330 |
+
assign_max,_ = torch.max(affinity_map, dim=1, keepdim=True)
|
331 |
+
assign_map = torch.where(affinity_map==assign_max, torch.ones(affinity_map.shape, device=device), torch.zeros(affinity_map.shape, device=device))
|
332 |
+
## one_map = (N,1,H,W)
|
333 |
+
_, elem_num_maps = poolfeat(torch.ones(assign_max.shape, device=device), assign_map, sp_h, sp_w, True)
|
334 |
+
#all_one_map = torch.ones(elem_num_maps.shape).cuda()
|
335 |
+
#empty_mask = torch.where(elem_num_maps < elem_thres/256, all_one_map, 1-all_one_map)
|
336 |
+
return elem_num_maps
|
337 |
+
|
338 |
+
|
339 |
+
def upfeat(input, prob, up_h=2, up_w=2):
|
340 |
+
# input b*n*H*W downsampled
|
341 |
+
# prob b*9*h*w
|
342 |
+
b, c, h, w = input.shape
|
343 |
+
|
344 |
+
h_shift = 1
|
345 |
+
w_shift = 1
|
346 |
+
|
347 |
+
p2d = (w_shift, w_shift, h_shift, h_shift)
|
348 |
+
feat_pd = F.pad(input, p2d, mode='constant', value=0)
|
349 |
+
|
350 |
+
gt_frm_top_left = F.interpolate(feat_pd[:, :, :-2 * h_shift, :-2 * w_shift], size=(h * up_h, w * up_w),mode='nearest')
|
351 |
+
feat_sum = gt_frm_top_left * prob.narrow(1,0,1)
|
352 |
+
|
353 |
+
top = F.interpolate(feat_pd[:, :, :-2 * h_shift, w_shift:-w_shift], size=(h * up_h, w * up_w), mode='nearest')
|
354 |
+
feat_sum += top * prob.narrow(1, 1, 1)
|
355 |
+
|
356 |
+
top_right = F.interpolate(feat_pd[:, :, :-2 * h_shift, 2 * w_shift:], size=(h * up_h, w * up_w), mode='nearest')
|
357 |
+
feat_sum += top_right * prob.narrow(1,2,1)
|
358 |
+
|
359 |
+
left = F.interpolate(feat_pd[:, :, h_shift:-w_shift, :-2 * w_shift], size=(h * up_h, w * up_w), mode='nearest')
|
360 |
+
feat_sum += left * prob.narrow(1, 3, 1)
|
361 |
+
|
362 |
+
center = F.interpolate(input, (h * up_h, w * up_w), mode='nearest')
|
363 |
+
feat_sum += center * prob.narrow(1, 4, 1)
|
364 |
+
|
365 |
+
right = F.interpolate(feat_pd[:, :, h_shift:-w_shift, 2 * w_shift:], size=(h * up_h, w * up_w), mode='nearest')
|
366 |
+
feat_sum += right * prob.narrow(1, 5, 1)
|
367 |
+
|
368 |
+
bottom_left = F.interpolate(feat_pd[:, :, 2 * h_shift:, :-2 * w_shift], size=(h * up_h, w * up_w), mode='nearest')
|
369 |
+
feat_sum += bottom_left * prob.narrow(1, 6, 1)
|
370 |
+
|
371 |
+
bottom = F.interpolate(feat_pd[:, :, 2 * h_shift:, w_shift:-w_shift], size=(h * up_h, w * up_w), mode='nearest')
|
372 |
+
feat_sum += bottom * prob.narrow(1, 7, 1)
|
373 |
+
|
374 |
+
bottom_right = F.interpolate(feat_pd[:, :, 2 * h_shift:, 2 * w_shift:], size=(h * up_h, w * up_w), mode='nearest')
|
375 |
+
feat_sum += bottom_right * prob.narrow(1, 8, 1)
|
376 |
+
|
377 |
+
return feat_sum
|
378 |
+
|
379 |
+
|
380 |
+
def suck_and_spread(self, base_maps, seg_layers):
|
381 |
+
N,S,H,W = seg_layers.shape
|
382 |
+
base_maps = base_maps.unsqueeze(1)
|
383 |
+
seg_layers = seg_layers.unsqueeze(2)
|
384 |
+
## (N,S,C,1,1) = (N,1,C,H,W) * (N,S,1,H,W)
|
385 |
+
mean_val_layers = (base_maps * seg_layers).sum(dim=(3,4), keepdim=True) / (1e-5 + seg_layers.sum(dim=(3,4), keepdim=True))
|
386 |
+
## normalized to be sum one
|
387 |
+
weight_layers = seg_layers / (1e-5 + torch.sum(seg_layers, dim=1, keepdim=True))
|
388 |
+
## (N,S,C,H,W) = (N,S,C,1,1) * (N,S,1,H,W)
|
389 |
+
recon_maps = mean_val_layers * weight_layers
|
390 |
+
return recon_maps.sum(dim=1)
|
391 |
+
|
392 |
+
|
393 |
+
#! copy from Richard Zhang [SIGGRAPH2017]
|
394 |
+
# RGB grid points maps to Lab range: L[0,100], a[-86.183,98,233], b[-107.857,94.478]
|
395 |
+
#------------------------------------------------------------------------------
|
396 |
+
def rgb2xyz(rgb): # rgb from [0,1]
|
397 |
+
# xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423],
|
398 |
+
# [0.212671, 0.715160, 0.072169],
|
399 |
+
# [0.019334, 0.119193, 0.950227]])
|
400 |
+
mask = (rgb > .04045).type(torch.FloatTensor)
|
401 |
+
if(rgb.is_cuda):
|
402 |
+
mask = mask.cuda()
|
403 |
+
rgb = (((rgb+.055)/1.055)**2.4)*mask + rgb/12.92*(1-mask)
|
404 |
+
x = .412453*rgb[:,0,:,:]+.357580*rgb[:,1,:,:]+.180423*rgb[:,2,:,:]
|
405 |
+
y = .212671*rgb[:,0,:,:]+.715160*rgb[:,1,:,:]+.072169*rgb[:,2,:,:]
|
406 |
+
z = .019334*rgb[:,0,:,:]+.119193*rgb[:,1,:,:]+.950227*rgb[:,2,:,:]
|
407 |
+
out = torch.cat((x[:,None,:,:],y[:,None,:,:],z[:,None,:,:]),dim=1)
|
408 |
+
return out
|
409 |
+
|
410 |
+
def xyz2rgb(xyz):
|
411 |
+
# array([[ 3.24048134, -1.53715152, -0.49853633],
|
412 |
+
# [-0.96925495, 1.87599 , 0.04155593],
|
413 |
+
# [ 0.05564664, -0.20404134, 1.05731107]])
|
414 |
+
r = 3.24048134*xyz[:,0,:,:]-1.53715152*xyz[:,1,:,:]-0.49853633*xyz[:,2,:,:]
|
415 |
+
g = -0.96925495*xyz[:,0,:,:]+1.87599*xyz[:,1,:,:]+.04155593*xyz[:,2,:,:]
|
416 |
+
b = .05564664*xyz[:,0,:,:]-.20404134*xyz[:,1,:,:]+1.05731107*xyz[:,2,:,:]
|
417 |
+
rgb = torch.cat((r[:,None,:,:],g[:,None,:,:],b[:,None,:,:]),dim=1)
|
418 |
+
#! sometimes reaches a small negative number, which causes NaNs
|
419 |
+
rgb = torch.max(rgb,torch.zeros_like(rgb))
|
420 |
+
mask = (rgb > .0031308).type(torch.FloatTensor)
|
421 |
+
if(rgb.is_cuda):
|
422 |
+
mask = mask.cuda()
|
423 |
+
rgb = (1.055*(rgb**(1./2.4)) - 0.055)*mask + 12.92*rgb*(1-mask)
|
424 |
+
return rgb
|
425 |
+
|
426 |
+
def xyz2lab(xyz):
|
427 |
+
# 0.95047, 1., 1.08883 # white
|
428 |
+
sc = torch.Tensor((0.95047, 1., 1.08883))[None,:,None,None]
|
429 |
+
if(xyz.is_cuda):
|
430 |
+
sc = sc.cuda()
|
431 |
+
xyz_scale = xyz/sc
|
432 |
+
mask = (xyz_scale > .008856).type(torch.FloatTensor)
|
433 |
+
if(xyz_scale.is_cuda):
|
434 |
+
mask = mask.cuda()
|
435 |
+
xyz_int = xyz_scale**(1/3.)*mask + (7.787*xyz_scale + 16./116.)*(1-mask)
|
436 |
+
L = 116.*xyz_int[:,1,:,:]-16.
|
437 |
+
a = 500.*(xyz_int[:,0,:,:]-xyz_int[:,1,:,:])
|
438 |
+
b = 200.*(xyz_int[:,1,:,:]-xyz_int[:,2,:,:])
|
439 |
+
out = torch.cat((L[:,None,:,:],a[:,None,:,:],b[:,None,:,:]),dim=1)
|
440 |
+
return out
|
441 |
+
|
442 |
+
def lab2xyz(lab):
|
443 |
+
y_int = (lab[:,0,:,:]+16.)/116.
|
444 |
+
x_int = (lab[:,1,:,:]/500.) + y_int
|
445 |
+
z_int = y_int - (lab[:,2,:,:]/200.)
|
446 |
+
if(z_int.is_cuda):
|
447 |
+
z_int = torch.max(torch.Tensor((0,)).cuda(), z_int)
|
448 |
+
else:
|
449 |
+
z_int = torch.max(torch.Tensor((0,)), z_int)
|
450 |
+
out = torch.cat((x_int[:,None,:,:],y_int[:,None,:,:],z_int[:,None,:,:]),dim=1)
|
451 |
+
mask = (out > .2068966).type(torch.FloatTensor)
|
452 |
+
if(out.is_cuda):
|
453 |
+
mask = mask.cuda()
|
454 |
+
out = (out**3.)*mask + (out - 16./116.)/7.787*(1-mask)
|
455 |
+
sc = torch.Tensor((0.95047, 1., 1.08883))[None,:,None,None]
|
456 |
+
sc = sc.to(out.device)
|
457 |
+
out = out*sc
|
458 |
+
return out
|
459 |
+
|
460 |
+
def rgb2lab(rgb, l_mean=50, l_norm=50, ab_norm=110):
|
461 |
+
#! input rgb: [0,1]
|
462 |
+
#! output lab: [-1,1]
|
463 |
+
lab = xyz2lab(rgb2xyz(rgb))
|
464 |
+
l_rs = (lab[:,[0],:,:]-l_mean) / l_norm
|
465 |
+
ab_rs = lab[:,1:,:,:] / ab_norm
|
466 |
+
out = torch.cat((l_rs,ab_rs),dim=1)
|
467 |
+
return out
|
468 |
+
|
469 |
+
def lab2rgb(lab_rs, l_mean=50, l_norm=50, ab_norm=110):
|
470 |
+
#! input lab: [-1,1]
|
471 |
+
#! output rgb: [0,1]
|
472 |
+
l_ = lab_rs[:,[0],:,:] * l_norm + l_mean
|
473 |
+
ab = lab_rs[:,1:,:,:] * ab_norm
|
474 |
+
lab = torch.cat((l_,ab), dim=1)
|
475 |
+
out = xyz2rgb(lab2xyz(lab))
|
476 |
+
return out
|
477 |
+
|
478 |
+
|
479 |
+
if __name__ == '__main__':
|
480 |
+
minL, minA, minB = 999., 999., 999.
|
481 |
+
maxL, maxA, maxB = 0., 0., 0.
|
482 |
+
for r in range(256):
|
483 |
+
print('h',r)
|
484 |
+
for g in range(256):
|
485 |
+
for b in range(256):
|
486 |
+
rgb = np.array([r,g,b], np.float32).reshape(1,1,-1) / 255.0
|
487 |
+
#lab_img = cv2.cvtColor(rgb, cv2.COLOR_RGB2LAB)
|
488 |
+
rgb = torch.from_numpy(rgb.transpose((2, 0, 1)))
|
489 |
+
rgb = rgb.reshape(1,3,1,1)
|
490 |
+
lab = rgb2lab(rgb)
|
491 |
+
lab[:,[0],:,:] = lab[:,[0],:,:] * 50 + 50
|
492 |
+
lab[:,1:,:,:] = lab[:,1:,:,:] * 110
|
493 |
+
lab = lab.squeeze()
|
494 |
+
lab_float = lab.numpy()
|
495 |
+
#print('zhang vs. cv2:', lab_float, lab_img.squeeze())
|
496 |
+
minL = min(lab_float[0], minL)
|
497 |
+
minA = min(lab_float[1], minA)
|
498 |
+
minB = min(lab_float[2], minB)
|
499 |
+
maxL = max(lab_float[0], maxL)
|
500 |
+
maxA = max(lab_float[1], maxA)
|
501 |
+
maxB = max(lab_float[2], maxB)
|
502 |
+
print('L:', minL, maxL)
|
503 |
+
print('A:', minA, maxA)
|
504 |
+
print('B:', minB, maxB)
|
models/clusterkit.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from functools import partial
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
import math, random
|
10 |
+
#from sklearn.cluster import KMeans, kmeans_plusplus, MeanShift, estimate_bandwidth
|
11 |
+
|
12 |
+
|
13 |
+
def tensor_kmeans_sklearn(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20):
|
14 |
+
N,C,H,W = data_vecs.shape
|
15 |
+
assert N == 1, 'only support singe image tensor'
|
16 |
+
## (1,C,H,W) -> (HW,C)
|
17 |
+
data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
|
18 |
+
## convert tensor to array
|
19 |
+
data_vecs_np = data_vecs.squeeze().detach().to("cpu").numpy()
|
20 |
+
km = KMeans(n_clusters=n_clusters, init='k-means++', n_init=10, max_iter=300)
|
21 |
+
pred = km.fit_predict(data_vecs_np)
|
22 |
+
cluster_ids_x = torch.from_numpy(km.labels_).to(data_vecs.device)
|
23 |
+
id_maps = cluster_ids_x.reshape(1,1,H,W).long()
|
24 |
+
if need_layer_masks:
|
25 |
+
one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float()
|
26 |
+
cluster_mask = one_hot_labels.permute(0,3,1,2)
|
27 |
+
return cluster_mask
|
28 |
+
return id_maps
|
29 |
+
|
30 |
+
|
31 |
+
def tensor_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20):
|
32 |
+
N,C,H,W = data_vecs.shape
|
33 |
+
assert N == 1, 'only support singe image tensor'
|
34 |
+
|
35 |
+
## (1,C,H,W) -> (HW,C)
|
36 |
+
data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
|
37 |
+
## cosine | euclidean
|
38 |
+
#cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric, device=data_vecs.device)
|
39 |
+
cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\
|
40 |
+
tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device)
|
41 |
+
id_maps = cluster_ids_x.reshape(1,1,H,W)
|
42 |
+
if need_layer_masks:
|
43 |
+
one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float()
|
44 |
+
cluster_mask = one_hot_labels.permute(0,3,1,2)
|
45 |
+
return cluster_mask
|
46 |
+
return id_maps
|
47 |
+
|
48 |
+
|
49 |
+
def batch_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', use_sklearn_kmeans=False):
|
50 |
+
N,C,H,W = data_vecs.shape
|
51 |
+
sample_list = []
|
52 |
+
for idx in range(N):
|
53 |
+
if use_sklearn_kmeans:
|
54 |
+
cluster_mask = tensor_kmeans_sklearn(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True)
|
55 |
+
else:
|
56 |
+
cluster_mask = tensor_kmeans_pytorch(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True)
|
57 |
+
sample_list.append(cluster_mask)
|
58 |
+
return torch.cat(sample_list, dim=0)
|
59 |
+
|
60 |
+
|
61 |
+
def get_centroid_candidates(data_vecs, n_clusters=7, metric='euclidean', max_iters=20):
|
62 |
+
N,C,H,W = data_vecs.shape
|
63 |
+
data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
|
64 |
+
cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\
|
65 |
+
tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device)
|
66 |
+
return cluster_centers
|
67 |
+
|
68 |
+
|
69 |
+
def find_distinctive_elements(data_tensor, n_clusters=7, topk=3, metric='euclidean'):
|
70 |
+
N,C,H,W = data_tensor.shape
|
71 |
+
centroid_list = []
|
72 |
+
for idx in range(N):
|
73 |
+
cluster_centers = get_centroid_candidates(data_tensor[idx:idx+1,:,:,:], n_clusters, metric)
|
74 |
+
centroid_list.append(cluster_centers)
|
75 |
+
|
76 |
+
batch_centroids = torch.stack(centroid_list, dim=0)
|
77 |
+
data_vecs = data_tensor.flatten(2)
|
78 |
+
## distance matrix: (N,K,HW) = (N,K,C) x (N,C,HW)
|
79 |
+
AtB = torch.matmul(batch_centroids, data_vecs)
|
80 |
+
AtA = torch.matmul(batch_centroids, batch_centroids.permute(0,2,1))
|
81 |
+
BtB = torch.matmul(data_vecs.permute(0,2,1), data_vecs)
|
82 |
+
diag_A = torch.diagonal(AtA, dim1=-2, dim2=-1)
|
83 |
+
diag_B = torch.diagonal(BtB, dim1=-2, dim2=-1)
|
84 |
+
A2 = diag_A.unsqueeze(2).repeat(1,1,H*W)
|
85 |
+
B2 = diag_B.unsqueeze(1).repeat(1,n_clusters,1)
|
86 |
+
distance_map = A2 - 2*AtB + B2
|
87 |
+
values, indices = distance_map.topk(topk, dim=2, largest=False, sorted=True)
|
88 |
+
cluster_mask = torch.where(distance_map <= values[:,:,topk-1:], torch.ones_like(distance_map), torch.zeros_like(distance_map))
|
89 |
+
cluster_mask = cluster_mask.view(N,n_clusters,H,W)
|
90 |
+
return cluster_mask
|
91 |
+
|
92 |
+
|
93 |
+
##---------------------------------------------------------------------------------
|
94 |
+
'''
|
95 |
+
resource from github: https://github.com/subhadarship/kmeans_pytorch
|
96 |
+
'''
|
97 |
+
##---------------------------------------------------------------------------------
|
98 |
+
|
99 |
+
def initialize(X, num_clusters):
|
100 |
+
"""
|
101 |
+
initialize cluster centers
|
102 |
+
:param X: (torch.tensor) matrix
|
103 |
+
:param num_clusters: (int) number of clusters
|
104 |
+
:return: (np.array) initial state
|
105 |
+
"""
|
106 |
+
np.random.seed(1)
|
107 |
+
num_samples = len(X)
|
108 |
+
indices = np.random.choice(num_samples, num_clusters, replace=False)
|
109 |
+
initial_state = X[indices]
|
110 |
+
return initial_state
|
111 |
+
|
112 |
+
|
113 |
+
def kmeans(
|
114 |
+
X,
|
115 |
+
num_clusters,
|
116 |
+
distance='euclidean',
|
117 |
+
cluster_centers=[],
|
118 |
+
tol=1e-4,
|
119 |
+
tqdm_flag=True,
|
120 |
+
iter_limit=0,
|
121 |
+
device=torch.device('cpu'),
|
122 |
+
gamma_for_soft_dtw=0.001
|
123 |
+
):
|
124 |
+
"""
|
125 |
+
perform kmeans
|
126 |
+
:param X: (torch.tensor) matrix
|
127 |
+
:param num_clusters: (int) number of clusters
|
128 |
+
:param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
|
129 |
+
:param tol: (float) threshold [default: 0.0001]
|
130 |
+
:param device: (torch.device) device [default: cpu]
|
131 |
+
:param tqdm_flag: Allows to turn logs on and off
|
132 |
+
:param iter_limit: hard limit for max number of iterations
|
133 |
+
:param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0
|
134 |
+
:return: (torch.tensor, torch.tensor) cluster ids, cluster centers
|
135 |
+
"""
|
136 |
+
if tqdm_flag:
|
137 |
+
print(f'running k-means on {device}..')
|
138 |
+
|
139 |
+
if distance == 'euclidean':
|
140 |
+
pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag)
|
141 |
+
elif distance == 'cosine':
|
142 |
+
pairwise_distance_function = partial(pairwise_cosine, device=device)
|
143 |
+
else:
|
144 |
+
raise NotImplementedError
|
145 |
+
|
146 |
+
# convert to float
|
147 |
+
X = X.float()
|
148 |
+
|
149 |
+
# transfer to device
|
150 |
+
X = X.to(device)
|
151 |
+
|
152 |
+
# initialize
|
153 |
+
if type(cluster_centers) == list: # ToDo: make this less annoyingly weird
|
154 |
+
initial_state = initialize(X, num_clusters)
|
155 |
+
else:
|
156 |
+
if tqdm_flag:
|
157 |
+
print('resuming')
|
158 |
+
# find data point closest to the initial cluster center
|
159 |
+
initial_state = cluster_centers
|
160 |
+
dis = pairwise_distance_function(X, initial_state)
|
161 |
+
choice_points = torch.argmin(dis, dim=0)
|
162 |
+
initial_state = X[choice_points]
|
163 |
+
initial_state = initial_state.to(device)
|
164 |
+
|
165 |
+
iteration = 0
|
166 |
+
if tqdm_flag:
|
167 |
+
tqdm_meter = tqdm(desc='[running kmeans]')
|
168 |
+
while True:
|
169 |
+
|
170 |
+
dis = pairwise_distance_function(X, initial_state)
|
171 |
+
|
172 |
+
choice_cluster = torch.argmin(dis, dim=1)
|
173 |
+
|
174 |
+
initial_state_pre = initial_state.clone()
|
175 |
+
|
176 |
+
for index in range(num_clusters):
|
177 |
+
selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
|
178 |
+
|
179 |
+
selected = torch.index_select(X, 0, selected)
|
180 |
+
|
181 |
+
# https://github.com/subhadarship/kmeans_pytorch/issues/16
|
182 |
+
if selected.shape[0] == 0:
|
183 |
+
selected = X[torch.randint(len(X), (1,))]
|
184 |
+
|
185 |
+
initial_state[index] = selected.mean(dim=0)
|
186 |
+
|
187 |
+
center_shift = torch.sum(
|
188 |
+
torch.sqrt(
|
189 |
+
torch.sum((initial_state - initial_state_pre) ** 2, dim=1)
|
190 |
+
))
|
191 |
+
|
192 |
+
# increment iteration
|
193 |
+
iteration = iteration + 1
|
194 |
+
|
195 |
+
# update tqdm meter
|
196 |
+
if tqdm_flag:
|
197 |
+
tqdm_meter.set_postfix(
|
198 |
+
iteration=f'{iteration}',
|
199 |
+
center_shift=f'{center_shift ** 2:0.6f}',
|
200 |
+
tol=f'{tol:0.6f}'
|
201 |
+
)
|
202 |
+
tqdm_meter.update()
|
203 |
+
if center_shift ** 2 < tol:
|
204 |
+
break
|
205 |
+
if iter_limit != 0 and iteration >= iter_limit:
|
206 |
+
#print('hello, there!')
|
207 |
+
break
|
208 |
+
|
209 |
+
return choice_cluster.to(device), initial_state.to(device)
|
210 |
+
|
211 |
+
|
212 |
+
def kmeans_predict(
|
213 |
+
X,
|
214 |
+
cluster_centers,
|
215 |
+
distance='euclidean',
|
216 |
+
device=torch.device('cpu'),
|
217 |
+
gamma_for_soft_dtw=0.001,
|
218 |
+
tqdm_flag=True
|
219 |
+
):
|
220 |
+
"""
|
221 |
+
predict using cluster centers
|
222 |
+
:param X: (torch.tensor) matrix
|
223 |
+
:param cluster_centers: (torch.tensor) cluster centers
|
224 |
+
:param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
|
225 |
+
:param device: (torch.device) device [default: 'cpu']
|
226 |
+
:param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0
|
227 |
+
:return: (torch.tensor) cluster ids
|
228 |
+
"""
|
229 |
+
if tqdm_flag:
|
230 |
+
print(f'predicting on {device}..')
|
231 |
+
|
232 |
+
if distance == 'euclidean':
|
233 |
+
pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag)
|
234 |
+
elif distance == 'cosine':
|
235 |
+
pairwise_distance_function = partial(pairwise_cosine, device=device)
|
236 |
+
elif distance == 'soft_dtw':
|
237 |
+
sdtw = SoftDTW(use_cuda=device.type == 'cuda', gamma=gamma_for_soft_dtw)
|
238 |
+
pairwise_distance_function = partial(pairwise_soft_dtw, sdtw=sdtw, device=device)
|
239 |
+
else:
|
240 |
+
raise NotImplementedError
|
241 |
+
|
242 |
+
# convert to float
|
243 |
+
X = X.float()
|
244 |
+
|
245 |
+
# transfer to device
|
246 |
+
X = X.to(device)
|
247 |
+
|
248 |
+
dis = pairwise_distance_function(X, cluster_centers)
|
249 |
+
choice_cluster = torch.argmin(dis, dim=1)
|
250 |
+
|
251 |
+
return choice_cluster.cpu()
|
252 |
+
|
253 |
+
|
254 |
+
def pairwise_distance(data1, data2, device=torch.device('cpu'), tqdm_flag=True):
|
255 |
+
if tqdm_flag:
|
256 |
+
print(f'device is :{device}')
|
257 |
+
|
258 |
+
# transfer to device
|
259 |
+
data1, data2 = data1.to(device), data2.to(device)
|
260 |
+
|
261 |
+
# N*1*M
|
262 |
+
A = data1.unsqueeze(dim=1)
|
263 |
+
|
264 |
+
# 1*N*M
|
265 |
+
B = data2.unsqueeze(dim=0)
|
266 |
+
|
267 |
+
dis = (A - B) ** 2.0
|
268 |
+
# return N*N matrix for pairwise distance
|
269 |
+
dis = dis.sum(dim=-1).squeeze()
|
270 |
+
return dis
|
271 |
+
|
272 |
+
|
273 |
+
def pairwise_cosine(data1, data2, device=torch.device('cpu')):
|
274 |
+
# transfer to device
|
275 |
+
data1, data2 = data1.to(device), data2.to(device)
|
276 |
+
|
277 |
+
# N*1*M
|
278 |
+
A = data1.unsqueeze(dim=1)
|
279 |
+
|
280 |
+
# 1*N*M
|
281 |
+
B = data2.unsqueeze(dim=0)
|
282 |
+
|
283 |
+
# normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5]
|
284 |
+
A_normalized = A / A.norm(dim=-1, keepdim=True)
|
285 |
+
B_normalized = B / B.norm(dim=-1, keepdim=True)
|
286 |
+
|
287 |
+
cosine = A_normalized * B_normalized
|
288 |
+
|
289 |
+
# return N*N matrix for pairwise distance
|
290 |
+
cosine_dis = 1 - cosine.sum(dim=-1).squeeze()
|
291 |
+
return cosine_dis
|
models/loss.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
import os, glob, shutil, math, random, json
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torchvision
|
7 |
+
import basic
|
8 |
+
from utils import util
|
9 |
+
|
10 |
+
eps = 0.0000001
|
11 |
+
|
12 |
+
class SPixelLoss:
|
13 |
+
def __init__(self, psize=8, mpdist=False, gpu_no=0):
|
14 |
+
self.mpdist = mpdist
|
15 |
+
self.gpu_no = gpu_no
|
16 |
+
self.sp_size = psize
|
17 |
+
|
18 |
+
def __call__(self, data, epoch_no):
|
19 |
+
kernel_size = self.sp_size
|
20 |
+
#pos_weight = 0.003
|
21 |
+
prob = data['pred_prob']
|
22 |
+
labxy_feat = data['target_feat']
|
23 |
+
N,C,H,W = labxy_feat.shape
|
24 |
+
pooled_labxy = basic.poolfeat(labxy_feat, prob, kernel_size, kernel_size)
|
25 |
+
reconstr_feat = basic.upfeat(pooled_labxy, prob, kernel_size, kernel_size)
|
26 |
+
loss_map = reconstr_feat[:,:,:,:] - labxy_feat[:,:,:,:]
|
27 |
+
featLoss_idx = torch.norm(loss_map[:,:-2,:,:], p=2, dim=1).mean()
|
28 |
+
posLoss_idx = torch.norm(loss_map[:,-2:,:,:], p=2, dim=1).mean() / kernel_size
|
29 |
+
totalLoss_idx = 10*featLoss_idx + 0.003*posLoss_idx
|
30 |
+
return {'totalLoss':totalLoss_idx, 'featLoss':featLoss_idx, 'posLoss':posLoss_idx}
|
31 |
+
|
32 |
+
|
33 |
+
class AnchorColorProbLoss:
|
34 |
+
def __init__(self, hint2regress=False, enhanced=False, with_grad=False, mpdist=False, gpu_no=0):
|
35 |
+
self.mpdist = mpdist
|
36 |
+
self.gpu_no = gpu_no
|
37 |
+
self.hint2regress = hint2regress
|
38 |
+
self.enhanced = enhanced
|
39 |
+
self.with_grad = with_grad
|
40 |
+
self.rebalance_gradient = basic.RebalanceLoss.apply
|
41 |
+
self.entropy_loss = nn.CrossEntropyLoss(ignore_index=-1)
|
42 |
+
if self.enhanced:
|
43 |
+
self.VGGLoss = VGG19Loss(gpu_no=gpu_no, is_ddp=mpdist)
|
44 |
+
|
45 |
+
def _perceptual_loss(self, input_grays, input_colors, pred_colors):
|
46 |
+
input_RGBs = basic.lab2rgb(torch.cat([input_grays,input_colors], dim=1))
|
47 |
+
pred_RGBs = basic.lab2rgb(torch.cat([input_grays,pred_colors], dim=1))
|
48 |
+
## the output of "lab2rgb" just matches the input of "VGGLoss": [0,1]
|
49 |
+
return self.VGGLoss(input_RGBs, pred_RGBs)
|
50 |
+
|
51 |
+
def _laplace_gradient(self, pred_AB, target_AB):
|
52 |
+
N,C,H,W = pred_AB.shape
|
53 |
+
kernel = torch.tensor([[1, 1, 1], [1, -8, 1], [1, 1, 1]], device=pred_AB.get_device()).float()
|
54 |
+
kernel = kernel.view(1, 1, *kernel.size()).repeat(C,1,1,1)
|
55 |
+
grad_pred = F.conv2d(pred_AB, kernel, groups=C)
|
56 |
+
grad_trg = F.conv2d(target_AB, kernel, groups=C)
|
57 |
+
return l1_loss(grad_trg, grad_pred)
|
58 |
+
|
59 |
+
def __call__(self, data, epoch_no):
|
60 |
+
N,C,H,W = data['target_label'].shape
|
61 |
+
pal_probs = self.rebalance_gradient(data['pal_prob'], data['class_weight'])
|
62 |
+
#ref_probs = data['ref_prob']
|
63 |
+
pal_probs = pal_probs.permute(0,2,3,1).contiguous().view(N*H*W, -1)
|
64 |
+
gt_labels = data['target_label'].permute(0,2,3,1).contiguous().view(N*H*W, -1)
|
65 |
+
'''
|
66 |
+
igored_mask = data['empty_entries'].permute(0,2,3,1).contiguous().view(N*H*W, -1)
|
67 |
+
gt_labels[igored_mask] = -1
|
68 |
+
gt_labels = gt_probs.squeeze()
|
69 |
+
'''
|
70 |
+
palLoss_idx = self.entropy_loss(pal_probs, gt_labels.squeeze(dim=1))
|
71 |
+
if self.hint2regress:
|
72 |
+
ref_probs = data['ref_prob']
|
73 |
+
refLoss_idx = 50 * l2_loss(data['spix_color'], ref_probs)
|
74 |
+
else:
|
75 |
+
ref_probs = self.rebalance_gradient(data['ref_prob'], data['class_weight'])
|
76 |
+
ref_probs = ref_probs.permute(0,2,3,1).contiguous().view(N*H*W, -1)
|
77 |
+
refLoss_idx = self.entropy_loss(ref_probs, gt_labels.squeeze(dim=1))
|
78 |
+
reconLoss_idx = torch.zeros_like(palLoss_idx)
|
79 |
+
if self.enhanced:
|
80 |
+
scalar = 1.0 if self.hint2regress else 5.0
|
81 |
+
reconLoss_idx = scalar * self._perceptual_loss(data['input_gray'], data['pred_color'], data['input_color'])
|
82 |
+
if self.with_grad:
|
83 |
+
gradient_loss = self._laplace_gradient(data['pred_color'], data['input_color'])
|
84 |
+
reconLoss_idx += gradient_loss
|
85 |
+
totalLoss_idx = palLoss_idx + refLoss_idx + reconLoss_idx
|
86 |
+
#print("loss terms:", palLoss_idx.item(), refLoss_idx.item(), reconLoss_idx.item())
|
87 |
+
return {'totalLoss':totalLoss_idx, 'palLoss':palLoss_idx, 'refLoss':refLoss_idx, 'recLoss':reconLoss_idx}
|
88 |
+
|
89 |
+
|
90 |
+
def compute_affinity_pos_loss(prob_in, labxy_feat, pos_weight=0.003, kernel_size=16):
|
91 |
+
S = kernel_size
|
92 |
+
m = pos_weight
|
93 |
+
prob = prob_in.clone()
|
94 |
+
N,C,H,W = labxy_feat.shape
|
95 |
+
pooled_labxy = basic.poolfeat(labxy_feat, prob, kernel_size, kernel_size)
|
96 |
+
reconstr_feat = basic.upfeat(pooled_labxy, prob, kernel_size, kernel_size)
|
97 |
+
loss_map = reconstr_feat[:,:,:,:] - labxy_feat[:,:,:,:]
|
98 |
+
loss_feat = torch.norm(loss_map[:,:-2,:,:], p=2, dim=1).mean()
|
99 |
+
loss_pos = torch.norm(loss_map[:,-2:,:,:], p=2, dim=1).mean() * m / S
|
100 |
+
loss_affinity = loss_feat + loss_pos
|
101 |
+
return loss_affinity
|
102 |
+
|
103 |
+
|
104 |
+
def l2_loss(y_input, y_target, weight_map=None):
|
105 |
+
if weight_map is None:
|
106 |
+
return F.mse_loss(y_input, y_target)
|
107 |
+
else:
|
108 |
+
diff_map = torch.mean(torch.abs(y_input-y_target), dim=1, keepdim=True)
|
109 |
+
batch_dev = torch.sum(diff_map*diff_map*weight_map, dim=(1,2,3)) / (eps+torch.sum(weight_map, dim=(1,2,3)))
|
110 |
+
return batch_dev.mean()
|
111 |
+
|
112 |
+
|
113 |
+
def l1_loss(y_input, y_target, weight_map=None):
|
114 |
+
if weight_map is None:
|
115 |
+
return F.l1_loss(y_input, y_target)
|
116 |
+
else:
|
117 |
+
diff_map = torch.mean(torch.abs(y_input-y_target), dim=1, keepdim=True)
|
118 |
+
batch_dev = torch.sum(diff_map*weight_map, dim=(1,2,3)) / (eps+torch.sum(weight_map, dim=(1,2,3)))
|
119 |
+
return batch_dev.mean()
|
120 |
+
|
121 |
+
|
122 |
+
def masked_l1_loss(y_input, y_target, outlier_mask):
|
123 |
+
one = torch.tensor([1.0]).cuda(y_input.get_device())
|
124 |
+
weight_map = torch.where(outlier_mask, one * 0.0, one * 1.0)
|
125 |
+
return l1_loss(y_input, y_target, weight_map)
|
126 |
+
|
127 |
+
|
128 |
+
def huber_loss(y_input, y_target, delta=0.01):
|
129 |
+
mask = torch.zeros_like(y_input)
|
130 |
+
mann = torch.abs(y_input - y_target)
|
131 |
+
eucl = 0.5 * (mann**2)
|
132 |
+
mask[...] = mann < delta
|
133 |
+
loss = eucl * mask / delta + (mann - 0.5 * delta) * (1 - mask)
|
134 |
+
return torch.mean(loss)
|
135 |
+
|
136 |
+
|
137 |
+
## Perceptual loss that uses a pretrained VGG network
|
138 |
+
class VGG19Loss(nn.Module):
|
139 |
+
def __init__(self, feat_type='liu', gpu_no=0, is_ddp=False, requires_grad=False):
|
140 |
+
super(VGG19Loss, self).__init__()
|
141 |
+
os.environ['TORCH_HOME'] = '/apdcephfs/share_1290939/richardxia/Saved/Checkpoints/VGG19'
|
142 |
+
## data requirement: (N,C,H,W) in RGB format, [0,1] range, and resolution >= 224x224
|
143 |
+
self.mean = [0.485, 0.456, 0.406]
|
144 |
+
self.std = [0.229, 0.224, 0.225]
|
145 |
+
self.feat_type = feat_type
|
146 |
+
|
147 |
+
vgg_model = torchvision.models.vgg19(pretrained=True)
|
148 |
+
## AssertionError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient
|
149 |
+
'''
|
150 |
+
if is_ddp:
|
151 |
+
vgg_model = vgg_model.cuda(gpu_no)
|
152 |
+
vgg_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vgg_model)
|
153 |
+
vgg_model = torch.nn.parallel.DistributedDataParallel(vgg_model, device_ids=[gpu_no], find_unused_parameters=True)
|
154 |
+
else:
|
155 |
+
vgg_model = vgg_model.cuda(gpu_no)
|
156 |
+
'''
|
157 |
+
vgg_model = vgg_model.cuda(gpu_no)
|
158 |
+
if self.feat_type == 'liu':
|
159 |
+
## conv1_1, conv2_1, conv3_1, conv4_1, conv5_1
|
160 |
+
self.slice1 = nn.Sequential(*list(vgg_model.features)[:2]).eval()
|
161 |
+
self.slice2 = nn.Sequential(*list(vgg_model.features)[2:7]).eval()
|
162 |
+
self.slice3 = nn.Sequential(*list(vgg_model.features)[7:12]).eval()
|
163 |
+
self.slice4 = nn.Sequential(*list(vgg_model.features)[12:21]).eval()
|
164 |
+
self.slice5 = nn.Sequential(*list(vgg_model.features)[21:30]).eval()
|
165 |
+
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
|
166 |
+
elif self.feat_type == 'lei':
|
167 |
+
## conv1_2, conv2_2, conv3_2, conv4_2, conv5_2
|
168 |
+
self.slice1 = nn.Sequential(*list(vgg_model.features)[:4]).eval()
|
169 |
+
self.slice2 = nn.Sequential(*list(vgg_model.features)[4:9]).eval()
|
170 |
+
self.slice3 = nn.Sequential(*list(vgg_model.features)[9:14]).eval()
|
171 |
+
self.slice4 = nn.Sequential(*list(vgg_model.features)[14:23]).eval()
|
172 |
+
self.slice5 = nn.Sequential(*list(vgg_model.features)[23:32]).eval()
|
173 |
+
self.weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10.0/1.5]
|
174 |
+
else:
|
175 |
+
## maxpool after conv4_4
|
176 |
+
self.featureExactor = nn.Sequential(*list(vgg_model.features)[:28]).eval()
|
177 |
+
'''
|
178 |
+
for x in range(2):
|
179 |
+
self.slice1.add_module(str(x), pretrained_features[x])
|
180 |
+
for x in range(2, 7):
|
181 |
+
self.slice2.add_module(str(x), pretrained_features[x])
|
182 |
+
for x in range(7, 12):
|
183 |
+
self.slice3.add_module(str(x), pretrained_features[x])
|
184 |
+
for x in range(12, 21):
|
185 |
+
self.slice4.add_module(str(x), pretrained_features[x])
|
186 |
+
for x in range(21, 30):
|
187 |
+
self.slice5.add_module(str(x), pretrained_features[x])
|
188 |
+
'''
|
189 |
+
self.criterion = nn.L1Loss()
|
190 |
+
|
191 |
+
## fixed parameters
|
192 |
+
if not requires_grad:
|
193 |
+
for param in self.parameters():
|
194 |
+
param.requires_grad = False
|
195 |
+
self.eval()
|
196 |
+
print('[*] VGG19Loss init!')
|
197 |
+
|
198 |
+
def normalize(self, tensor):
|
199 |
+
tensor = tensor.clone()
|
200 |
+
mean = torch.as_tensor(self.mean, dtype=torch.float32, device=tensor.device)
|
201 |
+
std = torch.as_tensor(self.std, dtype=torch.float32, device=tensor.device)
|
202 |
+
tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
|
203 |
+
return tensor
|
204 |
+
|
205 |
+
def forward(self, x, y):
|
206 |
+
norm_x, norm_y = self.normalize(x), self.normalize(y)
|
207 |
+
## feature extract
|
208 |
+
if self.feat_type == 'liu' or self.feat_type == 'lei':
|
209 |
+
x_relu1, y_relu1 = self.slice1(norm_x), self.slice1(norm_y)
|
210 |
+
x_relu2, y_relu2 = self.slice2(x_relu1), self.slice2(y_relu1)
|
211 |
+
x_relu3, y_relu3 = self.slice3(x_relu2), self.slice3(y_relu2)
|
212 |
+
x_relu4, y_relu4 = self.slice4(x_relu3), self.slice4(y_relu3)
|
213 |
+
x_relu5, y_relu5 = self.slice5(x_relu4), self.slice5(y_relu4)
|
214 |
+
x_vgg = [x_relu1, x_relu2, x_relu3, x_relu4, x_relu5]
|
215 |
+
y_vgg = [y_relu1, y_relu2, y_relu3, y_relu4, y_relu5]
|
216 |
+
loss = 0
|
217 |
+
for i in range(len(x_vgg)):
|
218 |
+
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
|
219 |
+
else:
|
220 |
+
x_vgg, y_vgg = self.featureExactor(norm_x), self.featureExactor(norm_y)
|
221 |
+
loss = self.criterion(x_vgg, y_vgg.detach())
|
222 |
+
return loss
|
models/model.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from models.network import HourGlass2, SpixelNet, ColorProbNet
|
5 |
+
from models.transformer2d import EncoderLayer, DecoderLayer, TransformerEncoder, TransformerDecoder
|
6 |
+
from models.position_encoding import build_position_encoding
|
7 |
+
from models import basic, clusterkit, anchor_gen
|
8 |
+
from collections import OrderedDict
|
9 |
+
from utils import util, cielab
|
10 |
+
|
11 |
+
|
12 |
+
class SpixelSeg(nn.Module):
|
13 |
+
def __init__(self, inChannel=1, outChannel=9, batchNorm=True):
|
14 |
+
super(SpixelSeg, self).__init__()
|
15 |
+
self.net = SpixelNet(inChannel=inChannel, outChannel=outChannel, batchNorm=batchNorm)
|
16 |
+
|
17 |
+
def get_trainable_params(self, lr=1.0):
|
18 |
+
#print('=> [optimizer] finetune backbone with smaller lr')
|
19 |
+
params = []
|
20 |
+
for name, param in self.named_parameters():
|
21 |
+
if 'xxx' in name:
|
22 |
+
params.append({'params': param, 'lr': lr})
|
23 |
+
else:
|
24 |
+
params.append({'params': param})
|
25 |
+
return params
|
26 |
+
|
27 |
+
def forward(self, input_grays):
|
28 |
+
pred_probs = self.net(input_grays)
|
29 |
+
return pred_probs
|
30 |
+
|
31 |
+
|
32 |
+
class AnchorColorProb(nn.Module):
|
33 |
+
def __init__(self, inChannel=1, outChannel=313, sp_size=16, d_model=64, use_dense_pos=True, spix_pos=False, learning_pos=False, \
|
34 |
+
random_hint=False, hint2regress=False, enhanced=False, use_mask=False, rank=0, colorLabeler=None):
|
35 |
+
super(AnchorColorProb, self).__init__()
|
36 |
+
self.sp_size = sp_size
|
37 |
+
self.spix_pos = spix_pos
|
38 |
+
self.use_token_mask = use_mask
|
39 |
+
self.hint2regress = hint2regress
|
40 |
+
self.segnet = SpixelSeg(inChannel=1, outChannel=9, batchNorm=True)
|
41 |
+
self.repnet = ColorProbNet(inChannel=inChannel, outChannel=64)
|
42 |
+
self.enhanced = enhanced
|
43 |
+
if self.enhanced:
|
44 |
+
self.enhanceNet = HourGlass2(inChannel=64+1, outChannel=2, resNum=3, normLayer=nn.BatchNorm2d)
|
45 |
+
|
46 |
+
## transformer architecture
|
47 |
+
self.n_vocab = 313
|
48 |
+
d_model, dim_feedforward, nhead = d_model, 4*d_model, 8
|
49 |
+
dropout, activation = 0.1, "relu"
|
50 |
+
n_enc_layers, n_dec_layers = 6, 6
|
51 |
+
enc_layer = EncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, use_dense_pos)
|
52 |
+
self.wildpath = TransformerEncoder(enc_layer, n_enc_layers, use_dense_pos)
|
53 |
+
self.hintpath = TransformerEncoder(enc_layer, n_enc_layers, use_dense_pos)
|
54 |
+
if self.spix_pos:
|
55 |
+
n_pos_x, n_pos_y = 256, 256
|
56 |
+
else:
|
57 |
+
n_pos_x, n_pos_y = 256//sp_size, 16//sp_size
|
58 |
+
self.pos_enc = build_position_encoding(d_model//2, n_pos_x, n_pos_y, is_learned=False)
|
59 |
+
|
60 |
+
self.mid_word_prj = nn.Linear(d_model, self.n_vocab, bias=False)
|
61 |
+
if self.hint2regress:
|
62 |
+
self.trg_word_emb = nn.Linear(d_model+2+1, d_model, bias=False)
|
63 |
+
self.trg_word_prj = nn.Linear(d_model, 2, bias=False)
|
64 |
+
else:
|
65 |
+
self.trg_word_emb = nn.Linear(d_model+self.n_vocab+1, d_model, bias=False)
|
66 |
+
self.trg_word_prj = nn.Linear(d_model, self.n_vocab, bias=False)
|
67 |
+
|
68 |
+
self.colorLabeler = colorLabeler
|
69 |
+
anchor_mode = 'random' if random_hint else 'clustering'
|
70 |
+
self.anchorGen = anchor_gen.AnchorAnalysis(mode=anchor_mode, colorLabeler=self.colorLabeler)
|
71 |
+
self._reset_parameters()
|
72 |
+
|
73 |
+
def _reset_parameters(self):
|
74 |
+
for p in self.parameters():
|
75 |
+
if p.dim() > 1:
|
76 |
+
nn.init.xavier_uniform_(p)
|
77 |
+
|
78 |
+
def load_and_froze_weight(self, checkpt_path):
|
79 |
+
data_dict = torch.load(checkpt_path, map_location=torch.device('cpu'))
|
80 |
+
'''
|
81 |
+
for param_tensor in data_dict['state_dict']:
|
82 |
+
print(param_tensor,'\t',data_dict['state_dict'][param_tensor].size())
|
83 |
+
'''
|
84 |
+
self.segnet.load_state_dict(data_dict['state_dict'])
|
85 |
+
for name, param in self.segnet.named_parameters():
|
86 |
+
param.requires_grad = False
|
87 |
+
self.segnet.eval()
|
88 |
+
|
89 |
+
def set_train(self):
|
90 |
+
## running mode only affect certain modules, e.g. Dropout, BN, etc.
|
91 |
+
self.repnet.train()
|
92 |
+
self.wildpath.train()
|
93 |
+
self.hintpath.train()
|
94 |
+
if self.enhanced:
|
95 |
+
self.enhanceNet.train()
|
96 |
+
|
97 |
+
def get_entry_mask(self, mask_tensor):
|
98 |
+
if mask_tensor is None:
|
99 |
+
return None
|
100 |
+
## flatten (N,1,H,W) to (N,HW)
|
101 |
+
return mask_tensor.flatten(1)
|
102 |
+
|
103 |
+
def forward(self, input_grays, input_colors, n_anchors=8, sampled_T=0):
|
104 |
+
'''
|
105 |
+
Notice: function was customized for inferece only
|
106 |
+
'''
|
107 |
+
affinity_map = self.segnet(input_grays)
|
108 |
+
pred_feats = self.repnet(input_grays)
|
109 |
+
if self.spix_pos:
|
110 |
+
full_pos_feats = self.pos_enc(pred_feats)
|
111 |
+
proxy_feats = torch.cat([pred_feats, input_colors, full_pos_feats], dim=1)
|
112 |
+
pooled_proxy_feats, conf_sum = basic.poolfeat(proxy_feats, affinity_map, self.sp_size, self.sp_size, True)
|
113 |
+
feat_tokens = pooled_proxy_feats[:,:64,:,:]
|
114 |
+
spix_colors = pooled_proxy_feats[:,64:66,:,:]
|
115 |
+
pos_feats = pooled_proxy_feats[:,66:,:,:]
|
116 |
+
else:
|
117 |
+
proxy_feats = torch.cat([pred_feats, input_colors], dim=1)
|
118 |
+
pooled_proxy_feats, conf_sum = basic.poolfeat(proxy_feats, affinity_map, self.sp_size, self.sp_size, True)
|
119 |
+
feat_tokens = pooled_proxy_feats[:,:64,:,:]
|
120 |
+
spix_colors = pooled_proxy_feats[:,64:,:,:]
|
121 |
+
pos_feats = self.pos_enc(feat_tokens)
|
122 |
+
|
123 |
+
token_labels = torch.max(self.colorLabeler.encode_ab2ind(spix_colors), dim=1, keepdim=True)[1]
|
124 |
+
spixel_sizes = basic.get_spixel_size(affinity_map, self.sp_size, self.sp_size)
|
125 |
+
all_one_map = torch.ones(spixel_sizes.shape, device=input_grays.device)
|
126 |
+
empty_entries = torch.where(spixel_sizes < 25/(self.sp_size**2), all_one_map, 1-all_one_map)
|
127 |
+
src_pad_mask = self.get_entry_mask(empty_entries) if self.use_token_mask else None
|
128 |
+
trg_pad_mask = src_pad_mask
|
129 |
+
|
130 |
+
## parallel prob
|
131 |
+
N,C,H,W = feat_tokens.shape
|
132 |
+
## (N,C,H,W) -> (HW,N,C)
|
133 |
+
src_pos_seq = pos_feats.flatten(2).permute(2, 0, 1)
|
134 |
+
src_seq = feat_tokens.flatten(2).permute(2, 0, 1)
|
135 |
+
## color prob branch
|
136 |
+
enc_out, _ = self.wildpath(src_seq, src_pos_seq, src_pad_mask)
|
137 |
+
pal_logit = self.mid_word_prj(enc_out)
|
138 |
+
pal_logit = pal_logit.permute(1, 2, 0).view(N,self.n_vocab,H,W)
|
139 |
+
|
140 |
+
## seed prob branch
|
141 |
+
## mask(N,1,H,W): sample anchors at clustering layers
|
142 |
+
color_feat = enc_out.permute(1, 2, 0).view(N,C,H,W)
|
143 |
+
hint_mask, cluster_mask = self.anchorGen(color_feat, n_anchors, spixel_sizes, use_sklearn_kmeans=False)
|
144 |
+
pred_prob = torch.softmax(pal_logit, dim=1)
|
145 |
+
color_feat2 = src_seq.permute(1, 2, 0).view(N,C,H,W)
|
146 |
+
#pred_prob, adj_matrix = self.anchorGen._detect_correlation(color_feat, pred_prob, hint_mask, thres=0.1)
|
147 |
+
if sampled_T < 0:
|
148 |
+
## GT anchor colors
|
149 |
+
sampled_spix_colors = spix_colors
|
150 |
+
elif sampled_T > 0:
|
151 |
+
top1_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=0)
|
152 |
+
top2_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=1)
|
153 |
+
top3_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=2)
|
154 |
+
## duplicate meta tensors
|
155 |
+
sampled_spix_colors = torch.cat((top1_spix_colors,top2_spix_colors,top3_spix_colors), dim=0)
|
156 |
+
N = 3*N
|
157 |
+
input_grays = input_grays.expand(N,-1,-1,-1)
|
158 |
+
hint_mask = hint_mask.expand(N,-1,-1,-1)
|
159 |
+
affinity_map = affinity_map.expand(N,-1,-1,-1)
|
160 |
+
src_seq = src_seq.expand(-1, N,-1)
|
161 |
+
src_pos_seq = src_pos_seq.expand(-1, N,-1)
|
162 |
+
else:
|
163 |
+
sampled_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=sampled_T)
|
164 |
+
## debug: controllable
|
165 |
+
if False:
|
166 |
+
hint_mask, sampled_spix_colors = basic.io_user_control(hint_mask, spix_colors, output=False)
|
167 |
+
|
168 |
+
sampled_token_labels = torch.max(self.colorLabeler.encode_ab2ind(sampled_spix_colors), dim=1, keepdim=True)[1]
|
169 |
+
|
170 |
+
## hint based prediction
|
171 |
+
## (N,C,H,W) -> (HW,N,C)
|
172 |
+
mask_seq = hint_mask.flatten(2).permute(2, 0, 1)
|
173 |
+
if self.hint2regress:
|
174 |
+
spix_colors_ = sampled_spix_colors
|
175 |
+
gt_seq = spix_colors_.flatten(2).permute(2, 0, 1)
|
176 |
+
hint_seq = self.trg_word_emb(torch.cat([src_seq, mask_seq * gt_seq, mask_seq], dim=2))
|
177 |
+
dec_out, _ = self.hintpath(hint_seq, src_pos_seq, src_pad_mask)
|
178 |
+
else:
|
179 |
+
token_labels_ = sampled_token_labels
|
180 |
+
label_map = F.one_hot(token_labels_, num_classes=313).squeeze(1).float()
|
181 |
+
label_seq = label_map.permute(0, 3, 1, 2).flatten(2).permute(2, 0, 1)
|
182 |
+
hint_seq = self.trg_word_emb(torch.cat([src_seq, mask_seq * label_seq, mask_seq], dim=2))
|
183 |
+
dec_out, _ = self.hintpath(hint_seq, src_pos_seq, src_pad_mask)
|
184 |
+
ref_logit = self.trg_word_prj(dec_out)
|
185 |
+
Ct = 2 if self.hint2regress else self.n_vocab
|
186 |
+
ref_logit = ref_logit.permute(1, 2, 0).view(N,Ct,H,W)
|
187 |
+
|
188 |
+
## pixelwise enhancement
|
189 |
+
pred_colors = None
|
190 |
+
if self.enhanced:
|
191 |
+
proc_feats = dec_out.permute(1, 2, 0).view(N,64,H,W)
|
192 |
+
full_feats = basic.upfeat(proc_feats, affinity_map, self.sp_size, self.sp_size)
|
193 |
+
pred_colors = self.enhanceNet(torch.cat((input_grays,full_feats), dim=1))
|
194 |
+
pred_colors = torch.tanh(pred_colors)
|
195 |
+
|
196 |
+
return pal_logit, ref_logit, pred_colors, affinity_map, spix_colors, hint_mask
|
models/network.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn import init
|
5 |
+
import torchvision
|
6 |
+
import torch.nn.utils.spectral_norm as spectral_norm
|
7 |
+
import math
|
8 |
+
|
9 |
+
|
10 |
+
class ConvBlock(nn.Module):
|
11 |
+
def __init__(self, inChannels, outChannels, convNum, normLayer=None):
|
12 |
+
super(ConvBlock, self).__init__()
|
13 |
+
self.inConv = nn.Sequential(
|
14 |
+
nn.Conv2d(inChannels, outChannels, kernel_size=3, padding=1),
|
15 |
+
nn.ReLU(inplace=True)
|
16 |
+
)
|
17 |
+
layers = []
|
18 |
+
for _ in range(convNum - 1):
|
19 |
+
layers.append(nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1))
|
20 |
+
layers.append(nn.ReLU(inplace=True))
|
21 |
+
if not (normLayer is None):
|
22 |
+
layers.append(normLayer(outChannels))
|
23 |
+
self.conv = nn.Sequential(*layers)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
x = self.inConv(x)
|
27 |
+
x = self.conv(x)
|
28 |
+
return x
|
29 |
+
|
30 |
+
|
31 |
+
class ResidualBlock(nn.Module):
|
32 |
+
def __init__(self, channels, normLayer=None):
|
33 |
+
super(ResidualBlock, self).__init__()
|
34 |
+
layers = []
|
35 |
+
layers.append(nn.Conv2d(channels, channels, kernel_size=3, padding=1))
|
36 |
+
layers.append(spectral_norm(nn.Conv2d(channels, channels, kernel_size=3, padding=1)))
|
37 |
+
if not (normLayer is None):
|
38 |
+
layers.append(normLayer(channels))
|
39 |
+
layers.append(nn.ReLU(inplace=True))
|
40 |
+
layers.append(nn.Conv2d(channels, channels, kernel_size=3, padding=1))
|
41 |
+
if not (normLayer is None):
|
42 |
+
layers.append(normLayer(channels))
|
43 |
+
self.conv = nn.Sequential(*layers)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
residual = self.conv(x)
|
47 |
+
return F.relu(x + residual, inplace=True)
|
48 |
+
|
49 |
+
|
50 |
+
class ResidualBlockSN(nn.Module):
|
51 |
+
def __init__(self, channels, normLayer=None):
|
52 |
+
super(ResidualBlockSN, self).__init__()
|
53 |
+
layers = []
|
54 |
+
layers.append(spectral_norm(nn.Conv2d(channels, channels, kernel_size=3, padding=1)))
|
55 |
+
layers.append(nn.LeakyReLU(0.2, True))
|
56 |
+
layers.append(spectral_norm(nn.Conv2d(channels, channels, kernel_size=3, padding=1)))
|
57 |
+
if not (normLayer is None):
|
58 |
+
layers.append(normLayer(channels))
|
59 |
+
self.conv = nn.Sequential(*layers)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
residual = self.conv(x)
|
63 |
+
return F.leaky_relu(x + residual, 2e-1, inplace=True)
|
64 |
+
|
65 |
+
|
66 |
+
class DownsampleBlock(nn.Module):
|
67 |
+
def __init__(self, inChannels, outChannels, convNum=2, normLayer=None):
|
68 |
+
super(DownsampleBlock, self).__init__()
|
69 |
+
layers = []
|
70 |
+
layers.append(nn.Conv2d(inChannels, outChannels, kernel_size=3, padding=1, stride=2))
|
71 |
+
layers.append(nn.ReLU(inplace=True))
|
72 |
+
for _ in range(convNum - 1):
|
73 |
+
layers.append(nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1))
|
74 |
+
layers.append(nn.ReLU(inplace=True))
|
75 |
+
if not (normLayer is None):
|
76 |
+
layers.append(normLayer(outChannels))
|
77 |
+
self.conv = nn.Sequential(*layers)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
return self.conv(x)
|
81 |
+
|
82 |
+
|
83 |
+
class UpsampleBlock(nn.Module):
|
84 |
+
def __init__(self, inChannels, outChannels, convNum=2, normLayer=None):
|
85 |
+
super(UpsampleBlock, self).__init__()
|
86 |
+
self.conv1 = nn.Conv2d(inChannels, outChannels, kernel_size=3, padding=1, stride=1)
|
87 |
+
self.combine = nn.Conv2d(2 * outChannels, outChannels, kernel_size=3, padding=1)
|
88 |
+
layers = []
|
89 |
+
for _ in range(convNum - 1):
|
90 |
+
layers.append(nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1))
|
91 |
+
layers.append(nn.ReLU(inplace=True))
|
92 |
+
if not (normLayer is None):
|
93 |
+
layers.append(normLayer(outChannels))
|
94 |
+
self.conv2 = nn.Sequential(*layers)
|
95 |
+
|
96 |
+
def forward(self, x, x0):
|
97 |
+
x = self.conv1(x)
|
98 |
+
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
99 |
+
x = self.combine(torch.cat((x, x0), 1))
|
100 |
+
x = F.relu(x)
|
101 |
+
return self.conv2(x)
|
102 |
+
|
103 |
+
|
104 |
+
class UpsampleBlockSN(nn.Module):
|
105 |
+
def __init__(self, inChannels, outChannels, convNum=2, normLayer=None):
|
106 |
+
super(UpsampleBlockSN, self).__init__()
|
107 |
+
self.conv1 = spectral_norm(nn.Conv2d(inChannels, outChannels, kernel_size=3, stride=1, padding=1))
|
108 |
+
self.shortcut = spectral_norm(nn.Conv2d(outChannels, outChannels, kernel_size=3, stride=1, padding=1))
|
109 |
+
layers = []
|
110 |
+
for _ in range(convNum - 1):
|
111 |
+
layers.append(spectral_norm(nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1)))
|
112 |
+
layers.append(nn.LeakyReLU(0.2, True))
|
113 |
+
if not (normLayer is None):
|
114 |
+
layers.append(normLayer(outChannels))
|
115 |
+
self.conv2 = nn.Sequential(*layers)
|
116 |
+
|
117 |
+
def forward(self, x, x0):
|
118 |
+
x = self.conv1(x)
|
119 |
+
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
120 |
+
x = x + self.shortcut(x0)
|
121 |
+
x = F.leaky_relu(x, 2e-1)
|
122 |
+
return self.conv2(x)
|
123 |
+
|
124 |
+
|
125 |
+
class HourGlass2(nn.Module):
|
126 |
+
def __init__(self, inChannel=3, outChannel=1, resNum=3, normLayer=None):
|
127 |
+
super(HourGlass2, self).__init__()
|
128 |
+
self.inConv = ConvBlock(inChannel, 64, convNum=2, normLayer=normLayer)
|
129 |
+
self.down1 = DownsampleBlock(64, 128, convNum=2, normLayer=normLayer)
|
130 |
+
self.down2 = DownsampleBlock(128, 256, convNum=2, normLayer=normLayer)
|
131 |
+
self.residual = nn.Sequential(*[ResidualBlock(256) for _ in range(resNum)])
|
132 |
+
self.up2 = UpsampleBlock(256, 128, convNum=3, normLayer=normLayer)
|
133 |
+
self.up1 = UpsampleBlock(128, 64, convNum=3, normLayer=normLayer)
|
134 |
+
self.outConv = nn.Conv2d(64, outChannel, kernel_size=3, padding=1)
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
f1 = self.inConv(x)
|
138 |
+
f2 = self.down1(f1)
|
139 |
+
f3 = self.down2(f2)
|
140 |
+
r3 = self.residual(f3)
|
141 |
+
r2 = self.up2(r3, f2)
|
142 |
+
r1 = self.up1(r2, f1)
|
143 |
+
y = self.outConv(r1)
|
144 |
+
return y
|
145 |
+
|
146 |
+
|
147 |
+
class ColorProbNet(nn.Module):
|
148 |
+
def __init__(self, inChannel=1, outChannel=2, with_SA=False):
|
149 |
+
super(ColorProbNet, self).__init__()
|
150 |
+
BNFunc = nn.BatchNorm2d
|
151 |
+
# conv1: 256
|
152 |
+
conv1_2 = [spectral_norm(nn.Conv2d(inChannel, 64, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
|
153 |
+
conv1_2 += [spectral_norm(nn.Conv2d(64, 64, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
|
154 |
+
conv1_2 += [BNFunc(64, affine=True)]
|
155 |
+
# conv2: 128
|
156 |
+
conv2_3 = [spectral_norm(nn.Conv2d(64, 128, 3, stride=2, padding=1)), nn.LeakyReLU(0.2, True),]
|
157 |
+
conv2_3 += [spectral_norm(nn.Conv2d(128, 128, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
|
158 |
+
conv2_3 += [spectral_norm(nn.Conv2d(128, 128, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
|
159 |
+
conv2_3 += [BNFunc(128, affine=True)]
|
160 |
+
# conv3: 64
|
161 |
+
conv3_3 = [spectral_norm(nn.Conv2d(128, 256, 3, stride=2, padding=1)), nn.LeakyReLU(0.2, True),]
|
162 |
+
conv3_3 += [spectral_norm(nn.Conv2d(256, 256, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
|
163 |
+
conv3_3 += [spectral_norm(nn.Conv2d(256, 256, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
|
164 |
+
conv3_3 += [BNFunc(256, affine=True)]
|
165 |
+
# conv4: 32
|
166 |
+
conv4_3 = [spectral_norm(nn.Conv2d(256, 512, 3, stride=2, padding=1)), nn.LeakyReLU(0.2, True),]
|
167 |
+
conv4_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
|
168 |
+
conv4_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
|
169 |
+
conv4_3 += [BNFunc(512, affine=True)]
|
170 |
+
# conv5: 32
|
171 |
+
conv5_3 = [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
|
172 |
+
conv5_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
|
173 |
+
conv5_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
|
174 |
+
conv5_3 += [BNFunc(512, affine=True)]
|
175 |
+
# conv6: 32
|
176 |
+
conv6_3 = [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
|
177 |
+
conv6_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
|
178 |
+
conv6_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
|
179 |
+
conv6_3 += [BNFunc(512, affine=True),]
|
180 |
+
if with_SA:
|
181 |
+
conv6_3 += [Self_Attn(512)]
|
182 |
+
# conv7: 32
|
183 |
+
conv7_3 = [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
|
184 |
+
conv7_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
|
185 |
+
conv7_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
|
186 |
+
conv7_3 += [BNFunc(512, affine=True)]
|
187 |
+
# conv8: 64
|
188 |
+
conv8up = [nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(512, 256, 3, stride=1, padding=1),]
|
189 |
+
conv3short8 = [nn.Conv2d(256, 256, 3, stride=1, padding=1),]
|
190 |
+
conv8_3 = [nn.ReLU(True),]
|
191 |
+
conv8_3 += [nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.ReLU(True),]
|
192 |
+
conv8_3 += [nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.ReLU(True),]
|
193 |
+
conv8_3 += [BNFunc(256, affine=True),]
|
194 |
+
# conv9: 128
|
195 |
+
conv9up = [nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(256, 128, 3, stride=1, padding=1),]
|
196 |
+
conv9_2 = [nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.ReLU(True),]
|
197 |
+
conv9_2 += [BNFunc(128, affine=True)]
|
198 |
+
# conv10: 64
|
199 |
+
conv10up = [nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(128, 64, 3, stride=1, padding=1),]
|
200 |
+
conv10_2 = [nn.ReLU(True),]
|
201 |
+
conv10_2 += [nn.Conv2d(64, outChannel, 3, stride=1, padding=1), nn.ReLU(True),]
|
202 |
+
|
203 |
+
self.conv1_2 = nn.Sequential(*conv1_2)
|
204 |
+
self.conv2_3 = nn.Sequential(*conv2_3)
|
205 |
+
self.conv3_3 = nn.Sequential(*conv3_3)
|
206 |
+
self.conv4_3 = nn.Sequential(*conv4_3)
|
207 |
+
self.conv5_3 = nn.Sequential(*conv5_3)
|
208 |
+
self.conv6_3 = nn.Sequential(*conv6_3)
|
209 |
+
self.conv7_3 = nn.Sequential(*conv7_3)
|
210 |
+
self.conv8up = nn.Sequential(*conv8up)
|
211 |
+
self.conv3short8 = nn.Sequential(*conv3short8)
|
212 |
+
self.conv8_3 = nn.Sequential(*conv8_3)
|
213 |
+
self.conv9up = nn.Sequential(*conv9up)
|
214 |
+
self.conv9_2 = nn.Sequential(*conv9_2)
|
215 |
+
self.conv10up = nn.Sequential(*conv10up)
|
216 |
+
self.conv10_2 = nn.Sequential(*conv10_2)
|
217 |
+
# claffificaton output
|
218 |
+
#self.model_class = nn.Sequential(*[nn.Conv2d(256, 313, kernel_size=1, padding=0, stride=1),])
|
219 |
+
|
220 |
+
def forward(self, input_grays):
|
221 |
+
f1_2 = self.conv1_2(input_grays)
|
222 |
+
f2_3 = self.conv2_3(f1_2)
|
223 |
+
f3_3 = self.conv3_3(f2_3)
|
224 |
+
f4_3 = self.conv4_3(f3_3)
|
225 |
+
f5_3 = self.conv5_3(f4_3)
|
226 |
+
f6_3 = self.conv6_3(f5_3)
|
227 |
+
f7_3 = self.conv7_3(f6_3)
|
228 |
+
f8_up = self.conv8up(f7_3) + self.conv3short8(f3_3)
|
229 |
+
f8_3 = self.conv8_3(f8_up)
|
230 |
+
f9_up = self.conv9up(f8_3)
|
231 |
+
f9_2 = self.conv9_2(f9_up)
|
232 |
+
f10_up = self.conv10up(f9_2)
|
233 |
+
f10_2 = self.conv10_2(f10_up)
|
234 |
+
out_feats = f10_2
|
235 |
+
#out_probs = self.model_class(f8_3)
|
236 |
+
return out_feats
|
237 |
+
|
238 |
+
|
239 |
+
|
240 |
+
def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
|
241 |
+
if batchNorm:
|
242 |
+
return nn.Sequential(
|
243 |
+
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False),
|
244 |
+
nn.BatchNorm2d(out_planes),
|
245 |
+
nn.LeakyReLU(0.1)
|
246 |
+
)
|
247 |
+
else:
|
248 |
+
return nn.Sequential(
|
249 |
+
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
|
250 |
+
nn.LeakyReLU(0.1)
|
251 |
+
)
|
252 |
+
|
253 |
+
|
254 |
+
def deconv(in_planes, out_planes):
|
255 |
+
return nn.Sequential(
|
256 |
+
nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True),
|
257 |
+
nn.LeakyReLU(0.1)
|
258 |
+
)
|
259 |
+
|
260 |
+
class SpixelNet(nn.Module):
|
261 |
+
def __init__(self, inChannel=3, outChannel=9, batchNorm=True):
|
262 |
+
super(SpixelNet,self).__init__()
|
263 |
+
self.batchNorm = batchNorm
|
264 |
+
self.conv0a = conv(self.batchNorm, inChannel, 16, kernel_size=3)
|
265 |
+
self.conv0b = conv(self.batchNorm, 16, 16, kernel_size=3)
|
266 |
+
self.conv1a = conv(self.batchNorm, 16, 32, kernel_size=3, stride=2)
|
267 |
+
self.conv1b = conv(self.batchNorm, 32, 32, kernel_size=3)
|
268 |
+
self.conv2a = conv(self.batchNorm, 32, 64, kernel_size=3, stride=2)
|
269 |
+
self.conv2b = conv(self.batchNorm, 64, 64, kernel_size=3)
|
270 |
+
self.conv3a = conv(self.batchNorm, 64, 128, kernel_size=3, stride=2)
|
271 |
+
self.conv3b = conv(self.batchNorm, 128, 128, kernel_size=3)
|
272 |
+
self.conv4a = conv(self.batchNorm, 128, 256, kernel_size=3, stride=2)
|
273 |
+
self.conv4b = conv(self.batchNorm, 256, 256, kernel_size=3)
|
274 |
+
self.deconv3 = deconv(256, 128)
|
275 |
+
self.conv3_1 = conv(self.batchNorm, 256, 128)
|
276 |
+
self.deconv2 = deconv(128, 64)
|
277 |
+
self.conv2_1 = conv(self.batchNorm, 128, 64)
|
278 |
+
self.deconv1 = deconv(64, 32)
|
279 |
+
self.conv1_1 = conv(self.batchNorm, 64, 32)
|
280 |
+
self.deconv0 = deconv(32, 16)
|
281 |
+
self.conv0_1 = conv(self.batchNorm, 32, 16)
|
282 |
+
self.pred_mask0 = nn.Conv2d(16, outChannel, kernel_size=3, stride=1, padding=1, bias=True)
|
283 |
+
self.softmax = nn.Softmax(1)
|
284 |
+
for m in self.modules():
|
285 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
286 |
+
init.kaiming_normal_(m.weight, 0.1)
|
287 |
+
if m.bias is not None:
|
288 |
+
init.constant_(m.bias, 0)
|
289 |
+
elif isinstance(m, nn.BatchNorm2d):
|
290 |
+
init.constant_(m.weight, 1)
|
291 |
+
init.constant_(m.bias, 0)
|
292 |
+
|
293 |
+
def forward(self, x):
|
294 |
+
out1 = self.conv0b(self.conv0a(x)) #5*5
|
295 |
+
out2 = self.conv1b(self.conv1a(out1)) #11*11
|
296 |
+
out3 = self.conv2b(self.conv2a(out2)) #23*23
|
297 |
+
out4 = self.conv3b(self.conv3a(out3)) #47*47
|
298 |
+
out5 = self.conv4b(self.conv4a(out4)) #95*95
|
299 |
+
out_deconv3 = self.deconv3(out5)
|
300 |
+
concat3 = torch.cat((out4, out_deconv3), 1)
|
301 |
+
out_conv3_1 = self.conv3_1(concat3)
|
302 |
+
out_deconv2 = self.deconv2(out_conv3_1)
|
303 |
+
concat2 = torch.cat((out3, out_deconv2), 1)
|
304 |
+
out_conv2_1 = self.conv2_1(concat2)
|
305 |
+
out_deconv1 = self.deconv1(out_conv2_1)
|
306 |
+
concat1 = torch.cat((out2, out_deconv1), 1)
|
307 |
+
out_conv1_1 = self.conv1_1(concat1)
|
308 |
+
out_deconv0 = self.deconv0(out_conv1_1)
|
309 |
+
concat0 = torch.cat((out1, out_deconv0), 1)
|
310 |
+
out_conv0_1 = self.conv0_1(concat0)
|
311 |
+
mask0 = self.pred_mask0(out_conv0_1)
|
312 |
+
prob0 = self.softmax(mask0)
|
313 |
+
return prob0
|
314 |
+
|
315 |
+
|
316 |
+
|
317 |
+
## VGG architecter, used for the perceptual loss using a pretrained VGG network
|
318 |
+
class VGG19(torch.nn.Module):
|
319 |
+
def __init__(self, requires_grad=False, local_pretrained_path='checkpoints/vgg19.pth'):
|
320 |
+
super().__init__()
|
321 |
+
#vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
|
322 |
+
model = torchvision.models.vgg19()
|
323 |
+
model.load_state_dict(torch.load(local_pretrained_path))
|
324 |
+
vgg_pretrained_features = model.features
|
325 |
+
|
326 |
+
self.slice1 = torch.nn.Sequential()
|
327 |
+
self.slice2 = torch.nn.Sequential()
|
328 |
+
self.slice3 = torch.nn.Sequential()
|
329 |
+
self.slice4 = torch.nn.Sequential()
|
330 |
+
self.slice5 = torch.nn.Sequential()
|
331 |
+
for x in range(2):
|
332 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
333 |
+
for x in range(2, 7):
|
334 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
335 |
+
for x in range(7, 12):
|
336 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
337 |
+
for x in range(12, 21):
|
338 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
339 |
+
for x in range(21, 30):
|
340 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
341 |
+
if not requires_grad:
|
342 |
+
for param in self.parameters():
|
343 |
+
param.requires_grad = False
|
344 |
+
|
345 |
+
def forward(self, X):
|
346 |
+
h_relu1 = self.slice1(X)
|
347 |
+
h_relu2 = self.slice2(h_relu1)
|
348 |
+
h_relu3 = self.slice3(h_relu2)
|
349 |
+
h_relu4 = self.slice4(h_relu3)
|
350 |
+
h_relu5 = self.slice5(h_relu4)
|
351 |
+
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
352 |
+
return out
|
models/position_encoding.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Various positional encodings for the transformer.
|
4 |
+
"""
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
class PositionEmbeddingSine(nn.Module):
|
11 |
+
"""
|
12 |
+
This is a more standard version of the position embedding, very similar to the one
|
13 |
+
used by the Attention is all you need paper, generalized to work on images.
|
14 |
+
"""
|
15 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
16 |
+
super().__init__()
|
17 |
+
self.num_pos_feats = num_pos_feats
|
18 |
+
self.temperature = temperature
|
19 |
+
self.normalize = normalize
|
20 |
+
if scale is not None and normalize is False:
|
21 |
+
raise ValueError("normalize should be True if scale is passed")
|
22 |
+
if scale is None:
|
23 |
+
scale = 2 * math.pi
|
24 |
+
self.scale = scale
|
25 |
+
|
26 |
+
def forward(self, token_tensors):
|
27 |
+
## input: (B,C,H,W)
|
28 |
+
x = token_tensors
|
29 |
+
h, w = x.shape[-2:]
|
30 |
+
identity_map= torch.ones((h,w), device=x.device)
|
31 |
+
y_embed = identity_map.cumsum(0, dtype=torch.float32)
|
32 |
+
x_embed = identity_map.cumsum(1, dtype=torch.float32)
|
33 |
+
if self.normalize:
|
34 |
+
eps = 1e-6
|
35 |
+
y_embed = y_embed / (y_embed[-1:, :] + eps) * self.scale
|
36 |
+
x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale
|
37 |
+
|
38 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
39 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
40 |
+
|
41 |
+
pos_x = x_embed[:, :, None] / dim_t
|
42 |
+
pos_y = y_embed[:, :, None] / dim_t
|
43 |
+
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
|
44 |
+
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
|
45 |
+
pos = torch.cat((pos_y, pos_x), dim=2).permute(2, 0, 1)
|
46 |
+
batch_pos = pos.unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
47 |
+
return batch_pos
|
48 |
+
|
49 |
+
|
50 |
+
class PositionEmbeddingLearned(nn.Module):
|
51 |
+
"""
|
52 |
+
Absolute pos embedding, learned.
|
53 |
+
"""
|
54 |
+
def __init__(self, n_pos_x=16, n_pos_y=16, num_pos_feats=64):
|
55 |
+
super().__init__()
|
56 |
+
self.row_embed = nn.Embedding(n_pos_y, num_pos_feats)
|
57 |
+
self.col_embed = nn.Embedding(n_pos_x, num_pos_feats)
|
58 |
+
self.reset_parameters()
|
59 |
+
|
60 |
+
def reset_parameters(self):
|
61 |
+
nn.init.uniform_(self.row_embed.weight)
|
62 |
+
nn.init.uniform_(self.col_embed.weight)
|
63 |
+
|
64 |
+
def forward(self, token_tensors):
|
65 |
+
## input: (B,C,H,W)
|
66 |
+
x = token_tensors
|
67 |
+
h, w = x.shape[-2:]
|
68 |
+
i = torch.arange(w, device=x.device)
|
69 |
+
j = torch.arange(h, device=x.device)
|
70 |
+
x_emb = self.col_embed(i)
|
71 |
+
y_emb = self.row_embed(j)
|
72 |
+
pos = torch.cat([
|
73 |
+
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
74 |
+
y_emb.unsqueeze(1).repeat(1, w, 1),
|
75 |
+
], dim=-1).permute(2, 0, 1)
|
76 |
+
batch_pos = pos.unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
77 |
+
return batch_pos
|
78 |
+
|
79 |
+
|
80 |
+
def build_position_encoding(num_pos_feats=64, n_pos_x=16, n_pos_y=16, is_learned=False):
|
81 |
+
if is_learned:
|
82 |
+
position_embedding = PositionEmbeddingLearned(n_pos_x, n_pos_y, num_pos_feats)
|
83 |
+
else:
|
84 |
+
position_embedding = PositionEmbeddingSine(num_pos_feats, normalize=True)
|
85 |
+
|
86 |
+
return position_embedding
|
models/transformer2d.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
import copy, math
|
5 |
+
from models.position_encoding import build_position_encoding
|
6 |
+
|
7 |
+
|
8 |
+
class TransformerEncoder(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, enc_layer, num_layers, use_dense_pos=False):
|
11 |
+
super().__init__()
|
12 |
+
self.layers = nn.ModuleList([copy.deepcopy(enc_layer) for i in range(num_layers)])
|
13 |
+
self.num_layers = num_layers
|
14 |
+
self.use_dense_pos = use_dense_pos
|
15 |
+
|
16 |
+
def forward(self, src, pos, padding_mask=None):
|
17 |
+
if self.use_dense_pos:
|
18 |
+
## pos encoding at each MH-Attention block (q,k)
|
19 |
+
output, pos_enc = src, pos
|
20 |
+
for layer in self.layers:
|
21 |
+
output, att_map = layer(output, pos_enc, padding_mask)
|
22 |
+
else:
|
23 |
+
## pos encoding at input only (q,k,v)
|
24 |
+
output, pos_enc = src + pos, None
|
25 |
+
for layer in self.layers:
|
26 |
+
output, att_map = layer(output, pos_enc, padding_mask)
|
27 |
+
return output, att_map
|
28 |
+
|
29 |
+
|
30 |
+
class EncoderLayer(nn.Module):
|
31 |
+
|
32 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
|
33 |
+
use_dense_pos=False):
|
34 |
+
super().__init__()
|
35 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
36 |
+
# Implementation of Feedforward model
|
37 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
38 |
+
self.dropout = nn.Dropout(dropout)
|
39 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
40 |
+
|
41 |
+
self.norm1 = nn.LayerNorm(d_model)
|
42 |
+
self.norm2 = nn.LayerNorm(d_model)
|
43 |
+
self.dropout1 = nn.Dropout(dropout)
|
44 |
+
self.dropout2 = nn.Dropout(dropout)
|
45 |
+
|
46 |
+
self.activation = _get_activation_fn(activation)
|
47 |
+
|
48 |
+
def with_pos_embed(self, tensor, pos):
|
49 |
+
return tensor if pos is None else tensor + pos
|
50 |
+
|
51 |
+
def forward(self, src, pos, padding_mask):
|
52 |
+
q = k = self.with_pos_embed(src, pos)
|
53 |
+
src2, attn = self.self_attn(q, k, value=src, key_padding_mask=padding_mask)
|
54 |
+
src = src + self.dropout1(src2)
|
55 |
+
src = self.norm1(src)
|
56 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
57 |
+
src = src + self.dropout2(src2)
|
58 |
+
src = self.norm2(src)
|
59 |
+
return src, attn
|
60 |
+
|
61 |
+
|
62 |
+
class TransformerDecoder(nn.Module):
|
63 |
+
|
64 |
+
def __init__(self, dec_layer, num_layers, use_dense_pos=False, return_intermediate=False):
|
65 |
+
super().__init__()
|
66 |
+
self.layers = nn.ModuleList([copy.deepcopy(dec_layer) for i in range(num_layers)])
|
67 |
+
self.num_layers = num_layers
|
68 |
+
self.use_dense_pos = use_dense_pos
|
69 |
+
self.return_intermediate = return_intermediate
|
70 |
+
|
71 |
+
def forward(self, tgt, tgt_pos, memory, memory_pos,
|
72 |
+
tgt_padding_mask, src_padding_mask, tgt_attn_mask=None):
|
73 |
+
intermediate = []
|
74 |
+
if self.use_dense_pos:
|
75 |
+
## pos encoding at each MH-Attention block (q,k)
|
76 |
+
output = tgt
|
77 |
+
tgt_pos_enc, memory_pos_enc = tgt_pos, memory_pos
|
78 |
+
for layer in self.layers:
|
79 |
+
output, att_map = layer(output, tgt_pos_enc, memory, memory_pos_enc,
|
80 |
+
tgt_padding_mask, src_padding_mask, tgt_attn_mask)
|
81 |
+
if self.return_intermediate:
|
82 |
+
intermediate.append(output)
|
83 |
+
else:
|
84 |
+
## pos encoding at input only (q,k,v)
|
85 |
+
output = tgt + tgt_pos
|
86 |
+
tgt_pos_enc, memory_pos_enc = None, None
|
87 |
+
for layer in self.layers:
|
88 |
+
output, att_map = layer(output, tgt_pos_enc, memory, memory_pos_enc,
|
89 |
+
tgt_padding_mask, src_padding_mask, tgt_attn_mask)
|
90 |
+
if self.return_intermediate:
|
91 |
+
intermediate.append(output)
|
92 |
+
|
93 |
+
if self.return_intermediate:
|
94 |
+
return torch.stack(intermediate)
|
95 |
+
return output, att_map
|
96 |
+
|
97 |
+
|
98 |
+
class DecoderLayer(nn.Module):
|
99 |
+
|
100 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
|
101 |
+
use_dense_pos=False):
|
102 |
+
super().__init__()
|
103 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
104 |
+
self.corr_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
105 |
+
# Implementation of Feedforward model
|
106 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
107 |
+
self.dropout = nn.Dropout(dropout)
|
108 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
109 |
+
|
110 |
+
self.norm1 = nn.LayerNorm(d_model)
|
111 |
+
self.norm2 = nn.LayerNorm(d_model)
|
112 |
+
self.norm3 = nn.LayerNorm(d_model)
|
113 |
+
self.dropout1 = nn.Dropout(dropout)
|
114 |
+
self.dropout2 = nn.Dropout(dropout)
|
115 |
+
self.dropout3 = nn.Dropout(dropout)
|
116 |
+
|
117 |
+
self.activation = _get_activation_fn(activation)
|
118 |
+
|
119 |
+
def with_pos_embed(self, tensor, pos):
|
120 |
+
return tensor if pos is None else tensor + pos
|
121 |
+
|
122 |
+
def forward(self, tgt, tgt_pos, memory, memory_pos,
|
123 |
+
tgt_padding_mask, memory_padding_mask, tgt_attn_mask):
|
124 |
+
q = k = self.with_pos_embed(tgt, tgt_pos)
|
125 |
+
tgt2, attn = self.self_attn(q, k, value=tgt, key_padding_mask=tgt_padding_mask,
|
126 |
+
attn_mask=tgt_attn_mask)
|
127 |
+
tgt = tgt + self.dropout1(tgt2)
|
128 |
+
tgt = self.norm1(tgt)
|
129 |
+
tgt2, attn = self.corr_attn(query=self.with_pos_embed(tgt, tgt_pos),
|
130 |
+
key=self.with_pos_embed(memory, memory_pos),
|
131 |
+
value=memory, key_padding_mask=memory_padding_mask)
|
132 |
+
tgt = tgt + self.dropout2(tgt2)
|
133 |
+
tgt = self.norm2(tgt)
|
134 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
135 |
+
tgt = tgt + self.dropout3(tgt2)
|
136 |
+
tgt = self.norm3(tgt)
|
137 |
+
return tgt, attn
|
138 |
+
|
139 |
+
|
140 |
+
def _get_activation_fn(activation):
|
141 |
+
"""Return an activation function given a string"""
|
142 |
+
if activation == "relu":
|
143 |
+
return F.relu
|
144 |
+
if activation == "gelu":
|
145 |
+
return F.gelu
|
146 |
+
if activation == "glu":
|
147 |
+
return F.glu
|
148 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
#-----------------------------------------------------------------------------------
|
153 |
+
'''
|
154 |
+
copy from the implementatoin of "attention-is-all-you-need-pytorch-master" by Yu-Hsiang Huang
|
155 |
+
'''
|
156 |
+
|
157 |
+
class MultiHeadAttention(nn.Module):
|
158 |
+
''' Multi-Head Attention module '''
|
159 |
+
|
160 |
+
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
161 |
+
super().__init__()
|
162 |
+
|
163 |
+
self.n_head = n_head
|
164 |
+
self.d_k = d_k
|
165 |
+
self.d_v = d_v
|
166 |
+
|
167 |
+
self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
|
168 |
+
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
|
169 |
+
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
|
170 |
+
self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
|
171 |
+
|
172 |
+
self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
|
173 |
+
|
174 |
+
self.dropout = nn.Dropout(dropout)
|
175 |
+
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
176 |
+
|
177 |
+
|
178 |
+
def forward(self, q, k, v, mask=None):
|
179 |
+
|
180 |
+
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
|
181 |
+
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
|
182 |
+
|
183 |
+
residual = q
|
184 |
+
|
185 |
+
# Pass through the pre-attention projection: b x lq x (n*dv)
|
186 |
+
# Separate different heads: b x lq x n x dv
|
187 |
+
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
|
188 |
+
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
|
189 |
+
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
|
190 |
+
|
191 |
+
# Transpose for attention dot product: b x n x lq x dv
|
192 |
+
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
193 |
+
|
194 |
+
if mask is not None:
|
195 |
+
mask = mask.unsqueeze(1) # For head axis broadcasting.
|
196 |
+
|
197 |
+
q, attn = self.attention(q, k, v, mask=mask)
|
198 |
+
|
199 |
+
# Transpose to move the head dimension back: b x lq x n x dv
|
200 |
+
# Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
|
201 |
+
q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
|
202 |
+
q = self.dropout(self.fc(q))
|
203 |
+
q += residual
|
204 |
+
|
205 |
+
q = self.layer_norm(q)
|
206 |
+
|
207 |
+
return q, attn
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
class ScaledDotProductAttention(nn.Module):
|
212 |
+
''' Scaled Dot-Product Attention '''
|
213 |
+
|
214 |
+
def __init__(self, temperature, attn_dropout=0.1):
|
215 |
+
super().__init__()
|
216 |
+
self.temperature = temperature
|
217 |
+
self.dropout = nn.Dropout(attn_dropout)
|
218 |
+
|
219 |
+
def forward(self, q, k, v, mask=None):
|
220 |
+
|
221 |
+
attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
|
222 |
+
|
223 |
+
if mask is not None:
|
224 |
+
attn = attn.masked_fill(mask == 0, -1e9)
|
225 |
+
|
226 |
+
attn = self.dropout(F.softmax(attn, dim=-1))
|
227 |
+
output = torch.matmul(attn, v)
|
228 |
+
|
229 |
+
return output, attn
|