Realcat
fix: eloftr
63f3cf2
raw
history blame
24.7 kB
# %BANNER_BEGIN%
# ---------------------------------------------------------------------
# %COPYRIGHT_BEGIN%
#
# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
#
# Unpublished Copyright (c) 2020
# Magic Leap, Inc., All Rights Reserved.
#
# NOTICE: All information contained herein is, and remains the property
# of COMPANY. The intellectual and technical concepts contained herein
# are proprietary to COMPANY and may be covered by U.S. and Foreign
# Patents, patents in process, and are protected by trade secret or
# copyright law. Dissemination of this information or reproduction of
# this material is strictly forbidden unless prior written permission is
# obtained from COMPANY. Access to the source code contained herein is
# hereby forbidden to anyone except current COMPANY employees, managers
# or contractors who have executed Confidentiality and Non-disclosure
# agreements explicitly covering such access.
#
# The copyright notice above does not evidence any actual or intended
# publication or disclosure of this source code, which includes
# information that is confidential and/or proprietary, and is a trade
# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
#
# %COPYRIGHT_END%
# ----------------------------------------------------------------------
# %AUTHORS_BEGIN%
#
# Originating Authors: Paul-Edouard Sarlin
#
# %AUTHORS_END%
# --------------------------------------------------------------------*/
# %BANNER_END%
from pathlib import Path
import torch
from torch import nn
import numpy as np
import cv2
import torch.nn.functional as F
def simple_nms(scores, nms_radius: int):
""" Fast Non-maximum suppression to remove nearby points """
assert (nms_radius >= 0)
def max_pool(x):
return torch.nn.functional.max_pool2d(
x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius)
zeros = torch.zeros_like(scores)
max_mask = scores == max_pool(scores)
for _ in range(2):
supp_mask = max_pool(max_mask.float()) > 0
supp_scores = torch.where(supp_mask, zeros, scores)
new_max_mask = supp_scores == max_pool(supp_scores)
max_mask = max_mask | (new_max_mask & (~supp_mask))
return torch.where(max_mask, scores, zeros)
def remove_borders(keypoints, scores, border: int, height: int, width: int):
""" Removes keypoints too close to the border """
mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))
mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))
mask = mask_h & mask_w
return keypoints[mask], scores[mask]
def top_k_keypoints(keypoints, scores, k: int):
if k >= len(keypoints):
return keypoints, scores
scores, indices = torch.topk(scores, k, dim=0)
return keypoints[indices], scores
def sample_descriptors(keypoints, descriptors, s: int = 8):
""" Interpolate descriptors at keypoint locations """
b, c, h, w = descriptors.shape
keypoints = keypoints - s / 2 + 0.5
keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
).to(keypoints)[None]
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {}
descriptors = torch.nn.functional.grid_sample(
descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)
descriptors = torch.nn.functional.normalize(
descriptors.reshape(b, c, -1), p=2, dim=1)
return descriptors
class SuperPoint(nn.Module):
"""SuperPoint Convolutional Detector and Descriptor
SuperPoint: Self-Supervised Interest Point Detection and
Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew
Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629
"""
default_config = {
'descriptor_dim': 256,
'nms_radius': 3,
'keypoint_threshold': 0.001,
'max_keypoints': -1,
'min_keypoints': 32,
'remove_borders': 4,
}
def __init__(self, config):
super().__init__()
self.config = {**self.default_config, **config}
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) # 64
self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) # 64
self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) # 128
self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) # 128
self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) # 256
self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) # 256
self.convDb = nn.Conv2d(
c5, self.config['descriptor_dim'],
kernel_size=1, stride=1, padding=0)
# path = Path(__file__).parent / 'weights/superpoint_v1.pth'
path = config['weight_path']
self.load_state_dict(torch.load(str(path), map_location='cpu'), strict=True)
mk = self.config['max_keypoints']
if mk == 0 or mk < -1:
raise ValueError('\"max_keypoints\" must be positive or \"-1\"')
print('Loaded SuperPoint model')
def extract_global(self, data):
# Shared Encoder
x0 = self.relu(self.conv1a(data['image']))
x0 = self.relu(self.conv1b(x0))
x0 = self.pool(x0)
x1 = self.relu(self.conv2a(x0))
x1 = self.relu(self.conv2b(x1))
x1 = self.pool(x1)
x2 = self.relu(self.conv3a(x1))
x2 = self.relu(self.conv3b(x2))
x2 = self.pool(x2)
x3 = self.relu(self.conv4a(x2))
x3 = self.relu(self.conv4b(x3))
x4 = self.relu(self.convDa(x3))
# print('ex_g: ', x0.shape, x1.shape, x2.shape, x3.shape, x4.shape)
return [x0, x1, x2, x3, x4]
def extract_local_global(self, data):
# Shared Encoder
b, ic, ih, iw = data['image'].shape
x0 = self.relu(self.conv1a(data['image']))
x0 = self.relu(self.conv1b(x0))
x0 = self.pool(x0)
x1 = self.relu(self.conv2a(x0))
x1 = self.relu(self.conv2b(x1))
x1 = self.pool(x1)
x2 = self.relu(self.conv3a(x1))
x2 = self.relu(self.conv3b(x2))
x2 = self.pool(x2)
x3 = self.relu(self.conv4a(x2))
x3 = self.relu(self.conv4b(x3))
# Compute the dense keypoint scores
cPa = self.relu(self.convPa(x3))
score = self.convPb(cPa)
score = torch.nn.functional.softmax(score, 1)[:, :-1]
# print(scores.shape)
b, _, h, w = score.shape
score = score.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
score = score.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
score = torch.nn.functional.interpolate(score.unsqueeze(1), size=(ih, iw), align_corners=True,
mode='bilinear')
score = score.squeeze(1)
# extract kpts
nms_scores = simple_nms(scores=score, nms_radius=self.config['nms_radius'])
keypoints = [
torch.nonzero(s >= self.config['keypoint_threshold'])
for s in nms_scores]
scores = [s[tuple(k.t())] for s, k in zip(nms_scores, keypoints)]
if len(scores[0]) <= self.config['min_keypoints']:
keypoints = [
torch.nonzero(s >= self.config['keypoint_threshold'] * 0.5)
for s in nms_scores]
scores = [s[tuple(k.t())] for s, k in zip(nms_scores, keypoints)]
# Discard keypoints near the image borders
keypoints, scores = list(zip(*[
remove_borders(k, s, self.config['remove_borders'], ih, iw)
for k, s in zip(keypoints, scores)]))
# Keep the k keypoints with the highest score
if self.config['max_keypoints'] >= 0:
keypoints, scores = list(zip(*[
top_k_keypoints(k, s, self.config['max_keypoints'])
for k, s in zip(keypoints, scores)]))
# Convert (h, w) to (x, y)
keypoints = [torch.flip(k, [1]).float() for k in keypoints]
# Compute the dense descriptors
cDa = self.relu(self.convDa(x3))
desc_map = self.convDb(cDa)
desc_map = torch.nn.functional.normalize(desc_map, p=2, dim=1)
descriptors = [sample_descriptors(k[None], d[None], 8)[0]
for k, d in zip(keypoints, desc_map)]
return {
'score_map': score,
'desc_map': desc_map,
'mid_features': cDa, # 256
'global_descriptors': [x0, x1, x2, x3, cDa],
'keypoints': keypoints,
'scores': scores,
'descriptors': descriptors,
}
def sample(self, score_map, semi_descs, kpts, s=8, norm_desc=True):
# print('sample: ', score_map.shape, semi_descs.shape, kpts.shape)
b, c, h, w = semi_descs.shape
norm_kpts = kpts - s / 2 + 0.5
norm_kpts = norm_kpts / torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
).to(norm_kpts)[None]
norm_kpts = norm_kpts * 2 - 1
# args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {}
descriptors = torch.nn.functional.grid_sample(
semi_descs, norm_kpts.view(b, 1, -1, 2), mode='bilinear', align_corners=True)
if norm_desc:
descriptors = torch.nn.functional.normalize(
descriptors.reshape(b, c, -1), p=2, dim=1)
else:
descriptors = descriptors.reshape(b, c, -1)
# print('max: ', torch.min(kpts[:, 1].long()), torch.max(kpts[:, 1].long()), torch.min(kpts[:, 0].long()),
# torch.max(kpts[:, 0].long()))
scores = score_map[0, kpts[:, 1].long(), kpts[:, 0].long()]
return scores, descriptors.squeeze(0)
def extract(self, data):
""" Compute keypoints, scores, descriptors for image """
# Shared Encoder
x = self.relu(self.conv1a(data['image']))
x = self.relu(self.conv1b(x))
x = self.pool(x)
x = self.relu(self.conv2a(x))
x = self.relu(self.conv2b(x))
x = self.pool(x)
x = self.relu(self.conv3a(x))
x = self.relu(self.conv3b(x))
x = self.pool(x)
x = self.relu(self.conv4a(x))
x = self.relu(self.conv4b(x))
# Compute the dense keypoint scores
cPa = self.relu(self.convPa(x))
scores = self.convPb(cPa)
scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
b, _, h, w = scores.shape
scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
# Compute the dense descriptors
cDa = self.relu(self.convDa(x))
descriptors = self.convDb(cDa)
descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
return scores, descriptors
def det(self, image):
""" Compute keypoints, scores, descriptors for image """
# Shared Encoder
x = self.relu(self.conv1a(image))
x = self.relu(self.conv1b(x))
x = self.pool(x)
x = self.relu(self.conv2a(x))
x = self.relu(self.conv2b(x))
x = self.pool(x)
x = self.relu(self.conv3a(x))
x = self.relu(self.conv3b(x))
x = self.pool(x)
x = self.relu(self.conv4a(x))
x = self.relu(self.conv4b(x))
# Compute the dense keypoint scores
cPa = self.relu(self.convPa(x))
scores = self.convPb(cPa)
scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
# print(scores.shape)
b, _, h, w = scores.shape
scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
# Compute the dense descriptors
cDa = self.relu(self.convDa(x))
descriptors = self.convDb(cDa)
descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
return scores, descriptors
def forward(self, data):
""" Compute keypoints, scores, descriptors for image """
# Shared Encoder
x = self.relu(self.conv1a(data['image']))
x = self.relu(self.conv1b(x))
x = self.pool(x)
x = self.relu(self.conv2a(x))
x = self.relu(self.conv2b(x))
x = self.pool(x)
x = self.relu(self.conv3a(x))
x = self.relu(self.conv3b(x))
x = self.pool(x)
x = self.relu(self.conv4a(x))
x = self.relu(self.conv4b(x))
# Compute the dense keypoint scores
cPa = self.relu(self.convPa(x))
scores = self.convPb(cPa)
scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
# print(scores.shape)
b, _, h, w = scores.shape
scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
scores = simple_nms(scores, self.config['nms_radius'])
# Extract keypoints
keypoints = [
torch.nonzero(s > self.config['keypoint_threshold'])
for s in scores]
scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
# Discard keypoints near the image borders
keypoints, scores = list(zip(*[
remove_borders(k, s, self.config['remove_borders'], h * 8, w * 8)
for k, s in zip(keypoints, scores)]))
# Keep the k keypoints with highest score
if self.config['max_keypoints'] >= 0:
keypoints, scores = list(zip(*[
top_k_keypoints(k, s, self.config['max_keypoints'])
for k, s in zip(keypoints, scores)]))
# Convert (h, w) to (x, y)
keypoints = [torch.flip(k, [1]).float() for k in keypoints]
# Compute the dense descriptors
cDa = self.relu(self.convDa(x))
descriptors = self.convDb(cDa)
descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
# Extract descriptors
# print(keypoints[0].shape)
descriptors = [sample_descriptors(k[None], d[None], 8)[0]
for k, d in zip(keypoints, descriptors)]
return {
'keypoints': keypoints,
'scores': scores,
'descriptors': descriptors,
'global_descriptor': x,
}
def extract_descriptor(sample_pts, coarse_desc, H, W):
'''
:param samplt_pts:
:param coarse_desc:
:return:
'''
with torch.no_grad():
norm_sample_pts = torch.zeros_like(sample_pts)
norm_sample_pts[0, :] = (sample_pts[0, :] / (float(W) / 2.)) - 1. # x
norm_sample_pts[1, :] = (sample_pts[1, :] / (float(H) / 2.)) - 1. # y
norm_sample_pts = norm_sample_pts.transpose(0, 1).contiguous()
norm_sample_pts = norm_sample_pts.view(1, 1, -1, 2).float()
sample_desc = torch.nn.functional.grid_sample(coarse_desc[None], norm_sample_pts, mode='bilinear',
align_corners=False)
sample_desc = torch.nn.functional.normalize(sample_desc, dim=1).squeeze(2).squeeze(0)
return sample_desc
def extract_sp_return(model, img, conf_th=0.005,
mask=None,
topK=-1,
**kwargs):
old_bm = torch.backends.cudnn.benchmark
torch.backends.cudnn.benchmark = False # speedup
# print(img.shape)
img = img.cuda()
# if len(img.shape) == 3: # gray image
# img = img[None]
B, one, H, W = img.shape
all_pts = []
all_descs = []
if 'scales' in kwargs.keys():
scales = kwargs.get('scales')
else:
scales = [1.0]
for s in scales:
if s == 1.0:
new_img = img
else:
nh = int(H * s)
nw = int(W * s)
new_img = F.interpolate(img, size=(nh, nw), mode='bilinear', align_corners=True)
nh, nw = new_img.shape[2:]
with torch.no_grad():
heatmap, coarse_desc = model.det(new_img)
# print("nh, nw, heatmap, desc: ", nh, nw, heatmap.shape, coarse_desc.shape)
if len(heatmap.size()) == 3:
heatmap = heatmap.unsqueeze(1)
if len(heatmap.size()) == 2:
heatmap = heatmap.unsqueeze(0)
heatmap = heatmap.unsqueeze(1)
# print(heatmap.shape)
if heatmap.size(2) != nh or heatmap.size(3) != nw:
heatmap = F.interpolate(heatmap, size=[nh, nw], mode='bilinear', align_corners=True)
conf_thresh = conf_th
nms_dist = 4
border_remove = 4
scores = simple_nms(heatmap, nms_radius=nms_dist)
keypoints = [
torch.nonzero(s > conf_thresh)
for s in scores]
scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
# print(keypoints[0].shape)
keypoints = [torch.flip(k, [1]).float() for k in keypoints]
scores = scores[0].data.cpu().numpy().squeeze()
keypoints = keypoints[0].data.cpu().numpy().squeeze()
pts = keypoints.transpose()
pts[2, :] = scores
inds = np.argsort(pts[2, :])
pts = pts[:, inds[::-1]] # Sort by confidence.
# Remove points along border.
bord = border_remove
toremoveW = np.logical_or(pts[0, :] < bord, pts[0, :] >= (W - bord))
toremoveH = np.logical_or(pts[1, :] < bord, pts[1, :] >= (H - bord))
toremove = np.logical_or(toremoveW, toremoveH)
pts = pts[:, ~toremove]
# valid_idex = heatmap > conf_thresh
# valid_score = heatmap[valid_idex]
# """
# --- Process descriptor.
# coarse_desc = coarse_desc.data.cpu().numpy().squeeze()
D = coarse_desc.size(1)
if pts.shape[1] == 0:
desc = np.zeros((D, 0))
else:
if coarse_desc.size(2) == nh and coarse_desc.size(3) == nw:
desc = coarse_desc[:, :, pts[1, :], pts[0, :]]
desc = desc.data.cpu().numpy().reshape(D, -1)
else:
# Interpolate into descriptor map using 2D point locations.
samp_pts = torch.from_numpy(pts[:2, :].copy())
samp_pts[0, :] = (samp_pts[0, :] / (float(nw) / 2.)) - 1.
samp_pts[1, :] = (samp_pts[1, :] / (float(nh) / 2.)) - 1.
samp_pts = samp_pts.transpose(0, 1).contiguous()
samp_pts = samp_pts.view(1, 1, -1, 2)
samp_pts = samp_pts.float()
samp_pts = samp_pts.cuda()
desc = torch.nn.functional.grid_sample(coarse_desc, samp_pts, mode='bilinear', align_corners=True)
desc = desc.data.cpu().numpy().reshape(D, -1)
desc /= np.linalg.norm(desc, axis=0)[np.newaxis, :]
if pts.shape[1] == 0:
continue
# print(pts.shape, heatmap.shape, new_img.shape, img.shape, nw, nh, W, H)
pts[0, :] = pts[0, :] * W / nw
pts[1, :] = pts[1, :] * H / nh
all_pts.append(np.transpose(pts, [1, 0]))
all_descs.append(np.transpose(desc, [1, 0]))
all_pts = np.vstack(all_pts)
all_descs = np.vstack(all_descs)
torch.backends.cudnn.benchmark = old_bm
if all_pts.shape[0] == 0:
return None, None, None
keypoints = all_pts[:, 0:2]
scores = all_pts[:, 2]
descriptors = all_descs
if mask is not None:
# cv2.imshow("mask", mask)
# cv2.waitKey(0)
labels = []
others = []
keypoints_with_labels = []
scores_with_labels = []
descriptors_with_labels = []
keypoints_without_labels = []
scores_without_labels = []
descriptors_without_labels = []
id_img = np.int32(mask[:, :, 2]) * 256 * 256 + np.int32(mask[:, :, 1]) * 256 + np.int32(mask[:, :, 0])
# print(img.shape, id_img.shape)
for i in range(keypoints.shape[0]):
x = keypoints[i, 0]
y = keypoints[i, 1]
# print("x-y", x, y, int(x), int(y))
gid = id_img[int(y), int(x)]
if gid == 0:
keypoints_without_labels.append(keypoints[i])
scores_without_labels.append(scores[i])
descriptors_without_labels.append(descriptors[i])
others.append(0)
else:
keypoints_with_labels.append(keypoints[i])
scores_with_labels.append(scores[i])
descriptors_with_labels.append(descriptors[i])
labels.append(gid)
if topK > 0:
if topK <= len(keypoints_with_labels):
idxes = np.array(scores_with_labels, float).argsort()[::-1][:topK]
keypoints = np.array(keypoints_with_labels, float)[idxes]
scores = np.array(scores_with_labels, float)[idxes]
labels = np.array(labels, np.int32)[idxes]
descriptors = np.array(descriptors_with_labels, float)[idxes]
elif topK >= len(keypoints_with_labels) + len(keypoints_without_labels):
# keypoints = np.vstack([keypoints_with_labels, keypoints_without_labels])
# scores = np.vstack([scorescc_with_labels, scores_without_labels])
# descriptors = np.vstack([descriptors_with_labels, descriptors_without_labels])
# labels = np.vstack([labels, others])
keypoints = keypoints_with_labels
scores = scores_with_labels
descriptors = descriptors_with_labels
for i in range(len(others)):
keypoints.append(keypoints_without_labels[i])
scores.append(scores_without_labels[i])
descriptors.append(descriptors_without_labels[i])
labels.append(others[i])
else:
n = topK - len(keypoints_with_labels)
idxes = np.array(scores_without_labels, float).argsort()[::-1][:n]
keypoints = keypoints_with_labels
scores = scores_with_labels
descriptors = descriptors_with_labels
for i in idxes:
keypoints.append(keypoints_without_labels[i])
scores.append(scores_without_labels[i])
descriptors.append(descriptors_without_labels[i])
labels.append(others[i])
keypoints = np.array(keypoints, float)
descriptors = np.array(descriptors, float)
# print(keypoints.shape, descriptors.shape)
return {"keypoints": np.array(keypoints, float),
"descriptors": np.array(descriptors, float),
"scores": np.array(scores, float),
"labels": np.array(labels, np.int32),
}
else:
# print(topK)
if topK > 0:
idxes = np.array(scores, dtype=float).argsort()[::-1][:topK]
keypoints = np.array(keypoints[idxes], dtype=float)
scores = np.array(scores[idxes], dtype=float)
descriptors = np.array(descriptors[idxes], dtype=float)
keypoints = np.array(keypoints, dtype=float)
scores = np.array(scores, dtype=float)
descriptors = np.array(descriptors, dtype=float)
# print(keypoints.shape, descriptors.shape)
return {"keypoints": np.array(keypoints, dtype=float),
"descriptors": descriptors,
"scores": scores,
}