Spaces:
Runtime error
Runtime error
Arnaudding001
commited on
Commit
•
8e2d158
1
Parent(s):
aba05b0
Update raft_core_raft.py
Browse files- raft_core_raft.py +144 -0
raft_core_raft.py
CHANGED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from raft_core_update import BasicUpdateBlock, SmallUpdateBlock
|
7 |
+
from raft_core_extractor import BasicEncoder, SmallEncoder
|
8 |
+
from raft_core_corr import CorrBlock, AlternateCorrBlock
|
9 |
+
from raft_core_utils_utils import bilinear_sampler, coords_grid, upflow8
|
10 |
+
|
11 |
+
try:
|
12 |
+
autocast = torch.cuda.amp.autocast
|
13 |
+
except:
|
14 |
+
# dummy autocast for PyTorch < 1.6
|
15 |
+
class autocast:
|
16 |
+
def __init__(self, enabled):
|
17 |
+
pass
|
18 |
+
def __enter__(self):
|
19 |
+
pass
|
20 |
+
def __exit__(self, *args):
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
class RAFT(nn.Module):
|
25 |
+
def __init__(self, args):
|
26 |
+
super(RAFT, self).__init__()
|
27 |
+
self.args = args
|
28 |
+
|
29 |
+
if args.small:
|
30 |
+
self.hidden_dim = hdim = 96
|
31 |
+
self.context_dim = cdim = 64
|
32 |
+
args.corr_levels = 4
|
33 |
+
args.corr_radius = 3
|
34 |
+
|
35 |
+
else:
|
36 |
+
self.hidden_dim = hdim = 128
|
37 |
+
self.context_dim = cdim = 128
|
38 |
+
args.corr_levels = 4
|
39 |
+
args.corr_radius = 4
|
40 |
+
|
41 |
+
if 'dropout' not in self.args:
|
42 |
+
self.args.dropout = 0
|
43 |
+
|
44 |
+
if 'alternate_corr' not in self.args:
|
45 |
+
self.args.alternate_corr = False
|
46 |
+
|
47 |
+
# feature network, context network, and update block
|
48 |
+
if args.small:
|
49 |
+
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
|
50 |
+
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
|
51 |
+
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
|
52 |
+
|
53 |
+
else:
|
54 |
+
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
|
55 |
+
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
|
56 |
+
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
57 |
+
|
58 |
+
def freeze_bn(self):
|
59 |
+
for m in self.modules():
|
60 |
+
if isinstance(m, nn.BatchNorm2d):
|
61 |
+
m.eval()
|
62 |
+
|
63 |
+
def initialize_flow(self, img):
|
64 |
+
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
|
65 |
+
N, C, H, W = img.shape
|
66 |
+
coords0 = coords_grid(N, H//8, W//8, device=img.device)
|
67 |
+
coords1 = coords_grid(N, H//8, W//8, device=img.device)
|
68 |
+
|
69 |
+
# optical flow computed as difference: flow = coords1 - coords0
|
70 |
+
return coords0, coords1
|
71 |
+
|
72 |
+
def upsample_flow(self, flow, mask):
|
73 |
+
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
|
74 |
+
N, _, H, W = flow.shape
|
75 |
+
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
76 |
+
mask = torch.softmax(mask, dim=2)
|
77 |
+
|
78 |
+
up_flow = F.unfold(8 * flow, [3,3], padding=1)
|
79 |
+
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
80 |
+
|
81 |
+
up_flow = torch.sum(mask * up_flow, dim=2)
|
82 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
83 |
+
return up_flow.reshape(N, 2, 8*H, 8*W)
|
84 |
+
|
85 |
+
|
86 |
+
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
|
87 |
+
""" Estimate optical flow between pair of frames """
|
88 |
+
|
89 |
+
image1 = 2 * (image1 / 255.0) - 1.0
|
90 |
+
image2 = 2 * (image2 / 255.0) - 1.0
|
91 |
+
|
92 |
+
image1 = image1.contiguous()
|
93 |
+
image2 = image2.contiguous()
|
94 |
+
|
95 |
+
hdim = self.hidden_dim
|
96 |
+
cdim = self.context_dim
|
97 |
+
|
98 |
+
# run the feature network
|
99 |
+
with autocast(enabled=self.args.mixed_precision):
|
100 |
+
fmap1, fmap2 = self.fnet([image1, image2])
|
101 |
+
|
102 |
+
fmap1 = fmap1.float()
|
103 |
+
fmap2 = fmap2.float()
|
104 |
+
if self.args.alternate_corr:
|
105 |
+
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
106 |
+
else:
|
107 |
+
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
108 |
+
|
109 |
+
# run the context network
|
110 |
+
with autocast(enabled=self.args.mixed_precision):
|
111 |
+
cnet = self.cnet(image1)
|
112 |
+
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
113 |
+
net = torch.tanh(net)
|
114 |
+
inp = torch.relu(inp)
|
115 |
+
|
116 |
+
coords0, coords1 = self.initialize_flow(image1)
|
117 |
+
|
118 |
+
if flow_init is not None:
|
119 |
+
coords1 = coords1 + flow_init
|
120 |
+
|
121 |
+
flow_predictions = []
|
122 |
+
for itr in range(iters):
|
123 |
+
coords1 = coords1.detach()
|
124 |
+
corr = corr_fn(coords1) # index correlation volume
|
125 |
+
|
126 |
+
flow = coords1 - coords0
|
127 |
+
with autocast(enabled=self.args.mixed_precision):
|
128 |
+
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
|
129 |
+
|
130 |
+
# F(t+1) = F(t) + \Delta(t)
|
131 |
+
coords1 = coords1 + delta_flow
|
132 |
+
|
133 |
+
# upsample predictions
|
134 |
+
if up_mask is None:
|
135 |
+
flow_up = upflow8(coords1 - coords0)
|
136 |
+
else:
|
137 |
+
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
138 |
+
|
139 |
+
flow_predictions.append(flow_up)
|
140 |
+
|
141 |
+
if test_mode:
|
142 |
+
return coords1 - coords0, flow_up
|
143 |
+
|
144 |
+
return flow_predictions
|