Realcat
fix: eloftr
63f3cf2
raw
history blame
10.6 kB
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File pram -> gm
@IDE PyCharm
@Author fx221@cam.ac.uk
@Date 07/02/2024 10:47
=================================================='''
import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.layers import KeypointEncoder, AttentionalPropagation
from nets.utils import normalize_keypoints, arange_like
eps = 1e-8
def dual_softmax(M, dustbin):
M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1)
M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2)
score = torch.log_softmax(M, dim=-1) + torch.log_softmax(M, dim=1)
return torch.exp(score)
def sinkhorn(M, r, c, iteration):
p = torch.softmax(M, dim=-1)
u = torch.ones_like(r)
v = torch.ones_like(c)
for _ in range(iteration):
u = r / ((p * v.unsqueeze(-2)).sum(-1) + eps)
v = c / ((p * u.unsqueeze(-1)).sum(-2) + eps)
p = p * u.unsqueeze(-1) * v.unsqueeze(-2)
return p
def sink_algorithm(M, dustbin, iteration):
M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1)
M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2)
r = torch.ones([M.shape[0], M.shape[1] - 1], device='cuda')
r = torch.cat([r, torch.ones([M.shape[0], 1], device='cuda') * M.shape[1]], dim=-1)
c = torch.ones([M.shape[0], M.shape[2] - 1], device='cuda')
c = torch.cat([c, torch.ones([M.shape[0], 1], device='cuda') * M.shape[2]], dim=-1)
p = sinkhorn(M, r, c, iteration)
return p
class AttentionalGNN(nn.Module):
def __init__(self, feature_dim: int, layer_names: list, hidden_dim: int = 256, ac_fn: str = 'relu',
norm_fn: str = 'bn'):
super().__init__()
self.layers = nn.ModuleList([
AttentionalPropagation(feature_dim=feature_dim, num_heads=4, hidden_dim=hidden_dim, ac_fn=ac_fn,
norm_fn=norm_fn)
for _ in range(len(layer_names))])
self.names = layer_names
def forward(self, desc0, desc1):
# desc0s = []
# desc1s = []
for i, (layer, name) in enumerate(zip(self.layers, self.names)):
if name == 'cross':
src0, src1 = desc1, desc0
else:
src0, src1 = desc0, desc1
delta0 = layer(desc0, src0)
# prob0 = layer.attn.prob
delta1 = layer(desc1, src1)
# prob1 = layer.attn.prob
desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
# if name == 'cross':
# desc0s.append(desc0)
# desc1s.append(desc1)
return [desc0], [desc1]
def predict(self, desc0, desc1, n_it=-1):
for i, (layer, name) in enumerate(zip(self.layers, self.names)):
if name == 'cross':
src0, src1 = desc1, desc0
else:
src0, src1 = desc0, desc1
delta0 = layer(desc0, src0)
# prob0 = layer.attn.prob
delta1 = layer(desc1, src1)
# prob1 = layer.attn.prob
desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
if name == 'cross' and i == n_it:
break
return [desc0], [desc1]
class GM(nn.Module):
default_config = {
'descriptor_dim': 128,
'hidden_dim': 256,
'keypoint_encoder': [32, 64, 128, 256],
'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total
'sinkhorn_iterations': 20,
'match_threshold': 0.2,
'with_pose': False,
'n_layers': 9,
'n_min_tokens': 256,
'with_sinkhorn': True,
'ac_fn': 'relu',
'norm_fn': 'bn',
'weight_path': None,
}
required_inputs = [
'image0', 'keypoints0', 'scores0', 'descriptors0',
'image1', 'keypoints1', 'scores1', 'descriptors1',
]
def __init__(self, config):
super().__init__()
self.config = {**self.default_config, **config}
print('gm: ', self.config)
self.n_layers = self.config['n_layers']
self.with_sinkhorn = self.config['with_sinkhorn']
self.match_threshold = self.config['match_threshold']
self.sinkhorn_iterations = self.config['sinkhorn_iterations']
self.kenc = KeypointEncoder(
self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128,
self.config['keypoint_encoder'],
ac_fn=self.config['ac_fn'],
norm_fn=self.config['norm_fn'])
self.gnn = AttentionalGNN(
feature_dim=self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128,
hidden_dim=self.config['hidden_dim'],
layer_names=self.config['GNN_layers'],
ac_fn=self.config['ac_fn'],
norm_fn=self.config['norm_fn'],
)
self.final_proj = nn.ModuleList([nn.Conv1d(
self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128,
self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128,
kernel_size=1, bias=True) for _ in range(self.n_layers)])
bin_score = torch.nn.Parameter(torch.tensor(1.))
self.register_parameter('bin_score', bin_score)
self.match_net = None # GraphLoss(config=self.config)
self.self_prob0 = None
self.self_prob1 = None
self.cross_prob0 = None
self.cross_prob1 = None
self.desc_compressor = None
def forward_train(self, data):
pass
def produce_matches(self, data, p=0.2, n_it=-1, **kwargs):
kpts0, kpts1 = data['keypoints0'], data['keypoints1']
scores0, scores1 = data['scores0'], data['scores1']
if kpts0.shape[1] == 0 or kpts1.shape[1] == 0: # no keypoints
shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]
return {
'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0],
'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0],
'matching_scores0': kpts0.new_zeros(shape0)[0],
'matching_scores1': kpts1.new_zeros(shape1)[0],
'skip_train': True
}
if 'norm_keypoints0' in data.keys() and 'norm_keypoints1' in data.keys():
norm_kpts0 = data['norm_keypoints0']
norm_kpts1 = data['norm_keypoints1']
elif 'image0' in data.keys() and 'image1' in data.keys():
norm_kpts0 = normalize_keypoints(kpts0, data['image0'].shape)
norm_kpts1 = normalize_keypoints(kpts1, data['image1'].shape)
elif 'image_shape0' in data.keys() and 'image_shape1' in data.keys():
norm_kpts0 = normalize_keypoints(kpts0, data['image_shape0'])
norm_kpts1 = normalize_keypoints(kpts1, data['image_shape1'])
else:
raise ValueError('Require image shape for keypoint coordinate normalization')
# Keypoint MLP encoder.
enc0, enc1 = self.encode_keypoint(norm_kpts0=norm_kpts0, norm_kpts1=norm_kpts1, scores0=scores0,
scores1=scores1)
if self.config['descriptor_dim'] > 0:
desc0, desc1 = data['descriptors0'], data['descriptors1']
desc0 = desc0.transpose(0, 2, 1) # [B, N, D ] -> [B, D, N]
desc1 = desc1.transpose(0, 2, 1) # [B, N, D ] -> [B, D, N]
with torch.no_grad():
if desc0.shape[1] != self.config['descriptor_dim']:
desc0 = self.desc_compressor(desc0)
if desc1.shape[1] != self.config['descriptor_dim']:
desc1 = self.desc_compressor(desc1)
desc0 = desc0 + enc0
desc1 = desc1 + enc1
else:
desc0 = enc0
desc1 = enc1
desc0s, desc1s = self.gnn.predict(desc0, desc1, n_it=n_it)
mdescs0 = self.final_proj[n_it](desc0s[-1])
mdescs1 = self.final_proj[n_it](desc1s[-1])
dist = torch.einsum('bdn,bdm->bnm', mdescs0, mdescs1)
if self.config['descriptor_dim'] > 0:
dist = dist / self.config['descriptor_dim'] ** .5
else:
dist = dist / 128 ** .5
score = self.compute_score(dist=dist, dustbin=self.bin_score, iteration=self.sinkhorn_iterations)
indices0, indices1, mscores0, mscores1 = self.compute_matches(scores=score, p=p)
output = {
'matches0': indices0, # use -1 for invalid match
'matches1': indices1, # use -1 for invalid match
'matching_scores0': mscores0,
'matching_scores1': mscores1,
}
return output
def forward(self, data, mode=0):
if not self.training:
return self.produce_matches(data=data, n_it=-1)
return self.forward_train(data=data)
def encode_keypoint(self, norm_kpts0, norm_kpts1, scores0, scores1):
return self.kenc(norm_kpts0, scores0), self.kenc(norm_kpts1, scores1)
def compute_distance(self, desc0, desc1, layer_id=-1):
mdesc0 = self.final_proj[layer_id](desc0)
mdesc1 = self.final_proj[layer_id](desc1)
dist = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)
dist = dist / self.config['descriptor_dim'] ** .5
return dist
def compute_score(self, dist, dustbin, iteration):
if self.with_sinkhorn:
score = sink_algorithm(M=dist, dustbin=dustbin,
iteration=iteration) # [nI * nB, N, M]
else:
score = dual_softmax(M=dist, dustbin=dustbin)
return score
def compute_matches(self, scores, p=0.2):
max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
indices0, indices1 = max0.indices, max1.indices
mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
zero = scores.new_tensor(0)
# mscores0 = torch.where(mutual0, max0.values.exp(), zero)
mscores0 = torch.where(mutual0, max0.values, zero)
mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
# valid0 = mutual0 & (mscores0 > self.config['match_threshold'])
valid0 = mutual0 & (mscores0 > p)
valid1 = mutual1 & valid0.gather(1, indices1)
indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
return indices0, indices1, mscores0, mscores1