Spaces:
Sleeping
Sleeping
File size: 7,319 Bytes
251e479 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from .backbone import CNNEncoder
from .transformer import FeatureTransformer, FeatureFlowAttention
from .matching import global_correlation_softmax, local_correlation_softmax
from .geometry import flow_warp
from .utils import normalize_img, feature_add_position
class GMFlow(nn.Module):
def __init__(self,
num_scales=1,
upsample_factor=8,
feature_channels=128,
attention_type='swin',
num_transformer_layers=6,
ffn_dim_expansion=4,
num_head=1,
**kwargs,
):
super(GMFlow, self).__init__()
self.num_scales = num_scales
self.feature_channels = feature_channels
self.upsample_factor = upsample_factor
self.attention_type = attention_type
self.num_transformer_layers = num_transformer_layers
# CNN backbone
self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
# Transformer
self.transformer = FeatureTransformer(num_layers=num_transformer_layers,
d_model=feature_channels,
nhead=num_head,
attention_type=attention_type,
ffn_dim_expansion=ffn_dim_expansion,
)
# flow propagation with self-attn
self.feature_flow_attn = FeatureFlowAttention(in_channels=feature_channels)
# convex upsampling: concat feature0 and flow as input
self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0))
def extract_feature(self, img0, img1):
concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low
# reverse: resolution from low to high
features = features[::-1]
feature0, feature1 = [], []
for i in range(len(features)):
feature = features[i]
chunks = torch.chunk(feature, 2, 0) # tuple
feature0.append(chunks[0])
feature1.append(chunks[1])
return feature0, feature1
def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8,
):
if bilinear:
up_flow = F.interpolate(flow, scale_factor=upsample_factor,
mode='bilinear', align_corners=True) * upsample_factor
else:
# convex upsampling
concat = torch.cat((flow, feature), dim=1)
mask = self.upsampler(concat)
b, flow_channel, h, w = flow.shape
mask = mask.view(b, 1, 9, self.upsample_factor, self.upsample_factor, h, w) # [B, 1, 9, K, K, H, W]
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1)
up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W]
up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
up_flow = up_flow.reshape(b, flow_channel, self.upsample_factor * h,
self.upsample_factor * w) # [B, 2, K*H, K*W]
return up_flow
def forward(self, img0, img1,
attn_splits_list=None,
corr_radius_list=None,
prop_radius_list=None,
pred_bidir_flow=False,
**kwargs,
):
results_dict = {}
flow_preds = []
img0, img1 = normalize_img(img0, img1) # [B, 3, H, W]
# resolution low to high
feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
flow = None
assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales
for scale_idx in range(self.num_scales):
feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
if pred_bidir_flow and scale_idx > 0:
# predicting bidirectional flow with refinement
feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0)
upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx))
if scale_idx > 0:
flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2
if flow is not None:
flow = flow.detach()
feature1 = flow_warp(feature1, flow) # [B, C, H, W]
attn_splits = attn_splits_list[scale_idx]
corr_radius = corr_radius_list[scale_idx]
prop_radius = prop_radius_list[scale_idx]
# add position to features
feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels)
# Transformer
feature0, feature1 = self.transformer(feature0, feature1, attn_num_splits=attn_splits)
# correlation and softmax
if corr_radius == -1: # global matching
flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0]
else: # local matching
flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0]
# flow or residual flow
flow = flow + flow_pred if flow is not None else flow_pred
# upsample to the original resolution for supervison
if self.training: # only need to upsample intermediate flow predictions at training time
flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor)
flow_preds.append(flow_bilinear)
# flow propagation with self-attn
if pred_bidir_flow and scale_idx == 0:
feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation
flow = self.feature_flow_attn(feature0, flow.detach(),
local_window_attn=prop_radius > 0,
local_window_radius=prop_radius)
# bilinear upsampling at training time except the last one
if self.training and scale_idx < self.num_scales - 1:
flow_up = self.upsample_flow(flow, feature0, bilinear=True, upsample_factor=upsample_factor)
flow_preds.append(flow_up)
if scale_idx == self.num_scales - 1:
flow_up = self.upsample_flow(flow, feature0)
flow_preds.append(flow_up)
results_dict.update({'flow_preds': flow_preds})
return results_dict
|