Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
def get_surface_normalv2(xyz, patch_size=5): | |
""" | |
xyz: xyz coordinates | |
patch: [p1, p2, p3, | |
p4, p5, p6, | |
p7, p8, p9] | |
surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)] | |
return: normal [h, w, 3, b] | |
""" | |
b, h, w, c = xyz.shape | |
half_patch = patch_size // 2 | |
xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device) | |
xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz | |
# xyz_left_top = xyz_pad[:, :h, :w, :] # p1 | |
# xyz_right_bottom = xyz_pad[:, -h:, -w:, :]# p9 | |
# xyz_left_bottom = xyz_pad[:, -h:, :w, :] # p7 | |
# xyz_right_top = xyz_pad[:, :h, -w:, :] # p3 | |
# xyz_cross1 = xyz_left_top - xyz_right_bottom # p1p9 | |
# xyz_cross2 = xyz_left_bottom - xyz_right_top # p7p3 | |
xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4 | |
xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6 | |
xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2 | |
xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8 | |
xyz_horizon = xyz_left - xyz_right # p4p6 | |
xyz_vertical = xyz_top - xyz_bottom # p2p8 | |
xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4 | |
xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6 | |
xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2 | |
xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8 | |
xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6 | |
xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8 | |
n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3) | |
n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3) | |
# re-orient normals consistently | |
orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0 | |
n_img_1[orient_mask] *= -1 | |
orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0 | |
n_img_2[orient_mask] *= -1 | |
n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True)) | |
n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8) | |
n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True)) | |
n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8) | |
# average 2 norms | |
n_img_aver = n_img1_norm + n_img2_norm | |
n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True)) | |
n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8) | |
# re-orient normals consistently | |
orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0 | |
n_img_aver_norm[orient_mask] *= -1 | |
n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b] | |
# a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze() | |
# plt.imshow(np.abs(a), cmap='rainbow') | |
# plt.show() | |
return n_img_aver_norm_out#n_img1_norm.permute((1, 2, 3, 0)) | |
def init_image_coor(height, width): | |
x_row = np.arange(0, width) | |
x = np.tile(x_row, (height, 1)) | |
x = x[np.newaxis, :, :] | |
x = x.astype(np.float32) | |
x = torch.from_numpy(x.copy()).cuda() | |
u_u0 = x - width/2.0 | |
y_col = np.arange(0, height) # y_col = np.arange(0, height) | |
y = np.tile(y_col, (width, 1)).T | |
y = y[np.newaxis, :, :] | |
y = y.astype(np.float32) | |
y = torch.from_numpy(y.copy()).cuda() | |
v_v0 = y - height/2.0 | |
return u_u0, v_v0 | |
def depth_to_xyz(depth, focal_length): | |
b, c, h, w = depth.shape | |
u_u0, v_v0 = init_image_coor(h, w) | |
x = u_u0 * depth / focal_length | |
y = v_v0 * depth / focal_length | |
z = depth | |
pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c] | |
return pw | |
def surface_normal_from_depth(depth, focal_length, valid_mask=None): | |
# para depth: depth map, [b, c, h, w] | |
b, c, h, w = depth.shape | |
focal_length = focal_length[:, None, None, None] | |
depth_filter = torch.nn.functional.avg_pool2d(depth, kernel_size=3, stride=1, padding=1) | |
depth_filter = torch.nn.functional.avg_pool2d(depth_filter, kernel_size=3, stride=1, padding=1) | |
xyz = depth_to_xyz(depth_filter, focal_length) | |
sn_batch = [] | |
for i in range(b): | |
xyz_i = xyz[i, :][None, :, :, :] | |
normal = get_surface_normalv2(xyz_i) | |
sn_batch.append(normal) | |
sn_batch = torch.cat(sn_batch, dim=3).permute((3, 2, 0, 1)) # [b, c, h, w] | |
mask_invalid = (~valid_mask).repeat(1, 3, 1, 1) | |
sn_batch[mask_invalid] = 0.0 | |
return | |
########### | |
# EDGE-GUIDED SAMPLING | |
# input: | |
# inputs[i,:], targets[i, :], masks[i, :], edges_img[i], thetas_img[i], masks[i, :], h, w | |
# return: | |
# inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B | |
########### | |
def ind2sub(idx, cols): | |
r = idx / cols | |
c = idx - r * cols | |
return r, c | |
def sub2ind(r, c, cols): | |
idx = r * cols + c | |
return idx | |
def edgeGuidedSampling(inputs, targets, edges_img, thetas_img, masks, h, w): | |
# find edges | |
edges_max = edges_img.max() | |
edges_min = edges_img.min() | |
edges_mask = edges_img.ge(edges_max*0.1) | |
edges_loc = edges_mask.nonzero() | |
thetas_edge = torch.masked_select(thetas_img, edges_mask) | |
minlen = thetas_edge.size()[0] | |
# find anchor points (i.e, edge points) | |
sample_num = minlen | |
index_anchors = torch.randint(0, minlen, (sample_num,), dtype=torch.long).cuda() | |
theta_anchors = torch.gather(thetas_edge, 0, index_anchors) | |
row_anchors, col_anchors = ind2sub(edges_loc[index_anchors].squeeze(1), w) | |
## compute the coordinates of 4-points, distances are from [2, 30] | |
distance_matrix = torch.randint(3, 20, (4,sample_num)).cuda() | |
pos_or_neg = torch.ones(4,sample_num).cuda() | |
pos_or_neg[:2,:] = -pos_or_neg[:2,:] | |
distance_matrix = distance_matrix.float() * pos_or_neg | |
col = col_anchors.unsqueeze(0).expand(4, sample_num).long() + torch.round(distance_matrix.double() * torch.cos(theta_anchors).unsqueeze(0)).long() | |
row = row_anchors.unsqueeze(0).expand(4, sample_num).long() + torch.round(distance_matrix.double() * torch.sin(theta_anchors).unsqueeze(0)).long() | |
# constrain 0=<c<=w, 0<=r<=h | |
# Note: index should minus 1 | |
col[col<0] = 0 | |
col[col>w-1] = w-1 | |
row[row<0] = 0 | |
row[row>h-1] = h-1 | |
# a-b, b-c, c-d | |
a = sub2ind(row[0,:], col[0,:], w) | |
b = sub2ind(row[1,:], col[1,:], w) | |
c = sub2ind(row[2,:], col[2,:], w) | |
d = sub2ind(row[3,:], col[3,:], w) | |
A = torch.cat((a,b,c), 0) | |
B = torch.cat((b,c,d), 0) | |
inputs_A = inputs[:, A] | |
inputs_B = inputs[:, B] | |
targets_A = targets[:, A] | |
targets_B = targets[:, B] | |
masks_A = torch.gather(masks, 0, A.long()) | |
masks_B = torch.gather(masks, 0, B.long()) | |
return inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B, sample_num, row, col | |
########### | |
# RANDOM SAMPLING | |
# input: | |
# inputs[i,:], targets[i, :], masks[i, :], self.mask_value, self.point_pairs | |
# return: | |
# inputs_A, inputs_B, targets_A, targets_B, consistent_masks_A, consistent_masks_B | |
########### | |
def randomSamplingNormal(inputs, targets, masks, sample_num): | |
# find A-B point pairs from predictions | |
num_effect_pixels = torch.sum(masks) | |
shuffle_effect_pixels = torch.randperm(num_effect_pixels).cuda() | |
valid_inputs = inputs[:, masks] | |
valid_targes = targets[:, masks] | |
inputs_A = valid_inputs[:, shuffle_effect_pixels[0:sample_num*2:2]] | |
inputs_B = valid_inputs[:, shuffle_effect_pixels[1:sample_num*2:2]] | |
# find corresponding pairs from GT | |
targets_A = valid_targes[:, shuffle_effect_pixels[0:sample_num*2:2]] | |
targets_B = valid_targes[:, shuffle_effect_pixels[1:sample_num*2:2]] | |
if inputs_A.shape[1] != inputs_B.shape[1]: | |
num_min = min(targets_A.shape[1], targets_B.shape[1]) | |
inputs_A = inputs_A[:, :num_min] | |
inputs_B = inputs_B[:, :num_min] | |
targets_A = targets_A[:, :num_min] | |
targets_B = targets_B[:, :num_min] | |
return inputs_A, inputs_B, targets_A, targets_B | |
class EdgeguidedNormalRegressionLoss(nn.Module): | |
def __init__(self, point_pairs=10000, cos_theta1=0.3, cos_theta2=0.95, cos_theta3=0.5, cos_theta4=0.86, mask_value=-1e-8, max_threshold=10.1): | |
super(EdgeguidedNormalRegressionLoss, self).__init__() | |
self.point_pairs = point_pairs # number of point pairs | |
self.mask_value = mask_value | |
self.max_threshold = max_threshold | |
self.cos_theta1 = cos_theta1 # 75 degree | |
self.cos_theta2 = cos_theta2 # 10 degree | |
self.cos_theta3 = cos_theta3 # 60 degree | |
self.cos_theta4 = cos_theta4 # 30 degree | |
self.kernel = torch.tensor(np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]], dtype=np.float32), requires_grad=False)[None, None, :, :].cuda() | |
def scale_shift_pred_depth(self, pred, gt): | |
b, c, h, w = pred.shape | |
mask = (gt > self.mask_value) & (gt < self.max_threshold) # [b, c, h, w] | |
EPS = 1e-6 * torch.eye(2, dtype=pred.dtype, device=pred.device) | |
scale_shift_batch = [] | |
ones_img = torch.ones((1, h, w), dtype=pred.dtype, device=pred.device) | |
for i in range(b): | |
mask_i = mask[i, ...] | |
pred_valid_i = pred[i, ...][mask_i] | |
ones_i = ones_img[mask_i] | |
pred_valid_ones_i = torch.stack((pred_valid_i, ones_i), dim=0) # [c+1, n] | |
A_i = torch.matmul(pred_valid_ones_i, pred_valid_ones_i.permute(1, 0)) # [2, 2] | |
A_inverse = torch.inverse(A_i + EPS) | |
gt_i = gt[i, ...][mask_i] | |
B_i = torch.matmul(pred_valid_ones_i, gt_i)[:, None] # [2, 1] | |
scale_shift_i = torch.matmul(A_inverse, B_i) # [2, 1] | |
scale_shift_batch.append(scale_shift_i) | |
scale_shift_batch = torch.stack(scale_shift_batch, dim=0) # [b, 2, 1] | |
ones = torch.ones_like(pred) | |
pred_ones = torch.cat((pred, ones), dim=1) # [b, 2, h, w] | |
pred_scale_shift = torch.matmul(pred_ones.permute(0, 2, 3, 1).reshape(b, h * w, 2), scale_shift_batch) # [b, h*w, 1] | |
pred_scale_shift = pred_scale_shift.permute(0, 2, 1).reshape((b, c, h, w)) | |
return pred_scale_shift | |
def getEdge(self, images): | |
n,c,h,w = images.size() | |
a = torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).cuda().view((1,1,3,3)).repeat(1, 1, 1, 1) | |
b = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).cuda().view((1,1,3,3)).repeat(1, 1, 1, 1) | |
if c == 3: | |
gradient_x = F.conv2d(images[:,0,:,:].unsqueeze(1), a) | |
gradient_y = F.conv2d(images[:,0,:,:].unsqueeze(1), b) | |
else: | |
gradient_x = F.conv2d(images, a) | |
gradient_y = F.conv2d(images, b) | |
edges = torch.sqrt(torch.pow(gradient_x,2)+ torch.pow(gradient_y,2)) | |
edges = F.pad(edges, (1,1,1,1), "constant", 0) | |
thetas = torch.atan2(gradient_y, gradient_x) | |
thetas = F.pad(thetas, (1,1,1,1), "constant", 0) | |
return edges, thetas | |
def getNormalEdge(self, normals): | |
n,c,h,w = normals.size() | |
a = torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).cuda().view((1,1,3,3)).repeat(3, 1, 1, 1) | |
b = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).cuda().view((1,1,3,3)).repeat(3, 1, 1, 1) | |
gradient_x = torch.abs(F.conv2d(normals, a, groups=c)) | |
gradient_y = torch.abs(F.conv2d(normals, b, groups=c)) | |
gradient_x = gradient_x.mean(dim=1, keepdim=True) | |
gradient_y = gradient_y.mean(dim=1, keepdim=True) | |
edges = torch.sqrt(torch.pow(gradient_x,2)+ torch.pow(gradient_y,2)) | |
edges = F.pad(edges, (1,1,1,1), "constant", 0) | |
thetas = torch.atan2(gradient_y, gradient_x) | |
thetas = F.pad(thetas, (1,1,1,1), "constant", 0) | |
return edges, thetas | |
def forward(self, pred_depths, gt_depths, images, focal_length): | |
""" | |
inputs and targets: surface normal image | |
images: rgb images | |
""" | |
masks = gt_depths > self.mask_value | |
#pred_depths_ss = self.scale_shift_pred_depth(pred_depths, gt_depths) | |
inputs = surface_normal_from_depth(pred_depths, focal_length, valid_mask=masks) | |
targets = surface_normal_from_depth(gt_depths, focal_length, valid_mask=masks) | |
# find edges from RGB | |
edges_img, thetas_img = self.getEdge(images) | |
# find edges from normals | |
edges_normal, thetas_normal = self.getNormalEdge(targets) | |
mask_img_border = torch.ones_like(edges_normal) # normals on the borders | |
mask_img_border[:, :, 5:-5, 5:-5] = 0 | |
edges_normal[mask_img_border.bool()] = 0 | |
# find edges from depth | |
edges_depth, _ = self.getEdge(gt_depths) | |
edges_depth_mask = edges_depth.ge(edges_depth.max() * 0.1) | |
edges_mask_dilate = torch.clamp(torch.nn.functional.conv2d(edges_depth_mask.float(), self.kernel, padding=(1, 1)), 0, | |
1).bool() | |
edges_normal[edges_mask_dilate] = 0 | |
edges_img[edges_mask_dilate] = 0 | |
#============================= | |
n,c,h,w = targets.size() | |
inputs = inputs.contiguous().view(n, c, -1).double() | |
targets = targets.contiguous().view(n, c, -1).double() | |
masks = masks.contiguous().view(n, -1) | |
edges_img = edges_img.contiguous().view(n, -1).double() | |
thetas_img = thetas_img.contiguous().view(n, -1).double() | |
edges_normal = edges_normal.view(n, -1).double() | |
thetas_normal = thetas_normal.view(n, -1).double() | |
# initialization | |
loss = torch.DoubleTensor([0.0]).cuda() | |
for i in range(n): | |
# Edge-Guided sampling | |
inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B, sample_num, row_img, col_img = edgeGuidedSampling(inputs[i,:], targets[i, :], edges_img[i], thetas_img[i], masks[i, :], h, w) | |
normal_inputs_A, normal_inputs_B, normal_targets_A, normal_targets_B, normal_masks_A, normal_masks_B, normal_sample_num, row_normal, col_normal = edgeGuidedSampling(inputs[i,:], targets[i, :], edges_normal[i], thetas_normal[i], masks[i, :], h, w) | |
# Combine EGS + EGNS | |
inputs_A = torch.cat((inputs_A, normal_inputs_A), 1) | |
inputs_B = torch.cat((inputs_B, normal_inputs_B), 1) | |
targets_A = torch.cat((targets_A, normal_targets_A), 1) | |
targets_B = torch.cat((targets_B, normal_targets_B), 1) | |
masks_A = torch.cat((masks_A, normal_masks_A), 0) | |
masks_B = torch.cat((masks_B, normal_masks_B), 0) | |
# consider forward-backward consistency checking, i.e, only compute losses of point pairs with valid GT | |
consistency_mask = masks_A & masks_B | |
#GT ordinal relationship | |
target_cos = torch.abs(torch.sum(targets_A * targets_B, dim=0)) | |
input_cos = torch.abs(torch.sum(inputs_A * inputs_B, dim=0)) | |
# ranking regression | |
#loss += torch.mean(torch.abs(target_cos[consistency_mask] - input_cos[consistency_mask])) | |
# Ranking for samples | |
mask_cos75 = target_cos < self.cos_theta1 | |
mask_cos10 = target_cos > self.cos_theta2 | |
# Regression for samples | |
loss += torch.sum(torch.abs(target_cos[mask_cos75 & consistency_mask] - input_cos[mask_cos75 & consistency_mask])) / (torch.sum(mask_cos75 & consistency_mask)+1e-8) | |
loss += torch.sum(torch.abs(target_cos[mask_cos10 & consistency_mask] - input_cos[mask_cos10 & consistency_mask])) / (torch.sum(mask_cos10 & consistency_mask)+1e-8) | |
# Random Sampling regression | |
random_sample_num = torch.sum(mask_cos10 & consistency_mask) + torch.sum(torch.sum(mask_cos75 & consistency_mask)) | |
random_inputs_A, random_inputs_B, random_targets_A, random_targets_B = randomSamplingNormal(inputs[i,:], targets[i, :], masks[i, :], random_sample_num) | |
#GT ordinal relationship | |
random_target_cos = torch.abs(torch.sum(random_targets_A * random_targets_B, dim=0)) | |
random_input_cos = torch.abs(torch.sum(random_inputs_A * random_inputs_B, dim=0)) | |
loss += torch.sum(torch.abs(random_target_cos - random_input_cos)) / (random_target_cos.shape[0] + 1e-8) | |
if loss[0] != 0: | |
return loss[0].float() / n | |
else: | |
return pred_depths.sum() * 0.0 | |