kevinwang676's picture
Upload folder using huggingface_hub
6755a2d verified
# Copyright (c) OpenMMLab. All rights reserved.
# from mmflow
import re
from io import BytesIO
from typing import Tuple
import cv2
import matplotlib.pyplot as plt
import mmcv
import numpy as np
from numpy import ndarray
def read_flow(name: str) -> np.ndarray:
"""Read flow file with the suffix '.flo'.
This function is modified from
https://lmb.informatik.uni-freiburg.de/resources/datasets/IO.py
Copyright (c) 2011, LMB, University of Freiburg.
Args:
name (str): Optical flow file path.
Returns:
ndarray: Optical flow
"""
with open(name, 'rb') as f:
header = f.read(4)
if header.decode('utf-8') != 'PIEH':
raise Exception('Flow file header does not contain PIEH')
width = np.fromfile(f, np.int32, 1).squeeze()
height = np.fromfile(f, np.int32, 1).squeeze()
flow = np.fromfile(f, np.float32, width * height * 2).reshape(
(height, width, 2))
return flow
def write_flow(flow: np.ndarray, flow_file: str) -> None:
"""Write the flow in disk.
This function is modified from
https://lmb.informatik.uni-freiburg.de/resources/datasets/IO.py
Copyright (c) 2011, LMB, University of Freiburg.
Args:
flow (ndarray): The optical flow that will be saved.
flow_file (str): The file for saving optical flow.
"""
with open(flow_file, 'wb') as f:
f.write('PIEH'.encode('utf-8'))
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
flow = flow.astype(np.float32)
flow.tofile(f)
def visualize_flow(flow: np.ndarray, save_file: str = None) -> np.ndarray:
"""Flow visualization function.
Args:
flow (ndarray): The flow will be render
save_dir ([type], optional): save dir. Defaults to None.
Returns:
ndarray: flow map image with RGB order.
"""
# return value from mmcv.flow2rgb is [0, 1.] with type np.float32
flow_map = np.uint8(mmcv.flow2rgb(flow) * 255.)
if save_file:
plt.imsave(save_file, flow_map)
return flow_map
def render_color_wheel(save_file: str = 'color_wheel.png') -> np.ndarray:
"""Render color wheel.
Args:
save_file (str): The saved file name . Defaults to 'color_wheel.png'.
Returns:
ndarray: color wheel image.
"""
x0 = 75
y0 = 75
height = 151
width = 151
flow = np.zeros((height, width, 2), dtype=np.float32)
grid_x = np.tile(np.expand_dims(np.arange(width), 0), [height, 1])
grid_y = np.tile(np.expand_dims(np.arange(height), 1), [1, width])
grid_x0 = np.tile(np.array([x0]), [height, width])
grid_y0 = np.tile(np.array([y0]), [height, width])
flow[:, :, 0] = grid_x - grid_x0
flow[:, :, 1] = grid_y - grid_y0
return visualize_flow(flow, save_file)
def read_flow_kitti(name: str) -> Tuple[np.ndarray, np.ndarray]:
"""Read sparse flow file from KITTI dataset.
This function is modified from
https://github.com/princeton-vl/RAFT/blob/master/core/utils/frame_utils.py.
Copyright (c) 2020, princeton-vl
Licensed under the BSD 3-Clause License
Args:
name (str): The flow file
Returns:
Tuple[ndarray, ndarray]: flow and valid map
"""
# to specify not to change the image depth (16bit)
flow = cv2.imread(name, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
flow = flow[:, :, ::-1].astype(np.float32)
# flow shape (H, W, 2) valid shape (H, W)
flow, valid = flow[:, :, :2], flow[:, :, 2]
flow = (flow - 2**15) / 64.0
return flow, valid
def write_flow_kitti(uv: np.ndarray, filename: str):
"""Write the flow in disk.
This function is modified from
https://github.com/princeton-vl/RAFT/blob/master/core/utils/frame_utils.py.
Copyright (c) 2020, princeton-vl
Licensed under the BSD 3-Clause License
Args:
uv (ndarray): The optical flow that will be saved.
filename ([type]): The file for saving optical flow.
"""
uv = 64.0 * uv + 2**15
valid = np.ones([uv.shape[0], uv.shape[1], 1])
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
cv2.imwrite(filename, uv[..., ::-1])
def flow_from_bytes(content: bytes, suffix: str = 'flo') -> ndarray:
"""Read dense optical flow from bytes.
.. note::
This load optical flow function works for FlyingChairs, FlyingThings3D,
Sintel, FlyingChairsOcc datasets, but cannot load the data from
ChairsSDHom.
Args:
content (bytes): Optical flow bytes got from files or other streams.
Returns:
ndarray: Loaded optical flow with the shape (H, W, 2).
"""
assert suffix in ('flo', 'pfm'), 'suffix of flow file must be `flo` '\
f'or `pfm`, but got {suffix}'
if suffix == 'flo':
return flo_from_bytes(content)
else:
return pfm_from_bytes(content)
def flo_from_bytes(content: bytes):
"""Decode bytes based on flo file.
Args:
content (bytes): Optical flow bytes got from files or other streams.
Returns:
ndarray: Loaded optical flow with the shape (H, W, 2).
"""
# header in first 4 bytes
header = content[:4]
if header != b'PIEH':
raise Exception('Flow file header does not contain PIEH')
# width in second 4 bytes
width = np.frombuffer(content[4:], np.int32, 1).squeeze()
# height in third 4 bytes
height = np.frombuffer(content[8:], np.int32, 1).squeeze()
# after first 12 bytes, all bytes are flow
flow = np.frombuffer(content[12:], np.float32, width * height * 2).reshape(
(height, width, 2))
return flow
def pfm_from_bytes(content: bytes) -> np.ndarray:
"""Load the file with the suffix '.pfm'.
Args:
content (bytes): Optical flow bytes got from files or other streams.
Returns:
ndarray: The loaded data
"""
file = BytesIO(content)
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header == b'PF':
color = True
elif header == b'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
if dim_match:
width, height = list(map(int, dim_match.groups()))
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().rstrip())
if scale < 0: # little-endian
endian = '<'
scale = -scale
else:
endian = '>' # big-endian
data = np.frombuffer(file.read(), endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data[:, :, :-1]
def read_pfm(file: str) -> np.ndarray:
"""Load the file with the suffix '.pfm'.
This function is modified from
https://lmb.informatik.uni-freiburg.de/resources/datasets/IO.py
Copyright (c) 2011, LMB, University of Freiburg.
Args:
file (str): The file name will be loaded
Returns:
ndarray: The loaded data
"""
file = open(file, 'rb')
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header.decode('ascii') == 'PF':
color = True
elif header.decode('ascii') == 'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('ascii'))
if dim_match:
width, height = list(map(int, dim_match.groups()))
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().decode('ascii').rstrip())
if scale < 0: # little-endian
endian = '<'
scale = -scale
else:
endian = '>' # big-endian
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data[:, :, :-1]