import sys #sys.path.append('core') import argparse import os import cv2 import glob import numpy as np import torch from PIL import Image from raft_core_raft import RAFT from raft_core_utils import flow_viz from raft_core_utils_utils import InputPadder DEVICE = 'cuda' def load_image(imfile): img = np.array(Image.open(imfile)).astype(np.uint8) img = torch.from_numpy(img).permute(2, 0, 1).float() return img[None].to(DEVICE) def viz(img, flo): img = img[0].permute(1,2,0).cpu().numpy() flo = flo[0].permute(1,2,0).cpu().numpy() # map flow to rgb image flo = flow_viz.flow_to_image(flo) img_flo = np.concatenate([img, flo], axis=0) # import matplotlib.pyplot as plt # plt.imshow(img_flo / 255.0) # plt.show() cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) cv2.waitKey() def demo(args): model = torch.nn.DataParallel(RAFT(args)) model.load_state_dict(torch.load(args.model)) model = model.module model.to(DEVICE) model.eval() with torch.no_grad(): images = glob.glob(os.path.join(args.path, '*.png')) + \ glob.glob(os.path.join(args.path, '*.jpg')) images = sorted(images) for imfile1, imfile2 in zip(images[:-1], images[1:]): image1 = load_image(imfile1) image2 = load_image(imfile2) padder = InputPadder(image1.shape) image1, image2 = padder.pad(image1, image2) flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) viz(image1, flow_up) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--model', help="restore checkpoint") parser.add_argument('--path', help="dataset for evaluation") parser.add_argument('--small', action='store_true', help='use small model') parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') args = parser.parse_args() demo(args)