Arnaudding001 commited on
Commit
b7bf749
1 Parent(s): e9f92a9

Create raft_train.py

Browse files
Files changed (1) hide show
  1. raft_train.py +247 -0
raft_train.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+ import sys
3
+ sys.path.append('core')
4
+
5
+ import argparse
6
+ import os
7
+ import cv2
8
+ import time
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+ import torch.nn.functional as F
16
+
17
+ from torch.utils.data import DataLoader
18
+ from raft import RAFT
19
+ import evaluate
20
+ import datasets
21
+
22
+ from torch.utils.tensorboard import SummaryWriter
23
+
24
+ try:
25
+ from torch.cuda.amp import GradScaler
26
+ except:
27
+ # dummy GradScaler for PyTorch < 1.6
28
+ class GradScaler:
29
+ def __init__(self):
30
+ pass
31
+ def scale(self, loss):
32
+ return loss
33
+ def unscale_(self, optimizer):
34
+ pass
35
+ def step(self, optimizer):
36
+ optimizer.step()
37
+ def update(self):
38
+ pass
39
+
40
+
41
+ # exclude extremly large displacements
42
+ MAX_FLOW = 400
43
+ SUM_FREQ = 100
44
+ VAL_FREQ = 5000
45
+
46
+
47
+ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
48
+ """ Loss function defined over sequence of flow predictions """
49
+
50
+ n_predictions = len(flow_preds)
51
+ flow_loss = 0.0
52
+
53
+ # exlude invalid pixels and extremely large diplacements
54
+ mag = torch.sum(flow_gt**2, dim=1).sqrt()
55
+ valid = (valid >= 0.5) & (mag < max_flow)
56
+
57
+ for i in range(n_predictions):
58
+ i_weight = gamma**(n_predictions - i - 1)
59
+ i_loss = (flow_preds[i] - flow_gt).abs()
60
+ flow_loss += i_weight * (valid[:, None] * i_loss).mean()
61
+
62
+ epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
63
+ epe = epe.view(-1)[valid.view(-1)]
64
+
65
+ metrics = {
66
+ 'epe': epe.mean().item(),
67
+ '1px': (epe < 1).float().mean().item(),
68
+ '3px': (epe < 3).float().mean().item(),
69
+ '5px': (epe < 5).float().mean().item(),
70
+ }
71
+
72
+ return flow_loss, metrics
73
+
74
+
75
+ def count_parameters(model):
76
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
77
+
78
+
79
+ def fetch_optimizer(args, model):
80
+ """ Create the optimizer and learning rate scheduler """
81
+ optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
82
+
83
+ scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
84
+ pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
85
+
86
+ return optimizer, scheduler
87
+
88
+
89
+ class Logger:
90
+ def __init__(self, model, scheduler):
91
+ self.model = model
92
+ self.scheduler = scheduler
93
+ self.total_steps = 0
94
+ self.running_loss = {}
95
+ self.writer = None
96
+
97
+ def _print_training_status(self):
98
+ metrics_data = [self.running_loss[k]/SUM_FREQ for k in sorted(self.running_loss.keys())]
99
+ training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0])
100
+ metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
101
+
102
+ # print the training status
103
+ print(training_str + metrics_str)
104
+
105
+ if self.writer is None:
106
+ self.writer = SummaryWriter()
107
+
108
+ for k in self.running_loss:
109
+ self.writer.add_scalar(k, self.running_loss[k]/SUM_FREQ, self.total_steps)
110
+ self.running_loss[k] = 0.0
111
+
112
+ def push(self, metrics):
113
+ self.total_steps += 1
114
+
115
+ for key in metrics:
116
+ if key not in self.running_loss:
117
+ self.running_loss[key] = 0.0
118
+
119
+ self.running_loss[key] += metrics[key]
120
+
121
+ if self.total_steps % SUM_FREQ == SUM_FREQ-1:
122
+ self._print_training_status()
123
+ self.running_loss = {}
124
+
125
+ def write_dict(self, results):
126
+ if self.writer is None:
127
+ self.writer = SummaryWriter()
128
+
129
+ for key in results:
130
+ self.writer.add_scalar(key, results[key], self.total_steps)
131
+
132
+ def close(self):
133
+ self.writer.close()
134
+
135
+
136
+ def train(args):
137
+
138
+ model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
139
+ print("Parameter Count: %d" % count_parameters(model))
140
+
141
+ if args.restore_ckpt is not None:
142
+ model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
143
+
144
+ model.cuda()
145
+ model.train()
146
+
147
+ if args.stage != 'chairs':
148
+ model.module.freeze_bn()
149
+
150
+ train_loader = datasets.fetch_dataloader(args)
151
+ optimizer, scheduler = fetch_optimizer(args, model)
152
+
153
+ total_steps = 0
154
+ scaler = GradScaler(enabled=args.mixed_precision)
155
+ logger = Logger(model, scheduler)
156
+
157
+ VAL_FREQ = 5000
158
+ add_noise = True
159
+
160
+ should_keep_training = True
161
+ while should_keep_training:
162
+
163
+ for i_batch, data_blob in enumerate(train_loader):
164
+ optimizer.zero_grad()
165
+ image1, image2, flow, valid = [x.cuda() for x in data_blob]
166
+
167
+ if args.add_noise:
168
+ stdv = np.random.uniform(0.0, 5.0)
169
+ image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
170
+ image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)
171
+
172
+ flow_predictions = model(image1, image2, iters=args.iters)
173
+
174
+ loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)
175
+ scaler.scale(loss).backward()
176
+ scaler.unscale_(optimizer)
177
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
178
+
179
+ scaler.step(optimizer)
180
+ scheduler.step()
181
+ scaler.update()
182
+
183
+ logger.push(metrics)
184
+
185
+ if total_steps % VAL_FREQ == VAL_FREQ - 1:
186
+ PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
187
+ torch.save(model.state_dict(), PATH)
188
+
189
+ results = {}
190
+ for val_dataset in args.validation:
191
+ if val_dataset == 'chairs':
192
+ results.update(evaluate.validate_chairs(model.module))
193
+ elif val_dataset == 'sintel':
194
+ results.update(evaluate.validate_sintel(model.module))
195
+ elif val_dataset == 'kitti':
196
+ results.update(evaluate.validate_kitti(model.module))
197
+
198
+ logger.write_dict(results)
199
+
200
+ model.train()
201
+ if args.stage != 'chairs':
202
+ model.module.freeze_bn()
203
+
204
+ total_steps += 1
205
+
206
+ if total_steps > args.num_steps:
207
+ should_keep_training = False
208
+ break
209
+
210
+ logger.close()
211
+ PATH = 'checkpoints/%s.pth' % args.name
212
+ torch.save(model.state_dict(), PATH)
213
+
214
+ return PATH
215
+
216
+
217
+ if __name__ == '__main__':
218
+ parser = argparse.ArgumentParser()
219
+ parser.add_argument('--name', default='raft', help="name your experiment")
220
+ parser.add_argument('--stage', help="determines which dataset to use for training")
221
+ parser.add_argument('--restore_ckpt', help="restore checkpoint")
222
+ parser.add_argument('--small', action='store_true', help='use small model')
223
+ parser.add_argument('--validation', type=str, nargs='+')
224
+
225
+ parser.add_argument('--lr', type=float, default=0.00002)
226
+ parser.add_argument('--num_steps', type=int, default=100000)
227
+ parser.add_argument('--batch_size', type=int, default=6)
228
+ parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
229
+ parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
230
+ parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
231
+
232
+ parser.add_argument('--iters', type=int, default=12)
233
+ parser.add_argument('--wdecay', type=float, default=.00005)
234
+ parser.add_argument('--epsilon', type=float, default=1e-8)
235
+ parser.add_argument('--clip', type=float, default=1.0)
236
+ parser.add_argument('--dropout', type=float, default=0.0)
237
+ parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
238
+ parser.add_argument('--add_noise', action='store_true')
239
+ args = parser.parse_args()
240
+
241
+ torch.manual_seed(1234)
242
+ np.random.seed(1234)
243
+
244
+ if not os.path.isdir('checkpoints'):
245
+ os.mkdir('checkpoints')
246
+
247
+ train(args)