|
""" |
|
@Date: 2021/11/06 |
|
@description: |
|
""" |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import matplotlib.pyplot as plt |
|
|
|
from utils.conversion import depth2xyz |
|
|
|
|
|
def convert_img(value, h, need_nor=True, cmap=None): |
|
value = value.clone().detach().cpu().numpy()[None] |
|
if need_nor: |
|
value -= value.min() |
|
value /= value.max() - value.min() |
|
grad_img = value.repeat(int(h), axis=0) |
|
|
|
if cmap is None: |
|
grad_img = grad_img[..., np.newaxis].repeat(3, axis=-1) |
|
elif cmap == cv2.COLORMAP_PLASMA: |
|
grad_img = cv2.applyColorMap((grad_img * 255).astype(np.uint8), colormap=cmap) |
|
grad_img = grad_img[..., ::-1] |
|
grad_img = grad_img.astype(np.float32) / 255.0 |
|
elif cmap == 'HSV': |
|
grad_img = np.round(grad_img * 1000) / 1000.0 |
|
grad_img = grad_img[..., np.newaxis].repeat(3, axis=-1) |
|
grad_img[..., 0] = grad_img[..., 0] * 180 |
|
grad_img[..., 1] = 255 |
|
grad_img[..., 2] = 255 |
|
grad_img = grad_img.astype(np.uint8) |
|
grad_img = cv2.cvtColor(grad_img, cv2.COLOR_HSV2RGB) |
|
grad_img = grad_img.astype(np.float32) / 255.0 |
|
return grad_img |
|
|
|
|
|
def show_grad(depth, grad_conv, h=5, show=False): |
|
""" |
|
:param h: |
|
:param depth: [patch_num] |
|
:param grad_conv: |
|
:param show: |
|
:return: |
|
""" |
|
|
|
direction, angle, grad = get_all(depth[None], grad_conv) |
|
|
|
|
|
|
|
|
|
depth_img = convert_img(depth, h, cmap=cv2.COLORMAP_PLASMA) |
|
angle_img = convert_img(angle[0], h, cmap='HSV') |
|
|
|
|
|
grad_img = convert_img(grad[0], h) |
|
img = np.concatenate([depth_img, angle_img, grad_img], axis=0) |
|
if show: |
|
plt.imshow(img) |
|
plt.show() |
|
return img |
|
|
|
|
|
def get_grad(direction): |
|
""" |
|
:param direction: [b patch_num] |
|
:return:[b patch_num] |
|
""" |
|
a = torch.roll(direction, -1, dims=1) |
|
b = torch.roll(direction, 1, dims=1) |
|
grad = torch.acos(torch.clip(a[..., 0] * b[..., 0] + a[..., 1] * b[..., 1], -1+1e-6, 1-1e-6)) |
|
return grad |
|
|
|
|
|
def get_grad2(angle, grad_conv): |
|
""" |
|
:param angle: [b patch_num] |
|
:param grad_conv: |
|
:return:[b patch_num] |
|
""" |
|
angle = torch.sin(angle) |
|
angle = angle + 1 |
|
|
|
angle = torch.cat([angle[..., -1:], angle, angle[..., :1]], dim=-1) |
|
grad = grad_conv(angle[:, None]) |
|
|
|
return grad.reshape(angle.shape[0], -1) |
|
|
|
|
|
def get_edge_angle(direction): |
|
""" |
|
:param direction: [b patch_num 2] |
|
:return: |
|
""" |
|
angle = torch.atan2(direction[..., 1], direction[..., 0]) |
|
return angle |
|
|
|
|
|
def get_edge_direction(depth): |
|
xz = depth2xyz(depth)[..., ::2] |
|
direction = torch.roll(xz, -1, dims=1) - xz |
|
direction = direction / direction.norm(p=2, dim=-1)[..., None] |
|
return direction |
|
|
|
|
|
def get_all(depth, grad_conv): |
|
""" |
|
|
|
:param grad_conv: |
|
:param depth: [b patch_num] |
|
:return: |
|
""" |
|
direction = get_edge_direction(depth) |
|
angle = get_edge_angle(direction) |
|
|
|
angle_grad = get_grad2(angle, grad_conv) |
|
return direction, angle, angle_grad |
|
|