import os import sys import numpy as np import torch import torch.nn.functional as F parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) gmflow_dir = os.path.join(parent_dir, 'gmflow_module') sys.path.insert(0, gmflow_dir) from gmflow.gmflow import GMFlow # noqa: E702 E402 F401 from utils.utils import InputPadder # noqa: E702 E402 import huggingface_hub repo_name = 'Anonymous-sub/Rerender' global_device = 'cuda' if torch.cuda.is_available() else 'cpu' gmflow_path = huggingface_hub.hf_hub_download( repo_name, 'models/gmflow_sintel-0c07dcb3.pth', local_dir='models') def coords_grid(b, h, w, homogeneous=False, device=None): y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] stacks = [x, y] if homogeneous: ones = torch.ones_like(x) # [H, W] stacks.append(ones) grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] if device is not None: grid = grid.to(global_device) return grid def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): # img: [B, C, H, W] # sample_coords: [B, 2, H, W] in image scale if sample_coords.size(1) != 2: # [B, H, W, 2] sample_coords = sample_coords.permute(0, 3, 1, 2) b, _, h, w = sample_coords.shape # Normalize to [-1, 1] x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) if return_mask: mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & ( y_grid <= 1) # [B, H, W] return img, mask return img def flow_warp(feature, flow, mask=False, mode='bilinear', padding_mode='zeros'): b, c, h, w = feature.size() assert flow.size(1) == 2 grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] return bilinear_sample(feature, grid, mode=mode, padding_mode=padding_mode, return_mask=mask) def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5): # fwd_flow, bwd_flow: [B, 2, H, W] # alpha and beta values are following UnFlow # (https://arxiv.org/abs/1711.07837) assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) threshold = alpha * flow_mag + beta fwd_occ = (diff_fwd > threshold).float() # [B, H, W] bwd_occ = (diff_bwd > threshold).float() return fwd_occ, bwd_occ @torch.no_grad() def get_warped_and_mask(flow_model, image1, image2, image3=None, pixel_consistency=False): if image3 is None: image3 = image1 padder = InputPadder(image1.shape, padding_factor=8) image1, image2 = padder.pad(image1[None].to(global_device), image2[None].to(global_device)) results_dict = flow_model(image1, image2, attn_splits_list=[2], corr_radius_list=[-1], prop_radius_list=[-1], pred_bidir_flow=True) flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W] bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W] fwd_occ, bwd_occ = forward_backward_consistency_check( fwd_flow, bwd_flow) # [1, H, W] float if pixel_consistency: warped_image1 = flow_warp(image1, bwd_flow) bwd_occ = torch.clamp( bwd_occ + (abs(image2 - warped_image1).mean(dim=1) > 255 * 0.25).float(), 0, 1).unsqueeze(0) warped_results = flow_warp(image3, bwd_flow) return warped_results, bwd_occ, bwd_flow class FlowCalc(): def __init__(self, model_path='./models/gmflow_sintel-0c07dcb3.pth'): flow_model = GMFlow( feature_channels=128, num_scales=1, upsample_factor=8, num_head=1, attention_type='swin', ffn_dim_expansion=4, num_transformer_layers=6, ).to(global_device) checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) weights = checkpoint['model'] if 'model' in checkpoint else checkpoint flow_model.load_state_dict(weights, strict=False) flow_model.eval() self.model = flow_model @torch.no_grad() def get_flow(self, image1, image2, save_path=None): if save_path is not None and os.path.exists(save_path): bwd_flow = read_flow(save_path) return bwd_flow image1 = torch.from_numpy(image1).permute(2, 0, 1).float() image2 = torch.from_numpy(image2).permute(2, 0, 1).float() padder = InputPadder(image1.shape, padding_factor=8) image1, image2 = padder.pad(image1[None].to(global_device), image2[None].to(global_device)) results_dict = self.model(image1, image2, attn_splits_list=[2], corr_radius_list=[-1], prop_radius_list=[-1], pred_bidir_flow=True) flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W] if save_path is not None: flow_np = bwd_flow.cpu().numpy() np.save(save_path, flow_np) return bwd_flow def warp(self, img, flow, mode='bilinear'): expand = False if len(img.shape) == 2: expand = True img = np.expand_dims(img, 2) img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) dtype = img.dtype img = img.to(torch.float) res = flow_warp(img, flow, mode=mode) res = res.to(dtype) res = res[0].cpu().permute(1, 2, 0).numpy() if expand: res = res[:, :, 0] return res def read_flow(save_path): flow_np = np.load(save_path) bwd_flow = torch.from_numpy(flow_np) return bwd_flow flow_calc = FlowCalc()