Stable-X
Update code
9dfa4de
raw
history blame
7.81 kB
# A reimplemented version in public environments by Xiao Fu and Mu Hu
import pickle
import os
import h5py
import numpy as np
import cv2
import torch
import torch.nn as nn
import glob
def init_image_coor(height, width):
x_row = np.arange(0, width)
x = np.tile(x_row, (height, 1))
x = x[np.newaxis, :, :]
x = x.astype(np.float32)
x = torch.from_numpy(x.copy()).cuda()
u_u0 = x - width/2.0
y_col = np.arange(0, height) # y_col = np.arange(0, height)
y = np.tile(y_col, (width, 1)).T
y = y[np.newaxis, :, :]
y = y.astype(np.float32)
y = torch.from_numpy(y.copy()).cuda()
v_v0 = y - height/2.0
return u_u0, v_v0
def depth_to_xyz(depth, focal_length):
b, c, h, w = depth.shape
u_u0, v_v0 = init_image_coor(h, w)
x = u_u0 * depth / focal_length[0]
y = v_v0 * depth / focal_length[1]
z = depth
pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c]
return pw
def get_surface_normal(xyz, patch_size=5):
# xyz: [1, h, w, 3]
x, y, z = torch.unbind(xyz, dim=3)
x = torch.unsqueeze(x, 0)
y = torch.unsqueeze(y, 0)
z = torch.unsqueeze(z, 0)
xx = x * x
yy = y * y
zz = z * z
xy = x * y
xz = x * z
yz = y * z
patch_weight = torch.ones((1, 1, patch_size, patch_size), requires_grad=False).cuda()
xx_patch = nn.functional.conv2d(xx, weight=patch_weight, padding=int(patch_size / 2))
yy_patch = nn.functional.conv2d(yy, weight=patch_weight, padding=int(patch_size / 2))
zz_patch = nn.functional.conv2d(zz, weight=patch_weight, padding=int(patch_size / 2))
xy_patch = nn.functional.conv2d(xy, weight=patch_weight, padding=int(patch_size / 2))
xz_patch = nn.functional.conv2d(xz, weight=patch_weight, padding=int(patch_size / 2))
yz_patch = nn.functional.conv2d(yz, weight=patch_weight, padding=int(patch_size / 2))
ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch],
dim=4)
ATA = torch.squeeze(ATA)
ATA = torch.reshape(ATA, (ATA.size(0), ATA.size(1), 3, 3))
eps_identity = 1e-6 * torch.eye(3, device=ATA.device, dtype=ATA.dtype)[None, None, :, :].repeat([ATA.size(0), ATA.size(1), 1, 1])
ATA = ATA + eps_identity
x_patch = nn.functional.conv2d(x, weight=patch_weight, padding=int(patch_size / 2))
y_patch = nn.functional.conv2d(y, weight=patch_weight, padding=int(patch_size / 2))
z_patch = nn.functional.conv2d(z, weight=patch_weight, padding=int(patch_size / 2))
AT1 = torch.stack([x_patch, y_patch, z_patch], dim=4)
AT1 = torch.squeeze(AT1)
AT1 = torch.unsqueeze(AT1, 3)
patch_num = 4
patch_x = int(AT1.size(1) / patch_num)
patch_y = int(AT1.size(0) / patch_num)
n_img = torch.randn(AT1.shape).cuda()
overlap = patch_size // 2 + 1
for x in range(int(patch_num)):
for y in range(int(patch_num)):
left_flg = 0 if x == 0 else 1
right_flg = 0 if x == patch_num -1 else 1
top_flg = 0 if y == 0 else 1
btm_flg = 0 if y == patch_num - 1 else 1
at1 = AT1[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
ata = ATA[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
# n_img_tmp, _ = torch.solve(at1, ata)
n_img_tmp = torch.linalg.solve(ata, at1)
n_img_tmp_select = n_img_tmp[top_flg * overlap:patch_y + top_flg * overlap, left_flg * overlap:patch_x + left_flg * overlap, :, :]
n_img[y * patch_y:y * patch_y + patch_y, x * patch_x:x * patch_x + patch_x, :, :] = n_img_tmp_select
n_img_L2 = torch.sqrt(torch.sum(n_img ** 2, dim=2, keepdim=True))
n_img_norm = n_img / n_img_L2
# re-orient normals consistently
orient_mask = torch.sum(torch.squeeze(n_img_norm) * torch.squeeze(xyz), dim=2) > 0
n_img_norm[orient_mask] *= -1
return n_img_norm
def get_surface_normalv2(xyz, patch_size=5):
"""
xyz: xyz coordinates
patch: [p1, p2, p3,
p4, p5, p6,
p7, p8, p9]
surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)]
return: normal [h, w, 3, b]
"""
b, h, w, c = xyz.shape
half_patch = patch_size // 2
xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device)
xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz
# xyz_left_top = xyz_pad[:, :h, :w, :] # p1
# xyz_right_bottom = xyz_pad[:, -h:, -w:, :]# p9
# xyz_left_bottom = xyz_pad[:, -h:, :w, :] # p7
# xyz_right_top = xyz_pad[:, :h, -w:, :] # p3
# xyz_cross1 = xyz_left_top - xyz_right_bottom # p1p9
# xyz_cross2 = xyz_left_bottom - xyz_right_top # p7p3
xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4
xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6
xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2
xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8
xyz_horizon = xyz_left - xyz_right # p4p6
xyz_vertical = xyz_top - xyz_bottom # p2p8
xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4
xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6
xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2
xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8
xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6
xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8
n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3)
n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3)
# re-orient normals consistently
orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0
n_img_1[orient_mask] *= -1
orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0
n_img_2[orient_mask] *= -1
n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True))
n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8)
n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True))
n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8)
# average 2 norms
n_img_aver = n_img1_norm + n_img2_norm
n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True))
n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8)
# re-orient normals consistently
orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0
n_img_aver_norm[orient_mask] *= -1
n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b]
# a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze()
# plt.imshow(np.abs(a), cmap='rainbow')
# plt.show()
return n_img_aver_norm_out#n_img1_norm.permute((1, 2, 3, 0))
def surface_normal_from_depth(depth, focal_length, valid_mask=None):
# para depth: depth map, [b, c, h, w]
b, c, h, w = depth.shape
focal_length = focal_length[:, None, None, None]
depth_filter = nn.functional.avg_pool2d(depth, kernel_size=3, stride=1, padding=1)
#depth_filter = nn.functional.avg_pool2d(depth_filter, kernel_size=3, stride=1, padding=1)
xyz = depth_to_xyz(depth_filter, focal_length)
sn_batch = []
for i in range(b):
xyz_i = xyz[i, :][None, :, :, :]
#normal = get_surface_normalv2(xyz_i)
normal = get_surface_normal(xyz_i)
sn_batch.append(normal)
sn_batch = torch.cat(sn_batch, dim=3).permute((3, 2, 0, 1)) # [b, c, h, w]
if valid_mask != None:
mask_invalid = (~valid_mask).repeat(1, 3, 1, 1)
sn_batch[mask_invalid] = 0.0
return sn_batch