File size: 2,533 Bytes
6755a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import Callable

import numpy as np
import torch
from torch import nn
import cv2


def torch_wrap(img, flow, mode='bilinear', padding_mode='zeros', align_corners=True, outlier_func: Callable=None):
    """
    reffers:
        1. https://github.com/safwankdb/ReCoNet-PyTorch/blob/master/utilities.py
        2. https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
    warp an image/tensor (img) back to output images, according to the optical flow
    img: [B, C, H, W], torch.FloatTensor, [0, 1] or [0, 255]
    flow: [B, 2, H, W]
    return:
        output: BxCxHxW
    """
    B, C, H, W = img.size()
    # mesh grid 
    xx = torch.arange(0, W).view(1,-1).repeat(H,1)
    yy = torch.arange(0, H).view(-1,1).repeat(1,W)
    xx = xx.view(1,1,H,W).repeat(B,1,1,1)
    yy = yy.view(1,1,H,W).repeat(B,1,1,1)
    grid = torch.cat((xx,yy), 1)
    grid = grid.to(img.device, dtype=img.dtype)
    # print(img.shape, grid.shape, flow.shape)
    vgrid = grid + flow
    if outlier_func is not None:
        from ..utils.torch_util import find_outlier
        mask = find_outlier(vgrid).to(img.device)
        mask = mask.unsqueeze(dim=1).repeat(1,C,1,1)
    # scale grid to [-1,1]
    vgrid[:,0,:,:] = 2.0 * vgrid[:,0,:,:] / max(W - 1, 1) - 1.0
    vgrid[:,1,:,:] = 2.0 * vgrid[:,1,:,:] / max(H - 1, 1) - 1.0
    vgrid = vgrid.permute(0, 2, 3, 1)
    output = nn.functional.grid_sample(img, vgrid,  mode=mode, padding_mode=padding_mode, align_corners=align_corners)
    if outlier_func is not None:
        outlier = outlier_func(output.shape).to(img.device)
        output = mask * output + (1 - mask) * outlier
    output = output.to(dtype=img.dtype)
    return output


def opencv_wrap(img:np.array, flow: np.array, outlier_func: Callable=None) -> np.array:
    """wrap image with flow to output image

    Args:
        img (np.array): source image, HxWx3, [0-255]
        flow (np.array): flow from source image to output image, HxWx2, [-int, int]

    Returns:
        np.array: output image, HxWx3, 
    """
    from ..utils.vision_util import find_outlier
    h, w, c = flow.shape
    flow[:,:,0] += np.arange(w)
    flow[:,:,1] += np.arange(h)[:,np.newaxis]
    output = cv2.remap(img, flow, None, cv2.INTER_LINEAR)
    if outlier_func is not None:
        outlier = outlier_func(output.shape)
        mask = find_outlier(np.transpose(flow, (1, 2, 0)))
        mask = np.repeat(mask[:, :, np.newaxis], repeats=c, axis=2)
        output = mask * output + (1 - mask) * outlier
    return output