File size: 3,710 Bytes
320e465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# -*- coding: utf-8 -*-
import sys 
sys.path.append(".") 

import os
import cv2
import argparse
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision import transforms

from RAFT import RAFT
from utils.flow_util import *

def imwrite(img, file_path, params=None, auto_mkdir=True):
    if auto_mkdir:
        dir_name = os.path.abspath(os.path.dirname(file_path))
        os.makedirs(dir_name, exist_ok=True)
    return cv2.imwrite(file_path, img, params)

def initialize_RAFT(model_path='weights/raft-things.pth', device='cuda'):
    """Initializes the RAFT model.
    """
    args = argparse.ArgumentParser()
    args.raft_model = model_path
    args.small = False
    args.mixed_precision = False
    args.alternate_corr = False

    model = torch.nn.DataParallel(RAFT(args))
    model.load_state_dict(torch.load(args.raft_model))

    model = model.module
    model.to(device)
    model.eval()

    return model


if __name__ == '__main__':
    device = 'cuda'

    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--root_path', type=str, default='your_dataset_root/youtube-vos/JPEGImages')
    parser.add_argument('-o', '--save_path', type=str, default='your_dataset_root/youtube-vos/Flows_flo')
    parser.add_argument('--height', type=int, default=240)
    parser.add_argument('--width', type=int, default=432)
    
    args = parser.parse_args()
    
    # Flow model
    RAFT_model = initialize_RAFT(device=device)  
    
    root_path = args.root_path
    save_path = args.save_path
    h_new, w_new = (args.height, args.width)
    
    file_list = sorted(os.listdir(root_path))
    for f in file_list:
        print(f'Processing: {f} ...')
        m_list = sorted(os.listdir(os.path.join(root_path, f)))
        len_m = len(m_list)
        for i in range(len_m-1):
            img1_path = os.path.join(root_path, f, m_list[i])
            img2_path = os.path.join(root_path, f, m_list[i+1])
            img1 = Image.fromarray(cv2.imread(img1_path))
            img2 = Image.fromarray(cv2.imread(img2_path))

            transform = transforms.Compose([transforms.ToTensor()])

            img1 = transform(img1).unsqueeze(0).to(device)[:,[2,1,0],:,:]
            img2 = transform(img2).unsqueeze(0).to(device)[:,[2,1,0],:,:]

            # upsize to a multiple of 16
            # h, w = img1.shape[2:4]
            # w_new = w if (w % 16) == 0 else 16 * (w // 16 + 1)
            # h_new = h if (h % 16) == 0 else 16 * (h // 16 + 1)


            img1 = F.interpolate(input=img1,
                                size=(h_new, w_new),
                                mode='bilinear',
                                align_corners=False)
            img2 = F.interpolate(input=img2,
                                size=(h_new, w_new),
                                mode='bilinear',
                                align_corners=False)

            with torch.no_grad():
              img1 = img1*2 - 1
              img2 = img2*2 - 1

              _, flow_f = RAFT_model(img1, img2, iters=20, test_mode=True)
              _, flow_b = RAFT_model(img2, img1, iters=20, test_mode=True)


            flow_f = flow_f[0].permute(1,2,0).cpu().numpy()
            flow_b = flow_b[0].permute(1,2,0).cpu().numpy()

            # flow_f = resize_flow(flow_f, w_new, h_new)
            # flow_b = resize_flow(flow_b, w_new, h_new)

            save_flow_f = os.path.join(save_path, f, f'{m_list[i][:-4]}_{m_list[i+1][:-4]}_f.flo')
            save_flow_b = os.path.join(save_path, f, f'{m_list[i+1][:-4]}_{m_list[i][:-4]}_b.flo')
            
            flowwrite(flow_f, save_flow_f, quantize=False)
            flowwrite(flow_b, save_flow_b, quantize=False)