Spaces:
Running
Running
# -*- coding: UTF-8 -*- | |
'''================================================= | |
@Project -> File pram -> gm | |
@IDE PyCharm | |
@Author fx221@cam.ac.uk | |
@Date 07/02/2024 10:47 | |
==================================================''' | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from nets.layers import KeypointEncoder, AttentionalPropagation | |
from nets.utils import normalize_keypoints, arange_like | |
eps = 1e-8 | |
def dual_softmax(M, dustbin): | |
M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) | |
M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) | |
score = torch.log_softmax(M, dim=-1) + torch.log_softmax(M, dim=1) | |
return torch.exp(score) | |
def sinkhorn(M, r, c, iteration): | |
p = torch.softmax(M, dim=-1) | |
u = torch.ones_like(r) | |
v = torch.ones_like(c) | |
for _ in range(iteration): | |
u = r / ((p * v.unsqueeze(-2)).sum(-1) + eps) | |
v = c / ((p * u.unsqueeze(-1)).sum(-2) + eps) | |
p = p * u.unsqueeze(-1) * v.unsqueeze(-2) | |
return p | |
def sink_algorithm(M, dustbin, iteration): | |
M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) | |
M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) | |
r = torch.ones([M.shape[0], M.shape[1] - 1], device='cuda') | |
r = torch.cat([r, torch.ones([M.shape[0], 1], device='cuda') * M.shape[1]], dim=-1) | |
c = torch.ones([M.shape[0], M.shape[2] - 1], device='cuda') | |
c = torch.cat([c, torch.ones([M.shape[0], 1], device='cuda') * M.shape[2]], dim=-1) | |
p = sinkhorn(M, r, c, iteration) | |
return p | |
class AttentionalGNN(nn.Module): | |
def __init__(self, feature_dim: int, layer_names: list, hidden_dim: int = 256, ac_fn: str = 'relu', | |
norm_fn: str = 'bn'): | |
super().__init__() | |
self.layers = nn.ModuleList([ | |
AttentionalPropagation(feature_dim=feature_dim, num_heads=4, hidden_dim=hidden_dim, ac_fn=ac_fn, | |
norm_fn=norm_fn) | |
for _ in range(len(layer_names))]) | |
self.names = layer_names | |
def forward(self, desc0, desc1): | |
# desc0s = [] | |
# desc1s = [] | |
for i, (layer, name) in enumerate(zip(self.layers, self.names)): | |
if name == 'cross': | |
src0, src1 = desc1, desc0 | |
else: | |
src0, src1 = desc0, desc1 | |
delta0 = layer(desc0, src0) | |
# prob0 = layer.attn.prob | |
delta1 = layer(desc1, src1) | |
# prob1 = layer.attn.prob | |
desc0, desc1 = (desc0 + delta0), (desc1 + delta1) | |
# if name == 'cross': | |
# desc0s.append(desc0) | |
# desc1s.append(desc1) | |
return [desc0], [desc1] | |
def predict(self, desc0, desc1, n_it=-1): | |
for i, (layer, name) in enumerate(zip(self.layers, self.names)): | |
if name == 'cross': | |
src0, src1 = desc1, desc0 | |
else: | |
src0, src1 = desc0, desc1 | |
delta0 = layer(desc0, src0) | |
# prob0 = layer.attn.prob | |
delta1 = layer(desc1, src1) | |
# prob1 = layer.attn.prob | |
desc0, desc1 = (desc0 + delta0), (desc1 + delta1) | |
if name == 'cross' and i == n_it: | |
break | |
return [desc0], [desc1] | |
class GM(nn.Module): | |
default_config = { | |
'descriptor_dim': 128, | |
'hidden_dim': 256, | |
'keypoint_encoder': [32, 64, 128, 256], | |
'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total | |
'sinkhorn_iterations': 20, | |
'match_threshold': 0.2, | |
'with_pose': False, | |
'n_layers': 9, | |
'n_min_tokens': 256, | |
'with_sinkhorn': True, | |
'ac_fn': 'relu', | |
'norm_fn': 'bn', | |
'weight_path': None, | |
} | |
required_inputs = [ | |
'image0', 'keypoints0', 'scores0', 'descriptors0', | |
'image1', 'keypoints1', 'scores1', 'descriptors1', | |
] | |
def __init__(self, config): | |
super().__init__() | |
self.config = {**self.default_config, **config} | |
print('gm: ', self.config) | |
self.n_layers = self.config['n_layers'] | |
self.with_sinkhorn = self.config['with_sinkhorn'] | |
self.match_threshold = self.config['match_threshold'] | |
self.sinkhorn_iterations = self.config['sinkhorn_iterations'] | |
self.kenc = KeypointEncoder( | |
self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128, | |
self.config['keypoint_encoder'], | |
ac_fn=self.config['ac_fn'], | |
norm_fn=self.config['norm_fn']) | |
self.gnn = AttentionalGNN( | |
feature_dim=self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128, | |
hidden_dim=self.config['hidden_dim'], | |
layer_names=self.config['GNN_layers'], | |
ac_fn=self.config['ac_fn'], | |
norm_fn=self.config['norm_fn'], | |
) | |
self.final_proj = nn.ModuleList([nn.Conv1d( | |
self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128, | |
self.config['descriptor_dim'] if self.config['descriptor_dim'] > 0 else 128, | |
kernel_size=1, bias=True) for _ in range(self.n_layers)]) | |
bin_score = torch.nn.Parameter(torch.tensor(1.)) | |
self.register_parameter('bin_score', bin_score) | |
self.match_net = None # GraphLoss(config=self.config) | |
self.self_prob0 = None | |
self.self_prob1 = None | |
self.cross_prob0 = None | |
self.cross_prob1 = None | |
self.desc_compressor = None | |
def forward_train(self, data): | |
pass | |
def produce_matches(self, data, p=0.2, n_it=-1, **kwargs): | |
kpts0, kpts1 = data['keypoints0'], data['keypoints1'] | |
scores0, scores1 = data['scores0'], data['scores1'] | |
if kpts0.shape[1] == 0 or kpts1.shape[1] == 0: # no keypoints | |
shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1] | |
return { | |
'matches0': kpts0.new_full(shape0, -1, dtype=torch.int)[0], | |
'matches1': kpts1.new_full(shape1, -1, dtype=torch.int)[0], | |
'matching_scores0': kpts0.new_zeros(shape0)[0], | |
'matching_scores1': kpts1.new_zeros(shape1)[0], | |
'skip_train': True | |
} | |
if 'norm_keypoints0' in data.keys() and 'norm_keypoints1' in data.keys(): | |
norm_kpts0 = data['norm_keypoints0'] | |
norm_kpts1 = data['norm_keypoints1'] | |
elif 'image0' in data.keys() and 'image1' in data.keys(): | |
norm_kpts0 = normalize_keypoints(kpts0, data['image0'].shape) | |
norm_kpts1 = normalize_keypoints(kpts1, data['image1'].shape) | |
elif 'image_shape0' in data.keys() and 'image_shape1' in data.keys(): | |
norm_kpts0 = normalize_keypoints(kpts0, data['image_shape0']) | |
norm_kpts1 = normalize_keypoints(kpts1, data['image_shape1']) | |
else: | |
raise ValueError('Require image shape for keypoint coordinate normalization') | |
# Keypoint MLP encoder. | |
enc0, enc1 = self.encode_keypoint(norm_kpts0=norm_kpts0, norm_kpts1=norm_kpts1, scores0=scores0, | |
scores1=scores1) | |
if self.config['descriptor_dim'] > 0: | |
desc0, desc1 = data['descriptors0'], data['descriptors1'] | |
desc0 = desc0.transpose(0, 2, 1) # [B, N, D ] -> [B, D, N] | |
desc1 = desc1.transpose(0, 2, 1) # [B, N, D ] -> [B, D, N] | |
with torch.no_grad(): | |
if desc0.shape[1] != self.config['descriptor_dim']: | |
desc0 = self.desc_compressor(desc0) | |
if desc1.shape[1] != self.config['descriptor_dim']: | |
desc1 = self.desc_compressor(desc1) | |
desc0 = desc0 + enc0 | |
desc1 = desc1 + enc1 | |
else: | |
desc0 = enc0 | |
desc1 = enc1 | |
desc0s, desc1s = self.gnn.predict(desc0, desc1, n_it=n_it) | |
mdescs0 = self.final_proj[n_it](desc0s[-1]) | |
mdescs1 = self.final_proj[n_it](desc1s[-1]) | |
dist = torch.einsum('bdn,bdm->bnm', mdescs0, mdescs1) | |
if self.config['descriptor_dim'] > 0: | |
dist = dist / self.config['descriptor_dim'] ** .5 | |
else: | |
dist = dist / 128 ** .5 | |
score = self.compute_score(dist=dist, dustbin=self.bin_score, iteration=self.sinkhorn_iterations) | |
indices0, indices1, mscores0, mscores1 = self.compute_matches(scores=score, p=p) | |
output = { | |
'matches0': indices0, # use -1 for invalid match | |
'matches1': indices1, # use -1 for invalid match | |
'matching_scores0': mscores0, | |
'matching_scores1': mscores1, | |
} | |
return output | |
def forward(self, data, mode=0): | |
if not self.training: | |
return self.produce_matches(data=data, n_it=-1) | |
return self.forward_train(data=data) | |
def encode_keypoint(self, norm_kpts0, norm_kpts1, scores0, scores1): | |
return self.kenc(norm_kpts0, scores0), self.kenc(norm_kpts1, scores1) | |
def compute_distance(self, desc0, desc1, layer_id=-1): | |
mdesc0 = self.final_proj[layer_id](desc0) | |
mdesc1 = self.final_proj[layer_id](desc1) | |
dist = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1) | |
dist = dist / self.config['descriptor_dim'] ** .5 | |
return dist | |
def compute_score(self, dist, dustbin, iteration): | |
if self.with_sinkhorn: | |
score = sink_algorithm(M=dist, dustbin=dustbin, | |
iteration=iteration) # [nI * nB, N, M] | |
else: | |
score = dual_softmax(M=dist, dustbin=dustbin) | |
return score | |
def compute_matches(self, scores, p=0.2): | |
max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) | |
indices0, indices1 = max0.indices, max1.indices | |
mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) | |
mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) | |
zero = scores.new_tensor(0) | |
# mscores0 = torch.where(mutual0, max0.values.exp(), zero) | |
mscores0 = torch.where(mutual0, max0.values, zero) | |
mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) | |
# valid0 = mutual0 & (mscores0 > self.config['match_threshold']) | |
valid0 = mutual0 & (mscores0 > p) | |
valid1 = mutual1 & valid0.gather(1, indices1) | |
indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) | |
indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) | |
return indices0, indices1, mscores0, mscores1 | |