Arnaudding001 commited on
Commit
9986e48
1 Parent(s): efb54b3

Create raft_core_datasets.py

Browse files
Files changed (1) hide show
  1. raft_core_datasets.py +234 -0
raft_core_datasets.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.utils.data as data
6
+ import torch.nn.functional as F
7
+
8
+ import os
9
+ import math
10
+ import random
11
+ from glob import glob
12
+ import os.path as osp
13
+
14
+ from raft_core_utils import frame_utils
15
+ from raft_core_utils_augmentor import FlowAugmentor, SparseFlowAugmentor
16
+
17
+
18
+ class FlowDataset(data.Dataset):
19
+ def __init__(self, aug_params=None, sparse=False):
20
+ self.augmentor = None
21
+ self.sparse = sparse
22
+ if aug_params is not None:
23
+ if sparse:
24
+ self.augmentor = SparseFlowAugmentor(**aug_params)
25
+ else:
26
+ self.augmentor = FlowAugmentor(**aug_params)
27
+
28
+ self.is_test = False
29
+ self.init_seed = False
30
+ self.flow_list = []
31
+ self.image_list = []
32
+ self.extra_info = []
33
+
34
+ def __getitem__(self, index):
35
+
36
+ if self.is_test:
37
+ img1 = frame_utils.read_gen(self.image_list[index][0])
38
+ img2 = frame_utils.read_gen(self.image_list[index][1])
39
+ img1 = np.array(img1).astype(np.uint8)[..., :3]
40
+ img2 = np.array(img2).astype(np.uint8)[..., :3]
41
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
42
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
43
+ return img1, img2, self.extra_info[index]
44
+
45
+ if not self.init_seed:
46
+ worker_info = torch.utils.data.get_worker_info()
47
+ if worker_info is not None:
48
+ torch.manual_seed(worker_info.id)
49
+ np.random.seed(worker_info.id)
50
+ random.seed(worker_info.id)
51
+ self.init_seed = True
52
+
53
+ index = index % len(self.image_list)
54
+ valid = None
55
+ if self.sparse:
56
+ flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
57
+ else:
58
+ flow = frame_utils.read_gen(self.flow_list[index])
59
+
60
+ img1 = frame_utils.read_gen(self.image_list[index][0])
61
+ img2 = frame_utils.read_gen(self.image_list[index][1])
62
+
63
+ flow = np.array(flow).astype(np.float32)
64
+ img1 = np.array(img1).astype(np.uint8)
65
+ img2 = np.array(img2).astype(np.uint8)
66
+
67
+ # grayscale images
68
+ if len(img1.shape) == 2:
69
+ img1 = np.tile(img1[...,None], (1, 1, 3))
70
+ img2 = np.tile(img2[...,None], (1, 1, 3))
71
+ else:
72
+ img1 = img1[..., :3]
73
+ img2 = img2[..., :3]
74
+
75
+ if self.augmentor is not None:
76
+ if self.sparse:
77
+ img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
78
+ else:
79
+ img1, img2, flow = self.augmentor(img1, img2, flow)
80
+
81
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
82
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
83
+ flow = torch.from_numpy(flow).permute(2, 0, 1).float()
84
+
85
+ if valid is not None:
86
+ valid = torch.from_numpy(valid)
87
+ else:
88
+ valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
89
+
90
+ return img1, img2, flow, valid.float()
91
+
92
+
93
+ def __rmul__(self, v):
94
+ self.flow_list = v * self.flow_list
95
+ self.image_list = v * self.image_list
96
+ return self
97
+
98
+ def __len__(self):
99
+ return len(self.image_list)
100
+
101
+
102
+ class MpiSintel(FlowDataset):
103
+ def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
104
+ super(MpiSintel, self).__init__(aug_params)
105
+ flow_root = osp.join(root, split, 'flow')
106
+ image_root = osp.join(root, split, dstype)
107
+
108
+ if split == 'test':
109
+ self.is_test = True
110
+
111
+ for scene in os.listdir(image_root):
112
+ image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
113
+ for i in range(len(image_list)-1):
114
+ self.image_list += [ [image_list[i], image_list[i+1]] ]
115
+ self.extra_info += [ (scene, i) ] # scene and frame_id
116
+
117
+ if split != 'test':
118
+ self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
119
+
120
+
121
+ class FlyingChairs(FlowDataset):
122
+ def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
123
+ super(FlyingChairs, self).__init__(aug_params)
124
+
125
+ images = sorted(glob(osp.join(root, '*.ppm')))
126
+ flows = sorted(glob(osp.join(root, '*.flo')))
127
+ assert (len(images)//2 == len(flows))
128
+
129
+ split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
130
+ for i in range(len(flows)):
131
+ xid = split_list[i]
132
+ if (split=='training' and xid==1) or (split=='validation' and xid==2):
133
+ self.flow_list += [ flows[i] ]
134
+ self.image_list += [ [images[2*i], images[2*i+1]] ]
135
+
136
+
137
+ class FlyingThings3D(FlowDataset):
138
+ def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
139
+ super(FlyingThings3D, self).__init__(aug_params)
140
+
141
+ for cam in ['left']:
142
+ for direction in ['into_future', 'into_past']:
143
+ image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
144
+ image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
145
+
146
+ flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
147
+ flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
148
+
149
+ for idir, fdir in zip(image_dirs, flow_dirs):
150
+ images = sorted(glob(osp.join(idir, '*.png')) )
151
+ flows = sorted(glob(osp.join(fdir, '*.pfm')) )
152
+ for i in range(len(flows)-1):
153
+ if direction == 'into_future':
154
+ self.image_list += [ [images[i], images[i+1]] ]
155
+ self.flow_list += [ flows[i] ]
156
+ elif direction == 'into_past':
157
+ self.image_list += [ [images[i+1], images[i]] ]
158
+ self.flow_list += [ flows[i+1] ]
159
+
160
+
161
+ class KITTI(FlowDataset):
162
+ def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
163
+ super(KITTI, self).__init__(aug_params, sparse=True)
164
+ if split == 'testing':
165
+ self.is_test = True
166
+
167
+ root = osp.join(root, split)
168
+ images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
169
+ images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
170
+
171
+ for img1, img2 in zip(images1, images2):
172
+ frame_id = img1.split('/')[-1]
173
+ self.extra_info += [ [frame_id] ]
174
+ self.image_list += [ [img1, img2] ]
175
+
176
+ if split == 'training':
177
+ self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
178
+
179
+
180
+ class HD1K(FlowDataset):
181
+ def __init__(self, aug_params=None, root='datasets/HD1k'):
182
+ super(HD1K, self).__init__(aug_params, sparse=True)
183
+
184
+ seq_ix = 0
185
+ while 1:
186
+ flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
187
+ images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
188
+
189
+ if len(flows) == 0:
190
+ break
191
+
192
+ for i in range(len(flows)-1):
193
+ self.flow_list += [flows[i]]
194
+ self.image_list += [ [images[i], images[i+1]] ]
195
+
196
+ seq_ix += 1
197
+
198
+
199
+ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
200
+ """ Create the data loader for the corresponding trainign set """
201
+
202
+ if args.stage == 'chairs':
203
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
204
+ train_dataset = FlyingChairs(aug_params, split='training')
205
+
206
+ elif args.stage == 'things':
207
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
208
+ clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
209
+ final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
210
+ train_dataset = clean_dataset + final_dataset
211
+
212
+ elif args.stage == 'sintel':
213
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
214
+ things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
215
+ sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
216
+ sintel_final = MpiSintel(aug_params, split='training', dstype='final')
217
+
218
+ if TRAIN_DS == 'C+T+K+S+H':
219
+ kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
220
+ hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
221
+ train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
222
+
223
+ elif TRAIN_DS == 'C+T+K/S':
224
+ train_dataset = 100*sintel_clean + 100*sintel_final + things
225
+
226
+ elif args.stage == 'kitti':
227
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
228
+ train_dataset = KITTI(aug_params, split='training')
229
+
230
+ train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
231
+ pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
232
+
233
+ print('Training with %d image pairs' % len(train_dataset))
234
+ return train_loader