Arnaudding001 commited on
Commit
8e2d158
1 Parent(s): aba05b0

Update raft_core_raft.py

Browse files
Files changed (1) hide show
  1. raft_core_raft.py +144 -0
raft_core_raft.py CHANGED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from raft_core_update import BasicUpdateBlock, SmallUpdateBlock
7
+ from raft_core_extractor import BasicEncoder, SmallEncoder
8
+ from raft_core_corr import CorrBlock, AlternateCorrBlock
9
+ from raft_core_utils_utils import bilinear_sampler, coords_grid, upflow8
10
+
11
+ try:
12
+ autocast = torch.cuda.amp.autocast
13
+ except:
14
+ # dummy autocast for PyTorch < 1.6
15
+ class autocast:
16
+ def __init__(self, enabled):
17
+ pass
18
+ def __enter__(self):
19
+ pass
20
+ def __exit__(self, *args):
21
+ pass
22
+
23
+
24
+ class RAFT(nn.Module):
25
+ def __init__(self, args):
26
+ super(RAFT, self).__init__()
27
+ self.args = args
28
+
29
+ if args.small:
30
+ self.hidden_dim = hdim = 96
31
+ self.context_dim = cdim = 64
32
+ args.corr_levels = 4
33
+ args.corr_radius = 3
34
+
35
+ else:
36
+ self.hidden_dim = hdim = 128
37
+ self.context_dim = cdim = 128
38
+ args.corr_levels = 4
39
+ args.corr_radius = 4
40
+
41
+ if 'dropout' not in self.args:
42
+ self.args.dropout = 0
43
+
44
+ if 'alternate_corr' not in self.args:
45
+ self.args.alternate_corr = False
46
+
47
+ # feature network, context network, and update block
48
+ if args.small:
49
+ self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
50
+ self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
51
+ self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
52
+
53
+ else:
54
+ self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
55
+ self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
56
+ self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
57
+
58
+ def freeze_bn(self):
59
+ for m in self.modules():
60
+ if isinstance(m, nn.BatchNorm2d):
61
+ m.eval()
62
+
63
+ def initialize_flow(self, img):
64
+ """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
65
+ N, C, H, W = img.shape
66
+ coords0 = coords_grid(N, H//8, W//8, device=img.device)
67
+ coords1 = coords_grid(N, H//8, W//8, device=img.device)
68
+
69
+ # optical flow computed as difference: flow = coords1 - coords0
70
+ return coords0, coords1
71
+
72
+ def upsample_flow(self, flow, mask):
73
+ """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
74
+ N, _, H, W = flow.shape
75
+ mask = mask.view(N, 1, 9, 8, 8, H, W)
76
+ mask = torch.softmax(mask, dim=2)
77
+
78
+ up_flow = F.unfold(8 * flow, [3,3], padding=1)
79
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
80
+
81
+ up_flow = torch.sum(mask * up_flow, dim=2)
82
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
83
+ return up_flow.reshape(N, 2, 8*H, 8*W)
84
+
85
+
86
+ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
87
+ """ Estimate optical flow between pair of frames """
88
+
89
+ image1 = 2 * (image1 / 255.0) - 1.0
90
+ image2 = 2 * (image2 / 255.0) - 1.0
91
+
92
+ image1 = image1.contiguous()
93
+ image2 = image2.contiguous()
94
+
95
+ hdim = self.hidden_dim
96
+ cdim = self.context_dim
97
+
98
+ # run the feature network
99
+ with autocast(enabled=self.args.mixed_precision):
100
+ fmap1, fmap2 = self.fnet([image1, image2])
101
+
102
+ fmap1 = fmap1.float()
103
+ fmap2 = fmap2.float()
104
+ if self.args.alternate_corr:
105
+ corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
106
+ else:
107
+ corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
108
+
109
+ # run the context network
110
+ with autocast(enabled=self.args.mixed_precision):
111
+ cnet = self.cnet(image1)
112
+ net, inp = torch.split(cnet, [hdim, cdim], dim=1)
113
+ net = torch.tanh(net)
114
+ inp = torch.relu(inp)
115
+
116
+ coords0, coords1 = self.initialize_flow(image1)
117
+
118
+ if flow_init is not None:
119
+ coords1 = coords1 + flow_init
120
+
121
+ flow_predictions = []
122
+ for itr in range(iters):
123
+ coords1 = coords1.detach()
124
+ corr = corr_fn(coords1) # index correlation volume
125
+
126
+ flow = coords1 - coords0
127
+ with autocast(enabled=self.args.mixed_precision):
128
+ net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
129
+
130
+ # F(t+1) = F(t) + \Delta(t)
131
+ coords1 = coords1 + delta_flow
132
+
133
+ # upsample predictions
134
+ if up_mask is None:
135
+ flow_up = upflow8(coords1 - coords0)
136
+ else:
137
+ flow_up = self.upsample_flow(coords1 - coords0, up_mask)
138
+
139
+ flow_predictions.append(flow_up)
140
+
141
+ if test_mode:
142
+ return coords1 - coords0, flow_up
143
+
144
+ return flow_predictions