Realcat
fix: eloftr
63f3cf2
raw
history blame
3.88 kB
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File pram -> segnet
@IDE PyCharm
@Author fx221@cam.ac.uk
@Date 29/01/2024 14:46
=================================================='''
import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.layers import MLP, KeypointEncoder
from nets.layers import AttentionalPropagation
from nets.utils import normalize_keypoints
class SegGNN(nn.Module):
def __init__(self, feature_dim: int, n_layers: int, ac_fn: str = 'relu', norm_fn: str = 'bn', **kwargs):
super().__init__()
self.layers = nn.ModuleList([
AttentionalPropagation(feature_dim, 4, ac_fn=ac_fn, norm_fn=norm_fn)
for _ in range(n_layers)
])
def forward(self, desc):
for i, layer in enumerate(self.layers):
delta = layer(desc, desc)
desc = desc + delta
return desc
class SegNet(nn.Module):
default_config = {
'descriptor_dim': 256,
'output_dim': 1024,
'n_class': 512,
'keypoint_encoder': [32, 64, 128, 256],
'n_layers': 9,
'ac_fn': 'relu',
'norm_fn': 'in',
'with_score': False,
# 'with_global': False,
'with_cls': False,
'with_sc': False,
}
def __init__(self, config={}):
super().__init__()
self.config = {**self.default_config, **config}
self.with_cls = self.config['with_cls']
self.with_sc = self.config['with_sc']
self.n_layers = self.config['n_layers']
self.gnn = SegGNN(
feature_dim=self.config['descriptor_dim'],
n_layers=self.config['n_layers'],
ac_fn=self.config['ac_fn'],
norm_fn=self.config['norm_fn'],
)
self.with_score = self.config['with_score']
self.kenc = KeypointEncoder(
input_dim=3 if self.with_score else 2,
feature_dim=self.config['descriptor_dim'],
layers=self.config['keypoint_encoder'],
ac_fn=self.config['ac_fn'],
norm_fn=self.config['norm_fn']
)
self.seg = MLP(channels=[self.config['descriptor_dim'],
self.config['output_dim'],
self.config['n_class']],
ac_fn=self.config['ac_fn'],
norm_fn=self.config['norm_fn']
)
if self.with_sc:
self.sc = MLP(channels=[self.config['descriptor_dim'],
self.config['output_dim'],
3],
ac_fn=self.config['ac_fn'],
norm_fn=self.config['norm_fn']
)
def preprocess(self, data):
desc0 = data['seg_descriptors']
desc0 = desc0.transpose(1, 2) # [B, N, D] - > [B, D, N]
if 'norm_keypoints' in data.keys():
norm_kpts0 = data['norm_keypoints']
elif 'image' in data.keys():
kpts0 = data['keypoints']
norm_kpts0 = normalize_keypoints(kpts0, data['image'].shape)
else:
raise ValueError('Require image shape for keypoint coordinate normalization')
# Keypoint MLP encoder.
if self.with_score:
scores0 = data['scores']
else:
scores0 = None
enc0 = self.kenc(norm_kpts0, scores0)
return desc0, enc0
def forward(self, data):
desc, enc = self.preprocess(data=data)
desc = desc + enc
desc = self.gnn(desc)
cls_output = self.seg(desc) # [B, C, N]
output = {
'prediction': cls_output.transpose(-1, -2).contiguous(),
}
if self.with_sc:
sc_output = self.sc(desc)
output['sc'] = sc_output
return output