# -*- coding: UTF-8 -*- '''================================================= @Project -> File pram -> segnet @IDE PyCharm @Author fx221@cam.ac.uk @Date 29/01/2024 14:46 ==================================================''' import torch import torch.nn as nn import torch.nn.functional as F from nets.layers import MLP, KeypointEncoder from nets.layers import AttentionalPropagation from nets.utils import normalize_keypoints class SegGNN(nn.Module): def __init__(self, feature_dim: int, n_layers: int, ac_fn: str = 'relu', norm_fn: str = 'bn', **kwargs): super().__init__() self.layers = nn.ModuleList([ AttentionalPropagation(feature_dim, 4, ac_fn=ac_fn, norm_fn=norm_fn) for _ in range(n_layers) ]) def forward(self, desc): for i, layer in enumerate(self.layers): delta = layer(desc, desc) desc = desc + delta return desc class SegNet(nn.Module): default_config = { 'descriptor_dim': 256, 'output_dim': 1024, 'n_class': 512, 'keypoint_encoder': [32, 64, 128, 256], 'n_layers': 9, 'ac_fn': 'relu', 'norm_fn': 'in', 'with_score': False, # 'with_global': False, 'with_cls': False, 'with_sc': False, } def __init__(self, config={}): super().__init__() self.config = {**self.default_config, **config} self.with_cls = self.config['with_cls'] self.with_sc = self.config['with_sc'] self.n_layers = self.config['n_layers'] self.gnn = SegGNN( feature_dim=self.config['descriptor_dim'], n_layers=self.config['n_layers'], ac_fn=self.config['ac_fn'], norm_fn=self.config['norm_fn'], ) self.with_score = self.config['with_score'] self.kenc = KeypointEncoder( input_dim=3 if self.with_score else 2, feature_dim=self.config['descriptor_dim'], layers=self.config['keypoint_encoder'], ac_fn=self.config['ac_fn'], norm_fn=self.config['norm_fn'] ) self.seg = MLP(channels=[self.config['descriptor_dim'], self.config['output_dim'], self.config['n_class']], ac_fn=self.config['ac_fn'], norm_fn=self.config['norm_fn'] ) if self.with_sc: self.sc = MLP(channels=[self.config['descriptor_dim'], self.config['output_dim'], 3], ac_fn=self.config['ac_fn'], norm_fn=self.config['norm_fn'] ) def preprocess(self, data): desc0 = data['seg_descriptors'] desc0 = desc0.transpose(1, 2) # [B, N, D] - > [B, D, N] if 'norm_keypoints' in data.keys(): norm_kpts0 = data['norm_keypoints'] elif 'image' in data.keys(): kpts0 = data['keypoints'] norm_kpts0 = normalize_keypoints(kpts0, data['image'].shape) else: raise ValueError('Require image shape for keypoint coordinate normalization') # Keypoint MLP encoder. if self.with_score: scores0 = data['scores'] else: scores0 = None enc0 = self.kenc(norm_kpts0, scores0) return desc0, enc0 def forward(self, data): desc, enc = self.preprocess(data=data) desc = desc + enc desc = self.gnn(desc) cls_output = self.seg(desc) # [B, C, N] output = { 'prediction': cls_output.transpose(-1, -2).contiguous(), } if self.with_sc: sc_output = self.sc(desc) output['sc'] = sc_output return output