# Copyright (c) OpenMMLab. All rights reserved. # from mmflow import re from io import BytesIO from typing import Tuple import cv2 import matplotlib.pyplot as plt import mmcv import numpy as np from numpy import ndarray def read_flow(name: str) -> np.ndarray: """Read flow file with the suffix '.flo'. This function is modified from https://lmb.informatik.uni-freiburg.de/resources/datasets/IO.py Copyright (c) 2011, LMB, University of Freiburg. Args: name (str): Optical flow file path. Returns: ndarray: Optical flow """ with open(name, 'rb') as f: header = f.read(4) if header.decode('utf-8') != 'PIEH': raise Exception('Flow file header does not contain PIEH') width = np.fromfile(f, np.int32, 1).squeeze() height = np.fromfile(f, np.int32, 1).squeeze() flow = np.fromfile(f, np.float32, width * height * 2).reshape( (height, width, 2)) return flow def write_flow(flow: np.ndarray, flow_file: str) -> None: """Write the flow in disk. This function is modified from https://lmb.informatik.uni-freiburg.de/resources/datasets/IO.py Copyright (c) 2011, LMB, University of Freiburg. Args: flow (ndarray): The optical flow that will be saved. flow_file (str): The file for saving optical flow. """ with open(flow_file, 'wb') as f: f.write('PIEH'.encode('utf-8')) np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) flow = flow.astype(np.float32) flow.tofile(f) def visualize_flow(flow: np.ndarray, save_file: str = None) -> np.ndarray: """Flow visualization function. Args: flow (ndarray): The flow will be render save_dir ([type], optional): save dir. Defaults to None. Returns: ndarray: flow map image with RGB order. """ # return value from mmcv.flow2rgb is [0, 1.] with type np.float32 flow_map = np.uint8(mmcv.flow2rgb(flow) * 255.) if save_file: plt.imsave(save_file, flow_map) return flow_map def render_color_wheel(save_file: str = 'color_wheel.png') -> np.ndarray: """Render color wheel. Args: save_file (str): The saved file name . Defaults to 'color_wheel.png'. Returns: ndarray: color wheel image. """ x0 = 75 y0 = 75 height = 151 width = 151 flow = np.zeros((height, width, 2), dtype=np.float32) grid_x = np.tile(np.expand_dims(np.arange(width), 0), [height, 1]) grid_y = np.tile(np.expand_dims(np.arange(height), 1), [1, width]) grid_x0 = np.tile(np.array([x0]), [height, width]) grid_y0 = np.tile(np.array([y0]), [height, width]) flow[:, :, 0] = grid_x - grid_x0 flow[:, :, 1] = grid_y - grid_y0 return visualize_flow(flow, save_file) def read_flow_kitti(name: str) -> Tuple[np.ndarray, np.ndarray]: """Read sparse flow file from KITTI dataset. This function is modified from https://github.com/princeton-vl/RAFT/blob/master/core/utils/frame_utils.py. Copyright (c) 2020, princeton-vl Licensed under the BSD 3-Clause License Args: name (str): The flow file Returns: Tuple[ndarray, ndarray]: flow and valid map """ # to specify not to change the image depth (16bit) flow = cv2.imread(name, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) flow = flow[:, :, ::-1].astype(np.float32) # flow shape (H, W, 2) valid shape (H, W) flow, valid = flow[:, :, :2], flow[:, :, 2] flow = (flow - 2**15) / 64.0 return flow, valid def write_flow_kitti(uv: np.ndarray, filename: str): """Write the flow in disk. This function is modified from https://github.com/princeton-vl/RAFT/blob/master/core/utils/frame_utils.py. Copyright (c) 2020, princeton-vl Licensed under the BSD 3-Clause License Args: uv (ndarray): The optical flow that will be saved. filename ([type]): The file for saving optical flow. """ uv = 64.0 * uv + 2**15 valid = np.ones([uv.shape[0], uv.shape[1], 1]) uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) cv2.imwrite(filename, uv[..., ::-1]) def flow_from_bytes(content: bytes, suffix: str = 'flo') -> ndarray: """Read dense optical flow from bytes. .. note:: This load optical flow function works for FlyingChairs, FlyingThings3D, Sintel, FlyingChairsOcc datasets, but cannot load the data from ChairsSDHom. Args: content (bytes): Optical flow bytes got from files or other streams. Returns: ndarray: Loaded optical flow with the shape (H, W, 2). """ assert suffix in ('flo', 'pfm'), 'suffix of flow file must be `flo` '\ f'or `pfm`, but got {suffix}' if suffix == 'flo': return flo_from_bytes(content) else: return pfm_from_bytes(content) def flo_from_bytes(content: bytes): """Decode bytes based on flo file. Args: content (bytes): Optical flow bytes got from files or other streams. Returns: ndarray: Loaded optical flow with the shape (H, W, 2). """ # header in first 4 bytes header = content[:4] if header != b'PIEH': raise Exception('Flow file header does not contain PIEH') # width in second 4 bytes width = np.frombuffer(content[4:], np.int32, 1).squeeze() # height in third 4 bytes height = np.frombuffer(content[8:], np.int32, 1).squeeze() # after first 12 bytes, all bytes are flow flow = np.frombuffer(content[12:], np.float32, width * height * 2).reshape( (height, width, 2)) return flow def pfm_from_bytes(content: bytes) -> np.ndarray: """Load the file with the suffix '.pfm'. Args: content (bytes): Optical flow bytes got from files or other streams. Returns: ndarray: The loaded data """ file = BytesIO(content) color = None width = None height = None scale = None endian = None header = file.readline().rstrip() if header == b'PF': color = True elif header == b'Pf': color = False else: raise Exception('Not a PFM file.') dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) if dim_match: width, height = list(map(int, dim_match.groups())) else: raise Exception('Malformed PFM header.') scale = float(file.readline().rstrip()) if scale < 0: # little-endian endian = '<' scale = -scale else: endian = '>' # big-endian data = np.frombuffer(file.read(), endian + 'f') shape = (height, width, 3) if color else (height, width) data = np.reshape(data, shape) data = np.flipud(data) return data[:, :, :-1] def read_pfm(file: str) -> np.ndarray: """Load the file with the suffix '.pfm'. This function is modified from https://lmb.informatik.uni-freiburg.de/resources/datasets/IO.py Copyright (c) 2011, LMB, University of Freiburg. Args: file (str): The file name will be loaded Returns: ndarray: The loaded data """ file = open(file, 'rb') color = None width = None height = None scale = None endian = None header = file.readline().rstrip() if header.decode('ascii') == 'PF': color = True elif header.decode('ascii') == 'Pf': color = False else: raise Exception('Not a PFM file.') dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('ascii')) if dim_match: width, height = list(map(int, dim_match.groups())) else: raise Exception('Malformed PFM header.') scale = float(file.readline().decode('ascii').rstrip()) if scale < 0: # little-endian endian = '<' scale = -scale else: endian = '>' # big-endian data = np.fromfile(file, endian + 'f') shape = (height, width, 3) if color else (height, width) data = np.reshape(data, shape) data = np.flipud(data) return data[:, :, :-1]