diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..6bd4034cf3a1e77c58207b511e1d05e122c5c3e2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,131 @@ +.vscode + +# ignored files +version.py + +# ignored files with suffix +*.html +# *.png +# *.jpeg +# *.jpg +# *.gif +*.pt +*.pth +*.dat +*.zip + +# template + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# project +experiments_model/ +unreleased/ +results_eval/ +results/ +*debug* +*old* +*.sh \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..bb4d0f5aa33a710069e0b1f3675707d60e6d4a0b --- /dev/null +++ b/LICENSE @@ -0,0 +1,14 @@ +# S-Lab License 1.0 + +Copyright 2023 S-Lab + +Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\ +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. + + +--- +For the commercial use of the code, please consult Prof. Chen Change Loy (ccloy@ntu.edu.sg) \ No newline at end of file diff --git a/RAFT/__init__.py b/RAFT/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e7179ea3ce4ad81425c619772d4bc47bc7ceea3a --- /dev/null +++ b/RAFT/__init__.py @@ -0,0 +1,2 @@ +# from .demo import RAFT_infer +from .raft import RAFT diff --git a/RAFT/corr.py b/RAFT/corr.py new file mode 100755 index 0000000000000000000000000000000000000000..449dbd963b8303eda242a65063ca857b95475721 --- /dev/null +++ b/RAFT/corr.py @@ -0,0 +1,111 @@ +import torch +import torch.nn.functional as F +from .utils.utils import bilinear_sampler, coords_grid + +try: + import alt_cuda_corr +except: + # alt_cuda_corr is not compiled + pass + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels-1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2*r+1) + dy = torch.linspace(-r, r, 2*r+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht*wd) + fmap2 = fmap2.view(batch, dim, ht*wd) + + corr = torch.matmul(fmap1.transpose(1,2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class CorrLayer(torch.autograd.Function): + @staticmethod + def forward(ctx, fmap1, fmap2, coords, r): + fmap1 = fmap1.contiguous() + fmap2 = fmap2.contiguous() + coords = coords.contiguous() + ctx.save_for_backward(fmap1, fmap2, coords) + ctx.r = r + corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r) + return corr + + @staticmethod + def backward(ctx, grad_corr): + fmap1, fmap2, coords = ctx.saved_tensors + grad_corr = grad_corr.contiguous() + fmap1_grad, fmap2_grad, coords_grad = \ + correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r) + return fmap1_grad, fmap2_grad, coords_grad, None + + +class AlternateCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + def __call__(self, coords): + + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1) + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1) + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r) + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / 16.0 diff --git a/RAFT/datasets.py b/RAFT/datasets.py new file mode 100755 index 0000000000000000000000000000000000000000..3411fdacfb900024005e8997d07c600e963a95ca --- /dev/null +++ b/RAFT/datasets.py @@ -0,0 +1,235 @@ +# Data loading based on https://github.com/NVIDIA/flownet2-pytorch + +import numpy as np +import torch +import torch.utils.data as data +import torch.nn.functional as F + +import os +import math +import random +from glob import glob +import os.path as osp + +from utils import frame_utils +from utils.augmentor import FlowAugmentor, SparseFlowAugmentor + + +class FlowDataset(data.Dataset): + def __init__(self, aug_params=None, sparse=False): + self.augmentor = None + self.sparse = sparse + if aug_params is not None: + if sparse: + self.augmentor = SparseFlowAugmentor(**aug_params) + else: + self.augmentor = FlowAugmentor(**aug_params) + + self.is_test = False + self.init_seed = False + self.flow_list = [] + self.image_list = [] + self.extra_info = [] + + def __getitem__(self, index): + + if self.is_test: + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + img1 = np.array(img1).astype(np.uint8)[..., :3] + img2 = np.array(img2).astype(np.uint8)[..., :3] + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + return img1, img2, self.extra_info[index] + + if not self.init_seed: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + torch.manual_seed(worker_info.id) + np.random.seed(worker_info.id) + random.seed(worker_info.id) + self.init_seed = True + + index = index % len(self.image_list) + valid = None + if self.sparse: + flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) + else: + flow = frame_utils.read_gen(self.flow_list[index]) + + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + flow = np.array(flow).astype(np.float32) + img1 = np.array(img1).astype(np.uint8) + img2 = np.array(img2).astype(np.uint8) + + # grayscale images + if len(img1.shape) == 2: + img1 = np.tile(img1[...,None], (1, 1, 3)) + img2 = np.tile(img2[...,None], (1, 1, 3)) + else: + img1 = img1[..., :3] + img2 = img2[..., :3] + + if self.augmentor is not None: + if self.sparse: + img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) + else: + img1, img2, flow = self.augmentor(img1, img2, flow) + + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + flow = torch.from_numpy(flow).permute(2, 0, 1).float() + + if valid is not None: + valid = torch.from_numpy(valid) + else: + valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) + + return img1, img2, flow, valid.float() + + + def __rmul__(self, v): + self.flow_list = v * self.flow_list + self.image_list = v * self.image_list + return self + + def __len__(self): + return len(self.image_list) + + +class MpiSintel(FlowDataset): + def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): + super(MpiSintel, self).__init__(aug_params) + flow_root = osp.join(root, split, 'flow') + image_root = osp.join(root, split, dstype) + + if split == 'test': + self.is_test = True + + for scene in os.listdir(image_root): + image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) + for i in range(len(image_list)-1): + self.image_list += [ [image_list[i], image_list[i+1]] ] + self.extra_info += [ (scene, i) ] # scene and frame_id + + if split != 'test': + self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) + + +class FlyingChairs(FlowDataset): + def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): + super(FlyingChairs, self).__init__(aug_params) + + images = sorted(glob(osp.join(root, '*.ppm'))) + flows = sorted(glob(osp.join(root, '*.flo'))) + assert (len(images)//2 == len(flows)) + + split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) + for i in range(len(flows)): + xid = split_list[i] + if (split=='training' and xid==1) or (split=='validation' and xid==2): + self.flow_list += [ flows[i] ] + self.image_list += [ [images[2*i], images[2*i+1]] ] + + +class FlyingThings3D(FlowDataset): + def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): + super(FlyingThings3D, self).__init__(aug_params) + + for cam in ['left']: + for direction in ['into_future', 'into_past']: + image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) + image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) + + flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) + flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) + + for idir, fdir in zip(image_dirs, flow_dirs): + images = sorted(glob(osp.join(idir, '*.png')) ) + flows = sorted(glob(osp.join(fdir, '*.pfm')) ) + for i in range(len(flows)-1): + if direction == 'into_future': + self.image_list += [ [images[i], images[i+1]] ] + self.flow_list += [ flows[i] ] + elif direction == 'into_past': + self.image_list += [ [images[i+1], images[i]] ] + self.flow_list += [ flows[i+1] ] + + +class KITTI(FlowDataset): + def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): + super(KITTI, self).__init__(aug_params, sparse=True) + if split == 'testing': + self.is_test = True + + root = osp.join(root, split) + images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) + images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) + + for img1, img2 in zip(images1, images2): + frame_id = img1.split('/')[-1] + self.extra_info += [ [frame_id] ] + self.image_list += [ [img1, img2] ] + + if split == 'training': + self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) + + +class HD1K(FlowDataset): + def __init__(self, aug_params=None, root='datasets/HD1k'): + super(HD1K, self).__init__(aug_params, sparse=True) + + seq_ix = 0 + while 1: + flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) + images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) + + if len(flows) == 0: + break + + for i in range(len(flows)-1): + self.flow_list += [flows[i]] + self.image_list += [ [images[i], images[i+1]] ] + + seq_ix += 1 + + +def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): + """ Create the data loader for the corresponding trainign set """ + + if args.stage == 'chairs': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} + train_dataset = FlyingChairs(aug_params, split='training') + + elif args.stage == 'things': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} + clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') + final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') + train_dataset = clean_dataset + final_dataset + + elif args.stage == 'sintel': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} + things = FlyingThings3D(aug_params, dstype='frames_cleanpass') + sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') + sintel_final = MpiSintel(aug_params, split='training', dstype='final') + + if TRAIN_DS == 'C+T+K+S+H': + kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) + hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) + train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things + + elif TRAIN_DS == 'C+T+K/S': + train_dataset = 100*sintel_clean + 100*sintel_final + things + + elif args.stage == 'kitti': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} + train_dataset = KITTI(aug_params, split='training') + + train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, + pin_memory=False, shuffle=True, num_workers=4, drop_last=True) + + print('Training with %d image pairs' % len(train_dataset)) + return train_loader + diff --git a/RAFT/demo.py b/RAFT/demo.py new file mode 100755 index 0000000000000000000000000000000000000000..096963bdbb36aed3df673f131d6e044d8c6f95ea --- /dev/null +++ b/RAFT/demo.py @@ -0,0 +1,79 @@ +import sys +import argparse +import os +import cv2 +import glob +import numpy as np +import torch +from PIL import Image + +from .raft import RAFT +from .utils import flow_viz +from .utils.utils import InputPadder + + + +DEVICE = 'cuda' + +def load_image(imfile): + img = np.array(Image.open(imfile)).astype(np.uint8) + img = torch.from_numpy(img).permute(2, 0, 1).float() + return img + + +def load_image_list(image_files): + images = [] + for imfile in sorted(image_files): + images.append(load_image(imfile)) + + images = torch.stack(images, dim=0) + images = images.to(DEVICE) + + padder = InputPadder(images.shape) + return padder.pad(images)[0] + + +def viz(img, flo): + img = img[0].permute(1,2,0).cpu().numpy() + flo = flo[0].permute(1,2,0).cpu().numpy() + + # map flow to rgb image + flo = flow_viz.flow_to_image(flo) + # img_flo = np.concatenate([img, flo], axis=0) + img_flo = flo + + cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]]) + # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) + # cv2.waitKey() + + +def demo(args): + model = torch.nn.DataParallel(RAFT(args)) + model.load_state_dict(torch.load(args.model)) + + model = model.module + model.to(DEVICE) + model.eval() + + with torch.no_grad(): + images = glob.glob(os.path.join(args.path, '*.png')) + \ + glob.glob(os.path.join(args.path, '*.jpg')) + + images = load_image_list(images) + for i in range(images.shape[0]-1): + image1 = images[i,None] + image2 = images[i+1,None] + + flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) + viz(image1, flow_up) + + +def RAFT_infer(args): + model = torch.nn.DataParallel(RAFT(args)) + model.load_state_dict(torch.load(args.model)) + + model = model.module + model.to(DEVICE) + model.eval() + + return model diff --git a/RAFT/extractor.py b/RAFT/extractor.py new file mode 100755 index 0000000000000000000000000000000000000000..9a9c759d1243d4694e8656c2f6f8a37e53edd009 --- /dev/null +++ b/RAFT/extractor.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/RAFT/raft.py b/RAFT/raft.py new file mode 100755 index 0000000000000000000000000000000000000000..829ef97b8d3e280aac59ebef7bb2eaf06274b62a --- /dev/null +++ b/RAFT/raft.py @@ -0,0 +1,146 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .update import BasicUpdateBlock, SmallUpdateBlock +from .extractor import BasicEncoder, SmallEncoder +from .corr import CorrBlock, AlternateCorrBlock +from .utils.utils import bilinear_sampler, coords_grid, upflow8 + +try: + autocast = torch.cuda.amp.autocast +except: + # dummy autocast for PyTorch < 1.6 + class autocast: + def __init__(self, enabled): + pass + def __enter__(self): + pass + def __exit__(self, *args): + pass + + +class RAFT(nn.Module): + def __init__(self, args): + super(RAFT, self).__init__() + self.args = args + + if args.small: + self.hidden_dim = hdim = 96 + self.context_dim = cdim = 64 + args.corr_levels = 4 + args.corr_radius = 3 + + else: + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + args.corr_levels = 4 + args.corr_radius = 4 + + if 'dropout' not in args._get_kwargs(): + args.dropout = 0 + + if 'alternate_corr' not in args._get_kwargs(): + args.alternate_corr = False + + # feature network, context network, and update block + if args.small: + self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) + self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) + self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) + + else: + self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) + self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) + + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = coords_grid(N, H//8, W//8).to(img.device) + coords1 = coords_grid(N, H//8, W//8).to(img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3,3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8*H, 8*W) + + + def forward(self, image1, image2, iters=12, flow_init=None, test_mode=True): + """ Estimate optical flow between pair of frames """ + + # image1 = 2 * (image1 / 255.0) - 1.0 + # image2 = 2 * (image2 / 255.0) - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + with autocast(enabled=self.args.mixed_precision): + fmap1, fmap2 = self.fnet([image1, image2]) + + fmap1 = fmap1.float() + fmap2 = fmap2.float() + + if self.args.alternate_corr: + corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + else: + corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network + with autocast(enabled=self.args.mixed_precision): + cnet = self.cnet(image1) + net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + + coords0, coords1 = self.initialize_flow(image1) + + if flow_init is not None: + coords1 = coords1 + flow_init + + flow_predictions = [] + for itr in range(iters): + coords1 = coords1.detach() + corr = corr_fn(coords1) # index correlation volume + + flow = coords1 - coords0 + with autocast(enabled=self.args.mixed_precision): + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + flow_up = upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + if test_mode: + return coords1 - coords0, flow_up + + return flow_predictions diff --git a/RAFT/update.py b/RAFT/update.py new file mode 100755 index 0000000000000000000000000000000000000000..f940497f9b5eb1c12091574fe9a0223a1b196d50 --- /dev/null +++ b/RAFT/update.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + + h = (1-z) * h + z * q + return h + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + +class SmallMotionEncoder(nn.Module): + def __init__(self, args): + super(SmallMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) + self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf2 = nn.Conv2d(64, 32, 3, padding=1) + self.conv = nn.Conv2d(128, 80, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class SmallUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=96): + super(SmallUpdateBlock, self).__init__() + self.encoder = SmallMotionEncoder(args) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + return net, None, delta_flow + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow + + + diff --git a/RAFT/utils/__init__.py b/RAFT/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..0437149bfee42718973728158641020ccc1906ad --- /dev/null +++ b/RAFT/utils/__init__.py @@ -0,0 +1,2 @@ +from .flow_viz import flow_to_image +from .frame_utils import writeFlow diff --git a/RAFT/utils/augmentor.py b/RAFT/utils/augmentor.py new file mode 100755 index 0000000000000000000000000000000000000000..e81c4f2b5c16c31c0ae236d744f299d430228a04 --- /dev/null +++ b/RAFT/utils/augmentor.py @@ -0,0 +1,246 @@ +import numpy as np +import random +import math +from PIL import Image + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +import torch +from torchvision.transforms import ColorJitter +import torch.nn.functional as F + + +class FlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): + + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + """ Photometric augmentation """ + + # asymmetric + if np.random.rand() < self.asymmetric_color_aug_prob: + img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) + img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) + + # symmetric + else: + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + + return img1, img2 + + def eraser_transform(self, img1, img2, bounds=[50, 100]): + """ Occlusion augmentation """ + + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(bounds[0], bounds[1]) + dy = np.random.randint(bounds[0], bounds[1]) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def spatial_transform(self, img1, img2, flow): + # randomly sample scale + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), + (self.crop_size[1] + 8) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + + scale_x = np.clip(scale_x, min_scale, None) + scale_y = np.clip(scale_y, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = flow * [scale_x, scale_y] + + if self.do_flip: + if np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow = self.spatial_transform(img1, img2, flow) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + + return img1, img2, flow + +class SparseFlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + return img1, img2 + + def eraser_transform(self, img1, img2): + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(50, 100) + dy = np.random.randint(50, 100) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht)) + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = valid.reshape(-1).astype(np.float32) + + coords0 = coords[valid>=1] + flow0 = flow[valid>=1] + + ht1 = int(round(ht * fy)) + wd1 = int(round(wd * fx)) + + coords1 = coords0 * [fx, fy] + flow1 = flow0 * [fx, fy] + + xx = np.round(coords1[:,0]).astype(np.int32) + yy = np.round(coords1[:,1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) + valid_img = np.zeros([ht1, wd1], dtype=np.int32) + + flow_img[yy, xx] = flow1 + valid_img[yy, xx] = 1 + + return flow_img, valid_img + + def spatial_transform(self, img1, img2, flow, valid): + # randomly sample scale + + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 1) / float(ht), + (self.crop_size[1] + 1) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = np.clip(scale, min_scale, None) + scale_y = np.clip(scale, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) + + if self.do_flip: + if np.random.rand() < 0.5: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + valid = valid[:, ::-1] + + margin_y = 20 + margin_x = 50 + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) + x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) + + y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) + x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + return img1, img2, flow, valid + + + def __call__(self, img1, img2, flow, valid): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + valid = np.ascontiguousarray(valid) + + return img1, img2, flow, valid diff --git a/RAFT/utils/flow_viz.py b/RAFT/utils/flow_viz.py new file mode 100755 index 0000000000000000000000000000000000000000..dcee65e89b91b07ee0496aeb4c7e7436abf99641 --- /dev/null +++ b/RAFT/utils/flow_viz.py @@ -0,0 +1,132 @@ +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization + + +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +import numpy as np + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file diff --git a/RAFT/utils/flow_viz_pt.py b/RAFT/utils/flow_viz_pt.py new file mode 100644 index 0000000000000000000000000000000000000000..12e666a40fa49c11592e311b141aa2a522e567fd --- /dev/null +++ b/RAFT/utils/flow_viz_pt.py @@ -0,0 +1,118 @@ +# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization +import torch +torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 + +@torch.no_grad() +def flow_to_image(flow: torch.Tensor) -> torch.Tensor: + + """ + Converts a flow to an RGB image. + + Args: + flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. + + Returns: + img (Tensor): Image Tensor of dtype uint8 where each color corresponds + to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. + """ + + if flow.dtype != torch.float: + raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") + + orig_shape = flow.shape + if flow.ndim == 3: + flow = flow[None] # Add batch dim + + if flow.ndim != 4 or flow.shape[1] != 2: + raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") + + max_norm = torch.sum(flow**2, dim=1).sqrt().max() + epsilon = torch.finfo((flow).dtype).eps + normalized_flow = flow / (max_norm + epsilon) + img = _normalized_flow_to_image(normalized_flow) + + if len(orig_shape) == 3: + img = img[0] # Remove batch dim + return img + +@torch.no_grad() +def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: + + """ + Converts a batch of normalized flow to an RGB image. + + Args: + normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) + Returns: + img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. + """ + + N, _, H, W = normalized_flow.shape + device = normalized_flow.device + flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) + colorwheel = _make_colorwheel().to(device) # shape [55x3] + num_cols = colorwheel.shape[0] + norm = torch.sum(normalized_flow**2, dim=1).sqrt() + a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi + fk = (a + 1) / 2 * (num_cols - 1) + k0 = torch.floor(fk).to(torch.long) + k1 = k0 + 1 + k1[k1 == num_cols] = 0 + f = fk - k0 + + for c in range(colorwheel.shape[1]): + tmp = colorwheel[:, c] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1 - f) * col0 + f * col1 + col = 1 - norm * (1 - col) + flow_image[:, c, :, :] = torch.floor(255. * col) + return flow_image + + +@torch.no_grad() +def _make_colorwheel() -> torch.Tensor: + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. + + Returns: + colorwheel (Tensor[55, 3]): Colorwheel Tensor. + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = torch.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = torch.floor(255. * torch.arange(0., RY) / RY) + col = col + RY + # YG + colorwheel[col : col + YG, 0] = 255 - torch.floor(255. * torch.arange(0., YG) / YG) + colorwheel[col : col + YG, 1] = 255 + col = col + YG + # GC + colorwheel[col : col + GC, 1] = 255 + colorwheel[col : col + GC, 2] = torch.floor(255. * torch.arange(0., GC) / GC) + col = col + GC + # CB + colorwheel[col : col + CB, 1] = 255 - torch.floor(255. * torch.arange(CB) / CB) + colorwheel[col : col + CB, 2] = 255 + col = col + CB + # BM + colorwheel[col : col + BM, 2] = 255 + colorwheel[col : col + BM, 0] = torch.floor(255. * torch.arange(0., BM) / BM) + col = col + BM + # MR + colorwheel[col : col + MR, 2] = 255 - torch.floor(255. * torch.arange(MR) / MR) + colorwheel[col : col + MR, 0] = 255 + return colorwheel diff --git a/RAFT/utils/frame_utils.py b/RAFT/utils/frame_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..6c491135efaffc25bd61ec3ecde99d236f5deb12 --- /dev/null +++ b/RAFT/utils/frame_utils.py @@ -0,0 +1,137 @@ +import numpy as np +from PIL import Image +from os.path import * +import re + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +TAG_CHAR = np.array([202021.25], np.float32) + +def readFlow(fn): + """ Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, 'rb') as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + +def writeFlow(filename,uv,v=None): + """ Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert(uv.ndim == 3) + assert(uv.shape[2] == 2) + u = uv[:,:,0] + v = uv[:,:,1] + else: + u = uv + + assert(u.shape == v.shape) + height,width = u.shape + f = open(filename,'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width*nBands)) + tmp[:,np.arange(width)*2] = u + tmp[:,np.arange(width)*2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +def readFlowKITTI(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) + flow = flow[:,:,::-1].astype(np.float32) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2**15) / 64.0 + return flow, valid + +def readDispKITTI(filename): + disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 + valid = disp > 0.0 + flow = np.stack([-disp, np.zeros_like(disp)], -1) + return flow, valid + + +def writeFlowKITTI(filename, uv): + uv = 64.0 * uv + 2**15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + + +def read_gen(file_name, pil=False): + ext = splitext(file_name)[-1] + if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': + return Image.open(file_name) + elif ext == '.bin' or ext == '.raw': + return np.load(file_name) + elif ext == '.flo': + return readFlow(file_name).astype(np.float32) + elif ext == '.pfm': + flow = readPFM(file_name).astype(np.float32) + if len(flow.shape) == 2: + return flow + else: + return flow[:, :, :-1] + return [] \ No newline at end of file diff --git a/RAFT/utils/utils.py b/RAFT/utils/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..5f32d281c1c46353a0a2bf36b0550adb74125c65 --- /dev/null +++ b/RAFT/utils/utils.py @@ -0,0 +1,82 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self,x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/configs/train_flowcomp.json b/configs/train_flowcomp.json new file mode 100644 index 0000000000000000000000000000000000000000..c3c9ca8043b88611c1e0579ee2f469d3eee987b6 --- /dev/null +++ b/configs/train_flowcomp.json @@ -0,0 +1,40 @@ +{ + "seed": 2023, + "save_dir": "experiments_model/", + "train_data_loader": { + "name": "youtube-vos", + "video_root": "your_video_root", + "flow_root": "your_flow_root", + "w": 432, + "h": 240, + "num_local_frames": 10, + "num_ref_frames": 1, + "load_flow": 0 + }, + "losses": { + "flow_weight": 0.25 + }, + "model": { + "net": "recurrent_flow_completion" + }, + "trainer": { + "version": "trainer_flow_w_edge", + "type": "Adam", + "beta1": 0, + "beta2": 0.99, + "lr": 5e-5, + "batch_size": 8, + "num_workers": 4, + "num_prefetch_queue": 4, + "log_freq": 100, + "save_freq": 5e3, + "iterations": 700e3, + "scheduler": { + "type": "MultiStepLR", + "milestones": [ + 300e3, 400e3, 500e3, 600e3 + ], + "gamma": 0.2 + } + } +} \ No newline at end of file diff --git a/configs/train_propainter.json b/configs/train_propainter.json new file mode 100644 index 0000000000000000000000000000000000000000..c0c29ba7a6ad02d6983206d530f6256d8b120ec7 --- /dev/null +++ b/configs/train_propainter.json @@ -0,0 +1,48 @@ +{ + "seed": 2023, + "save_dir": "experiments_model/", + "train_data_loader": { + "name": "youtube-vos", + "video_root": "your_video_root", + "flow_root": "your_flow_root", + "w": 432, + "h": 240, + "num_local_frames": 10, + "num_ref_frames": 6, + "load_flow": 0 + }, + "losses": { + "hole_weight": 1, + "valid_weight": 1, + "flow_weight": 1, + "adversarial_weight": 0.01, + "GAN_LOSS": "hinge", + "perceptual_weight": 0 + }, + "model": { + "net": "propainter", + "no_dis": 0, + "load_d": 1, + "interp_mode": "nearest" + }, + "trainer": { + "version": "trainer", + "type": "Adam", + "beta1": 0, + "beta2": 0.99, + "lr": 1e-4, + "batch_size": 8, + "num_workers": 8, + "num_prefetch_queue": 8, + "log_freq": 100, + "save_freq": 1e4, + "iterations": 700e3, + "scheduler": { + "type": "MultiStepLR", + "milestones": [ + 400e3 + ], + "gamma": 0.1 + } + } +} \ No newline at end of file diff --git a/core/dataset.py b/core/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..27b135bb7716f0e89d9a3ec9fd4411dfe3eb94eb --- /dev/null +++ b/core/dataset.py @@ -0,0 +1,232 @@ +import os +import json +import random + +import cv2 +from PIL import Image +import numpy as np + +import torch +import torchvision.transforms as transforms + +from utils.file_client import FileClient +from utils.img_util import imfrombytes +from utils.flow_util import resize_flow, flowread +from core.utils import (create_random_shape_with_random_motion, Stack, + ToTorchFormatTensor, GroupRandomHorizontalFlip,GroupRandomHorizontalFlowFlip) + + +class TrainDataset(torch.utils.data.Dataset): + def __init__(self, args: dict): + self.args = args + self.video_root = args['video_root'] + self.flow_root = args['flow_root'] + self.num_local_frames = args['num_local_frames'] + self.num_ref_frames = args['num_ref_frames'] + self.size = self.w, self.h = (args['w'], args['h']) + + self.load_flow = args['load_flow'] + if self.load_flow: + assert os.path.exists(self.flow_root) + + json_path = os.path.join('./datasets', args['name'], 'train.json') + + with open(json_path, 'r') as f: + self.video_train_dict = json.load(f) + self.video_names = sorted(list(self.video_train_dict.keys())) + + # self.video_names = sorted(os.listdir(self.video_root)) + self.video_dict = {} + self.frame_dict = {} + + for v in self.video_names: + frame_list = sorted(os.listdir(os.path.join(self.video_root, v))) + v_len = len(frame_list) + if v_len > self.num_local_frames + self.num_ref_frames: + self.video_dict[v] = v_len + self.frame_dict[v] = frame_list + + + self.video_names = list(self.video_dict.keys()) # update names + + self._to_tensors = transforms.Compose([ + Stack(), + ToTorchFormatTensor(), + ]) + self.file_client = FileClient('disk') + + def __len__(self): + return len(self.video_names) + + def _sample_index(self, length, sample_length, num_ref_frame=3): + complete_idx_set = list(range(length)) + pivot = random.randint(0, length - sample_length) + local_idx = complete_idx_set[pivot:pivot + sample_length] + remain_idx = list(set(complete_idx_set) - set(local_idx)) + ref_index = sorted(random.sample(remain_idx, num_ref_frame)) + + return local_idx + ref_index + + def __getitem__(self, index): + video_name = self.video_names[index] + # create masks + all_masks = create_random_shape_with_random_motion( + self.video_dict[video_name], imageHeight=self.h, imageWidth=self.w) + + # create sample index + selected_index = self._sample_index(self.video_dict[video_name], + self.num_local_frames, + self.num_ref_frames) + + # read video frames + frames = [] + masks = [] + flows_f, flows_b = [], [] + for idx in selected_index: + frame_list = self.frame_dict[video_name] + img_path = os.path.join(self.video_root, video_name, frame_list[idx]) + img_bytes = self.file_client.get(img_path, 'img') + img = imfrombytes(img_bytes, float32=False) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) + img = Image.fromarray(img) + + frames.append(img) + masks.append(all_masks[idx]) + + if len(frames) <= self.num_local_frames-1 and self.load_flow: + current_n = frame_list[idx][:-4] + next_n = frame_list[idx+1][:-4] + flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo') + flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo') + flow_f = flowread(flow_f_path, quantize=False) + flow_b = flowread(flow_b_path, quantize=False) + flow_f = resize_flow(flow_f, self.h, self.w) + flow_b = resize_flow(flow_b, self.h, self.w) + flows_f.append(flow_f) + flows_b.append(flow_b) + + if len(frames) == self.num_local_frames: # random reverse + if random.random() < 0.5: + frames.reverse() + masks.reverse() + if self.load_flow: + flows_f.reverse() + flows_b.reverse() + flows_ = flows_f + flows_f = flows_b + flows_b = flows_ + + if self.load_flow: + frames, flows_f, flows_b = GroupRandomHorizontalFlowFlip()(frames, flows_f, flows_b) + else: + frames = GroupRandomHorizontalFlip()(frames) + + # normalizate, to tensors + frame_tensors = self._to_tensors(frames) * 2.0 - 1.0 + mask_tensors = self._to_tensors(masks) + if self.load_flow: + flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1 + flows_b = np.stack(flows_b, axis=-1) + flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float() + flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float() + + # img [-1,1] mask [0,1] + if self.load_flow: + return frame_tensors, mask_tensors, flows_f, flows_b, video_name + else: + return frame_tensors, mask_tensors, 'None', 'None', video_name + + +class TestDataset(torch.utils.data.Dataset): + def __init__(self, args): + self.args = args + self.size = self.w, self.h = args['size'] + + self.video_root = args['video_root'] + self.mask_root = args['mask_root'] + self.flow_root = args['flow_root'] + + self.load_flow = args['load_flow'] + if self.load_flow: + assert os.path.exists(self.flow_root) + self.video_names = sorted(os.listdir(self.mask_root)) + + self.video_dict = {} + self.frame_dict = {} + + for v in self.video_names: + frame_list = sorted(os.listdir(os.path.join(self.video_root, v))) + v_len = len(frame_list) + self.video_dict[v] = v_len + self.frame_dict[v] = frame_list + + self._to_tensors = transforms.Compose([ + Stack(), + ToTorchFormatTensor(), + ]) + self.file_client = FileClient('disk') + + def __len__(self): + return len(self.video_names) + + def __getitem__(self, index): + video_name = self.video_names[index] + selected_index = list(range(self.video_dict[video_name])) + + # read video frames + frames = [] + masks = [] + flows_f, flows_b = [], [] + for idx in selected_index: + frame_list = self.frame_dict[video_name] + frame_path = os.path.join(self.video_root, video_name, frame_list[idx]) + + img_bytes = self.file_client.get(frame_path, 'input') + img = imfrombytes(img_bytes, float32=False) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) + img = Image.fromarray(img) + + frames.append(img) + + mask_path = os.path.join(self.mask_root, video_name, str(idx).zfill(5) + '.png') + mask = Image.open(mask_path).resize(self.size, Image.NEAREST).convert('L') + + # origin: 0 indicates missing. now: 1 indicates missing + mask = np.asarray(mask) + m = np.array(mask > 0).astype(np.uint8) + + m = cv2.dilate(m, + cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)), + iterations=4) + mask = Image.fromarray(m * 255) + masks.append(mask) + + if len(frames) <= len(selected_index)-1 and self.load_flow: + current_n = frame_list[idx][:-4] + next_n = frame_list[idx+1][:-4] + flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo') + flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo') + flow_f = flowread(flow_f_path, quantize=False) + flow_b = flowread(flow_b_path, quantize=False) + flow_f = resize_flow(flow_f, self.h, self.w) + flow_b = resize_flow(flow_b, self.h, self.w) + flows_f.append(flow_f) + flows_b.append(flow_b) + + # normalizate, to tensors + frames_PIL = [np.array(f).astype(np.uint8) for f in frames] + frame_tensors = self._to_tensors(frames) * 2.0 - 1.0 + mask_tensors = self._to_tensors(masks) + + if self.load_flow: + flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1 + flows_b = np.stack(flows_b, axis=-1) + flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float() + flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float() + + if self.load_flow: + return frame_tensors, mask_tensors, flows_f, flows_b, video_name, frames_PIL + else: + return frame_tensors, mask_tensors, 'None', 'None', video_name \ No newline at end of file diff --git a/core/dist.py b/core/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..4e4e9e670a3b853fac345618d3557d648d813902 --- /dev/null +++ b/core/dist.py @@ -0,0 +1,47 @@ +import os +import torch + + +def get_world_size(): + """Find OMPI world size without calling mpi functions + :rtype: int + """ + if os.environ.get('PMI_SIZE') is not None: + return int(os.environ.get('PMI_SIZE') or 1) + elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) + else: + return torch.cuda.device_count() + + +def get_global_rank(): + """Find OMPI world rank without calling mpi functions + :rtype: int + """ + if os.environ.get('PMI_RANK') is not None: + return int(os.environ.get('PMI_RANK') or 0) + elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) + else: + return 0 + + +def get_local_rank(): + """Find OMPI local rank without calling mpi functions + :rtype: int + """ + if os.environ.get('MPI_LOCALRANKID') is not None: + return int(os.environ.get('MPI_LOCALRANKID') or 0) + elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) + else: + return 0 + + +def get_master_ip(): + if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: + return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] + elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: + return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') + else: + return "127.0.0.1" diff --git a/core/loss.py b/core/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d94d0ce9433b66ce2dce7adb24acb16051e8da --- /dev/null +++ b/core/loss.py @@ -0,0 +1,180 @@ +import torch +import torch.nn as nn +import lpips +from model.vgg_arch import VGGFeatureExtractor + +class PerceptualLoss(nn.Module): + """Perceptual loss with commonly used style loss. + + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'conv5_4': 1.}, which means the conv5_4 + feature layer (before relu5_4) will be extracted with weight + 1.0 in calculting losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + style_weight (float): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplied by the weight. + Default: 0. + criterion (str): Criterion used for perceptual loss. Default: 'l1'. + """ + + def __init__(self, + layer_weights, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + perceptual_weight=1.0, + style_weight=0., + criterion='l1'): + super(PerceptualLoss, self).__init__() + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.layer_weights = layer_weights + self.vgg = VGGFeatureExtractor( + layer_name_list=list(layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + range_norm=range_norm) + + self.criterion_type = criterion + if self.criterion_type == 'l1': + self.criterion = torch.nn.L1Loss() + elif self.criterion_type == 'l2': + self.criterion = torch.nn.L2loss() + elif self.criterion_type == 'mse': + self.criterion = torch.nn.MSELoss(reduction='mean') + elif self.criterion_type == 'fro': + self.criterion = None + else: + raise NotImplementedError(f'{criterion} criterion has not been supported.') + + def forward(self, x, gt): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + # calculate perceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k] + else: + percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss *= self.perceptual_weight + else: + percep_loss = None + + # calculate style loss + if self.style_weight > 0: + style_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + style_loss += torch.norm( + self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k] + else: + style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat( + gt_features[k])) * self.layer_weights[k] + style_loss *= self.style_weight + else: + style_loss = None + + return percep_loss, style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram + +class LPIPSLoss(nn.Module): + def __init__(self, + loss_weight=1.0, + use_input_norm=True, + range_norm=False,): + super(LPIPSLoss, self).__init__() + self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval() + self.loss_weight = loss_weight + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, pred, target): + if self.range_norm: + pred = (pred + 1) / 2 + target = (target + 1) / 2 + if self.use_input_norm: + pred = (pred - self.mean) / self.std + target = (target - self.mean) / self.std + lpips_loss = self.perceptual(target.contiguous(), pred.contiguous()) + return self.loss_weight * lpips_loss.mean(), None + + +class AdversarialLoss(nn.Module): + r""" + Adversarial loss + https://arxiv.org/abs/1711.10337 + """ + def __init__(self, + type='nsgan', + target_real_label=1.0, + target_fake_label=0.0): + r""" + type = nsgan | lsgan | hinge + """ + super(AdversarialLoss, self).__init__() + self.type = type + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + + if type == 'nsgan': + self.criterion = nn.BCELoss() + elif type == 'lsgan': + self.criterion = nn.MSELoss() + elif type == 'hinge': + self.criterion = nn.ReLU() + + def __call__(self, outputs, is_real, is_disc=None): + if self.type == 'hinge': + if is_disc: + if is_real: + outputs = -outputs + return self.criterion(1 + outputs).mean() + else: + return (-outputs).mean() + else: + labels = (self.real_label + if is_real else self.fake_label).expand_as(outputs) + loss = self.criterion(outputs, labels) + return loss diff --git a/core/lr_scheduler.py b/core/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd1341cdcc64aa1c2a416b837551590ded4a43d --- /dev/null +++ b/core/lr_scheduler.py @@ -0,0 +1,112 @@ +""" + LR scheduler from BasicSR https://github.com/xinntao/BasicSR +""" +import math +from collections import Counter +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepRestartLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + def __init__(self, + optimizer, + milestones, + gamma=0.1, + restarts=(0, ), + restart_weights=(1, ), + last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.restarts = restarts + self.restart_weights = restart_weights + assert len(self.restarts) == len( + self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [ + group['initial_lr'] * weight + for group in self.optimizer.param_groups + ] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [ + group['lr'] * self.gamma**self.milestones[self.last_epoch] + for group in self.optimizer.param_groups + ] + + +def get_position_from_periods(iteration, cumulative_period): + """Get the position from a period list. + It will return the index of the right-closest number in the period list. + For example, the cumulative_period = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 2. + Args: + iteration (int): Current iteration. + cumulative_period (list[int]): Cumulative period list. + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_period): + if iteration <= period: + return i + + +class CosineAnnealingRestartLR(_LRScheduler): + """ Cosine annealing with restarts learning rate scheme. + An example of config: + periods = [10, 10, 10, 10] + restart_weights = [1, 0.5, 0.5, 0.5] + eta_min=1e-7 + It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the + scheduler will restart with the weights in restart_weights. + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + periods (list): Period for each cosine anneling cycle. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + eta_min (float): The mimimum lr. Default: 0. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + def __init__(self, + optimizer, + periods, + restart_weights=(1, ), + eta_min=1e-7, + last_epoch=-1): + self.periods = periods + self.restart_weights = restart_weights + self.eta_min = eta_min + assert (len(self.periods) == len(self.restart_weights) + ), 'periods and restart_weights should have the same length.' + self.cumulative_period = [ + sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) + ] + super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + idx = get_position_from_periods(self.last_epoch, + self.cumulative_period) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] + current_period = self.periods[idx] + + return [ + self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * + (1 + math.cos(math.pi * ( + (self.last_epoch - nearest_restart) / current_period))) + for base_lr in self.base_lrs + ] diff --git a/core/metrics.py b/core/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..d0dfb73f1d09a249f801770eada5e133c8148df2 --- /dev/null +++ b/core/metrics.py @@ -0,0 +1,569 @@ +import numpy as np +from skimage import measure +from scipy import linalg + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from core.utils import to_tensors + + +def calculate_epe(flow1, flow2): + """Calculate End point errors.""" + + epe = torch.sum((flow1 - flow2)**2, dim=1).sqrt() + epe = epe.view(-1) + return epe.mean().item() + + +def calculate_psnr(img1, img2): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + Returns: + float: psnr result. + """ + + assert img1.shape == img2.shape, \ + (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20. * np.log10(255. / np.sqrt(mse)) + + +def calc_psnr_and_ssim(img1, img2): + """Calculate PSNR and SSIM for images. + img1: ndarray, range [0, 255] + img2: ndarray, range [0, 255] + """ + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + psnr = calculate_psnr(img1, img2) + ssim = measure.compare_ssim(img1, + img2, + data_range=255, + multichannel=True, + win_size=65) + + return psnr, ssim + + +########################### +# I3D models +########################### + + +def init_i3d_model(i3d_model_path): + print(f"[Loading I3D model from {i3d_model_path} for FID score ..]") + i3d_model = InceptionI3d(400, in_channels=3, final_endpoint='Logits') + i3d_model.load_state_dict(torch.load(i3d_model_path)) + i3d_model.to(torch.device('cuda:0')) + return i3d_model + + +def calculate_i3d_activations(video1, video2, i3d_model, device): + """Calculate VFID metric. + video1: list[PIL.Image] + video2: list[PIL.Image] + """ + video1 = to_tensors()(video1).unsqueeze(0).to(device) + video2 = to_tensors()(video2).unsqueeze(0).to(device) + video1_activations = get_i3d_activations( + video1, i3d_model).cpu().numpy().flatten() + video2_activations = get_i3d_activations( + video2, i3d_model).cpu().numpy().flatten() + + return video1_activations, video2_activations + + +def calculate_vfid(real_activations, fake_activations): + """ + Given two distribution of features, compute the FID score between them + Params: + real_activations: list[ndarray] + fake_activations: list[ndarray] + """ + m1 = np.mean(real_activations, axis=0) + m2 = np.mean(fake_activations, axis=0) + s1 = np.cov(real_activations, rowvar=False) + s2 = np.cov(fake_activations, rowvar=False) + return calculate_frechet_distance(m1, s1, m2, s2) + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representive data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representive data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + # NOQA + np.trace(sigma2) - 2 * tr_covmean) + + +def get_i3d_activations(batched_video, + i3d_model, + target_endpoint='Logits', + flatten=True, + grad_enabled=False): + """ + Get features from i3d model and flatten them to 1d feature, + valid target endpoints are defined in InceptionI3d.VALID_ENDPOINTS + VALID_ENDPOINTS = ( + 'Conv3d_1a_7x7', + 'MaxPool3d_2a_3x3', + 'Conv3d_2b_1x1', + 'Conv3d_2c_3x3', + 'MaxPool3d_3a_3x3', + 'Mixed_3b', + 'Mixed_3c', + 'MaxPool3d_4a_3x3', + 'Mixed_4b', + 'Mixed_4c', + 'Mixed_4d', + 'Mixed_4e', + 'Mixed_4f', + 'MaxPool3d_5a_2x2', + 'Mixed_5b', + 'Mixed_5c', + 'Logits', + 'Predictions', + ) + """ + with torch.set_grad_enabled(grad_enabled): + feat = i3d_model.extract_features(batched_video.transpose(1, 2), + target_endpoint) + if flatten: + feat = feat.view(feat.size(0), -1) + + return feat + + +# This code is from https://github.com/piergiaj/pytorch-i3d/blob/master/pytorch_i3d.py +# I only fix flake8 errors and do some cleaning here + + +class MaxPool3dSamePadding(nn.MaxPool3d): + def compute_pad(self, dim, s): + if s % self.stride[dim] == 0: + return max(self.kernel_size[dim] - self.stride[dim], 0) + else: + return max(self.kernel_size[dim] - (s % self.stride[dim]), 0) + + def forward(self, x): + # compute 'same' padding + (batch, channel, t, h, w) = x.size() + pad_t = self.compute_pad(0, t) + pad_h = self.compute_pad(1, h) + pad_w = self.compute_pad(2, w) + + pad_t_f = pad_t // 2 + pad_t_b = pad_t - pad_t_f + pad_h_f = pad_h // 2 + pad_h_b = pad_h - pad_h_f + pad_w_f = pad_w // 2 + pad_w_b = pad_w - pad_w_f + + pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) + x = F.pad(x, pad) + return super(MaxPool3dSamePadding, self).forward(x) + + +class Unit3D(nn.Module): + def __init__(self, + in_channels, + output_channels, + kernel_shape=(1, 1, 1), + stride=(1, 1, 1), + padding=0, + activation_fn=F.relu, + use_batch_norm=True, + use_bias=False, + name='unit_3d'): + """Initializes Unit3D module.""" + super(Unit3D, self).__init__() + + self._output_channels = output_channels + self._kernel_shape = kernel_shape + self._stride = stride + self._use_batch_norm = use_batch_norm + self._activation_fn = activation_fn + self._use_bias = use_bias + self.name = name + self.padding = padding + + self.conv3d = nn.Conv3d( + in_channels=in_channels, + out_channels=self._output_channels, + kernel_size=self._kernel_shape, + stride=self._stride, + padding=0, # we always want padding to be 0 here. We will + # dynamically pad based on input size in forward function + bias=self._use_bias) + + if self._use_batch_norm: + self.bn = nn.BatchNorm3d(self._output_channels, + eps=0.001, + momentum=0.01) + + def compute_pad(self, dim, s): + if s % self._stride[dim] == 0: + return max(self._kernel_shape[dim] - self._stride[dim], 0) + else: + return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0) + + def forward(self, x): + # compute 'same' padding + (batch, channel, t, h, w) = x.size() + pad_t = self.compute_pad(0, t) + pad_h = self.compute_pad(1, h) + pad_w = self.compute_pad(2, w) + + pad_t_f = pad_t // 2 + pad_t_b = pad_t - pad_t_f + pad_h_f = pad_h // 2 + pad_h_b = pad_h - pad_h_f + pad_w_f = pad_w // 2 + pad_w_b = pad_w - pad_w_f + + pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) + x = F.pad(x, pad) + + x = self.conv3d(x) + if self._use_batch_norm: + x = self.bn(x) + if self._activation_fn is not None: + x = self._activation_fn(x) + return x + + +class InceptionModule(nn.Module): + def __init__(self, in_channels, out_channels, name): + super(InceptionModule, self).__init__() + + self.b0 = Unit3D(in_channels=in_channels, + output_channels=out_channels[0], + kernel_shape=[1, 1, 1], + padding=0, + name=name + '/Branch_0/Conv3d_0a_1x1') + self.b1a = Unit3D(in_channels=in_channels, + output_channels=out_channels[1], + kernel_shape=[1, 1, 1], + padding=0, + name=name + '/Branch_1/Conv3d_0a_1x1') + self.b1b = Unit3D(in_channels=out_channels[1], + output_channels=out_channels[2], + kernel_shape=[3, 3, 3], + name=name + '/Branch_1/Conv3d_0b_3x3') + self.b2a = Unit3D(in_channels=in_channels, + output_channels=out_channels[3], + kernel_shape=[1, 1, 1], + padding=0, + name=name + '/Branch_2/Conv3d_0a_1x1') + self.b2b = Unit3D(in_channels=out_channels[3], + output_channels=out_channels[4], + kernel_shape=[3, 3, 3], + name=name + '/Branch_2/Conv3d_0b_3x3') + self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3], + stride=(1, 1, 1), + padding=0) + self.b3b = Unit3D(in_channels=in_channels, + output_channels=out_channels[5], + kernel_shape=[1, 1, 1], + padding=0, + name=name + '/Branch_3/Conv3d_0b_1x1') + self.name = name + + def forward(self, x): + b0 = self.b0(x) + b1 = self.b1b(self.b1a(x)) + b2 = self.b2b(self.b2a(x)) + b3 = self.b3b(self.b3a(x)) + return torch.cat([b0, b1, b2, b3], dim=1) + + +class InceptionI3d(nn.Module): + """Inception-v1 I3D architecture. + The model is introduced in: + Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset + Joao Carreira, Andrew Zisserman + https://arxiv.org/pdf/1705.07750v1.pdf. + See also the Inception architecture, introduced in: + Going deeper with convolutions + Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, + Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. + http://arxiv.org/pdf/1409.4842v1.pdf. + """ + + # Endpoints of the model in order. During construction, all the endpoints up + # to a designated `final_endpoint` are returned in a dictionary as the + # second return value. + VALID_ENDPOINTS = ( + 'Conv3d_1a_7x7', + 'MaxPool3d_2a_3x3', + 'Conv3d_2b_1x1', + 'Conv3d_2c_3x3', + 'MaxPool3d_3a_3x3', + 'Mixed_3b', + 'Mixed_3c', + 'MaxPool3d_4a_3x3', + 'Mixed_4b', + 'Mixed_4c', + 'Mixed_4d', + 'Mixed_4e', + 'Mixed_4f', + 'MaxPool3d_5a_2x2', + 'Mixed_5b', + 'Mixed_5c', + 'Logits', + 'Predictions', + ) + + def __init__(self, + num_classes=400, + spatial_squeeze=True, + final_endpoint='Logits', + name='inception_i3d', + in_channels=3, + dropout_keep_prob=0.5): + """Initializes I3D model instance. + Args: + num_classes: The number of outputs in the logit layer (default 400, which + matches the Kinetics dataset). + spatial_squeeze: Whether to squeeze the spatial dimensions for the logits + before returning (default True). + final_endpoint: The model contains many possible endpoints. + `final_endpoint` specifies the last endpoint for the model to be built + up to. In addition to the output at `final_endpoint`, all the outputs + at endpoints up to `final_endpoint` will also be returned, in a + dictionary. `final_endpoint` must be one of + InceptionI3d.VALID_ENDPOINTS (default 'Logits'). + name: A string (optional). The name of this module. + Raises: + ValueError: if `final_endpoint` is not recognized. + """ + + if final_endpoint not in self.VALID_ENDPOINTS: + raise ValueError('Unknown final endpoint %s' % final_endpoint) + + super(InceptionI3d, self).__init__() + self._num_classes = num_classes + self._spatial_squeeze = spatial_squeeze + self._final_endpoint = final_endpoint + self.logits = None + + if self._final_endpoint not in self.VALID_ENDPOINTS: + raise ValueError('Unknown final endpoint %s' % + self._final_endpoint) + + self.end_points = {} + end_point = 'Conv3d_1a_7x7' + self.end_points[end_point] = Unit3D(in_channels=in_channels, + output_channels=64, + kernel_shape=[7, 7, 7], + stride=(2, 2, 2), + padding=(3, 3, 3), + name=name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'MaxPool3d_2a_3x3' + self.end_points[end_point] = MaxPool3dSamePadding( + kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0) + if self._final_endpoint == end_point: + return + + end_point = 'Conv3d_2b_1x1' + self.end_points[end_point] = Unit3D(in_channels=64, + output_channels=64, + kernel_shape=[1, 1, 1], + padding=0, + name=name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'Conv3d_2c_3x3' + self.end_points[end_point] = Unit3D(in_channels=64, + output_channels=192, + kernel_shape=[3, 3, 3], + padding=1, + name=name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'MaxPool3d_3a_3x3' + self.end_points[end_point] = MaxPool3dSamePadding( + kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_3b' + self.end_points[end_point] = InceptionModule(192, + [64, 96, 128, 16, 32, 32], + name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_3c' + self.end_points[end_point] = InceptionModule( + 256, [128, 128, 192, 32, 96, 64], name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'MaxPool3d_4a_3x3' + self.end_points[end_point] = MaxPool3dSamePadding( + kernel_size=[3, 3, 3], stride=(2, 2, 2), padding=0) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_4b' + self.end_points[end_point] = InceptionModule( + 128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_4c' + self.end_points[end_point] = InceptionModule( + 192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_4d' + self.end_points[end_point] = InceptionModule( + 160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_4e' + self.end_points[end_point] = InceptionModule( + 128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_4f' + self.end_points[end_point] = InceptionModule( + 112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128], + name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'MaxPool3d_5a_2x2' + self.end_points[end_point] = MaxPool3dSamePadding( + kernel_size=[2, 2, 2], stride=(2, 2, 2), padding=0) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_5b' + self.end_points[end_point] = InceptionModule( + 256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128], + name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_5c' + self.end_points[end_point] = InceptionModule( + 256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128], + name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'Logits' + self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], stride=(1, 1, 1)) + self.dropout = nn.Dropout(dropout_keep_prob) + self.logits = Unit3D(in_channels=384 + 384 + 128 + 128, + output_channels=self._num_classes, + kernel_shape=[1, 1, 1], + padding=0, + activation_fn=None, + use_batch_norm=False, + use_bias=True, + name='logits') + + self.build() + + def replace_logits(self, num_classes): + self._num_classes = num_classes + self.logits = Unit3D(in_channels=384 + 384 + 128 + 128, + output_channels=self._num_classes, + kernel_shape=[1, 1, 1], + padding=0, + activation_fn=None, + use_batch_norm=False, + use_bias=True, + name='logits') + + def build(self): + for k in self.end_points.keys(): + self.add_module(k, self.end_points[k]) + + def forward(self, x): + for end_point in self.VALID_ENDPOINTS: + if end_point in self.end_points: + x = self._modules[end_point]( + x) # use _modules to work with dataparallel + + x = self.logits(self.dropout(self.avg_pool(x))) + if self._spatial_squeeze: + logits = x.squeeze(3).squeeze(3) + # logits is batch X time X classes, which is what we want to work with + return logits + + def extract_features(self, x, target_endpoint='Logits'): + for end_point in self.VALID_ENDPOINTS: + if end_point in self.end_points: + x = self._modules[end_point](x) + if end_point == target_endpoint: + break + if target_endpoint == 'Logits': + return x.mean(4).mean(3).mean(2) + else: + return x diff --git a/core/prefetch_dataloader.py b/core/prefetch_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..5088425050d4cc98114a9b93eb50ea60273f35a0 --- /dev/null +++ b/core/prefetch_dataloader.py @@ -0,0 +1,125 @@ +import queue as Queue +import threading +import torch +from torch.utils.data import DataLoader + + +class PrefetchGenerator(threading.Thread): + """A general prefetch generator. + + Ref: + https://stackoverflow.com/questions/7323664/python-generator-pre-fetch + + Args: + generator: Python generator. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, generator, num_prefetch_queue): + threading.Thread.__init__(self) + self.queue = Queue.Queue(num_prefetch_queue) + self.generator = generator + self.daemon = True + self.start() + + def run(self): + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def __next__(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class PrefetchDataLoader(DataLoader): + """Prefetch version of dataloader. + + Ref: + https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# + + TODO: + Need to test on single gpu and ddp (multi-gpu). There is a known issue in + ddp. + + Args: + num_prefetch_queue (int): Number of prefetch queue. + kwargs (dict): Other arguments for dataloader. + """ + + def __init__(self, num_prefetch_queue, **kwargs): + self.num_prefetch_queue = num_prefetch_queue + super(PrefetchDataLoader, self).__init__(**kwargs) + + def __iter__(self): + return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) + + +class CPUPrefetcher(): + """CPU prefetcher. + + Args: + loader: Dataloader. + """ + + def __init__(self, loader): + self.ori_loader = loader + self.loader = iter(loader) + + def next(self): + try: + return next(self.loader) + except StopIteration: + return None + + def reset(self): + self.loader = iter(self.ori_loader) + + +class CUDAPrefetcher(): + """CUDA prefetcher. + + Ref: + https://github.com/NVIDIA/apex/issues/304# + + It may consums more GPU memory. + + Args: + loader: Dataloader. + opt (dict): Options. + """ + + def __init__(self, loader, opt): + self.ori_loader = loader + self.loader = iter(loader) + self.opt = opt + self.stream = torch.cuda.Stream() + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.preload() + + def preload(self): + try: + self.batch = next(self.loader) # self.batch is a dict + except StopIteration: + self.batch = None + return None + # put tensors to gpu + with torch.cuda.stream(self.stream): + for k, v in self.batch.items(): + if torch.is_tensor(v): + self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + self.preload() + return batch + + def reset(self): + self.loader = iter(self.ori_loader) + self.preload() diff --git a/core/trainer.py b/core/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e90ec8c66e70a0355191dcc374ccc2963e8ca2a8 --- /dev/null +++ b/core/trainer.py @@ -0,0 +1,509 @@ +import os +import glob +import logging +import importlib +from tqdm import tqdm + +import torch +import torch.nn as nn +import torch.nn.functional as F +from core.prefetch_dataloader import PrefetchDataLoader, CPUPrefetcher +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP +import torchvision +from torch.utils.tensorboard import SummaryWriter + +from core.lr_scheduler import MultiStepRestartLR, CosineAnnealingRestartLR +from core.loss import AdversarialLoss, PerceptualLoss, LPIPSLoss +from core.dataset import TrainDataset + +from model.modules.flow_comp_raft import RAFT_bi, FlowLoss, EdgeLoss +from model.recurrent_flow_completion import RecurrentFlowCompleteNet + +from RAFT.utils.flow_viz_pt import flow_to_image + + +class Trainer: + def __init__(self, config): + self.config = config + self.epoch = 0 + self.iteration = 0 + self.num_local_frames = config['train_data_loader']['num_local_frames'] + self.num_ref_frames = config['train_data_loader']['num_ref_frames'] + + # setup data set and data loader + self.train_dataset = TrainDataset(config['train_data_loader']) + + self.train_sampler = None + self.train_args = config['trainer'] + if config['distributed']: + self.train_sampler = DistributedSampler( + self.train_dataset, + num_replicas=config['world_size'], + rank=config['global_rank']) + + dataloader_args = dict( + dataset=self.train_dataset, + batch_size=self.train_args['batch_size'] // config['world_size'], + shuffle=(self.train_sampler is None), + num_workers=self.train_args['num_workers'], + sampler=self.train_sampler, + drop_last=True) + + self.train_loader = PrefetchDataLoader(self.train_args['num_prefetch_queue'], **dataloader_args) + self.prefetcher = CPUPrefetcher(self.train_loader) + + # set loss functions + self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS']) + self.adversarial_loss = self.adversarial_loss.to(self.config['device']) + self.l1_loss = nn.L1Loss() + # self.perc_loss = PerceptualLoss( + # layer_weights={'conv3_4': 0.25, 'conv4_4': 0.25, 'conv5_4': 0.5}, + # use_input_norm=True, + # range_norm=True, + # criterion='l1' + # ).to(self.config['device']) + + if self.config['losses']['perceptual_weight'] > 0: + self.perc_loss = LPIPSLoss(use_input_norm=True, range_norm=True).to(self.config['device']) + + # self.flow_comp_loss = FlowCompletionLoss().to(self.config['device']) + # self.flow_comp_loss = FlowCompletionLoss(self.config['device']) + + # set raft + self.fix_raft = RAFT_bi(device = self.config['device']) + self.fix_flow_complete = RecurrentFlowCompleteNet('/mnt/lustre/sczhou/VQGANs/CodeMOVI/experiments_model/recurrent_flow_completion_v5_train_flowcomp_v5/gen_760000.pth') + for p in self.fix_flow_complete.parameters(): + p.requires_grad = False + self.fix_flow_complete.to(self.config['device']) + self.fix_flow_complete.eval() + + # self.flow_loss = FlowLoss() + + # setup models including generator and discriminator + net = importlib.import_module('model.' + config['model']['net']) + self.netG = net.InpaintGenerator() + # print(self.netG) + self.netG = self.netG.to(self.config['device']) + if not self.config['model'].get('no_dis', False): + if self.config['model'].get('dis_2d', False): + self.netD = net.Discriminator_2D( + in_channels=3, + use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge') + else: + self.netD = net.Discriminator( + in_channels=3, + use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge') + self.netD = self.netD.to(self.config['device']) + + self.interp_mode = self.config['model']['interp_mode'] + # setup optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + self.load() + + if config['distributed']: + self.netG = DDP(self.netG, + device_ids=[self.config['local_rank']], + output_device=self.config['local_rank'], + broadcast_buffers=True, + find_unused_parameters=True) + if not self.config['model']['no_dis']: + self.netD = DDP(self.netD, + device_ids=[self.config['local_rank']], + output_device=self.config['local_rank'], + broadcast_buffers=True, + find_unused_parameters=False) + + # set summary writer + self.dis_writer = None + self.gen_writer = None + self.summary = {} + if self.config['global_rank'] == 0 or (not config['distributed']): + if not self.config['model']['no_dis']: + self.dis_writer = SummaryWriter( + os.path.join(config['save_dir'], 'dis')) + self.gen_writer = SummaryWriter( + os.path.join(config['save_dir'], 'gen')) + + def setup_optimizers(self): + """Set up optimizers.""" + backbone_params = [] + for name, param in self.netG.named_parameters(): + if param.requires_grad: + backbone_params.append(param) + else: + print(f'Params {name} will not be optimized.') + + optim_params = [ + { + 'params': backbone_params, + 'lr': self.config['trainer']['lr'] + }, + ] + + self.optimG = torch.optim.Adam(optim_params, + betas=(self.config['trainer']['beta1'], + self.config['trainer']['beta2'])) + + if not self.config['model']['no_dis']: + self.optimD = torch.optim.Adam( + self.netD.parameters(), + lr=self.config['trainer']['lr'], + betas=(self.config['trainer']['beta1'], + self.config['trainer']['beta2'])) + + def setup_schedulers(self): + """Set up schedulers.""" + scheduler_opt = self.config['trainer']['scheduler'] + scheduler_type = scheduler_opt.pop('type') + + if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + self.scheG = MultiStepRestartLR( + self.optimG, + milestones=scheduler_opt['milestones'], + gamma=scheduler_opt['gamma']) + if not self.config['model']['no_dis']: + self.scheD = MultiStepRestartLR( + self.optimD, + milestones=scheduler_opt['milestones'], + gamma=scheduler_opt['gamma']) + elif scheduler_type == 'CosineAnnealingRestartLR': + self.scheG = CosineAnnealingRestartLR( + self.optimG, + periods=scheduler_opt['periods'], + restart_weights=scheduler_opt['restart_weights'], + eta_min=scheduler_opt['eta_min']) + if not self.config['model']['no_dis']: + self.scheD = CosineAnnealingRestartLR( + self.optimD, + periods=scheduler_opt['periods'], + restart_weights=scheduler_opt['restart_weights'], + eta_min=scheduler_opt['eta_min']) + else: + raise NotImplementedError( + f'Scheduler {scheduler_type} is not implemented yet.') + + def update_learning_rate(self): + """Update learning rate.""" + self.scheG.step() + if not self.config['model']['no_dis']: + self.scheD.step() + + def get_lr(self): + """Get current learning rate.""" + return self.optimG.param_groups[0]['lr'] + + def add_summary(self, writer, name, val): + """Add tensorboard summary.""" + if name not in self.summary: + self.summary[name] = 0 + self.summary[name] += val + n = self.train_args['log_freq'] + if writer is not None and self.iteration % n == 0: + writer.add_scalar(name, self.summary[name] / n, self.iteration) + self.summary[name] = 0 + + def load(self): + """Load netG (and netD).""" + # get the latest checkpoint + model_path = self.config['save_dir'] + # TODO: add resume name + if os.path.isfile(os.path.join(model_path, 'latest.ckpt')): + latest_epoch = open(os.path.join(model_path, 'latest.ckpt'), + 'r').read().splitlines()[-1] + else: + ckpts = [ + os.path.basename(i).split('.pth')[0] + for i in glob.glob(os.path.join(model_path, '*.pth')) + ] + ckpts.sort() + latest_epoch = ckpts[-1][4:] if len(ckpts) > 0 else None + + if latest_epoch is not None: + gen_path = os.path.join(model_path, + f'gen_{int(latest_epoch):06d}.pth') + dis_path = os.path.join(model_path, + f'dis_{int(latest_epoch):06d}.pth') + opt_path = os.path.join(model_path, + f'opt_{int(latest_epoch):06d}.pth') + + if self.config['global_rank'] == 0: + print(f'Loading model from {gen_path}...') + dataG = torch.load(gen_path, map_location=self.config['device']) + self.netG.load_state_dict(dataG) + if not self.config['model']['no_dis'] and self.config['model']['load_d']: + dataD = torch.load(dis_path, map_location=self.config['device']) + self.netD.load_state_dict(dataD) + + data_opt = torch.load(opt_path, map_location=self.config['device']) + self.optimG.load_state_dict(data_opt['optimG']) + # self.scheG.load_state_dict(data_opt['scheG']) + if not self.config['model']['no_dis'] and self.config['model']['load_d']: + self.optimD.load_state_dict(data_opt['optimD']) + # self.scheD.load_state_dict(data_opt['scheD']) + self.epoch = data_opt['epoch'] + self.iteration = data_opt['iteration'] + else: + gen_path = self.config['trainer'].get('gen_path', None) + dis_path = self.config['trainer'].get('dis_path', None) + opt_path = self.config['trainer'].get('opt_path', None) + if gen_path is not None: + if self.config['global_rank'] == 0: + print(f'Loading Gen-Net from {gen_path}...') + dataG = torch.load(gen_path, map_location=self.config['device']) + self.netG.load_state_dict(dataG) + + if dis_path is not None and not self.config['model']['no_dis'] and self.config['model']['load_d']: + if self.config['global_rank'] == 0: + print(f'Loading Dis-Net from {dis_path}...') + dataD = torch.load(dis_path, map_location=self.config['device']) + self.netD.load_state_dict(dataD) + if opt_path is not None: + data_opt = torch.load(opt_path, map_location=self.config['device']) + self.optimG.load_state_dict(data_opt['optimG']) + self.scheG.load_state_dict(data_opt['scheG']) + if not self.config['model']['no_dis'] and self.config['model']['load_d']: + self.optimD.load_state_dict(data_opt['optimD']) + self.scheD.load_state_dict(data_opt['scheD']) + else: + if self.config['global_rank'] == 0: + print('Warnning: There is no trained model found.' + 'An initialized model will be used.') + + def save(self, it): + """Save parameters every eval_epoch""" + if self.config['global_rank'] == 0: + # configure path + gen_path = os.path.join(self.config['save_dir'], + f'gen_{it:06d}.pth') + dis_path = os.path.join(self.config['save_dir'], + f'dis_{it:06d}.pth') + opt_path = os.path.join(self.config['save_dir'], + f'opt_{it:06d}.pth') + print(f'\nsaving model to {gen_path} ...') + + # remove .module for saving + if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP): + netG = self.netG.module + if not self.config['model']['no_dis']: + netD = self.netD.module + else: + netG = self.netG + if not self.config['model']['no_dis']: + netD = self.netD + + # save checkpoints + torch.save(netG.state_dict(), gen_path) + if not self.config['model']['no_dis']: + torch.save(netD.state_dict(), dis_path) + torch.save( + { + 'epoch': self.epoch, + 'iteration': self.iteration, + 'optimG': self.optimG.state_dict(), + 'optimD': self.optimD.state_dict(), + 'scheG': self.scheG.state_dict(), + 'scheD': self.scheD.state_dict() + }, opt_path) + else: + torch.save( + { + 'epoch': self.epoch, + 'iteration': self.iteration, + 'optimG': self.optimG.state_dict(), + 'scheG': self.scheG.state_dict() + }, opt_path) + + latest_path = os.path.join(self.config['save_dir'], 'latest.ckpt') + os.system(f"echo {it:06d} > {latest_path}") + + def train(self): + """training entry""" + pbar = range(int(self.train_args['iterations'])) + if self.config['global_rank'] == 0: + pbar = tqdm(pbar, + initial=self.iteration, + dynamic_ncols=True, + smoothing=0.01) + + os.makedirs('logs', exist_ok=True) + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(filename)s[line:%(lineno)d]" + "%(levelname)s %(message)s", + datefmt="%a, %d %b %Y %H:%M:%S", + filename=f"logs/{self.config['save_dir'].split('/')[-1]}.log", + filemode='w') + + while True: + self.epoch += 1 + self.prefetcher.reset() + if self.config['distributed']: + self.train_sampler.set_epoch(self.epoch) + self._train_epoch(pbar) + if self.iteration > self.train_args['iterations']: + break + print('\nEnd training....') + + def _train_epoch(self, pbar): + """Process input and calculate loss every training epoch""" + device = self.config['device'] + train_data = self.prefetcher.next() + while train_data is not None: + self.iteration += 1 + frames, masks, flows_f, flows_b, _ = train_data + frames, masks = frames.to(device), masks.to(device).float() + l_t = self.num_local_frames + b, t, c, h, w = frames.size() + gt_local_frames = frames[:, :l_t, ...] + local_masks = masks[:, :l_t, ...].contiguous() + + masked_frames = frames * (1 - masks) + masked_local_frames = masked_frames[:, :l_t, ...] + # get gt optical flow + if flows_f[0] == 'None' or flows_b[0] == 'None': + gt_flows_bi = self.fix_raft(gt_local_frames) + else: + gt_flows_bi = (flows_f.to(device), flows_b.to(device)) + + # ---- complete flow ---- + pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, local_masks) + pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, local_masks) + # pred_flows_bi = gt_flows_bi + + # ---- image propagation ---- + prop_imgs, updated_local_masks = self.netG.module.img_propagation(masked_local_frames, pred_flows_bi, local_masks, interpolation=self.interp_mode) + updated_masks = masks.clone() + updated_masks[:, :l_t, ...] = updated_local_masks.view(b, l_t, 1, h, w) + updated_frames = masked_frames.clone() + prop_local_frames = gt_local_frames * (1-local_masks) + prop_imgs.view(b, l_t, 3, h, w) * local_masks # merge + updated_frames[:, :l_t, ...] = prop_local_frames + + # ---- feature propagation + Transformer ---- + pred_imgs = self.netG(updated_frames, pred_flows_bi, masks, updated_masks, l_t) + pred_imgs = pred_imgs.view(b, -1, c, h, w) + + # get the local frames + pred_local_frames = pred_imgs[:, :l_t, ...] + comp_local_frames = gt_local_frames * (1. - local_masks) + pred_local_frames * local_masks + comp_imgs = frames * (1. - masks) + pred_imgs * masks + + gen_loss = 0 + dis_loss = 0 + # optimize net_g + if not self.config['model']['no_dis']: + for p in self.netD.parameters(): + p.requires_grad = False + + self.optimG.zero_grad() + + # generator l1 loss + hole_loss = self.l1_loss(pred_imgs * masks, frames * masks) + hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight'] + gen_loss += hole_loss + self.add_summary(self.gen_writer, 'loss/hole_loss', hole_loss.item()) + + valid_loss = self.l1_loss(pred_imgs * (1 - masks), frames * (1 - masks)) + valid_loss = valid_loss / torch.mean(1-masks) * self.config['losses']['valid_weight'] + gen_loss += valid_loss + self.add_summary(self.gen_writer, 'loss/valid_loss', valid_loss.item()) + + # perceptual loss + if self.config['losses']['perceptual_weight'] > 0: + perc_loss = self.perc_loss(pred_imgs.view(-1,3,h,w), frames.view(-1,3,h,w))[0] * self.config['losses']['perceptual_weight'] + gen_loss += perc_loss + self.add_summary(self.gen_writer, 'loss/perc_loss', perc_loss.item()) + + # gan loss + if not self.config['model']['no_dis']: + # generator adversarial loss + gen_clip = self.netD(comp_imgs) + gan_loss = self.adversarial_loss(gen_clip, True, False) + gan_loss = gan_loss * self.config['losses']['adversarial_weight'] + gen_loss += gan_loss + self.add_summary(self.gen_writer, 'loss/gan_loss', gan_loss.item()) + gen_loss.backward() + self.optimG.step() + + if not self.config['model']['no_dis']: + # optimize net_d + for p in self.netD.parameters(): + p.requires_grad = True + self.optimD.zero_grad() + + # discriminator adversarial loss + real_clip = self.netD(frames) + fake_clip = self.netD(comp_imgs.detach()) + dis_real_loss = self.adversarial_loss(real_clip, True, True) + dis_fake_loss = self.adversarial_loss(fake_clip, False, True) + dis_loss += (dis_real_loss + dis_fake_loss) / 2 + self.add_summary(self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item()) + self.add_summary(self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item()) + dis_loss.backward() + self.optimD.step() + + self.update_learning_rate() + + # write image to tensorboard + if self.iteration % 200 == 0: + # img to cpu + t = 0 + gt_local_frames_cpu = ((gt_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu() + masked_local_frames = ((masked_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu() + prop_local_frames_cpu = ((prop_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu() + pred_local_frames_cpu = ((pred_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu() + img_results = torch.cat([masked_local_frames[0][t], gt_local_frames_cpu[0][t], + prop_local_frames_cpu[0][t], pred_local_frames_cpu[0][t]], 1) + img_results = torchvision.utils.make_grid(img_results, nrow=1, normalize=True) + if self.gen_writer is not None: + self.gen_writer.add_image(f'img/img:inp-gt-res-{t}', img_results, self.iteration) + + t = 5 + if masked_local_frames.shape[1] > 5: + img_results = torch.cat([masked_local_frames[0][t], gt_local_frames_cpu[0][t], + prop_local_frames_cpu[0][t], pred_local_frames_cpu[0][t]], 1) + img_results = torchvision.utils.make_grid(img_results, nrow=1, normalize=True) + if self.gen_writer is not None: + self.gen_writer.add_image(f'img/img:inp-gt-res-{t}', img_results, self.iteration) + + # flow to cpu + gt_flows_forward_cpu = flow_to_image(gt_flows_bi[0][0]).cpu() + masked_flows_forward_cpu = (gt_flows_forward_cpu[0] * (1-local_masks[0][0].cpu())).to(gt_flows_forward_cpu) + pred_flows_forward_cpu = flow_to_image(pred_flows_bi[0][0]).cpu() + + flow_results = torch.cat([gt_flows_forward_cpu[0], masked_flows_forward_cpu, pred_flows_forward_cpu[0]], 1) + if self.gen_writer is not None: + self.gen_writer.add_image('img/flow:gt-pred', flow_results, self.iteration) + + # console logs + if self.config['global_rank'] == 0: + pbar.update(1) + if not self.config['model']['no_dis']: + pbar.set_description((f"d: {dis_loss.item():.3f}; " + f"hole: {hole_loss.item():.3f}; " + f"valid: {valid_loss.item():.3f}")) + else: + pbar.set_description((f"hole: {hole_loss.item():.3f}; " + f"valid: {valid_loss.item():.3f}")) + + if self.iteration % self.train_args['log_freq'] == 0: + if not self.config['model']['no_dis']: + logging.info(f"[Iter {self.iteration}] " + f"d: {dis_loss.item():.4f}; " + f"hole: {hole_loss.item():.4f}; " + f"valid: {valid_loss.item():.4f}") + else: + logging.info(f"[Iter {self.iteration}] " + f"hole: {hole_loss.item():.4f}; " + f"valid: {valid_loss.item():.4f}") + + # saving models + if self.iteration % self.train_args['save_freq'] == 0: + self.save(int(self.iteration)) + + if self.iteration > self.train_args['iterations']: + break + + train_data = self.prefetcher.next() \ No newline at end of file diff --git a/core/trainer_flow_w_edge.py b/core/trainer_flow_w_edge.py new file mode 100644 index 0000000000000000000000000000000000000000..d4eba04c8a5fa56bce3e335e6036bc0e0a1e848a --- /dev/null +++ b/core/trainer_flow_w_edge.py @@ -0,0 +1,380 @@ +import os +import glob +import logging +import importlib +from tqdm import tqdm + +import torch +import torch.nn as nn +import torch.nn.functional as F +from core.prefetch_dataloader import PrefetchDataLoader, CPUPrefetcher +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP + +from torch.utils.tensorboard import SummaryWriter + +from core.lr_scheduler import MultiStepRestartLR, CosineAnnealingRestartLR +from core.dataset import TrainDataset + +from model.modules.flow_comp_raft import RAFT_bi, FlowLoss, EdgeLoss + +# from skimage.feature import canny +from model.canny.canny_filter import Canny +from RAFT.utils.flow_viz_pt import flow_to_image + + +class Trainer: + def __init__(self, config): + self.config = config + self.epoch = 0 + self.iteration = 0 + self.num_local_frames = config['train_data_loader']['num_local_frames'] + self.num_ref_frames = config['train_data_loader']['num_ref_frames'] + + # setup data set and data loader + self.train_dataset = TrainDataset(config['train_data_loader']) + + self.train_sampler = None + self.train_args = config['trainer'] + if config['distributed']: + self.train_sampler = DistributedSampler( + self.train_dataset, + num_replicas=config['world_size'], + rank=config['global_rank']) + + dataloader_args = dict( + dataset=self.train_dataset, + batch_size=self.train_args['batch_size'] // config['world_size'], + shuffle=(self.train_sampler is None), + num_workers=self.train_args['num_workers'], + sampler=self.train_sampler, + drop_last=True) + + self.train_loader = PrefetchDataLoader(self.train_args['num_prefetch_queue'], **dataloader_args) + self.prefetcher = CPUPrefetcher(self.train_loader) + + # set raft + self.fix_raft = RAFT_bi(device = self.config['device']) + self.flow_loss = FlowLoss() + self.edge_loss = EdgeLoss() + self.canny = Canny(sigma=(2,2), low_threshold=0.1, high_threshold=0.2) + + # setup models including generator and discriminator + net = importlib.import_module('model.' + config['model']['net']) + self.netG = net.RecurrentFlowCompleteNet() + # print(self.netG) + self.netG = self.netG.to(self.config['device']) + + # setup optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + self.load() + + if config['distributed']: + self.netG = DDP(self.netG, + device_ids=[self.config['local_rank']], + output_device=self.config['local_rank'], + broadcast_buffers=True, + find_unused_parameters=True) + + # set summary writer + self.dis_writer = None + self.gen_writer = None + self.summary = {} + if self.config['global_rank'] == 0 or (not config['distributed']): + self.gen_writer = SummaryWriter( + os.path.join(config['save_dir'], 'gen')) + + def setup_optimizers(self): + """Set up optimizers.""" + backbone_params = [] + for name, param in self.netG.named_parameters(): + if param.requires_grad: + backbone_params.append(param) + else: + print(f'Params {name} will not be optimized.') + + optim_params = [ + { + 'params': backbone_params, + 'lr': self.config['trainer']['lr'] + }, + ] + + self.optimG = torch.optim.Adam(optim_params, + betas=(self.config['trainer']['beta1'], + self.config['trainer']['beta2'])) + + + def setup_schedulers(self): + """Set up schedulers.""" + scheduler_opt = self.config['trainer']['scheduler'] + scheduler_type = scheduler_opt.pop('type') + + if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + self.scheG = MultiStepRestartLR( + self.optimG, + milestones=scheduler_opt['milestones'], + gamma=scheduler_opt['gamma']) + elif scheduler_type == 'CosineAnnealingRestartLR': + self.scheG = CosineAnnealingRestartLR( + self.optimG, + periods=scheduler_opt['periods'], + restart_weights=scheduler_opt['restart_weights']) + else: + raise NotImplementedError( + f'Scheduler {scheduler_type} is not implemented yet.') + + def update_learning_rate(self): + """Update learning rate.""" + self.scheG.step() + + def get_lr(self): + """Get current learning rate.""" + return self.optimG.param_groups[0]['lr'] + + def add_summary(self, writer, name, val): + """Add tensorboard summary.""" + if name not in self.summary: + self.summary[name] = 0 + self.summary[name] += val + n = self.train_args['log_freq'] + if writer is not None and self.iteration % n == 0: + writer.add_scalar(name, self.summary[name] / n, self.iteration) + self.summary[name] = 0 + + def load(self): + """Load netG.""" + # get the latest checkpoint + model_path = self.config['save_dir'] + if os.path.isfile(os.path.join(model_path, 'latest.ckpt')): + latest_epoch = open(os.path.join(model_path, 'latest.ckpt'), + 'r').read().splitlines()[-1] + else: + ckpts = [ + os.path.basename(i).split('.pth')[0] + for i in glob.glob(os.path.join(model_path, '*.pth')) + ] + ckpts.sort() + latest_epoch = ckpts[-1][4:] if len(ckpts) > 0 else None + + if latest_epoch is not None: + gen_path = os.path.join(model_path, f'gen_{int(latest_epoch):06d}.pth') + opt_path = os.path.join(model_path,f'opt_{int(latest_epoch):06d}.pth') + + if self.config['global_rank'] == 0: + print(f'Loading model from {gen_path}...') + dataG = torch.load(gen_path, map_location=self.config['device']) + self.netG.load_state_dict(dataG) + + + data_opt = torch.load(opt_path, map_location=self.config['device']) + self.optimG.load_state_dict(data_opt['optimG']) + self.scheG.load_state_dict(data_opt['scheG']) + + self.epoch = data_opt['epoch'] + self.iteration = data_opt['iteration'] + + else: + if self.config['global_rank'] == 0: + print('Warnning: There is no trained model found.' + 'An initialized model will be used.') + + def save(self, it): + """Save parameters every eval_epoch""" + if self.config['global_rank'] == 0: + # configure path + gen_path = os.path.join(self.config['save_dir'], + f'gen_{it:06d}.pth') + opt_path = os.path.join(self.config['save_dir'], + f'opt_{it:06d}.pth') + print(f'\nsaving model to {gen_path} ...') + + # remove .module for saving + if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP): + netG = self.netG.module + else: + netG = self.netG + + # save checkpoints + torch.save(netG.state_dict(), gen_path) + torch.save( + { + 'epoch': self.epoch, + 'iteration': self.iteration, + 'optimG': self.optimG.state_dict(), + 'scheG': self.scheG.state_dict() + }, opt_path) + + latest_path = os.path.join(self.config['save_dir'], 'latest.ckpt') + os.system(f"echo {it:06d} > {latest_path}") + + def train(self): + """training entry""" + pbar = range(int(self.train_args['iterations'])) + if self.config['global_rank'] == 0: + pbar = tqdm(pbar, + initial=self.iteration, + dynamic_ncols=True, + smoothing=0.01) + + os.makedirs('logs', exist_ok=True) + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(filename)s[line:%(lineno)d]" + "%(levelname)s %(message)s", + datefmt="%a, %d %b %Y %H:%M:%S", + filename=f"logs/{self.config['save_dir'].split('/')[-1]}.log", + filemode='w') + + while True: + self.epoch += 1 + self.prefetcher.reset() + if self.config['distributed']: + self.train_sampler.set_epoch(self.epoch) + self._train_epoch(pbar) + if self.iteration > self.train_args['iterations']: + break + print('\nEnd training....') + + # def get_edges(self, flows): # fgvc + # # (b, t, 2, H, W) + # b, t, _, h, w = flows.shape + # flows = flows.view(-1, 2, h, w) + # flows_list = flows.permute(0, 2, 3, 1).cpu().numpy() + # edges = [] + # for f in list(flows_list): + # flows_gray = (f[:, :, 0] ** 2 + f[:, :, 1] ** 2) ** 0.5 + # if flows_gray.max() < 1: + # flows_gray = flows_gray*0 + # else: + # flows_gray = flows_gray / flows_gray.max() + + # edge = canny(flows_gray, sigma=2, low_threshold=0.1, high_threshold=0.2) # fgvc + # edge = torch.from_numpy(edge).view(1, 1, h, w).float() + # edges.append(edge) + # edges = torch.stack(edges, dim=0).to(self.config['device']) + # edges = edges.view(b, t, 1, h, w) + # return edges + + def get_edges(self, flows): + # (b, t, 2, H, W) + b, t, _, h, w = flows.shape + flows = flows.view(-1, 2, h, w) + flows_gray = (flows[:, 0, None] ** 2 + flows[:, 1, None] ** 2) ** 0.5 + if flows_gray.max() < 1: + flows_gray = flows_gray*0 + else: + flows_gray = flows_gray / flows_gray.max() + + magnitude, edges = self.canny(flows_gray.float()) + edges = edges.view(b, t, 1, h, w) + return edges + + def _train_epoch(self, pbar): + """Process input and calculate loss every training epoch""" + device = self.config['device'] + train_data = self.prefetcher.next() + while train_data is not None: + self.iteration += 1 + frames, masks, flows_f, flows_b, _ = train_data + frames, masks = frames.to(device), masks.to(device) + masks = masks.float() + + l_t = self.num_local_frames + b, t, c, h, w = frames.size() + gt_local_frames = frames[:, :l_t, ...] + local_masks = masks[:, :l_t, ...].contiguous() + + # get gt optical flow + if flows_f[0] == 'None' or flows_b[0] == 'None': + gt_flows_bi = self.fix_raft(gt_local_frames) + else: + gt_flows_bi = (flows_f.to(device), flows_b.to(device)) + + # get gt edge + gt_edges_forward = self.get_edges(gt_flows_bi[0]) + gt_edges_backward = self.get_edges(gt_flows_bi[1]) + gt_edges_bi = [gt_edges_forward, gt_edges_backward] + + # complete flow + pred_flows_bi, pred_edges_bi = self.netG.module.forward_bidirect_flow(gt_flows_bi, local_masks) + + # optimize net_g + self.optimG.zero_grad() + + # compulte flow_loss + flow_loss, warp_loss = self.flow_loss(pred_flows_bi, gt_flows_bi, local_masks, gt_local_frames) + flow_loss = flow_loss * self.config['losses']['flow_weight'] + warp_loss = warp_loss * 0.01 + self.add_summary(self.gen_writer, 'loss/flow_loss', flow_loss.item()) + self.add_summary(self.gen_writer, 'loss/warp_loss', warp_loss.item()) + + # compute edge loss + edge_loss = self.edge_loss(pred_edges_bi, gt_edges_bi, local_masks) + edge_loss = edge_loss*1.0 + self.add_summary(self.gen_writer, 'loss/edge_loss', edge_loss.item()) + + loss = flow_loss + warp_loss + edge_loss + loss.backward() + self.optimG.step() + self.update_learning_rate() + + # write image to tensorboard + # if self.iteration % 200 == 0: + if self.iteration % 200 == 0 and self.gen_writer is not None: + t = 5 + # forward to cpu + gt_flows_forward_cpu = flow_to_image(gt_flows_bi[0][0]).cpu() + masked_flows_forward_cpu = (gt_flows_forward_cpu[t] * (1-local_masks[0][t].cpu())).to(gt_flows_forward_cpu) + pred_flows_forward_cpu = flow_to_image(pred_flows_bi[0][0]).cpu() + + flow_results = torch.cat([gt_flows_forward_cpu[t], masked_flows_forward_cpu, pred_flows_forward_cpu[t]], 1) + self.gen_writer.add_image('img/flow-f:gt-pred', flow_results, self.iteration) + + # backward to cpu + gt_flows_backward_cpu = flow_to_image(gt_flows_bi[1][0]).cpu() + masked_flows_backward_cpu = (gt_flows_backward_cpu[t] * (1-local_masks[0][t+1].cpu())).to(gt_flows_backward_cpu) + pred_flows_backward_cpu = flow_to_image(pred_flows_bi[1][0]).cpu() + + flow_results = torch.cat([gt_flows_backward_cpu[t], masked_flows_backward_cpu, pred_flows_backward_cpu[t]], 1) + self.gen_writer.add_image('img/flow-b:gt-pred', flow_results, self.iteration) + + # TODO: show edge + # forward + gt_edges_forward_cpu = gt_edges_bi[0][0].cpu() + masked_edges_forward_cpu = (gt_edges_forward_cpu[t] * (1-local_masks[0][t].cpu())).to(gt_edges_forward_cpu) + pred_edges_forward_cpu = pred_edges_bi[0][0].cpu() + + edge_results = torch.cat([gt_edges_forward_cpu[t], masked_edges_forward_cpu, pred_edges_forward_cpu[t]], 1) + self.gen_writer.add_image('img/edge-f:gt-pred', edge_results, self.iteration) + # backward + gt_edges_backward_cpu = gt_edges_bi[1][0].cpu() + masked_edges_backward_cpu = (gt_edges_backward_cpu[t] * (1-local_masks[0][t+1].cpu())).to(gt_edges_backward_cpu) + pred_edges_backward_cpu = pred_edges_bi[1][0].cpu() + + edge_results = torch.cat([gt_edges_backward_cpu[t], masked_edges_backward_cpu, pred_edges_backward_cpu[t]], 1) + self.gen_writer.add_image('img/edge-b:gt-pred', edge_results, self.iteration) + + # console logs + if self.config['global_rank'] == 0: + pbar.update(1) + pbar.set_description((f"flow: {flow_loss.item():.3f}; " + f"warp: {warp_loss.item():.3f}; " + f"edge: {edge_loss.item():.3f}; " + f"lr: {self.get_lr()}")) + + if self.iteration % self.train_args['log_freq'] == 0: + logging.info(f"[Iter {self.iteration}] " + f"flow: {flow_loss.item():.4f}; " + f"warp: {warp_loss.item():.4f}") + + # saving models + if self.iteration % self.train_args['save_freq'] == 0: + self.save(int(self.iteration)) + + if self.iteration > self.train_args['iterations']: + break + + train_data = self.prefetcher.next() \ No newline at end of file diff --git a/core/utils.py b/core/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..37dccb2d26e6916aacbd530ab03726a7c54f8ec8 --- /dev/null +++ b/core/utils.py @@ -0,0 +1,371 @@ +import os +import io +import cv2 +import random +import numpy as np +from PIL import Image, ImageOps +import zipfile +import math + +import torch +import matplotlib +import matplotlib.patches as patches +from matplotlib.path import Path +from matplotlib import pyplot as plt +from torchvision import transforms + +# matplotlib.use('agg') + +# ########################################################################### +# Directory IO +# ########################################################################### + + +def read_dirnames_under_root(root_dir): + dirnames = [ + name for i, name in enumerate(sorted(os.listdir(root_dir))) + if os.path.isdir(os.path.join(root_dir, name)) + ] + print(f'Reading directories under {root_dir}, num: {len(dirnames)}') + return dirnames + + +class TrainZipReader(object): + file_dict = dict() + + def __init__(self): + super(TrainZipReader, self).__init__() + + @staticmethod + def build_file_dict(path): + file_dict = TrainZipReader.file_dict + if path in file_dict: + return file_dict[path] + else: + file_handle = zipfile.ZipFile(path, 'r') + file_dict[path] = file_handle + return file_dict[path] + + @staticmethod + def imread(path, idx): + zfile = TrainZipReader.build_file_dict(path) + filelist = zfile.namelist() + filelist.sort() + data = zfile.read(filelist[idx]) + # + im = Image.open(io.BytesIO(data)) + return im + + +class TestZipReader(object): + file_dict = dict() + + def __init__(self): + super(TestZipReader, self).__init__() + + @staticmethod + def build_file_dict(path): + file_dict = TestZipReader.file_dict + if path in file_dict: + return file_dict[path] + else: + file_handle = zipfile.ZipFile(path, 'r') + file_dict[path] = file_handle + return file_dict[path] + + @staticmethod + def imread(path, idx): + zfile = TestZipReader.build_file_dict(path) + filelist = zfile.namelist() + filelist.sort() + data = zfile.read(filelist[idx]) + file_bytes = np.asarray(bytearray(data), dtype=np.uint8) + im = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) + im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB)) + # im = Image.open(io.BytesIO(data)) + return im + + +# ########################################################################### +# Data augmentation +# ########################################################################### + + +def to_tensors(): + return transforms.Compose([Stack(), ToTorchFormatTensor()]) + + +class GroupRandomHorizontalFlowFlip(object): + """Randomly horizontally flips the given PIL.Image with a probability of 0.5 + """ + def __call__(self, img_group, flowF_group, flowB_group): + v = random.random() + if v < 0.5: + ret_img = [ + img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group + ] + ret_flowF = [ff[:, ::-1] * [-1.0, 1.0] for ff in flowF_group] + ret_flowB = [fb[:, ::-1] * [-1.0, 1.0] for fb in flowB_group] + return ret_img, ret_flowF, ret_flowB + else: + return img_group, flowF_group, flowB_group + + +class GroupRandomHorizontalFlip(object): + """Randomly horizontally flips the given PIL.Image with a probability of 0.5 + """ + def __call__(self, img_group, is_flow=False): + v = random.random() + if v < 0.5: + ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] + if is_flow: + for i in range(0, len(ret), 2): + # invert flow pixel values when flipping + ret[i] = ImageOps.invert(ret[i]) + return ret + else: + return img_group + + +class Stack(object): + def __init__(self, roll=False): + self.roll = roll + + def __call__(self, img_group): + mode = img_group[0].mode + if mode == '1': + img_group = [img.convert('L') for img in img_group] + mode = 'L' + if mode == 'L': + return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2) + elif mode == 'RGB': + if self.roll: + return np.stack([np.array(x)[:, :, ::-1] for x in img_group], + axis=2) + else: + return np.stack(img_group, axis=2) + else: + raise NotImplementedError(f"Image mode {mode}") + + +class ToTorchFormatTensor(object): + """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] + to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ + def __init__(self, div=True): + self.div = div + + def __call__(self, pic): + if isinstance(pic, np.ndarray): + # numpy img: [L, C, H, W] + img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous() + else: + # handle PIL Image + img = torch.ByteTensor(torch.ByteStorage.from_buffer( + pic.tobytes())) + img = img.view(pic.size[1], pic.size[0], len(pic.mode)) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + img = img.float().div(255) if self.div else img.float() + return img + + +# ########################################################################### +# Create masks with random shape +# ########################################################################### + + +def create_random_shape_with_random_motion(video_length, + imageHeight=240, + imageWidth=432): + # get a random shape + height = random.randint(imageHeight // 3, imageHeight - 1) + width = random.randint(imageWidth // 3, imageWidth - 1) + edge_num = random.randint(6, 8) + ratio = random.randint(6, 8) / 10 + + region = get_random_shape(edge_num=edge_num, + ratio=ratio, + height=height, + width=width) + region_width, region_height = region.size + # get random position + x, y = random.randint(0, imageHeight - region_height), random.randint( + 0, imageWidth - region_width) + velocity = get_random_velocity(max_speed=3) + m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8)) + m.paste(region, (y, x, y + region.size[0], x + region.size[1])) + masks = [m.convert('L')] + # return fixed masks + if random.uniform(0, 1) > 0.5: + return masks * video_length + # return moving masks + for _ in range(video_length - 1): + x, y, velocity = random_move_control_points(x, + y, + imageHeight, + imageWidth, + velocity, + region.size, + maxLineAcceleration=(3, + 0.5), + maxInitSpeed=3) + m = Image.fromarray( + np.zeros((imageHeight, imageWidth)).astype(np.uint8)) + m.paste(region, (y, x, y + region.size[0], x + region.size[1])) + masks.append(m.convert('L')) + return masks + + +def create_random_shape_with_random_motion_zoom_rotation(video_length, zoomin=0.9, zoomout=1.1, rotmin=1, rotmax=10, imageHeight=240, imageWidth=432): + # get a random shape + assert zoomin < 1, "Zoom-in parameter must be smaller than 1" + assert zoomout > 1, "Zoom-out parameter must be larger than 1" + assert rotmin < rotmax, "Minimum value of rotation must be smaller than maximun value !" + height = random.randint(imageHeight//3, imageHeight-1) + width = random.randint(imageWidth//3, imageWidth-1) + edge_num = random.randint(6, 8) + ratio = random.randint(6, 8)/10 + region = get_random_shape( + edge_num=edge_num, ratio=ratio, height=height, width=width) + region_width, region_height = region.size + # get random position + x, y = random.randint( + 0, imageHeight-region_height), random.randint(0, imageWidth-region_width) + velocity = get_random_velocity(max_speed=3) + m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8)) + m.paste(region, (y, x, y+region.size[0], x+region.size[1])) + masks = [m.convert('L')] + # return fixed masks + if random.uniform(0, 1) > 0.5: + return masks*video_length # -> directly copy all the base masks + # return moving masks + for _ in range(video_length-1): + x, y, velocity = random_move_control_points( + x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3) + m = Image.fromarray( + np.zeros((imageHeight, imageWidth)).astype(np.uint8)) + ### add by kaidong, to simulate zoon-in, zoom-out and rotation + extra_transform = random.uniform(0, 1) + # zoom in and zoom out + if extra_transform > 0.75: + resize_coefficient = random.uniform(zoomin, zoomout) + region = region.resize((math.ceil(region_width * resize_coefficient), math.ceil(region_height * resize_coefficient)), Image.NEAREST) + m.paste(region, (y, x, y + region.size[0], x + region.size[1])) + region_width, region_height = region.size + # rotation + elif extra_transform > 0.5: + m.paste(region, (y, x, y + region.size[0], x + region.size[1])) + m = m.rotate(random.randint(rotmin, rotmax)) + # region_width, region_height = region.size + ### end + else: + m.paste(region, (y, x, y+region.size[0], x+region.size[1])) + masks.append(m.convert('L')) + return masks + + +def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240): + ''' + There is the initial point and 3 points per cubic bezier curve. + Thus, the curve will only pass though n points, which will be the sharp edges. + The other 2 modify the shape of the bezier curve. + edge_num, Number of possibly sharp edges + points_num, number of points in the Path + ratio, (0, 1) magnitude of the perturbation from the unit circle, + ''' + points_num = edge_num*3 + 1 + angles = np.linspace(0, 2*np.pi, points_num) + codes = np.full(points_num, Path.CURVE4) + codes[0] = Path.MOVETO + # Using this instead of Path.CLOSEPOLY avoids an innecessary straight line + verts = np.stack((np.cos(angles), np.sin(angles))).T * \ + (2*ratio*np.random.random(points_num)+1-ratio)[:, None] + verts[-1, :] = verts[0, :] + path = Path(verts, codes) + # draw paths into images + fig = plt.figure() + ax = fig.add_subplot(111) + patch = patches.PathPatch(path, facecolor='black', lw=2) + ax.add_patch(patch) + ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1) + ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1) + ax.axis('off') # removes the axis to leave only the shape + fig.canvas.draw() + # convert plt images into numpy images + data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,))) + plt.close(fig) + # postprocess + data = cv2.resize(data, (width, height))[:, :, 0] + data = (1 - np.array(data > 0).astype(np.uint8))*255 + corrdinates = np.where(data > 0) + xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max( + corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1]) + region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax)) + return region + + +def random_accelerate(velocity, maxAcceleration, dist='uniform'): + speed, angle = velocity + d_speed, d_angle = maxAcceleration + if dist == 'uniform': + speed += np.random.uniform(-d_speed, d_speed) + angle += np.random.uniform(-d_angle, d_angle) + elif dist == 'guassian': + speed += np.random.normal(0, d_speed / 2) + angle += np.random.normal(0, d_angle / 2) + else: + raise NotImplementedError( + f'Distribution type {dist} is not supported.') + return (speed, angle) + + +def get_random_velocity(max_speed=3, dist='uniform'): + if dist == 'uniform': + speed = np.random.uniform(max_speed) + elif dist == 'guassian': + speed = np.abs(np.random.normal(0, max_speed / 2)) + else: + raise NotImplementedError( + f'Distribution type {dist} is not supported.') + angle = np.random.uniform(0, 2 * np.pi) + return (speed, angle) + + +def random_move_control_points(X, + Y, + imageHeight, + imageWidth, + lineVelocity, + region_size, + maxLineAcceleration=(3, 0.5), + maxInitSpeed=3): + region_width, region_height = region_size + speed, angle = lineVelocity + X += int(speed * np.cos(angle)) + Y += int(speed * np.sin(angle)) + lineVelocity = random_accelerate(lineVelocity, + maxLineAcceleration, + dist='guassian') + if ((X > imageHeight - region_height) or (X < 0) + or (Y > imageWidth - region_width) or (Y < 0)): + lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian') + new_X = np.clip(X, 0, imageHeight - region_height) + new_Y = np.clip(Y, 0, imageWidth - region_width) + return new_X, new_Y, lineVelocity + + +if __name__ == '__main__': + + trials = 10 + for _ in range(trials): + video_length = 10 + # The returned masks are either stationary (50%) or moving (50%) + masks = create_random_shape_with_random_motion(video_length, + imageHeight=240, + imageWidth=432) + + for m in masks: + cv2.imshow('mask', np.array(m)) + cv2.waitKey(500) diff --git a/datasets/davis/test.json b/datasets/davis/test.json new file mode 100644 index 0000000000000000000000000000000000000000..54875df42cba3451a6c3f2642706652ae087996a --- /dev/null +++ b/datasets/davis/test.json @@ -0,0 +1 @@ +{"bear": 82, "blackswan": 50, "bmx-bumps": 90, "bmx-trees": 80, "boat": 75, "breakdance": 84, "breakdance-flare": 71, "bus": 80, "camel": 90, "car-roundabout": 75, "car-shadow": 40, "car-turn": 80, "cows": 104, "dance-jump": 60, "dance-twirl": 90, "dog": 60, "dog-agility": 25, "drift-chicane": 52, "drift-straight": 50, "drift-turn": 64, "elephant": 80, "flamingo": 80, "goat": 90, "hike": 80, "hockey": 75, "horsejump-high": 50, "horsejump-low": 60, "kite-surf": 50, "kite-walk": 80, "libby": 49, "lucia": 70, "mallard-fly": 70, "mallard-water": 80, "motocross-bumps": 60, "motocross-jump": 40, "motorbike": 43, "paragliding": 70, "paragliding-launch": 80, "parkour": 100, "rhino": 90, "rollerblade": 35, "scooter-black": 43, "scooter-gray": 75, "soapbox": 99, "soccerball": 48, "stroller": 91, "surf": 55, "swing": 60, "tennis": 70, "train": 80} \ No newline at end of file diff --git a/datasets/davis/train.json b/datasets/davis/train.json new file mode 100644 index 0000000000000000000000000000000000000000..3f63b2d95553e8ab606d9c207a6a8ae56a28035c --- /dev/null +++ b/datasets/davis/train.json @@ -0,0 +1 @@ +{"baseball": 90, "basketball-game": 77, "bears-ball": 78, "bmx-rider": 85, "butterfly": 80, "car-competition": 66, "cat": 52, "chairlift": 99, "circus": 73, "city-ride": 70, "crafting": 45, "curling": 76, "dog-competition": 85, "dolphins-show": 74, "dribbling": 49, "drone-flying": 70, "ducks": 75, "elephant-hyenas": 55, "giraffes": 88, "gym-ball": 69, "helicopter-landing": 77, "horse-race": 80, "horses-kids": 78, "hurdles-race": 55, "ice-hockey": 52, "jet-ski": 83, "juggling-selfie": 78, "kayak-race": 63, "kids-robot": 75, "landing": 35, "luggage": 83, "mantaray": 73, "marbles": 70, "mascot": 78, "mermaid": 78, "monster-trucks": 99, "motorbike-indoors": 79, "motorbike-race": 88, "music-band": 87, "obstacles": 81, "obstacles-race": 48, "peacock": 75, "plane-exhibition": 73, "puppet": 100, "robot-battle": 85, "robotic-arm": 82, "rodeo": 85, "sea-turtle": 90, "skydiving-jumping": 75, "snowboard-race": 75, "snowboard-sand": 55, "surfer": 80, "swimmer": 86, "table-tennis": 70, "tram": 84, "trucks-race": 78, "twist-dance": 83, "volleyball-beach": 73, "water-slide": 88, "weightlifting": 90} \ No newline at end of file diff --git a/datasets/youtube-vos/test.json b/datasets/youtube-vos/test.json new file mode 100644 index 0000000000000000000000000000000000000000..c4d79d915bd9171830d7b10de53f433dc92ca81d --- /dev/null +++ b/datasets/youtube-vos/test.json @@ -0,0 +1 @@ +{"0070461469": 91, "00bd64cb00": 180, "00fef116ee": 96, "012257ffcf": 180, "01475d1fe7": 180, "0163b18674": 96, "017fa2adaa": 180, "0232ba85ed": 180, "02b1a46f42": 180, "02caec8ac0": 91, "047436c72c": 96, "0481e165b4": 150, "04f98557e7": 144, "05e73c3ecb": 96, "08f95ce1ff": 144, "0b6db1c6fd": 96, "0bd8c18197": 180, "0c6d13ee2c": 91, "0c7ba00455": 96, "0cba3e52eb": 91, "0d16524447": 150, "0d4827437d": 150, "0d62fa582a": 180, "0e1f91c0d7": 91, "0ef454b3f0": 91, "10e18fcf0c": 96, "11105e147e": 91, "11444b16da": 91, "11a4df37a4": 180, "11b3298d6a": 96, "13006c4c7e": 96, "1345523ba1": 180, "144a16eb12": 180, "15a6536e74": 180, "1616507c9e": 180, "1655f4782a": 92, "16608ccef6": 96, "16bc05b66c": 150, "16f1e1779b": 96, "17caf00e26": 96, "18f1e2f716": 91, "191a0bfcdf": 180, "19d4acf831": 91, "1a1dc21969": 96, "1a72d9fcea": 150, "1a92c81edd": 180, "1b2c2022a3": 96, "1d1601d079": 180, "1db7b25d1c": 180, "1dee5b7b5a": 150, "1e0c2e54f2": 96, "1e458b1539": 92, "1e6ac08c86": 91, "1e790eae99": 56, "1ed0c6ca5b": 96, "1edbdb6d18": 180, "1f2015e056": 96, "215ac56b15": 180, "2233485b49": 96, "224d171af6": 180, "237c6ebaf4": 91, "2462c51412": 96, "24bf968338": 180, "250d5953a0": 150, "25bcf222fb": 180, "25ea8feecf": 150, "25fc493839": 92, "262f69837e": 180, "264ca20298": 180, "26d8d48248": 51, "270f84c5e5": 91, "27889bc0fe": 180, "29b87846e7": 96, "29d2e79171": 180, "2a44411a3d": 180, "2b426fd330": 180, "2c4c4e2d5b": 180, "2c4c718eda": 180, "2c962c1bbe": 180, "2cc841341c": 92, "2cf6c4d17e": 91, "2d7ef0be04": 180, "2e5e52c6c8": 150, "2ef6fce8c6": 144, "3014e769bf": 180, "30d5f163b6": 180, "318df73d6a": 90, "31fbb9df3c": 96, "3255fcad2f": 180, "3303eea8e4": 91, "3447c30052": 150, "362722660c": 180, "37e0b4642b": 91, "383e51ed93": 180, "386b050bd0": 41, "3876ba3136": 180, "388ec2934c": 180, "38b45d9c6b": 96, "396680839c": 150, "39ffa3a4a4": 180, "3b0291b2be": 150, "3b333693f4": 180, "3bde1da2cf": 96, "3c5f4e6672": 91, "3c80682cc6": 92, "3ce634a1c1": 180, "3d6a761295": 96, "3da878c317": 91, "3db571b7ee": 96, "3e2336812c": 180, "3f16b04d6d": 96, "3fbbc75c5e": 180, "4015a1e1cc": 87, "406cd7bd48": 91, "407b87ba26": 91, "40a5628dcc": 91, "41af239f5e": 180, "42c671b285": 180, "42de37f462": 180, "4381c60a2f": 180, "4445dc0af5": 180, "44a3419d24": 180, "4566034eaf": 51, "45877fd086": 180, "4595935b88": 91, "4923010cfe": 96, "49b6d81ee8": 180, "4a39c34139": 180, "4a5a9fde01": 144, "4a90394892": 180, "4af10534e4": 180, "4af307f5bc": 180, "4be0ac97df": 91, "4be9025726": 91, "4c18a7bfab": 91, "4c269afea9": 91, "4c3db058db": 179, "4e1ef26a1e": 96, "50f4c0195b": 150, "50f89963c0": 96, "5105c5e4b8": 180, "51d60e4f93": 46, "51ee638399": 96, "522ea1a892": 180, "528e9f30e7": 91, "532efb206a": 180, "544b1486ac": 91, "5592eb680c": 180, "562fadda3a": 91, "568b30cf93": 150, "575f0e2d8e": 91, "5767fe466c": 150, "581c78d558": 180, "5a0ddcf128": 96, "5adf056317": 144, "5b33c701ce": 180, "5b8f636b33": 150, "5b9d26b1d7": 180, "5c24813a0b": 180, "5d0b35f30f": 46, "5e130392e1": 96, "5e41efe5bc": 180, "5e75de78ae": 91, "5fc34880f7": 180, "60912d6bab": 96, "612c96383d": 180, "61e5fd2205": 144, "620e350d23": 180, "62c27fcaaf": 180, "637c22d967": 91, "63eaebe4a2": 96, "63fd6b311e": 180, "64099f32ab": 180, "65643c4b34": 96, "660a88feb5": 180, "664b8d0c9f": 150, "665a7947b0": 180, "66affc2e86": 180, "673b1c03c9": 96, "67780f49c2": 91, "679a24b7bd": 180, "680d35b75b": 144, "68364a69ef": 180, "683bfaf498": 180, "68e883ff28": 180, "691f63f681": 180, "69f2d3146c": 96, "6c5c018237": 91, "6caa33f43a": 96, "6d2c7cc107": 180, "6d55effbbe": 144, "6d6b09b420": 51, "6d715acc3e": 180, "6e89b7359d": 96, "6e9428d555": 150, "6e9feafa2b": 91, "6eced45fee": 180, "6ef0b3282c": 96, "6f9019f0ea": 91, "6fe0ee9b7c": 180, "6ff74d4995": 180, "712b6ec68e": 96, "71680a627f": 96, "716aad4b56": 180, "721c2cda07": 180, "72218d52ac": 96, "7286b8aac9": 91, "728ba7998d": 91, "73b2b9af5f": 96, "7452941f4f": 180, "759d8249dd": 91, "75a55907dc": 150, "75f3a2a19e": 150, "77e7e4b1a1": 144, "7898e6542c": 180, "78e639c2c4": 91, "79091168f8": 180, "7ad5af3fe6": 180, "7b1a7dec16": 150, "7b36c4c3db": 180, "7b455d07cc": 150, "7bce4cfa48": 180, "7c064444d0": 144, "7c8014406a": 91, "7cb70182e5": 96, "7d04e540f5": 91, "7d5df020bf": 96, "7dfda4322c": 96, "7e6a27cc7c": 96, "7e9e344bf4": 180, "7eb9424a53": 180, "7ec8ea61f4": 91, "7fd2806fb0": 180, "8006501830": 150, "8014aeb412": 180, "80d1d22999": 180, "812f31be15": 144, "81312af68f": 92, "82843a1676": 150, "835aea9584": 36, "8366c67e9b": 180, "8467aa6c5c": 180, "8470ee5f48": 180, "8473ae2c60": 180, "8519765a65": 150, "851f73e4fc": 96, "85621c2c81": 150, "85b045995c": 180, "860c0a7cf8": 92, "861bd4b31e": 180, "8639adb930": 180, "8683e4d414": 150, "8687e892ff": 180, "86c5907811": 180, "870c197c8b": 180, "87de455fb7": 180, "87e1975888": 96, "87f5d4903c": 96, "883ede763d": 150, "88b84fe107": 91, "88ee198ce0": 91, "89d148a39f": 96, "89f3d789c5": 180, "8a22bb6c32": 180, "8a76048654": 180, "8a99d63296": 97, "8b0697f61a": 96, "8b722babfb": 180, "8ba5691030": 180, "8bdd52a66b": 150, "8c427b6a57": 180, "8cb68f36f6": 91, "8cbf0d6194": 180, "8d1ab4a2ed": 91, "8d55a5aebb": 180, "8d8c5906bd": 180, "8eb95e2e56": 150, "8f99788aa7": 180, "8fa5b3778f": 91, "9009ab4811": 91, "90c10e44cf": 91, "90c2c5c336": 96, "9124189275": 91, "91ee8300e7": 144, "9246556dfd": 91, "9323741e3b": 150, "94a33d3d20": 180, "9584210f86": 91, "9637e3b658": 51, "966c4c022e": 180, "9781e083b5": 180, "990d358980": 180, "995c087687": 150, "99a7d42674": 144, "99f056c109": 180, "9a29032b9c": 180, "9b07fc4cf6": 180, "9b5aa49509": 96, "9b5abb8108": 91, "9be210e984": 150, "9c3c28740e": 180, "9cace717c5": 180, "9d3ff7c1c1": 91, "9d8c66d92c": 150, "9eaa2f1fcc": 91, "9f1967f60f": 96, "9fa359e1cb": 150, "9fca469ddd": 96, "9ff11b620a": 180, "9ff655b9a3": 180, "a029b21901": 180, "a0c7eedeb8": 144, "a15e70486b": 180, "a35bef8bbf": 180, "a4309379a2": 91, "a51335af59": 96, "a5690fb3bf": 180, "a5b71f76fb": 86, "a5c8b1f945": 150, "a635426233": 150, "a73cc75b81": 144, "a7863d3903": 180, "a88f1fd4e3": 144, "aa2e90aa98": 144, "aab5ecf878": 91, "aafc5edf08": 96, "ab49400ffe": 180, "acd7b890f6": 91, "ad3ee9b86b": 180, "ad5fda372c": 144, "adb2040e5f": 91, "ae30aed29d": 180, "ae57b941a0": 180, "aeb9de8f66": 41, "af658a277c": 91, "af881cd801": 150, "b016a85236": 180, "b0313efe37": 96, "b19d6e149a": 120, "b19f091836": 180, "b2304e81df": 144, "b2d23dcf3a": 150, "b3cee57f31": 36, "b41a7ebfc6": 180, "b455f801b5": 46, "b47336c07b": 96, "b499ce791f": 180, "b52d26ddf9": 96, "b5c525cb08": 180, "b5d3b9be03": 91, "b6386bc3ce": 96, "b748b0f3be": 180, "b75e9ea782": 180, "b8237af453": 180, "b8a2104720": 96, "b8d6f92a65": 96, "b8f93a4094": 180, "bb0a1708ea": 180, "bb2245ab94": 180, "bb4ae8019f": 180, "bbdc38baa0": 76, "bbfe438d63": 96, "bc2be9fdc8": 96, "bcc00265f4": 96, "bd42cc48e4": 150, "bd43315417": 180, "bd85b04982": 51, "bda3146a46": 96, "be2b40d82a": 150, "c0f856e4de": 96, "c1bfacba4a": 91, "c1dcd30fb2": 96, "c285ede7f3": 180, "c2a6163d39": 150, "c3517ebed5": 86, "c3aabac30c": 180, "c3bb62a2f7": 144, "c454f19e90": 150, "c4c410ccd7": 180, "c5b94822e3": 180, "c64e9d1f7e": 91, "c682d1748f": 150, "c6d04b1ca3": 180, "c6dda81d86": 180, "c71623ab0c": 180, "c7db88a9db": 144, "c80ecb97d6": 150, "c8dd4de705": 180, "c915c8cbba": 150, "cb25a994d8": 144, "cba3e31e88": 91, "cc43a853e2": 180, "cc6c653874": 180, "cc718c7746": 180, "cc7e050f7f": 144, "cd14ed8653": 144, "cd5e4efaad": 46, "cddf78284d": 86, "cde37afe57": 144, "ce358eaf23": 150, "ce45145721": 91, "ce7d4af66d": 180, "ce9fb4bd8e": 91, "cec4db17a0": 180, "cecdd82d3c": 180, "ceea39e735": 180, "cf3e28c92a": 180, "cf8c671dab": 150, "cfd1e8166f": 96, "cfe7d98e50": 150, "cff0bbcba8": 96, "d1219663b7": 180, "d18ea7cd51": 180, "d1ed509b94": 91, "d22c5d5908": 81, "d2c6c7d8f6": 96, "d380084b7c": 91, "d3a2586e34": 180, "d3b1039c67": 180, "d3b25a44b3": 180, "d3f1d615b1": 180, "d7203fdab6": 96, "d76e963754": 96, "d7b3892660": 66, "d8b3e257da": 150, "d8b93e6bb1": 180, "d949468ad6": 180, "da553b619f": 180, "daac20af89": 180, "db8bf2430a": 180, "dbd729449a": 180, "dc0928b157": 91, "dc9aa0b8c0": 180, "dcc0637430": 180, "dcd3e1b53e": 86, "de1854f657": 101, "deb31e46cf": 96, "debccf2743": 150, "decf924833": 150, "e08b241b91": 180, "e0daa3b339": 180, "e1a52251b7": 180, "e1fc6d5237": 91, "e228ce16fd": 96, "e36dbb2ab7": 91, "e3dcf7a45e": 180, "e411e957af": 180, "e412e6a76b": 180, "e45a003b97": 179, "e60826ddf9": 91, "e6295c843b": 96, "e62c23b62b": 150, "e6b7a8fe73": 180, "e6f0e3131c": 180, "e7a3f8884e": 180, "e7c176739c": 180, "e965cd989b": 86, "e989440f7b": 150, "e98d115b9c": 81, "ea5f8c74d6": 180, "ea8a5b5a78": 96, "eaad295e8c": 150, "eaf4947f74": 180, "eb65451f4b": 92, "eb79c39e8e": 180, "eb92c92912": 96, "ebbb88e5f5": 180, "ec9b46eb6c": 180, "eca0be379d": 180, "ed33e8efb7": 66, "eda3a7bbb1": 150, "ee3ff10184": 180, "eec8403cc8": 91, "eee2db8829": 150, "ef22b8a227": 91, "ef8737ca22": 180, "eff7c1c098": 180, "f00dc892b2": 96, "f019c9ff98": 96, "f01edcbffb": 179, "f0866da89c": 180, "f12eb5256e": 180, "f1df2ea2dc": 180, "f29119c644": 180, "f3419f3a62": 150, "f35029f76d": 180, "f39dc2240d": 180, "f3aa63fa74": 150, "f3f3c201bd": 180, "f4865471b4": 96, "f505ae958c": 91, "f7605e73cd": 150, "f7917687d6": 180, "f7d310e219": 180, "f7e25f87b2": 180, "f94cd39525": 91, "f9f9aa431c": 180, "fa666fcc95": 66, "fb10740465": 180, "fb25b14e48": 91, "fb28ec1ba3": 150, "fbdda5ec7b": 96, "fbdf2180ee": 150, "fc0db37221": 91, "fd237cf4fb": 180, "fe36582e18": 180, "fef14bb2f2": 180, "ffe59ed1c1": 150} \ No newline at end of file diff --git a/datasets/youtube-vos/train.json b/datasets/youtube-vos/train.json new file mode 100644 index 0000000000000000000000000000000000000000..ac43202f1016619010595d602908690b2be9fddc --- /dev/null +++ b/datasets/youtube-vos/train.json @@ -0,0 +1 @@ +{"003234408d": 180, "0043f083b5": 96, "0044fa5fba": 87, "005a527edd": 144, "0065b171f9": 180, "00917dcfc4": 96, "00a23ccf53": 180, "00ad5016a4": 91, "01082ae388": 150, "011ac0a06f": 180, "013099c098": 91, "0155498c85": 180, "01694ad9c8": 91, "017ac35701": 180, "01b80e8e1a": 61, "01baa5a4e1": 150, "01c3111683": 180, "01c4cb5ffe": 180, "01c76f0a82": 96, "01c783268c": 180, "01e64dd36a": 91, "01ed275c6e": 96, "01ff60d1fa": 180, "020cd28cd2": 150, "02264db755": 180, "0248626d9a": 91, "02668dbffa": 150, "0274193026": 144, "02d28375aa": 180, "02f3a5c4df": 46, "031ccc99b1": 91, "0321b18c10": 92, "0348a45bca": 180, "0355e92655": 92, "0358b938c1": 91, "0368107cf1": 96, "0379ddf557": 180, "038b2cc71d": 91, "038c15a5dd": 178, "03a06cc98a": 96, "03a63e187f": 180, "03c95b4dae": 92, "03e2b57b0e": 150, "04194e1248": 180, "04259896e2": 180, "0444918a5f": 96, "04460a7a52": 180, "04474174a4": 180, "0450095513": 150, "045f00aed2": 180, "04667fabaa": 180, "04735c5030": 91, "04990d1915": 92, "04d62d9d98": 96, "04f21da964": 180, "04fbad476e": 180, "04fe256562": 96, "0503bf89c9": 150, "0536c9eed0": 92, "054acb238f": 180, "05579ca250": 150, "056c200404": 96, "05774f3a2c": 180, "058a7592c8": 96, "05a0a513df": 96, "05a569d8aa": 91, "05aa652648": 150, "05d7715782": 96, "05e0b0f28f": 150, "05fdbbdd7a": 66, "05ffcfed85": 180, "0630391881": 150, "06840b2bbe": 91, "068f7dce6f": 180, "0693719753": 150, "06ce2b51fb": 91, "06e224798e": 180, "06ee361788": 91, "06fbb3fa2c": 90, "0700264286": 96, "070c918ca7": 180, "07129e14a4": 180, "07177017e9": 86, "07238ffc58": 180, "07353b2a89": 150, "0738493cbf": 87, "075926c651": 87, "075c701292": 180, "0762ea9a30": 96, "07652ee4af": 150, "076f206928": 96, "077d32af19": 96, "079049275c": 144, "07913cdda7": 92, "07a11a35e8": 180, "07ac33b6df": 150, "07b6e8fda8": 46, "07c62c3d11": 180, "07cc1c7d74": 180, "080196ef01": 180, "081207976e": 96, "081ae4fa44": 150, "081d8250cb": 96, "082900c5d4": 96, "0860df21e2": 180, "0866d4c5e3": 91, "0891ac2eb6": 81, "08931bc458": 180, "08aa2705d5": 180, "08c8450db7": 96, "08d50b926c": 180, "08e1e4de15": 180, "08e48c1a48": 92, "08f561c65e": 180, "08feb87790": 96, "09049f6fe3": 150, "092e4ff450": 180, "09338adea8": 180, "093c335ccc": 144, "0970d28339": 180, "0974a213dc": 96, "097b471ed8": 96, "0990941758": 180, "09a348f4fa": 150, "09a6841288": 96, "09c5bad17b": 96, "09c9ce80c7": 180, "09ff54fef4": 150, "0a23765d15": 91, "0a275e7f12": 96, "0a2f2bd294": 96, "0a7a2514aa": 96, "0a7b27fde9": 180, "0a8c467cc3": 180, "0ac8c560ae": 96, "0b1627e896": 96, "0b285c47f6": 144, "0b34ec1d55": 180, "0b5b5e8e5a": 96, "0b68535614": 180, "0b6f9105fc": 180, "0b7dbfa3cb": 91, "0b9cea51ca": 180, "0b9d012be8": 180, "0bcfc4177d": 96, "0bd37b23c1": 96, "0bd864064c": 158, "0c11c6bf7b": 180, "0c26bc77ac": 180, "0c3a04798c": 96, "0c44a9d545": 180, "0c817cc390": 180, "0ca839ee9a": 180, "0cd7ac0ac0": 150, "0ce06e0121": 180, "0cfe974a89": 180, "0d2fcc0dcd": 96, "0d3aad05d2": 144, "0d40b015f4": 180, "0d97fba242": 91, "0d9cc80d7e": 51, "0dab85b6d3": 144, "0db5c427a5": 96, "0dbaf284f1": 97, "0de4923598": 97, "0df28a9101": 150, "0e04f636c4": 150, "0e05f0e232": 180, "0e0930474b": 91, "0e27472bea": 180, "0e30020549": 144, "0e621feb6c": 180, "0e803c7d73": 91, "0e9ebe4e3c": 92, "0e9f2785ec": 96, "0ea68d418b": 96, "0eb403a222": 96, "0ee92053d6": 97, "0eefca067f": 150, "0f17fa6fcb": 180, "0f1ac8e9a3": 180, "0f202e9852": 91, "0f2ab8b1ff": 180, "0f51a78756": 150, "0f5fbe16b0": 180, "0f6072077b": 91, "0f6b69b2f4": 180, "0f6c2163de": 144, "0f74ec5599": 180, "0f9683715b": 96, "0fa7b59356": 180, "0fb173695b": 96, "0fc958cde2": 150, "0fe7b1a621": 180, "0ffcdb491c": 96, "101caff7d4": 96, "1022fe8417": 96, "1032e80b37": 96, "103f501680": 180, "104e64565f": 96, "104f1ab997": 91, "106242403f": 96, "10b31f5431": 180, "10eced835e": 91, "110d26fa3a": 150, "1122c1d16a": 180, "1145b49a5f": 180, "11485838c2": 96, "114e7676ec": 180, "1157472b95": 180, "115ee1072c": 91, "1171141012": 150, "117757b4b8": 180, "1178932d2f": 180, "117cc76bda": 180, "1180cbf814": 180, "1187bbd0e3": 96, "1197e44b26": 180, "119cf20728": 180, "119dd54871": 180, "11a0c3b724": 91, "11a6ba8c94": 180, "11c722a456": 180, "11cbcb0b4d": 96, "11ccf5e99d": 96, "11ce6f452e": 91, "11e53de6f2": 46, "11feabe596": 150, "120cb9514d": 180, "12156b25b3": 180, "122896672d": 180, "1232b2f1d4": 36, "1233ac8596": 97, "1239c87234": 180, "1250423f7c": 96, "1257a1bc67": 180, "125d1b19dd": 180, "126d203967": 180, "1295e19071": 96, "12ad198c54": 144, "12bddb2bcb": 150, "12ec9b93ee": 180, "12eebedc35": 91, "132852e094": 180, "1329409f2a": 180, "13325cfa14": 96, "1336440745": 180, "134d06dbf9": 97, "135625b53d": 144, "13870016f9": 92, "13960b3c84": 96, "13adaad9d9": 180, "13ae097e20": 180, "13e3070469": 96, "13f6a8c20d": 144, "1416925cf2": 92, "142d2621f5": 91, "145d5d7c03": 180, "145fdc3ac5": 180, "1471274fa7": 76, "14a6b5a139": 180, "14c21cea0d": 180, "14dae0dc93": 96, "14f9bd22b5": 180, "14fd28ae99": 180, "15097d5d4e": 144, "150ea711f2": 180, "1514e3563f": 180, "152aaa3a9e": 180, "152b7d3bd7": 150, "15617297cc": 180, "15abbe0c52": 150, "15d1fb3de5": 180, "15f67b0fab": 180, "161eb59aad": 96, "16288ea47f": 180, "164410ce62": 91, "165c3c8cd4": 96, "165c42b41b": 91, "165ec9e22b": 144, "1669502269": 91, "16763cccbb": 150, "16adde065e": 96, "16af445362": 96, "16afd538ad": 150, "16c3fa4d5d": 96, "16d1d65c27": 180, "16e8599e94": 180, "16fe9fb444": 91, "1705796b02": 96, "1724db7671": 144, "17418e81ea": 180, "175169edbb": 144, "17622326fd": 180, "17656bae77": 91, "17b0d94172": 61, "17c220e4f6": 180, "17c7bcd146": 96, "17cb4afe89": 180, "17cd79a434": 180, "17d18604c3": 96, "17d8ca1a37": 150, "17e33f4330": 180, "17f7a6d805": 150, "180abc8378": 180, "183ba3d652": 96, "185bf64702": 96, "18913cc690": 91, "1892651815": 180, "189ac8208a": 91, "189b44e92c": 97, "18ac264b76": 150, "18b245ab49": 91, "18b5cebc34": 150, "18bad52083": 180, "18bb5144d5": 180, "18c6f205c5": 96, "1903f9ea15": 96, "1917b209f2": 91, "191e74c01d": 150, "19367bb94e": 180, "193ffaa217": 91, "19696b67d3": 96, "197f3ab6f3": 180, "1981e763cc": 180, "198afe39ae": 144, "19a6e62b9b": 150, "19b60d5335": 180, "19c00c11f9": 150, "19e061eb88": 91, "19e8bc6178": 86, "19ee80dac6": 180, "1a25a9170a": 180, "1a359a6c1a": 150, "1a3e87c566": 150, "1a5fe06b00": 91, "1a6c0fbd1e": 144, "1a6f3b5a4b": 96, "1a8afbad92": 92, "1a8bdc5842": 150, "1a95752aca": 150, "1a9c131cb7": 180, "1aa3da3ee3": 150, "1ab27ec7ea": 56, "1abf16d21d": 150, "1acd0f993b": 180, "1ad202e499": 180, "1af8d2395d": 180, "1afd39a1fa": 91, "1b2d31306f": 180, "1b3fa67f0e": 92, "1b43fa74b4": 150, "1b73ea9fc2": 92, "1b7e8bb255": 96, "1b8680f8cd": 180, "1b883843c0": 91, "1b8898785b": 180, "1b88ba1aa4": 180, "1b96a498e5": 150, "1bbc4c274f": 96, "1bd87fe9ab": 66, "1c4090c75b": 180, "1c41934f84": 96, "1c72b04b56": 180, "1c87955a3a": 150, "1c9f9eb792": 180, "1ca240fede": 96, "1ca5673803": 180, "1cada35274": 180, "1cb44b920d": 180, "1cd10e62be": 150, "1d3087d5e5": 180, "1d3685150a": 92, "1d6ff083aa": 96, "1d746352a6": 66, "1da256d146": 91, "1da4e956b1": 180, "1daf812218": 150, "1dba687bce": 180, "1dce57d05d": 86, "1de4a9e537": 97, "1dec5446c8": 180, "1dfbe6f586": 150, "1e1a18c45a": 180, "1e1e42529d": 76, "1e4be70796": 96, "1eb60959c8": 180, "1ec8b2566b": 180, "1ecdc2941c": 180, "1ee0ac70ff": 87, "1ef8e17def": 91, "1f1a2a9fc0": 86, "1f1beb8daa": 150, "1f2609ee13": 180, "1f3876f8d0": 144, "1f4ec0563d": 150, "1f64955634": 96, "1f7d31b5b2": 96, "1f8014b7fd": 96, "1f9c7d10f1": 180, "1fa350df76": 96, "1fc9538993": 180, "1fe2f0ec59": 150, "2000c02f9d": 180, "20142b2f05": 180, "201a8d75e5": 150, "2023b3ee4f": 180, "202b767bbc": 92, "203594a418": 180, "2038987336": 150, "2039c3aecb": 96, "204a90d81f": 150, "207bc6cf01": 144, "208833d1d1": 180, "20c6d8b362": 46, "20e3e52e0a": 96, "2117fa0c14": 180, "211bc5d102": 150, "2120d9c3c3": 150, "2125235a49": 180, "21386f5978": 92, "2142af8795": 150, "215dfc0f73": 96, "217bae91e5": 180, "217c0d44e4": 150, "219057c87b": 150, "21d0edbf81": 96, "21df87ad76": 96, "21f1d089f5": 96, "21f4019116": 180, "222597030f": 91, "222904eb5b": 92, "223a0e0657": 180, "223bd973ab": 92, "22472f7395": 150, "224e7c833e": 96, "225aba51d9": 86, "2261d421ea": 180, "2263a8782b": 180, "2268cb1ffd": 150, "2268e93b0a": 61, "2293c99f3f": 180, "22a1141970": 91, "22b13084b2": 180, "22d9f5ab0c": 180, "22f02efe3a": 144, "232c09b75b": 150, "2350d71b4b": 180, "2376440551": 180, "2383d8aafd": 144, "238b84e67f": 96, "238d4b86f6": 91, "238d947c6b": 46, "23993ce90d": 180, "23b0c8a9ab": 150, "23b3beafcc": 156, "23d80299fe": 92, "23f404a9fc": 96, "240118e58a": 178, "2431dec2fd": 180, "24440e0ac7": 97, "2457274dbc": 180, "2465bf515d": 91, "246b142c4d": 180, "247d729e36": 96, "2481ceafeb": 150, "24866b4e6a": 150, "2489d78320": 180, "24ab0b83e8": 180, "24b0868d92": 180, "24b5207cd9": 96, "24ddf05c03": 92, "250116161c": 71, "256ad2e3fc": 180, "256bd83d5e": 180, "256dcc8ab8": 180, "2589956baa": 150, "258b3b33c6": 91, "25ad437e29": 96, "25ae395636": 180, "25c750c6db": 150, "25d2c3fe5d": 180, "25dc80db7c": 96, "25f97e926f": 180, "26011bc28b": 150, "260846ffbe": 180, "260dd9ad33": 66, "267964ee57": 92, "2680861931": 96, "268ac7d3fc": 180, "26b895d91e": 71, "26bc786d4f": 91, "26ddd2ef12": 180, "26de3d18ca": 150, "26f7784762": 180, "2703e52a6a": 180, "270ed80c12": 180, "2719b742ab": 180, "272f4163d0": 180, "27303333e1": 96, "27659fa7d6": 180, "279214115d": 180, "27a5f92a9c": 97, "27cf2af1f3": 150, "27f0d5f8a2": 86, "28075f33c1": 180, "281629cb41": 96, "282b0d51f5": 96, "282fcab00b": 96, "28449fa0dc": 180, "28475208ca": 96, "285580b7c4": 180, "285b69e223": 150, "288c117201": 150, "28a8eb9623": 180, "28bf9c3cf3": 180, "28c6b8f86a": 180, "28c972dacd": 144, "28d9fa6016": 96, "28e392de91": 144, "28f4a45190": 150, "298c844fc9": 91, "29a0356a2b": 180, "29d779f9e3": 76, "29dde5f12b": 86, "29de7b6579": 150, "29e630bdd0": 144, "29f2332d30": 144, "2a18873352": 92, "2a3824ff31": 91, "2a559dd27f": 96, "2a5c09acbd": 76, "2a63eb1524": 96, "2a6a30a4ea": 150, "2a6d9099d1": 180, "2a821394e3": 81, "2a8c5b1342": 96, "2abc8d66d2": 96, "2ac9ef904a": 46, "2b08f37364": 150, "2b351bfd7d": 180, "2b659a49d7": 66, "2b69ee5c26": 96, "2b6c30bbbd": 180, "2b88561cf2": 144, "2b8b14954e": 180, "2ba621c750": 150, "2bab50f9a7": 180, "2bb00c2434": 91, "2bbde474ef": 92, "2bdd82fb86": 150, "2be06fb855": 96, "2bf545c2f5": 180, "2bffe4cf9a": 96, "2c04b887b7": 144, "2c05209105": 180, "2c0ad8cf39": 180, "2c11fedca8": 56, "2c1a94ebfb": 91, "2c1e8c8e2f": 180, "2c29fabcf1": 96, "2c2c076c01": 180, "2c3ea7ee7d": 92, "2c41fa0648": 87, "2c44bb6d1c": 96, "2c54cfbb78": 180, "2c5537eddf": 180, "2c6e63b7de": 150, "2cb10c6a7e": 180, "2cbcd5ccd1": 180, "2cc5d9c5f6": 180, "2cd01cf915": 180, "2cdbf5f0a7": 91, "2ce660f123": 96, "2cf114677e": 150, "2d01eef98e": 180, "2d03593bdc": 96, "2d183ac8c4": 180, "2d33ad3935": 96, "2d3991d83e": 150, "2d4333577b": 180, "2d4d015c64": 96, "2d8f5e5025": 144, "2d900bdb8e": 180, "2d9a1a1d49": 46, "2db0576a5c": 180, "2dc0838721": 180, "2dcc417f82": 150, "2df005b843": 180, "2df356de14": 180, "2e00393d96": 61, "2e03b8127a": 180, "2e0f886168": 96, "2e2bf37e6d": 180, "2e42410932": 87, "2ea78f46e4": 180, "2ebb017a26": 180, "2ee2edba2a": 96, "2efb07554a": 180, "2f17e4fc1e": 96, "2f2c65c2f3": 144, "2f2d9b33be": 150, "2f309c206b": 180, "2f53822e88": 144, "2f53998171": 96, "2f5b0c89b1": 180, "2f680909e6": 180, "2f710f66bd": 180, "2f724132b9": 91, "2f7e3517ae": 91, "2f96f5fc6f": 180, "2f97d9fecb": 96, "2fbfa431ec": 96, "2fc9520b53": 180, "2fcd9f4c62": 180, "2feb30f208": 87, "2ff7f5744f": 150, "30085a2cc6": 96, "30176e3615": 56, "301f72ee11": 92, "3026bb2f61": 180, "30318465dc": 150, "3054ca937d": 180, "306121e726": 92, "3064ad91e8": 180, "307444a47f": 180, "307bbb7409": 91, "30a20194ab": 144, "30c35c64a4": 150, "30dbdb2cd6": 91, "30fc77d72f": 150, "310021b58b": 96, "3113140ee8": 144, "3150b2ee57": 180, "31539918c4": 180, "318dfe2ce2": 144, "3193da4835": 91, "319f725ad9": 180, "31bbd0d793": 91, "322505c47f": 180, "322b237865": 92, "322da43910": 97, "3245e049fb": 66, "324c4c38f6": 180, "324e35111a": 150, "3252398f09": 150, "327dc4cabf": 180, "328d918c7d": 180, "3290c0de97": 96, "3299ae3116": 180, "32a7cd687b": 150, "33098cedb4": 92, "3332334ac4": 180, "334cb835ac": 180, "3355e056eb": 180, "33639a2847": 180, "3373891cdc": 180, "337975816b": 180, "33e29d7e91": 96, "34046fe4f2": 180, "3424f58959": 180, "34370a710f": 92, "343bc6a65a": 179, "3450382ef7": 144, "3454303a08": 180, "346aacf439": 180, "346e92ff37": 180, "34a5ece7dd": 144, "34b109755a": 180, "34d1b37101": 96, "34dd2c70a7": 180, "34efa703df": 180, "34fbee00a6": 150, "3504df2fda": 96, "35195a56a1": 150, "351c822748": 180, "351cfd6bc5": 180, "3543d8334c": 180, "35573455c7": 96, "35637a827f": 96, "357a710863": 92, "358bf16f9e": 96, "35ab34cc34": 180, "35c6235b8d": 91, "35d01a438a": 180, "3605019d3b": 96, "3609bc3f88": 92, "360e25da17": 97, "36299c687c": 96, "362c5bc56e": 180, "3649228783": 150, "365b0501ea": 92, "365f459863": 180, "369893f3ad": 180, "369c9977e1": 180, "369dde050a": 96, "36c7dac02f": 180, "36d5b1493b": 180, "36f5cc68fd": 91, "3735480d18": 180, "374b479880": 97, "375a49d38f": 180, "375a5c0e09": 180, "376bda9651": 144, "377db65f60": 144, "37c19d1087": 46, "37d4ae24fc": 96, "37ddce7f8b": 180, "37e10d33af": 180, "37e45c6247": 96, "37fa0001e8": 180, "3802d458c0": 150, "382caa3cb4": 91, "383bb93111": 91, "388843df90": 180, "38924f4a7f": 92, "38b00f93d7": 92, "38c197c10e": 96, "38c9c3d801": 180, "38eb2bf67f": 92, "38fe9b3ed1": 180, "390352cced": 180, "390c51b987": 96, "390ca6f1d6": 144, "392bc0f8a1": 96, "392ecb43bd": 92, "3935291688": 150, "3935e63b41": 180, "394454fa9c": 180, "394638fc8b": 96, "39545e20b7": 180, "397abeae8f": 180, "3988074b88": 91, "398f5d5f19": 174, "39bc49a28c": 180, "39befd99fb": 144, "39c3c7bf55": 180, "39d584b09f": 91, "39f6f6ffb1": 180, "3a079fb484": 180, "3a0d3a81b7": 150, "3a1d55d22b": 82, "3a20a7583e": 96, "3a2c1f66e5": 150, "3a33f4d225": 180, "3a3bf84b13": 144, "3a4565e5ec": 144, "3a4e32ed5e": 180, "3a7ad86ce0": 180, "3a7bdde9b8": 180, "3a98867cbe": 91, "3aa3f1c9e8": 150, "3aa7fce8b6": 91, "3aa876887d": 96, "3ab807ded6": 96, "3ab9b1a85a": 96, "3adac8d7da": 180, "3ae1a4016f": 96, "3ae2deaec2": 180, "3ae81609d6": 144, "3af847e62f": 92, "3b23792b84": 144, "3b3b0af2ee": 150, "3b512dad74": 144, "3b6c7988f6": 91, "3b6e983b5b": 180, "3b74a0fc20": 180, "3b7a50b80d": 180, "3b96d3492f": 180, "3b9ad0c5a9": 150, "3b9ba0894a": 180, "3bb4e10ed7": 144, "3bd9a9b515": 150, "3beef45388": 96, "3c019c0a24": 96, "3c090704aa": 96, "3c2784fc0d": 144, "3c47ab95f8": 150, "3c4db32d74": 91, "3c5ff93faf": 180, "3c700f073e": 180, "3c713cbf2f": 91, "3c8320669c": 180, "3c90d225ee": 180, "3cadbcc404": 96, "3cb9be84a5": 150, "3cc37fd487": 91, "3cc6f90cb2": 92, "3cd5e035ef": 180, "3cdf03531b": 178, "3cdf828f59": 180, "3d254b0bca": 180, "3d5aeac5ba": 180, "3d690473e1": 180, "3d69fed2fb": 96, "3d8997aeb6": 96, "3db0d6b07e": 96, "3db1ddb8cf": 180, "3db907ac77": 180, "3dcbc0635b": 150, "3dd48ed55f": 144, "3de4ac4ec4": 92, "3decd63d88": 180, "3e04a6be11": 180, "3e108fb65a": 96, "3e1448b01c": 150, "3e16c19634": 180, "3e2845307e": 61, "3e38336da5": 96, "3e3a819865": 180, "3e3e4be915": 96, "3e680622d7": 91, "3e7d2aeb07": 96, "3e7d8f363d": 180, "3e91f10205": 26, "3ea4c49bbe": 144, "3eb39d11ab": 180, "3ec273c8d5": 96, "3ed3f91271": 76, "3ee062a2fd": 180, "3eede9782c": 180, "3ef2fa99cb": 180, "3efc6e9892": 92, "3f0b0dfddd": 96, "3f0c860359": 91, "3f18728586": 180, "3f3b15f083": 96, "3f45a470ad": 46, "3f4f3bc803": 150, "3fd96c5267": 91, "3fea675fab": 91, "3fee8cbc9f": 96, "3fff16d112": 180, "401888b36c": 144, "4019231330": 150, "402316532d": 180, "402680df52": 180, "404d02e0c0": 150, "40709263a8": 81, "4083cfbe15": 150, "40a96c5cb1": 96, "40b8e50f82": 91, "40f4026bf5": 144, "4100b57a3a": 150, "41059fdd0b": 180, "41124e36de": 144, "4122aba5f9": 180, "413bab0f0d": 96, "4164faee0b": 180, "418035eec9": 180, "4182d51532": 96, "418bb97e10": 144, "41a34c20e7": 96, "41dab05200": 180, "41ff6d5e2a": 77, "420caf0859": 56, "42264230ba": 96, "425a0c96e0": 91, "42da96b87c": 180, "42eb5a5b0f": 180, "42f17cd14d": 91, "42f5c61c49": 180, "42ffdcdee9": 180, "432f9884f9": 91, "43326d9940": 150, "4350f3ab60": 144, "4399ffade3": 96, "43a6c21f37": 150, "43b5555faa": 180, "43d63b752a": 180, "4416bdd6ac": 92, "4444753edd": 76, "444aa274e7": 150, "444d4e0596": 150, "446b8b5f7a": 96, "4478f694bb": 91, "44b1da0d87": 92, "44b4dad8c9": 96, "44b5ece1b9": 180, "44d239b24e": 150, "44eaf8f51e": 180, "44f4f57099": 96, "44f7422af2": 180, "450787ac97": 180, "4523656564": 96, "4536c882e5": 180, "453b65daa4": 180, "454f227427": 91, "45636d806a": 180, "456fb9362e": 91, "457e717a14": 150, "45a89f35e1": 180, "45bf0e947d": 150, "45c36a9eab": 150, "45d9fc1357": 174, "45f8128b97": 180, "4607f6c03c": 91, "46146dfd39": 92, "4620e66b1e": 150, "4625f3f2d3": 96, "462b22f263": 96, "4634736113": 180, "463c0f4fdd": 180, "46565a75f8": 96, "46630b55ae": 56, "466839cb37": 91, "466ba4ae0c": 180, "4680236c9d": 180, "46bf4e8709": 91, "46e18e42f1": 150, "46f5093c59": 180, "47269e0499": 92, "472da1c484": 144, "47354fab09": 180, "4743bb84a7": 92, "474a796272": 180, "4783d2ab87": 96, "479cad5da3": 180, "479f5d7ef6": 96, "47a05fbd1d": 96, "4804ee2767": 97, "4810c3fbca": 180, "482fb439c2": 150, "48375af288": 96, "484ab44de4": 96, "485f3944cd": 96, "4867b84887": 150, "486a8ac57e": 180, "486e69c5bd": 180, "48812cf33e": 150, "4894b3b9ea": 180, "48bd66517d": 180, "48d83b48a4": 91, "49058178b8": 46, "4918d10ff0": 91, "4932911f80": 150, "49405b7900": 180, "49972c2d14": 150, "499bf07002": 96, "49b16e9377": 180, "49c104258e": 144, "49c879f82d": 96, "49e7326789": 180, "49ec3e406a": 91, "49fbf0c98a": 96, "4a0255c865": 180, "4a088fe99a": 96, "4a341402d0": 180, "4a3471bdf5": 96, "4a4b50571c": 144, "4a50f3d2e9": 96, "4a6e3faaa1": 180, "4a7191f08a": 150, "4a86fcfc30": 180, "4a885fa3ef": 144, "4a8af115de": 21, "4aa2e0f865": 180, "4aa9d6527f": 180, "4abb74bb52": 96, "4ae13de1cd": 91, "4af8cb323f": 97, "4b02c272b3": 180, "4b19c529fb": 96, "4b2974eff4": 180, "4b3154c159": 95, "4b54d2587f": 180, "4b556740ff": 144, "4b67aa9ef6": 178, "4b97cc7b8d": 96, "4baa1ed4aa": 91, "4bc8c676bb": 96, "4beaea4dbe": 180, "4bf5763d24": 96, "4bffa92b67": 138, "4c25dfa8ec": 96, "4c397b6fd4": 180, "4c51e75d66": 150, "4c7710908f": 180, "4c9b5017be": 180, "4ca2ffc361": 92, "4cad2e93bc": 150, "4cd427b535": 180, "4cd9a4b1ef": 180, "4cdfe3c2b2": 180, "4cef87b649": 96, "4cf208e9b3": 180, "4cf5bc3e60": 92, "4cfdd73249": 91, "4cff5c9e42": 180, "4d26d41091": 96, "4d5c23c554": 180, "4d67c59727": 150, "4d983cad9f": 180, "4da0d00b55": 144, "4daa179861": 91, "4dadd57153": 92, "4db117e6c5": 91, "4de4ce4dea": 180, "4dfaee19e5": 180, "4dfdd7fab0": 180, "4e3f346aa5": 92, "4e49c2a9c7": 56, "4e4e06a749": 180, "4e70279712": 96, "4e72856cc7": 91, "4e752f8075": 180, "4e7a28907f": 66, "4e824b9247": 180, "4e82b1df57": 180, "4e87a639bc": 180, "4ea77bfd15": 150, "4eb6fc23a2": 180, "4ec9da329e": 96, "4efb9a0720": 180, "4f062fbc63": 96, "4f35be0e0b": 96, "4f37e86797": 91, "4f414dd6e7": 180, "4f424abded": 180, "4f470cc3ae": 144, "4f601d255a": 150, "4f7386a1ab": 144, "4f824d3dcd": 91, "4f827b0751": 144, "4f8db33a13": 180, "4fa160f8a3": 180, "4fa9c30a45": 180, "4facd8f0e8": 96, "4fca07ad01": 91, "4fded94004": 180, "4fdfef4dea": 91, "4feb3ac01f": 92, "4fffec8479": 96, "500c835a86": 180, "50168342bf": 180, "50243cffdc": 180, "5031d5a036": 180, "504dd9c0fd": 96, "50568fbcfb": 180, "5069c7c5b3": 180, "508189ac91": 180, "50b6b3d4b7": 91, "50c6f4fe3e": 86, "50cce40173": 180, "50efbe152f": 180, "50f290b95d": 91, "5104aa1fea": 96, "5110dc72c0": 180, "511e8ecd7f": 150, "513aada14e": 92, "5158d6e985": 180, "5161e1fa57": 180, "51794ddd58": 96, "517d276725": 91, "51a597ee04": 51, "51b37b6d97": 96, "51b5dc30a0": 96, "51e85b347b": 180, "51eea1fdac": 150, "51eef778af": 91, "51f384721c": 76, "521cfadcb4": 180, "52355da42f": 96, "5247d4b160": 180, "524b470fd0": 180, "524cee1534": 96, "5252195e8a": 91, "5255c9ca97": 144, "525928f46f": 96, "526df007a7": 180, "529b12de78": 91, "52c7a3d653": 150, "52c8ec0373": 91, "52d225ed52": 96, "52ee406d9e": 180, "52ff1ccd4a": 96, "53143511e8": 180, "5316d11eb7": 96, "53253f2362": 180, "534a560609": 91, "5352c4a70e": 180, "536096501f": 92, "536b17bcea": 180, "5380eaabff": 144, "5390a43a54": 180, "53af427bb2": 91, "53bf5964ce": 180, "53c30110b5": 96, "53cad8e44a": 150, "53d9c45013": 91, "53e274f1b5": 150, "53e32d21ea": 96, "540850e1c7": 96, "540cb31cfe": 180, "541c4da30f": 91, "541d7935d7": 180, "545468262b": 180, "5458647306": 144, "54657855cd": 96, "547b3fb23b": 180, "5497dc3712": 150, "549c56f1d4": 96, "54a4260bb1": 150, "54b98b8d5e": 180, "54e1054b0f": 91, "54e8867b83": 180, "54ebe34f6e": 180, "5519b4ad13": 86, "551acbffd5": 150, "55341f42da": 180, "5566ab97e1": 91, "556c79bbf2": 144, "5589637cc4": 180, "558aa072f0": 180, "559824b6f6": 91, "55c1764e90": 180, "55eda6c77e": 180, "562d173565": 150, "5665c024cb": 96, "566cef4959": 91, "5675d78833": 144, "5678a91bd8": 180, "567a2b4bd0": 180, "569c282890": 86, "56cc449917": 150, "56e71f3e07": 150, "56f09b9d92": 180, "56fc0e8cf9": 144, "571ca79c71": 91, "57243657cf": 144, "57246af7d1": 91, "57427393e9": 96, "574b682c19": 180, "578f211b86": 180, "5790ac295d": 91, "579393912d": 180, "57a344ab1a": 180, "57bd3bcda4": 180, "57bfb7fa4c": 150, "57c010175e": 180, "57c457cc75": 180, "57c7fc2183": 150, "57d5289a01": 61, "58045fde85": 96, "58163c37cd": 150, "582d463e5c": 180, "5851739c15": 180, "585dd0f208": 66, "587250f3c3": 180, "589e4cc1de": 180, "589f65f5d5": 180, "58a07c17d5": 180, "58adc6d8b6": 76, "58b9bcf656": 96, "58c374917e": 96, "58fc75fd42": 87, "5914c30f05": 96, "59323787d5": 150, "5937b08d69": 96, "594065ddd7": 96, "595a0ceea6": 91, "59623ec40b": 91, "597ff7ef78": 150, "598935ef05": 46, "598c2ad3b2": 180, "59a6459751": 180, "59b175e138": 96, "59bf0a149f": 180, "59d53d1649": 180, "59e3e6fae7": 180, "59fe33e560": 180, "5a13a73fe5": 96, "5a25c22770": 150, "5a4a785006": 96, "5a50640995": 180, "5a75f7a1cf": 96, "5a841e59ad": 180, "5a91c5ab6d": 150, "5ab49d9de0": 96, "5aba1057fe": 180, "5abe46ba6d": 91, "5ac7c88d0c": 180, "5aeb95cc7d": 92, "5af15e4fc3": 91, "5afe381ae4": 96, "5b07b4229d": 51, "5b1001cc4f": 180, "5b1df237d2": 180, "5b263013bf": 91, "5b27d19f0b": 180, "5b48ae16c5": 96, "5b5babc719": 180, "5baaebdf00": 180, "5bab55cdbe": 180, "5bafef6e79": 96, "5bc77844da": 180, "5bd1f84545": 180, "5bddc3ba25": 180, "5bdf7c20d2": 180, "5bf23bc9d3": 180, "5c01f6171a": 144, "5c021681b7": 96, "5c185cff1d": 180, "5c42aba280": 180, "5c44bf8ab6": 180, "5c4c574894": 144, "5c52fa4662": 76, "5c6ea7dac3": 96, "5c74315dc2": 180, "5c7668855e": 92, "5c83e96778": 180, "5ca36173e4": 96, "5cac477371": 97, "5cb0cb1b2f": 96, "5cb0cfb98f": 144, "5cb49a19cf": 180, "5cbf7dc388": 180, "5d0e07d126": 96, "5d1e24b6e3": 81, "5d663000ff": 150, "5da6b2dc5d": 180, "5de9b90f24": 61, "5e08de0ed7": 180, "5e1011df9a": 87, "5e1ce354fd": 150, "5e35512dd7": 180, "5e418b25f9": 96, "5e4849935a": 144, "5e4ee19663": 96, "5e886ef78f": 96, "5e8d00b974": 180, "5e8d59dc31": 180, "5ed838bd5c": 96, "5edda6ee5a": 180, "5ede4d2f7a": 144, "5ede9767da": 144, "5ee23ca60e": 87, "5eec4d9fe5": 96, "5eecf07824": 180, "5eef7ed4f4": 91, "5ef5860ac6": 144, "5ef6573a99": 96, "5f1193e72b": 91, "5f29ced797": 96, "5f32cf521e": 150, "5f51876986": 96, "5f6ebe94a9": 86, "5f6f14977c": 91, "5f808d0d2d": 91, "5fb8aded6a": 180, "5fba90767d": 96, "5fd1c7a3df": 92, "5fd3da9f68": 91, "5fee2570ae": 180, "5ff66140d6": 180, "5ff8b85b53": 180, "600803c0f6": 180, "600be7f53e": 96, "6024888af8": 180, "603189a03c": 96, "6057307f6e": 180, "6061ddbb65": 96, "606c86c455": 180, "60c61cc2e5": 180, "60e51ff1ae": 150, "610e38b751": 150, "61344be2f6": 180, "6135e27185": 96, "614afe7975": 150, "614e571886": 180, "614e7078db": 96, "619812a1a7": 96, "61b481a78b": 96, "61c7172650": 180, "61cf7e40d2": 96, "61d08ef5a1": 46, "61da008958": 96, "61ed178ecb": 61, "61f5d1282c": 92, "61fd977e49": 144, "621584cffe": 180, "625817a927": 180, "625892cf0b": 96, "625b89d28a": 91, "629995af95": 150, "62a0840bb5": 180, "62ad6e121c": 87, "62d6ece152": 91, "62ede7b2da": 91, "62f025e1bc": 180, "6316faaebc": 97, "63281534dc": 150, "634058dda0": 144, "6353f09384": 180, "6363c87314": 180, "636e4872e0": 180, "637681cd6b": 180, "6376d49f31": 180, "6377809ec2": 180, "63936d7de5": 96, "639bddef11": 150, "63d37e9fd3": 180, "63d90c2bae": 96, "63e544a5d6": 180, "63ebbcf874": 96, "63fff40b31": 180, "6406c72e4d": 61, "64148128be": 96, "6419386729": 150, "643092bc41": 96, "644081b88d": 144, "64453cf61d": 180, "644bad9729": 96, "6454f548fd": 180, "645913b63a": 180, "64750b825f": 180, "64a43876b7": 96, "64dd6c83e3": 92, "64e05bf46e": 96, "64f55f1478": 150, "650b0165e4": 180, "651066ed39": 180, "652b67d960": 180, "653821d680": 180, "6538d00d73": 180, "65866dce22": 150, "6589565c8c": 150, "659832db64": 180, "65ab7e1d98": 180, "65b7dda462": 180, "65bd5eb4f5": 180, "65dcf115ab": 91, "65e9825801": 180, "65f9afe51c": 91, "65ff12bcb5": 180, "666b660284": 180, "6671643f31": 180, "668364b372": 96, "66852243cb": 96, "6693a52081": 180, "669b572898": 180, "66e98e78f5": 91, "670f12e88f": 180, "674c12c92d": 91, "675c27208a": 180, "675ed3e1ca": 144, "67741db50a": 96, "678a2357eb": 70, "67b0f4d562": 180, "67cfbff9b1": 180, "67e717d6bd": 91, "67ea169a3b": 92, "67ea809e0e": 180, "681249baa3": 180, "683de643d9": 180, "6846ac20df": 96, "6848e012ef": 96, "684bcd8812": 96, "684dc1c40c": 96, "685a1fa9cf": 91, "686dafaac9": 144, "68807d8601": 96, "6893778c77": 96, "6899d2dabe": 91, "68a2fad4ab": 180, "68cb45fda3": 180, "68cc4a1970": 96, "68dcb40675": 180, "68ea4a8c3d": 180, "68f6e7fbf0": 96, "68fa8300b4": 180, "69023db81f": 96, "6908ccf557": 91, "691a111e7c": 180, "6927723ba5": 180, "692ca0e1a2": 97, "692eb57b63": 180, "69340faa52": 96, "693cbf0c9d": 180, "6942f684ad": 96, "6944fc833b": 180, "69491c0ebf": 91, "695b61a2b0": 96, "6979b4d83f": 180, "697d4fdb02": 144, "69910460a4": 180, "6997636670": 180, "69a436750b": 96, "69aebf7669": 180, "69b8c17047": 180, "69c67f109f": 180, "69e0e7b868": 180, "69ea9c09d1": 180, "69f0af42a6": 97, "6a078cdcc7": 144, "6a37a91708": 71, "6a42176f2e": 180, "6a48e4aea8": 96, "6a5977be3a": 180, "6a5de0535f": 180, "6a80d2e2e5": 96, "6a96c8815d": 180, "6a986084e2": 96, "6aa8e50445": 92, "6ab9dce449": 150, "6abf0ba6b2": 180, "6acc6049d9": 96, "6adb31756c": 180, "6ade215eb0": 96, "6afb7d50e4": 144, "6afd692f1a": 180, "6b0b1044fe": 91, "6b17c67633": 180, "6b1b6ef28b": 92, "6b1e04d00d": 180, "6b2261888d": 96, "6b25d6528a": 144, "6b3a24395c": 150, "6b685eb75b": 96, "6b79be238c": 92, "6b928b7ba6": 96, "6b9c43c25a": 180, "6ba99cc41f": 91, "6bdab62bcd": 86, "6bf2e853b1": 180, "6bf584200f": 180, "6bf95df2b9": 150, "6c0949c51c": 180, "6c11a5f11f": 96, "6c23d89189": 61, "6c4387daf5": 96, "6c4ce479a4": 86, "6c5123e4bc": 96, "6c54265f16": 92, "6c56848429": 96, "6c623fac5f": 36, "6c81b014e9": 96, "6c99ea7c31": 92, "6c9d29d509": 91, "6c9e3b7d1a": 91, "6ca006e283": 96, "6caeb928d6": 180, "6cb2ee722a": 180, "6cbfd32c5e": 180, "6cc791250b": 150, "6cccc985e0": 96, "6d12e30c48": 180, "6d4bf200ad": 180, "6d6d2b8843": 91, "6d6eea5682": 180, "6d7a3d0c21": 96, "6d7efa9b9e": 180, "6da21f5c91": 180, "6da6adabc0": 150, "6dd2827fbb": 96, "6dd36705b9": 131, "6df3637557": 180, "6dfe55e9e5": 150, "6e1a21ba55": 96, "6e2f834767": 180, "6e36e4929a": 96, "6e4f460caf": 96, "6e618d26b6": 56, "6ead4670f7": 180, "6eaff19b9f": 180, "6eb2e1cd9e": 180, "6eb30b3b5a": 96, "6eca26c202": 91, "6ecad29e52": 96, "6ef0b44654": 96, "6efcfe9275": 180, "6f4789045c": 180, "6f49f522ef": 96, "6f67d7c4c4": 180, "6f96e91d81": 144, "6fc6fce380": 180, "6fc9b44c00": 96, "6fce7f3226": 150, "6fdf1ca888": 150, "702fd8b729": 180, "70405185d2": 180, "7053e4f41e": 180, "707bf4ce41": 87, "7082544248": 81, "708535b72a": 96, "7094ac0f60": 180, "70a6b875fa": 180, "70c3e97e41": 180, "7106b020ab": 91, "711dce6fe2": 96, "7136a4453f": 180, "7143fb084f": 180, "714d902095": 150, "7151c53b32": 150, "715357be94": 180, "7163b8085f": 150, "716df1aa59": 150, "71caded286": 150, "71d2665f35": 91, "71d67b9e19": 96, "71e06dda39": 180, "720b398b9c": 91, "720e3fa04c": 150, "720e7a5f1e": 91, "721bb6f2cb": 91, "722803f4f2": 92, "72552a07c9": 91, "726243a205": 96, "72690ef572": 46, "728cda9b65": 86, "728e81c319": 91, "72a810a799": 180, "72acb8cdf6": 180, "72b01281f9": 180, "72cac683e4": 91, "72cadebbce": 180, "72cae058a5": 180, "72d8dba870": 180, "72e8d1c1ff": 96, "72edc08285": 180, "72f04f1a38": 81, "731b825695": 144, "7320b49b13": 180, "732626383b": 87, "732df1eb05": 150, "73329902ab": 150, "733798921e": 150, "733824d431": 150, "734ea0d7fb": 91, "735a7cf7b9": 144, "7367a42892": 91, "7368d5c053": 180, "738e5a0a14": 180, "73c6ae7711": 96, "73e1852735": 150, "73e4e5cc74": 150, "73eac9156b": 180, "73f8441a88": 91, "7419e2ab3f": 91, "74267f68b9": 91, "7435690c8c": 46, "747c44785c": 81, "747f1b1f2f": 144, "748b2d5c01": 96, "74d4cee0a4": 91, "74ec2b3073": 91, "74ef677020": 96, "750be4c4d8": 96, "75172d4ac8": 96, "75285a7eb1": 180, "75504539c3": 91, "7550949b1d": 96, "7551cbd537": 150, "75595b453d": 91, "7559b4b0ec": 91, "755bd1fbeb": 96, "756f76f74d": 180, "7570ca7f3c": 180, "757a69746e": 180, "757cac96c6": 180, "7584129dc3": 144, "75a058dbcd": 91, "75b09ce005": 96, "75cae39a8f": 180, "75cee6caf0": 180, "75cf58fb2c": 91, "75d5c2f32a": 180, "75eaf5669d": 96, "75f7937438": 180, "75f99bd3b3": 96, "75fa586876": 92, "7613df1f84": 150, "762e1b3487": 96, "76379a3e69": 180, "764271f0f3": 92, "764503c499": 86, "7660005554": 46, "7666351b84": 96, "76693db153": 51, "767856368b": 92, "768671f652": 180, "768802b80d": 180, "76962c7ed2": 71, "76a75f4eee": 150, "76b90809f7": 180, "770a441457": 96, "772a0fa402": 180, "772f2ffc3e": 91, "774f6c2175": 180, "77610860e0": 56, "777e58ff3d": 96, "77920f1708": 150, "7799df28e7": 180, "779e847a9a": 81, "77ba4edc72": 96, "77c834dc43": 41, "77d8aa8691": 180, "77e7f38f4d": 144, "77eea6845e": 96, "7806308f33": 91, "78254660ea": 91, "7828af8bff": 180, "784398620a": 71, "784d201b12": 96, "78613981ed": 180, "78896c6baf": 92, "78aff3ebc0": 150, "78c7c03716": 91, "78d3676361": 91, "78e29dd4c3": 150, "78f1a1a54f": 91, "79208585cd": 180, "792218456c": 180, "7923bad550": 150, "794e6fc49f": 96, "796e6762ce": 180, "797cd21f71": 150, "79921b21c2": 150, "79a5778027": 180, "79bc006280": 180, "79bf95e624": 91, "79d9e00c55": 91, "79e20fc008": 96, "79e9db913e": 180, "79f014085e": 91, "79fcbb433a": 150, "7a13a5dfaa": 180, "7a14bc9a36": 96, "7a3c535f70": 96, "7a446a51e9": 91, "7a56e759c5": 91, "7a5f46198d": 86, "7a626ec98d": 92, "7a802264c4": 180, "7a8b5456ca": 180, "7abdff3086": 150, "7aecf9f7ac": 150, "7b0fd09c28": 96, "7b18b3db87": 180, "7b39fe7371": 144, "7b49e03d4c": 180, "7b5388c9f1": 180, "7b5cf7837f": 180, "7b733d31d8": 180, "7b74fd7b98": 180, "7b918ccb8a": 150, "7ba3ce3485": 96, "7bb0abc031": 180, "7bb5bb25cd": 180, "7bb7dac673": 92, "7bc7761b8c": 180, "7bf3820566": 96, "7c03a18ec1": 96, "7c078f211b": 150, "7c37d7991a": 71, "7c4ec17eff": 144, "7c649c2aaf": 180, "7c73340ab7": 91, "7c78a2266d": 180, "7c88ce3c5b": 180, "7ca6843a72": 180, "7cc9258dee": 96, "7cec7296ae": 46, "7d0ffa68a4": 96, "7d11b4450f": 81, "7d1333fcbe": 96, "7d18074fef": 91, "7d18c8c716": 96, "7d508fb027": 180, "7d55f791f0": 180, "7d74e3c2f6": 150, "7d783f67a9": 96, "7d83a5d854": 150, "7dd409947e": 180, "7de45f75e5": 150, "7e0cd25696": 150, "7e1922575c": 96, "7e1e3bbcc1": 180, "7e24023274": 180, "7e2f212fd3": 96, "7e6d1cc1f4": 180, "7e7cdcb284": 144, "7e9b6bef69": 66, "7ea5b49283": 92, "7eb2605d96": 91, "7eb26b8485": 180, "7ecd1f0c69": 96, "7f02b3cfe2": 180, "7f1723f0d5": 97, "7f21063c3a": 81, "7f3658460e": 91, "7f54132e48": 144, "7f559f9d4a": 144, "7f5faedf8b": 96, "7f838baf2b": 180, "7fa5f527e3": 96, "7ff84d66dd": 150, "802b45c8c4": 180, "804382b1ad": 180, "804c558adb": 96, "804f6338a4": 180, "8056117b89": 150, "806b6223ab": 96, "8088bda461": 46, "80b790703b": 180, "80c4a94706": 96, "80ce2e351b": 180, "80db581acd": 96, "80e12193df": 150, "80e41b608f": 180, "80f16b016d": 91, "81541b3725": 91, "8175486e6a": 96, "8179095000": 180, "8193671178": 180, "81a58d2c6b": 150, "81aa1286fb": 96, "81dffd30fb": 96, "8200245704": 41, "823e7a86e8": 46, "824973babb": 144, "824ca5538f": 180, "827171a845": 180, "8273a03530": 180, "827cf4f886": 91, "82b865c7dd": 180, "82c1517708": 91, "82d15514d6": 150, "82e117b900": 179, "82fec06574": 150, "832b5ef379": 97, "83424c9fbf": 180, "8345358fb8": 71, "834b50b31b": 180, "835e3b67d7": 97, "836ea92b15": 90, "837c618777": 144, "838eb3bd89": 180, "839381063f": 91, "839bc71489": 180, "83a8151377": 180, "83ae88d217": 180, "83ca8bcad0": 180, "83ce590d7f": 180, "83d3130ba0": 36, "83d40bcba5": 86, "83daba503a": 144, "83de906ec0": 180, "84044f37f3": 180, "84696b5a5e": 96, "84752191a3": 91, "847eeeb2e0": 180, "848e7835a0": 96, "84a4b29286": 180, "84a4bf147d": 66, "84be115c09": 144, "84d95c4350": 180, "84e0922cf7": 150, "84f0cfc665": 96, "8515f6db22": 180, "851f2f32c1": 91, "852a4d6067": 150, "854c48b02a": 96, "857a387c86": 180, "859633d56a": 96, "85a4f4a639": 144, "85ab85510c": 180, "85b1eda0d9": 92, "85dc1041c6": 96, "85e081f3c7": 150, "85f75187ad": 96, "8604bb2b75": 96, "860745b042": 150, "863b4049d7": 180, "8643de22d0": 180, "8647d06439": 46, "864ffce4fe": 180, "8662d9441a": 180, "8666521b13": 76, "868d6a0685": 91, "869fa45998": 91, "86a40b655d": 150, "86a8ae4223": 92, "86b2180703": 180, "86c85d27df": 180, "86d3755680": 180, "86e61829a1": 180, "871015806c": 91, "871e409c5c": 180, "8744b861ce": 96, "8749369ba0": 180, "878a299541": 144, "8792c193a0": 96, "8799ab0118": 96, "87d1f7d741": 180, "882b9e4500": 180, "885673ea17": 180, "8859dedf41": 96, "8873ab2806": 91, "887a93b198": 180, "8883e991a9": 86, "8891aa6dfa": 91, "8899d8cbcd": 91, "88b8274d67": 180, "88d3b80af6": 91, "88ede83da2": 180, "88f345941b": 180, "890976d6da": 91, "8909bde9ab": 91, "8929c7d5d9": 180, "89363acf76": 150, "89379487e0": 96, "8939db6354": 180, "893f658345": 144, "8953138465": 180, "895c96d671": 180, "895cbf96f9": 180, "895e8b29a7": 91, "898fa256c8": 180, "89986c60be": 180, "89b874547b": 180, "89bdb021d5": 144, "89c802ff9c": 96, "89d6336c2b": 180, "89ebb27334": 91, "8a27e2407c": 96, "8a31f7bca5": 96, "8a4a2fc105": 96, "8a5d6c619c": 96, "8a75ad7924": 180, "8aa817e4ed": 87, "8aad0591eb": 180, "8aca214360": 180, "8ae168c71b": 96, "8b0cfbab97": 21, "8b3645d826": 96, "8b3805dbd4": 180, "8b473f0f5d": 180, "8b4f6d1186": 180, "8b4fb018b7": 66, "8b518ee936": 92, "8b523bdfd6": 150, "8b52fb5fba": 91, "8b91036e5c": 144, "8b99a77ac5": 180, "8ba04b1e7b": 96, "8ba782192f": 180, "8bbeaad78b": 96, "8bd1b45776": 180, "8bd7a2dda6": 150, "8bdb091ccf": 180, "8be56f165d": 96, "8be950d00f": 96, "8bf84e7d45": 180, "8bffc4374b": 66, "8bfff50747": 180, "8c09867481": 144, "8c0a3251c3": 180, "8c3015cccb": 180, "8c469815cf": 96, "8c9ccfedc7": 91, "8ca1af9f3c": 150, "8ca3f6e6c1": 96, "8ca6a4f60f": 96, "8cac6900fe": 96, "8cba221a1e": 180, "8cbbe62ccd": 180, "8d064b29e2": 92, "8d167e7c08": 91, "8d4ab94e1c": 96, "8d81f6f899": 180, "8d87897d66": 91, "8dcccd2bd2": 180, "8dcfb878a8": 150, "8dd3ab71b9": 91, "8dda6bf10f": 96, "8ddd51ca94": 180, "8dea22c533": 180, "8def5bd3bf": 96, "8e1848197c": 91, "8e3a83cf2d": 91, "8e478e73f3": 91, "8e98ae3c84": 96, "8ea6687ab0": 180, "8eb0d315c1": 91, "8ec10891f9": 150, "8ec3065ec2": 180, "8ecf51a971": 150, "8eddbab9f7": 91, "8ee198467a": 180, "8ee2368f40": 180, "8ef595ce82": 150, "8f0a653ad7": 150, "8f1204a732": 150, "8f1600f7f6": 91, "8f16366707": 96, "8f1ce0a411": 92, "8f2e05e814": 91, "8f320d0e09": 96, "8f3b4a84ad": 91, "8f3fdad3da": 96, "8f5d3622d8": 96, "8f62a2c633": 180, "8f81c9405a": 97, "8f8c974d53": 120, "8f918598b6": 96, "8ff61619f6": 96, "9002761b41": 96, "90107941f3": 92, "90118a42ee": 96, "902bc16b37": 91, "903e87e0d6": 144, "9041a0f489": 96, "9047bf3222": 51, "9057bfa502": 150, "90617b0954": 92, "9076f4b6db": 180, "9077e69b08": 144, "909655b4a6": 96, "909c2eca88": 180, "909dbd1b76": 180, "90bc4a319a": 180, "90c7a87887": 96, "90cc785ddd": 96, "90d300f09b": 180, "9101ea9b1b": 96, "9108130458": 150, "911ac9979b": 150, "9151cad9b5": 97, "9153762797": 180, "91634ee0c9": 91, "916942666f": 76, "9198cfb4ea": 180, "919ac864d6": 180, "91b67d58d4": 180, "91bb8df281": 150, "91be106477": 91, "91c33b4290": 180, "91ca7dd9f3": 144, "91d095f869": 180, "91f107082e": 180, "920329dd5e": 180, "920c959958": 150, "92128fbf4b": 144, "9223dacb40": 150, "923137bb7f": 61, "9268e1f88a": 180, "927647fe08": 150, "9276f5ba47": 150, "92a28cd233": 71, "92b5c1fc6d": 144, "92c46be756": 180, "92dabbe3a0": 96, "92e3159361": 180, "92ebab216a": 180, "934bdc2893": 180, "9359174efc": 180, "935d97dd2f": 91, "935feaba1b": 96, "93901858ee": 150, "939378f6d6": 91, "939bdf742e": 96, "93a22bee7e": 96, "93da9aeddf": 91, "93e2feacce": 180, "93e6f1fdf9": 96, "93e811e393": 180, "93e85d8fd3": 180, "93f623d716": 180, "93ff35e801": 46, "94031f12f2": 96, "94091a4873": 180, "94125907e3": 87, "9418653742": 91, "941c870569": 101, "94209c86f0": 180, "9437c715eb": 76, "9445c3eca2": 91, "9467c8617c": 96, "946d71fb5d": 96, "948f3ae6fb": 180, "9498baa359": 96, "94a33abeab": 91, "94bf1af5e3": 144, "94cf3a8025": 96, "94db712ac8": 180, "94e4b66cff": 92, "94e76cbaf6": 180, "950be91db1": 180, "952058e2d0": 92, "952633c37f": 96, "952ec313fe": 87, "9533fc037c": 96, "9574b81269": 92, "9579b73761": 180, "957f7bc48b": 180, "958073d2b0": 150, "9582e0eb33": 71, "9584092d0b": 91, "95b58b8004": 150, "95bd88da55": 180, "95f74a9959": 180, "962781c601": 180, "962f045bf5": 91, "964ad23b44": 91, "967b90590e": 144, "967bffe201": 86, "96825c4714": 81, "968492136a": 96, "9684ef9d64": 86, "968c41829e": 91, "96a856ef9a": 180, "96dfc49961": 180, "96e1a5b4f8": 180, "96e6ff0917": 150, "96fb88e9d7": 96, "96fbe5fc23": 150, "96fc924050": 96, "9715cc83dc": 180, "9720eff40f": 180, "972c187c0d": 180, "97476eb38d": 180, "97659ed431": 180, "9773492949": 96, "97756b264f": 96, "977bff0d10": 96, "97ab569ff3": 96, "97ba838008": 180, "97d9d008c7": 150, "97e59f09fa": 96, "97eb642e56": 96, "98043e2d14": 96, "981ff580cf": 180, "983e66cbfc": 96, "984f0f1c36": 180, "98595f2bb4": 91, "985c3be474": 91, "9869a12362": 180, "986b5a5e18": 180, "9877af5063": 180, "98911292da": 180, "9893a3cf77": 97, "9893d9202d": 91, "98a8b06e7f": 91, "98ac6f93d9": 150, "98b6974d12": 96, "98ba3c9417": 180, "98c7c00a19": 96, "98d044f206": 96, "98e909f9d1": 150, "98fe7f0410": 150, "990f2742c7": 96, "992bd0779a": 180, "994b9b47ba": 150, "9955b76bf5": 91, "9966f3adac": 46, "997117a654": 180, "999d53d841": 150, "99c04108d3": 180, "99c4277aee": 96, "99c6b1acf2": 96, "99dc8bb20b": 180, "99fcba71e5": 150, "99fecd4efb": 92, "9a02c70ba2": 96, "9a08e7a6f8": 180, "9a2f2c0f86": 81, "9a3254a76e": 92, "9a3570a020": 180, "9a39112493": 180, "9a4e9fd399": 180, "9a50af4bfb": 180, "9a68631d24": 150, "9a72318dbf": 92, "9a767493b7": 180, "9a7fc1548b": 96, "9a84ccf6a7": 150, "9a9c0e15b7": 96, "9adf06d89b": 150, "9b22b54ee4": 91, "9b473fc8fe": 96, "9b4f081782": 180, "9b997664ba": 180, "9bc454e109": 180, "9bccfd04de": 96, "9bce4583a2": 96, "9bebf1b87f": 158, "9bfc50d261": 180, "9c166c86ff": 96, "9c293ef4d7": 144, "9c29c047b0": 91, "9c3bc2e2a7": 96, "9c3ce23bd1": 91, "9c404cac0c": 180, "9c5180d23a": 144, "9c7feca6e4": 144, "9caa49d3ff": 180, "9cb2f1b646": 180, "9ce6f765c3": 91, "9cfee34031": 180, "9d01f08ec6": 180, "9d04c280b8": 91, "9d12ceaddc": 180, "9d15f8cb3c": 180, "9d2101e9bf": 180, "9d407c3aeb": 96, "9ddefc6165": 180, "9df0b1e298": 96, "9e16f115d8": 144, "9e249b4982": 96, "9e29b1982c": 92, "9e493e4773": 180, "9e4c752cd0": 91, "9e4de40671": 96, "9e6319faeb": 96, "9e6ddbb52d": 91, "9eadcea74f": 180, "9ecec5f8ea": 46, "9efb47b595": 96, "9f30bfe61e": 72, "9f3734c3a4": 180, "9f5b858101": 180, "9f66640cda": 180, "9f913803e9": 180, "9f97bc74c8": 180, "9fbad86e20": 180, "9fc2bad316": 180, "9fc5c3af78": 150, "9fcb310255": 92, "9fcc256871": 91, "9fd2fd4d47": 180, "a0071ae316": 96, "a023141022": 56, "a046399a74": 96, "a066e739c1": 150, "a06722ba82": 96, "a07a15dd64": 180, "a07b47f694": 180, "a09c39472e": 144, "a0b208fe2e": 91, "a0b61c959e": 96, "a0bc6c611d": 180, "a0e6da5ba2": 91, "a1193d6490": 96, "a14ef483ff": 91, "a14f709908": 180, "a15ccc5658": 96, "a16062456f": 180, "a174e8d989": 91, "a177c2733c": 150, "a17c62e764": 92, "a18ad065fc": 150, "a1aaf63216": 96, "a1bb65fb91": 150, "a1bd8e5349": 91, "a1dfdd0cac": 180, "a2052e4f6c": 96, "a20fd34693": 96, "a21ffe4d81": 150, "a22349e647": 180, "a235d01ec1": 180, "a24f63e8a2": 180, "a2554c9f6d": 46, "a263ce8a87": 180, "a29bfc29ec": 91, "a2a80072d4": 150, "a2a800ab63": 180, "a2bcd10a33": 180, "a2bdaff3b0": 91, "a2c146ab0d": 91, "a2c996e429": 96, "a2dc51ebe8": 180, "a2e6608bfa": 180, "a2f2a55f01": 96, "a301869dea": 180, "a31fccd2cc": 180, "a34f440f33": 180, "a35e0206da": 180, "a36bdc4cab": 180, "a36e8c79d8": 71, "a378053b20": 144, "a37db3a2b3": 91, "a38950ebc2": 180, "a39a0eb433": 91, "a39c9bca52": 180, "a3a945dc8c": 91, "a3b40a0c1e": 150, "a3b8588550": 91, "a3c502bec3": 180, "a3f2878017": 180, "a3f4d58010": 180, "a3f51855c3": 150, "a402dc0dfe": 21, "a4065a7eda": 180, "a412bb2fef": 180, "a416b56b53": 96, "a41ec95906": 91, "a43299e362": 180, "a4757bd7af": 96, "a48c53c454": 180, "a49dcf9ad5": 150, "a4a506521f": 180, "a4ba7753d9": 180, "a4bac06849": 91, "a4f05d681c": 91, "a50c10060f": 150, "a50eb5a0ea": 150, "a5122c6ec6": 150, "a522b1aa79": 96, "a590915345": 180, "a5b5b59139": 96, "a5b77abe43": 180, "a5c2b2c3e1": 96, "a5cd17bb11": 180, "a5da03aef1": 180, "a5dd11de0d": 150, "a5ea2b93b6": 150, "a5eaeac80b": 180, "a5ec5b0265": 144, "a5f350a87e": 180, "a5f472caf4": 96, "a6027a53cf": 180, "a61715bb1b": 180, "a61cf4389d": 150, "a61d9bbd9b": 180, "a6470dbbf5": 150, "a64a40f3eb": 76, "a653d5c23b": 180, "a65bd23cb5": 150, "a66e0b7ad4": 180, "a66fc5053c": 91, "a68259572b": 180, "a6a810a92c": 150, "a6bc36937f": 91, "a6c3a374e9": 180, "a6d8a4228d": 180, "a6f4e0817f": 180, "a71e0481f5": 96, "a7203deb2d": 150, "a7392d4438": 150, "a73d3c3902": 180, "a7491f1578": 150, "a74b9ca19c": 180, "a77b7a91df": 150, "a78195a5f5": 150, "a78758d4ce": 180, "a7e6d6c29a": 96, "a800d85e88": 51, "a832fa8790": 180, "a83d06410d": 150, "a8999af004": 180, "a8f78125b9": 180, "a907b18df1": 150, "a919392446": 150, "a965504e88": 96, "a96b84b8d2": 96, "a973f239cd": 91, "a977126596": 180, "a9804f2a08": 91, "a984e56893": 96, "a99738f24c": 91, "a99bdd0079": 144, "a9c9c1517e": 178, "a9cbf9c41b": 150, "a9e42e3c0c": 150, "aa07b7c1c0": 180, "aa175e5ec7": 96, "aa1a338630": 96, "aa27d7b868": 96, "aa45f1caaf": 91, "aa49e46432": 96, "aa51934e1b": 180, "aa6287bb6c": 96, "aa6d999971": 180, "aa85278334": 96, "aab33f0e2a": 180, "aaba004362": 180, "aade4cf385": 180, "aae78feda4": 91, "aaed233bf3": 180, "aaff16c2db": 96, "ab199e8dfb": 96, "ab23b78715": 96, "ab2e1b5577": 180, "ab33a18ded": 96, "ab45078265": 180, "ab56201494": 180, "ab90f0d24b": 180, "abab2e6c20": 180, "abb50c8697": 92, "abbe2d15a0": 180, "abbe73cd21": 150, "abe61a11bb": 180, "abeae8ce21": 150, "ac2b431d5f": 150, "ac2cb1b9eb": 150, "ac31fcd6d0": 91, "ac3d3a126d": 180, "ac46bd8087": 180, "ac783ef388": 180, "acb73e4297": 150, "acbf581760": 180, "accafc3531": 96, "acf2c4b745": 96, "acf44293a2": 96, "acf736a27b": 90, "acff336758": 180, "ad1fe56886": 92, "ad28f9b9d9": 91, "ad2de9f80e": 180, "ad397527b2": 97, "ad3d1cfbcb": 86, "ad3fada9d9": 180, "ad4108ee8e": 180, "ad54468654": 66, "ad573f7d31": 96, "ad6255bc29": 180, "ad65ebaa07": 144, "ad97cc064a": 96, "adabbd1cc4": 180, "adb0b5a270": 180, "adc648f890": 150, "add21ee467": 180, "adfd15ceef": 180, "adfdd52eac": 96, "ae01cdab63": 180, "ae0b50ff4f": 96, "ae13ee3d70": 180, "ae1bcbd423": 180, "ae20d09dea": 180, "ae2cecf5f6": 56, "ae3bc4a0ef": 180, "ae499c7514": 92, "ae628f2cd4": 150, "ae8545d581": 86, "ae93214fe6": 150, "ae9cd16dbf": 46, "aeba9ac967": 180, "aebb242b5c": 150, "aed4e0b4c4": 86, "aedd71f125": 180, "aef3e2cb0e": 180, "af0b54cee3": 96, "af3de54c7a": 180, "af5fd24a36": 150, "af8826d084": 91, "af8ad72057": 180, "afb71e22c5": 92, "afcb331e1f": 96, "afe1a35c1e": 150, "b01080b5d3": 180, "b05ad0d345": 96, "b0623a6232": 91, "b064dbd4b7": 96, "b06ed37831": 96, "b06f5888e6": 92, "b08dcc490e": 91, "b0a68228dc": 92, "b0aece727f": 144, "b0b0731606": 96, "b0c7f11f9f": 180, "b0cca8b830": 180, "b0dd580a89": 180, "b0de66ca08": 180, "b0df7c5c5c": 96, "b0f5295608": 96, "b11099eb09": 180, "b132a53086": 91, "b1399fac64": 180, "b13abc0c69": 96, "b1457e3b5e": 180, "b15bf4453b": 91, "b179c4a82d": 96, "b17ee70e8c": 180, "b190b1aa65": 96, "b19b3e22c0": 180, "b19c561fab": 180, "b1d1cd2e6e": 92, "b1d7c03927": 91, "b1d7fe2753": 180, "b1f540a4bd": 96, "b1fc9c64e1": 96, "b1fcbb3ced": 180, "b220939e93": 96, "b22099b419": 180, "b241e95235": 96, "b2432ae86d": 180, "b2456267df": 180, "b247940d01": 150, "b24af1c35c": 180, "b24f600420": 97, "b24fe36b2a": 150, "b258fb0b7d": 180, "b26b219919": 96, "b26d9904de": 96, "b274456ce1": 180, "b27b28d581": 72, "b2a26bc912": 180, "b2a9c51e1b": 180, "b2b0baf470": 180, "b2b2756fe7": 96, "b2ce7699e3": 180, "b2edc76bd2": 150, "b2f6b52100": 180, "b30bf47bcd": 180, "b34105a4e9": 91, "b372a82edf": 150, "b3779a1962": 96, "b379ab4ff5": 46, "b37a1d69e3": 150, "b37c01396e": 180, "b382b09e25": 150, "b3996e4ba5": 180, "b3d9ca2aee": 180, "b3dde1e1e9": 180, "b3eb7f05eb": 86, "b40b25055c": 91, "b41e0f1f19": 91, "b44e32a42b": 91, "b4805ae9cd": 46, "b4807569a5": 97, "b48efceb3e": 150, "b493c25c7f": 180, "b4b565aba1": 150, "b4b715a15b": 180, "b4d0c90bf4": 91, "b4d84bc371": 180, "b4e5ad97aa": 180, "b4eaea9e6b": 150, "b50f4b90d5": 180, "b53f675641": 150, "b54278cd43": 180, "b554843889": 150, "b573c0677a": 180, "b58d853734": 180, "b5943b18ab": 180, "b5a09a83f3": 71, "b5aae1fe25": 91, "b5b9da5364": 97, "b5eb64d419": 91, "b5ebb1d000": 96, "b5f1c0c96a": 96, "b5f7fece90": 180, "b6070de1bb": 180, "b60a76fe73": 86, "b61f998772": 96, "b62c943664": 96, "b63094ba0c": 180, "b64fca8100": 96, "b673e7dcfb": 96, "b678b7db00": 180, "b68fc1b217": 180, "b69926d9fa": 96, "b6a1df3764": 180, "b6a4859528": 96, "b6b4738b78": 96, "b6b4f847b7": 150, "b6b8d502d4": 150, "b6bb00e366": 180, "b6d65a9eef": 180, "b6d79a0845": 180, "b6e9ec577f": 91, "b6ec609f7b": 163, "b6f92a308d": 180, "b70a2c0ab1": 46, "b70a5a0d50": 180, "b70c052f2f": 150, "b70d231781": 92, "b72ac6e10b": 180, "b7302d8226": 92, "b73867d769": 150, "b751e767f2": 180, "b76df6e059": 96, "b77e5eddef": 92, "b7a2c2c83c": 96, "b7bcbe6466": 180, "b7c2a469c4": 180, "b7d69da8f0": 144, "b7f31b7c36": 61, "b7f675fb98": 46, "b7fb871660": 51, "b82e5ad1c9": 91, "b841cfb932": 96, "b84b8ae665": 180, "b85b78ac2b": 180, "b86c17caa6": 180, "b86e50d82d": 96, "b871db031a": 66, "b87d56925a": 96, "b8aaa59b75": 92, "b8c03d1091": 180, "b8c3210036": 46, "b8e16df00b": 144, "b8f34cf72e": 91, "b8fb75864e": 150, "b9004db86c": 180, "b9166cbae9": 92, "b920b256a6": 180, "b938d79dff": 20, "b93963f214": 180, "b941aef1a0": 144, "b94d34d14e": 96, "b964c57da4": 96, "b96a95bc7a": 180, "b96c57d2c7": 144, "b9b6bdde0c": 180, "b9bcb3e0f2": 96, "b9d3b92169": 180, "b9dd4b306c": 180, "b9f43ef41e": 92, "ba1f03c811": 96, "ba3a775d7b": 180, "ba3c7f2a31": 150, "ba3fcd417d": 180, "ba5e1f4faa": 150, "ba795f3089": 96, "ba8a291e6a": 150, "ba98512f97": 92, "bac9db04f5": 180, "baedae3442": 180, "baff40d29d": 180, "bb04e28695": 96, "bb1b0ee89f": 96, "bb1c770fe7": 150, "bb1fc34f99": 150, "bb2d220506": 180, "bb334e5cdb": 91, "bb337f9830": 81, "bb721eb9aa": 96, "bb87ff58bd": 96, "bb89a6b18a": 87, "bbaa9a036a": 144, "bbb4302dda": 180, "bbd31510cf": 96, "bbe0256a75": 180, "bc141b9ad5": 91, "bc17ab8a99": 150, "bc318160de": 180, "bc3b9ee033": 91, "bc4240b43c": 96, "bc4ce49105": 91, "bc4f71372d": 96, "bc6b8d6371": 180, "bcaad44ad7": 150, "bcc241b081": 91, "bcc5d8095e": 96, "bcd1d39afb": 96, "bd0d849da4": 180, "bd0e9ed437": 150, "bd2c94730f": 180, "bd321d2be6": 61, "bd3ec46511": 91, "bd5b2e2848": 41, "bd7e02b139": 96, "bd96f9943a": 180, "bda224cb25": 91, "bda4a82837": 96, "bdb74e333f": 180, "bdccd69dde": 96, "bddcc15521": 180, "be116aab29": 150, "be15e18f1e": 150, "be1a284edb": 180, "be2a367a7b": 180, "be376082d0": 150, "be3e3cffbd": 51, "be5d1d89a0": 180, "be8b72fe37": 180, "be9b29e08e": 91, "bea1f6e62c": 97, "bea83281b5": 92, "beb921a4c9": 96, "bec5e9edcd": 180, "beeb8a3f92": 150, "bf2232b58d": 96, "bf28751739": 150, "bf443804e8": 180, "bf461df850": 150, "bf5374f122": 180, "bf551a6f60": 180, "bf8d0f5ada": 96, "bf961167a6": 92, "bfab1ad8f9": 150, "bfcb05d88d": 96, "bfd8f6e6c9": 92, "bfd91d0742": 150, "bfe262322f": 87, "c013f42ed7": 180, "c01878083f": 180, "c01faff1ed": 180, "c046fd0edb": 150, "c053e35f97": 91, "c079a6482d": 96, "c0847b521a": 96, "c0a1e06710": 180, "c0e8d4635c": 96, "c0e973ad85": 96, "c0f49c6579": 92, "c0f5b222d7": 96, "c10d07c90d": 180, "c1268d998c": 96, "c130c3fc0c": 180, "c14826ad5e": 180, "c15b922281": 180, "c16f09cb63": 180, "c18e19d922": 180, "c1c830a735": 96, "c1e8aeea45": 180, "c20a5ccc99": 180, "c20fd5e597": 180, "c219d6f8dc": 150, "c2406ae462": 96, "c26f7b5824": 180, "c279e641ee": 96, "c27adaeac5": 180, "c2a35c1cda": 96, "c2a9903b8b": 180, "c2b62567c1": 96, "c2b974ec8c": 150, "c2baaff7bf": 91, "c2be6900f2": 180, "c304dd44d5": 180, "c307f33da2": 96, "c30a7b62c9": 92, "c3128733ee": 180, "c31fa6c598": 180, "c325c8201e": 96, "c32d4aa5d1": 180, "c33f28249a": 144, "c34365e2d7": 180, "c3457af795": 96, "c34d120a88": 180, "c3509e728d": 96, "c35e4fa6c4": 180, "c36240d96f": 150, "c3641dfc5a": 92, "c37b17a4a9": 180, "c39559ddf6": 180, "c3b0c6e180": 96, "c3b3d82e6c": 180, "c3be369fdb": 91, "c3bf1e40c2": 97, "c3c760b015": 96, "c3dd38bf98": 150, "c3e4274614": 91, "c3edc48cbd": 180, "c41e6587f5": 96, "c4272227b0": 96, "c42917fe82": 86, "c438858117": 180, "c44676563f": 180, "c44beb7472": 180, "c45411dacb": 91, "c4571bedc8": 91, "c46deb2956": 180, "c479ee052e": 180, "c47d551843": 180, "c49f07d46d": 180, "c4cc40c1fc": 97, "c4f256f5d5": 144, "c4f5b1ddcc": 180, "c4ff9b4885": 150, "c52bce43db": 66, "c544da6854": 180, "c55784c766": 180, "c557b69fbf": 180, "c593a3f7ab": 92, "c598faa682": 180, "c5ab1f09c8": 180, "c5b6da8602": 96, "c5b9128d94": 96, "c5e845c6b7": 150, "c5fba7b341": 150, "c60897f093": 96, "c61fe6ed7c": 96, "c62188c536": 96, "c64035b2e2": 150, "c69689f177": 180, "c6a12c131f": 51, "c6bb6d2d5c": 180, "c6c18e860f": 150, "c6d9526e0d": 180, "c6e55c33f0": 96, "c7030b28bd": 96, "c70682c7cc": 180, "c70f9be8c5": 87, "c71f30d7b6": 180, "c73c8e747f": 180, "c760eeb8b3": 144, "c7637cab0a": 150, "c7a1a17308": 87, "c7bf937af5": 91, "c7c2860db3": 180, "c7cef4aee2": 91, "c7ebfc5d57": 180, "c813dcf13c": 91, "c82235a49a": 96, "c82a7619a1": 180, "c82ecb90cb": 180, "c844f03dc7": 96, "c8557963f3": 91, "c89147e6e8": 180, "c8a46ff0c8": 150, "c8ab107dd5": 97, "c8b869a04a": 96, "c8c7b306a6": 91, "c8c8b28781": 180, "c8d79e3163": 180, "c8edab0415": 150, "c8f494f416": 96, "c8f6cba9fd": 150, "c909ceea97": 92, "c9188f4980": 180, "c922365dd4": 96, "c92c8c3c75": 96, "c937eb0b83": 91, "c94b31b5e5": 180, "c95cd17749": 180, "c96379c03c": 180, "c96465ee65": 180, "c965afa713": 144, "c9734b451f": 92, "c9862d82dc": 180, "c98b6fe013": 180, "c9999b7c48": 180, "c99e92aaf0": 97, "c9b3a8fbda": 150, "c9bf64e965": 96, "c9c3cb3797": 91, "c9d1c60cd0": 144, "c9de9c22c4": 96, "ca1828fa54": 96, "ca346f17eb": 180, "ca3787d3d3": 150, "ca4b99cbac": 96, "ca91c69e3b": 71, "ca91e99105": 46, "caa8e97f81": 96, "caac5807f8": 180, "cabba242c2": 96, "cad5a656a9": 180, "cad673e375": 180, "cad8a85930": 150, "cae7b0a02b": 180, "cae7ef3184": 180, "caeb6b6cbb": 150, "caecf0a5db": 91, "cb15312003": 76, "cb2e35d610": 150, "cb35a87504": 150, "cb3f22b0cf": 96, "cbb410da64": 91, "cc8728052e": 150, "cc892997b8": 180, "cce03c2a9b": 144, "cd47a23e31": 92, "cd4dc03dc0": 180, "cd5ae611da": 96, "cd603bb9d1": 144, "cd8f49734c": 180, "cdc6b1c032": 92, "cdcfe008ad": 144, "cdd57027c2": 96, "ce1af99b4b": 150, "ce1bc5743a": 150, "ce25872021": 97, "ce2776f78f": 180, "ce49b1f474": 180, "ce4f0a266f": 180, "ce5641b195": 180, "ce6866aa19": 180, "ce712ed3c9": 91, "ce7d1c8117": 144, "ce7dbeaa88": 180, "ce9b015a5e": 180, "cea7697b25": 96, "cebbd826cf": 150, "cec3415361": 150, "cec41ad4f4": 180, "ced49d26df": 180, "ced7705ab2": 144, "cef824a1e1": 92, "cf13f5c95a": 144, "cf4376a52d": 180, "cf85ab28b5": 180, "cfc2e50b9d": 150, "cfcd571fff": 144, "cfd9d4ae47": 180, "cfda2dcce5": 150, "cff035928b": 91, "cff8191891": 46, "d01608c2a5": 96, "d01a8f1f83": 144, "d021d68bca": 180, "d04258ca14": 150, "d0483573dc": 150, "d04a90aaff": 180, "d05279c0bd": 180, "d0696bd5fc": 91, "d072fda75b": 178, "d0a83bcd9f": 150, "d0ab39112e": 180, "d0acde820f": 96, "d0b4442c71": 144, "d0c65e9e95": 180, "d0fb600c73": 150, "d107a1457c": 61, "d123d674c1": 66, "d14d1e9289": 96, "d154e3388e": 96, "d177e9878a": 96, "d1802f69f8": 150, "d182c4483a": 180, "d195d31128": 180, "d200838929": 180, "d205e3cff5": 180, "d247420c4c": 180, "d2484bff33": 66, "d26f6ed9b0": 150, "d280fcd1cb": 180, "d2857f0faa": 180, "d292a50c7f": 46, "d295ea2dc7": 96, "d2a58b4fa6": 91, "d2b026739a": 150, "d2ebe0890f": 180, "d2ede5d862": 91, "d301ca58cc": 150, "d3069da8bb": 91, "d343d4a77d": 150, "d355e634ef": 86, "d367fb5253": 91, "d36d16358e": 76, "d38bc77e2c": 101, "d38d1679e2": 144, "d3932ad4bd": 97, "d3987b2930": 180, "d39934abe3": 144, "d3ae1c3f4c": 92, "d3b088e593": 87, "d3e6e05e16": 150, "d3eefae7c5": 144, "d3f55f5ab8": 180, "d3f5c309cc": 61, "d4034a7fdf": 180, "d4193011f3": 144, "d429c67630": 180, "d42c0ff975": 180, "d44a764409": 180, "d44e6acd1d": 66, "d45158c175": 150, "d454e8444f": 150, "d45f62717e": 180, "d48ebdcf74": 180, "d49ab52a25": 86, "d4a607ad81": 92, "d4b063c7db": 144, "d4da13e9ba": 96, "d4dd1a7d00": 180, "d4f4f7c9c3": 96, "d521aba02e": 180, "d535bb1b97": 92, "d53b955f78": 96, "d55cb7a205": 92, "d55f247a45": 150, "d5695544d8": 180, "d5853d9b8b": 180, "d5b6c6d94a": 96, "d5cae12834": 150, "d5df027f0c": 144, "d5ee40e5d0": 180, "d600046f73": 144, "d632fd3510": 144, "d6476cad55": 180, "d65a7bae86": 150, "d664c89912": 150, "d689658f06": 180, "d6917db4be": 96, "d69967143e": 96, "d699d3d798": 91, "d69f757a3f": 180, "d6ac0e065c": 91, "d6c02bfda5": 96, "d6c1b5749e": 92, "d6e12ef6cc": 92, "d6eed152c4": 180, "d6faaaf726": 96, "d704766646": 180, "d708e1350c": 180, "d7135cf104": 180, "d7157a9f44": 46, "d719cf9316": 96, "d724134cfd": 144, "d73a60a244": 180, "d7411662da": 144, "d74875ea7c": 96, "d756f5a694": 91, "d7572b7d8a": 180, "d763bd6d96": 180, "d7697c8b13": 96, "d7797196b4": 150, "d79c834768": 180, "d7b34e5d73": 91, "d7bb6b37a7": 150, "d7c7e064a6": 180, "d7fbf545b3": 96, "d82a0aa15b": 180, "d847e24abd": 144, "d8596701b7": 144, "d86101499c": 144, "d87069ba86": 150, "d87160957b": 144, "d874654b52": 91, "d88a403092": 96, "d8aee40f3f": 144, "d8e77a222d": 91, "d8eb07c381": 180, "d9010348a1": 66, "d90e3cf281": 91, "d92532c7b2": 180, "d927fae122": 150, "d95707bca8": 91, "d973b31c00": 144, "d991cb471d": 180, "d992c69d37": 150, "d99d770820": 180, "d9b63abc11": 180, "d9db6f1983": 144, "d9e52be2d2": 96, "d9edc82650": 150, "da01070697": 96, "da070ea4b7": 180, "da080507b9": 150, "da0e944cc4": 180, "da28d94ff4": 96, "da5d78b9d1": 180, "da6003fc72": 150, "da690fee9f": 180, "da6c68708f": 180, "da7a816676": 144, "dac361e828": 180, "dac71659b8": 144, "dad980385d": 96, "daebc12b77": 150, "db0968cdd3": 150, "db231a7100": 92, "db59282ace": 91, "db7f267c3f": 180, "dba35b87fd": 96, "dbba735a50": 86, "dbca076acd": 180, "dbd66dc3ac": 180, "dbdc3c292b": 180, "dbf4a5b32b": 180, "dbfc417d28": 180, "dc1745e0a2": 91, "dc32a44804": 180, "dc34b35e30": 150, "dc504a4f79": 92, "dc704dd647": 180, "dc71bc6918": 92, "dc7771b3be": 180, "dcf8c93617": 96, "dd0f4c9fb9": 180, "dd415df125": 120, "dd601f9a3f": 144, "dd61d903df": 150, "dd77583736": 150, "dd8636bd8b": 180, "dd9fe6c6ac": 92, "ddb2da4c14": 180, "ddcd450d47": 144, "dde8e67fb4": 76, "ddfc3f04d3": 150, "de2ab79dfa": 180, "de2f35b2fd": 91, "de30990a51": 180, "de36b216da": 96, "de37403340": 180, "de46e4943b": 96, "de4ddbccb1": 180, "de5e480f05": 96, "de6a9382ca": 96, "de74a601d3": 180, "de827c510d": 92, "ded6069f7b": 180, "defb71c741": 96, "df01f277f1": 180, "df05214b82": 92, "df0638b0a0": 46, "df11931ffe": 180, "df1b0e4620": 180, "df20a8650d": 92, "df2bc56d7c": 180, "df365282c6": 180, "df39a0d9df": 96, "df3c430c24": 91, "df5536cfb9": 180, "df59cfd91d": 97, "df5e2152b3": 66, "df741313c9": 96, "df7626172f": 92, "df8ad5deb9": 180, "df96aa609a": 180, "df9705605c": 180, "df9c91c4da": 180, "dfc0d3d27a": 180, "dfdbf91a99": 180, "e00baaae9b": 180, "e0a938c6e7": 91, "e0b2ceee6f": 150, "e0bdb5dfae": 36, "e0be1f6e17": 96, "e0c478f775": 150, "e0de82caa7": 180, "e0f217dd59": 91, "e0f7208874": 180, "e0fb58395e": 180, "e1194c2e9d": 150, "e11adcd05d": 180, "e128124b9d": 87, "e1495354e4": 180, "e1561d6d4b": 180, "e158805399": 91, "e16945b951": 46, "e19edcd34b": 180, "e1a1544285": 180, "e1ab7957f4": 150, "e1d26d35be": 96, "e1e957085b": 96, "e1f14510fa": 180, "e214b160f4": 180, "e2167379b8": 150, "e21acb20ab": 180, "e221105579": 180, "e22ddf8a1b": 180, "e22de45950": 96, "e22ffc469b": 180, "e23cca5244": 96, "e252f46f0b": 180, "e25fa6cf39": 180, "e26e486026": 150, "e275760245": 96, "e27bbedbfe": 92, "e29e9868a8": 180, "e2b37ff8af": 96, "e2b608d309": 180, "e2bef4da9a": 96, "e2c87a6421": 96, "e2ea25542c": 144, "e2fb1d6497": 178, "e2fcc99117": 91, "e33c18412a": 71, "e348377191": 91, "e352cb59c8": 180, "e36ac982f0": 91, "e391bc981e": 96, "e39e3e0a06": 96, "e3bf38265f": 51, "e3d5b2cd21": 150, "e3d60e82d5": 46, "e3e3245492": 96, "e3e4134877": 150, "e3f4635e03": 180, "e4004ee048": 180, "e402d1afa5": 180, "e415093d27": 71, "e41ceb5d81": 180, "e424653b78": 96, "e42b6d3dbb": 96, "e42d60f0d4": 180, "e436d0ff1e": 180, "e43d7ae2c5": 92, "e4428801bc": 97, "e44e0b4917": 180, "e470345ede": 180, "e48e8b4263": 180, "e4922e3726": 180, "e4936852bb": 96, "e495f32c60": 41, "e499228f26": 150, "e4af66e163": 180, "e4b2095f58": 180, "e4d19c8283": 180, "e4d4872dab": 96, "e4e2983570": 41, "e4eaa63aab": 91, "e4ef0a3a34": 91, "e4f8e5f46e": 96, "e4ffb6d0dd": 71, "e53e21aa02": 180, "e57f4f668b": 180, "e588433c1e": 96, "e597442c99": 150, "e5abc0e96b": 91, "e5be628030": 180, "e5ce96a55d": 61, "e5d6b70a9f": 81, "e5fde1574c": 92, "e625e1d27b": 180, "e6261d2348": 91, "e6267d46bc": 96, "e6295f223f": 180, "e63463d8c6": 96, "e6387bd1e0": 180, "e653883384": 96, "e65f134e0b": 150, "e668ef5664": 180, "e672ccd250": 92, "e674510b20": 91, "e676107765": 150, "e699da0cdf": 180, "e6be243065": 46, "e6deab5e0b": 76, "e6f065f2b9": 96, "e71629e7b5": 96, "e72a7d7b0b": 150, "e72f6104e1": 92, "e75a466eea": 72, "e76c55933f": 150, "e7784ec8ad": 180, "e78922e5e6": 47, "e78d450a9c": 91, "e7c6354e77": 91, "e7c8de1fce": 150, "e7ea10db28": 150, "e803918710": 180, "e8073a140b": 180, "e828dd02db": 150, "e845994987": 150, "e8485a2615": 96, "e85c5118a7": 180, "e88b6736e4": 180, "e8962324e3": 91, "e8b3018d36": 91, "e8cee8bf0b": 150, "e8d97ebece": 144, "e8da49ea6a": 96, "e8ed1a3ccf": 180, "e8f7904326": 72, "e8f8341dec": 180, "e8fa21eb13": 180, "e90c10fc4c": 150, "e914b8cac8": 180, "e92b6bfea4": 46, "e92e1b7623": 150, "e93f83e512": 92, "e9422ad240": 46, "e9460b55f9": 180, "e9502628f6": 180, "e950befd5f": 180, "e9582bdd1b": 91, "e95e5afe0f": 96, "e97cfac475": 96, "e98d57d99c": 91, "e98eda8978": 92, "e99706b555": 41, "e9bc0760ba": 91, "e9d3c78bf3": 87, "e9ec1b7ea8": 144, "ea065cc205": 180, "ea138b6617": 150, "ea16d3fd48": 180, "ea2545d64b": 180, "ea286a581c": 150, "ea320da917": 96, "ea345f3627": 91, "ea3b94a591": 180, "ea444a37eb": 71, "ea4a01216b": 180, "ea5672ffa8": 81, "eaa99191cb": 150, "eaab4d746c": 91, "eac7a59bc1": 150, "ead5d3835a": 96, "eaec65cfa7": 180, "eaed1a87be": 180, "eb2f821c6f": 180, "eb383cb82e": 91, "eb6992fe02": 150, "eb6ac20a01": 92, "eb6d7ab39e": 96, "eb7921facd": 180, "eb8fce51a6": 180, "ebbb90e9f9": 91, "ebbf5c9ee1": 180, "ebc4ec32e6": 91, "ebe56e5ef8": 180, "ec1299aee4": 97, "ec139ff675": 180, "ec193e1a01": 180, "ec28252938": 150, "ec387be051": 180, "ec3d4fac00": 91, "ec4186ce12": 95, "ec579c2f96": 91, "ecae59b782": 180, "ecb33a0448": 180, "ece6bc9e92": 150, "ecfedd4035": 92, "ecfff22fd6": 180, "ed3291c3d6": 180, "ed3cd5308d": 180, "ed3e6fc1a5": 180, "ed72ae8825": 180, "ed7455da68": 92, "ed844e879f": 150, "ed8f814b2b": 92, "ed911a1f63": 180, "ed9ff4f649": 180, "eda8ab984b": 180, "edb8878849": 96, "edbfdfe1b4": 180, "edd22c46a2": 96, "edd663afa3": 180, "ede3552eae": 96, "edeab61ee0": 174, "ee07583fc0": 150, "ee316eaed6": 91, "ee3f509537": 150, "ee40a1e491": 92, "ee4bf100f1": 180, "ee6f9b01f9": 180, "ee947ed771": 96, "ee9706ac7f": 91, "ee9a7840ae": 180, "eeb90cb569": 180, "eebf45e5c5": 92, "eeed0c7d73": 87, "ef0061a309": 96, "ef07f1a655": 96, "ef0a8e8f35": 56, "ef232a2aed": 150, "ef308ad2e9": 180, "ef44945428": 96, "ef45ce3035": 180, "ef5dde449d": 180, "ef5e770988": 144, "ef6359cea3": 96, "ef65268834": 180, "ef6cb5eae0": 86, "ef78972bc2": 150, "ef8cfcfc4f": 82, "ef96501dd0": 150, "ef9a2e976b": 91, "efb24f950f": 180, "efce0c1868": 180, "efe5ac6901": 91, "efe828affa": 180, "efea4e0523": 144, "f0268aa627": 180, "f0483250c8": 180, "f04cf99ee6": 62, "f05b189097": 96, "f08928c6d3": 96, "f09d74856f": 150, "f0a7607d63": 180, "f0ad38da27": 71, "f0c34e1213": 92, "f0c7f86c29": 180, "f0dfa18ba7": 150, "f0eb3179f7": 180, "f119bab27d": 150, "f14409b6a3": 180, "f1489baff4": 86, "f14c18cf6a": 180, "f15c607b92": 180, "f1af214222": 97, "f1b77bd309": 180, "f1ba9e1a3e": 180, "f1d99239eb": 66, "f1dc710cf4": 180, "f1ec5c08fa": 97, "f22648fe12": 180, "f22d21f1f1": 144, "f233257395": 91, "f23e95dbe5": 96, "f2445b1572": 150, "f253b3486d": 144, "f277c7a6a4": 91, "f2ab2b84d6": 87, "f2b7c9b1f3": 150, "f2b83d5ce5": 180, "f2c276018f": 150, "f2cfd94d64": 150, "f2dd6e3add": 150, "f2e7653f16": 180, "f2f333ad06": 96, "f2f55d6713": 180, "f2fdb6abec": 180, "f305a56d9f": 46, "f3085d6570": 96, "f3325c3338": 180, "f3400f1204": 180, "f34497c932": 97, "f34a56525e": 91, "f36483c824": 96, "f3704d5663": 91, "f3734c4913": 150, "f38e5aa5b4": 86, "f3986fba44": 180, "f3a0ffc7d9": 180, "f3b24a7d28": 96, "f3e6c35ec3": 180, "f3fc0ea80b": 96, "f40a683fbe": 180, "f4207ca554": 180, "f4377499c2": 150, "f46184f393": 144, "f46c2d0a6d": 180, "f46c364dca": 180, "f46f7a0b63": 180, "f46fe141b0": 91, "f470b9aeb0": 180, "f47eb7437f": 96, "f48b535719": 92, "f49e4866ac": 180, "f4aa882cfd": 180, "f4daa3dbd5": 96, "f4dd51ac35": 91, "f507a1b9dc": 96, "f51c5ac84b": 86, "f52104164b": 180, "f54c67b9bb": 96, "f5966cadd2": 180, "f5bddf5598": 91, "f5d85cfd17": 92, "f5e2e7d6a0": 96, "f5f051e9b4": 180, "f5f8a93a76": 150, "f6283e8af5": 96, "f635e9568b": 180, "f6474735be": 144, "f659251be2": 150, "f66981af4e": 96, "f6708fa398": 87, "f697fe8e8f": 96, "f6adb12c42": 76, "f6c7906ca4": 180, "f6cd0a8016": 144, "f6d6f15ae7": 144, "f6e501892c": 96, "f6f59d986f": 180, "f6fe8c90a5": 180, "f714160545": 144, "f74c3888d7": 180, "f7782c430e": 150, "f7783ae5f2": 96, "f77ab47923": 97, "f788a98327": 91, "f7961ac1f0": 96, "f7a71e7574": 150, "f7a8521432": 180, "f7afbf4947": 150, "f7b7cd5f44": 81, "f7cf4b4a39": 92, "f7d49799ad": 150, "f7e0c9bb83": 180, "f7e5b84928": 96, "f7e6bd58be": 96, "f7f2a38ac6": 96, "f7f6cb2d6d": 150, "f83f19e796": 76, "f85796a921": 91, "f8603c26b2": 180, "f8819b42ec": 144, "f891f8eaa1": 96, "f89288d10c": 92, "f895ae8cc1": 180, "f8af30d4b6": 97, "f8b4ac12f1": 180, "f8c3fb2b01": 180, "f8c8de2764": 180, "f8db369b40": 92, "f8fcb6a78c": 180, "f94aafdeef": 180, "f95d217b70": 96, "f9681d5103": 92, "f9750192a4": 91, "f9823a32c2": 96, "f991ddb4c2": 96, "f99d535567": 96, "f9ae3d98b7": 144, "f9b6217959": 91, "f9bd1fabf5": 96, "f9c68eaa64": 180, "f9d3e04c4f": 92, "f9daf64494": 180, "f9e4cc5a0a": 96, "f9ea6b7f31": 96, "f9f3852526": 180, "fa04c615cf": 150, "fa08e00a56": 180, "fa4370d74d": 180, "fa67744af3": 180, "fa88d48a92": 150, "fa8b904cc9": 92, "fa9526bdf1": 150, "fa9b9d2426": 150, "fad633fbe1": 150, "faf5222dc3": 91, "faff0e15f1": 180, "fb08c64e8c": 180, "fb23455a7f": 150, "fb2e19fa6e": 180, "fb34dfbb77": 180, "fb47fcea1e": 96, "fb49738155": 180, "fb4cbc514b": 71, "fb4e6062f7": 180, "fb5ba7ad6e": 96, "fb63cd1236": 96, "fb81157a07": 180, "fb92abdaeb": 180, "fba22a6848": 92, "fbaca0c9df": 180, "fbc645f602": 96, "fbd77444cd": 96, "fbe53dc8e8": 96, "fbe541dd73": 97, "fbe8488798": 91, "fbfd25174f": 96, "fc28cb305e": 97, "fc33b1ffd6": 150, "fc6186f0bb": 180, "fc918e3a40": 150, "fc96cda9d8": 150, "fc9832eea4": 150, "fcb10d0f81": 180, "fcd20a2509": 180, "fcf637e3ab": 92, "fcfd81727f": 96, "fd31890379": 180, "fd33551c28": 144, "fd542da05e": 144, "fd6789b3fe": 180, "fd77828200": 180, "fd7af75f4d": 150, "fdb28d0fbb": 150, "fdb3d1fb1e": 82, "fdb8b04124": 96, "fdc6e3d581": 91, "fdfce7e6fc": 180, "fe0f76d41b": 180, "fe24b0677d": 180, "fe3c02699d": 144, "fe58b48235": 96, "fe6a5596b8": 91, "fe6c244f63": 66, "fe7afec086": 180, "fe985d510a": 144, "fe9db35d15": 96, "fea8ffcd36": 144, "feb1080388": 180, "fed208bfca": 180, "feda5ad1c2": 180, "feec95b386": 91, "ff15a5eff6": 144, "ff204daf4b": 96, "ff25f55852": 180, "ff2ada194f": 180, "ff2ce142e8": 96, "ff49d36d20": 180, "ff5a1ec4f3": 180, "ff66152b25": 180, "ff692fdc56": 180, "ff773b1a1e": 96, "ff97129478": 144, "ffb904207d": 180, "ffc43fc345": 150, "fffe5f8df6": 180} \ No newline at end of file diff --git a/inference_propainter.py b/inference_propainter.py new file mode 100644 index 0000000000000000000000000000000000000000..80d476d7467bac26c24e04b195fb0940b0f85442 --- /dev/null +++ b/inference_propainter.py @@ -0,0 +1,475 @@ +# -*- coding: utf-8 -*- +import os +import cv2 +import argparse +import imageio +import numpy as np +import scipy.ndimage +from PIL import Image +from tqdm import tqdm + +import torch +import torchvision + +from model.modules.flow_comp_raft import RAFT_bi +from model.recurrent_flow_completion import RecurrentFlowCompleteNet +from model.propainter import InpaintGenerator +from utils.download_util import load_file_from_url +from core.utils import to_tensors +from model.misc import get_device + +import warnings +warnings.filterwarnings("ignore") + +pretrain_model_url = 'https://github.com/sczhou/ProPainter/releases/download/v0.1.0/' + +def imwrite(img, file_path, params=None, auto_mkdir=True): + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +# resize frames +def resize_frames(frames, size=None): + if size is not None: + out_size = size + process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8) + frames = [f.resize(process_size) for f in frames] + else: + out_size = frames[0].size + process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8) + if not out_size == process_size: + frames = [f.resize(process_size) for f in frames] + + return frames, process_size, out_size + + +# read frames from video +def read_frame_from_videos(frame_root): + if frame_root.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path + video_name = os.path.basename(frame_root)[:-4] + vframes, aframes, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec') # RGB + frames = list(vframes.numpy()) + frames = [Image.fromarray(f) for f in frames] + fps = info['video_fps'] + else: + video_name = os.path.basename(frame_root) + frames = [] + fr_lst = sorted(os.listdir(frame_root)) + for fr in fr_lst: + frame = cv2.imread(os.path.join(frame_root, fr)) + frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + frames.append(frame) + fps = None + size = frames[0].size + + return frames, fps, size, video_name + + +def binary_mask(mask, th=0.1): + mask[mask>th] = 1 + mask[mask<=th] = 0 + return mask + + +# read frame-wise masks +def read_mask(mpath, length, size, flow_mask_dilates=8, mask_dilates=5): + masks_img = [] + masks_dilated = [] + flow_masks = [] + + if mpath.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path + masks_img = [Image.open(mpath)] + else: + mnames = sorted(os.listdir(mpath)) + for mp in mnames: + masks_img.append(Image.open(os.path.join(mpath, mp))) + + for mask_img in masks_img: + if size is not None: + mask_img = mask_img.resize(size, Image.NEAREST) + mask_img = np.array(mask_img.convert('L')) + + # Dilate 8 pixel so that all known pixel is trustworthy + if flow_mask_dilates > 0: + flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=flow_mask_dilates).astype(np.uint8) + else: + flow_mask_img = binary_mask(mask_img).astype(np.uint8) + # Close the small holes inside the foreground objects + # flow_mask_img = cv2.morphologyEx(flow_mask_img, cv2.MORPH_CLOSE, np.ones((21, 21),np.uint8)).astype(bool) + # flow_mask_img = scipy.ndimage.binary_fill_holes(flow_mask_img).astype(np.uint8) + flow_masks.append(Image.fromarray(flow_mask_img * 255)) + + if mask_dilates > 0: + mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=mask_dilates).astype(np.uint8) + else: + mask_img = binary_mask(mask_img).astype(np.uint8) + masks_dilated.append(Image.fromarray(mask_img * 255)) + + if len(masks_img) == 1: + flow_masks = flow_masks * length + masks_dilated = masks_dilated * length + + return flow_masks, masks_dilated + + +def extrapolation(video_ori, scale): + """Prepares the data for video outpainting. + """ + nFrame = len(video_ori) + imgW, imgH = video_ori[0].size + + # Defines new FOV. + imgH_extr = int(scale[0] * imgH) + imgW_extr = int(scale[1] * imgW) + imgH_extr = imgH_extr - imgH_extr % 8 + imgW_extr = imgW_extr - imgW_extr % 8 + H_start = int((imgH_extr - imgH) / 2) + W_start = int((imgW_extr - imgW) / 2) + + # Extrapolates the FOV for video. + frames = [] + for v in video_ori: + frame = np.zeros(((imgH_extr, imgW_extr, 3)), dtype=np.uint8) + frame[H_start: H_start + imgH, W_start: W_start + imgW, :] = v + frames.append(Image.fromarray(frame)) + + # Generates the mask for missing region. + masks_dilated = [] + flow_masks = [] + + dilate_h = 4 if H_start > 10 else 0 + dilate_w = 4 if W_start > 10 else 0 + mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.uint8) + + mask[H_start+dilate_h: H_start+imgH-dilate_h, + W_start+dilate_w: W_start+imgW-dilate_w] = 0 + flow_masks.append(Image.fromarray(mask * 255)) + + mask[H_start: H_start+imgH, W_start: W_start+imgW] = 0 + masks_dilated.append(Image.fromarray(mask * 255)) + + flow_masks = flow_masks * nFrame + masks_dilated = masks_dilated * nFrame + + return frames, flow_masks, masks_dilated, (imgW_extr, imgH_extr) + + +def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1): + ref_index = [] + if ref_num == -1: + for i in range(0, length, ref_stride): + if i not in neighbor_ids: + ref_index.append(i) + else: + start_idx = max(0, mid_neighbor_id - ref_stride * (ref_num // 2)) + end_idx = min(length, mid_neighbor_id + ref_stride * (ref_num // 2)) + for i in range(start_idx, end_idx, ref_stride): + if i not in neighbor_ids: + if len(ref_index) > ref_num: + break + ref_index.append(i) + return ref_index + + + +if __name__ == '__main__': + # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = get_device() + + parser = argparse.ArgumentParser() + parser.add_argument( + '-i', '--video', type=str, default='inputs/object_removal/bmx-trees', help='Path of the input video or image folder.') + parser.add_argument( + '-m', '--mask', type=str, default='inputs/object_removal/bmx-trees_mask', help='Path of the mask(s) or mask folder.') + parser.add_argument( + '-o', '--output', type=str, default='results', help='Output folder. Default: results') + parser.add_argument( + "--resize_ratio", type=float, default=1.0, help='Resize scale for processing video.') + parser.add_argument( + '--height', type=int, default=-1, help='Height of the processing video.') + parser.add_argument( + '--width', type=int, default=-1, help='Width of the processing video.') + parser.add_argument( + '--mask_dilation', type=int, default=4, help='Mask dilation for video and flow masking.') + parser.add_argument( + "--ref_stride", type=int, default=10, help='Stride of global reference frames.') + parser.add_argument( + "--neighbor_length", type=int, default=10, help='Length of local neighboring frames.') + parser.add_argument( + "--subvideo_length", type=int, default=80, help='Length of sub-video for long video inference.') + parser.add_argument( + "--raft_iter", type=int, default=20, help='Iterations for RAFT inference.') + parser.add_argument( + '--mode', default='video_inpainting', choices=['video_inpainting', 'video_outpainting'], help="Modes: video_inpainting / video_outpainting") + parser.add_argument( + '--scale_h', type=float, default=1.0, help='Outpainting scale of height for video_outpainting mode.') + parser.add_argument( + '--scale_w', type=float, default=1.2, help='Outpainting scale of width for video_outpainting mode.') + parser.add_argument( + '--save_fps', type=int, default=24, help='Frame per second. Default: 24') + parser.add_argument( + '--save_frames', action='store_true', help='Save output frames. Default: False') + parser.add_argument( + '--fp16', action='store_true', help='Use fp16 (half precision) during inference. Default: fp32 (single precision).') + + args = parser.parse_args() + + # Use fp16 precision during inference to reduce running memory cost + use_half = True if args.fp16 else False + + + frames, fps, size, video_name = read_frame_from_videos(args.video) + if not args.width == -1 and not args.height == -1: + size = (args.width, args.height) + if not args.resize_ratio == 1.0: + size = (int(args.resize_ratio * size[0]), int(args.resize_ratio * size[1])) + + frames, size, out_size = resize_frames(frames, size) + + fps = args.save_fps if fps is None else fps + save_root = os.path.join(args.output, video_name) + if not os.path.exists(save_root): + os.makedirs(save_root, exist_ok=True) + + if args.mode == 'video_inpainting': + frames_len = len(frames) + flow_masks, masks_dilated = read_mask(args.mask, frames_len, size, + flow_mask_dilates=args.mask_dilation, + mask_dilates=args.mask_dilation) + w, h = size + elif args.mode == 'video_outpainting': + assert args.scale_h is not None and args.scale_w is not None, 'Please provide a outpainting scale (s_h, s_w).' + frames, flow_masks, masks_dilated, size = extrapolation(frames, (args.scale_h, args.scale_w)) + w, h = size + else: + raise NotImplementedError + + # for saving the masked frames or video + masked_frame_for_save = [] + for i in range(len(frames)): + mask_ = np.expand_dims(np.array(masks_dilated[i]),2).repeat(3, axis=2)/255. + img = np.array(frames[i]) + green = np.zeros([h, w, 3]) + green[:,:,1] = 255 + alpha = 0.6 + # alpha = 1.0 + fuse_img = (1-alpha)*img + alpha*green + fuse_img = mask_ * fuse_img + (1-mask_)*img + masked_frame_for_save.append(fuse_img.astype(np.uint8)) + + frames_inp = [np.array(f).astype(np.uint8) for f in frames] + frames = to_tensors()(frames).unsqueeze(0) * 2 - 1 + flow_masks = to_tensors()(flow_masks).unsqueeze(0) + masks_dilated = to_tensors()(masks_dilated).unsqueeze(0) + frames, flow_masks, masks_dilated = frames.to(device), flow_masks.to(device), masks_dilated.to(device) + + + ############################################## + # set up RAFT and flow competition model + ############################################## + ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'raft-things.pth'), + model_dir='weights', progress=True, file_name=None) + fix_raft = RAFT_bi(ckpt_path, device) + + ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'recurrent_flow_completion.pth'), + model_dir='weights', progress=True, file_name=None) + fix_flow_complete = RecurrentFlowCompleteNet(ckpt_path) + for p in fix_flow_complete.parameters(): + p.requires_grad = False + fix_flow_complete.to(device) + fix_flow_complete.eval() + + + ############################################## + # set up ProPainter model + ############################################## + ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'ProPainter.pth'), + model_dir='weights', progress=True, file_name=None) + model = InpaintGenerator(model_path=ckpt_path).to(device) + model.eval() + + + ############################################## + # ProPainter inference + ############################################## + video_length = frames.size(1) + print(f'\nProcessing: {video_name} [{video_length} frames]...') + with torch.no_grad(): + # ---- compute flow ---- + if frames.size(-1) <= 640: + short_clip_len = 12 + elif frames.size(-1) <= 720: + short_clip_len = 8 + elif frames.size(-1) <= 1280: + short_clip_len = 4 + else: + short_clip_len = 2 + + # use fp32 for RAFT + if frames.size(1) > short_clip_len: + gt_flows_f_list, gt_flows_b_list = [], [] + for f in range(0, video_length, short_clip_len): + end_f = min(video_length, f + short_clip_len) + if f == 0: + flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=args.raft_iter) + else: + flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=args.raft_iter) + + gt_flows_f_list.append(flows_f) + gt_flows_b_list.append(flows_b) + torch.cuda.empty_cache() + + gt_flows_f = torch.cat(gt_flows_f_list, dim=1) + gt_flows_b = torch.cat(gt_flows_b_list, dim=1) + gt_flows_bi = (gt_flows_f, gt_flows_b) + else: + gt_flows_bi = fix_raft(frames, iters=args.raft_iter) + torch.cuda.empty_cache() + + + if use_half: + frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half() + gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half()) + fix_flow_complete = fix_flow_complete.half() + model = model.half() + + + # ---- complete flow ---- + flow_length = gt_flows_bi[0].size(1) + if flow_length > args.subvideo_length: + pred_flows_f, pred_flows_b = [], [] + pad_len = 5 + for f in range(0, flow_length, args.subvideo_length): + s_f = max(0, f - pad_len) + e_f = min(flow_length, f + args.subvideo_length + pad_len) + pad_len_s = max(0, f) - s_f + pad_len_e = e_f - min(flow_length, f + args.subvideo_length) + pred_flows_bi_sub, _ = fix_flow_complete.forward_bidirect_flow( + (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]), + flow_masks[:, s_f:e_f+1]) + pred_flows_bi_sub = fix_flow_complete.combine_flow( + (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]), + pred_flows_bi_sub, + flow_masks[:, s_f:e_f+1]) + + pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e]) + pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e]) + torch.cuda.empty_cache() + + pred_flows_f = torch.cat(pred_flows_f, dim=1) + pred_flows_b = torch.cat(pred_flows_b, dim=1) + pred_flows_bi = (pred_flows_f, pred_flows_b) + else: + pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks) + pred_flows_bi = fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks) + torch.cuda.empty_cache() + + + # ---- image propagation ---- + masked_frames = frames * (1 - masks_dilated) + subvideo_length_img_prop = min(100, args.subvideo_length) # ensure a minimum of 100 frames for image propagation + if video_length > subvideo_length_img_prop: + updated_frames, updated_masks = [], [] + pad_len = 10 + for f in range(0, video_length, subvideo_length_img_prop): + s_f = max(0, f - pad_len) + e_f = min(video_length, f + subvideo_length_img_prop + pad_len) + pad_len_s = max(0, f) - s_f + pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop) + + b, t, _, _, _ = masks_dilated[:, s_f:e_f].size() + pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f-1], pred_flows_bi[1][:, s_f:e_f-1]) + prop_imgs_sub, updated_local_masks_sub = model.img_propagation(masked_frames[:, s_f:e_f], + pred_flows_bi_sub, + masks_dilated[:, s_f:e_f], + 'nearest') + updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + \ + prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f] + updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w) + + updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e]) + updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e]) + torch.cuda.empty_cache() + + updated_frames = torch.cat(updated_frames, dim=1) + updated_masks = torch.cat(updated_masks, dim=1) + else: + b, t, _, _, _ = masks_dilated.size() + prop_imgs, updated_local_masks = model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest') + updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated + updated_masks = updated_local_masks.view(b, t, 1, h, w) + torch.cuda.empty_cache() + + + ori_frames = frames_inp + comp_frames = [None] * video_length + + neighbor_stride = args.neighbor_length // 2 + if video_length > args.subvideo_length: + ref_num = args.subvideo_length // args.ref_stride + else: + ref_num = -1 + + # ---- feature propagation + transformer ---- + for f in tqdm(range(0, video_length, neighbor_stride)): + neighbor_ids = [ + i for i in range(max(0, f - neighbor_stride), + min(video_length, f + neighbor_stride + 1)) + ] + ref_ids = get_ref_index(f, neighbor_ids, video_length, args.ref_stride, ref_num) + selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :] + selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :] + selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :] + selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :]) + + with torch.no_grad(): + # 1.0 indicates mask + l_t = len(neighbor_ids) + + # pred_img = selected_imgs # results of image propagation + pred_img = model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t) + + pred_img = pred_img.view(-1, 3, h, w) + + pred_img = (pred_img + 1) / 2 + pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 + binary_masks = masks_dilated[0, neighbor_ids, :, :, :].cpu().permute( + 0, 2, 3, 1).numpy().astype(np.uint8) + for i in range(len(neighbor_ids)): + idx = neighbor_ids[i] + img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \ + + ori_frames[idx] * (1 - binary_masks[i]) + if comp_frames[idx] is None: + comp_frames[idx] = img + else: + comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5 + + comp_frames[idx] = comp_frames[idx].astype(np.uint8) + + torch.cuda.empty_cache() + + # save each frame + if args.save_frames: + for idx in range(video_length): + f = comp_frames[idx] + f = cv2.resize(f, out_size, interpolation = cv2.INTER_CUBIC) + f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB) + img_save_root = os.path.join(save_root, 'frames', str(idx).zfill(4)+'.png') + imwrite(f, img_save_root) + + + # if args.mode == 'video_outpainting': + # comp_frames = [i[10:-10,10:-10] for i in comp_frames] + # masked_frame_for_save = [i[10:-10,10:-10] for i in masked_frame_for_save] + + # save videos frame + masked_frame_for_save = [cv2.resize(f, out_size) for f in masked_frame_for_save] + comp_frames = [cv2.resize(f, out_size) for f in comp_frames] + imageio.mimwrite(os.path.join(save_root, 'masked_in.mp4'), masked_frame_for_save, fps=fps, quality=7) + imageio.mimwrite(os.path.join(save_root, 'inpaint_out.mp4'), comp_frames, fps=fps, quality=7) + + print(f'\nAll results are saved in {save_root}') + + torch.cuda.empty_cache() \ No newline at end of file diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/model/__init__.py @@ -0,0 +1 @@ + diff --git a/model/canny/canny_filter.py b/model/canny/canny_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..3d16195c9355b506e22a2ba527006adb9c541a7c --- /dev/null +++ b/model/canny/canny_filter.py @@ -0,0 +1,256 @@ +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .gaussian import gaussian_blur2d +from .kernels import get_canny_nms_kernel, get_hysteresis_kernel +from .sobel import spatial_gradient + +def rgb_to_grayscale(image, rgb_weights = None): + if len(image.shape) < 3 or image.shape[-3] != 3: + raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") + + if rgb_weights is None: + # 8 bit images + if image.dtype == torch.uint8: + rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8) + # floating point images + elif image.dtype in (torch.float16, torch.float32, torch.float64): + rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype) + else: + raise TypeError(f"Unknown data type: {image.dtype}") + else: + # is tensor that we make sure is in the same device/dtype + rgb_weights = rgb_weights.to(image) + + # unpack the color image channels with RGB order + r = image[..., 0:1, :, :] + g = image[..., 1:2, :, :] + b = image[..., 2:3, :, :] + + w_r, w_g, w_b = rgb_weights.unbind() + return w_r * r + w_g * g + w_b * b + + +def canny( + input: torch.Tensor, + low_threshold: float = 0.1, + high_threshold: float = 0.2, + kernel_size: Tuple[int, int] = (5, 5), + sigma: Tuple[float, float] = (1, 1), + hysteresis: bool = True, + eps: float = 1e-6, +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Find edges of the input image and filters them using the Canny algorithm. + + .. image:: _static/img/canny.png + + Args: + input: input image tensor with shape :math:`(B,C,H,W)`. + low_threshold: lower threshold for the hysteresis procedure. + high_threshold: upper threshold for the hysteresis procedure. + kernel_size: the size of the kernel for the gaussian blur. + sigma: the standard deviation of the kernel for the gaussian blur. + hysteresis: if True, applies the hysteresis edge tracking. + Otherwise, the edges are divided between weak (0.5) and strong (1) edges. + eps: regularization number to avoid NaN during backprop. + + Returns: + - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`. + - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`. + + .. note:: + See a working example `here `__. + + Example: + >>> input = torch.rand(5, 3, 4, 4) + >>> magnitude, edges = canny(input) # 5x3x4x4 + >>> magnitude.shape + torch.Size([5, 1, 4, 4]) + >>> edges.shape + torch.Size([5, 1, 4, 4]) + """ + if not isinstance(input, torch.Tensor): + raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") + + if not len(input.shape) == 4: + raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") + + if low_threshold > high_threshold: + raise ValueError( + "Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: {}>{}".format( + low_threshold, high_threshold + ) + ) + + if low_threshold < 0 and low_threshold > 1: + raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}") + + if high_threshold < 0 and high_threshold > 1: + raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}") + + device: torch.device = input.device + dtype: torch.dtype = input.dtype + + # To Grayscale + if input.shape[1] == 3: + input = rgb_to_grayscale(input) + + # Gaussian filter + blurred: torch.Tensor = gaussian_blur2d(input, kernel_size, sigma) + + # Compute the gradients + gradients: torch.Tensor = spatial_gradient(blurred, normalized=False) + + # Unpack the edges + gx: torch.Tensor = gradients[:, :, 0] + gy: torch.Tensor = gradients[:, :, 1] + + # Compute gradient magnitude and angle + magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps) + angle: torch.Tensor = torch.atan2(gy, gx) + + # Radians to Degrees + angle = 180.0 * angle / math.pi + + # Round angle to the nearest 45 degree + angle = torch.round(angle / 45) * 45 + + # Non-maximal suppression + nms_kernels: torch.Tensor = get_canny_nms_kernel(device, dtype) + nms_magnitude: torch.Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2) + + # Get the indices for both directions + positive_idx: torch.Tensor = (angle / 45) % 8 + positive_idx = positive_idx.long() + + negative_idx: torch.Tensor = ((angle / 45) + 4) % 8 + negative_idx = negative_idx.long() + + # Apply the non-maximum suppression to the different directions + channel_select_filtered_positive: torch.Tensor = torch.gather(nms_magnitude, 1, positive_idx) + channel_select_filtered_negative: torch.Tensor = torch.gather(nms_magnitude, 1, negative_idx) + + channel_select_filtered: torch.Tensor = torch.stack( + [channel_select_filtered_positive, channel_select_filtered_negative], 1 + ) + + is_max: torch.Tensor = channel_select_filtered.min(dim=1)[0] > 0.0 + + magnitude = magnitude * is_max + + # Threshold + edges: torch.Tensor = F.threshold(magnitude, low_threshold, 0.0) + + low: torch.Tensor = magnitude > low_threshold + high: torch.Tensor = magnitude > high_threshold + + edges = low * 0.5 + high * 0.5 + edges = edges.to(dtype) + + # Hysteresis + if hysteresis: + edges_old: torch.Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype) + hysteresis_kernels: torch.Tensor = get_hysteresis_kernel(device, dtype) + + while ((edges_old - edges).abs() != 0).any(): + weak: torch.Tensor = (edges == 0.5).float() + strong: torch.Tensor = (edges == 1).float() + + hysteresis_magnitude: torch.Tensor = F.conv2d( + edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2 + ) + hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype) + hysteresis_magnitude = hysteresis_magnitude * weak + strong + + edges_old = edges.clone() + edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5 + + edges = hysteresis_magnitude + + return magnitude, edges + + +class Canny(nn.Module): + r"""Module that finds edges of the input image and filters them using the Canny algorithm. + + Args: + input: input image tensor with shape :math:`(B,C,H,W)`. + low_threshold: lower threshold for the hysteresis procedure. + high_threshold: upper threshold for the hysteresis procedure. + kernel_size: the size of the kernel for the gaussian blur. + sigma: the standard deviation of the kernel for the gaussian blur. + hysteresis: if True, applies the hysteresis edge tracking. + Otherwise, the edges are divided between weak (0.5) and strong (1) edges. + eps: regularization number to avoid NaN during backprop. + + Returns: + - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`. + - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`. + + Example: + >>> input = torch.rand(5, 3, 4, 4) + >>> magnitude, edges = Canny()(input) # 5x3x4x4 + >>> magnitude.shape + torch.Size([5, 1, 4, 4]) + >>> edges.shape + torch.Size([5, 1, 4, 4]) + """ + + def __init__( + self, + low_threshold: float = 0.1, + high_threshold: float = 0.2, + kernel_size: Tuple[int, int] = (5, 5), + sigma: Tuple[float, float] = (1, 1), + hysteresis: bool = True, + eps: float = 1e-6, + ) -> None: + super().__init__() + + if low_threshold > high_threshold: + raise ValueError( + "Invalid input thresholds. low_threshold should be\ + smaller than the high_threshold. Got: {}>{}".format( + low_threshold, high_threshold + ) + ) + + if low_threshold < 0 or low_threshold > 1: + raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}") + + if high_threshold < 0 or high_threshold > 1: + raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}") + + # Gaussian blur parameters + self.kernel_size = kernel_size + self.sigma = sigma + + # Double threshold + self.low_threshold = low_threshold + self.high_threshold = high_threshold + + # Hysteresis + self.hysteresis = hysteresis + + self.eps: float = eps + + def __repr__(self) -> str: + return ''.join( + ( + f'{type(self).__name__}(', + ', '.join( + f'{name}={getattr(self, name)}' for name in sorted(self.__dict__) if not name.startswith('_') + ), + ')', + ) + ) + + def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return canny( + input, self.low_threshold, self.high_threshold, self.kernel_size, self.sigma, self.hysteresis, self.eps + ) \ No newline at end of file diff --git a/model/canny/filter.py b/model/canny/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..e39d44d67a067c56f994dc9a189f3cf98663bf68 --- /dev/null +++ b/model/canny/filter.py @@ -0,0 +1,288 @@ +from typing import List + +import torch +import torch.nn.functional as F + +from .kernels import normalize_kernel2d + + +def _compute_padding(kernel_size: List[int]) -> List[int]: + """Compute padding tuple.""" + # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) + # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad + if len(kernel_size) < 2: + raise AssertionError(kernel_size) + computed = [k - 1 for k in kernel_size] + + # for even kernels we need to do asymmetric padding :( + out_padding = 2 * len(kernel_size) * [0] + + for i in range(len(kernel_size)): + computed_tmp = computed[-(i + 1)] + + pad_front = computed_tmp // 2 + pad_rear = computed_tmp - pad_front + + out_padding[2 * i + 0] = pad_front + out_padding[2 * i + 1] = pad_rear + + return out_padding + + +def filter2d( + input: torch.Tensor, + kernel: torch.Tensor, + border_type: str = 'reflect', + normalized: bool = False, + padding: str = 'same', +) -> torch.Tensor: + r"""Convolve a tensor with a 2d kernel. + + The function applies a given kernel to a tensor. The kernel is applied + independently at each depth channel of the tensor. Before applying the + kernel, the function applies padding according to the specified mode so + that the output remains in the same shape. + + Args: + input: the input tensor with shape of + :math:`(B, C, H, W)`. + kernel: the kernel to be convolved with the input + tensor. The kernel shape must be :math:`(1, kH, kW)` or :math:`(B, kH, kW)`. + border_type: the padding mode to be applied before convolving. + The expected modes are: ``'constant'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. + normalized: If True, kernel will be L1 normalized. + padding: This defines the type of padding. + 2 modes available ``'same'`` or ``'valid'``. + + Return: + torch.Tensor: the convolved tensor of same size and numbers of channels + as the input with shape :math:`(B, C, H, W)`. + + Example: + >>> input = torch.tensor([[[ + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 5., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.],]]]) + >>> kernel = torch.ones(1, 3, 3) + >>> filter2d(input, kernel, padding='same') + tensor([[[[0., 0., 0., 0., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 0., 0., 0., 0.]]]]) + """ + if not isinstance(input, torch.Tensor): + raise TypeError(f"Input input is not torch.Tensor. Got {type(input)}") + + if not isinstance(kernel, torch.Tensor): + raise TypeError(f"Input kernel is not torch.Tensor. Got {type(kernel)}") + + if not isinstance(border_type, str): + raise TypeError(f"Input border_type is not string. Got {type(border_type)}") + + if border_type not in ['constant', 'reflect', 'replicate', 'circular']: + raise ValueError( + f"Invalid border type, we expect 'constant', \ + 'reflect', 'replicate', 'circular'. Got:{border_type}" + ) + + if not isinstance(padding, str): + raise TypeError(f"Input padding is not string. Got {type(padding)}") + + if padding not in ['valid', 'same']: + raise ValueError(f"Invalid padding mode, we expect 'valid' or 'same'. Got: {padding}") + + if not len(input.shape) == 4: + raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") + + if (not len(kernel.shape) == 3) and not ((kernel.shape[0] == 0) or (kernel.shape[0] == input.shape[0])): + raise ValueError(f"Invalid kernel shape, we expect 1xHxW or BxHxW. Got: {kernel.shape}") + + # prepare kernel + b, c, h, w = input.shape + tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input) + + if normalized: + tmp_kernel = normalize_kernel2d(tmp_kernel) + + tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) + + height, width = tmp_kernel.shape[-2:] + + # pad the input tensor + if padding == 'same': + padding_shape: List[int] = _compute_padding([height, width]) + input = F.pad(input, padding_shape, mode=border_type) + + # kernel and input tensor reshape to align element-wise or batch-wise params + tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) + input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) + + # convolve the tensor with the kernel. + output = F.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) + + if padding == 'same': + out = output.view(b, c, h, w) + else: + out = output.view(b, c, h - height + 1, w - width + 1) + + return out + + +def filter2d_separable( + input: torch.Tensor, + kernel_x: torch.Tensor, + kernel_y: torch.Tensor, + border_type: str = 'reflect', + normalized: bool = False, + padding: str = 'same', +) -> torch.Tensor: + r"""Convolve a tensor with two 1d kernels, in x and y directions. + + The function applies a given kernel to a tensor. The kernel is applied + independently at each depth channel of the tensor. Before applying the + kernel, the function applies padding according to the specified mode so + that the output remains in the same shape. + + Args: + input: the input tensor with shape of + :math:`(B, C, H, W)`. + kernel_x: the kernel to be convolved with the input + tensor. The kernel shape must be :math:`(1, kW)` or :math:`(B, kW)`. + kernel_y: the kernel to be convolved with the input + tensor. The kernel shape must be :math:`(1, kH)` or :math:`(B, kH)`. + border_type: the padding mode to be applied before convolving. + The expected modes are: ``'constant'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. + normalized: If True, kernel will be L1 normalized. + padding: This defines the type of padding. + 2 modes available ``'same'`` or ``'valid'``. + + Return: + torch.Tensor: the convolved tensor of same size and numbers of channels + as the input with shape :math:`(B, C, H, W)`. + + Example: + >>> input = torch.tensor([[[ + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 5., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.],]]]) + >>> kernel = torch.ones(1, 3) + + >>> filter2d_separable(input, kernel, kernel, padding='same') + tensor([[[[0., 0., 0., 0., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 0., 0., 0., 0.]]]]) + """ + out_x = filter2d(input, kernel_x.unsqueeze(0), border_type, normalized, padding) + out = filter2d(out_x, kernel_y.unsqueeze(-1), border_type, normalized, padding) + return out + + +def filter3d( + input: torch.Tensor, kernel: torch.Tensor, border_type: str = 'replicate', normalized: bool = False +) -> torch.Tensor: + r"""Convolve a tensor with a 3d kernel. + + The function applies a given kernel to a tensor. The kernel is applied + independently at each depth channel of the tensor. Before applying the + kernel, the function applies padding according to the specified mode so + that the output remains in the same shape. + + Args: + input: the input tensor with shape of + :math:`(B, C, D, H, W)`. + kernel: the kernel to be convolved with the input + tensor. The kernel shape must be :math:`(1, kD, kH, kW)` or :math:`(B, kD, kH, kW)`. + border_type: the padding mode to be applied before convolving. + The expected modes are: ``'constant'``, + ``'replicate'`` or ``'circular'``. + normalized: If True, kernel will be L1 normalized. + + Return: + the convolved tensor of same size and numbers of channels + as the input with shape :math:`(B, C, D, H, W)`. + + Example: + >>> input = torch.tensor([[[ + ... [[0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.]], + ... [[0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 5., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.]], + ... [[0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.]] + ... ]]]) + >>> kernel = torch.ones(1, 3, 3, 3) + >>> filter3d(input, kernel) + tensor([[[[[0., 0., 0., 0., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 0., 0., 0., 0.]]]]]) + """ + if not isinstance(input, torch.Tensor): + raise TypeError(f"Input border_type is not torch.Tensor. Got {type(input)}") + + if not isinstance(kernel, torch.Tensor): + raise TypeError(f"Input border_type is not torch.Tensor. Got {type(kernel)}") + + if not isinstance(border_type, str): + raise TypeError(f"Input border_type is not string. Got {type(kernel)}") + + if not len(input.shape) == 5: + raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}") + + if not len(kernel.shape) == 4 and kernel.shape[0] != 1: + raise ValueError(f"Invalid kernel shape, we expect 1xDxHxW. Got: {kernel.shape}") + + # prepare kernel + b, c, d, h, w = input.shape + tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input) + + if normalized: + bk, dk, hk, wk = kernel.shape + tmp_kernel = normalize_kernel2d(tmp_kernel.view(bk, dk, hk * wk)).view_as(tmp_kernel) + + tmp_kernel = tmp_kernel.expand(-1, c, -1, -1, -1) + + # pad the input tensor + depth, height, width = tmp_kernel.shape[-3:] + padding_shape: List[int] = _compute_padding([depth, height, width]) + input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type) + + # kernel and input tensor reshape to align element-wise or batch-wise params + tmp_kernel = tmp_kernel.reshape(-1, 1, depth, height, width) + input_pad = input_pad.view(-1, tmp_kernel.size(0), input_pad.size(-3), input_pad.size(-2), input_pad.size(-1)) + + # convolve the tensor with the kernel. + output = F.conv3d(input_pad, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) + + return output.view(b, c, d, h, w) \ No newline at end of file diff --git a/model/canny/gaussian.py b/model/canny/gaussian.py new file mode 100644 index 0000000000000000000000000000000000000000..182f05c5d7d297d97b3dd008287e053493350bb6 --- /dev/null +++ b/model/canny/gaussian.py @@ -0,0 +1,116 @@ +from typing import Tuple + +import torch +import torch.nn as nn + +from .filter import filter2d, filter2d_separable +from .kernels import get_gaussian_kernel1d, get_gaussian_kernel2d + + +def gaussian_blur2d( + input: torch.Tensor, + kernel_size: Tuple[int, int], + sigma: Tuple[float, float], + border_type: str = 'reflect', + separable: bool = True, +) -> torch.Tensor: + r"""Create an operator that blurs a tensor using a Gaussian filter. + + .. image:: _static/img/gaussian_blur2d.png + + The operator smooths the given tensor with a gaussian kernel by convolving + it to each channel. It supports batched operation. + + Arguments: + input: the input tensor with shape :math:`(B,C,H,W)`. + kernel_size: the size of the kernel. + sigma: the standard deviation of the kernel. + border_type: the padding mode to be applied before convolving. + The expected modes are: ``'constant'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. + separable: run as composition of two 1d-convolutions. + + Returns: + the blurred tensor with shape :math:`(B, C, H, W)`. + + .. note:: + See a working example `here `__. + + Examples: + >>> input = torch.rand(2, 4, 5, 5) + >>> output = gaussian_blur2d(input, (3, 3), (1.5, 1.5)) + >>> output.shape + torch.Size([2, 4, 5, 5]) + """ + if separable: + kernel_x: torch.Tensor = get_gaussian_kernel1d(kernel_size[1], sigma[1]) + kernel_y: torch.Tensor = get_gaussian_kernel1d(kernel_size[0], sigma[0]) + out = filter2d_separable(input, kernel_x[None], kernel_y[None], border_type) + else: + kernel: torch.Tensor = get_gaussian_kernel2d(kernel_size, sigma) + out = filter2d(input, kernel[None], border_type) + return out + + +class GaussianBlur2d(nn.Module): + r"""Create an operator that blurs a tensor using a Gaussian filter. + + The operator smooths the given tensor with a gaussian kernel by convolving + it to each channel. It supports batched operation. + + Arguments: + kernel_size: the size of the kernel. + sigma: the standard deviation of the kernel. + border_type: the padding mode to be applied before convolving. + The expected modes are: ``'constant'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. + separable: run as composition of two 1d-convolutions. + + Returns: + the blurred tensor. + + Shape: + - Input: :math:`(B, C, H, W)` + - Output: :math:`(B, C, H, W)` + + Examples:: + + >>> input = torch.rand(2, 4, 5, 5) + >>> gauss = GaussianBlur2d((3, 3), (1.5, 1.5)) + >>> output = gauss(input) # 2x4x5x5 + >>> output.shape + torch.Size([2, 4, 5, 5]) + """ + + def __init__( + self, + kernel_size: Tuple[int, int], + sigma: Tuple[float, float], + border_type: str = 'reflect', + separable: bool = True, + ) -> None: + super().__init__() + self.kernel_size: Tuple[int, int] = kernel_size + self.sigma: Tuple[float, float] = sigma + self.border_type = border_type + self.separable = separable + + def __repr__(self) -> str: + return ( + self.__class__.__name__ + + '(kernel_size=' + + str(self.kernel_size) + + ', ' + + 'sigma=' + + str(self.sigma) + + ', ' + + 'border_type=' + + self.border_type + + 'separable=' + + str(self.separable) + + ')' + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return gaussian_blur2d(input, self.kernel_size, self.sigma, self.border_type, self.separable) \ No newline at end of file diff --git a/model/canny/kernels.py b/model/canny/kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1ee251b8363ba76c7b63c6925a1776c50b7f32 --- /dev/null +++ b/model/canny/kernels.py @@ -0,0 +1,690 @@ +import math +from math import sqrt +from typing import List, Optional, Tuple + +import torch + + +def normalize_kernel2d(input: torch.Tensor) -> torch.Tensor: + r"""Normalize both derivative and smoothing kernel.""" + if len(input.size()) < 2: + raise TypeError(f"input should be at least 2D tensor. Got {input.size()}") + norm: torch.Tensor = input.abs().sum(dim=-1).sum(dim=-1) + return input / (norm.unsqueeze(-1).unsqueeze(-1)) + + +def gaussian(window_size: int, sigma: float) -> torch.Tensor: + device, dtype = None, None + if isinstance(sigma, torch.Tensor): + device, dtype = sigma.device, sigma.dtype + x = torch.arange(window_size, device=device, dtype=dtype) - window_size // 2 + if window_size % 2 == 0: + x = x + 0.5 + + gauss = torch.exp((-x.pow(2.0) / (2 * sigma**2)).float()) + return gauss / gauss.sum() + + +def gaussian_discrete_erf(window_size: int, sigma) -> torch.Tensor: + r"""Discrete Gaussian by interpolating the error function. + + Adapted from: + https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py + """ + device = sigma.device if isinstance(sigma, torch.Tensor) else None + sigma = torch.as_tensor(sigma, dtype=torch.float, device=device) + x = torch.arange(window_size).float() - window_size // 2 + t = 0.70710678 / torch.abs(sigma) + gauss = 0.5 * ((t * (x + 0.5)).erf() - (t * (x - 0.5)).erf()) + gauss = gauss.clamp(min=0) + return gauss / gauss.sum() + + +def _modified_bessel_0(x: torch.Tensor) -> torch.Tensor: + r"""Adapted from: + + https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py + """ + if torch.abs(x) < 3.75: + y = (x / 3.75) * (x / 3.75) + return 1.0 + y * ( + 3.5156229 + y * (3.0899424 + y * (1.2067492 + y * (0.2659732 + y * (0.360768e-1 + y * 0.45813e-2)))) + ) + ax = torch.abs(x) + y = 3.75 / ax + ans = 0.916281e-2 + y * (-0.2057706e-1 + y * (0.2635537e-1 + y * (-0.1647633e-1 + y * 0.392377e-2))) + coef = 0.39894228 + y * (0.1328592e-1 + y * (0.225319e-2 + y * (-0.157565e-2 + y * ans))) + return (torch.exp(ax) / torch.sqrt(ax)) * coef + + +def _modified_bessel_1(x: torch.Tensor) -> torch.Tensor: + r"""adapted from: + + https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py + """ + if torch.abs(x) < 3.75: + y = (x / 3.75) * (x / 3.75) + ans = 0.51498869 + y * (0.15084934 + y * (0.2658733e-1 + y * (0.301532e-2 + y * 0.32411e-3))) + return torch.abs(x) * (0.5 + y * (0.87890594 + y * ans)) + ax = torch.abs(x) + y = 3.75 / ax + ans = 0.2282967e-1 + y * (-0.2895312e-1 + y * (0.1787654e-1 - y * 0.420059e-2)) + ans = 0.39894228 + y * (-0.3988024e-1 + y * (-0.362018e-2 + y * (0.163801e-2 + y * (-0.1031555e-1 + y * ans)))) + ans = ans * torch.exp(ax) / torch.sqrt(ax) + return -ans if x < 0.0 else ans + + +def _modified_bessel_i(n: int, x: torch.Tensor) -> torch.Tensor: + r"""adapted from: + + https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py + """ + if n < 2: + raise ValueError("n must be greater than 1.") + if x == 0.0: + return x + device = x.device + tox = 2.0 / torch.abs(x) + ans = torch.tensor(0.0, device=device) + bip = torch.tensor(0.0, device=device) + bi = torch.tensor(1.0, device=device) + m = int(2 * (n + int(sqrt(40.0 * n)))) + for j in range(m, 0, -1): + bim = bip + float(j) * tox * bi + bip = bi + bi = bim + if abs(bi) > 1.0e10: + ans = ans * 1.0e-10 + bi = bi * 1.0e-10 + bip = bip * 1.0e-10 + if j == n: + ans = bip + ans = ans * _modified_bessel_0(x) / bi + return -ans if x < 0.0 and (n % 2) == 1 else ans + + +def gaussian_discrete(window_size, sigma) -> torch.Tensor: + r"""Discrete Gaussian kernel based on the modified Bessel functions. + + Adapted from: + https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py + """ + device = sigma.device if isinstance(sigma, torch.Tensor) else None + sigma = torch.as_tensor(sigma, dtype=torch.float, device=device) + sigma2 = sigma * sigma + tail = int(window_size // 2) + out_pos: List[Optional[torch.Tensor]] = [None] * (tail + 1) + out_pos[0] = _modified_bessel_0(sigma2) + out_pos[1] = _modified_bessel_1(sigma2) + for k in range(2, len(out_pos)): + out_pos[k] = _modified_bessel_i(k, sigma2) + out = out_pos[:0:-1] + out.extend(out_pos) + out = torch.stack(out) * torch.exp(sigma2) # type: ignore + return out / out.sum() # type: ignore + + +def laplacian_1d(window_size) -> torch.Tensor: + r"""One could also use the Laplacian of Gaussian formula to design the filter.""" + + filter_1d = torch.ones(window_size) + filter_1d[window_size // 2] = 1 - window_size + laplacian_1d: torch.Tensor = filter_1d + return laplacian_1d + + +def get_box_kernel2d(kernel_size: Tuple[int, int]) -> torch.Tensor: + r"""Utility function that returns a box filter.""" + kx: float = float(kernel_size[0]) + ky: float = float(kernel_size[1]) + scale: torch.Tensor = torch.tensor(1.0) / torch.tensor([kx * ky]) + tmp_kernel: torch.Tensor = torch.ones(1, kernel_size[0], kernel_size[1]) + return scale.to(tmp_kernel.dtype) * tmp_kernel + + +def get_binary_kernel2d(window_size: Tuple[int, int]) -> torch.Tensor: + r"""Create a binary kernel to extract the patches. + + If the window size is HxW will create a (H*W)xHxW kernel. + """ + window_range: int = window_size[0] * window_size[1] + kernel: torch.Tensor = torch.zeros(window_range, window_range) + for i in range(window_range): + kernel[i, i] += 1.0 + return kernel.view(window_range, 1, window_size[0], window_size[1]) + + +def get_sobel_kernel_3x3() -> torch.Tensor: + """Utility function that returns a sobel kernel of 3x3.""" + return torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]) + + +def get_sobel_kernel_5x5_2nd_order() -> torch.Tensor: + """Utility function that returns a 2nd order sobel kernel of 5x5.""" + return torch.tensor( + [ + [-1.0, 0.0, 2.0, 0.0, -1.0], + [-4.0, 0.0, 8.0, 0.0, -4.0], + [-6.0, 0.0, 12.0, 0.0, -6.0], + [-4.0, 0.0, 8.0, 0.0, -4.0], + [-1.0, 0.0, 2.0, 0.0, -1.0], + ] + ) + + +def _get_sobel_kernel_5x5_2nd_order_xy() -> torch.Tensor: + """Utility function that returns a 2nd order sobel kernel of 5x5.""" + return torch.tensor( + [ + [-1.0, -2.0, 0.0, 2.0, 1.0], + [-2.0, -4.0, 0.0, 4.0, 2.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [2.0, 4.0, 0.0, -4.0, -2.0], + [1.0, 2.0, 0.0, -2.0, -1.0], + ] + ) + + +def get_diff_kernel_3x3() -> torch.Tensor: + """Utility function that returns a first order derivative kernel of 3x3.""" + return torch.tensor([[-0.0, 0.0, 0.0], [-1.0, 0.0, 1.0], [-0.0, 0.0, 0.0]]) + + +def get_diff_kernel3d(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: + """Utility function that returns a first order derivative kernel of 3x3x3.""" + kernel: torch.Tensor = torch.tensor( + [ + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [-0.5, 0.0, 0.5], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ], + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, -0.5, 0.0], [0.0, 0.0, 0.0], [0.0, 0.5, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ], + [ + [[0.0, 0.0, 0.0], [0.0, -0.5, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.5, 0.0], [0.0, 0.0, 0.0]], + ], + ], + device=device, + dtype=dtype, + ) + return kernel.unsqueeze(1) + + +def get_diff_kernel3d_2nd_order(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: + """Utility function that returns a first order derivative kernel of 3x3x3.""" + kernel: torch.Tensor = torch.tensor( + [ + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [1.0, -2.0, 1.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ], + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 1.0, 0.0], [0.0, -2.0, 0.0], [0.0, 1.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ], + [ + [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, -2.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + ], + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[1.0, 0.0, -1.0], [0.0, 0.0, 0.0], [-1.0, 0.0, 1.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ], + [ + [[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, -1.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, -1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + ], + [ + [[0.0, 0.0, 0.0], [1.0, 0.0, -1.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [-1.0, 0.0, 1.0], [0.0, 0.0, 0.0]], + ], + ], + device=device, + dtype=dtype, + ) + return kernel.unsqueeze(1) + + +def get_sobel_kernel2d() -> torch.Tensor: + kernel_x: torch.Tensor = get_sobel_kernel_3x3() + kernel_y: torch.Tensor = kernel_x.transpose(0, 1) + return torch.stack([kernel_x, kernel_y]) + + +def get_diff_kernel2d() -> torch.Tensor: + kernel_x: torch.Tensor = get_diff_kernel_3x3() + kernel_y: torch.Tensor = kernel_x.transpose(0, 1) + return torch.stack([kernel_x, kernel_y]) + + +def get_sobel_kernel2d_2nd_order() -> torch.Tensor: + gxx: torch.Tensor = get_sobel_kernel_5x5_2nd_order() + gyy: torch.Tensor = gxx.transpose(0, 1) + gxy: torch.Tensor = _get_sobel_kernel_5x5_2nd_order_xy() + return torch.stack([gxx, gxy, gyy]) + + +def get_diff_kernel2d_2nd_order() -> torch.Tensor: + gxx: torch.Tensor = torch.tensor([[0.0, 0.0, 0.0], [1.0, -2.0, 1.0], [0.0, 0.0, 0.0]]) + gyy: torch.Tensor = gxx.transpose(0, 1) + gxy: torch.Tensor = torch.tensor([[-1.0, 0.0, 1.0], [0.0, 0.0, 0.0], [1.0, 0.0, -1.0]]) + return torch.stack([gxx, gxy, gyy]) + + +def get_spatial_gradient_kernel2d(mode: str, order: int) -> torch.Tensor: + r"""Function that returns kernel for 1st or 2nd order image gradients, using one of the following operators: + + sobel, diff. + """ + if mode not in ['sobel', 'diff']: + raise TypeError( + "mode should be either sobel\ + or diff. Got {}".format( + mode + ) + ) + if order not in [1, 2]: + raise TypeError( + "order should be either 1 or 2\ + Got {}".format( + order + ) + ) + if mode == 'sobel' and order == 1: + kernel: torch.Tensor = get_sobel_kernel2d() + elif mode == 'sobel' and order == 2: + kernel = get_sobel_kernel2d_2nd_order() + elif mode == 'diff' and order == 1: + kernel = get_diff_kernel2d() + elif mode == 'diff' and order == 2: + kernel = get_diff_kernel2d_2nd_order() + else: + raise NotImplementedError("") + return kernel + + +def get_spatial_gradient_kernel3d(mode: str, order: int, device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: + r"""Function that returns kernel for 1st or 2nd order scale pyramid gradients, using one of the following + operators: sobel, diff.""" + if mode not in ['sobel', 'diff']: + raise TypeError( + "mode should be either sobel\ + or diff. Got {}".format( + mode + ) + ) + if order not in [1, 2]: + raise TypeError( + "order should be either 1 or 2\ + Got {}".format( + order + ) + ) + if mode == 'sobel': + raise NotImplementedError("Sobel kernel for 3d gradient is not implemented yet") + if mode == 'diff' and order == 1: + kernel = get_diff_kernel3d(device, dtype) + elif mode == 'diff' and order == 2: + kernel = get_diff_kernel3d_2nd_order(device, dtype) + else: + raise NotImplementedError("") + return kernel + + +def get_gaussian_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor: + r"""Function that returns Gaussian filter coefficients. + + Args: + kernel_size: filter size. It should be odd and positive. + sigma: gaussian standard deviation. + force_even: overrides requirement for odd kernel size. + + Returns: + 1D tensor with gaussian filter coefficients. + + Shape: + - Output: :math:`(\text{kernel_size})` + + Examples: + + >>> get_gaussian_kernel1d(3, 2.5) + tensor([0.3243, 0.3513, 0.3243]) + + >>> get_gaussian_kernel1d(5, 1.5) + tensor([0.1201, 0.2339, 0.2921, 0.2339, 0.1201]) + """ + if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0): + raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size)) + window_1d: torch.Tensor = gaussian(kernel_size, sigma) + return window_1d + + +def get_gaussian_discrete_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor: + r"""Function that returns Gaussian filter coefficients based on the modified Bessel functions. Adapted from: + https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py. + + Args: + kernel_size: filter size. It should be odd and positive. + sigma: gaussian standard deviation. + force_even: overrides requirement for odd kernel size. + + Returns: + 1D tensor with gaussian filter coefficients. + + Shape: + - Output: :math:`(\text{kernel_size})` + + Examples: + + >>> get_gaussian_discrete_kernel1d(3, 2.5) + tensor([0.3235, 0.3531, 0.3235]) + + >>> get_gaussian_discrete_kernel1d(5, 1.5) + tensor([0.1096, 0.2323, 0.3161, 0.2323, 0.1096]) + """ + if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0): + raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size)) + window_1d = gaussian_discrete(kernel_size, sigma) + return window_1d + + +def get_gaussian_erf_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor: + r"""Function that returns Gaussian filter coefficients by interpolating the error function, adapted from: + https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py. + + Args: + kernel_size: filter size. It should be odd and positive. + sigma: gaussian standard deviation. + force_even: overrides requirement for odd kernel size. + + Returns: + 1D tensor with gaussian filter coefficients. + + Shape: + - Output: :math:`(\text{kernel_size})` + + Examples: + + >>> get_gaussian_erf_kernel1d(3, 2.5) + tensor([0.3245, 0.3511, 0.3245]) + + >>> get_gaussian_erf_kernel1d(5, 1.5) + tensor([0.1226, 0.2331, 0.2887, 0.2331, 0.1226]) + """ + if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0): + raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size)) + window_1d = gaussian_discrete_erf(kernel_size, sigma) + return window_1d + + +def get_gaussian_kernel2d( + kernel_size: Tuple[int, int], sigma: Tuple[float, float], force_even: bool = False +) -> torch.Tensor: + r"""Function that returns Gaussian filter matrix coefficients. + + Args: + kernel_size: filter sizes in the x and y direction. + Sizes should be odd and positive. + sigma: gaussian standard deviation in the x and y + direction. + force_even: overrides requirement for odd kernel size. + + Returns: + 2D tensor with gaussian filter matrix coefficients. + + Shape: + - Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)` + + Examples: + >>> get_gaussian_kernel2d((3, 3), (1.5, 1.5)) + tensor([[0.0947, 0.1183, 0.0947], + [0.1183, 0.1478, 0.1183], + [0.0947, 0.1183, 0.0947]]) + >>> get_gaussian_kernel2d((3, 5), (1.5, 1.5)) + tensor([[0.0370, 0.0720, 0.0899, 0.0720, 0.0370], + [0.0462, 0.0899, 0.1123, 0.0899, 0.0462], + [0.0370, 0.0720, 0.0899, 0.0720, 0.0370]]) + """ + if not isinstance(kernel_size, tuple) or len(kernel_size) != 2: + raise TypeError(f"kernel_size must be a tuple of length two. Got {kernel_size}") + if not isinstance(sigma, tuple) or len(sigma) != 2: + raise TypeError(f"sigma must be a tuple of length two. Got {sigma}") + ksize_x, ksize_y = kernel_size + sigma_x, sigma_y = sigma + kernel_x: torch.Tensor = get_gaussian_kernel1d(ksize_x, sigma_x, force_even) + kernel_y: torch.Tensor = get_gaussian_kernel1d(ksize_y, sigma_y, force_even) + kernel_2d: torch.Tensor = torch.matmul(kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t()) + return kernel_2d + + +def get_laplacian_kernel1d(kernel_size: int) -> torch.Tensor: + r"""Function that returns the coefficients of a 1D Laplacian filter. + + Args: + kernel_size: filter size. It should be odd and positive. + + Returns: + 1D tensor with laplacian filter coefficients. + + Shape: + - Output: math:`(\text{kernel_size})` + + Examples: + >>> get_laplacian_kernel1d(3) + tensor([ 1., -2., 1.]) + >>> get_laplacian_kernel1d(5) + tensor([ 1., 1., -4., 1., 1.]) + """ + if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0: + raise TypeError(f"ksize must be an odd positive integer. Got {kernel_size}") + window_1d: torch.Tensor = laplacian_1d(kernel_size) + return window_1d + + +def get_laplacian_kernel2d(kernel_size: int) -> torch.Tensor: + r"""Function that returns Gaussian filter matrix coefficients. + + Args: + kernel_size: filter size should be odd. + + Returns: + 2D tensor with laplacian filter matrix coefficients. + + Shape: + - Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)` + + Examples: + >>> get_laplacian_kernel2d(3) + tensor([[ 1., 1., 1.], + [ 1., -8., 1.], + [ 1., 1., 1.]]) + >>> get_laplacian_kernel2d(5) + tensor([[ 1., 1., 1., 1., 1.], + [ 1., 1., 1., 1., 1.], + [ 1., 1., -24., 1., 1.], + [ 1., 1., 1., 1., 1.], + [ 1., 1., 1., 1., 1.]]) + """ + if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0: + raise TypeError(f"ksize must be an odd positive integer. Got {kernel_size}") + + kernel = torch.ones((kernel_size, kernel_size)) + mid = kernel_size // 2 + kernel[mid, mid] = 1 - kernel_size**2 + kernel_2d: torch.Tensor = kernel + return kernel_2d + + +def get_pascal_kernel_2d(kernel_size: int, norm: bool = True) -> torch.Tensor: + """Generate pascal filter kernel by kernel size. + + Args: + kernel_size: height and width of the kernel. + norm: if to normalize the kernel or not. Default: True. + + Returns: + kernel shaped as :math:`(kernel_size, kernel_size)` + + Examples: + >>> get_pascal_kernel_2d(1) + tensor([[1.]]) + >>> get_pascal_kernel_2d(4) + tensor([[0.0156, 0.0469, 0.0469, 0.0156], + [0.0469, 0.1406, 0.1406, 0.0469], + [0.0469, 0.1406, 0.1406, 0.0469], + [0.0156, 0.0469, 0.0469, 0.0156]]) + >>> get_pascal_kernel_2d(4, norm=False) + tensor([[1., 3., 3., 1.], + [3., 9., 9., 3.], + [3., 9., 9., 3.], + [1., 3., 3., 1.]]) + """ + a = get_pascal_kernel_1d(kernel_size) + + filt = a[:, None] * a[None, :] + if norm: + filt = filt / torch.sum(filt) + return filt + + +def get_pascal_kernel_1d(kernel_size: int, norm: bool = False) -> torch.Tensor: + """Generate Yang Hui triangle (Pascal's triangle) by a given number. + + Args: + kernel_size: height and width of the kernel. + norm: if to normalize the kernel or not. Default: False. + + Returns: + kernel shaped as :math:`(kernel_size,)` + + Examples: + >>> get_pascal_kernel_1d(1) + tensor([1.]) + >>> get_pascal_kernel_1d(2) + tensor([1., 1.]) + >>> get_pascal_kernel_1d(3) + tensor([1., 2., 1.]) + >>> get_pascal_kernel_1d(4) + tensor([1., 3., 3., 1.]) + >>> get_pascal_kernel_1d(5) + tensor([1., 4., 6., 4., 1.]) + >>> get_pascal_kernel_1d(6) + tensor([ 1., 5., 10., 10., 5., 1.]) + """ + pre: List[float] = [] + cur: List[float] = [] + for i in range(kernel_size): + cur = [1.0] * (i + 1) + + for j in range(1, i // 2 + 1): + value = pre[j - 1] + pre[j] + cur[j] = value + if i != 2 * j: + cur[-j - 1] = value + pre = cur + + out = torch.as_tensor(cur) + if norm: + out = out / torch.sum(out) + return out + + +def get_canny_nms_kernel(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: + """Utility function that returns 3x3 kernels for the Canny Non-maximal suppression.""" + kernel: torch.Tensor = torch.tensor( + [ + [[0.0, 0.0, 0.0], [0.0, 1.0, -1.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]], + [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, -1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + ], + device=device, + dtype=dtype, + ) + return kernel.unsqueeze(1) + + +def get_hysteresis_kernel(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: + """Utility function that returns the 3x3 kernels for the Canny hysteresis.""" + kernel: torch.Tensor = torch.tensor( + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ], + device=device, + dtype=dtype, + ) + return kernel.unsqueeze(1) + + +def get_hanning_kernel1d(kernel_size: int, device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: + r"""Returns Hanning (also known as Hann) kernel, used in signal processing and KCF tracker. + + .. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right) + \\qquad 0 \\leq n \\leq M-1 + + See further in numpy docs https://numpy.org/doc/stable/reference/generated/numpy.hanning.html + + Args: + kernel_size: The size the of the kernel. It should be positive. + + Returns: + 1D tensor with Hanning filter coefficients. + .. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right) + + Shape: + - Output: math:`(\text{kernel_size})` + + Examples: + >>> get_hanning_kernel1d(4) + tensor([0.0000, 0.7500, 0.7500, 0.0000]) + """ + if not isinstance(kernel_size, int) or kernel_size <= 2: + raise TypeError(f"ksize must be an positive integer > 2. Got {kernel_size}") + + x: torch.Tensor = torch.arange(kernel_size, device=device, dtype=dtype) + x = 0.5 - 0.5 * torch.cos(2.0 * math.pi * x / float(kernel_size - 1)) + return x + + +def get_hanning_kernel2d(kernel_size: Tuple[int, int], device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: + r"""Returns 2d Hanning kernel, used in signal processing and KCF tracker. + + Args: + kernel_size: The size of the kernel for the filter. It should be positive. + + Returns: + 2D tensor with Hanning filter coefficients. + .. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right) + + Shape: + - Output: math:`(\text{kernel_size[0], kernel_size[1]})` + """ + if kernel_size[0] <= 2 or kernel_size[1] <= 2: + raise TypeError(f"ksize must be an tuple of positive integers > 2. Got {kernel_size}") + ky: torch.Tensor = get_hanning_kernel1d(kernel_size[0], device, dtype)[None].T + kx: torch.Tensor = get_hanning_kernel1d(kernel_size[1], device, dtype)[None] + kernel2d = ky @ kx + return kernel2d \ No newline at end of file diff --git a/model/canny/sobel.py b/model/canny/sobel.py new file mode 100644 index 0000000000000000000000000000000000000000..d780c5c4a22bb6403122a292b6d30fa022f262e8 --- /dev/null +++ b/model/canny/sobel.py @@ -0,0 +1,263 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .kernels import get_spatial_gradient_kernel2d, get_spatial_gradient_kernel3d, normalize_kernel2d + + +def spatial_gradient(input: torch.Tensor, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> torch.Tensor: + r"""Compute the first order image derivative in both x and y using a Sobel operator. + + .. image:: _static/img/spatial_gradient.png + + Args: + input: input image tensor with shape :math:`(B, C, H, W)`. + mode: derivatives modality, can be: `sobel` or `diff`. + order: the order of the derivatives. + normalized: whether the output is normalized. + + Return: + the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`. + + .. note:: + See a working example `here `__. + + Examples: + >>> input = torch.rand(1, 3, 4, 4) + >>> output = spatial_gradient(input) # 1x3x2x4x4 + >>> output.shape + torch.Size([1, 3, 2, 4, 4]) + """ + if not isinstance(input, torch.Tensor): + raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") + + if not len(input.shape) == 4: + raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") + # allocate kernel + kernel: torch.Tensor = get_spatial_gradient_kernel2d(mode, order) + if normalized: + kernel = normalize_kernel2d(kernel) + + # prepare kernel + b, c, h, w = input.shape + tmp_kernel: torch.Tensor = kernel.to(input).detach() + tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1) + + # convolve input tensor with sobel kernel + kernel_flip: torch.Tensor = tmp_kernel.flip(-3) + + # Pad with "replicate for spatial dims, but with zeros for channel + spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2] + out_channels: int = 3 if order == 2 else 2 + padded_inp: torch.Tensor = F.pad(input.reshape(b * c, 1, h, w), spatial_pad, 'replicate')[:, :, None] + + return F.conv3d(padded_inp, kernel_flip, padding=0).view(b, c, out_channels, h, w) + + +def spatial_gradient3d(input: torch.Tensor, mode: str = 'diff', order: int = 1) -> torch.Tensor: + r"""Compute the first and second order volume derivative in x, y and d using a diff operator. + + Args: + input: input features tensor with shape :math:`(B, C, D, H, W)`. + mode: derivatives modality, can be: `sobel` or `diff`. + order: the order of the derivatives. + + Return: + the spatial gradients of the input feature map with shape math:`(B, C, 3, D, H, W)` + or :math:`(B, C, 6, D, H, W)`. + + Examples: + >>> input = torch.rand(1, 4, 2, 4, 4) + >>> output = spatial_gradient3d(input) + >>> output.shape + torch.Size([1, 4, 3, 2, 4, 4]) + """ + if not isinstance(input, torch.Tensor): + raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") + + if not len(input.shape) == 5: + raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}") + b, c, d, h, w = input.shape + dev = input.device + dtype = input.dtype + if (mode == 'diff') and (order == 1): + # we go for the special case implementation due to conv3d bad speed + x: torch.Tensor = F.pad(input, 6 * [1], 'replicate') + center = slice(1, -1) + left = slice(0, -2) + right = slice(2, None) + out = torch.empty(b, c, 3, d, h, w, device=dev, dtype=dtype) + out[..., 0, :, :, :] = x[..., center, center, right] - x[..., center, center, left] + out[..., 1, :, :, :] = x[..., center, right, center] - x[..., center, left, center] + out[..., 2, :, :, :] = x[..., right, center, center] - x[..., left, center, center] + out = 0.5 * out + else: + # prepare kernel + # allocate kernel + kernel: torch.Tensor = get_spatial_gradient_kernel3d(mode, order) + + tmp_kernel: torch.Tensor = kernel.to(input).detach() + tmp_kernel = tmp_kernel.repeat(c, 1, 1, 1, 1) + + # convolve input tensor with grad kernel + kernel_flip: torch.Tensor = tmp_kernel.flip(-3) + + # Pad with "replicate for spatial dims, but with zeros for channel + spatial_pad = [ + kernel.size(2) // 2, + kernel.size(2) // 2, + kernel.size(3) // 2, + kernel.size(3) // 2, + kernel.size(4) // 2, + kernel.size(4) // 2, + ] + out_ch: int = 6 if order == 2 else 3 + out = F.conv3d(F.pad(input, spatial_pad, 'replicate'), kernel_flip, padding=0, groups=c).view( + b, c, out_ch, d, h, w + ) + return out + + +def sobel(input: torch.Tensor, normalized: bool = True, eps: float = 1e-6) -> torch.Tensor: + r"""Compute the Sobel operator and returns the magnitude per channel. + + .. image:: _static/img/sobel.png + + Args: + input: the input image with shape :math:`(B,C,H,W)`. + normalized: if True, L1 norm of the kernel is set to 1. + eps: regularization number to avoid NaN during backprop. + + Return: + the sobel edge gradient magnitudes map with shape :math:`(B,C,H,W)`. + + .. note:: + See a working example `here `__. + + Example: + >>> input = torch.rand(1, 3, 4, 4) + >>> output = sobel(input) # 1x3x4x4 + >>> output.shape + torch.Size([1, 3, 4, 4]) + """ + if not isinstance(input, torch.Tensor): + raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") + + if not len(input.shape) == 4: + raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") + + # comput the x/y gradients + edges: torch.Tensor = spatial_gradient(input, normalized=normalized) + + # unpack the edges + gx: torch.Tensor = edges[:, :, 0] + gy: torch.Tensor = edges[:, :, 1] + + # compute gradient maginitude + magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps) + + return magnitude + + +class SpatialGradient(nn.Module): + r"""Compute the first order image derivative in both x and y using a Sobel operator. + + Args: + mode: derivatives modality, can be: `sobel` or `diff`. + order: the order of the derivatives. + normalized: whether the output is normalized. + + Return: + the sobel edges of the input feature map. + + Shape: + - Input: :math:`(B, C, H, W)` + - Output: :math:`(B, C, 2, H, W)` + + Examples: + >>> input = torch.rand(1, 3, 4, 4) + >>> output = SpatialGradient()(input) # 1x3x2x4x4 + """ + + def __init__(self, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> None: + super().__init__() + self.normalized: bool = normalized + self.order: int = order + self.mode: str = mode + + def __repr__(self) -> str: + return ( + self.__class__.__name__ + '(' + 'order=' + str(self.order) + ', ' + 'normalized=' + str(self.normalized) + ', ' + 'mode=' + self.mode + ')' + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return spatial_gradient(input, self.mode, self.order, self.normalized) + + +class SpatialGradient3d(nn.Module): + r"""Compute the first and second order volume derivative in x, y and d using a diff operator. + + Args: + mode: derivatives modality, can be: `sobel` or `diff`. + order: the order of the derivatives. + + Return: + the spatial gradients of the input feature map. + + Shape: + - Input: :math:`(B, C, D, H, W)`. D, H, W are spatial dimensions, gradient is calculated w.r.t to them. + - Output: :math:`(B, C, 3, D, H, W)` or :math:`(B, C, 6, D, H, W)` + + Examples: + >>> input = torch.rand(1, 4, 2, 4, 4) + >>> output = SpatialGradient3d()(input) + >>> output.shape + torch.Size([1, 4, 3, 2, 4, 4]) + """ + + def __init__(self, mode: str = 'diff', order: int = 1) -> None: + super().__init__() + self.order: int = order + self.mode: str = mode + self.kernel = get_spatial_gradient_kernel3d(mode, order) + return + + def __repr__(self) -> str: + return self.__class__.__name__ + '(' 'order=' + str(self.order) + ', ' + 'mode=' + self.mode + ')' + + def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore + return spatial_gradient3d(input, self.mode, self.order) + + +class Sobel(nn.Module): + r"""Compute the Sobel operator and returns the magnitude per channel. + + Args: + normalized: if True, L1 norm of the kernel is set to 1. + eps: regularization number to avoid NaN during backprop. + + Return: + the sobel edge gradient magnitudes map. + + Shape: + - Input: :math:`(B, C, H, W)` + - Output: :math:`(B, C, H, W)` + + Examples: + >>> input = torch.rand(1, 3, 4, 4) + >>> output = Sobel()(input) # 1x3x4x4 + """ + + def __init__(self, normalized: bool = True, eps: float = 1e-6) -> None: + super().__init__() + self.normalized: bool = normalized + self.eps: float = eps + + def __repr__(self) -> str: + return self.__class__.__name__ + '(' 'normalized=' + str(self.normalized) + ')' + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return sobel(input, self.normalized, self.eps) \ No newline at end of file diff --git a/model/misc.py b/model/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..43b849902245dd338a36f4f4ff09e33425365af6 --- /dev/null +++ b/model/misc.py @@ -0,0 +1,131 @@ +import os +import re +import random +import time +import torch +import torch.nn as nn +import logging +import numpy as np +from os import path as osp + +def constant_init(module, val, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.constant_(module.weight, val) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + +initialized_logger = {} +def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): + """Get the root logger. + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. + Args: + logger_name (str): root logger name. Default: 'basicsr'. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(logger_name) + # if the logger has been initialized, just return it + if logger_name in initialized_logger: + return logger + + format_str = '%(asctime)s %(levelname)s: %(message)s' + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(logging.Formatter(format_str)) + logger.addHandler(stream_handler) + logger.propagate = False + + if log_file is not None: + logger.setLevel(log_level) + # add file handler + # file_handler = logging.FileHandler(log_file, 'w') + file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + initialized_logger[logger_name] = True + return logger + + +IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ + torch.__version__)[0][:3])] >= [1, 12, 0] + +def gpu_is_available(): + if IS_HIGH_VERSION: + if torch.backends.mps.is_available(): + return True + return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False + +def get_device(gpu_id=None): + if gpu_id is None: + gpu_str = '' + elif isinstance(gpu_id, int): + gpu_str = f':{gpu_id}' + else: + raise TypeError('Input should be int value.') + + if IS_HIGH_VERSION: + if torch.backends.mps.is_available(): + return torch.device('mps'+gpu_str) + return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative pathes. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) \ No newline at end of file diff --git a/model/modules/base_module.py b/model/modules/base_module.py new file mode 100644 index 0000000000000000000000000000000000000000..b28c094308dd4d1bbb62dd75e02e937e2c9ddf14 --- /dev/null +++ b/model/modules/base_module.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from functools import reduce + +class BaseNetwork(nn.Module): + def __init__(self): + super(BaseNetwork, self).__init__() + + def print_network(self): + if isinstance(self, list): + self = self[0] + num_params = 0 + for param in self.parameters(): + num_params += param.numel() + print( + 'Network [%s] was created. Total number of parameters: %.1f million. ' + 'To see the architecture, do print(network).' % + (type(self).__name__, num_params / 1000000)) + + def init_weights(self, init_type='normal', gain=0.02): + ''' + initialize network's weights + init_type: normal | xavier | kaiming | orthogonal + https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 + ''' + def init_func(m): + classname = m.__class__.__name__ + if classname.find('InstanceNorm2d') != -1: + if hasattr(m, 'weight') and m.weight is not None: + nn.init.constant_(m.weight.data, 1.0) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + elif hasattr(m, 'weight') and (classname.find('Conv') != -1 + or classname.find('Linear') != -1): + if init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'xavier_uniform': + nn.init.xavier_uniform_(m.weight.data, gain=1.0) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=gain) + elif init_type == 'none': # uses pytorch's default init method + m.reset_parameters() + else: + raise NotImplementedError( + 'initialization method [%s] is not implemented' % + init_type) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + # propagate to children + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights(init_type, gain) + + +class Vec2Feat(nn.Module): + def __init__(self, channel, hidden, kernel_size, stride, padding): + super(Vec2Feat, self).__init__() + self.relu = nn.LeakyReLU(0.2, inplace=True) + c_out = reduce((lambda x, y: x * y), kernel_size) * channel + self.embedding = nn.Linear(hidden, c_out) + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.bias_conv = nn.Conv2d(channel, + channel, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t, output_size): + b_, _, _, _, c_ = x.shape + x = x.view(b_, -1, c_) + feat = self.embedding(x) + b, _, c = feat.size() + feat = feat.view(b * t, -1, c).permute(0, 2, 1) + feat = F.fold(feat, + output_size=output_size, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding) + feat = self.bias_conv(feat) + return feat + + +class FusionFeedForward(nn.Module): + def __init__(self, dim, hidden_dim=1960, t2t_params=None): + super(FusionFeedForward, self).__init__() + # We set hidden_dim as a default to 1960 + self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim)) + self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim)) + assert t2t_params is not None + self.t2t_params = t2t_params + self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size']) # 49 + + def forward(self, x, output_size): + n_vecs = 1 + for i, d in enumerate(self.t2t_params['kernel_size']): + n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] - + (d - 1) - 1) / self.t2t_params['stride'][i] + 1) + + x = self.fc1(x) + b, n, c = x.size() + normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1) + normalizer = F.fold(normalizer, + output_size=output_size, + kernel_size=self.t2t_params['kernel_size'], + padding=self.t2t_params['padding'], + stride=self.t2t_params['stride']) + + x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1), + output_size=output_size, + kernel_size=self.t2t_params['kernel_size'], + padding=self.t2t_params['padding'], + stride=self.t2t_params['stride']) + + x = F.unfold(x / normalizer, + kernel_size=self.t2t_params['kernel_size'], + padding=self.t2t_params['padding'], + stride=self.t2t_params['stride']).permute( + 0, 2, 1).contiguous().view(b, n, c) + x = self.fc2(x) + return x diff --git a/model/modules/deformconv.py b/model/modules/deformconv.py new file mode 100644 index 0000000000000000000000000000000000000000..89cb31b3d80bd69704a380930964db6fb29a6bbe --- /dev/null +++ b/model/modules/deformconv.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +from torch.nn import init as init +from torch.nn.modules.utils import _pair, _single +import math + +class ModulatedDeformConv2d(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deform_groups=1, + bias=True): + super(ModulatedDeformConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deform_groups = deform_groups + self.with_bias = bias + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x, offset, mask): + pass \ No newline at end of file diff --git a/model/modules/flow_comp_raft.py b/model/modules/flow_comp_raft.py new file mode 100644 index 0000000000000000000000000000000000000000..5add56c45b8fa7a58a83a9c94986560fe477fc6e --- /dev/null +++ b/model/modules/flow_comp_raft.py @@ -0,0 +1,265 @@ +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F + +from RAFT import RAFT +from model.modules.flow_loss_utils import flow_warp, ternary_loss2 + + +def initialize_RAFT(model_path='weights/raft-things.pth', device='cuda'): + """Initializes the RAFT model. + """ + args = argparse.ArgumentParser() + args.raft_model = model_path + args.small = False + args.mixed_precision = False + args.alternate_corr = False + model = torch.nn.DataParallel(RAFT(args)) + model.load_state_dict(torch.load(args.raft_model, map_location='cpu')) + model = model.module + + model.to(device) + + return model + + +class RAFT_bi(nn.Module): + """Flow completion loss""" + def __init__(self, model_path='weights/raft-things.pth', device='cuda'): + super().__init__() + self.fix_raft = initialize_RAFT(model_path, device=device) + + for p in self.fix_raft.parameters(): + p.requires_grad = False + + self.l1_criterion = nn.L1Loss() + self.eval() + + def forward(self, gt_local_frames, iters=20): + b, l_t, c, h, w = gt_local_frames.size() + # print(gt_local_frames.shape) + + with torch.no_grad(): + gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(-1, c, h, w) + gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(-1, c, h, w) + # print(gtlf_1.shape) + + _, gt_flows_forward = self.fix_raft(gtlf_1, gtlf_2, iters=iters, test_mode=True) + _, gt_flows_backward = self.fix_raft(gtlf_2, gtlf_1, iters=iters, test_mode=True) + + + gt_flows_forward = gt_flows_forward.view(b, l_t-1, 2, h, w) + gt_flows_backward = gt_flows_backward.view(b, l_t-1, 2, h, w) + + return gt_flows_forward, gt_flows_backward + + +################################################################################## +def smoothness_loss(flow, cmask): + delta_u, delta_v, mask = smoothness_deltas(flow) + loss_u = charbonnier_loss(delta_u, cmask) + loss_v = charbonnier_loss(delta_v, cmask) + return loss_u + loss_v + + +def smoothness_deltas(flow): + """ + flow: [b, c, h, w] + """ + mask_x = create_mask(flow, [[0, 0], [0, 1]]) + mask_y = create_mask(flow, [[0, 1], [0, 0]]) + mask = torch.cat((mask_x, mask_y), dim=1) + mask = mask.to(flow.device) + filter_x = torch.tensor([[0, 0, 0.], [0, 1, -1], [0, 0, 0]]) + filter_y = torch.tensor([[0, 0, 0.], [0, 1, 0], [0, -1, 0]]) + weights = torch.ones([2, 1, 3, 3]) + weights[0, 0] = filter_x + weights[1, 0] = filter_y + weights = weights.to(flow.device) + + flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1) + delta_u = F.conv2d(flow_u, weights, stride=1, padding=1) + delta_v = F.conv2d(flow_v, weights, stride=1, padding=1) + return delta_u, delta_v, mask + + +def second_order_loss(flow, cmask): + delta_u, delta_v, mask = second_order_deltas(flow) + loss_u = charbonnier_loss(delta_u, cmask) + loss_v = charbonnier_loss(delta_v, cmask) + return loss_u + loss_v + + +def charbonnier_loss(x, mask=None, truncate=None, alpha=0.45, beta=1.0, epsilon=0.001): + """ + Compute the generalized charbonnier loss of the difference tensor x + All positions where mask == 0 are not taken into account + x: a tensor of shape [b, c, h, w] + mask: a mask of shape [b, mc, h, w], where mask channels must be either 1 or the same as + the number of channels of x. Entries should be 0 or 1 + return: loss + """ + b, c, h, w = x.shape + norm = b * c * h * w + error = torch.pow(torch.square(x * beta) + torch.square(torch.tensor(epsilon)), alpha) + if mask is not None: + error = mask * error + if truncate is not None: + error = torch.min(error, truncate) + return torch.sum(error) / norm + + +def second_order_deltas(flow): + """ + consider the single flow first + flow shape: [b, c, h, w] + """ + # create mask + mask_x = create_mask(flow, [[0, 0], [1, 1]]) + mask_y = create_mask(flow, [[1, 1], [0, 0]]) + mask_diag = create_mask(flow, [[1, 1], [1, 1]]) + mask = torch.cat((mask_x, mask_y, mask_diag, mask_diag), dim=1) + mask = mask.to(flow.device) + + filter_x = torch.tensor([[0, 0, 0.], [1, -2, 1], [0, 0, 0]]) + filter_y = torch.tensor([[0, 1, 0.], [0, -2, 0], [0, 1, 0]]) + filter_diag1 = torch.tensor([[1, 0, 0.], [0, -2, 0], [0, 0, 1]]) + filter_diag2 = torch.tensor([[0, 0, 1.], [0, -2, 0], [1, 0, 0]]) + weights = torch.ones([4, 1, 3, 3]) + weights[0] = filter_x + weights[1] = filter_y + weights[2] = filter_diag1 + weights[3] = filter_diag2 + weights = weights.to(flow.device) + + # split the flow into flow_u and flow_v, conv them with the weights + flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1) + delta_u = F.conv2d(flow_u, weights, stride=1, padding=1) + delta_v = F.conv2d(flow_v, weights, stride=1, padding=1) + return delta_u, delta_v, mask + +def create_mask(tensor, paddings): + """ + tensor shape: [b, c, h, w] + paddings: [2 x 2] shape list, the first row indicates up and down paddings + the second row indicates left and right paddings + | | + | x | + | x * x | + | x | + | | + """ + shape = tensor.shape + inner_height = shape[2] - (paddings[0][0] + paddings[0][1]) + inner_width = shape[3] - (paddings[1][0] + paddings[1][1]) + inner = torch.ones([inner_height, inner_width]) + torch_paddings = [paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]] # left, right, up and down + mask2d = F.pad(inner, pad=torch_paddings) + mask3d = mask2d.unsqueeze(0).repeat(shape[0], 1, 1) + mask4d = mask3d.unsqueeze(1) + return mask4d.detach() + +def ternary_loss(flow_comp, flow_gt, mask, current_frame, shift_frame, scale_factor=1): + if scale_factor != 1: + current_frame = F.interpolate(current_frame, scale_factor=1 / scale_factor, mode='bilinear') + shift_frame = F.interpolate(shift_frame, scale_factor=1 / scale_factor, mode='bilinear') + warped_sc = flow_warp(shift_frame, flow_gt.permute(0, 2, 3, 1)) + noc_mask = torch.exp(-50. * torch.sum(torch.abs(current_frame - warped_sc), dim=1).pow(2)).unsqueeze(1) + warped_comp_sc = flow_warp(shift_frame, flow_comp.permute(0, 2, 3, 1)) + loss = ternary_loss2(current_frame, warped_comp_sc, noc_mask, mask) + return loss + +class FlowLoss(nn.Module): + def __init__(self): + super().__init__() + self.l1_criterion = nn.L1Loss() + + def forward(self, pred_flows, gt_flows, masks, frames): + # pred_flows: b t-1 2 h w + loss = 0 + warp_loss = 0 + h, w = pred_flows[0].shape[-2:] + masks = [masks[:,:-1,...].contiguous(), masks[:, 1:, ...].contiguous()] + frames0 = frames[:,:-1,...] + frames1 = frames[:,1:,...] + current_frames = [frames0, frames1] + next_frames = [frames1, frames0] + for i in range(len(pred_flows)): + # print(pred_flows[i].shape) + combined_flow = pred_flows[i] * masks[i] + gt_flows[i] * (1-masks[i]) + l1_loss = self.l1_criterion(pred_flows[i] * masks[i], gt_flows[i] * masks[i]) / torch.mean(masks[i]) + l1_loss += self.l1_criterion(pred_flows[i] * (1-masks[i]), gt_flows[i] * (1-masks[i])) / torch.mean((1-masks[i])) + + smooth_loss = smoothness_loss(combined_flow.reshape(-1,2,h,w), masks[i].reshape(-1,1,h,w)) + smooth_loss2 = second_order_loss(combined_flow.reshape(-1,2,h,w), masks[i].reshape(-1,1,h,w)) + + warp_loss_i = ternary_loss(combined_flow.reshape(-1,2,h,w), gt_flows[i].reshape(-1,2,h,w), + masks[i].reshape(-1,1,h,w), current_frames[i].reshape(-1,3,h,w), next_frames[i].reshape(-1,3,h,w)) + + loss += l1_loss + smooth_loss + smooth_loss2 + + warp_loss += warp_loss_i + + return loss, warp_loss + + +def edgeLoss(preds_edges, edges): + """ + + Args: + preds_edges: with shape [b, c, h , w] + edges: with shape [b, c, h, w] + + Returns: Edge losses + + """ + mask = (edges > 0.5).float() + b, c, h, w = mask.shape + num_pos = torch.sum(mask, dim=[1, 2, 3]).float() # Shape: [b,]. + num_neg = c * h * w - num_pos # Shape: [b,]. + neg_weights = (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3) + pos_weights = (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3) + weight = neg_weights * mask + pos_weights * (1 - mask) # weight for debug + losses = F.binary_cross_entropy_with_logits(preds_edges.float(), edges.float(), weight=weight, reduction='none') + loss = torch.mean(losses) + return loss + +class EdgeLoss(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, pred_edges, gt_edges, masks): + # pred_flows: b t-1 1 h w + loss = 0 + h, w = pred_edges[0].shape[-2:] + masks = [masks[:,:-1,...].contiguous(), masks[:, 1:, ...].contiguous()] + for i in range(len(pred_edges)): + # print(f'edges_{i}', torch.sum(gt_edges[i])) # debug + combined_edge = pred_edges[i] * masks[i] + gt_edges[i] * (1-masks[i]) + edge_loss = (edgeLoss(pred_edges[i].reshape(-1,1,h,w), gt_edges[i].reshape(-1,1,h,w)) \ + + 5 * edgeLoss(combined_edge.reshape(-1,1,h,w), gt_edges[i].reshape(-1,1,h,w))) + loss += edge_loss + + return loss + + +class FlowSimpleLoss(nn.Module): + def __init__(self): + super().__init__() + self.l1_criterion = nn.L1Loss() + + def forward(self, pred_flows, gt_flows): + # pred_flows: b t-1 2 h w + loss = 0 + h, w = pred_flows[0].shape[-2:] + h_orig, w_orig = gt_flows[0].shape[-2:] + pred_flows = [f.view(-1, 2, h, w) for f in pred_flows] + gt_flows = [f.view(-1, 2, h_orig, w_orig) for f in gt_flows] + + ds_factor = 1.0*h/h_orig + gt_flows = [F.interpolate(f, scale_factor=ds_factor, mode='area') * ds_factor for f in gt_flows] + for i in range(len(pred_flows)): + loss += self.l1_criterion(pred_flows[i], gt_flows[i]) + + return loss \ No newline at end of file diff --git a/model/modules/flow_loss_utils.py b/model/modules/flow_loss_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..6e465c0605df760920b5cfc7f9079fadb74fbec1 --- /dev/null +++ b/model/modules/flow_loss_utils.py @@ -0,0 +1,142 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +def flow_warp(x, + flow, + interpolation='bilinear', + padding_mode='zeros', + align_corners=True): + """Warp an image or a feature map with optical flow. + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is + a two-channel, denoting the width and height relative offsets. + Note that the values are not normalized to [-1, 1]. + interpolation (str): Interpolation mode: 'nearest' or 'bilinear'. + Default: 'bilinear'. + padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Whether align corners. Default: True. + Returns: + Tensor: Warped image or feature map. + """ + if x.size()[-2:] != flow.size()[1:3]: + raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and ' + f'flow ({flow.size()[1:3]}) are not the same.') + _, _, h, w = x.size() + # create mesh grid + device = flow.device + grid_y, grid_x = torch.meshgrid(torch.arange(0, h, device=device), torch.arange(0, w, device=device)) + grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2) + grid.requires_grad = False + + grid_flow = grid + flow + # scale grid_flow to [-1,1] + grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0 + grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0 + grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3) + output = F.grid_sample(x, + grid_flow, + mode=interpolation, + padding_mode=padding_mode, + align_corners=align_corners) + return output + + +# def image_warp(image, flow): +# b, c, h, w = image.size() +# device = image.device +# flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1) # normalize to [-1~1](from upper left to lower right +# flow = flow.permute(0, 2, 3, 1) # if you wanna use grid_sample function, the channel(band) shape of show must be in the last dimension +# x = np.linspace(-1, 1, w) +# y = np.linspace(-1, 1, h) +# X, Y = np.meshgrid(x, y) +# grid = torch.cat((torch.from_numpy(X.astype('float32')).unsqueeze(0).unsqueeze(3), +# torch.from_numpy(Y.astype('float32')).unsqueeze(0).unsqueeze(3)), 3).to(device) +# output = torch.nn.functional.grid_sample(image, grid + flow, mode='bilinear', padding_mode='zeros') +# return output + + +def length_sq(x): + return torch.sum(torch.square(x), dim=1, keepdim=True) + + +def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5): + flow_bw_warped = flow_warp(flow_bw, flow_fw.permute(0, 2, 3, 1)) # wb(wf(x)) + flow_fw_warped = flow_warp(flow_fw, flow_bw.permute(0, 2, 3, 1)) # wf(wb(x)) + flow_diff_fw = flow_fw + flow_bw_warped # wf + wb(wf(x)) + flow_diff_bw = flow_bw + flow_fw_warped # wb + wf(wb(x)) + + mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) # |wf| + |wb(wf(x))| + mag_sq_bw = length_sq(flow_bw) + length_sq(flow_fw_warped) # |wb| + |wf(wb(x))| + occ_thresh_fw = alpha1 * mag_sq_fw + alpha2 + occ_thresh_bw = alpha1 * mag_sq_bw + alpha2 + + fb_occ_fw = (length_sq(flow_diff_fw) > occ_thresh_fw).float() + fb_occ_bw = (length_sq(flow_diff_bw) > occ_thresh_bw).float() + + return fb_occ_fw, fb_occ_bw # fb_occ_fw -> frame2 area occluded by frame1, fb_occ_bw -> frame1 area occluded by frame2 + + +def rgb2gray(image): + gray_image = image[:, 0] * 0.299 + image[:, 1] * 0.587 + 0.110 * image[:, 2] + gray_image = gray_image.unsqueeze(1) + return gray_image + + +def ternary_transform(image, max_distance=1): + device = image.device + patch_size = 2 * max_distance + 1 + intensities = rgb2gray(image) * 255 + out_channels = patch_size * patch_size + w = np.eye(out_channels).reshape(out_channels, 1, patch_size, patch_size) + weights = torch.from_numpy(w).float().to(device) + patches = F.conv2d(intensities, weights, stride=1, padding=1) + transf = patches - intensities + transf_norm = transf / torch.sqrt(0.81 + torch.square(transf)) + return transf_norm + + +def hamming_distance(t1, t2): + dist = torch.square(t1 - t2) + dist_norm = dist / (0.1 + dist) + dist_sum = torch.sum(dist_norm, dim=1, keepdim=True) + return dist_sum + + +def create_mask(mask, paddings): + """ + padding: [[top, bottom], [left, right]] + """ + shape = mask.shape + inner_height = shape[2] - (paddings[0][0] + paddings[0][1]) + inner_width = shape[3] - (paddings[1][0] + paddings[1][1]) + inner = torch.ones([inner_height, inner_width]) + + mask2d = F.pad(inner, pad=[paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]]) + mask3d = mask2d.unsqueeze(0) + mask4d = mask3d.unsqueeze(0).repeat(shape[0], 1, 1, 1) + return mask4d.detach() + + +def ternary_loss2(frame1, warp_frame21, confMask, masks, max_distance=1): + """ + + Args: + frame1: torch tensor, with shape [b * t, c, h, w] + warp_frame21: torch tensor, with shape [b * t, c, h, w] + confMask: confidence mask, with shape [b * t, c, h, w] + masks: torch tensor, with shape [b * t, c, h, w] + max_distance: maximum distance. + + Returns: ternary loss + + """ + t1 = ternary_transform(frame1) + t21 = ternary_transform(warp_frame21) + dist = hamming_distance(t1, t21) + loss = torch.mean(dist * confMask * masks) / torch.mean(masks) + return loss + diff --git a/model/modules/sparse_transformer.py b/model/modules/sparse_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..11028ffe05a0f59e9d222b0a18f92b1fde12007b --- /dev/null +++ b/model/modules/sparse_transformer.py @@ -0,0 +1,344 @@ +import math +from functools import reduce +import torch +import torch.nn as nn +import torch.nn.functional as F + +class SoftSplit(nn.Module): + def __init__(self, channel, hidden, kernel_size, stride, padding): + super(SoftSplit, self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.t2t = nn.Unfold(kernel_size=kernel_size, + stride=stride, + padding=padding) + c_in = reduce((lambda x, y: x * y), kernel_size) * channel + self.embedding = nn.Linear(c_in, hidden) + + def forward(self, x, b, output_size): + f_h = int((output_size[0] + 2 * self.padding[0] - + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + f_w = int((output_size[1] + 2 * self.padding[1] - + (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + + feat = self.t2t(x) + feat = feat.permute(0, 2, 1) + # feat shape [b*t, num_vec, ks*ks*c] + feat = self.embedding(feat) + # feat shape after embedding [b, t*num_vec, hidden] + feat = feat.view(b, -1, f_h, f_w, feat.size(2)) + return feat + + +class SoftComp(nn.Module): + def __init__(self, channel, hidden, kernel_size, stride, padding): + super(SoftComp, self).__init__() + self.relu = nn.LeakyReLU(0.2, inplace=True) + c_out = reduce((lambda x, y: x * y), kernel_size) * channel + self.embedding = nn.Linear(hidden, c_out) + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.bias_conv = nn.Conv2d(channel, + channel, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t, output_size): + b_, _, _, _, c_ = x.shape + x = x.view(b_, -1, c_) + feat = self.embedding(x) + b, _, c = feat.size() + feat = feat.view(b * t, -1, c).permute(0, 2, 1) + feat = F.fold(feat, + output_size=output_size, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding) + feat = self.bias_conv(feat) + return feat + + +class FusionFeedForward(nn.Module): + def __init__(self, dim, hidden_dim=1960, t2t_params=None): + super(FusionFeedForward, self).__init__() + # We set hidden_dim as a default to 1960 + self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim)) + self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim)) + assert t2t_params is not None + self.t2t_params = t2t_params + self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size']) # 49 + + def forward(self, x, output_size): + n_vecs = 1 + for i, d in enumerate(self.t2t_params['kernel_size']): + n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] - + (d - 1) - 1) / self.t2t_params['stride'][i] + 1) + + x = self.fc1(x) + b, n, c = x.size() + normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1) + normalizer = F.fold(normalizer, + output_size=output_size, + kernel_size=self.t2t_params['kernel_size'], + padding=self.t2t_params['padding'], + stride=self.t2t_params['stride']) + + x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1), + output_size=output_size, + kernel_size=self.t2t_params['kernel_size'], + padding=self.t2t_params['padding'], + stride=self.t2t_params['stride']) + + x = F.unfold(x / normalizer, + kernel_size=self.t2t_params['kernel_size'], + padding=self.t2t_params['padding'], + stride=self.t2t_params['stride']).permute( + 0, 2, 1).contiguous().view(b, n, c) + x = self.fc2(x) + return x + + +def window_partition(x, window_size, n_head): + """ + Args: + x: shape is (B, T, H, W, C) + window_size (tuple[int]): window size + Returns: + windows: (B, num_windows_h, num_windows_w, n_head, T, window_size, window_size, C//n_head) + """ + B, T, H, W, C = x.shape + x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1], window_size[1], n_head, C//n_head) + windows = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous() + return windows + +class SparseWindowAttention(nn.Module): + def __init__(self, dim, n_head, window_size, pool_size=(4,4), qkv_bias=True, attn_drop=0., proj_drop=0., + pooling_token=True): + super().__init__() + assert dim % n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(dim, dim, qkv_bias) + self.query = nn.Linear(dim, dim, qkv_bias) + self.value = nn.Linear(dim, dim, qkv_bias) + # regularization + self.attn_drop = nn.Dropout(attn_drop) + self.proj_drop = nn.Dropout(proj_drop) + # output projection + self.proj = nn.Linear(dim, dim) + self.n_head = n_head + self.window_size = window_size + self.pooling_token = pooling_token + if self.pooling_token: + ks, stride = pool_size, pool_size + self.pool_layer = nn.Conv2d(dim, dim, kernel_size=ks, stride=stride, padding=(0, 0), groups=dim) + self.pool_layer.weight.data.fill_(1. / (pool_size[0] * pool_size[1])) + self.pool_layer.bias.data.fill_(0) + # self.expand_size = tuple(i // 2 for i in window_size) + self.expand_size = tuple((i + 1) // 2 for i in window_size) + + if any(i > 0 for i in self.expand_size): + # get mask for rolled k and rolled v + mask_tl = torch.ones(self.window_size[0], self.window_size[1]) + mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0 + mask_tr = torch.ones(self.window_size[0], self.window_size[1]) + mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0 + mask_bl = torch.ones(self.window_size[0], self.window_size[1]) + mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0 + mask_br = torch.ones(self.window_size[0], self.window_size[1]) + mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0 + masrool_k = torch.stack((mask_tl, mask_tr, mask_bl, mask_br), 0).flatten(0) + self.register_buffer("valid_ind_rolled", masrool_k.nonzero(as_tuple=False).view(-1)) + + self.max_pool = nn.MaxPool2d(window_size, window_size, (0, 0)) + + + def forward(self, x, mask=None, T_ind=None, attn_mask=None): + b, t, h, w, c = x.shape # 20 36 + w_h, w_w = self.window_size[0], self.window_size[1] + c_head = c // self.n_head + n_wh = math.ceil(h / self.window_size[0]) + n_ww = math.ceil(w / self.window_size[1]) + new_h = n_wh * self.window_size[0] # 20 + new_w = n_ww * self.window_size[1] # 36 + pad_r = new_w - w + pad_b = new_h - h + # reverse order + if pad_r > 0 or pad_b > 0: + x = F.pad(x,(0, 0, 0, pad_r, 0, pad_b, 0, 0), mode='constant', value=0) + mask = F.pad(mask,(0, 0, 0, pad_r, 0, pad_b, 0, 0), mode='constant', value=0) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q = self.query(x) + k = self.key(x) + v = self.value(x) + win_q = window_partition(q.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head) + win_k = window_partition(k.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head) + win_v = window_partition(v.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head) + # roll_k and roll_v + if any(i > 0 for i in self.expand_size): + (k_tl, v_tl) = map(lambda a: torch.roll(a, shifts=(-self.expand_size[0], -self.expand_size[1]), dims=(2, 3)), (k, v)) + (k_tr, v_tr) = map(lambda a: torch.roll(a, shifts=(-self.expand_size[0], self.expand_size[1]), dims=(2, 3)), (k, v)) + (k_bl, v_bl) = map(lambda a: torch.roll(a, shifts=(self.expand_size[0], -self.expand_size[1]), dims=(2, 3)), (k, v)) + (k_br, v_br) = map(lambda a: torch.roll(a, shifts=(self.expand_size[0], self.expand_size[1]), dims=(2, 3)), (k, v)) + + (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map( + lambda a: window_partition(a, self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head), + (k_tl, k_tr, k_bl, k_br)) + (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map( + lambda a: window_partition(a, self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head), + (v_tl, v_tr, v_bl, v_br)) + rool_k = torch.cat((k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows), 4).contiguous() + rool_v = torch.cat((v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows), 4).contiguous() # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head] + # mask out tokens in current window + rool_k = rool_k[:, :, :, :, self.valid_ind_rolled] + rool_v = rool_v[:, :, :, :, self.valid_ind_rolled] + roll_N = rool_k.shape[4] + rool_k = rool_k.view(b, n_wh*n_ww, self.n_head, t, roll_N, c // self.n_head) + rool_v = rool_v.view(b, n_wh*n_ww, self.n_head, t, roll_N, c // self.n_head) + win_k = torch.cat((win_k, rool_k), dim=4) + win_v = torch.cat((win_v, rool_v), dim=4) + else: + win_k = win_k + win_v = win_v + + # pool_k and pool_v + if self.pooling_token: + pool_x = self.pool_layer(x.view(b*t, new_h, new_w, c).permute(0,3,1,2)) + _, _, p_h, p_w = pool_x.shape + pool_x = pool_x.permute(0,2,3,1).view(b, t, p_h, p_w, c) + # pool_k + pool_k = self.key(pool_x).unsqueeze(1).repeat(1, n_wh*n_ww, 1, 1, 1, 1) # [b, n_wh*n_ww, t, p_h, p_w, c] + pool_k = pool_k.view(b, n_wh*n_ww, t, p_h, p_w, self.n_head, c_head).permute(0,1,5,2,3,4,6) + pool_k = pool_k.contiguous().view(b, n_wh*n_ww, self.n_head, t, p_h*p_w, c_head) + win_k = torch.cat((win_k, pool_k), dim=4) + # pool_v + pool_v = self.value(pool_x).unsqueeze(1).repeat(1, n_wh*n_ww, 1, 1, 1, 1) # [b, n_wh*n_ww, t, p_h, p_w, c] + pool_v = pool_v.view(b, n_wh*n_ww, t, p_h, p_w, self.n_head, c_head).permute(0,1,5,2,3,4,6) + pool_v = pool_v.contiguous().view(b, n_wh*n_ww, self.n_head, t, p_h*p_w, c_head) + win_v = torch.cat((win_v, pool_v), dim=4) + + # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head] + out = torch.zeros_like(win_q) + l_t = mask.size(1) + + mask = self.max_pool(mask.view(b * l_t, new_h, new_w)) + mask = mask.view(b, l_t, n_wh*n_ww) + mask = torch.sum(mask, dim=1) # [b, n_wh*n_ww] + for i in range(win_q.shape[0]): + ### For masked windows + mask_ind_i = mask[i].nonzero(as_tuple=False).view(-1) + # mask out quary in current window + # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head] + mask_n = len(mask_ind_i) + if mask_n > 0: + win_q_t = win_q[i, mask_ind_i].view(mask_n, self.n_head, t*w_h*w_w, c_head) + win_k_t = win_k[i, mask_ind_i] + win_v_t = win_v[i, mask_ind_i] + # mask out key and value + if T_ind is not None: + # key [n_wh*n_ww, n_head, t, w_h*w_w, c_head] + win_k_t = win_k_t[:, :, T_ind.view(-1)].view(mask_n, self.n_head, -1, c_head) + # value + win_v_t = win_v_t[:, :, T_ind.view(-1)].view(mask_n, self.n_head, -1, c_head) + else: + win_k_t = win_k_t.view(n_wh*n_ww, self.n_head, t*w_h*w_w, c_head) + win_v_t = win_v_t.view(n_wh*n_ww, self.n_head, t*w_h*w_w, c_head) + + att_t = (win_q_t @ win_k_t.transpose(-2, -1)) * (1.0 / math.sqrt(win_q_t.size(-1))) + att_t = F.softmax(att_t, dim=-1) + att_t = self.attn_drop(att_t) + y_t = att_t @ win_v_t + + out[i, mask_ind_i] = y_t.view(-1, self.n_head, t, w_h*w_w, c_head) + + ### For unmasked windows + unmask_ind_i = (mask[i] == 0).nonzero(as_tuple=False).view(-1) + # mask out quary in current window + # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head] + win_q_s = win_q[i, unmask_ind_i] + win_k_s = win_k[i, unmask_ind_i, :, :, :w_h*w_w] + win_v_s = win_v[i, unmask_ind_i, :, :, :w_h*w_w] + + att_s = (win_q_s @ win_k_s.transpose(-2, -1)) * (1.0 / math.sqrt(win_q_s.size(-1))) + att_s = F.softmax(att_s, dim=-1) + att_s = self.attn_drop(att_s) + y_s = att_s @ win_v_s + out[i, unmask_ind_i] = y_s + + # re-assemble all head outputs side by side + out = out.view(b, n_wh, n_ww, self.n_head, t, w_h, w_w, c_head) + out = out.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous().view(b, t, new_h, new_w, c) + + + if pad_r > 0 or pad_b > 0: + out = out[:, :, :h, :w, :] + + # output projection + out = self.proj_drop(self.proj(out)) + return out + + +class TemporalSparseTransformer(nn.Module): + def __init__(self, dim, n_head, window_size, pool_size, + norm_layer=nn.LayerNorm, t2t_params=None): + super().__init__() + self.window_size = window_size + self.attention = SparseWindowAttention(dim, n_head, window_size, pool_size) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + self.mlp = FusionFeedForward(dim, t2t_params=t2t_params) + + def forward(self, x, fold_x_size, mask=None, T_ind=None): + """ + Args: + x: image tokens, shape [B T H W C] + fold_x_size: fold feature size, shape [60 108] + mask: mask tokens, shape [B T H W 1] + Returns: + out_tokens: shape [B T H W C] + """ + B, T, H, W, C = x.shape # 20 36 + + shortcut = x + x = self.norm1(x) + att_x = self.attention(x, mask, T_ind) + + # FFN + x = shortcut + att_x + y = self.norm2(x) + x = x + self.mlp(y.view(B, T * H * W, C), fold_x_size).view(B, T, H, W, C) + + return x + + +class TemporalSparseTransformerBlock(nn.Module): + def __init__(self, dim, n_head, window_size, pool_size, depths, t2t_params=None): + super().__init__() + blocks = [] + for i in range(depths): + blocks.append( + TemporalSparseTransformer(dim, n_head, window_size, pool_size, t2t_params=t2t_params) + ) + self.transformer = nn.Sequential(*blocks) + self.depths = depths + + def forward(self, x, fold_x_size, l_mask=None, t_dilation=2): + """ + Args: + x: image tokens, shape [B T H W C] + fold_x_size: fold feature size, shape [60 108] + l_mask: local mask tokens, shape [B T H W 1] + Returns: + out_tokens: shape [B T H W C] + """ + assert self.depths % t_dilation == 0, 'wrong t_dilation input.' + T = x.size(1) + T_ind = [torch.arange(i, T, t_dilation) for i in range(t_dilation)] * (self.depths // t_dilation) + + for i in range(0, self.depths): + x = self.transformer[i](x, fold_x_size, l_mask, T_ind[i]) + + return x diff --git a/model/modules/spectral_norm.py b/model/modules/spectral_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..f38c34e98c03caa28ce0b15a4083215fb7d8e9af --- /dev/null +++ b/model/modules/spectral_norm.py @@ -0,0 +1,288 @@ +""" +Spectral Normalization from https://arxiv.org/abs/1802.05957 +""" +import torch +from torch.nn.functional import normalize + + +class SpectralNorm(object): + # Invariant before and after each forward call: + # u = normalize(W @ v) + # NB: At initialization, this invariant is not enforced + + _version = 1 + + # At version 1: + # made `W` not a buffer, + # added `v` as a buffer, and + # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. + + def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): + self.name = name + self.dim = dim + if n_power_iterations <= 0: + raise ValueError( + 'Expected n_power_iterations to be positive, but ' + 'got n_power_iterations={}'.format(n_power_iterations)) + self.n_power_iterations = n_power_iterations + self.eps = eps + + def reshape_weight_to_matrix(self, weight): + weight_mat = weight + if self.dim != 0: + # permute dim to front + weight_mat = weight_mat.permute( + self.dim, + *[d for d in range(weight_mat.dim()) if d != self.dim]) + height = weight_mat.size(0) + return weight_mat.reshape(height, -1) + + def compute_weight(self, module, do_power_iteration): + # NB: If `do_power_iteration` is set, the `u` and `v` vectors are + # updated in power iteration **in-place**. This is very important + # because in `DataParallel` forward, the vectors (being buffers) are + # broadcast from the parallelized module to each module replica, + # which is a new module object created on the fly. And each replica + # runs its own spectral norm power iteration. So simply assigning + # the updated vectors to the module this function runs on will cause + # the update to be lost forever. And the next time the parallelized + # module is replicated, the same randomly initialized vectors are + # broadcast and used! + # + # Therefore, to make the change propagate back, we rely on two + # important behaviors (also enforced via tests): + # 1. `DataParallel` doesn't clone storage if the broadcast tensor + # is already on correct device; and it makes sure that the + # parallelized module is already on `device[0]`. + # 2. If the out tensor in `out=` kwarg has correct shape, it will + # just fill in the values. + # Therefore, since the same power iteration is performed on all + # devices, simply updating the tensors in-place will make sure that + # the module replica on `device[0]` will update the _u vector on the + # parallized module (by shared storage). + # + # However, after we update `u` and `v` in-place, we need to **clone** + # them before using them to normalize the weight. This is to support + # backproping through two forward passes, e.g., the common pattern in + # GAN training: loss = D(real) - D(fake). Otherwise, engine will + # complain that variables needed to do backward for the first forward + # (i.e., the `u` and `v` vectors) are changed in the second forward. + weight = getattr(module, self.name + '_orig') + u = getattr(module, self.name + '_u') + v = getattr(module, self.name + '_v') + weight_mat = self.reshape_weight_to_matrix(weight) + + if do_power_iteration: + with torch.no_grad(): + for _ in range(self.n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + v = normalize(torch.mv(weight_mat.t(), u), + dim=0, + eps=self.eps, + out=v) + u = normalize(torch.mv(weight_mat, v), + dim=0, + eps=self.eps, + out=u) + if self.n_power_iterations > 0: + # See above on why we need to clone + u = u.clone() + v = v.clone() + + sigma = torch.dot(u, torch.mv(weight_mat, v)) + weight = weight / sigma + return weight + + def remove(self, module): + with torch.no_grad(): + weight = self.compute_weight(module, do_power_iteration=False) + delattr(module, self.name) + delattr(module, self.name + '_u') + delattr(module, self.name + '_v') + delattr(module, self.name + '_orig') + module.register_parameter(self.name, + torch.nn.Parameter(weight.detach())) + + def __call__(self, module, inputs): + setattr( + module, self.name, + self.compute_weight(module, do_power_iteration=module.training)) + + def _solve_v_and_rescale(self, weight_mat, u, target_sigma): + # Tries to returns a vector `v` s.t. `u = normalize(W @ v)` + # (the invariant at top of this class) and `u @ W @ v = sigma`. + # This uses pinverse in case W^T W is not invertible. + v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), + weight_mat.t(), u.unsqueeze(1)).squeeze(1) + return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) + + @staticmethod + def apply(module, name, n_power_iterations, dim, eps): + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + raise RuntimeError( + "Cannot register two spectral_norm hooks on " + "the same parameter {}".format(name)) + + fn = SpectralNorm(name, n_power_iterations, dim, eps) + weight = module._parameters[name] + + with torch.no_grad(): + weight_mat = fn.reshape_weight_to_matrix(weight) + + h, w = weight_mat.size() + # randomly initialize `u` and `v` + u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) + v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) + + delattr(module, fn.name) + module.register_parameter(fn.name + "_orig", weight) + # We still need to assign weight back as fn.name because all sorts of + # things may assume that it exists, e.g., when initializing weights. + # However, we can't directly assign as it could be an nn.Parameter and + # gets added as a parameter. Instead, we register weight.data as a plain + # attribute. + setattr(module, fn.name, weight.data) + module.register_buffer(fn.name + "_u", u) + module.register_buffer(fn.name + "_v", v) + + module.register_forward_pre_hook(fn) + + module._register_state_dict_hook(SpectralNormStateDictHook(fn)) + module._register_load_state_dict_pre_hook( + SpectralNormLoadStateDictPreHook(fn)) + return fn + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormLoadStateDictPreHook(object): + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn): + self.fn = fn + + # For state_dict with version None, (assuming that it has gone through at + # least one training forward), we have + # + # u = normalize(W_orig @ v) + # W = W_orig / sigma, where sigma = u @ W_orig @ v + # + # To compute `v`, we solve `W_orig @ x = u`, and let + # v = x / (u @ W_orig @ x) * (W / W_orig). + def __call__(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + fn = self.fn + version = local_metadata.get('spectral_norm', + {}).get(fn.name + '.version', None) + if version is None or version < 1: + with torch.no_grad(): + weight_orig = state_dict[prefix + fn.name + '_orig'] + # weight = state_dict.pop(prefix + fn.name) + # sigma = (weight_orig / weight).mean() + weight_mat = fn.reshape_weight_to_matrix(weight_orig) + u = state_dict[prefix + fn.name + '_u'] + # v = fn._solve_v_and_rescale(weight_mat, u, sigma) + # state_dict[prefix + fn.name + '_v'] = v + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormStateDictHook(object): + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn): + self.fn = fn + + def __call__(self, module, state_dict, prefix, local_metadata): + if 'spectral_norm' not in local_metadata: + local_metadata['spectral_norm'] = {} + key = self.fn.name + '.version' + if key in local_metadata['spectral_norm']: + raise RuntimeError( + "Unexpected key in metadata['spectral_norm']: {}".format(key)) + local_metadata['spectral_norm'][key] = self.fn._version + + +def spectral_norm(module, + name='weight', + n_power_iterations=1, + eps=1e-12, + dim=None): + r"""Applies spectral normalization to a parameter in the given module. + + .. math:: + \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, + \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + + Spectral normalization stabilizes the training of discriminators (critics) + in Generative Adversarial Networks (GANs) by rescaling the weight tensor + with spectral norm :math:`\sigma` of the weight matrix calculated using + power iteration method. If the dimension of the weight tensor is greater + than 2, it is reshaped to 2D in power iteration method to get spectral + norm. This is implemented via a hook that calculates spectral norm and + rescales weight before every :meth:`~Module.forward` call. + + See `Spectral Normalization for Generative Adversarial Networks`_ . + + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + n_power_iterations (int, optional): number of power iterations to + calculate spectral norm + eps (float, optional): epsilon for numerical stability in + calculating norms + dim (int, optional): dimension corresponding to number of outputs, + the default is ``0``, except for modules that are instances of + ConvTranspose{1,2,3}d, when it is ``1`` + + Returns: + The original module with the spectral norm hook + + Example:: + + >>> m = spectral_norm(nn.Linear(20, 40)) + >>> m + Linear(in_features=20, out_features=40, bias=True) + >>> m.weight_u.size() + torch.Size([40]) + + """ + if dim is None: + if isinstance(module, + (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d)): + dim = 1 + else: + dim = 0 + SpectralNorm.apply(module, name, n_power_iterations, dim, eps) + return module + + +def remove_spectral_norm(module, name='weight'): + r"""Removes the spectral normalization reparameterization from a module. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + + Example: + >>> m = spectral_norm(nn.Linear(40, 10)) + >>> remove_spectral_norm(m) + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + return module + + raise ValueError("spectral_norm of '{}' not found in {}".format( + name, module)) + + +def use_spectral_norm(module, use_sn=False): + if use_sn: + return spectral_norm(module) + return module \ No newline at end of file diff --git a/model/propainter.py b/model/propainter.py new file mode 100644 index 0000000000000000000000000000000000000000..505a19fc17c047e0def8a309a2a32e9ecff759bf --- /dev/null +++ b/model/propainter.py @@ -0,0 +1,532 @@ +''' Towards An End-to-End Framework for Video Inpainting +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from einops import rearrange + +from model.modules.base_module import BaseNetwork +from model.modules.sparse_transformer import TemporalSparseTransformerBlock, SoftSplit, SoftComp +from model.modules.spectral_norm import spectral_norm as _spectral_norm +from model.modules.flow_loss_utils import flow_warp +from model.modules.deformconv import ModulatedDeformConv2d + +from .misc import constant_init + +def length_sq(x): + return torch.sum(torch.square(x), dim=1, keepdim=True) + +def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5): + flow_bw_warped = flow_warp(flow_bw, flow_fw.permute(0, 2, 3, 1)) # wb(wf(x)) + flow_diff_fw = flow_fw + flow_bw_warped # wf + wb(wf(x)) + + mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) # |wf| + |wb(wf(x))| + occ_thresh_fw = alpha1 * mag_sq_fw + alpha2 + + # fb_valid_fw = (length_sq(flow_diff_fw) < occ_thresh_fw).float() + fb_valid_fw = (length_sq(flow_diff_fw) < occ_thresh_fw).to(flow_fw) + return fb_valid_fw + + +class DeformableAlignment(ModulatedDeformConv2d): + """Second-order deformable alignment module.""" + def __init__(self, *args, **kwargs): + # self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10) + self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 3) + + super(DeformableAlignment, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Sequential( + nn.Conv2d(2*self.out_channels + 2 + 1 + 2, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1), + ) + self.init_offset() + + def init_offset(self): + constant_init(self.conv_offset[-1], val=0, bias=0) + + def forward(self, x, cond_feat, flow): + out = self.conv_offset(cond_feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + + # offset + offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1)) + offset = offset + flow.flip(1).repeat(1, offset.size(1) // 2, 1, 1) + + # mask + mask = torch.sigmoid(mask) + + return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, + self.stride, self.padding, + self.dilation, mask) + + +class BidirectionalPropagation(nn.Module): + def __init__(self, channel, learnable=True): + super(BidirectionalPropagation, self).__init__() + self.deform_align = nn.ModuleDict() + self.backbone = nn.ModuleDict() + self.channel = channel + self.prop_list = ['backward_1', 'forward_1'] + self.learnable = learnable + + if self.learnable: + for i, module in enumerate(self.prop_list): + self.deform_align[module] = DeformableAlignment( + channel, channel, 3, padding=1, deform_groups=16) + + self.backbone[module] = nn.Sequential( + nn.Conv2d(2*channel+2, channel, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(channel, channel, 3, 1, 1), + ) + + self.fuse = nn.Sequential( + nn.Conv2d(2*channel+2, channel, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(channel, channel, 3, 1, 1), + ) + + def binary_mask(self, mask, th=0.1): + mask[mask>th] = 1 + mask[mask<=th] = 0 + # return mask.float() + return mask.to(mask) + + def forward(self, x, flows_forward, flows_backward, mask, interpolation='bilinear'): + """ + x shape : [b, t, c, h, w] + return [b, t, c, h, w] + """ + + # For backward warping + # pred_flows_forward for backward feature propagation + # pred_flows_backward for forward feature propagation + b, t, c, h, w = x.shape + feats, masks = {}, {} + feats['input'] = [x[:, i, :, :, :] for i in range(0, t)] + masks['input'] = [mask[:, i, :, :, :] for i in range(0, t)] + + prop_list = ['backward_1', 'forward_1'] + cache_list = ['input'] + prop_list + + for p_i, module_name in enumerate(prop_list): + feats[module_name] = [] + masks[module_name] = [] + + if 'backward' in module_name: + frame_idx = range(0, t) + frame_idx = frame_idx[::-1] + flow_idx = frame_idx + flows_for_prop = flows_forward + flows_for_check = flows_backward + else: + frame_idx = range(0, t) + flow_idx = range(-1, t - 1) + flows_for_prop = flows_backward + flows_for_check = flows_forward + + for i, idx in enumerate(frame_idx): + feat_current = feats[cache_list[p_i]][idx] + mask_current = masks[cache_list[p_i]][idx] + + if i == 0: + feat_prop = feat_current + mask_prop = mask_current + else: + flow_prop = flows_for_prop[:, flow_idx[i], :, :, :] + flow_check = flows_for_check[:, flow_idx[i], :, :, :] + flow_vaild_mask = fbConsistencyCheck(flow_prop, flow_check) + feat_warped = flow_warp(feat_prop, flow_prop.permute(0, 2, 3, 1), interpolation) + + if self.learnable: + cond = torch.cat([feat_current, feat_warped, flow_prop, flow_vaild_mask, mask_current], dim=1) + feat_prop = self.deform_align[module_name](feat_prop, cond, flow_prop) + mask_prop = mask_current + else: + mask_prop_valid = flow_warp(mask_prop, flow_prop.permute(0, 2, 3, 1)) + mask_prop_valid = self.binary_mask(mask_prop_valid) + + union_vaild_mask = self.binary_mask(mask_current*flow_vaild_mask*(1-mask_prop_valid)) + feat_prop = union_vaild_mask * feat_warped + (1-union_vaild_mask) * feat_current + # update mask + mask_prop = self.binary_mask(mask_current*(1-(flow_vaild_mask*(1-mask_prop_valid)))) + + # refine + if self.learnable: + feat = torch.cat([feat_current, feat_prop, mask_current], dim=1) + feat_prop = feat_prop + self.backbone[module_name](feat) + # feat_prop = self.backbone[module_name](feat_prop) + + feats[module_name].append(feat_prop) + masks[module_name].append(mask_prop) + + # end for + if 'backward' in module_name: + feats[module_name] = feats[module_name][::-1] + masks[module_name] = masks[module_name][::-1] + + outputs_b = torch.stack(feats['backward_1'], dim=1).view(-1, c, h, w) + outputs_f = torch.stack(feats['forward_1'], dim=1).view(-1, c, h, w) + + if self.learnable: + mask_in = mask.view(-1, 2, h, w) + masks_b, masks_f = None, None + outputs = self.fuse(torch.cat([outputs_b, outputs_f, mask_in], dim=1)) + x.view(-1, c, h, w) + else: + masks_b = torch.stack(masks['backward_1'], dim=1) + masks_f = torch.stack(masks['forward_1'], dim=1) + outputs = outputs_f + + return outputs_b.view(b, -1, c, h, w), outputs_f.view(b, -1, c, h, w), \ + outputs.view(b, -1, c, h, w), masks_f + + +class Encoder(nn.Module): + def __init__(self): + super(Encoder, self).__init__() + self.group = [1, 2, 4, 8, 1] + self.layers = nn.ModuleList([ + nn.Conv2d(5, 64, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1), + nn.LeakyReLU(0.2, inplace=True) + ]) + + def forward(self, x): + bt, c, _, _ = x.size() + # h, w = h//4, w//4 + out = x + for i, layer in enumerate(self.layers): + if i == 8: + x0 = out + _, _, h, w = x0.size() + if i > 8 and i % 2 == 0: + g = self.group[(i - 8) // 2] + x = x0.view(bt, g, -1, h, w) + o = out.view(bt, g, -1, h, w) + out = torch.cat([x, o], 2).view(bt, -1, h, w) + out = layer(out) + return out + + +class deconv(nn.Module): + def __init__(self, + input_channel, + output_channel, + kernel_size=3, + padding=0): + super().__init__() + self.conv = nn.Conv2d(input_channel, + output_channel, + kernel_size=kernel_size, + stride=1, + padding=padding) + + def forward(self, x): + x = F.interpolate(x, + scale_factor=2, + mode='bilinear', + align_corners=True) + return self.conv(x) + + +class InpaintGenerator(BaseNetwork): + def __init__(self, init_weights=True, model_path=None): + super(InpaintGenerator, self).__init__() + channel = 128 + hidden = 512 + + # encoder + self.encoder = Encoder() + + # decoder + self.decoder = nn.Sequential( + deconv(channel, 128, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + deconv(64, 64, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)) + + # soft split and soft composition + kernel_size = (7, 7) + padding = (3, 3) + stride = (3, 3) + t2t_params = { + 'kernel_size': kernel_size, + 'stride': stride, + 'padding': padding + } + self.ss = SoftSplit(channel, hidden, kernel_size, stride, padding) + self.sc = SoftComp(channel, hidden, kernel_size, stride, padding) + self.max_pool = nn.MaxPool2d(kernel_size, stride, padding) + + # feature propagation module + self.img_prop_module = BidirectionalPropagation(3, learnable=False) + self.feat_prop_module = BidirectionalPropagation(128, learnable=True) + + + depths = 8 + num_heads = 4 + window_size = (5, 9) + pool_size = (4, 4) + self.transformers = TemporalSparseTransformerBlock(dim=hidden, + n_head=num_heads, + window_size=window_size, + pool_size=pool_size, + depths=depths, + t2t_params=t2t_params) + if init_weights: + self.init_weights() + + + if model_path is not None: + print('Pretrained ProPainter has loaded...') + ckpt = torch.load(model_path, map_location='cpu') + self.load_state_dict(ckpt, strict=True) + + # print network parameter number + self.print_network() + + def img_propagation(self, masked_frames, completed_flows, masks, interpolation='nearest'): + _, _, prop_frames, updated_masks = self.img_prop_module(masked_frames, completed_flows[0], completed_flows[1], masks, interpolation) + return prop_frames, updated_masks + + def forward(self, masked_frames, completed_flows, masks_in, masks_updated, num_local_frames, interpolation='bilinear', t_dilation=2): + """ + Args: + masks_in: original mask + masks_updated: updated mask after image propagation + """ + + l_t = num_local_frames + b, t, _, ori_h, ori_w = masked_frames.size() + + # extracting features + enc_feat = self.encoder(torch.cat([masked_frames.view(b * t, 3, ori_h, ori_w), + masks_in.view(b * t, 1, ori_h, ori_w), + masks_updated.view(b * t, 1, ori_h, ori_w)], dim=1)) + _, c, h, w = enc_feat.size() + local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...] + ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...] + fold_feat_size = (h, w) + + ds_flows_f = F.interpolate(completed_flows[0].view(-1, 2, ori_h, ori_w), scale_factor=1/4, mode='bilinear', align_corners=False).view(b, l_t-1, 2, h, w)/4.0 + ds_flows_b = F.interpolate(completed_flows[1].view(-1, 2, ori_h, ori_w), scale_factor=1/4, mode='bilinear', align_corners=False).view(b, l_t-1, 2, h, w)/4.0 + ds_mask_in = F.interpolate(masks_in.reshape(-1, 1, ori_h, ori_w), scale_factor=1/4, mode='nearest').view(b, t, 1, h, w) + ds_mask_in_local = ds_mask_in[:, :l_t] + ds_mask_updated_local = F.interpolate(masks_updated[:,:l_t].reshape(-1, 1, ori_h, ori_w), scale_factor=1/4, mode='nearest').view(b, l_t, 1, h, w) + + + if self.training: + mask_pool_l = self.max_pool(ds_mask_in.view(-1, 1, h, w)) + mask_pool_l = mask_pool_l.view(b, t, 1, mask_pool_l.size(-2), mask_pool_l.size(-1)) + else: + mask_pool_l = self.max_pool(ds_mask_in_local.view(-1, 1, h, w)) + mask_pool_l = mask_pool_l.view(b, l_t, 1, mask_pool_l.size(-2), mask_pool_l.size(-1)) + + + prop_mask_in = torch.cat([ds_mask_in_local, ds_mask_updated_local], dim=2) + _, _, local_feat, _ = self.feat_prop_module(local_feat, ds_flows_f, ds_flows_b, prop_mask_in, interpolation) + enc_feat = torch.cat((local_feat, ref_feat), dim=1) + + trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_feat_size) + mask_pool_l = rearrange(mask_pool_l, 'b t c h w -> b t h w c').contiguous() + trans_feat = self.transformers(trans_feat, fold_feat_size, mask_pool_l, t_dilation=t_dilation) + trans_feat = self.sc(trans_feat, t, fold_feat_size) + trans_feat = trans_feat.view(b, t, -1, h, w) + + enc_feat = enc_feat + trans_feat + + if self.training: + output = self.decoder(enc_feat.view(-1, c, h, w)) + output = torch.tanh(output).view(b, t, 3, ori_h, ori_w) + else: + output = self.decoder(enc_feat[:, :l_t].view(-1, c, h, w)) + output = torch.tanh(output).view(b, l_t, 3, ori_h, ori_w) + + return output + + +# ###################################################################### +# Discriminator for Temporal Patch GAN +# ###################################################################### +class Discriminator(BaseNetwork): + def __init__(self, + in_channels=3, + use_sigmoid=False, + use_spectral_norm=True, + init_weights=True): + super(Discriminator, self).__init__() + self.use_sigmoid = use_sigmoid + nf = 32 + + self.conv = nn.Sequential( + spectral_norm( + nn.Conv3d(in_channels=in_channels, + out_channels=nf * 1, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=1, + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(64, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 1, + nf * 2, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(128, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 2, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2))) + + if init_weights: + self.init_weights() + + def forward(self, xs): + # T, C, H, W = xs.shape (old) + # B, T, C, H, W (new) + xs_t = torch.transpose(xs, 1, 2) + feat = self.conv(xs_t) + if self.use_sigmoid: + feat = torch.sigmoid(feat) + out = torch.transpose(feat, 1, 2) # B, T, C, H, W + return out + + +class Discriminator_2D(BaseNetwork): + def __init__(self, + in_channels=3, + use_sigmoid=False, + use_spectral_norm=True, + init_weights=True): + super(Discriminator_2D, self).__init__() + self.use_sigmoid = use_sigmoid + nf = 32 + + self.conv = nn.Sequential( + spectral_norm( + nn.Conv3d(in_channels=in_channels, + out_channels=nf * 1, + kernel_size=(1, 5, 5), + stride=(1, 2, 2), + padding=(0, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(64, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 1, + nf * 2, + kernel_size=(1, 5, 5), + stride=(1, 2, 2), + padding=(0, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(128, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 2, + nf * 4, + kernel_size=(1, 5, 5), + stride=(1, 2, 2), + padding=(0, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(1, 5, 5), + stride=(1, 2, 2), + padding=(0, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(1, 5, 5), + stride=(1, 2, 2), + padding=(0, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(1, 5, 5), + stride=(1, 2, 2), + padding=(0, 2, 2))) + + if init_weights: + self.init_weights() + + def forward(self, xs): + # T, C, H, W = xs.shape (old) + # B, T, C, H, W (new) + xs_t = torch.transpose(xs, 1, 2) + feat = self.conv(xs_t) + if self.use_sigmoid: + feat = torch.sigmoid(feat) + out = torch.transpose(feat, 1, 2) # B, T, C, H, W + return out + +def spectral_norm(module, mode=True): + if mode: + return _spectral_norm(module) + return module diff --git a/model/recurrent_flow_completion.py b/model/recurrent_flow_completion.py new file mode 100644 index 0000000000000000000000000000000000000000..b002f125cca02948bf72a2482181ab9c627b752a --- /dev/null +++ b/model/recurrent_flow_completion.py @@ -0,0 +1,347 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from model.modules.deformconv import ModulatedDeformConv2d +from .misc import constant_init + +class SecondOrderDeformableAlignment(ModulatedDeformConv2d): + """Second-order deformable alignment module.""" + def __init__(self, *args, **kwargs): + self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 5) + + super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Sequential( + nn.Conv2d(3 * self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1), + ) + self.init_offset() + + def init_offset(self): + constant_init(self.conv_offset[-1], val=0, bias=0) + + def forward(self, x, extra_feat): + out = self.conv_offset(extra_feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + + # offset + offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1)) + offset_1, offset_2 = torch.chunk(offset, 2, dim=1) + offset = torch.cat([offset_1, offset_2], dim=1) + + # mask + mask = torch.sigmoid(mask) + + return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, + self.stride, self.padding, + self.dilation, mask) + +class BidirectionalPropagation(nn.Module): + def __init__(self, channel): + super(BidirectionalPropagation, self).__init__() + modules = ['backward_', 'forward_'] + self.deform_align = nn.ModuleDict() + self.backbone = nn.ModuleDict() + self.channel = channel + + for i, module in enumerate(modules): + self.deform_align[module] = SecondOrderDeformableAlignment( + 2 * channel, channel, 3, padding=1, deform_groups=16) + + self.backbone[module] = nn.Sequential( + nn.Conv2d((2 + i) * channel, channel, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(channel, channel, 3, 1, 1), + ) + + self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0) + + def forward(self, x): + """ + x shape : [b, t, c, h, w] + return [b, t, c, h, w] + """ + b, t, c, h, w = x.shape + feats = {} + feats['spatial'] = [x[:, i, :, :, :] for i in range(0, t)] + + for module_name in ['backward_', 'forward_']: + + feats[module_name] = [] + + frame_idx = range(0, t) + mapping_idx = list(range(0, len(feats['spatial']))) + mapping_idx += mapping_idx[::-1] + + if 'backward' in module_name: + frame_idx = frame_idx[::-1] + + feat_prop = x.new_zeros(b, self.channel, h, w) + for i, idx in enumerate(frame_idx): + feat_current = feats['spatial'][mapping_idx[idx]] + if i > 0: + cond_n1 = feat_prop + + # initialize second-order features + feat_n2 = torch.zeros_like(feat_prop) + cond_n2 = torch.zeros_like(cond_n1) + if i > 1: # second-order features + feat_n2 = feats[module_name][-2] + cond_n2 = feat_n2 + + cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1) # condition information, cond(flow warped 1st/2nd feature) + feat_prop = torch.cat([feat_prop, feat_n2], dim=1) # two order feat_prop -1 & -2 + feat_prop = self.deform_align[module_name](feat_prop, cond) + + # fuse current features + feat = [feat_current] + \ + [feats[k][idx] for k in feats if k not in ['spatial', module_name]] \ + + [feat_prop] + + feat = torch.cat(feat, dim=1) + # embed current features + feat_prop = feat_prop + self.backbone[module_name](feat) + + feats[module_name].append(feat_prop) + + # end for + if 'backward' in module_name: + feats[module_name] = feats[module_name][::-1] + + outputs = [] + for i in range(0, t): + align_feats = [feats[k].pop(0) for k in feats if k != 'spatial'] + align_feats = torch.cat(align_feats, dim=1) + outputs.append(self.fusion(align_feats)) + + return torch.stack(outputs, dim=1) + x + + +class deconv(nn.Module): + def __init__(self, + input_channel, + output_channel, + kernel_size=3, + padding=0): + super().__init__() + self.conv = nn.Conv2d(input_channel, + output_channel, + kernel_size=kernel_size, + stride=1, + padding=padding) + + def forward(self, x): + x = F.interpolate(x, + scale_factor=2, + mode='bilinear', + align_corners=True) + return self.conv(x) + + +class P3DBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, use_residual=0, bias=True): + super().__init__() + self.conv1 = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=(1, kernel_size, kernel_size), + stride=(1, stride, stride), padding=(0, padding, padding), bias=bias), + nn.LeakyReLU(0.2, inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), + padding=(2, 0, 0), dilation=(2, 1, 1), bias=bias) + ) + self.use_residual = use_residual + + def forward(self, feats): + feat1 = self.conv1(feats) + feat2 = self.conv2(feat1) + if self.use_residual: + output = feats + feat2 + else: + output = feat2 + return output + + +class EdgeDetection(nn.Module): + def __init__(self, in_ch=2, out_ch=1, mid_ch=16): + super().__init__() + self.projection = nn.Sequential( + nn.Conv2d(in_ch, mid_ch, 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True) + ) + + self.mid_layer_1 = nn.Sequential( + nn.Conv2d(mid_ch, mid_ch, 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True) + ) + + self.mid_layer_2 = nn.Sequential( + nn.Conv2d(mid_ch, mid_ch, 3, 1, 1) + ) + + self.l_relu = nn.LeakyReLU(0.01, inplace=True) + + self.out_layer = nn.Conv2d(mid_ch, out_ch, 1, 1, 0) + + def forward(self, flow): + flow = self.projection(flow) + edge = self.mid_layer_1(flow) + edge = self.mid_layer_2(edge) + edge = self.l_relu(flow + edge) + edge = self.out_layer(edge) + edge = torch.sigmoid(edge) + return edge + + +class RecurrentFlowCompleteNet(nn.Module): + def __init__(self, model_path=None): + super().__init__() + self.downsample = nn.Sequential( + nn.Conv3d(3, 32, kernel_size=(1, 5, 5), stride=(1, 2, 2), + padding=(0, 2, 2), padding_mode='replicate'), + nn.LeakyReLU(0.2, inplace=True) + ) + + self.encoder1 = nn.Sequential( + P3DBlock(32, 32, 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + P3DBlock(32, 64, 3, 2, 1), + nn.LeakyReLU(0.2, inplace=True) + ) # 4x + + self.encoder2 = nn.Sequential( + P3DBlock(64, 64, 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + P3DBlock(64, 128, 3, 2, 1), + nn.LeakyReLU(0.2, inplace=True) + ) # 8x + + self.mid_dilation = nn.Sequential( + nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 3, 3), dilation=(1, 3, 3)), # p = d*(k-1)/2 + nn.LeakyReLU(0.2, inplace=True), + nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 2, 2), dilation=(1, 2, 2)), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 1, 1), dilation=(1, 1, 1)), + nn.LeakyReLU(0.2, inplace=True) + ) + + # feature propagation module + self.feat_prop_module = BidirectionalPropagation(128) + + self.decoder2 = nn.Sequential( + nn.Conv2d(128, 128, 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + deconv(128, 64, 3, 1), + nn.LeakyReLU(0.2, inplace=True) + ) # 4x + + self.decoder1 = nn.Sequential( + nn.Conv2d(64, 64, 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + deconv(64, 32, 3, 1), + nn.LeakyReLU(0.2, inplace=True) + ) # 2x + + self.upsample = nn.Sequential( + nn.Conv2d(32, 32, 3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + deconv(32, 2, 3, 1) + ) + + # edge loss + self.edgeDetector = EdgeDetection(in_ch=2, out_ch=1, mid_ch=16) + + # Need to initial the weights of MSDeformAttn specifically + for m in self.modules(): + if isinstance(m, SecondOrderDeformableAlignment): + m.init_offset() + + if model_path is not None: + print('Pretrained flow completion model has loaded...') + ckpt = torch.load(model_path, map_location='cpu') + self.load_state_dict(ckpt, strict=True) + + + def forward(self, masked_flows, masks): + # masked_flows: b t-1 2 h w + # masks: b t-1 2 h w + b, t, _, h, w = masked_flows.size() + masked_flows = masked_flows.permute(0,2,1,3,4) + masks = masks.permute(0,2,1,3,4) + + inputs = torch.cat((masked_flows, masks), dim=1) + + x = self.downsample(inputs) + + feat_e1 = self.encoder1(x) + feat_e2 = self.encoder2(feat_e1) # b c t h w + feat_mid = self.mid_dilation(feat_e2) # b c t h w + feat_mid = feat_mid.permute(0,2,1,3,4) # b t c h w + + feat_prop = self.feat_prop_module(feat_mid) + feat_prop = feat_prop.view(-1, 128, h//8, w//8) # b*t c h w + + _, c, _, h_f, w_f = feat_e1.shape + feat_e1 = feat_e1.permute(0,2,1,3,4).contiguous().view(-1, c, h_f, w_f) # b*t c h w + feat_d2 = self.decoder2(feat_prop) + feat_e1 + + _, c, _, h_f, w_f = x.shape + x = x.permute(0,2,1,3,4).contiguous().view(-1, c, h_f, w_f) # b*t c h w + + feat_d1 = self.decoder1(feat_d2) + + flow = self.upsample(feat_d1) + if self.training: + edge = self.edgeDetector(flow) + edge = edge.view(b, t, 1, h, w) + else: + edge = None + + flow = flow.view(b, t, 2, h, w) + + return flow, edge + + + def forward_bidirect_flow(self, masked_flows_bi, masks): + """ + Args: + masked_flows_bi: [masked_flows_f, masked_flows_b] | (b t-1 2 h w), (b t-1 2 h w) + masks: b t 1 h w + """ + masks_forward = masks[:, :-1, ...].contiguous() + masks_backward = masks[:, 1:, ...].contiguous() + + # mask flow + masked_flows_forward = masked_flows_bi[0] * (1-masks_forward) + masked_flows_backward = masked_flows_bi[1] * (1-masks_backward) + + # -- completion -- + # forward + pred_flows_forward, pred_edges_forward = self.forward(masked_flows_forward, masks_forward) + + # backward + masked_flows_backward = torch.flip(masked_flows_backward, dims=[1]) + masks_backward = torch.flip(masks_backward, dims=[1]) + pred_flows_backward, pred_edges_backward = self.forward(masked_flows_backward, masks_backward) + pred_flows_backward = torch.flip(pred_flows_backward, dims=[1]) + if self.training: + pred_edges_backward = torch.flip(pred_edges_backward, dims=[1]) + + return [pred_flows_forward, pred_flows_backward], [pred_edges_forward, pred_edges_backward] + + + def combine_flow(self, masked_flows_bi, pred_flows_bi, masks): + masks_forward = masks[:, :-1, ...].contiguous() + masks_backward = masks[:, 1:, ...].contiguous() + + pred_flows_forward = pred_flows_bi[0] * masks_forward + masked_flows_bi[0] * (1-masks_forward) + pred_flows_backward = pred_flows_bi[1] * masks_backward + masked_flows_bi[1] * (1-masks_backward) + + return pred_flows_forward, pred_flows_backward diff --git a/model/vgg_arch.py b/model/vgg_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..43fc2ff8bc1c73313d632c6ab326372d389a4772 --- /dev/null +++ b/model/vgg_arch.py @@ -0,0 +1,157 @@ +import os +import torch +from collections import OrderedDict +from torch import nn as nn +from torchvision.models import vgg as vgg + +VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' +NAMES = { + 'vgg11': [ + 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', + 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', + 'pool5' + ], + 'vgg13': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', + 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' + ], + 'vgg16': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', + 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', + 'pool5' + ], + 'vgg19': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', + 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', + 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' + ] +} + + +def insert_bn(names): + """Insert bn layer after each conv. + + Args: + names (list): The list of layer names. + + Returns: + list: The list of layer names with bn layers. + """ + names_bn = [] + for name in names: + names_bn.append(name) + if 'conv' in name: + position = name.replace('conv', '') + names_bn.append('bn' + position) + return names_bn + +class VGGFeatureExtractor(nn.Module): + """VGG network for feature extraction. + + In this implementation, we allow users to choose whether use normalization + in the input feature and the type of vgg network. Note that the pretrained + path must fit the vgg type. + + Args: + layer_name_list (list[str]): Forward function returns the corresponding + features according to the layer_name_list. + Example: {'relu1_1', 'relu2_1', 'relu3_1'}. + vgg_type (str): Set the type of vgg network. Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image. Importantly, + the input feature must in the range [0, 1]. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + requires_grad (bool): If true, the parameters of VGG network will be + optimized. Default: False. + remove_pooling (bool): If true, the max pooling operations in VGG net + will be removed. Default: False. + pooling_stride (int): The stride of max pooling operation. Default: 2. + """ + + def __init__(self, + layer_name_list, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + requires_grad=False, + remove_pooling=False, + pooling_stride=2): + super(VGGFeatureExtractor, self).__init__() + + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + self.names = NAMES[vgg_type.replace('_bn', '')] + if 'bn' in vgg_type: + self.names = insert_bn(self.names) + + # only borrow layers that will be used to avoid unused params + max_idx = 0 + for v in layer_name_list: + idx = self.names.index(v) + if idx > max_idx: + max_idx = idx + + if os.path.exists(VGG_PRETRAIN_PATH): + vgg_net = getattr(vgg, vgg_type)(pretrained=False) + state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) + vgg_net.load_state_dict(state_dict) + else: + vgg_net = getattr(vgg, vgg_type)(pretrained=True) + + features = vgg_net.features[:max_idx + 1] + + modified_net = OrderedDict() + for k, v in zip(self.names, features): + if 'pool' in k: + # if remove_pooling is true, pooling operation will be removed + if remove_pooling: + continue + else: + # in some cases, we may want to change the default stride + modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) + else: + modified_net[k] = v + + self.vgg_net = nn.Sequential(modified_net) + + if not requires_grad: + self.vgg_net.eval() + for param in self.parameters(): + param.requires_grad = False + else: + self.vgg_net.train() + for param in self.parameters(): + param.requires_grad = True + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + if self.range_norm: + x = (x + 1) / 2 + if self.use_input_norm: + x = (x - self.mean) / self.std + output = {} + + for key, layer in self.vgg_net._modules.items(): + x = layer(x) + if key in self.layer_name_list: + output[key] = x.clone() + + return output diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1f9bf0a5dd2787423a79f4292684c267893bad4c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,33 @@ +av +addict +einops +future +numpy +scipy +opencv-python +matplotlib +scikit-image +torch>=1.7.1 +torchvision>=0.8.2 +imageio-ffmpeg +pyyaml +requests +timm +yapf +progressbar2 +gdown +gitpython +git+https://github.com/cheind/py-thin-plate-spline +hickle +tensorboard +numpy +git+https://github.com/facebookresearch/segment-anything.git +gradio +opencv-python +matplotlib +pyyaml +av +openmim +tqdm +psutil +omegaconf \ No newline at end of file diff --git a/scripts/compute_flow.py b/scripts/compute_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..8596e4dc95c1969826adaf9c72a076584886ece2 --- /dev/null +++ b/scripts/compute_flow.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +import sys +sys.path.append(".") + +import os +import cv2 +import argparse +from PIL import Image +import torch +import torch.nn.functional as F +from torchvision import transforms + +from RAFT import RAFT +from utils.flow_util import * + +def imwrite(img, file_path, params=None, auto_mkdir=True): + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + +def initialize_RAFT(model_path='weights/raft-things.pth', device='cuda'): + """Initializes the RAFT model. + """ + args = argparse.ArgumentParser() + args.raft_model = model_path + args.small = False + args.mixed_precision = False + args.alternate_corr = False + + model = torch.nn.DataParallel(RAFT(args)) + model.load_state_dict(torch.load(args.raft_model)) + + model = model.module + model.to(device) + model.eval() + + return model + + +if __name__ == '__main__': + device = 'cuda' + + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--root_path', type=str, default='your_dataset_root/youtube-vos/JPEGImages') + parser.add_argument('-o', '--save_path', type=str, default='your_dataset_root/youtube-vos/Flows_flo') + parser.add_argument('--height', type=int, default=240) + parser.add_argument('--width', type=int, default=432) + + args = parser.parse_args() + + # Flow model + RAFT_model = initialize_RAFT(device=device) + + root_path = args.root_path + save_path = args.save_path + h_new, w_new = (args.height, args.width) + + file_list = sorted(os.listdir(root_path)) + for f in file_list: + print(f'Processing: {f} ...') + m_list = sorted(os.listdir(os.path.join(root_path, f))) + len_m = len(m_list) + for i in range(len_m-1): + img1_path = os.path.join(root_path, f, m_list[i]) + img2_path = os.path.join(root_path, f, m_list[i+1]) + img1 = Image.fromarray(cv2.imread(img1_path)) + img2 = Image.fromarray(cv2.imread(img2_path)) + + transform = transforms.Compose([transforms.ToTensor()]) + + img1 = transform(img1).unsqueeze(0).to(device)[:,[2,1,0],:,:] + img2 = transform(img2).unsqueeze(0).to(device)[:,[2,1,0],:,:] + + # upsize to a multiple of 16 + # h, w = img1.shape[2:4] + # w_new = w if (w % 16) == 0 else 16 * (w // 16 + 1) + # h_new = h if (h % 16) == 0 else 16 * (h // 16 + 1) + + + img1 = F.interpolate(input=img1, + size=(h_new, w_new), + mode='bilinear', + align_corners=False) + img2 = F.interpolate(input=img2, + size=(h_new, w_new), + mode='bilinear', + align_corners=False) + + with torch.no_grad(): + img1 = img1*2 - 1 + img2 = img2*2 - 1 + + _, flow_f = RAFT_model(img1, img2, iters=20, test_mode=True) + _, flow_b = RAFT_model(img2, img1, iters=20, test_mode=True) + + + flow_f = flow_f[0].permute(1,2,0).cpu().numpy() + flow_b = flow_b[0].permute(1,2,0).cpu().numpy() + + # flow_f = resize_flow(flow_f, w_new, h_new) + # flow_b = resize_flow(flow_b, w_new, h_new) + + save_flow_f = os.path.join(save_path, f, f'{m_list[i][:-4]}_{m_list[i+1][:-4]}_f.flo') + save_flow_b = os.path.join(save_path, f, f'{m_list[i+1][:-4]}_{m_list[i][:-4]}_b.flo') + + flowwrite(flow_f, save_flow_f, quantize=False) + flowwrite(flow_b, save_flow_b, quantize=False) diff --git a/scripts/evaluate_flow_completion.py b/scripts/evaluate_flow_completion.py new file mode 100644 index 0000000000000000000000000000000000000000..405586bfe3ca4410134c7dddc4de7fb04b60bf3c --- /dev/null +++ b/scripts/evaluate_flow_completion.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +import sys +sys.path.append(".") + +import cv2 +import os +import numpy as np +import argparse +from PIL import Image + +import torch +from torch.utils.data import DataLoader + +from core.dataset import TestDataset +from model.modules.flow_comp_raft import RAFT_bi +from model.recurrent_flow_completion import RecurrentFlowCompleteNet + +from RAFT.utils.flow_viz_pt import flow_to_image + +import cvbase +import imageio +from time import time + +import warnings +warnings.filterwarnings("ignore") + +def create_dir(dir): + """Creates a directory if not exist. + """ + if not os.path.exists(dir): + os.makedirs(dir) + +def save_flows(output, videoFlowF, videoFlowB): + # create_dir(os.path.join(output, 'forward_flo')) + # create_dir(os.path.join(output, 'backward_flo')) + create_dir(os.path.join(output, 'forward_png')) + create_dir(os.path.join(output, 'backward_png')) + N = videoFlowF.shape[-1] + for i in range(N): + forward_flow = videoFlowF[..., i] + backward_flow = videoFlowB[..., i] + forward_flow_vis = cvbase.flow2rgb(forward_flow) + backward_flow_vis = cvbase.flow2rgb(backward_flow) + # cvbase.write_flow(forward_flow, os.path.join(output, 'forward_flo', '{:05d}.flo'.format(i))) + # cvbase.write_flow(backward_flow, os.path.join(output, 'backward_flo', '{:05d}.flo'.format(i))) + forward_flow_vis = (forward_flow_vis*255.0).astype(np.uint8) + backward_flow_vis = (backward_flow_vis*255.0).astype(np.uint8) + imageio.imwrite(os.path.join(output, 'forward_png', '{:05d}.png'.format(i)), forward_flow_vis) + imageio.imwrite(os.path.join(output, 'backward_png', '{:05d}.png'.format(i)), backward_flow_vis) + +def tensor2np(array): + array = torch.stack(array, dim=-1).squeeze(0).permute(1, 2, 0, 3).cpu().numpy() + return array + +def main_worker(args): + # set up datasets and data loader + args.size = (args.width, args.height) + test_dataset = TestDataset(vars(args)) + + test_loader = DataLoader(test_dataset, + batch_size=1, + shuffle=False, + num_workers=args.num_workers) + + # set up models + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + fix_raft = RAFT_bi(args.raft_model_path, device) + + fix_flow_complete = RecurrentFlowCompleteNet(args.fc_model_path) + for p in fix_flow_complete.parameters(): + p.requires_grad = False + fix_flow_complete.to(device) + fix_flow_complete.eval() + + total_frame_epe = [] + time_all = [] + + print('Start evaluation...') + # create results directory + result_path = os.path.join('results_flow', f'{args.dataset}') + if not os.path.exists(result_path): + os.makedirs(result_path) + + eval_summary = open(os.path.join(result_path, f"{args.dataset}_metrics.txt"), "w") + + for index, items in enumerate(test_loader): + frames, masks, flows_f, flows_b, video_name, frames_PIL = items + local_masks = masks.float().to(device) + + video_length = frames.size(1) + + if args.load_flow: + gt_flows_bi = (flows_f.to(device), flows_b.to(device)) + else: + short_len = 60 + if frames.size(1) > short_len: + gt_flows_f_list, gt_flows_b_list = [], [] + for f in range(0, video_length, short_len): + end_f = min(video_length, f + short_len) + if f == 0: + flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=args.raft_iter) + else: + flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=args.raft_iter) + + gt_flows_f_list.append(flows_f) + gt_flows_b_list.append(flows_b) + gt_flows_f = torch.cat(gt_flows_f_list, dim=1) + gt_flows_b = torch.cat(gt_flows_b_list, dim=1) + gt_flows_bi = (gt_flows_f, gt_flows_b) + else: + gt_flows_bi = fix_raft(frames, iters=20) + + torch.cuda.synchronize() + time_start = time() + + # flow_length = flows_f.size(1) + # f_stride = 30 + # pred_flows_f = [] + # pred_flows_b = [] + # suffix = flow_length%f_stride + # last = flow_length//f_stride + # for f in range(0, flow_length, f_stride): + # gt_flows_bi_i = (gt_flows_bi[0][:,f:f+f_stride], gt_flows_bi[1][:,f:f+f_stride]) + # pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi_i, local_masks[:,f:f+f_stride+1]) + # pred_flows_f_i, pred_flows_b_i = fix_flow_complete.combine_flow(gt_flows_bi_i, pred_flows_bi, local_masks[:,f:f+f_stride+1]) + # pred_flows_f.append(pred_flows_f_i) + # pred_flows_b.append(pred_flows_b_i) + # pred_flows_f = torch.cat(pred_flows_f, dim=1) + # pred_flows_b = torch.cat(pred_flows_b, dim=1) + # pred_flows_bi = (pred_flows_f, pred_flows_b) + + pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi, local_masks) + pred_flows_bi = fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, local_masks) + + torch.cuda.synchronize() + time_i = time() - time_start + time_i = time_i*1.0/frames.size(1) + + time_all = time_all+[time_i]*frames.size(1) + + cur_video_epe = [] + + epe1 = torch.mean(torch.sum((flows_f - pred_flows_bi[0].cpu())**2, dim=2).sqrt()) + epe2 = torch.mean(torch.sum((flows_b - pred_flows_bi[1].cpu())**2, dim=2).sqrt()) + + cur_video_epe.append(epe1.numpy()) + cur_video_epe.append(epe2.numpy()) + + total_frame_epe = total_frame_epe+[epe1.numpy()]*flows_f.size(1) + total_frame_epe = total_frame_epe+[epe2.numpy()]*flows_f.size(1) + + cur_epe = sum(cur_video_epe) / len(cur_video_epe) + avg_time = sum(time_all) / len(time_all) + print( + f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | EPE: {cur_epe:.4f} | Time: {avg_time:.4f}' + ) + eval_summary.write( + f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | EPE: {cur_epe:.4f} | Time: {avg_time:.4f}\n' + ) + + # saving images for evaluating warpping errors + if args.save_results: + forward_flows = pred_flows_bi[0].cpu().permute(1,0,2,3,4) + backward_flows = pred_flows_bi[1].cpu().permute(1,0,2,3,4) + # forward_flows = flows_f.cpu().permute(1,0,2,3,4) + # backward_flows = flows_b.cpu().permute(1,0,2,3,4) + videoFlowF = list(forward_flows) + videoFlowB = list(backward_flows) + + videoFlowF = tensor2np(videoFlowF) + videoFlowB = tensor2np(videoFlowB) + + save_frame_path = os.path.join(result_path, video_name[0]) + save_flows(save_frame_path, videoFlowF, videoFlowB) + + avg_frame_epe = sum(total_frame_epe) / len(total_frame_epe) + + print(f'Finish evaluation... Average Frame EPE: {avg_frame_epe:.4f} | | Time: {avg_time:.4f}') + eval_summary.write(f'Finish evaluation... Average Frame EPE: {avg_frame_epe:.4f} | | Time: {avg_time:.4f}\n') + eval_summary.close() + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--height', type=int, default=240) + parser.add_argument('--width', type=int, default=432) + parser.add_argument('--raft_model_path', default='weights/raft-things.pth', type=str) + parser.add_argument('--fc_model_path', default='weights/recurrent_flow_completion.pth', type=str) + parser.add_argument('--dataset', choices=['davis', 'youtube-vos'], type=str) + parser.add_argument('--video_root', default='dataset_root', type=str) + parser.add_argument('--mask_root', default='mask_root', type=str) + parser.add_argument('--flow_root', default='flow_ground_truth_root', type=str) + parser.add_argument('--load_flow', default=False, type=bool) + parser.add_argument("--raft_iter", type=int, default=20) + parser.add_argument('--save_results', action='store_true') + parser.add_argument('--num_workers', default=4, type=int) + args = parser.parse_args() + main_worker(args) diff --git a/scripts/evaluate_propainter.py b/scripts/evaluate_propainter.py new file mode 100644 index 0000000000000000000000000000000000000000..51b65f8511ab1df1d01e1581f224067857c55e9b --- /dev/null +++ b/scripts/evaluate_propainter.py @@ -0,0 +1,280 @@ +# -*- coding: utf-8 -*- +import sys +sys.path.append(".") + +import os +import cv2 +import numpy as np +import argparse +from PIL import Image +import torch.nn.functional as F + +import torch +from torch.utils.data import DataLoader + +from model.modules.flow_comp_raft import RAFT_bi +from model.recurrent_flow_completion import RecurrentFlowCompleteNet +from model.propainter import InpaintGenerator + +# from core.dataset import TestDataset +from core.dataset import TestDataset +from core.metrics import calc_psnr_and_ssim, calculate_i3d_activations, calculate_vfid, init_i3d_model + +from time import time + +import warnings +warnings.filterwarnings("ignore") + +# sample reference frames from the whole video +def get_ref_index(neighbor_ids, length, ref_stride=10): + ref_index = [] + for i in range(0, length, ref_stride): + if i not in neighbor_ids: + ref_index.append(i) + return ref_index + + +def main_worker(args): + args.size = (args.width, args.height) + w, h = args.size + # set up datasets and data loader + assert (args.dataset == 'davis') or args.dataset == 'youtube-vos', \ + f"{args.dataset} dataset is not supported" + test_dataset = TestDataset(vars(args)) + + test_loader = DataLoader(test_dataset, + batch_size=1, + shuffle=False, + num_workers=args.num_workers) + + # set up models + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + fix_raft = RAFT_bi(args.raft_model_path, device) + + fix_flow_complete = RecurrentFlowCompleteNet(args.fc_model_path) + for p in fix_flow_complete.parameters(): + p.requires_grad = False + fix_flow_complete.to(device) + fix_flow_complete.eval() + + model = InpaintGenerator(model_path=args.propainter_model_path).to(device) + model.eval() + + time_all = [] + + + print('Start evaluation ...') + if args.task == 'video_completion': + result_path = os.path.join(f'results_eval', + f'{args.dataset}_rs_{args.ref_stride}_nl_{args.neighbor_length}_video_completion') + if not os.path.exists(result_path): + os.makedirs(result_path, exist_ok=True) + eval_summary = open(os.path.join(result_path, f"{args.dataset}_metrics.txt"),"w") + total_frame_psnr = [] + total_frame_ssim = [] + output_i3d_activations = [] + real_i3d_activations = [] + i3d_model = init_i3d_model('weights/i3d_rgb_imagenet.pt') + else: + result_path = os.path.join(f'results_eval', + f'{args.dataset}_rs_{args.ref_stride}_nl_{args.neighbor_length}_object_removal') + if not os.path.exists(result_path): + os.makedirs(result_path, exist_ok=True) + + if not os.path.exists(result_path): + os.makedirs(result_path) + + + for index, items in enumerate(test_loader): + torch.cuda.empty_cache() + + # frames, masks, video_name, frames_PIL = items + frames, masks, flows_f, flows_b, video_name, frames_PIL = items + video_name = video_name[0] + print('Processing:', video_name) + + video_length = frames.size(1) + frames, masks = frames.to(device), masks.to(device) + masked_frames = frames * (1 - masks) + + torch.cuda.synchronize() + time_start = time() + + with torch.no_grad(): + # ---- compute flow ---- + if args.load_flow: + gt_flows_bi = (flows_f.to(device), flows_b.to(device)) + else: + short_len = 60 + if frames.size(1) > short_len: + gt_flows_f_list, gt_flows_b_list = [], [] + for f in range(0, video_length, short_len): + end_f = min(video_length, f + short_len) + if f == 0: + flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=args.raft_iter) + else: + flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=args.raft_iter) + + gt_flows_f_list.append(flows_f) + gt_flows_b_list.append(flows_b) + gt_flows_f = torch.cat(gt_flows_f_list, dim=1) + gt_flows_b = torch.cat(gt_flows_b_list, dim=1) + gt_flows_bi = (gt_flows_f, gt_flows_b) + else: + gt_flows_bi = fix_raft(frames, iters=args.raft_iter) + + # ---- complete flow ---- + pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi, masks) + pred_flows_bi = fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, masks) + + # ---- temporal propagation ---- + prop_imgs, updated_local_masks = model.img_propagation(masked_frames, pred_flows_bi, masks, 'nearest') + + b, t, _, _, _ = masks.size() + updated_masks = updated_local_masks.view(b, t, 1, h, w) + updated_frames = frames * (1-masks) + prop_imgs.view(b, t, 3, h, w) * masks # merge + + del gt_flows_bi, frames, updated_local_masks + if not args.load_flow: + torch.cuda.empty_cache() + + ori_frames = frames_PIL + ori_frames = [ + ori_frames[i].squeeze().cpu().numpy() for i in range(video_length) + ] + comp_frames = [None] * video_length + + # complete holes by our model + neighbor_stride = args.neighbor_length // 2 + for f in range(0, video_length, neighbor_stride): + neighbor_ids = [ + i for i in range(max(0, f - neighbor_stride), + min(video_length, f + neighbor_stride + 1)) + ] + ref_ids = get_ref_index(neighbor_ids, video_length, args.ref_stride) + selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :] + selected_masks = masks[:, neighbor_ids + ref_ids, :, :, :] + selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :] + selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :]) + + with torch.no_grad(): + l_t = len(neighbor_ids) + pred_img = model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t) + pred_img = pred_img.view(-1, 3, h, w) + + + pred_img = (pred_img + 1) / 2 + pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 + binary_masks = masks[0, neighbor_ids, :, :, :].cpu().permute( + 0, 2, 3, 1).numpy().astype(np.uint8) + for i in range(len(neighbor_ids)): + idx = neighbor_ids[i] + img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \ + + ori_frames[idx] * (1 - binary_masks[i]) + if comp_frames[idx] is None: + comp_frames[idx] = img + else: + comp_frames[idx] = comp_frames[idx].astype( + np.float32) * 0.5 + img.astype(np.float32) * 0.5 + + + torch.cuda.synchronize() + time_i = time() - time_start + time_i = time_i*1.0/video_length + time_all.append(time_i) + + if args.task == 'video_completion': + # calculate metrics + cur_video_psnr = [] + cur_video_ssim = [] + comp_PIL = [] # to calculate VFID + frames_PIL = [] + for ori, comp in zip(ori_frames, comp_frames): + psnr, ssim = calc_psnr_and_ssim(ori, comp) + + cur_video_psnr.append(psnr) + cur_video_ssim.append(ssim) + + total_frame_psnr.append(psnr) + total_frame_ssim.append(ssim) + + frames_PIL.append(Image.fromarray(ori.astype(np.uint8))) + comp_PIL.append(Image.fromarray(comp.astype(np.uint8))) + + # saving i3d activations + frames_i3d, comp_i3d = calculate_i3d_activations(frames_PIL, + comp_PIL, + i3d_model, + device=device) + real_i3d_activations.append(frames_i3d) + output_i3d_activations.append(comp_i3d) + + cur_psnr = sum(cur_video_psnr) / len(cur_video_psnr) + cur_ssim = sum(cur_video_ssim) / len(cur_video_ssim) + + avg_psnr = sum(total_frame_psnr) / len(total_frame_psnr) + avg_ssim = sum(total_frame_ssim) / len(total_frame_ssim) + + avg_time = sum(time_all) / len(time_all) + print( + f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | PSNR/SSIM: {cur_psnr:.4f}/{cur_ssim:.4f} \ + | Avg PSNR/SSIM: {avg_psnr:.4f}/{avg_ssim:.4f} | Time: {avg_time:.4f}' + ) + eval_summary.write( + f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | PSNR/SSIM: {cur_psnr:.4f}/{cur_ssim:.4f} \ + | Avg PSNR/SSIM: {avg_psnr:.4f}/{avg_ssim:.4f} | Time: {avg_time:.4f}\n' + ) + else: + avg_time = sum(time_all) / len(time_all) + print( + f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | Time: {avg_time:.4f}' + ) + + # saving images for evaluating warpping errors + if args.save_results: + save_frame_path = os.path.join(result_path, video_name) + if not os.path.exists(save_frame_path): + os.makedirs(save_frame_path, exist_ok=False) + + for i, frame in enumerate(comp_frames): + cv2.imwrite( + os.path.join(save_frame_path, + str(i).zfill(5) + '.png'), + cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR)) + + if args.task == 'video_completion': + avg_frame_psnr = sum(total_frame_psnr) / len(total_frame_psnr) + avg_frame_ssim = sum(total_frame_ssim) / len(total_frame_ssim) + + fid_score = calculate_vfid(real_i3d_activations, output_i3d_activations) + print('Finish evaluation... Average Frame PSNR/SSIM/VFID: ' + f'{avg_frame_psnr:.2f}/{avg_frame_ssim:.4f}/{fid_score:.3f} | Time: {avg_time:.4f}') + eval_summary.write( + 'Finish evaluation... Average Frame PSNR/SSIM/VFID: ' + f'{avg_frame_psnr:.2f}/{avg_frame_ssim:.4f}/{fid_score:.3f} | Time: {avg_time:.4f}') + eval_summary.close() + else: + print('Finish evaluation... Time: {avg_time:.4f}') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--height', type=int, default=240) + parser.add_argument('--width', type=int, default=432) + parser.add_argument("--ref_stride", type=int, default=10) + parser.add_argument("--neighbor_length", type=int, default=20) + parser.add_argument("--raft_iter", type=int, default=20) + parser.add_argument('--task', default='video_completion', choices=['object_removal', 'video_completion']) + parser.add_argument('--raft_model_path', default='weights/raft-things.pth', type=str) + parser.add_argument('--fc_model_path', default='weights/recurrent_flow_completion.pth', type=str) + parser.add_argument('--propainter_model_path', default='weights/ProPainter.pth', type=str) + parser.add_argument('--dataset', choices=['davis', 'youtube-vos'], type=str) + parser.add_argument('--video_root', default='dataset_root', type=str) + parser.add_argument('--mask_root', default='mask_root', type=str) + parser.add_argument('--flow_root', default='flow_ground_truth_root', type=str) + parser.add_argument('--load_flow', default=False, type=bool) + parser.add_argument('--save_results', action='store_true') + parser.add_argument('--num_workers', default=4, type=int) + + args = parser.parse_args() + main_worker(args) diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..56cc0ee28913b8e23fd2f37b1899d00d4386f9ce --- /dev/null +++ b/train.py @@ -0,0 +1,105 @@ +import os +import json +import argparse +import subprocess + +from shutil import copyfile +import torch.distributed as dist + +import torch +import torch.multiprocessing as mp + +import core +import core.trainer +import core.trainer_flow_w_edge + + +# import warnings +# warnings.filterwarnings("ignore") + +from core.dist import ( + get_world_size, + get_local_rank, + get_global_rank, + get_master_ip, +) + +parser = argparse.ArgumentParser() +parser.add_argument('-c', + '--config', + default='configs/train_propainter.json', + type=str) +parser.add_argument('-p', '--port', default='23490', type=str) +args = parser.parse_args() + + +def main_worker(rank, config): + if 'local_rank' not in config: + config['local_rank'] = config['global_rank'] = rank + if config['distributed']: + torch.cuda.set_device(int(config['local_rank'])) + torch.distributed.init_process_group(backend='nccl', + init_method=config['init_method'], + world_size=config['world_size'], + rank=config['global_rank'], + group_name='mtorch') + print('using GPU {}-{} for training'.format(int(config['global_rank']), + int(config['local_rank']))) + + + config['save_dir'] = os.path.join( + config['save_dir'], + '{}_{}'.format(config['model']['net'], + os.path.basename(args.config).split('.')[0])) + + config['save_metric_dir'] = os.path.join( + './scores', + '{}_{}'.format(config['model']['net'], + os.path.basename(args.config).split('.')[0])) + + if torch.cuda.is_available(): + config['device'] = torch.device("cuda:{}".format(config['local_rank'])) + else: + config['device'] = 'cpu' + + if (not config['distributed']) or config['global_rank'] == 0: + os.makedirs(config['save_dir'], exist_ok=True) + config_path = os.path.join(config['save_dir'], + args.config.split('/')[-1]) + if not os.path.isfile(config_path): + copyfile(args.config, config_path) + print('[**] create folder {}'.format(config['save_dir'])) + + trainer_version = config['trainer']['version'] + trainer = core.__dict__[trainer_version].__dict__['Trainer'](config) + # Trainer(config) + trainer.train() + + +if __name__ == "__main__": + + torch.backends.cudnn.benchmark = True + + mp.set_sharing_strategy('file_system') + + # loading configs + config = json.load(open(args.config)) + + # setting distributed configurations + # config['world_size'] = get_world_size() + config['world_size'] = torch.cuda.device_count() + config['init_method'] = f"tcp://{get_master_ip()}:{args.port}" + config['distributed'] = True if config['world_size'] > 1 else False + print('world_size:', config['world_size']) + # setup distributed parallel training environments + + # if get_master_ip() == "127.0.0.X": + # # manually launch distributed processes + # mp.spawn(main_worker, nprocs=config['world_size'], args=(config, )) + # else: + # # multiple processes have been launched by openmpi + # config['local_rank'] = get_local_rank() + # config['global_rank'] = get_global_rank() + # main_worker(-1, config) + + mp.spawn(main_worker, nprocs=torch.cuda.device_count(), args=(config, )) \ No newline at end of file diff --git a/utils/download_util.py b/utils/download_util.py new file mode 100644 index 0000000000000000000000000000000000000000..5e8fb1b00522309d0c0931f5396355011fb200e7 --- /dev/null +++ b/utils/download_util.py @@ -0,0 +1,109 @@ +import math +import os +import requests +from torch.hub import download_url_to_file, get_dir +from tqdm import tqdm +from urllib.parse import urlparse + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formated file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + Ref: + https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = 'https://docs.google.com/uc?export=download' + params = {'id': file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params['confirm'] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + print(response_file_size) + if 'Content-Range' in response_file_size.headers: + file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, destination, file_size=None, chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, 'wb') as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file \ No newline at end of file diff --git a/utils/file_client.py b/utils/file_client.py new file mode 100644 index 0000000000000000000000000000000000000000..c7578dec8304475b3906d5dcc734a5a58b56f2ad --- /dev/null +++ b/utils/file_client.py @@ -0,0 +1,166 @@ +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError('Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing differnet lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError(f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) diff --git a/utils/flow_util.py b/utils/flow_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f7e09bd0427cfdbaddbd54cd6f0a990c7a0bed20 --- /dev/null +++ b/utils/flow_util.py @@ -0,0 +1,196 @@ +import cv2 +import numpy as np +import os +import torch.nn.functional as F + +def resize_flow(flow, newh, neww): + oldh, oldw = flow.shape[0:2] + flow = cv2.resize(flow, (neww, newh), interpolation=cv2.INTER_LINEAR) + flow[:, :, 0] *= newh / oldh + flow[:, :, 1] *= neww / oldw + return flow + +def resize_flow_pytorch(flow, newh, neww): + oldh, oldw = flow.shape[-2:] + flow = F.interpolate(flow, (newh, neww), mode='bilinear') + flow[:, :, 0] *= newh / oldh + flow[:, :, 1] *= neww / oldw + return flow + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): + """Read an optical flow map. + + Args: + flow_path (ndarray or str): Flow path. + quantize (bool): whether to read quantized pair, if set to True, + remaining args will be passed to :func:`dequantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + + Returns: + ndarray: Optical flow represented as a (h, w, 2) numpy array + """ + if quantize: + assert concat_axis in [0, 1] + cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) + if cat_flow.ndim != 2: + raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.') + assert cat_flow.shape[concat_axis] % 2 == 0 + dx, dy = np.split(cat_flow, 2, axis=concat_axis) + flow = dequantize_flow(dx, dy, *args, **kwargs) + else: + with open(flow_path, 'rb') as f: + try: + header = f.read(4).decode('utf-8') + except Exception: + raise IOError(f'Invalid flow file: {flow_path}') + else: + if header != 'PIEH': + raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH') + + w = np.fromfile(f, np.int32, 1).squeeze() + h = np.fromfile(f, np.int32, 1).squeeze() + # flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) + flow = np.fromfile(f, np.float16, w * h * 2).reshape((h, w, 2)) + + return flow.astype(np.float32) + + +def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): + """Write optical flow to file. + + If the flow is not quantized, it will be saved as a .flo file losslessly, + otherwise a jpeg image which is lossy but of much smaller size. (dx and dy + will be concatenated horizontally into a single image if quantize is True.) + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + filename (str): Output filepath. + quantize (bool): Whether to quantize the flow and save it to 2 jpeg + images. If set to True, remaining args will be passed to + :func:`quantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + """ + dir_name = os.path.abspath(os.path.dirname(filename)) + os.makedirs(dir_name, exist_ok=True) + if not quantize: + with open(filename, 'wb') as f: + f.write('PIEH'.encode('utf-8')) + np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) + # flow = flow.astype(np.float32) + flow = flow.astype(np.float16) + flow.tofile(f) + f.flush() + else: + assert concat_axis in [0, 1] + dx, dy = quantize_flow(flow, *args, **kwargs) + dxdy = np.concatenate((dx, dy), axis=concat_axis) + # os.makedirs(os.path.dirname(filename), exist_ok=True) + cv2.imwrite(filename, dxdy) + # imwrite(dxdy, filename) + + +def quantize_flow(flow, max_val=0.02, norm=True): + """Quantize flow to [0, 255]. + + After this step, the size of flow will be much smaller, and can be + dumped as jpeg images. + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + max_val (float): Maximum value of flow, values beyond + [-max_val, max_val] will be truncated. + norm (bool): Whether to divide flow values by image width/height. + + Returns: + tuple[ndarray]: Quantized dx and dy. + """ + h, w, _ = flow.shape + dx = flow[..., 0] + dy = flow[..., 1] + if norm: + dx = dx / w # avoid inplace operations + dy = dy / h + # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. + flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]] + return tuple(flow_comps) + + +def dequantize_flow(dx, dy, max_val=0.02, denorm=True): + """Recover from quantized flow. + + Args: + dx (ndarray): Quantized dx. + dy (ndarray): Quantized dy. + max_val (float): Maximum value used when quantizing. + denorm (bool): Whether to multiply flow values with width/height. + + Returns: + ndarray: Dequantized flow. + """ + assert dx.shape == dy.shape + assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) + + dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] + + if denorm: + dx *= dx.shape[1] + dy *= dx.shape[0] + flow = np.dstack((dx, dy)) + return flow + + +def quantize(arr, min_val, max_val, levels, dtype=np.int64): + """Quantize an array of (-inf, inf) to [0, levels-1]. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the quantized array. + + Returns: + tuple: Quantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError(f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + arr = np.clip(arr, min_val, max_val) - min_val + quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) + + return quantized_arr + + +def dequantize(arr, min_val, max_val, levels, dtype=np.float64): + """Dequantize an array. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the dequantized array. + + Returns: + tuple: Dequantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError(f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val + + return dequantized_arr \ No newline at end of file diff --git a/utils/img_util.py b/utils/img_util.py new file mode 100644 index 0000000000000000000000000000000000000000..d409a132ff216e6943a276fb5d8cd5f410824883 --- /dev/null +++ b/utils/img_util.py @@ -0,0 +1,170 @@ +import cv2 +import math +import numpy as np +import os +import torch +from torchvision.utils import make_grid + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): + """This implementation is slightly faster than tensor2img. + It now only supports torch tensor with shape (1, c, h, w). + + Args: + tensor (Tensor): Now only support torch tensor with (1, c, h, w). + rgb2bgr (bool): Whether to change rgb to bgr. Default: True. + min_max (tuple[int]): min and max values for clamp. + """ + output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) + output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 + output = output.type(torch.uint8).cpu().numpy() + if rgb2bgr: + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def crop_border(imgs, crop_border): + """Crop borders of images. + + Args: + imgs (list[ndarray] | ndarray): Images with shape (h, w, c). + crop_border (int): Crop border for each end of height and weight. + + Returns: + list[ndarray]: Cropped images. + """ + if crop_border == 0: + return imgs + else: + if isinstance(imgs, list): + return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] + else: + return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] diff --git a/web-demos/hugging_face/.gitignore b/web-demos/hugging_face/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..cce2271ec487a5debb14ea9051db762035caa45a --- /dev/null +++ b/web-demos/hugging_face/.gitignore @@ -0,0 +1,8 @@ +__pycache__/ +.vscode/ +docs/ +debug_images/ +images/ +result/ +vots/ +vots.py diff --git a/web-demos/hugging_face/LICENSE b/web-demos/hugging_face/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0e658c9ff1a1850b4816271015634ed4a7cb11b3 --- /dev/null +++ b/web-demos/hugging_face/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Mingqi Gao + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/web-demos/hugging_face/README.md b/web-demos/hugging_face/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f5f075add318978857dfc6ea6a57522e05eb2b7e --- /dev/null +++ b/web-demos/hugging_face/README.md @@ -0,0 +1,57 @@ +## Get Started +1. Install ProPainter Dependencies +You can follow the [Dependencies and Installation](https://github.com/Luo-Yihang/ProPainter-pr/tree/dev_yihang#dependencies-and-installation) + +2. Install Demo Dependencies +```shell +cd web-demos/hugging_face + +# install python dependencies +pip3 install -r requirements.txt + +# Run the demo +python app.py +``` + +## Usage Guidance +* Step 1: Upload your video and click the `Get video info` button. +![Step 1](./assets/step1.png) + +* Step 2: + 1. *[Optional]* Specify the tracking period for the currently added mask by dragging the `Track start frame` or `Track end frame`. + 2. Click the image on the left to select the mask area. + 3. - Click `Add mask` if you are satisfied with the mask, or + - *[Optional]* Click `Clear clicks` if you want to reselect the mask area, or + - *[Optional]* Click `Remove mask` to remove all masks. + 4. *[Optional]* Go back to step 2.1 to add another mask. +![Step 2](./assets/step2.png) + +* Step 3: + 1. Click the `Tracking` button to track the masks for the whole video. + 2. *[Optional]* Select the ProPainter parameters if the `ProPainter Parameters` dropdown. + 2. Then click `Inpainting` to get the inpainting results. +![Step 3](./assets/step3.png) + +*You can always refer to the `Highlighted Text` box on the page for guidance on the next step!* + + +## Citation +If you find our repo useful for your research, please consider citing our paper: +```bibtex +@inproceedings{zhou2023propainter, + title={{ProPainter}: Improving Propagation and Transformer for Video Inpainting}, + author={Zhou, Shangchen and Li, Chongyi and Chan, Kelvin C.K and Loy, Chen Change}, + booktitle={Proceedings of IEEE International Conference on Computer Vision (ICCV)}, + year={2023} +} +``` + + +## License + +This project is licensed under NTU S-Lab License 1.0. Redistribution and use should follow this license. + + +## Acknowledgements + +The project harnesses the capabilities from [Track Anything](https://github.com/gaomingqi/Track-Anything), [Segment Anything](https://github.com/facebookresearch/segment-anything), [Cutie](https://github.com/hkchengrex/Cutie), and [E2FGVI](https://github.com/MCG-NKU/E2FGVI). Thanks for their awesome works. diff --git a/web-demos/hugging_face/app.py b/web-demos/hugging_face/app.py new file mode 100644 index 0000000000000000000000000000000000000000..2abe6655a77c2480c6550159387de565f8fa3ee0 --- /dev/null +++ b/web-demos/hugging_face/app.py @@ -0,0 +1,641 @@ +import sys +sys.path.append("../../") + +import os +import json +import time +import psutil +import argparse + +import cv2 +import torch +import torchvision +import numpy as np +import gradio as gr + +from tools.painter import mask_painter +from track_anything import TrackingAnything + +from model.misc import get_device +from utils.download_util import load_file_from_url + + +def parse_augment(): + parser = argparse.ArgumentParser() + parser.add_argument('--device', type=str, default=None) + parser.add_argument('--sam_model_type', type=str, default="vit_h") + parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications") + parser.add_argument('--mask_save', default=False) + args = parser.parse_args() + + if not args.device: + args.device = str(get_device()) + + return args + +# convert points input to prompt state +def get_prompt(click_state, click_input): + inputs = json.loads(click_input) + points = click_state[0] + labels = click_state[1] + for input in inputs: + points.append(input[:2]) + labels.append(input[2]) + click_state[0] = points + click_state[1] = labels + prompt = { + "prompt_type":["click"], + "input_point":click_state[0], + "input_label":click_state[1], + "multimask_output":"True", + } + return prompt + +# extract frames from upload video +def get_frames_from_video(video_input, video_state): + """ + Args: + video_path:str + timestamp:float64 + Return + [[0:nearest_frame], [nearest_frame:], nearest_frame] + """ + video_path = video_input + frames = [] + user_name = time.time() + operation_log = [("",""),("Video uploaded! Try to click the image shown in step2 to add masks.","Normal")] + try: + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) + while cap.isOpened(): + ret, frame = cap.read() + if ret == True: + current_memory_usage = psutil.virtual_memory().percent + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + # if current_memory_usage > 90: + # operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")] + # print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.") + # break + else: + break + except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e: + print("read_frame_source:{} error. {}\n".format(video_path, str(e))) + image_size = (frames[0].shape[0],frames[0].shape[1]) + # initialize video_state + video_state = { + "user_name": user_name, + "video_name": os.path.split(video_path)[-1], + "origin_images": frames, + "painted_images": frames.copy(), + "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames), + "logits": [None]*len(frames), + "select_frame_number": 0, + "fps": fps + } + video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size) + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(video_state["origin_images"][0]) + return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \ + gr.update(visible=True), gr.update(visible=True), \ + gr.update(visible=True), gr.update(visible=True),\ + gr.update(visible=True), gr.update(visible=True), \ + gr.update(visible=True), gr.update(visible=True), \ + gr.update(visible=True), gr.update(visible=True), \ + gr.update(visible=True), gr.update(visible=True, choices=[], value=[]), \ + gr.update(visible=True, value=operation_log), gr.update(visible=True, value=operation_log) + +# get the select frame from gradio slider +def select_template(image_selection_slider, video_state, interactive_state, mask_dropdown): + + # images = video_state[1] + image_selection_slider -= 1 + video_state["select_frame_number"] = image_selection_slider + + # once select a new template frame, set the image in sam + + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider]) + + operation_log = [("",""), ("Select tracking start frame {}. Try to click the image to add masks for tracking.".format(image_selection_slider),"Normal")] + + return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log, operation_log + +# set the tracking end frame +def get_end_number(track_pause_number_slider, video_state, interactive_state): + interactive_state["track_end_number"] = track_pause_number_slider + operation_log = [("",""),("Select tracking finish frame {}.Try to click the image to add masks for tracking.".format(track_pause_number_slider),"Normal")] + + return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log, operation_log + +# use sam to get the mask +def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData): + """ + Args: + template_frame: PIL.Image + point_prompt: flag for positive or negative button click + click_state: [[points], [labels]] + """ + if point_prompt == "Positive": + coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1]) + interactive_state["positive_click_times"] += 1 + else: + coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1]) + interactive_state["negative_click_times"] += 1 + + # prompt for sam model + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]]) + prompt = get_prompt(click_state=click_state, click_input=coordinate) + + mask, logit, painted_image = model.first_frame_click( + image=video_state["origin_images"][video_state["select_frame_number"]], + points=np.array(prompt["input_point"]), + labels=np.array(prompt["input_label"]), + multimask=prompt["multimask_output"], + ) + video_state["masks"][video_state["select_frame_number"]] = mask + video_state["logits"][video_state["select_frame_number"]] = logit + video_state["painted_images"][video_state["select_frame_number"]] = painted_image + + operation_log = [("",""), ("You can try to add positive or negative points by clicking, click Clear clicks button to refresh the image, click Add mask button when you are satisfied with the segment, or click Remove mask button to remove all added masks.","Normal")] + return painted_image, video_state, interactive_state, operation_log, operation_log + +def add_multi_mask(video_state, interactive_state, mask_dropdown): + try: + mask = video_state["masks"][video_state["select_frame_number"]] + interactive_state["multi_mask"]["masks"].append(mask) + interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) + mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) + select_frame, _, _ = show_mask(video_state, interactive_state, mask_dropdown) + operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")] + except: + operation_log = [("Please click the image in step2 to generate masks.", "Error"), ("","")] + return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log, operation_log + +def clear_click(video_state, click_state): + click_state = [[],[]] + template_frame = video_state["origin_images"][video_state["select_frame_number"]] + operation_log = [("",""), ("Cleared points history and refresh the image.","Normal")] + return template_frame, click_state, operation_log, operation_log + +def remove_multi_mask(interactive_state, mask_dropdown): + interactive_state["multi_mask"]["mask_names"]= [] + interactive_state["multi_mask"]["masks"] = [] + + operation_log = [("",""), ("Remove all masks. Try to add new masks","Normal")] + return interactive_state, gr.update(choices=[],value=[]), operation_log, operation_log + +def show_mask(video_state, interactive_state, mask_dropdown): + mask_dropdown.sort() + select_frame = video_state["origin_images"][video_state["select_frame_number"]] + for i in range(len(mask_dropdown)): + mask_number = int(mask_dropdown[i].split("_")[1]) - 1 + mask = interactive_state["multi_mask"]["masks"][mask_number] + select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2) + + operation_log = [("",""), ("Added masks {}. If you want to do the inpainting with current masks, please go to step3, and click the Tracking button first and then Inpainting button.".format(mask_dropdown),"Normal")] + return select_frame, operation_log, operation_log + +# tracking vos +def vos_tracking_video(video_state, interactive_state, mask_dropdown): + operation_log = [("",""), ("Tracking finished! Try to click the Inpainting button to get the inpainting result.","Normal")] + model.cutie.clear_memory() + if interactive_state["track_end_number"]: + following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] + else: + following_frames = video_state["origin_images"][video_state["select_frame_number"]:] + + if interactive_state["multi_mask"]["masks"]: + if len(mask_dropdown) == 0: + mask_dropdown = ["mask_001"] + mask_dropdown.sort() + template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1])) + for i in range(1,len(mask_dropdown)): + mask_number = int(mask_dropdown[i].split("_")[1]) - 1 + template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1) + video_state["masks"][video_state["select_frame_number"]]= template_mask + else: + template_mask = video_state["masks"][video_state["select_frame_number"]] + fps = video_state["fps"] + + # operation error + if len(np.unique(template_mask))==1: + template_mask[0][0]=1 + operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")] + # return video_output, video_state, interactive_state, operation_error + masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask) + # clear GPU memory + model.cutie.clear_memory() + + if interactive_state["track_end_number"]: + video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks + video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits + video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images + else: + video_state["masks"][video_state["select_frame_number"]:] = masks + video_state["logits"][video_state["select_frame_number"]:] = logits + video_state["painted_images"][video_state["select_frame_number"]:] = painted_images + + video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video + interactive_state["inference_times"] += 1 + + print("For generating this tracking result, inference times: {}, click times: {}, positive: {}, negative: {}".format(interactive_state["inference_times"], + interactive_state["positive_click_times"]+interactive_state["negative_click_times"], + interactive_state["positive_click_times"], + interactive_state["negative_click_times"])) + + #### shanggao code for mask save + if interactive_state["mask_save"]: + if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])): + os.makedirs('./result/mask/{}'.format(video_state["video_name"].split('.')[0])) + i = 0 + print("save mask") + for mask in video_state["masks"]: + np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask) + i+=1 + # save_mask(video_state["masks"], video_state["video_name"]) + #### shanggao code for mask save + return video_output, video_state, interactive_state, operation_log, operation_log + +# inpaint +def inpaint_video(video_state, resize_ratio_number, dilate_radius_number, raft_iter_number, subvideo_length_number, neighbor_length_number, ref_stride_number, mask_dropdown): + operation_log = [("",""), ("Inpainting finished!","Normal")] + + frames = np.asarray(video_state["origin_images"]) + fps = video_state["fps"] + inpaint_masks = np.asarray(video_state["masks"]) + if len(mask_dropdown) == 0: + mask_dropdown = ["mask_001"] + mask_dropdown.sort() + # convert mask_dropdown to mask numbers + inpaint_mask_numbers = [int(mask_dropdown[i].split("_")[1]) for i in range(len(mask_dropdown))] + # interate through all masks and remove the masks that are not in mask_dropdown + unique_masks = np.unique(inpaint_masks) + num_masks = len(unique_masks) - 1 + for i in range(1, num_masks + 1): + if i in inpaint_mask_numbers: + continue + inpaint_masks[inpaint_masks==i] = 0 + + # inpaint for videos + inpainted_frames = model.baseinpainter.inpaint(frames, + inpaint_masks, + ratio=resize_ratio_number, + dilate_radius=dilate_radius_number, + raft_iter=raft_iter_number, + subvideo_length=subvideo_length_number, + neighbor_length=neighbor_length_number, + ref_stride=ref_stride_number) # numpy array, T, H, W, 3 + + video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video + + return video_output, operation_log, operation_log + +# generate video after vos inference +def generate_video_from_frames(frames, output_path, fps=30): + """ + Generates a video from a list of frames. + + Args: + frames (list of numpy arrays): The frames to include in the video. + output_path (str): The path to save the generated video. + fps (int, optional): The frame rate of the output video. Defaults to 30. + """ + frames = torch.from_numpy(np.asarray(frames)) + if not os.path.exists(os.path.dirname(output_path)): + os.makedirs(os.path.dirname(output_path)) + torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264") + return output_path + +def restart(): + operation_log = [("",""), ("Try to upload your video and click the Get video info button to get started!", "Normal")] + return { + "user_name": "", + "video_name": "", + "origin_images": None, + "painted_images": None, + "masks": None, + "inpaint_masks": None, + "logits": None, + "select_frame_number": 0, + "fps": 30 + }, { + "inference_times": 0, + "negative_click_times" : 0, + "positive_click_times": 0, + "mask_save": args.mask_save, + "multi_mask": { + "mask_names": [], + "masks": [] + }, + "track_end_number": None, + }, [[],[]], None, None, None, \ + gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\ + gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ + gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ + gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), "", \ + gr.update(visible=True, value=operation_log), gr.update(visible=False, value=operation_log) + + +# args, defined in track_anything.py +args = parse_augment() +pretrain_model_url = 'https://github.com/sczhou/ProPainter/releases/download/v0.1.0/' +sam_checkpoint_url_dict = { + 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", + 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", + 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" +} +checkpoint_fodler = os.path.join('..', '..', 'weights') + +sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_fodler) +cutie_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'cutie-base-mega.pth'), checkpoint_fodler) +propainter_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'ProPainter.pth'), checkpoint_fodler) +raft_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'raft-things.pth'), checkpoint_fodler) +flow_completion_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'recurrent_flow_completion.pth'), checkpoint_fodler) + +# initialize sam, cutie, propainter models +model = TrackingAnything(sam_checkpoint, cutie_checkpoint, propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, args) + + +title = r"""

ProPainter: Improving Propagation and Transformer for Video Inpainting

""" + +description = r""" +
Propainter logo
+Official Gradio demo for Improving Propagation and Transformer for Video Inpainting (ICCV 2023).
+🔥 Propainter is a robust inpainting algorithm.
+🤗 Try to drop your video, add the masks and get the the inpainting results!
+""" +article = r""" +If ProPainter is helpful, please help to ⭐ the Github Repo. Thanks! +[![GitHub Stars](https://img.shields.io/github/stars/sczhou/ProPainter?style=social)](https://github.com/sczhou/ProPainter) + +--- + +📝 **Citation** +
+If our work is useful for your research, please consider citing: +```bibtex +@inproceedings{zhou2023propainter, + title={{ProPainter}: Improving Propagation and Transformer for Video Inpainting}, + author={Zhou, Shangchen and Li, Chongyi and Chan, Kelvin C.K and Loy, Chen Change}, + booktitle={Proceedings of IEEE International Conference on Computer Vision (ICCV)}, + year={2023} +} +``` + +📋 **License** +
+This project is licensed under S-Lab License 1.0. +Redistribution and use for non-commercial purposes should follow this license. + +📧 **Contact** +
+If you have any questions, please feel free to reach me out at shangchenzhou@gmail.com. +
+ 🤗 Find Me: + Twitter Follow + Github Follow +
+ +""" +css = """ +.gradio-container {width: 85% !important} +.gr-monochrome-group {border-radius: 5px !important; border: revert-layer !important; border-width: 2px !important; color: black !important} +button {border-radius: 8px !important;} +.add_button {background-color: #4CAF50 !important;} +.remove_button {background-color: #f44336 !important;} +.mask_button_group {gap: 10px !important;} +.video {height: 300px !important;} +.image {height: 300px !important;} +.video .wrap.svelte-lcpz3o {display: flex !important; align-items: center !important; justify-content: center !important;} +.video .wrap.svelte-lcpz3o > :first-child {height: 100% !important;} +.margin_center {width: 50% !important; margin: auto !important;} +.jc_center {justify-content: center !important;} +""" + +with gr.Blocks(theme=gr.themes.Monochrome(), css=css) as iface: + click_state = gr.State([[],[]]) + + interactive_state = gr.State({ + "inference_times": 0, + "negative_click_times" : 0, + "positive_click_times": 0, + "mask_save": args.mask_save, + "multi_mask": { + "mask_names": [], + "masks": [] + }, + "track_end_number": None, + } + ) + + video_state = gr.State( + { + "user_name": "", + "video_name": "", + "origin_images": None, + "painted_images": None, + "masks": None, + "inpaint_masks": None, + "logits": None, + "select_frame_number": 0, + "fps": 30 + } + ) + + gr.Markdown(title) + gr.Markdown(description) + + with gr.Group(elem_classes="gr-monochrome-group"): + with gr.Row(): + with gr.Accordion('ProPainter Parameters', open=False): + with gr.Row(): + resize_ratio_number = gr.Slider(label='Resize ratio', + minimum=0.01, + maximum=1.0, + step=0.01, + value=1.0) + raft_iter_number = gr.Slider(label='Iterations for RAFT inference.', + minimum=5, + maximum=20, + step=1, + value=20,) + with gr.Row(): + dilate_radius_number = gr.Slider(label='Mask dilation for video and flow masking.', + minimum=0, + maximum=10, + step=1, + value=8,) + + subvideo_length_number = gr.Slider(label='Length of sub-video for long video inference.', + minimum=40, + maximum=200, + step=1, + value=80,) + with gr.Row(): + neighbor_length_number = gr.Slider(label='Length of local neighboring frames.', + minimum=5, + maximum=20, + step=1, + value=10,) + + ref_stride_number = gr.Slider(label='Stride of global reference frames.', + minimum=5, + maximum=20, + step=1, + value=10,) + + with gr.Column(): + # input video + gr.Markdown("## Step1: Upload video") + with gr.Row(equal_height=True): + with gr.Column(scale=2): + video_input = gr.Video(elem_classes="video") + extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary") + with gr.Column(scale=2): + run_status = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get svideo info button to get started!", "Normal")]) + video_info = gr.Textbox(label="Video Info") + + + # add masks + step2_title = gr.Markdown("---\n## Step2: Add masks", visible=False) + with gr.Row(equal_height=True): + with gr.Column(scale=2): + template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image") + image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False) + track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False) + with gr.Column(scale=2, elem_classes="jc_center"): + run_status2 = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get svideo info button to get started!", "Normal")], visible=False) + with gr.Row(): + with gr.Column(scale=2, elem_classes="mask_button_group"): + clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False) + remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False, elem_classes="remove_button") + Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False, elem_classes="add_button") + point_prompt = gr.Radio( + choices=["Positive", "Negative"], + value="Positive", + label="Point prompt", + interactive=True, + visible=False, + min_width=100, + scale=1) + mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False) + + # output video + step3_title = gr.Markdown("---\n## Step3: Track masks and get the inpainting result", visible=False) + with gr.Row(equal_height=True): + with gr.Column(scale=2): + tracking_video_output = gr.Video(visible=False, elem_classes="video") + tracking_video_predict_button = gr.Button(value="1. Tracking", visible=False, elem_classes="margin_center") + with gr.Column(scale=2): + inpaiting_video_output = gr.Video(visible=False, elem_classes="video") + inpaint_video_predict_button = gr.Button(value="2. Inpainting", visible=False, elem_classes="margin_center") + + # first step: get the video information + extract_frames_button.click( + fn=get_frames_from_video, + inputs=[ + video_input, video_state + ], + outputs=[video_state, video_info, template_frame, + image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button, template_frame, + tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button, inpaint_video_predict_button, step2_title, step3_title,mask_dropdown, run_status, run_status2] + ) + + # second step: select images from slider + image_selection_slider.release(fn=select_template, + inputs=[image_selection_slider, video_state, interactive_state], + outputs=[template_frame, video_state, interactive_state, run_status, run_status2], api_name="select_image") + track_pause_number_slider.release(fn=get_end_number, + inputs=[track_pause_number_slider, video_state, interactive_state], + outputs=[template_frame, interactive_state, run_status, run_status2], api_name="end_image") + + # click select image to get mask using sam + template_frame.select( + fn=sam_refine, + inputs=[video_state, point_prompt, click_state, interactive_state], + outputs=[template_frame, video_state, interactive_state, run_status, run_status2] + ) + + # add different mask + Add_mask_button.click( + fn=add_multi_mask, + inputs=[video_state, interactive_state, mask_dropdown], + outputs=[interactive_state, mask_dropdown, template_frame, click_state, run_status, run_status2] + ) + + remove_mask_button.click( + fn=remove_multi_mask, + inputs=[interactive_state, mask_dropdown], + outputs=[interactive_state, mask_dropdown, run_status, run_status2] + ) + + # tracking video from select image and mask + tracking_video_predict_button.click( + fn=vos_tracking_video, + inputs=[video_state, interactive_state, mask_dropdown], + outputs=[tracking_video_output, video_state, interactive_state, run_status, run_status2] + ) + + # inpaint video from select image and mask + inpaint_video_predict_button.click( + fn=inpaint_video, + inputs=[video_state, resize_ratio_number, dilate_radius_number, raft_iter_number, subvideo_length_number, neighbor_length_number, ref_stride_number, mask_dropdown], + outputs=[inpaiting_video_output, run_status, run_status2] + ) + + # click to get mask + mask_dropdown.change( + fn=show_mask, + inputs=[video_state, interactive_state, mask_dropdown], + outputs=[template_frame, run_status, run_status2] + ) + + # clear input + video_input.change( + fn=restart, + inputs=[], + outputs=[ + video_state, + interactive_state, + click_state, + tracking_video_output, inpaiting_video_output, + template_frame, + tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click, + Add_mask_button, template_frame, tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button,inpaint_video_predict_button, step2_title, step3_title, mask_dropdown, video_info, run_status, run_status2 + ], + queue=False, + show_progress=False) + + video_input.clear( + fn=restart, + inputs=[], + outputs=[ + video_state, + interactive_state, + click_state, + tracking_video_output, inpaiting_video_output, + template_frame, + tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click, + Add_mask_button, template_frame, tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button,inpaint_video_predict_button, step2_title, step3_title, mask_dropdown, video_info, run_status, run_status2 + ], + queue=False, + show_progress=False) + + # points clear + clear_button_click.click( + fn = clear_click, + inputs = [video_state, click_state,], + outputs = [template_frame,click_state, run_status, run_status2], + ) + + # set example + gr.Markdown("## Examples") + gr.Examples( + examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample0.mp4", "test-sample1.mp4", "test-sample2.mp4", "test-sample3.mp4", "test-sample4.mp4"]], + inputs=[video_input], + ) + gr.Markdown(article) + +iface.queue(concurrency_count=1) +iface.launch(debug=True) \ No newline at end of file diff --git a/web-demos/hugging_face/assets/step1.png b/web-demos/hugging_face/assets/step1.png new file mode 100644 index 0000000000000000000000000000000000000000..b5b389402c6853ec138e101719ec06096b5bb7a1 --- /dev/null +++ b/web-demos/hugging_face/assets/step1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c93010fa938c75ae671e3aa362205f3f2692783930f67a6623e0a438479e7326 +size 308524 diff --git a/web-demos/hugging_face/assets/step2.png b/web-demos/hugging_face/assets/step2.png new file mode 100644 index 0000000000000000000000000000000000000000..82d2179eaf9372fc964ce10e464575c414bf0381 --- /dev/null +++ b/web-demos/hugging_face/assets/step2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de449d6fa2c476d1db0e7dcf48e58f7e2b0a510564dbbdadcd9436c02b989f79 +size 751253 diff --git a/web-demos/hugging_face/assets/step3.png b/web-demos/hugging_face/assets/step3.png new file mode 100644 index 0000000000000000000000000000000000000000..b5f980a93e575887ee438c285ca41deae7e38f84 --- /dev/null +++ b/web-demos/hugging_face/assets/step3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d8273104c2e558cb0d1edfe91f5e4ca27483815f5e31aad153f199a351c87b12 +size 1143298 diff --git a/web-demos/hugging_face/inpainter/base_inpainter.py b/web-demos/hugging_face/inpainter/base_inpainter.py new file mode 100644 index 0000000000000000000000000000000000000000..b2408c48315e363ca7ed09efa261f7519d36979a --- /dev/null +++ b/web-demos/hugging_face/inpainter/base_inpainter.py @@ -0,0 +1,372 @@ +# -*- coding: utf-8 -*- +import os +import sys +import cv2 +import numpy as np +import scipy.ndimage +from PIL import Image +from tqdm import tqdm + +import torch +import torchvision + +from model.modules.flow_comp_raft import RAFT_bi +from model.recurrent_flow_completion import RecurrentFlowCompleteNet +from model.propainter import InpaintGenerator +from core.utils import to_tensors + +import warnings +warnings.filterwarnings("ignore") + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def resize_frames(frames, size=None): + if size is not None: + out_size = size + process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8) + frames = [f.resize(process_size) for f in frames] + else: + out_size = frames[0].size + process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8) + if not out_size == process_size: + frames = [f.resize(process_size) for f in frames] + + return frames, process_size, out_size + + +def read_frame_from_videos(frame_root): + if frame_root.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path + video_name = os.path.basename(frame_root)[:-4] + vframes, aframes, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec') # RGB + frames = list(vframes.numpy()) + frames = [Image.fromarray(f) for f in frames] + fps = info['video_fps'] + else: + video_name = os.path.basename(frame_root) + frames = [] + fr_lst = sorted(os.listdir(frame_root)) + for fr in fr_lst: + frame = cv2.imread(os.path.join(frame_root, fr)) + frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + frames.append(frame) + fps = None + size = frames[0].size + + return frames, fps, size, video_name + + +def binary_mask(mask, th=0.1): + mask[mask>th] = 1 + mask[mask<=th] = 0 + return mask + + +def extrapolation(video_ori, scale): + """Prepares the data for video outpainting. + """ + nFrame = len(video_ori) + imgW, imgH = video_ori[0].size + + # Defines new FOV. + imgH_extr = int(scale[0] * imgH) + imgW_extr = int(scale[1] * imgW) + imgH_extr = imgH_extr - imgH_extr % 8 + imgW_extr = imgW_extr - imgW_extr % 8 + H_start = int((imgH_extr - imgH) / 2) + W_start = int((imgW_extr - imgW) / 2) + + # Extrapolates the FOV for video. + frames = [] + for v in video_ori: + frame = np.zeros(((imgH_extr, imgW_extr, 3)), dtype=np.uint8) + frame[H_start: H_start + imgH, W_start: W_start + imgW, :] = v + frames.append(Image.fromarray(frame)) + + # Generates the mask for missing region. + masks_dilated = [] + flow_masks = [] + + dilate_h = 4 if H_start > 10 else 0 + dilate_w = 4 if W_start > 10 else 0 + mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.uint8) + + mask[H_start+dilate_h: H_start+imgH-dilate_h, + W_start+dilate_w: W_start+imgW-dilate_w] = 0 + flow_masks.append(Image.fromarray(mask * 255)) + + mask[H_start: H_start+imgH, W_start: W_start+imgW] = 0 + masks_dilated.append(Image.fromarray(mask * 255)) + + flow_masks = flow_masks * nFrame + masks_dilated = masks_dilated * nFrame + + return frames, flow_masks, masks_dilated, (imgW_extr, imgH_extr) + + +def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1): + ref_index = [] + if ref_num == -1: + for i in range(0, length, ref_stride): + if i not in neighbor_ids: + ref_index.append(i) + else: + start_idx = max(0, mid_neighbor_id - ref_stride * (ref_num // 2)) + end_idx = min(length, mid_neighbor_id + ref_stride * (ref_num // 2)) + for i in range(start_idx, end_idx, ref_stride): + if i not in neighbor_ids: + if len(ref_index) > ref_num: + break + ref_index.append(i) + return ref_index + + +def read_mask_demo(masks, length, size, flow_mask_dilates=8, mask_dilates=5): + masks_img = [] + masks_dilated = [] + flow_masks = [] + + for mp in masks: + masks_img.append(Image.fromarray(mp.astype('uint8'))) + + for mask_img in masks_img: + if size is not None: + mask_img = mask_img.resize(size, Image.NEAREST) + mask_img = np.array(mask_img.convert('L')) + + # Dilate 8 pixel so that all known pixel is trustworthy + if flow_mask_dilates > 0: + flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=flow_mask_dilates).astype(np.uint8) + else: + flow_mask_img = binary_mask(mask_img).astype(np.uint8) + + flow_masks.append(Image.fromarray(flow_mask_img * 255)) + + if mask_dilates > 0: + mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=mask_dilates).astype(np.uint8) + else: + mask_img = binary_mask(mask_img).astype(np.uint8) + masks_dilated.append(Image.fromarray(mask_img * 255)) + + if len(masks_img) == 1: + flow_masks = flow_masks * length + masks_dilated = masks_dilated * length + + return flow_masks, masks_dilated + + +class ProInpainter: + def __init__(self, propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, device="cuda:0", use_half=True): + self.device = device + self.use_half = use_half + + ############################################## + # set up RAFT and flow competition model + ############################################## + self.fix_raft = RAFT_bi(raft_checkpoint, self.device) + + self.fix_flow_complete = RecurrentFlowCompleteNet(flow_completion_checkpoint) + for p in self.fix_flow_complete.parameters(): + p.requires_grad = False + self.fix_flow_complete.to(self.device) + self.fix_flow_complete.eval() + + ############################################## + # set up ProPainter model + ############################################## + self.model = InpaintGenerator(model_path=propainter_checkpoint).to(self.device) + self.model.eval() + + if self.use_half: + self.fix_flow_complete = self.fix_flow_complete.half() + self.model = self.model.half() + + def inpaint(self, npframes, masks, ratio=1.0, dilate_radius=4, raft_iter=20, subvideo_length=80, neighbor_length=10, ref_stride=10): + """ + Perform Inpainting for video subsets + + Output: + inpainted_frames: numpy array, T, H, W, 3 + """ + + frames = [] + for i in range(len(npframes)): + frames.append(Image.fromarray(npframes[i].astype('uint8'), mode="RGB")) + del npframes + + size = frames[0].size + # The ouput size should be divided by 2 so that it can encoded by libx264 + size = (int(ratio*size[0])//2*2, int(ratio*size[1])//2*2) + + frames_len = len(frames) + frames, size, out_size = resize_frames(frames, size) + flow_masks, masks_dilated = read_mask_demo(masks, frames_len, size, dilate_radius, dilate_radius) + w, h = size + + frames_inp = [np.array(f).astype(np.uint8) for f in frames] + frames = to_tensors()(frames).unsqueeze(0) * 2 - 1 + flow_masks = to_tensors()(flow_masks).unsqueeze(0) + masks_dilated = to_tensors()(masks_dilated).unsqueeze(0) + frames, flow_masks, masks_dilated = frames.to(self.device), flow_masks.to(self.device), masks_dilated.to(self.device) + + ############################################## + # ProPainter inference + ############################################## + video_length = frames.size(1) + with torch.no_grad(): + # ---- compute flow ---- + if frames.size(-1) <= 640: + short_clip_len = 12 + elif frames.size(-1) <= 720: + short_clip_len = 8 + elif frames.size(-1) <= 1280: + short_clip_len = 4 + else: + short_clip_len = 2 + + # use fp32 for RAFT + if frames.size(1) > short_clip_len: + gt_flows_f_list, gt_flows_b_list = [], [] + for f in range(0, video_length, short_clip_len): + end_f = min(video_length, f + short_clip_len) + if f == 0: + flows_f, flows_b = self.fix_raft(frames[:,f:end_f], iters=raft_iter) + else: + flows_f, flows_b = self.fix_raft(frames[:,f-1:end_f], iters=raft_iter) + + gt_flows_f_list.append(flows_f) + gt_flows_b_list.append(flows_b) + torch.cuda.empty_cache() + + gt_flows_f = torch.cat(gt_flows_f_list, dim=1) + gt_flows_b = torch.cat(gt_flows_b_list, dim=1) + gt_flows_bi = (gt_flows_f, gt_flows_b) + else: + gt_flows_bi = self.fix_raft(frames, iters=raft_iter) + torch.cuda.empty_cache() + + if self.use_half: + frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half() + gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half()) + + # ---- complete flow ---- + flow_length = gt_flows_bi[0].size(1) + if flow_length > subvideo_length: + pred_flows_f, pred_flows_b = [], [] + pad_len = 5 + for f in range(0, flow_length, subvideo_length): + s_f = max(0, f - pad_len) + e_f = min(flow_length, f + subvideo_length + pad_len) + pad_len_s = max(0, f) - s_f + pad_len_e = e_f - min(flow_length, f + subvideo_length) + pred_flows_bi_sub, _ = self.fix_flow_complete.forward_bidirect_flow( + (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]), + flow_masks[:, s_f:e_f+1]) + pred_flows_bi_sub = self.fix_flow_complete.combine_flow( + (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]), + pred_flows_bi_sub, + flow_masks[:, s_f:e_f+1]) + + pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e]) + pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e]) + torch.cuda.empty_cache() + + pred_flows_f = torch.cat(pred_flows_f, dim=1) + pred_flows_b = torch.cat(pred_flows_b, dim=1) + pred_flows_bi = (pred_flows_f, pred_flows_b) + else: + pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks) + pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks) + torch.cuda.empty_cache() + + # ---- image propagation ---- + masked_frames = frames * (1 - masks_dilated) + subvideo_length_img_prop = min(100, subvideo_length) # ensure a minimum of 100 frames for image propagation + if video_length > subvideo_length_img_prop: + updated_frames, updated_masks = [], [] + pad_len = 10 + for f in range(0, video_length, subvideo_length_img_prop): + s_f = max(0, f - pad_len) + e_f = min(video_length, f + subvideo_length_img_prop + pad_len) + pad_len_s = max(0, f) - s_f + pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop) + + b, t, _, _, _ = masks_dilated[:, s_f:e_f].size() + pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f-1], pred_flows_bi[1][:, s_f:e_f-1]) + prop_imgs_sub, updated_local_masks_sub = self.model.img_propagation(masked_frames[:, s_f:e_f], + pred_flows_bi_sub, + masks_dilated[:, s_f:e_f], + 'nearest') + updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + \ + prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f] + updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w) + + updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e]) + updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e]) + torch.cuda.empty_cache() + + updated_frames = torch.cat(updated_frames, dim=1) + updated_masks = torch.cat(updated_masks, dim=1) + else: + b, t, _, _, _ = masks_dilated.size() + prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest') + updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated + updated_masks = updated_local_masks.view(b, t, 1, h, w) + torch.cuda.empty_cache() + + ori_frames = frames_inp + comp_frames = [None] * video_length + + neighbor_stride = neighbor_length // 2 + if video_length > subvideo_length: + ref_num = subvideo_length // ref_stride + else: + ref_num = -1 + + # ---- feature propagation + transformer ---- + for f in tqdm(range(0, video_length, neighbor_stride)): + neighbor_ids = [ + i for i in range(max(0, f - neighbor_stride), + min(video_length, f + neighbor_stride + 1)) + ] + ref_ids = get_ref_index(f, neighbor_ids, video_length, ref_stride, ref_num) + selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :] + selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :] + selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :] + selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :]) + + with torch.no_grad(): + # 1.0 indicates mask + l_t = len(neighbor_ids) + + # pred_img = selected_imgs # results of image propagation + pred_img = self.model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t) + + pred_img = pred_img.view(-1, 3, h, w) + + pred_img = (pred_img + 1) / 2 + pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 + binary_masks = masks_dilated[0, neighbor_ids, :, :, :].cpu().permute( + 0, 2, 3, 1).numpy().astype(np.uint8) + for i in range(len(neighbor_ids)): + idx = neighbor_ids[i] + img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \ + + ori_frames[idx] * (1 - binary_masks[i]) + if comp_frames[idx] is None: + comp_frames[idx] = img + else: + comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5 + + comp_frames[idx] = comp_frames[idx].astype(np.uint8) + + torch.cuda.empty_cache() + + # need to return numpy array, T, H, W, 3 + comp_frames = [cv2.resize(f, out_size) for f in comp_frames] + + return comp_frames diff --git a/web-demos/hugging_face/requirements.txt b/web-demos/hugging_face/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..dc6195ace6f2979ca8d30d87a27f474b256146b5 --- /dev/null +++ b/web-demos/hugging_face/requirements.txt @@ -0,0 +1,16 @@ +progressbar2 +gdown +gitpython +git+https://github.com/cheind/py-thin-plate-spline +hickle +tensorboard +numpy +git+https://github.com/facebookresearch/segment-anything.git +gradio +opencv-python +matplotlib +pyyaml +av +openmim +tqdm +psutil \ No newline at end of file diff --git a/web-demos/hugging_face/test_sample/test-sample0.mp4 b/web-demos/hugging_face/test_sample/test-sample0.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..44401e257335372f1cda43498366cdae88c613d1 --- /dev/null +++ b/web-demos/hugging_face/test_sample/test-sample0.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb31e022c1bce2725aba61140eefc945b46c3565d39809a4ce1ea7151a7bf5d7 +size 336842 diff --git a/web-demos/hugging_face/test_sample/test-sample1.mp4 b/web-demos/hugging_face/test_sample/test-sample1.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..335850916b8c816bbe8b0533c594deb62daa5cd7 --- /dev/null +++ b/web-demos/hugging_face/test_sample/test-sample1.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0856abcba55f21e41830a58bde1681c6be5790a0dd1431d844e29687b36d891 +size 1163525 diff --git a/web-demos/hugging_face/test_sample/test-sample2.mp4 b/web-demos/hugging_face/test_sample/test-sample2.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8f79c29bbb8c60d5f5342daadf0ea798ba7086d8 --- /dev/null +++ b/web-demos/hugging_face/test_sample/test-sample2.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3cab99f0dfb741937f4c0c53fbbff5b040bb158d6186fe275e6c9d36c675f419 +size 316597 diff --git a/web-demos/hugging_face/test_sample/test-sample3.mp4 b/web-demos/hugging_face/test_sample/test-sample3.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..57b915e4b7f1e4c8d4517420323f3e05a053ccac --- /dev/null +++ b/web-demos/hugging_face/test_sample/test-sample3.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:279cdd4416a6cb8bd037369c34516ac37ce46b304f3f974884ead2194087fc5e +size 345185 diff --git a/web-demos/hugging_face/test_sample/test-sample4.mp4 b/web-demos/hugging_face/test_sample/test-sample4.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ac3b3c35d55b9b51325bc4e793aa7adcff8b5839 --- /dev/null +++ b/web-demos/hugging_face/test_sample/test-sample4.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:896c1a10230caaf913978e480378343d4992d80a2508eae8abe4f23f56ba7feb +size 1163525 diff --git a/web-demos/hugging_face/tools/__init__.py b/web-demos/hugging_face/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/web-demos/hugging_face/tools/base_segmenter.py b/web-demos/hugging_face/tools/base_segmenter.py new file mode 100644 index 0000000000000000000000000000000000000000..2b975bb779b47485f9e6ba7435646b4db40a2c6a --- /dev/null +++ b/web-demos/hugging_face/tools/base_segmenter.py @@ -0,0 +1,129 @@ +import time +import torch +import cv2 +from PIL import Image, ImageDraw, ImageOps +import numpy as np +from typing import Union +from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator +import matplotlib.pyplot as plt +import PIL +from .mask_painter import mask_painter + + +class BaseSegmenter: + def __init__(self, SAM_checkpoint, model_type, device='cuda:0'): + """ + device: model device + SAM_checkpoint: path of SAM checkpoint + model_type: vit_b, vit_l, vit_h + """ + print(f"Initializing BaseSegmenter to {device}") + assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h' + + self.device = device + self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint) + self.model.to(device=self.device) + self.predictor = SamPredictor(self.model) + self.embedded = False + + @torch.no_grad() + def set_image(self, image: np.ndarray): + # PIL.open(image_path) 3channel: RGB + # image embedding: avoid encode the same image multiple times + self.orignal_image = image + if self.embedded: + print('repeat embedding, please reset_image.') + return + self.predictor.set_image(image) + self.embedded = True + return + + @torch.no_grad() + def reset_image(self): + # reset image embeding + self.predictor.reset_image() + self.embedded = False + + def predict(self, prompts, mode, multimask=True): + """ + image: numpy array, h, w, 3 + prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input' + prompts['point_coords']: numpy array [N,2] + prompts['point_labels']: numpy array [1,N] + prompts['mask_input']: numpy array [1,256,256] + mode: 'point' (points only), 'mask' (mask only), 'both' (consider both) + mask_outputs: True (return 3 masks), False (return 1 mask only) + whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :] + """ + assert self.embedded, 'prediction is called before set_image (feature embedding).' + assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both' + + if mode == 'point': + masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'], + point_labels=prompts['point_labels'], + multimask_output=multimask) + elif mode == 'mask': + masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'], + multimask_output=multimask) + elif mode == 'both': # both + masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'], + point_labels=prompts['point_labels'], + mask_input=prompts['mask_input'], + multimask_output=multimask) + else: + raise("Not implement now!") + # masks (n, h, w), scores (n,), logits (n, 256, 256) + return masks, scores, logits + + +if __name__ == "__main__": + # load and show an image + image = cv2.imread('/hhd3/gaoshang/truck.jpg') + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3) + + # initialise BaseSegmenter + SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth' + model_type = 'vit_h' + device = "cuda:4" + base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device) + + # image embedding (once embedded, multiple prompts can be applied) + base_segmenter.set_image(image) + + # examples + # point only ------------------------ + mode = 'point' + prompts = { + 'point_coords': np.array([[500, 375], [1125, 625]]), + 'point_labels': np.array([1, 1]), + } + masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256) + painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) + painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) + cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image) + + # both ------------------------ + mode = 'both' + mask_input = logits[np.argmax(scores), :, :] + prompts = {'mask_input': mask_input [None, :, :]} + prompts = { + 'point_coords': np.array([[500, 375], [1125, 625]]), + 'point_labels': np.array([1, 0]), + 'mask_input': mask_input[None, :, :] + } + masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256) + painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) + painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) + cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image) + + # mask only ------------------------ + mode = 'mask' + mask_input = logits[np.argmax(scores), :, :] + + prompts = {'mask_input': mask_input[None, :, :]} + + masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256) + painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) + painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) + cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image) diff --git a/web-demos/hugging_face/tools/interact_tools.py b/web-demos/hugging_face/tools/interact_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..c70b8c40c6c57f19242b42bfcde60f378b1ce7ba --- /dev/null +++ b/web-demos/hugging_face/tools/interact_tools.py @@ -0,0 +1,99 @@ +import time +import torch +import cv2 +from PIL import Image, ImageDraw, ImageOps +import numpy as np +from typing import Union +from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator +import matplotlib.pyplot as plt +import PIL +from .mask_painter import mask_painter as mask_painter2 +from .base_segmenter import BaseSegmenter +from .painter import mask_painter, point_painter +import os +import requests +import sys + + +mask_color = 3 +mask_alpha = 0.7 +contour_color = 1 +contour_width = 5 +point_color_ne = 8 +point_color_ps = 50 +point_alpha = 0.9 +point_radius = 15 +contour_color = 2 +contour_width = 5 + + +class SamControler(): + def __init__(self, SAM_checkpoint, model_type, device): + ''' + initialize sam controler + ''' + self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device) + + + # def seg_again(self, image: np.ndarray): + # ''' + # it is used when interact in video + # ''' + # self.sam_controler.reset_image() + # self.sam_controler.set_image(image) + # return + + + def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3): + ''' + it is used in first frame in video + return: mask, logit, painted image(mask+point) + ''' + # self.sam_controler.set_image(image) + origal_image = self.sam_controler.orignal_image + neg_flag = labels[-1] + if neg_flag==1: + #find neg + prompts = { + 'point_coords': points, + 'point_labels': labels, + } + masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) + mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + prompts = { + 'point_coords': points, + 'point_labels': labels, + 'mask_input': logit[None, :, :] + } + masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask) + mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + else: + #find positive + prompts = { + 'point_coords': points, + 'point_labels': labels, + } + masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) + mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + + + assert len(points)==len(labels) + + painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) + painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) + painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) + painted_image = Image.fromarray(painted_image) + + return mask, logit, painted_image + + + + + + + + + + + + \ No newline at end of file diff --git a/web-demos/hugging_face/tools/mask_painter.py b/web-demos/hugging_face/tools/mask_painter.py new file mode 100644 index 0000000000000000000000000000000000000000..f471ea0116d656e2cc236832893b07c6d7be1643 --- /dev/null +++ b/web-demos/hugging_face/tools/mask_painter.py @@ -0,0 +1,288 @@ +import cv2 +import torch +import numpy as np +from PIL import Image +import copy +import time + + +def colormap(rgb=True): + color_list = np.array( + [ + 0.000, 0.000, 0.000, + 1.000, 1.000, 1.000, + 1.000, 0.498, 0.313, + 0.392, 0.581, 0.929, + 0.000, 0.447, 0.741, + 0.850, 0.325, 0.098, + 0.929, 0.694, 0.125, + 0.494, 0.184, 0.556, + 0.466, 0.674, 0.188, + 0.301, 0.745, 0.933, + 0.635, 0.078, 0.184, + 0.300, 0.300, 0.300, + 0.600, 0.600, 0.600, + 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, + 0.749, 0.749, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 1.000, + 0.667, 0.000, 1.000, + 0.333, 0.333, 0.000, + 0.333, 0.667, 0.000, + 0.333, 1.000, 0.000, + 0.667, 0.333, 0.000, + 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, + 1.000, 0.333, 0.000, + 1.000, 0.667, 0.000, + 1.000, 1.000, 0.000, + 0.000, 0.333, 0.500, + 0.000, 0.667, 0.500, + 0.000, 1.000, 0.500, + 0.333, 0.000, 0.500, + 0.333, 0.333, 0.500, + 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, + 0.667, 0.000, 0.500, + 0.667, 0.333, 0.500, + 0.667, 0.667, 0.500, + 0.667, 1.000, 0.500, + 1.000, 0.000, 0.500, + 1.000, 0.333, 0.500, + 1.000, 0.667, 0.500, + 1.000, 1.000, 0.500, + 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, + 0.000, 1.000, 1.000, + 0.333, 0.000, 1.000, + 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, + 0.333, 1.000, 1.000, + 0.667, 0.000, 1.000, + 0.667, 0.333, 1.000, + 0.667, 0.667, 1.000, + 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, + 1.000, 0.333, 1.000, + 1.000, 0.667, 1.000, + 0.167, 0.000, 0.000, + 0.333, 0.000, 0.000, + 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, + 1.000, 0.000, 0.000, + 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, + 0.000, 0.500, 0.000, + 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 0.167, + 0.000, 0.000, 0.333, + 0.000, 0.000, 0.500, + 0.000, 0.000, 0.667, + 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, + 0.143, 0.143, 0.143, + 0.286, 0.286, 0.286, + 0.429, 0.429, 0.429, + 0.571, 0.571, 0.571, + 0.714, 0.714, 0.714, + 0.857, 0.857, 0.857 + ] + ).astype(np.float32) + color_list = color_list.reshape((-1, 3)) * 255 + if not rgb: + color_list = color_list[:, ::-1] + return color_list + + +color_list = colormap() +color_list = color_list.astype('uint8').tolist() + + +def vis_add_mask(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha): + background_color = np.array(background_color) + contour_color = np.array(contour_color) + + # background_mask = 1 - background_mask + # contour_mask = 1 - contour_mask + + for i in range(3): + image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \ + + background_color[i] * (background_alpha-background_mask*background_alpha) + + image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \ + + contour_color[i] * (contour_alpha-contour_mask*contour_alpha) + + return image.astype('uint8') + + +def mask_generator_00(mask, background_radius, contour_radius): + # no background width when '00' + # distance map + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + contour_mask[contour_mask>0.5] = 1. + + return mask, contour_mask + + +def mask_generator_01(mask, background_radius, contour_radius): + # no background width when '00' + # distance map + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + return mask, contour_mask + + +def mask_generator_10(mask, background_radius, contour_radius): + # distance map + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # .....:::::!!!!! + background_mask = np.clip(dist_map, -background_radius, background_radius) + background_mask = (background_mask - np.min(background_mask)) + background_mask = background_mask / np.max(background_mask) + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + contour_mask[contour_mask>0.5] = 1. + return background_mask, contour_mask + + +def mask_generator_11(mask, background_radius, contour_radius): + # distance map + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # .....:::::!!!!! + background_mask = np.clip(dist_map, -background_radius, background_radius) + background_mask = (background_mask - np.min(background_mask)) + background_mask = background_mask / np.max(background_mask) + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + return background_mask, contour_mask + + +def mask_painter(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'): + """ + Input: + input_image: numpy array + input_mask: numpy array + background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing + background_blur_radius: radius of background blur, must be odd number + contour_width: width of mask contour, must be odd number + contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others + contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted + mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both + + Output: + painted_image: numpy array + """ + assert input_image.shape[:2] == input_mask.shape, 'different shape' + assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD' + assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11' + + # downsample input image and mask + width, height = input_image.shape[0], input_image.shape[1] + res = 1024 + ratio = min(1.0 * res / max(width, height), 1.0) + input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio))) + input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio))) + + # 0: background, 1: foreground + msk = np.clip(input_mask, 0, 1) + + # generate masks for background and contour pixels + background_radius = (background_blur_radius - 1) // 2 + contour_radius = (contour_width - 1) // 2 + generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11} + background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius) + + # paint + painted_image = vis_add_mask\ + (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background + + return painted_image + + +if __name__ == '__main__': + + background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing + background_blur_radius = 31 # radius of background blur, must be odd number + contour_width = 11 # contour width, must be odd number + contour_color = 3 # id in color map, 0: black, 1: white, >1: others + contour_alpha = 1 # transparency of background, 0: no contour highlighted + + # load input image and mask + input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB')) + input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P')) + + # paint + overall_time_1 = 0 + overall_time_2 = 0 + overall_time_3 = 0 + overall_time_4 = 0 + overall_time_5 = 0 + + for i in range(50): + t2 = time.time() + painted_image_00 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00') + e2 = time.time() + + t3 = time.time() + painted_image_10 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10') + e3 = time.time() + + t1 = time.time() + painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha) + e1 = time.time() + + t4 = time.time() + painted_image_01 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01') + e4 = time.time() + + t5 = time.time() + painted_image_11 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11') + e5 = time.time() + + overall_time_1 += (e1 - t1) + overall_time_2 += (e2 - t2) + overall_time_3 += (e3 - t3) + overall_time_4 += (e4 - t4) + overall_time_5 += (e5 - t5) + + print(f'average time w gaussian: {overall_time_1/50}') + print(f'average time w/o gaussian00: {overall_time_2/50}') + print(f'average time w/o gaussian10: {overall_time_3/50}') + print(f'average time w/o gaussian01: {overall_time_4/50}') + print(f'average time w/o gaussian11: {overall_time_5/50}') + + # save + painted_image_00 = Image.fromarray(painted_image_00) + painted_image_00.save('./test_img/painter_output_image_00.png') + + painted_image_10 = Image.fromarray(painted_image_10) + painted_image_10.save('./test_img/painter_output_image_10.png') + + painted_image_01 = Image.fromarray(painted_image_01) + painted_image_01.save('./test_img/painter_output_image_01.png') + + painted_image_11 = Image.fromarray(painted_image_11) + painted_image_11.save('./test_img/painter_output_image_11.png') diff --git a/web-demos/hugging_face/tools/painter.py b/web-demos/hugging_face/tools/painter.py new file mode 100644 index 0000000000000000000000000000000000000000..0e711d35aa8348d15cdad9d1cd413da41ea4f1ab --- /dev/null +++ b/web-demos/hugging_face/tools/painter.py @@ -0,0 +1,215 @@ +# paint masks, contours, or points on images, with specified colors +import cv2 +import torch +import numpy as np +from PIL import Image +import copy +import time + + +def colormap(rgb=True): + color_list = np.array( + [ + 0.000, 0.000, 0.000, + 1.000, 1.000, 1.000, + 1.000, 0.498, 0.313, + 0.392, 0.581, 0.929, + 0.000, 0.447, 0.741, + 0.850, 0.325, 0.098, + 0.929, 0.694, 0.125, + 0.494, 0.184, 0.556, + 0.466, 0.674, 0.188, + 0.301, 0.745, 0.933, + 0.635, 0.078, 0.184, + 0.300, 0.300, 0.300, + 0.600, 0.600, 0.600, + 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, + 0.749, 0.749, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 1.000, + 0.667, 0.000, 1.000, + 0.333, 0.333, 0.000, + 0.333, 0.667, 0.000, + 0.333, 1.000, 0.000, + 0.667, 0.333, 0.000, + 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, + 1.000, 0.333, 0.000, + 1.000, 0.667, 0.000, + 1.000, 1.000, 0.000, + 0.000, 0.333, 0.500, + 0.000, 0.667, 0.500, + 0.000, 1.000, 0.500, + 0.333, 0.000, 0.500, + 0.333, 0.333, 0.500, + 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, + 0.667, 0.000, 0.500, + 0.667, 0.333, 0.500, + 0.667, 0.667, 0.500, + 0.667, 1.000, 0.500, + 1.000, 0.000, 0.500, + 1.000, 0.333, 0.500, + 1.000, 0.667, 0.500, + 1.000, 1.000, 0.500, + 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, + 0.000, 1.000, 1.000, + 0.333, 0.000, 1.000, + 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, + 0.333, 1.000, 1.000, + 0.667, 0.000, 1.000, + 0.667, 0.333, 1.000, + 0.667, 0.667, 1.000, + 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, + 1.000, 0.333, 1.000, + 1.000, 0.667, 1.000, + 0.167, 0.000, 0.000, + 0.333, 0.000, 0.000, + 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, + 1.000, 0.000, 0.000, + 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, + 0.000, 0.500, 0.000, + 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 0.167, + 0.000, 0.000, 0.333, + 0.000, 0.000, 0.500, + 0.000, 0.000, 0.667, + 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, + 0.143, 0.143, 0.143, + 0.286, 0.286, 0.286, + 0.429, 0.429, 0.429, + 0.571, 0.571, 0.571, + 0.714, 0.714, 0.714, + 0.857, 0.857, 0.857 + ] + ).astype(np.float32) + color_list = color_list.reshape((-1, 3)) * 255 + if not rgb: + color_list = color_list[:, ::-1] + return color_list + + +color_list = colormap() +color_list = color_list.astype('uint8').tolist() + + +def vis_add_mask(image, mask, color, alpha): + color = np.array(color_list[color]) + mask = mask > 0.5 + image[mask] = image[mask] * (1-alpha) + color * alpha + return image.astype('uint8') + +def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, point_radius=15, contour_color=2, contour_width=5): + h, w = input_image.shape[:2] + point_mask = np.zeros((h, w)).astype('uint8') + for point in input_points: + point_mask[point[1], point[0]] = 1 + + kernel = cv2.getStructuringElement(2, (point_radius, point_radius)) + point_mask = cv2.dilate(point_mask, kernel) + + contour_radius = (contour_width - 1) // 2 + dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-point_mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + contour_mask[contour_mask>0.5] = 1. + + # paint mask + painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha) + # paint contour + painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1) + return painted_image + +def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3): + assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask' + # 0: background, 1: foreground + mask = np.clip(input_mask, 0, 1) + contour_radius = (contour_width - 1) // 2 + + dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) + dist_map = dist_transform_fore - dist_transform_back + # ...:::!!!:::... + contour_radius += 2 + contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) + contour_mask = contour_mask / np.max(contour_mask) + contour_mask[contour_mask>0.5] = 1. + + # paint mask + painted_image = vis_add_mask(input_image.copy(), mask.copy(), mask_color, mask_alpha) + # paint contour + painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1) + + return painted_image + +def background_remover(input_image, input_mask): + """ + input_image: H, W, 3, np.array + input_mask: H, W, np.array + + image_wo_background: PIL.Image + """ + assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask' + # 0: background, 1: foreground + mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2)*255 + image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4 + image_wo_background = Image.fromarray(image_wo_background).convert('RGBA') + + return image_wo_background + +if __name__ == '__main__': + input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB')) + input_mask = np.array(Image.open('images/painter_input_mask.jpg').convert('P')) + + # example of mask painter + mask_color = 3 + mask_alpha = 0.7 + contour_color = 1 + contour_width = 5 + + # save + painted_image = Image.fromarray(input_image) + painted_image.save('images/original.png') + + painted_image = mask_painter(input_image, input_mask, mask_color, mask_alpha, contour_color, contour_width) + # save + painted_image = Image.fromarray(input_image) + painted_image.save('images/original1.png') + + # example of point painter + input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB')) + input_points = np.array([[500, 375], [70, 600]]) # x, y + point_color = 5 + point_alpha = 0.9 + point_radius = 15 + contour_color = 2 + contour_width = 5 + painted_image_1 = point_painter(input_image, input_points, point_color, point_alpha, point_radius, contour_color, contour_width) + # save + painted_image = Image.fromarray(painted_image_1) + painted_image.save('images/point_painter_1.png') + + input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB')) + painted_image_2 = point_painter(input_image, input_points, point_color=9, point_radius=20, contour_color=29) + # save + painted_image = Image.fromarray(painted_image_2) + painted_image.save('images/point_painter_2.png') + + # example of background remover + input_image = np.array(Image.open('images/original.png').convert('RGB')) + image_wo_background = background_remover(input_image, input_mask) # return PIL.Image + image_wo_background.save('images/image_wo_background.png') diff --git a/web-demos/hugging_face/track_anything.py b/web-demos/hugging_face/track_anything.py new file mode 100644 index 0000000000000000000000000000000000000000..deeb88e647bec7260db1a6de55aa86f32723b72e --- /dev/null +++ b/web-demos/hugging_face/track_anything.py @@ -0,0 +1,40 @@ +import numpy as np +from tqdm import tqdm + +from tools.interact_tools import SamControler +from tracker.base_tracker import BaseTracker +from inpainter.base_inpainter import ProInpainter + + +class TrackingAnything(): + def __init__(self, sam_checkpoint, cutie_checkpoint, propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, args): + self.args = args + self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device) + self.cutie = BaseTracker(cutie_checkpoint, device=args.device) + self.baseinpainter = ProInpainter(propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, args.device) + + def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): + mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask) + return mask, logit, painted_image + + def generator(self, images: list, template_mask:np.ndarray): + masks = [] + logits = [] + painted_images = [] + for i in tqdm(range(len(images)), desc="Tracking image"): + if i==0: + mask, logit, painted_image = self.cutie.track(images[i], template_mask) + masks.append(mask) + logits.append(logit) + painted_images.append(painted_image) + else: + mask, logit, painted_image = self.cutie.track(images[i]) + masks.append(mask) + logits.append(logit) + painted_images.append(painted_image) + return masks, logits, painted_images + + + + + \ No newline at end of file diff --git a/web-demos/hugging_face/tracker/base_tracker.py b/web-demos/hugging_face/tracker/base_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec527326c1df4b4bfb9fdd90bc925460cca39f2 --- /dev/null +++ b/web-demos/hugging_face/tracker/base_tracker.py @@ -0,0 +1,103 @@ +import numpy as np +import torch +import torch.nn.functional as F +from omegaconf import OmegaConf + +import sys +sys.path.append('../') + +from tracker.config import CONFIG +from tracker.model.cutie import CUTIE +from tracker.inference.inference_core import InferenceCore +from tracker.utils.mask_mapper import MaskMapper + +from tools.painter import mask_painter + + +class BaseTracker: + def __init__(self, cutie_checkpoint, device) -> None: + """ + device: model device + cutie_checkpoint: checkpoint of XMem model + """ + config = OmegaConf.create(CONFIG) + + # initialise XMem + network = CUTIE(config).to(device).eval() + model_weights = torch.load(cutie_checkpoint, map_location=device) + network.load_weights(model_weights) + + # initialise IncerenceCore + self.tracker = InferenceCore(network, config) + self.device = device + + # changable properties + self.mapper = MaskMapper() + self.initialised = False + + @torch.no_grad() + def resize_mask(self, mask): + # mask transform is applied AFTER mapper, so we need to post-process it in eval.py + h, w = mask.shape[-2:] + min_hw = min(h, w) + return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)), + mode='nearest') + + @torch.no_grad() + def image_to_torch(self, frame: np.ndarray, device: str = 'cuda'): + # frame: H*W*3 numpy array + frame = frame.transpose(2, 0, 1) + frame = torch.from_numpy(frame).float().to(device, non_blocking=True) / 255 + return frame + + @torch.no_grad() + def track(self, frame, first_frame_annotation=None): + """ + Input: + frames: numpy arrays (H, W, 3) + logit: numpy array (H, W), logit + + Output: + mask: numpy arrays (H, W) + logit: numpy arrays, probability map (H, W) + painted_image: numpy array (H, W, 3) + """ + + if first_frame_annotation is not None: # first frame mask + # initialisation + mask, labels = self.mapper.convert_mask(first_frame_annotation) + mask = torch.Tensor(mask).to(self.device) + else: + mask = None + labels = None + + # prepare inputs + frame_tensor = self.image_to_torch(frame, self.device) + + # track one frame + probs = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W + + # convert to mask + out_mask = torch.argmax(probs, dim=0) + out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) + + final_mask = np.zeros_like(out_mask) + + # map back + for k, v in self.mapper.remappings.items(): + final_mask[out_mask == v] = k + + num_objs = final_mask.max() + painted_image = frame + for obj in range(1, num_objs+1): + if np.max(final_mask==obj) == 0: + continue + painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1) + + return final_mask, final_mask, painted_image + + @torch.no_grad() + def clear_memory(self): + self.tracker.clear_memory() + self.mapper.clear_labels() + torch.cuda.empty_cache() \ No newline at end of file diff --git a/web-demos/hugging_face/tracker/config/__init__.py b/web-demos/hugging_face/tracker/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a26a5c2f36fabe3c9870cf0c1ba3c4414e15cd42 --- /dev/null +++ b/web-demos/hugging_face/tracker/config/__init__.py @@ -0,0 +1 @@ +CONFIG = {'exp_id': 'default', 'dataset': 'd17-val', 'amp': False, 'output_dir': None, 'flip_aug': False, 'max_internal_size': -1, 'image_directory': None, 'mask_directory': None, 'json_directory': None, 'size': None, 'save_all': None, 'use_all_masks': None, 'use_long_term': None, 'mem_every': 5, 'max_mem_frames': 5, 'long_term': {'count_usage': True, 'max_mem_frames': 10, 'min_mem_frames': 5, 'num_prototypes': 128, 'max_num_tokens': 10000, 'buffer_tokens': 2000}, 'top_k': 30, 'stagger_updates': 5, 'chunk_size': -1, 'save_scores': False, 'save_aux': False, 'visualize': False, 'model': {'pixel_mean': [0.485, 0.456, 0.406], 'pixel_std': [0.229, 0.224, 0.225], 'pixel_dim': 256, 'key_dim': 64, 'value_dim': 256, 'sensory_dim': 256, 'embed_dim': 256, 'pixel_encoder': {'type': 'resnet50', 'ms_dims': [1024, 512, 256]}, 'mask_encoder': {'type': 'resnet18', 'final_dim': 256}, 'pixel_pe_scale': 32, 'pixel_pe_temperature': 128, 'object_transformer': {'embed_dim': '${model.embed_dim}', 'ff_dim': 2048, 'num_heads': 8, 'num_blocks': 3, 'num_queries': 16, 'read_from_pixel': {'input_norm': False, 'input_add_pe': False, 'add_pe_to_qkv': [True, True, False]}, 'read_from_past': {'add_pe_to_qkv': [True, True, False]}, 'read_from_memory': {'add_pe_to_qkv': [True, True, False]}, 'read_from_query': {'add_pe_to_qkv': [True, True, False], 'output_norm': False}, 'query_self_attention': {'add_pe_to_qkv': [True, True, False]}, 'pixel_self_attention': {'add_pe_to_qkv': [True, True, False]}}, 'object_summarizer': {'embed_dim': '${model.object_transformer.embed_dim}', 'num_summaries': '${model.object_transformer.num_queries}', 'add_pe': True}, 'aux_loss': {'sensory': {'enabled': True, 'weight': 0.01}, 'query': {'enabled': True, 'weight': 0.01}}, 'mask_decoder': {'up_dims': [256, 128, 128]}}} \ No newline at end of file diff --git a/web-demos/hugging_face/tracker/inference/__init__.py b/web-demos/hugging_face/tracker/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/web-demos/hugging_face/tracker/inference/image_feature_store.py b/web-demos/hugging_face/tracker/inference/image_feature_store.py new file mode 100644 index 0000000000000000000000000000000000000000..1d02a634b009bb06b935e49036adb121325f5816 --- /dev/null +++ b/web-demos/hugging_face/tracker/inference/image_feature_store.py @@ -0,0 +1,49 @@ +import warnings +from typing import Iterable +import torch +from tracker.model.cutie import CUTIE + + +class ImageFeatureStore: + """ + A cache for image features. + These features might be reused at different parts of the inference pipeline. + This class provide an interface for reusing these features. + It is the user's responsibility to delete redundant features. + + Feature of a frame should be associated with a unique index -- typically the frame id. + """ + def __init__(self, network: CUTIE, no_warning: bool = False): + self.network = network + self._store = {} + self.no_warning = no_warning + + def _encode_feature(self, index: int, image: torch.Tensor) -> None: + ms_features, pix_feat = self.network.encode_image(image) + key, shrinkage, selection = self.network.transform_key(ms_features[0]) + self._store[index] = (ms_features, pix_feat, key, shrinkage, selection) + + def get_features(self, index: int, + image: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor): + if index not in self._store: + self._encode_feature(index, image) + + return self._store[index][:2] + + def get_key(self, index: int, + image: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor): + if index not in self._store: + self._encode_feature(index, image) + + return self._store[index][2:] + + def delete(self, index: int) -> None: + if index in self._store: + del self._store[index] + + def __len__(self): + return len(self._store) + + def __del__(self): + if len(self._store) > 0 and not self.no_warning: + warnings.warn(f'Leaking {self._store.keys()} in the image feature store') diff --git a/web-demos/hugging_face/tracker/inference/inference_core.py b/web-demos/hugging_face/tracker/inference/inference_core.py new file mode 100644 index 0000000000000000000000000000000000000000..d7d5b24cf9aaed19656a194b1d6dfc0aaffe91c6 --- /dev/null +++ b/web-demos/hugging_face/tracker/inference/inference_core.py @@ -0,0 +1,381 @@ +from typing import List, Optional, Iterable, Dict +import logging +from omegaconf import DictConfig + +import numpy as np +import torch +import torch.nn.functional as F + +from tracker.inference.memory_manager import MemoryManager +from tracker.inference.object_manager import ObjectManager +from tracker.inference.image_feature_store import ImageFeatureStore +from tracker.model.cutie import CUTIE +from tracker.utils.tensor_utils import pad_divide_by, unpad, aggregate + +log = logging.getLogger() + + +class InferenceCore: + def __init__(self, + network: CUTIE, + cfg: DictConfig, + *, + image_feature_store: ImageFeatureStore = None): + self.network = network + self.cfg = cfg + self.mem_every = cfg.mem_every + stagger_updates = cfg.stagger_updates + self.chunk_size = cfg.chunk_size + self.save_aux = cfg.save_aux + self.max_internal_size = cfg.max_internal_size + self.flip_aug = cfg.flip_aug + + self.curr_ti = -1 + self.last_mem_ti = 0 + # at which time indices should we update the sensory memory + if stagger_updates >= self.mem_every: + self.stagger_ti = set(range(1, self.mem_every + 1)) + else: + self.stagger_ti = set( + np.round(np.linspace(1, self.mem_every, stagger_updates)).astype(int)) + self.object_manager = ObjectManager() + self.memory = MemoryManager(cfg=cfg, object_manager=self.object_manager) + + if image_feature_store is None: + self.image_feature_store = ImageFeatureStore(self.network) + else: + self.image_feature_store = image_feature_store + + self.last_mask = None + + def clear_memory(self): + self.curr_ti = -1 + self.last_mem_ti = 0 + self.memory = MemoryManager(cfg=self.cfg, object_manager=self.object_manager) + + def clear_non_permanent_memory(self): + self.curr_ti = -1 + self.last_mem_ti = 0 + self.memory.clear_non_permanent_memory() + + def clear_sensory_memory(self): + self.curr_ti = -1 + self.last_mem_ti = 0 + self.memory.clear_sensory_memory() + + def update_config(self, cfg): + self.mem_every = cfg['mem_every'] + self.memory.update_config(cfg) + + def _add_memory(self, + image: torch.Tensor, + pix_feat: torch.Tensor, + prob: torch.Tensor, + key: torch.Tensor, + shrinkage: torch.Tensor, + selection: torch.Tensor, + *, + is_deep_update: bool = True, + force_permanent: bool = False) -> None: + """ + Memorize the given segmentation in all memory stores. + + The batch dimension is 1 if flip augmentation is not used. + image: RGB image, (1/2)*3*H*W + pix_feat: from the key encoder, (1/2)*_*H*W + prob: (1/2)*num_objects*H*W, in [0, 1] + key/shrinkage/selection: for anisotropic l2, (1/2)*_*H*W + selection can be None if not using long-term memory + is_deep_update: whether to use deep update (e.g. with the mask encoder) + force_permanent: whether to force the memory to be permanent + """ + if prob.shape[1] == 0: + # nothing to add + log.warn('Trying to add an empty object mask to memory!') + return + + if force_permanent: + as_permanent = 'all' + else: + as_permanent = 'first' + + self.memory.initialize_sensory_if_needed(key, self.object_manager.all_obj_ids) + msk_value, sensory, obj_value, self.obj_logits = self.network.encode_mask( + image, + pix_feat, + self.memory.get_sensory(self.object_manager.all_obj_ids), + prob, + deep_update=is_deep_update, + chunk_size=self.chunk_size, + need_weights=self.save_aux) + self.memory.add_memory(key, + shrinkage, + msk_value, + obj_value, + self.object_manager.all_obj_ids, + selection=selection, + as_permanent=as_permanent) + self.last_mem_ti = self.curr_ti + if is_deep_update: + self.memory.update_sensory(sensory, self.object_manager.all_obj_ids) + + def _segment(self, + key: torch.Tensor, + selection: torch.Tensor, + pix_feat: torch.Tensor, + ms_features: Iterable[torch.Tensor], + update_sensory: bool = True) -> torch.Tensor: + """ + Produce a segmentation using the given features and the memory + + The batch dimension is 1 if flip augmentation is not used. + key/selection: for anisotropic l2: (1/2) * _ * H * W + pix_feat: from the key encoder, (1/2) * _ * H * W + ms_features: an iterable of multiscale features from the encoder, each is (1/2)*_*H*W + with strides 16, 8, and 4 respectively + update_sensory: whether to update the sensory memory + + Returns: (num_objects+1)*H*W normalized probability; the first channel is the background + """ + bs = key.shape[0] + if self.flip_aug: + assert bs == 2 + else: + assert bs == 1 + + if not self.memory.engaged: + log.warn('Trying to segment without any memory!') + return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16), + device=key.device, + dtype=key.dtype) + + memory_readout = self.memory.read(pix_feat, key, selection, self.last_mask, self.network) + memory_readout = self.object_manager.realize_dict(memory_readout) + sensory, _, pred_prob_with_bg = self.network.segment(ms_features, + memory_readout, + self.memory.get_sensory( + self.object_manager.all_obj_ids), + chunk_size=self.chunk_size, + update_sensory=update_sensory) + # remove batch dim + if self.flip_aug: + # average predictions of the non-flipped and flipped version + pred_prob_with_bg = (pred_prob_with_bg[0] + + torch.flip(pred_prob_with_bg[1], dims=[-1])) / 2 + else: + pred_prob_with_bg = pred_prob_with_bg[0] + if update_sensory: + self.memory.update_sensory(sensory, self.object_manager.all_obj_ids) + return pred_prob_with_bg + + def step(self, + image: torch.Tensor, + mask: Optional[torch.Tensor] = None, + objects: Optional[List[int]] = None, + *, + idx_mask: bool = True, + end: bool = False, + delete_buffer: bool = True, + force_permanent: bool = False) -> torch.Tensor: + """ + Take a step with a new incoming image. + If there is an incoming mask with new objects, we will memorize them. + If there is no incoming mask, we will segment the image using the memory. + In both cases, we will update the memory and return a segmentation. + + image: 3*H*W + mask: H*W (if idx mask) or len(objects)*H*W or None + objects: list of object ids that are valid in the mask Tensor. + The ids themselves do not need to be consecutive/in order, but they need to be + in the same position in the list as the corresponding mask + in the tensor in non-idx-mask mode. + objects is ignored if the mask is None. + If idx_mask is False and objects is None, we sequentially infer the object ids. + idx_mask: if True, mask is expected to contain an object id at every pixel. + If False, mask should have multiple channels with each channel representing one object. + end: if we are at the end of the sequence, we do not need to update memory + if unsure just set it to False + delete_buffer: whether to delete the image feature buffer after this step + force_permanent: the memory recorded this frame will be added to the permanent memory + """ + if objects is None and mask is not None: + assert not idx_mask + objects = list(range(1, mask.shape[0] + 1)) + + # resize input if needed -- currently only used for the GUI + resize_needed = False + if self.max_internal_size > 0: + h, w = image.shape[-2:] + min_side = min(h, w) + if min_side > self.max_internal_size: + resize_needed = True + new_h = int(h / min_side * self.max_internal_size) + new_w = int(w / min_side * self.max_internal_size) + image = F.interpolate(image.unsqueeze(0), + size=(new_h, new_w), + mode='bilinear', + align_corners=False)[0] + if mask is not None: + if idx_mask: + mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), + size=(new_h, new_w), + mode='nearest', + align_corners=False)[0, 0].round().long() + else: + mask = F.interpolate(mask.unsqueeze(0), + size=(new_h, new_w), + mode='bilinear', + align_corners=False)[0] + + self.curr_ti += 1 + + image, self.pad = pad_divide_by(image, 16) + image = image.unsqueeze(0) # add the batch dimension + if self.flip_aug: + image = torch.cat([image, torch.flip(image, dims=[-1])], dim=0) + + # whether to update the working memory + is_mem_frame = ((self.curr_ti - self.last_mem_ti >= self.mem_every) or + (mask is not None)) and (not end) + # segment when there is no input mask or when the input mask is incomplete + need_segment = (mask is None) or (self.object_manager.num_obj > 0 + and not self.object_manager.has_all(objects)) + update_sensory = ((self.curr_ti - self.last_mem_ti) in self.stagger_ti) and (not end) + + # encoding the image + ms_feat, pix_feat = self.image_feature_store.get_features(self.curr_ti, image) + key, shrinkage, selection = self.image_feature_store.get_key(self.curr_ti, image) + + # segmentation from memory if needed + if need_segment: + pred_prob_with_bg = self._segment(key, + selection, + pix_feat, + ms_feat, + update_sensory=update_sensory) + + # use the input mask if provided + if mask is not None: + # inform the manager of the new objects, and get a list of temporary id + # temporary ids -- indicates the position of objects in the tensor + # (starts with 1 due to the background channel) + corresponding_tmp_ids, _ = self.object_manager.add_new_objects(objects) + + mask, _ = pad_divide_by(mask, 16) + if need_segment: + # merge predicted mask with the incomplete input mask + pred_prob_no_bg = pred_prob_with_bg[1:] + # use the mutual exclusivity of segmentation + if idx_mask: + pred_prob_no_bg[:, mask > 0] = 0 + else: + pred_prob_no_bg[:, mask.max(0) > 0.5] = 0 + + new_masks = [] + for mask_id, tmp_id in enumerate(corresponding_tmp_ids): + if idx_mask: + this_mask = (mask == objects[mask_id]).type_as(pred_prob_no_bg) + else: + this_mask = mask[tmp_id] + if tmp_id >= pred_prob_no_bg.shape[0]: + new_masks.append(this_mask.unsqueeze(0)) + else: + # +1 for padding the background channel + pred_prob_no_bg[tmp_id + 1] = this_mask + # new_masks are always in the order of tmp_id + mask = torch.cat([pred_prob_no_bg, *new_masks], dim=0) + elif idx_mask: + # simply convert cls to one-hot representation + if len(objects) == 0: + if delete_buffer: + self.image_feature_store.delete(self.curr_ti) + log.warn('Trying to insert an empty mask as memory!') + return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16), + device=key.device, + dtype=key.dtype) + mask = torch.stack( + [mask == objects[mask_id] for mask_id, _ in enumerate(corresponding_tmp_ids)], + dim=0) + pred_prob_with_bg = aggregate(mask, dim=0) + pred_prob_with_bg = torch.softmax(pred_prob_with_bg, dim=0) + + self.last_mask = pred_prob_with_bg[1:].unsqueeze(0) + if self.flip_aug: + self.last_mask = torch.cat( + [self.last_mask, torch.flip(self.last_mask, dims=[-1])], dim=0) + + # save as memory if needed + if is_mem_frame or force_permanent: + self._add_memory(image, + pix_feat, + self.last_mask, + key, + shrinkage, + selection, + force_permanent=force_permanent) + + if delete_buffer: + self.image_feature_store.delete(self.curr_ti) + + output_prob = unpad(pred_prob_with_bg, self.pad) + if resize_needed: + # restore output to the original size + output_prob = F.interpolate(output_prob.unsqueeze(0), + size=(h, w), + mode='bilinear', + align_corners=False)[0] + + return output_prob + + def get_aux_outputs(self, image: torch.Tensor) -> Dict[str, torch.Tensor]: + image, pads = pad_divide_by(image, 16) + image = image.unsqueeze(0) # add the batch dimension + _, pix_feat = self.image_feature_store.get_features(self.curr_ti, image) + + aux_inputs = self.memory.aux + aux_outputs = self.network.compute_aux(pix_feat, aux_inputs, selector=None) + aux_outputs['q_weights'] = aux_inputs['q_weights'] + aux_outputs['p_weights'] = aux_inputs['p_weights'] + + for k, v in aux_outputs.items(): + if len(v.shape) == 5: + aux_outputs[k] = F.interpolate(v[0], + size=image.shape[-2:], + mode='bilinear', + align_corners=False) + elif 'weights' in k: + b, num_objects, num_heads, num_queries, h, w = v.shape + v = v.view(num_objects * num_heads, num_queries, h, w) + v = F.interpolate(v, size=image.shape[-2:], mode='bilinear', align_corners=False) + aux_outputs[k] = v.view(num_objects, num_heads, num_queries, *image.shape[-2:]) + else: + aux_outputs[k] = F.interpolate(v, + size=image.shape[-2:], + mode='bilinear', + align_corners=False)[0] + aux_outputs[k] = unpad(aux_outputs[k], pads) + if 'weights' in k: + weights = aux_outputs[k] + weights = weights / (weights.max(-1, keepdim=True)[0].max(-2, keepdim=True)[0] + + 1e-8) + aux_outputs[k] = (weights * 255).cpu().numpy() + else: + aux_outputs[k] = (aux_outputs[k].softmax(dim=0) * 255).cpu().numpy() + + self.image_feature_store.delete(self.curr_ti) + return aux_outputs + + def get_aux_object_weights(self, image: torch.Tensor) -> np.ndarray: + image, pads = pad_divide_by(image, 16) + # B*num_objects*H*W*num_queries -> num_objects*num_queries*H*W + # weights = F.softmax(self.obj_logits, dim=-1)[0] + weights = F.sigmoid(self.obj_logits)[0] + weights = weights.permute(0, 3, 1, 2).contiguous() + weights = F.interpolate(weights, + size=image.shape[-2:], + mode='bilinear', + align_corners=False) + # weights = weights / (weights.max(-1, keepdim=True)[0].max(-2, keepdim=True)[0]) + weights = unpad(weights, pads) + weights = (weights * 255).cpu().numpy() + return weights diff --git a/web-demos/hugging_face/tracker/inference/kv_memory_store.py b/web-demos/hugging_face/tracker/inference/kv_memory_store.py new file mode 100644 index 0000000000000000000000000000000000000000..e50b794dc227a772e8a7478d26d662749c0b1c6c --- /dev/null +++ b/web-demos/hugging_face/tracker/inference/kv_memory_store.py @@ -0,0 +1,348 @@ +from typing import Dict, List, Optional, Literal +from collections import defaultdict +import torch + + +def _add_last_dim(dictionary, key, new_value, prepend=False): + # append/prepend a new value to the last dimension of a tensor in a dictionary + # if the key does not exist, put the new value in + # append by default + if key in dictionary: + dictionary[key] = torch.cat([dictionary[key], new_value], -1) + else: + dictionary[key] = new_value + + +class KeyValueMemoryStore: + """ + Works for key/value pairs type storage + e.g., working and long-term memory + """ + def __init__(self, save_selection: bool = False, save_usage: bool = False): + """ + We store keys and values of objects that first appear in the same frame in a bucket. + Each bucket contains a set of object ids. + Each bucket is associated with a single key tensor + and a dictionary of value tensors indexed by object id. + + The keys and values are stored as the concatenation of a permanent part and a temporary part. + """ + self.save_selection = save_selection + self.save_usage = save_usage + + self.global_bucket_id = 0 # does not reduce even if buckets are removed + self.buckets: Dict[int, List[int]] = {} # indexed by bucket id + self.k: Dict[int, torch.Tensor] = {} # indexed by bucket id + self.v: Dict[int, torch.Tensor] = {} # indexed by object id + + # indexed by bucket id; the end point of permanent memory + self.perm_end_pt: Dict[int, int] = defaultdict(int) + + # shrinkage and selection are just like the keys + self.s = {} + if self.save_selection: + self.e = {} # does not contain the permanent memory part + + # usage + if self.save_usage: + self.use_cnt = {} # indexed by bucket id, does not contain the permanent memory part + self.life_cnt = {} # indexed by bucket id, does not contain the permanent memory part + + def add(self, + key: torch.Tensor, + values: Dict[int, torch.Tensor], + shrinkage: torch.Tensor, + selection: torch.Tensor, + supposed_bucket_id: int = -1, + as_permanent: Literal['no', 'first', 'all'] = 'no') -> None: + """ + key: (1/2)*C*N + values: dict of values ((1/2)*C*N), object ids are used as keys + shrinkage: (1/2)*1*N + selection: (1/2)*C*N + + supposed_bucket_id: used to sync the bucket id between working and long-term memory + if provided, the input should all be in a single bucket indexed by this id + as_permanent: whether to store the input as permanent memory + 'no': don't + 'first': only store it as permanent memory if the bucket is empty + 'all': always store it as permanent memory + """ + bs = key.shape[0] + ne = key.shape[-1] + assert len(key.shape) == 3 + assert len(shrinkage.shape) == 3 + assert not self.save_selection or len(selection.shape) == 3 + assert as_permanent in ['no', 'first', 'all'] + + # add the value and create new buckets if necessary + if supposed_bucket_id >= 0: + enabled_buckets = [supposed_bucket_id] + bucket_exist = supposed_bucket_id in self.buckets + for obj, value in values.items(): + if bucket_exist: + assert obj in self.v + assert obj in self.buckets[supposed_bucket_id] + _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all')) + else: + assert obj not in self.v + self.v[obj] = value + self.buckets[supposed_bucket_id] = list(values.keys()) + else: + new_bucket_id = None + enabled_buckets = set() + for obj, value in values.items(): + assert len(value.shape) == 3 + if obj in self.v: + _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all')) + bucket_used = [ + bucket_id for bucket_id, object_ids in self.buckets.items() + if obj in object_ids + ] + assert len(bucket_used) == 1 # each object should only be in one bucket + enabled_buckets.add(bucket_used[0]) + else: + self.v[obj] = value + if new_bucket_id is None: + # create new bucket + new_bucket_id = self.global_bucket_id + self.global_bucket_id += 1 + self.buckets[new_bucket_id] = [] + # put the new object into the corresponding bucket + self.buckets[new_bucket_id].append(obj) + enabled_buckets.add(new_bucket_id) + + # increment the permanent size if necessary + add_as_permanent = {} # indexed by bucket id + for bucket_id in enabled_buckets: + add_as_permanent[bucket_id] = False + if as_permanent == 'all': + self.perm_end_pt[bucket_id] += ne + add_as_permanent[bucket_id] = True + elif as_permanent == 'first': + if self.perm_end_pt[bucket_id] == 0: + self.perm_end_pt[bucket_id] = ne + add_as_permanent[bucket_id] = True + + # create new counters for usage if necessary + if self.save_usage and as_permanent != 'all': + new_count = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + new_life = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + 1e-7 + + # add the key to every bucket + for bucket_id in self.buckets: + if bucket_id not in enabled_buckets: + # if we are not adding new values to a bucket, we should skip it + continue + + _add_last_dim(self.k, bucket_id, key, prepend=add_as_permanent[bucket_id]) + _add_last_dim(self.s, bucket_id, shrinkage, prepend=add_as_permanent[bucket_id]) + if not add_as_permanent[bucket_id]: + if self.save_selection: + _add_last_dim(self.e, bucket_id, selection) + if self.save_usage: + _add_last_dim(self.use_cnt, bucket_id, new_count) + _add_last_dim(self.life_cnt, bucket_id, new_life) + + def update_bucket_usage(self, bucket_id: int, usage: torch.Tensor) -> None: + # increase all life count by 1 + # increase use of indexed elements + if not self.save_usage: + return + + usage = usage[:, self.perm_end_pt[bucket_id]:] + if usage.shape[-1] == 0: + # if there is no temporary memory, we don't need to update + return + self.use_cnt[bucket_id] += usage.view_as(self.use_cnt[bucket_id]) + self.life_cnt[bucket_id] += 1 + + def sieve_by_range(self, bucket_id: int, start: int, end: int, min_size: int) -> None: + # keep only the temporary elements *outside* of this range (with some boundary conditions) + # the permanent elements are ignored in this computation + # i.e., concat (a[:start], a[end:]) + # bucket with size <= min_size are not modified + + assert start >= 0 + assert end <= 0 + + object_ids = self.buckets[bucket_id] + bucket_num_elements = self.k[bucket_id].shape[-1] - self.perm_end_pt[bucket_id] + if bucket_num_elements <= min_size: + return + + if end == 0: + # negative 0 would not work as the end index! + # effectively make the second part an empty slice + end = self.k[bucket_id].shape[-1] + 1 + + p_size = self.perm_end_pt[bucket_id] + start = start + p_size + + k = self.k[bucket_id] + s = self.s[bucket_id] + if self.save_selection: + e = self.e[bucket_id] + if self.save_usage: + use_cnt = self.use_cnt[bucket_id] + life_cnt = self.life_cnt[bucket_id] + + self.k[bucket_id] = torch.cat([k[:, :, :start], k[:, :, end:]], -1) + self.s[bucket_id] = torch.cat([s[:, :, :start], s[:, :, end:]], -1) + if self.save_selection: + self.e[bucket_id] = torch.cat([e[:, :, :start - p_size], e[:, :, end:]], -1) + if self.save_usage: + self.use_cnt[bucket_id] = torch.cat([use_cnt[:, :start - p_size], use_cnt[:, end:]], -1) + self.life_cnt[bucket_id] = torch.cat([life_cnt[:, :start - p_size], life_cnt[:, end:]], + -1) + for obj_id in object_ids: + v = self.v[obj_id] + self.v[obj_id] = torch.cat([v[:, :, :start], v[:, :, end:]], -1) + + def remove_old_memory(self, bucket_id: int, max_len: int) -> None: + self.sieve_by_range(bucket_id, 0, -max_len, max_len) + + def remove_obsolete_features(self, bucket_id: int, max_size: int) -> None: + # for long-term memory only + object_ids = self.buckets[bucket_id] + + assert self.perm_end_pt[bucket_id] == 0 # permanent memory should be empty in LT memory + + # normalize with life duration + usage = self.get_usage(bucket_id) + bs = usage.shape[0] + + survivals = [] + + for bi in range(bs): + _, survived = torch.topk(usage[bi], k=max_size) + survivals.append(survived.flatten()) + assert survived.shape[-1] == survivals[0].shape[-1] + + self.k[bucket_id] = torch.stack( + [self.k[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) + self.s[bucket_id] = torch.stack( + [self.s[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) + + if self.save_selection: + # Long-term memory does not store selection so this should not be needed + self.e[bucket_id] = torch.stack( + [self.e[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) + for obj_id in object_ids: + self.v[obj_id] = torch.stack( + [self.v[obj_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) + + self.use_cnt[bucket_id] = torch.stack( + [self.use_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0) + self.life_cnt[bucket_id] = torch.stack( + [self.life_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0) + + def get_usage(self, bucket_id: int) -> torch.Tensor: + # return normalized usage + if not self.save_usage: + raise RuntimeError('I did not count usage!') + else: + usage = self.use_cnt[bucket_id] / self.life_cnt[bucket_id] + return usage + + def get_all_sliced( + self, bucket_id: int, start: int, end: int + ) -> (torch.Tensor, torch.Tensor, torch.Tensor, Dict[int, torch.Tensor], torch.Tensor): + # return k, sk, ek, value, normalized usage in order, sliced by start and end + # this only queries the temporary memory + + assert start >= 0 + assert end <= 0 + + p_size = self.perm_end_pt[bucket_id] + start = start + p_size + + if end == 0: + # negative 0 would not work as the end index! + k = self.k[bucket_id][:, :, start:] + sk = self.s[bucket_id][:, :, start:] + ek = self.e[bucket_id][:, :, start - p_size:] if self.save_selection else None + value = {obj_id: self.v[obj_id][:, :, start:] for obj_id in self.buckets[bucket_id]} + usage = self.get_usage(bucket_id)[:, start - p_size:] if self.save_usage else None + else: + k = self.k[bucket_id][:, :, start:end] + sk = self.s[bucket_id][:, :, start:end] + ek = self.e[bucket_id][:, :, start - p_size:end] if self.save_selection else None + value = {obj_id: self.v[obj_id][:, :, start:end] for obj_id in self.buckets[bucket_id]} + usage = self.get_usage(bucket_id)[:, start - p_size:end] if self.save_usage else None + + return k, sk, ek, value, usage + + def purge_except(self, obj_keep_idx: List[int]): + # purge certain objects from the memory except the one listed + obj_keep_idx = set(obj_keep_idx) + + # remove objects that are not in the keep list from the buckets + buckets_to_remove = [] + for bucket_id, object_ids in self.buckets.items(): + self.buckets[bucket_id] = [obj_id for obj_id in object_ids if obj_id in obj_keep_idx] + if len(self.buckets[bucket_id]) == 0: + buckets_to_remove.append(bucket_id) + + # remove object values that are not in the keep list + self.v = {k: v for k, v in self.v.items() if k in obj_keep_idx} + + # remove buckets that are empty + for bucket_id in buckets_to_remove: + del self.buckets[bucket_id] + del self.k[bucket_id] + del self.s[bucket_id] + if self.save_selection: + del self.e[bucket_id] + if self.save_usage: + del self.use_cnt[bucket_id] + del self.life_cnt[bucket_id] + + def clear_non_permanent_memory(self): + # clear all non-permanent memory + for bucket_id in self.buckets: + self.sieve_by_range(bucket_id, 0, 0, 0) + + def get_v_size(self, obj_id: int) -> int: + return self.v[obj_id].shape[-1] + + def size(self, bucket_id: int) -> int: + if bucket_id not in self.k: + return 0 + else: + return self.k[bucket_id].shape[-1] + + def perm_size(self, bucket_id: int) -> int: + return self.perm_end_pt[bucket_id] + + def non_perm_size(self, bucket_id: int) -> int: + return self.size(bucket_id) - self.perm_size(bucket_id) + + def engaged(self, bucket_id: Optional[int] = None) -> bool: + if bucket_id is None: + return len(self.buckets) > 0 + else: + return bucket_id in self.buckets + + @property + def num_objects(self) -> int: + return len(self.v) + + @property + def key(self) -> Dict[int, torch.Tensor]: + return self.k + + @property + def value(self) -> Dict[int, torch.Tensor]: + return self.v + + @property + def shrinkage(self) -> Dict[int, torch.Tensor]: + return self.s + + @property + def selection(self) -> Dict[int, torch.Tensor]: + return self.e + + def __contains__(self, key): + return key in self.v diff --git a/web-demos/hugging_face/tracker/inference/memory_manager.py b/web-demos/hugging_face/tracker/inference/memory_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..53995a0f3c0998191858b955415a549dfcef248e --- /dev/null +++ b/web-demos/hugging_face/tracker/inference/memory_manager.py @@ -0,0 +1,378 @@ +import logging +from omegaconf import DictConfig +from typing import List, Dict +import torch + +from tracker.inference.object_manager import ObjectManager +from tracker.inference.kv_memory_store import KeyValueMemoryStore +from tracker.model.cutie import CUTIE +from tracker.model.utils.memory_utils import * + +log = logging.getLogger() + + +class MemoryManager: + """ + Manages all three memory stores and the transition between working/long-term memory + """ + def __init__(self, cfg: DictConfig, object_manager: ObjectManager): + self.object_manager = object_manager + self.sensory_dim = cfg.model.sensory_dim + self.top_k = cfg.top_k + self.chunk_size = cfg.chunk_size + + self.save_aux = cfg.save_aux + + self.use_long_term = cfg.use_long_term + self.count_long_term_usage = cfg.long_term.count_usage + # subtract 1 because the first-frame is now counted as "permanent memory" + # and is not counted towards max_mem_frames + # but we want to keep the hyperparameters consistent as before for the same behavior + if self.use_long_term: + self.max_mem_frames = cfg.long_term.max_mem_frames - 1 + self.min_mem_frames = cfg.long_term.min_mem_frames - 1 + self.num_prototypes = cfg.long_term.num_prototypes + self.max_long_tokens = cfg.long_term.max_num_tokens + self.buffer_tokens = cfg.long_term.buffer_tokens + else: + self.max_mem_frames = cfg.max_mem_frames - 1 + + # dimensions will be inferred from input later + self.CK = self.CV = None + self.H = self.W = None + + # The sensory memory is stored as a dictionary indexed by object ids + # each of shape bs * C^h * H * W + self.sensory = {} + + # a dictionary indexed by object ids, each of shape bs * T * Q * C + self.obj_v = {} + + self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term, + save_usage=self.use_long_term) + if self.use_long_term: + self.long_mem = KeyValueMemoryStore(save_usage=self.count_long_term_usage) + + self.config_stale = True + self.engaged = False + + def update_config(self, cfg: DictConfig) -> None: + self.config_stale = True + self.top_k = cfg['top_k'] + + assert self.use_long_term == cfg.use_long_term, 'cannot update this' + assert self.count_long_term_usage == cfg.long_term.count_usage, 'cannot update this' + + self.use_long_term = cfg.use_long_term + self.count_long_term_usage = cfg.long_term.count_usage + if self.use_long_term: + self.max_mem_frames = cfg.long_term.max_mem_frames - 1 + self.min_mem_frames = cfg.long_term.min_mem_frames - 1 + self.num_prototypes = cfg.long_term.num_prototypes + self.max_long_tokens = cfg.long_term.max_num_tokens + self.buffer_tokens = cfg.long_term.buffer_tokens + else: + self.max_mem_frames = cfg.max_mem_frames - 1 + + def _readout(self, affinity, v) -> torch.Tensor: + # affinity: bs*N*HW + # v: bs*C*N or bs*num_objects*C*N + # returns bs*C*HW or bs*num_objects*C*HW + if len(v.shape) == 3: + # single object + return v @ affinity + else: + bs, num_objects, C, N = v.shape + v = v.view(bs, num_objects * C, N) + out = v @ affinity + return out.view(bs, num_objects, C, -1) + + def _get_mask_by_ids(self, mask: torch.Tensor, obj_ids: List[int]) -> torch.Tensor: + # -1 because the mask does not contain the background channel + return mask[:, [self.object_manager.find_tmp_by_id(obj) - 1 for obj in obj_ids]] + + def _get_sensory_by_ids(self, obj_ids: List[int]) -> torch.Tensor: + return torch.stack([self.sensory[obj] for obj in obj_ids], dim=1) + + def _get_object_mem_by_ids(self, obj_ids: List[int]) -> torch.Tensor: + return torch.stack([self.obj_v[obj] for obj in obj_ids], dim=1) + + def _get_visual_values_by_ids(self, obj_ids: List[int]) -> torch.Tensor: + # All the values that the object ids refer to should have the same shape + value = torch.stack([self.work_mem.value[obj] for obj in obj_ids], dim=1) + if self.use_long_term and obj_ids[0] in self.long_mem.value: + lt_value = torch.stack([self.long_mem.value[obj] for obj in obj_ids], dim=1) + value = torch.cat([lt_value, value], dim=-1) + + return value + + def read(self, pix_feat: torch.Tensor, query_key: torch.Tensor, selection: torch.Tensor, + last_mask: torch.Tensor, network: CUTIE) -> Dict[int, torch.Tensor]: + """ + Read from all memory stores and returns a single memory readout tensor for each object + + pix_feat: (1/2) x C x H x W + query_key: (1/2) x C^k x H x W + selection: (1/2) x C^k x H x W + last_mask: (1/2) x num_objects x H x W (at stride 16) + return a dict of memory readouts, indexed by object indices. Each readout is C*H*W + """ + h, w = pix_feat.shape[-2:] + bs = pix_feat.shape[0] + assert query_key.shape[0] == bs + assert selection.shape[0] == bs + assert last_mask.shape[0] == bs + + query_key = query_key.flatten(start_dim=2) # bs*C^k*HW + selection = selection.flatten(start_dim=2) # bs*C^k*HW + """ + Compute affinity and perform readout + """ + all_readout_mem = {} + buckets = self.work_mem.buckets + for bucket_id, bucket in buckets.items(): + if self.use_long_term and self.long_mem.engaged(bucket_id): + # Use long-term memory + long_mem_size = self.long_mem.size(bucket_id) + memory_key = torch.cat([self.long_mem.key[bucket_id], self.work_mem.key[bucket_id]], + -1) + shrinkage = torch.cat( + [self.long_mem.shrinkage[bucket_id], self.work_mem.shrinkage[bucket_id]], -1) + + similarity = get_similarity(memory_key, shrinkage, query_key, selection) + affinity, usage = do_softmax(similarity, + top_k=self.top_k, + inplace=True, + return_usage=True) + """ + Record memory usage for working and long-term memory + """ + # ignore the index return for long-term memory + work_usage = usage[:, long_mem_size:] + self.work_mem.update_bucket_usage(bucket_id, work_usage) + + if self.count_long_term_usage: + # ignore the index return for working memory + long_usage = usage[:, :long_mem_size] + self.long_mem.update_bucket_usage(bucket_id, long_usage) + else: + # no long-term memory + memory_key = self.work_mem.key[bucket_id] + shrinkage = self.work_mem.shrinkage[bucket_id] + similarity = get_similarity(memory_key, shrinkage, query_key, selection) + + if self.use_long_term: + affinity, usage = do_softmax(similarity, + top_k=self.top_k, + inplace=True, + return_usage=True) + self.work_mem.update_bucket_usage(bucket_id, usage) + else: + affinity = do_softmax(similarity, top_k=self.top_k, inplace=True) + + if self.chunk_size < 1: + object_chunks = [bucket] + else: + object_chunks = [ + bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size) + ] + + for objects in object_chunks: + this_sensory = self._get_sensory_by_ids(objects) + this_last_mask = self._get_mask_by_ids(last_mask, objects) + this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N + visual_readout = self._readout(affinity, + this_msk_value).view(bs, len(objects), self.CV, h, w) + pixel_readout = network.pixel_fusion(pix_feat, visual_readout, this_sensory, + this_last_mask) + this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2) + readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem) + for i, obj in enumerate(objects): + all_readout_mem[obj] = readout_memory[:, i] + + if self.save_aux: + aux_output = { + 'sensory': this_sensory, + 'pixel_readout': pixel_readout, + 'q_logits': aux_features['logits'] if aux_features else None, + 'q_weights': aux_features['q_weights'] if aux_features else None, + 'p_weights': aux_features['p_weights'] if aux_features else None, + 'attn_mask': aux_features['attn_mask'].float() if aux_features else None, + } + self.aux = aux_output + + return all_readout_mem + + def add_memory(self, + key: torch.Tensor, + shrinkage: torch.Tensor, + msk_value: torch.Tensor, + obj_value: torch.Tensor, + objects: List[int], + selection: torch.Tensor = None, + *, + as_permanent: bool = False) -> None: + # key: (1/2)*C*H*W + # msk_value: (1/2)*num_objects*C*H*W + # obj_value: (1/2)*num_objects*Q*C + # objects contains a list of object ids corresponding to the objects in msk_value/obj_value + bs = key.shape[0] + assert shrinkage.shape[0] == bs + assert msk_value.shape[0] == bs + assert obj_value.shape[0] == bs + + self.engaged = True + if self.H is None or self.config_stale: + self.config_stale = False + self.H, self.W = msk_value.shape[-2:] + self.HW = self.H * self.W + # convert from num. frames to num. tokens + self.max_work_tokens = self.max_mem_frames * self.HW + if self.use_long_term: + self.min_work_tokens = self.min_mem_frames * self.HW + + # key: bs*C*N + # value: bs*num_objects*C*N + key = key.flatten(start_dim=2) + shrinkage = shrinkage.flatten(start_dim=2) + self.CK = key.shape[1] + + msk_value = msk_value.flatten(start_dim=3) + self.CV = msk_value.shape[2] + + if selection is not None: + # not used in non-long-term mode + selection = selection.flatten(start_dim=2) + + # insert object values into object memory + for obj_id, obj in enumerate(objects): + if obj in self.obj_v: + """streaming average + each self.obj_v[obj] is (1/2)*num_summaries*(embed_dim+1) + first embed_dim keeps track of the sum of embeddings + the last dim keeps the total count + averaging in done inside the object transformer + + incoming obj_value is (1/2)*num_objects*num_summaries*(embed_dim+1) + self.obj_v[obj] = torch.cat([self.obj_v[obj], obj_value[:, obj_id]], dim=0) + """ + last_acc = self.obj_v[obj][:, :, -1] + new_acc = last_acc + obj_value[:, obj_id, :, -1] + + self.obj_v[obj][:, :, :-1] = (self.obj_v[obj][:, :, :-1] + + obj_value[:, obj_id, :, :-1]) + self.obj_v[obj][:, :, -1] = new_acc + else: + self.obj_v[obj] = obj_value[:, obj_id] + + # convert mask value tensor into a dict for insertion + msk_values = {obj: msk_value[:, obj_id] for obj_id, obj in enumerate(objects)} + self.work_mem.add(key, + msk_values, + shrinkage, + selection=selection, + as_permanent=as_permanent) + + for bucket_id in self.work_mem.buckets.keys(): + # long-term memory cleanup + if self.use_long_term: + # Do memory compressed if needed + if self.work_mem.non_perm_size(bucket_id) >= self.max_work_tokens: + # Remove obsolete features if needed + if self.long_mem.non_perm_size(bucket_id) >= (self.max_long_tokens - + self.num_prototypes): + self.long_mem.remove_obsolete_features( + bucket_id, + self.max_long_tokens - self.num_prototypes - self.buffer_tokens) + + self.compress_features(bucket_id) + else: + # FIFO + self.work_mem.remove_old_memory(bucket_id, self.max_work_tokens) + + def purge_except(self, obj_keep_idx: List[int]) -> None: + # purge certain objects from the memory except the one listed + self.work_mem.purge_except(obj_keep_idx) + if self.use_long_term and self.long_mem.engaged(): + self.long_mem.purge_except(obj_keep_idx) + self.sensory = {k: v for k, v in self.sensory.items() if k in obj_keep_idx} + + if not self.work_mem.engaged(): + # everything is removed! + self.engaged = False + + def compress_features(self, bucket_id: int) -> None: + HW = self.HW + + # perform memory consolidation + prototype_key, prototype_value, prototype_shrinkage = self.consolidation( + *self.work_mem.get_all_sliced(bucket_id, 0, -self.min_work_tokens)) + + # remove consolidated working memory + self.work_mem.sieve_by_range(bucket_id, + 0, + -self.min_work_tokens, + min_size=self.min_work_tokens) + + # add to long-term memory + self.long_mem.add(prototype_key, + prototype_value, + prototype_shrinkage, + selection=None, + supposed_bucket_id=bucket_id) + + def consolidation(self, candidate_key: torch.Tensor, candidate_shrinkage: torch.Tensor, + candidate_selection: torch.Tensor, candidate_value: Dict[int, torch.Tensor], + usage: torch.Tensor) -> (torch.Tensor, Dict[int, torch.Tensor], torch.Tensor): + # find the indices with max usage + bs = candidate_key.shape[0] + assert bs in [1, 2] + + prototype_key = [] + prototype_selection = [] + for bi in range(bs): + _, max_usage_indices = torch.topk(usage[bi], k=self.num_prototypes, dim=-1, sorted=True) + prototype_indices = max_usage_indices.flatten() + prototype_key.append(candidate_key[bi, :, prototype_indices]) + prototype_selection.append(candidate_selection[bi, :, prototype_indices]) + prototype_key = torch.stack(prototype_key, dim=0) + prototype_selection = torch.stack(prototype_selection, dim=0) + """ + Potentiation step + """ + similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key, + prototype_selection) + affinity = do_softmax(similarity) + + # readout the values + prototype_value = {k: self._readout(affinity, v) for k, v in candidate_value.items()} + + # readout the shrinkage term + prototype_shrinkage = self._readout(affinity, candidate_shrinkage) + + return prototype_key, prototype_value, prototype_shrinkage + + def initialize_sensory_if_needed(self, sample_key: torch.Tensor, ids: List[int]): + for obj in ids: + if obj not in self.sensory: + # also initializes the sensory memory + bs, _, h, w = sample_key.shape + self.sensory[obj] = torch.zeros((bs, self.sensory_dim, h, w), + device=sample_key.device) + + def update_sensory(self, sensory: torch.Tensor, ids: List[int]): + # sensory: 1*num_objects*C*H*W + for obj_id, obj in enumerate(ids): + self.sensory[obj] = sensory[:, obj_id] + + def get_sensory(self, ids: List[int]): + # returns (1/2)*num_objects*C*H*W + return self._get_sensory_by_ids(ids) + + def clear_non_permanent_memory(self): + self.work_mem.clear_non_permanent_memory() + if self.use_long_term: + self.long_mem.clear_non_permanent_memory() + + def clear_sensory_memory(self): + self.sensory = {} diff --git a/web-demos/hugging_face/tracker/inference/object_info.py b/web-demos/hugging_face/tracker/inference/object_info.py new file mode 100644 index 0000000000000000000000000000000000000000..b0e0bd45b10d0361c3ebc19783155e9ab29c8ad0 --- /dev/null +++ b/web-demos/hugging_face/tracker/inference/object_info.py @@ -0,0 +1,24 @@ +class ObjectInfo: + """ + Store meta information for an object + """ + def __init__(self, id: int): + self.id = id + self.poke_count = 0 # count number of detections missed + + def poke(self) -> None: + self.poke_count += 1 + + def unpoke(self) -> None: + self.poke_count = 0 + + def __hash__(self): + return hash(self.id) + + def __eq__(self, other): + if type(other) == int: + return self.id == other + return self.id == other.id + + def __repr__(self): + return f'(ID: {self.id})' diff --git a/web-demos/hugging_face/tracker/inference/object_manager.py b/web-demos/hugging_face/tracker/inference/object_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..a9e5a8533d3f4e5c72150abe7e79d05fcd2f7bd9 --- /dev/null +++ b/web-demos/hugging_face/tracker/inference/object_manager.py @@ -0,0 +1,148 @@ +from typing import Union, List, Dict + +import torch +from tracker.inference.object_info import ObjectInfo + + +class ObjectManager: + """ + Object IDs are immutable. The same ID always represent the same object. + Temporary IDs are the positions of each object in the tensor. It changes as objects get removed. + Temporary IDs start from 1. + """ + def __init__(self): + self.obj_to_tmp_id: Dict[ObjectInfo, int] = {} + self.tmp_id_to_obj: Dict[int, ObjectInfo] = {} + self.obj_id_to_obj: Dict[int, ObjectInfo] = {} + + self.all_historical_object_ids: List[int] = [] + + def _recompute_obj_id_to_obj_mapping(self) -> None: + self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id} + + def add_new_objects( + self, objects: Union[List[ObjectInfo], ObjectInfo, + List[int]]) -> (List[int], List[int]): + if not isinstance(objects, list): + objects = [objects] + + corresponding_tmp_ids = [] + corresponding_obj_ids = [] + for obj in objects: + if isinstance(obj, int): + obj = ObjectInfo(id=obj) + + if obj in self.obj_to_tmp_id: + # old object + corresponding_tmp_ids.append(self.obj_to_tmp_id[obj]) + corresponding_obj_ids.append(obj.id) + else: + # new object + new_obj = ObjectInfo(id=obj) + + # new object + new_tmp_id = len(self.obj_to_tmp_id) + 1 + self.obj_to_tmp_id[new_obj] = new_tmp_id + self.tmp_id_to_obj[new_tmp_id] = new_obj + self.all_historical_object_ids.append(new_obj.id) + corresponding_tmp_ids.append(new_tmp_id) + corresponding_obj_ids.append(new_obj.id) + + self._recompute_obj_id_to_obj_mapping() + assert corresponding_tmp_ids == sorted(corresponding_tmp_ids) + return corresponding_tmp_ids, corresponding_obj_ids + + def delete_object(self, obj_ids_to_remove: Union[int, List[int]]) -> None: + # delete an object or a list of objects + # re-sort the tmp ids + if isinstance(obj_ids_to_remove, int): + obj_ids_to_remove = [obj_ids_to_remove] + + new_tmp_id = 1 + total_num_id = len(self.obj_to_tmp_id) + + local_obj_to_tmp_id = {} + local_tmp_to_obj_id = {} + + for tmp_iter in range(1, total_num_id + 1): + obj = self.tmp_id_to_obj[tmp_iter] + if obj.id not in obj_ids_to_remove: + local_obj_to_tmp_id[obj] = new_tmp_id + local_tmp_to_obj_id[new_tmp_id] = obj + new_tmp_id += 1 + + self.obj_to_tmp_id = local_obj_to_tmp_id + self.tmp_id_to_obj = local_tmp_to_obj_id + self._recompute_obj_id_to_obj_mapping() + + def purge_inactive_objects(self, + max_missed_detection_count: int) -> (bool, List[int], List[int]): + # remove tmp ids of objects that are removed + obj_id_to_be_deleted = [] + tmp_id_to_be_deleted = [] + tmp_id_to_keep = [] + obj_id_to_keep = [] + + for obj in self.obj_to_tmp_id: + if obj.poke_count > max_missed_detection_count: + obj_id_to_be_deleted.append(obj.id) + tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj]) + else: + tmp_id_to_keep.append(self.obj_to_tmp_id[obj]) + obj_id_to_keep.append(obj.id) + + purge_activated = len(obj_id_to_be_deleted) > 0 + if purge_activated: + self.delete_object(obj_id_to_be_deleted) + return purge_activated, tmp_id_to_keep, obj_id_to_keep + + def tmp_to_obj_cls(self, mask) -> torch.Tensor: + # remap tmp id cls representation to the true object id representation + new_mask = torch.zeros_like(mask) + for tmp_id, obj in self.tmp_id_to_obj.items(): + new_mask[mask == tmp_id] = obj.id + return new_mask + + def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]: + # returns the mapping in a dict format for saving it with pickle + return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()} + + def realize_dict(self, obj_dict, dim=1) -> torch.Tensor: + # turns a dict indexed by obj id into a tensor, ordered by tmp IDs + output = [] + for _, obj in self.tmp_id_to_obj.items(): + if obj.id not in obj_dict: + raise NotImplementedError + output.append(obj_dict[obj.id]) + output = torch.stack(output, dim=dim) + return output + + def make_one_hot(self, cls_mask) -> torch.Tensor: + output = [] + for _, obj in self.tmp_id_to_obj.items(): + output.append(cls_mask == obj.id) + if len(output) == 0: + output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device) + else: + output = torch.stack(output, dim=0) + return output + + @property + def all_obj_ids(self) -> List[int]: + return [k.id for k in self.obj_to_tmp_id] + + @property + def num_obj(self) -> int: + return len(self.obj_to_tmp_id) + + def has_all(self, objects: List[int]) -> bool: + for obj in objects: + if obj not in self.obj_to_tmp_id: + return False + return True + + def find_object_by_id(self, obj_id) -> ObjectInfo: + return self.obj_id_to_obj[obj_id] + + def find_tmp_by_id(self, obj_id) -> int: + return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]] diff --git a/web-demos/hugging_face/tracker/inference/utils/__init__.py b/web-demos/hugging_face/tracker/inference/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/web-demos/hugging_face/tracker/inference/utils/args_utils.py b/web-demos/hugging_face/tracker/inference/utils/args_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a771ccaa080af2acd9757c7139c60c24652a1442 --- /dev/null +++ b/web-demos/hugging_face/tracker/inference/utils/args_utils.py @@ -0,0 +1,30 @@ +import logging +from omegaconf import DictConfig + +log = logging.getLogger() + + +def get_dataset_cfg(cfg: DictConfig): + dataset_name = cfg.dataset + data_cfg = cfg.datasets[dataset_name] + + potential_overrides = [ + 'image_directory', + 'mask_directory', + 'json_directory', + 'size', + 'save_all', + 'use_all_masks', + 'use_long_term', + 'mem_every', + ] + + for override in potential_overrides: + if cfg[override] is not None: + log.info(f'Overriding config {override} from {data_cfg[override]} to {cfg[override]}') + data_cfg[override] = cfg[override] + # escalte all potential overrides to the top-level config + if override in data_cfg: + cfg[override] = data_cfg[override] + + return data_cfg diff --git a/web-demos/hugging_face/tracker/inference/utils/burst_utils.py b/web-demos/hugging_face/tracker/inference/utils/burst_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..570442848c83378f8562485aa7cca3502910440c --- /dev/null +++ b/web-demos/hugging_face/tracker/inference/utils/burst_utils.py @@ -0,0 +1,19 @@ +from os import path +import copy +import json + + +class BURSTResultHandler: + def __init__(self, dataset_json): + self.dataset_json = copy.deepcopy(dataset_json) + + # get rid of the segmentations while keeping the metadata + self.dataset_json['sequences'] = [] + + def add_sequence(self, sequence_json): + self.dataset_json['sequences'].append(sequence_json) + + def dump(self, root): + json_path = path.join(root, 'predictions.json') + with open(json_path, 'w') as f: + json.dump(self.dataset_json, f) \ No newline at end of file diff --git a/web-demos/hugging_face/tracker/inference/utils/frame_utils.py b/web-demos/hugging_face/tracker/inference/utils/frame_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dc18f35ed7c9ee4ff1aa41fa1ef988d372e91039 --- /dev/null +++ b/web-demos/hugging_face/tracker/inference/utils/frame_utils.py @@ -0,0 +1,26 @@ +from typing import Dict, List, Tuple +import torch + +from inference.object_info import ObjectInfo + + +class FrameInfo: + def __init__(self, image: torch.Tensor, mask: torch.Tensor, segments_info: List[ObjectInfo], + ti: int, info: Dict): + self.image = image + self.mask = mask + self.segments_info = segments_info + self.ti = ti + self.info = info + + @property + def name(self) -> str: + return self.info['frame'] + + @property + def shape(self) -> Tuple(int): + return self.info['shape'] + + @property + def need_save(self) -> bool: + return self.info['save'] diff --git a/web-demos/hugging_face/tracker/inference/utils/results_utils.py b/web-demos/hugging_face/tracker/inference/utils/results_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..db970ffea9ff89df45bdb4932a5972373a4f23f2 --- /dev/null +++ b/web-demos/hugging_face/tracker/inference/utils/results_utils.py @@ -0,0 +1,256 @@ +from typing import Tuple, Optional, Dict +import logging +import os +import shutil +from os import path +from PIL import Image +import torch +import torch.nn.functional as F +import numpy as np + +import pycocotools.mask as mask_util +from threading import Thread +from queue import Queue +from dataclasses import dataclass +import copy + +from tracker.utils.pano_utils import ID2RGBConverter +from tracker.utils.palette import davis_palette_np +from tracker.inference.object_manager import ObjectManager +from tracker.inference.object_info import ObjectInfo + +log = logging.getLogger() + +try: + import hickle as hkl +except ImportError: + log.warning('Failed to import hickle. Fine if not using multi-scale testing.') + + +class ResultSaver: + def __init__(self, + output_root, + video_name, + *, + dataset, + object_manager: ObjectManager, + use_long_id, + palette=None, + save_mask=True, + save_scores=False, + score_output_root=None, + visualize_output_root=None, + visualize=False, + init_json=None): + self.output_root = output_root + self.video_name = video_name + self.dataset = dataset.lower() + self.use_long_id = use_long_id + self.palette = palette + self.object_manager = object_manager + self.save_mask = save_mask + self.save_scores = save_scores + self.score_output_root = score_output_root + self.visualize_output_root = visualize_output_root + self.visualize = visualize + + if self.visualize: + if self.palette is not None: + self.colors = np.array(self.palette, dtype=np.uint8).reshape(-1, 3) + else: + self.colors = davis_palette_np + + self.need_remapping = True + self.json_style = None + self.id2rgb_converter = ID2RGBConverter() + + if 'burst' in self.dataset: + assert init_json is not None + self.input_segmentations = init_json['segmentations'] + self.segmentations = [{} for _ in init_json['segmentations']] + self.annotated_frames = init_json['annotated_image_paths'] + self.video_json = {k: v for k, v in init_json.items() if k != 'segmentations'} + self.video_json['segmentations'] = self.segmentations + self.json_style = 'burst' + + self.queue = Queue(maxsize=10) + self.thread = Thread(target=save_result, args=(self.queue, )) + self.thread.daemon = True + self.thread.start() + + def process(self, + prob: torch.Tensor, + frame_name: str, + resize_needed: bool = False, + shape: Optional[Tuple[int, int]] = None, + last_frame: bool = False, + path_to_image: str = None): + + if resize_needed: + prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:, + 0] + # Probability mask -> index mask + mask = torch.argmax(prob, dim=0) + if self.save_scores: + # also need to pass prob + prob = prob.cpu() + else: + prob = None + + # remap indices + if self.need_remapping: + new_mask = torch.zeros_like(mask) + for tmp_id, obj in self.object_manager.tmp_id_to_obj.items(): + new_mask[mask == tmp_id] = obj.id + mask = new_mask + + args = ResultArgs(saver=self, + prob=prob, + mask=mask.cpu(), + frame_name=frame_name, + path_to_image=path_to_image, + tmp_id_to_obj=copy.deepcopy(self.object_manager.tmp_id_to_obj), + obj_to_tmp_id=copy.deepcopy(self.object_manager.obj_to_tmp_id), + last_frame=last_frame) + + self.queue.put(args) + + def end(self): + self.queue.put(None) + self.queue.join() + self.thread.join() + + +@dataclass +class ResultArgs: + saver: ResultSaver + prob: torch.Tensor + mask: torch.Tensor + frame_name: str + path_to_image: str + tmp_id_to_obj: Dict[int, ObjectInfo] + obj_to_tmp_id: Dict[ObjectInfo, int] + last_frame: bool + + +def save_result(queue: Queue): + while True: + args: ResultArgs = queue.get() + if args is None: + queue.task_done() + break + + saver = args.saver + prob = args.prob + mask = args.mask + frame_name = args.frame_name + path_to_image = args.path_to_image + tmp_id_to_obj = args.tmp_id_to_obj + obj_to_tmp_id = args.obj_to_tmp_id + last_frame = args.last_frame + all_obj_ids = [k.id for k in obj_to_tmp_id] + + # record output in the json file + if saver.json_style == 'burst': + if frame_name in saver.annotated_frames: + frame_index = saver.annotated_frames.index(frame_name) + input_segments = saver.input_segmentations[frame_index] + frame_segments = saver.segmentations[frame_index] + + for id in all_obj_ids: + if id in input_segments: + # if this frame has been given as input, just copy + frame_segments[id] = input_segments[id] + continue + + segment = {} + segment_mask = (mask == id) + if segment_mask.sum() > 0: + coco_mask = mask_util.encode(np.asfortranarray(segment_mask.numpy())) + segment['rle'] = coco_mask['counts'].decode('utf-8') + frame_segments[id] = segment + + # save the mask to disk + if saver.save_mask: + if saver.use_long_id: + out_mask = mask.numpy().astype(np.uint32) + rgb_mask = np.zeros((*out_mask.shape[-2:], 3), dtype=np.uint8) + for id in all_obj_ids: + _, image = saver.id2rgb_converter.convert(id) + obj_mask = (out_mask == id) + rgb_mask[obj_mask] = image + out_img = Image.fromarray(rgb_mask) + else: + rgb_mask = None + out_mask = mask.numpy().astype(np.uint8) + out_img = Image.fromarray(out_mask) + if saver.palette is not None: + out_img.putpalette(saver.palette) + + this_out_path = path.join(saver.output_root, saver.video_name) + os.makedirs(this_out_path, exist_ok=True) + out_img.save(os.path.join(this_out_path, frame_name[:-4] + '.png')) + + # save scores for multi-scale testing + if saver.save_scores: + this_out_path = path.join(saver.score_output_root, saver.video_name) + os.makedirs(this_out_path, exist_ok=True) + + prob = (prob.detach().numpy() * 255).astype(np.uint8) + + if last_frame: + tmp_to_obj_mapping = {obj.id: tmp_id for obj, tmp_id in tmp_id_to_obj.items()} + hkl.dump(tmp_to_obj_mapping, path.join(this_out_path, f'backward.hkl'), mode='w') + + hkl.dump(prob, + path.join(this_out_path, f'{frame_name[:-4]}.hkl'), + mode='w', + compression='lzf') + + if saver.visualize: + if path_to_image is not None: + image_np = np.array(Image.open(path_to_image)) + else: + raise ValueError('Cannot visualize without path_to_image') + + if rgb_mask is None: + # we need to apply a palette + rgb_mask = np.zeros((*out_mask.shape, 3), dtype=np.uint8) + for id in all_obj_ids: + image = saver.colors[id] + obj_mask = (out_mask == id) + rgb_mask[obj_mask] = image + + alpha = (out_mask == 0).astype(np.float32) * 0.5 + 0.5 + alpha = alpha[:, :, None] + blend = (image_np * alpha + rgb_mask * (1 - alpha)).astype(np.uint8) + + # find a place to save the visualization + this_vis_path = path.join(saver.visualize_output_root, saver.video_name) + os.makedirs(this_vis_path, exist_ok=True) + Image.fromarray(blend).save(path.join(this_vis_path, frame_name[:-4] + '.jpg')) + + queue.task_done() + + +def make_zip(dataset, run_dir, exp_id, mask_output_root): + if dataset.startswith('y'): + # YoutubeVOS + log.info('Making zip for YouTubeVOS...') + shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', run_dir, + 'Annotations') + elif dataset == 'd17-test-dev': + # DAVIS 2017 test-dev -- zip from within the Annotation folder + log.info('Making zip for DAVIS test-dev...') + shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', mask_output_root) + elif dataset == 'mose-val': + # MOSE validation -- same as DAVIS test-dev + log.info('Making zip for MOSE validation...') + shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', mask_output_root) + elif dataset == 'lvos-test': + # LVOS test -- same as YouTubeVOS + log.info('Making zip for LVOS test...') + shutil.make_archive(path.join(run_dir, f'{exp_id}_{dataset}'), 'zip', run_dir, + 'Annotations') + else: + log.info(f'Not making zip for {dataset}.') diff --git a/web-demos/hugging_face/tracker/model/__init__.py b/web-demos/hugging_face/tracker/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/web-demos/hugging_face/tracker/model/aux_modules.py b/web-demos/hugging_face/tracker/model/aux_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..bf4a8ef3913dc5f057b29ad5917cb7bc5541d004 --- /dev/null +++ b/web-demos/hugging_face/tracker/model/aux_modules.py @@ -0,0 +1,80 @@ +""" +For computing auxiliary outputs for auxiliary losses +""" +from typing import Dict +from omegaconf import DictConfig +import torch +import torch.nn as nn + +from tracker.model.group_modules import GConv2d +from tracker.utils.tensor_utils import aggregate + + +class LinearPredictor(nn.Module): + def __init__(self, x_dim: int, pix_dim: int): + super().__init__() + self.projection = GConv2d(x_dim, pix_dim + 1, kernel_size=1) + + def forward(self, pix_feat: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + # pixel_feat: B*pix_dim*H*W + # x: B*num_objects*x_dim*H*W + num_objects = x.shape[1] + x = self.projection(x) + + pix_feat = pix_feat.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) + logits = (pix_feat * x[:, :, :-1]).sum(dim=2) + x[:, :, -1] + return logits + + +class DirectPredictor(nn.Module): + def __init__(self, x_dim: int): + super().__init__() + self.projection = GConv2d(x_dim, 1, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: B*num_objects*x_dim*H*W + logits = self.projection(x).squeeze(2) + return logits + + +class AuxComputer(nn.Module): + def __init__(self, cfg: DictConfig): + super().__init__() + + use_sensory_aux = cfg.model.aux_loss.sensory.enabled + self.use_query_aux = cfg.model.aux_loss.query.enabled + + sensory_dim = cfg.model.sensory_dim + embed_dim = cfg.model.embed_dim + + if use_sensory_aux: + self.sensory_aux = LinearPredictor(sensory_dim, embed_dim) + else: + self.sensory_aux = None + + def _aggregate_with_selector(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor: + prob = torch.sigmoid(logits) + if selector is not None: + prob = prob * selector + logits = aggregate(prob, dim=1) + return logits + + def forward(self, pix_feat: torch.Tensor, aux_input: Dict[str, torch.Tensor], + selector: torch.Tensor) -> Dict[str, torch.Tensor]: + sensory = aux_input['sensory'] + q_logits = aux_input['q_logits'] + + aux_output = {} + aux_output['attn_mask'] = aux_input['attn_mask'] + + if self.sensory_aux is not None: + # B*num_objects*H*W + logits = self.sensory_aux(pix_feat, sensory) + aux_output['sensory_logits'] = self._aggregate_with_selector(logits, selector) + if self.use_query_aux: + # B*num_objects*num_levels*H*W + aux_output['q_logits'] = self._aggregate_with_selector( + torch.stack(q_logits, dim=2), + selector.unsqueeze(2) if selector is not None else None) + + return aux_output \ No newline at end of file diff --git a/web-demos/hugging_face/tracker/model/big_modules.py b/web-demos/hugging_face/tracker/model/big_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..cc1daaf0d72811694922476e63e018cefa6c5656 --- /dev/null +++ b/web-demos/hugging_face/tracker/model/big_modules.py @@ -0,0 +1,304 @@ +""" +big_modules.py - This file stores higher-level network blocks. + +x - usually denotes features that are shared between objects. +g - usually denotes features that are not shared between objects + with an extra "num_objects" dimension (batch_size * num_objects * num_channels * H * W). + +The trailing number of a variable usually denotes the stride +""" + +from omegaconf import DictConfig +import torch +import torch.nn as nn +import torch.nn.functional as F + +from tracker.model.group_modules import * +from tracker.model.utils import resnet +from tracker.model.modules import * + + +class PixelEncoder(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + + self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type + if self.is_resnet: + if model_cfg.pixel_encoder.type == 'resnet18': + network = resnet.resnet18(pretrained=True) + elif model_cfg.pixel_encoder.type == 'resnet50': + network = resnet.resnet50(pretrained=True) + else: + raise NotImplementedError + self.conv1 = network.conv1 + self.bn1 = network.bn1 + self.relu = network.relu + self.maxpool = network.maxpool + + self.res2 = network.layer1 + self.layer2 = network.layer2 + self.layer3 = network.layer3 + else: + raise NotImplementedError + + def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + f4 = self.res2(x) + f8 = self.layer2(f4) + f16 = self.layer3(f8) + + return f16, f8, f4 + + # override the default train() to freeze BN statistics + def train(self, mode=True): + self.training = False + for module in self.children(): + module.train(False) + return self + + +class KeyProjection(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + in_dim = model_cfg.pixel_encoder.ms_dims[0] + mid_dim = model_cfg.pixel_dim + key_dim = model_cfg.key_dim + + self.pix_feat_proj = nn.Conv2d(in_dim, mid_dim, kernel_size=1) + self.key_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1) + # shrinkage + self.d_proj = nn.Conv2d(mid_dim, 1, kernel_size=3, padding=1) + # selection + self.e_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1) + + nn.init.orthogonal_(self.key_proj.weight.data) + nn.init.zeros_(self.key_proj.bias.data) + + def forward(self, x: torch.Tensor, *, need_s: bool, + need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor): + x = self.pix_feat_proj(x) + shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None + selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None + + return self.key_proj(x), shrinkage, selection + + +class MaskEncoder(nn.Module): + def __init__(self, model_cfg: DictConfig, single_object=False): + super().__init__() + pixel_dim = model_cfg.pixel_dim + value_dim = model_cfg.value_dim + sensory_dim = model_cfg.sensory_dim + final_dim = model_cfg.mask_encoder.final_dim + + self.single_object = single_object + extra_dim = 1 if single_object else 2 + + if model_cfg.mask_encoder.type == 'resnet18': + network = resnet.resnet18(pretrained=True, extra_dim=extra_dim) + elif model_cfg.mask_encoder.type == 'resnet50': + network = resnet.resnet50(pretrained=True, extra_dim=extra_dim) + else: + raise NotImplementedError + self.conv1 = network.conv1 + self.bn1 = network.bn1 + self.relu = network.relu + self.maxpool = network.maxpool + + self.layer1 = network.layer1 + self.layer2 = network.layer2 + self.layer3 = network.layer3 + + self.distributor = MainToGroupDistributor() + self.fuser = GroupFeatureFusionBlock(pixel_dim, final_dim, value_dim) + + self.sensory_update = SensoryDeepUpdater(value_dim, sensory_dim) + + def forward(self, + image: torch.Tensor, + pix_feat: torch.Tensor, + sensory: torch.Tensor, + masks: torch.Tensor, + others: torch.Tensor, + *, + deep_update: bool = True, + chunk_size: int = -1) -> (torch.Tensor, torch.Tensor): + # ms_features are from the key encoder + # we only use the first one (lowest resolution), following XMem + if self.single_object: + g = masks.unsqueeze(2) + else: + g = torch.stack([masks, others], dim=2) + + g = self.distributor(image, g) + + batch_size, num_objects = g.shape[:2] + if chunk_size < 1 or chunk_size >= num_objects: + chunk_size = num_objects + fast_path = True + new_sensory = sensory + else: + if deep_update: + new_sensory = torch.empty_like(sensory) + else: + new_sensory = sensory + fast_path = False + + # chunk-by-chunk inference + all_g = [] + for i in range(0, num_objects, chunk_size): + if fast_path: + g_chunk = g + else: + g_chunk = g[:, i:i + chunk_size] + actual_chunk_size = g_chunk.shape[1] + g_chunk = g_chunk.flatten(start_dim=0, end_dim=1) + + g_chunk = self.conv1(g_chunk) + g_chunk = self.bn1(g_chunk) # 1/2, 64 + g_chunk = self.maxpool(g_chunk) # 1/4, 64 + g_chunk = self.relu(g_chunk) + + g_chunk = self.layer1(g_chunk) # 1/4 + g_chunk = self.layer2(g_chunk) # 1/8 + g_chunk = self.layer3(g_chunk) # 1/16 + + g_chunk = g_chunk.view(batch_size, actual_chunk_size, *g_chunk.shape[1:]) + g_chunk = self.fuser(pix_feat, g_chunk) + all_g.append(g_chunk) + if deep_update: + if fast_path: + new_sensory = self.sensory_update(g_chunk, sensory) + else: + new_sensory[:, i:i + chunk_size] = self.sensory_update( + g_chunk, sensory[:, i:i + chunk_size]) + g = torch.cat(all_g, dim=1) + + return g, new_sensory + + # override the default train() to freeze BN statistics + def train(self, mode=True): + self.training = False + for module in self.children(): + module.train(False) + return self + + +class PixelFeatureFuser(nn.Module): + def __init__(self, model_cfg: DictConfig, single_object=False): + super().__init__() + value_dim = model_cfg.value_dim + sensory_dim = model_cfg.sensory_dim + pixel_dim = model_cfg.pixel_dim + embed_dim = model_cfg.embed_dim + self.single_object = single_object + + self.fuser = GroupFeatureFusionBlock(pixel_dim, value_dim, embed_dim) + if self.single_object: + self.sensory_compress = GConv2d(sensory_dim + 1, value_dim, kernel_size=1) + else: + self.sensory_compress = GConv2d(sensory_dim + 2, value_dim, kernel_size=1) + + def forward(self, + pix_feat: torch.Tensor, + pixel_memory: torch.Tensor, + sensory_memory: torch.Tensor, + last_mask: torch.Tensor, + last_others: torch.Tensor, + *, + chunk_size: int = -1) -> torch.Tensor: + batch_size, num_objects = pixel_memory.shape[:2] + + if self.single_object: + last_mask = last_mask.unsqueeze(2) + else: + last_mask = torch.stack([last_mask, last_others], dim=2) + + if chunk_size < 1: + chunk_size = num_objects + + # chunk-by-chunk inference + all_p16 = [] + for i in range(0, num_objects, chunk_size): + sensory_readout = self.sensory_compress( + torch.cat([sensory_memory[:, i:i + chunk_size], last_mask[:, i:i + chunk_size]], 2)) + p16 = pixel_memory[:, i:i + chunk_size] + sensory_readout + p16 = self.fuser(pix_feat, p16) + all_p16.append(p16) + p16 = torch.cat(all_p16, dim=1) + + return p16 + + +class MaskDecoder(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + embed_dim = model_cfg.embed_dim + sensory_dim = model_cfg.sensory_dim + ms_image_dims = model_cfg.pixel_encoder.ms_dims + up_dims = model_cfg.mask_decoder.up_dims + + assert embed_dim == up_dims[0] + + self.sensory_update = SensoryUpdater([up_dims[0], up_dims[1], up_dims[2] + 1], sensory_dim, + sensory_dim) + + self.decoder_feat_proc = DecoderFeatureProcessor(ms_image_dims[1:], up_dims[:-1]) + self.up_16_8 = MaskUpsampleBlock(up_dims[0], up_dims[1]) + self.up_8_4 = MaskUpsampleBlock(up_dims[1], up_dims[2]) + + self.pred = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1) + + def forward(self, + ms_image_feat: Iterable[torch.Tensor], + memory_readout: torch.Tensor, + sensory: torch.Tensor, + *, + chunk_size: int = -1, + update_sensory: bool = True) -> (torch.Tensor, torch.Tensor): + + batch_size, num_objects = memory_readout.shape[:2] + f8, f4 = self.decoder_feat_proc(ms_image_feat[1:]) + if chunk_size < 1 or chunk_size >= num_objects: + chunk_size = num_objects + fast_path = True + new_sensory = sensory + else: + if update_sensory: + new_sensory = torch.empty_like(sensory) + else: + new_sensory = sensory + fast_path = False + + # chunk-by-chunk inference + all_logits = [] + for i in range(0, num_objects, chunk_size): + if fast_path: + p16 = memory_readout + else: + p16 = memory_readout[:, i:i + chunk_size] + actual_chunk_size = p16.shape[1] + + p8 = self.up_16_8(p16, f8) + p4 = self.up_8_4(p8, f4) + with torch.cuda.amp.autocast(enabled=False): + logits = self.pred(F.relu(p4.flatten(start_dim=0, end_dim=1).float())) + + if update_sensory: + p4 = torch.cat( + [p4, logits.view(batch_size, actual_chunk_size, 1, *logits.shape[-2:])], 2) + if fast_path: + new_sensory = self.sensory_update([p16, p8, p4], sensory) + else: + new_sensory[:, + i:i + chunk_size] = self.sensory_update([p16, p8, p4], + sensory[:, + i:i + chunk_size]) + all_logits.append(logits) + logits = torch.cat(all_logits, dim=0) + logits = logits.view(batch_size, num_objects, *logits.shape[-2:]) + + return new_sensory, logits diff --git a/web-demos/hugging_face/tracker/model/channel_attn.py b/web-demos/hugging_face/tracker/model/channel_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..a2096c1c4b4745a3ea2060bb25af3b19ff9cf3ec --- /dev/null +++ b/web-demos/hugging_face/tracker/model/channel_attn.py @@ -0,0 +1,39 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CAResBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int, residual: bool = True): + super().__init__() + self.residual = residual + self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1) + + t = int((abs(math.log2(out_dim)) + 1) // 2) + k = t if t % 2 else t + 1 + self.pool = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False) + + if self.residual: + if in_dim == out_dim: + self.downsample = nn.Identity() + else: + self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r = x + x = self.conv1(F.relu(x)) + x = self.conv2(F.relu(x)) + + b, c = x.shape[:2] + w = self.pool(x).view(b, 1, c) + w = self.conv(w).transpose(-1, -2).unsqueeze(-1).sigmoid() # B*C*1*1 + + if self.residual: + x = x * w + self.downsample(r) + else: + x = x * w + + return x diff --git a/web-demos/hugging_face/tracker/model/cutie.py b/web-demos/hugging_face/tracker/model/cutie.py new file mode 100644 index 0000000000000000000000000000000000000000..82c5652a3f3d657ab71ed208cd11ca2322608d7a --- /dev/null +++ b/web-demos/hugging_face/tracker/model/cutie.py @@ -0,0 +1,249 @@ +from typing import List, Dict +import logging +from omegaconf import DictConfig +import torch +import torch.nn as nn + +from tracker.model.modules import * +from tracker.model.big_modules import * +from tracker.model.aux_modules import AuxComputer +from tracker.model.utils.memory_utils import * +from tracker.model.transformer.object_transformer import QueryTransformer +from tracker.model.transformer.object_summarizer import ObjectSummarizer +from tracker.utils.tensor_utils import aggregate + +log = logging.getLogger() + + +class CUTIE(nn.Module): + def __init__(self, cfg: DictConfig, *, single_object=False): + super().__init__() + model_cfg = cfg.model + self.ms_dims = model_cfg.pixel_encoder.ms_dims + self.key_dim = model_cfg.key_dim + self.value_dim = model_cfg.value_dim + self.sensory_dim = model_cfg.sensory_dim + self.pixel_dim = model_cfg.pixel_dim + self.embed_dim = model_cfg.embed_dim + self.single_object = single_object + + log.info(f'Single object: {self.single_object}') + + self.pixel_encoder = PixelEncoder(model_cfg) + self.pix_feat_proj = nn.Conv2d(self.ms_dims[0], self.pixel_dim, kernel_size=1) + self.key_proj = KeyProjection(model_cfg) + self.mask_encoder = MaskEncoder(model_cfg, single_object=single_object) + self.mask_decoder = MaskDecoder(model_cfg) + self.pixel_fuser = PixelFeatureFuser(model_cfg, single_object=single_object) + self.object_transformer = QueryTransformer(model_cfg) + self.object_summarizer = ObjectSummarizer(model_cfg) + self.aux_computer = AuxComputer(cfg) + + self.register_buffer("pixel_mean", torch.Tensor(model_cfg.pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(model_cfg.pixel_std).view(-1, 1, 1), False) + + def _get_others(self, masks: torch.Tensor) -> torch.Tensor: + # for each object, return the sum of masks of all other objects + if self.single_object: + return None + + num_objects = masks.shape[1] + if num_objects >= 1: + others = (masks.sum(dim=1, keepdim=True) - masks).clamp(0, 1) + else: + others = torch.zeros_like(masks) + return others + + def encode_image(self, image: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor): + image = (image - self.pixel_mean) / self.pixel_std + ms_image_feat = self.pixel_encoder(image) + return ms_image_feat, self.pix_feat_proj(ms_image_feat[0]) + + def encode_mask( + self, + image: torch.Tensor, + ms_features: List[torch.Tensor], + sensory: torch.Tensor, + masks: torch.Tensor, + *, + deep_update: bool = True, + chunk_size: int = -1, + need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): + image = (image - self.pixel_mean) / self.pixel_std + others = self._get_others(masks) + mask_value, new_sensory = self.mask_encoder(image, + ms_features, + sensory, + masks, + others, + deep_update=deep_update, + chunk_size=chunk_size) + object_summaries, object_logits = self.object_summarizer(masks, mask_value, need_weights) + return mask_value, new_sensory, object_summaries, object_logits + + def transform_key(self, + final_pix_feat: torch.Tensor, + *, + need_sk: bool = True, + need_ek: bool = True) -> (torch.Tensor, torch.Tensor, torch.Tensor): + key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek) + return key, shrinkage, selection + + # Used in training only. + # This step is replaced by MemoryManager in test time + def read_memory(self, query_key: torch.Tensor, query_selection: torch.Tensor, + memory_key: torch.Tensor, memory_shrinkage: torch.Tensor, + msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor, + sensory: torch.Tensor, last_mask: torch.Tensor, + selector: torch.Tensor) -> (torch.Tensor, Dict[str, torch.Tensor]): + """ + query_key : B * CK * H * W + query_selection : B * CK * H * W + memory_key : B * CK * T * H * W + memory_shrinkage: B * 1 * T * H * W + msk_value : B * num_objects * CV * T * H * W + obj_memory : B * num_objects * T * num_summaries * C + pixel_feature : B * C * H * W + """ + batch_size, num_objects = msk_value.shape[:2] + + # read using visual attention + with torch.cuda.amp.autocast(enabled=False): + affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(), + query_selection.float()) + + msk_value = msk_value.flatten(start_dim=1, end_dim=2).float() + + # B * (num_objects*CV) * H * W + pixel_readout = readout(affinity, msk_value) + pixel_readout = pixel_readout.view(batch_size, num_objects, self.value_dim, + *pixel_readout.shape[-2:]) + pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask) + + # read from query transformer + mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector) + + aux_output = { + 'sensory': sensory, + 'q_logits': aux_features['logits'] if aux_features else None, + 'attn_mask': aux_features['attn_mask'] if aux_features else None, + } + + return mem_readout, aux_output + + def pixel_fusion(self, + pix_feat: torch.Tensor, + pixel: torch.Tensor, + sensory: torch.Tensor, + last_mask: torch.Tensor, + *, + chunk_size: int = -1) -> torch.Tensor: + last_mask = F.interpolate(last_mask, size=sensory.shape[-2:], mode='area') + last_others = self._get_others(last_mask) + fused = self.pixel_fuser(pix_feat, + pixel, + sensory, + last_mask, + last_others, + chunk_size=chunk_size) + return fused + + def readout_query(self, + pixel_readout, + obj_memory, + *, + selector=None, + need_weights=False) -> (torch.Tensor, Dict[str, torch.Tensor]): + return self.object_transformer(pixel_readout, + obj_memory, + selector=selector, + need_weights=need_weights) + + def segment(self, + ms_image_feat: List[torch.Tensor], + memory_readout: torch.Tensor, + sensory: torch.Tensor, + *, + selector: bool = None, + chunk_size: int = -1, + update_sensory: bool = True) -> (torch.Tensor, torch.Tensor, torch.Tensor): + """ + multi_scale_features is from the key encoder for skip-connection + memory_readout is from working/long-term memory + sensory is the sensory memory + last_mask is the mask from the last frame, supplementing sensory memory + selector is 1 if an object exists, and 0 otherwise. We use it to filter padded objects + during training. + """ + sensory, logits = self.mask_decoder(ms_image_feat, + memory_readout, + sensory, + chunk_size=chunk_size, + update_sensory=update_sensory) + + prob = torch.sigmoid(logits) + if selector is not None: + prob = prob * selector + + # Softmax over all objects[] + logits = aggregate(prob, dim=1) + logits = F.interpolate(logits, scale_factor=4, mode='bilinear', align_corners=False) + prob = F.softmax(logits, dim=1) + + return sensory, logits, prob + + def compute_aux(self, pix_feat: torch.Tensor, aux_inputs: Dict[str, torch.Tensor], + selector: torch.Tensor) -> Dict[str, torch.Tensor]: + return self.aux_computer(pix_feat, aux_inputs, selector) + + def forward(self, *args, **kwargs): + raise NotImplementedError + + def load_weights(self, src_dict, init_as_zero_if_needed=False) -> None: + if not self.single_object: + # Map single-object weight to multi-object weight (4->5 out channels in conv1) + for k in list(src_dict.keys()): + if k == 'mask_encoder.conv1.weight': + if src_dict[k].shape[1] == 4: + log.info(f'Converting {k} from single object to multiple objects.') + pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device) + if not init_as_zero_if_needed: + nn.init.orthogonal_(pads) + log.info(f'Randomly initialized padding for {k}.') + else: + log.info(f'Zero-initialized padding for {k}.') + src_dict[k] = torch.cat([src_dict[k], pads], 1) + elif k == 'pixel_fuser.sensory_compress.weight': + if src_dict[k].shape[1] == self.sensory_dim + 1: + log.info(f'Converting {k} from single object to multiple objects.') + pads = torch.zeros((self.value_dim, 1, 1, 1), device=src_dict[k].device) + if not init_as_zero_if_needed: + nn.init.orthogonal_(pads) + log.info(f'Randomly initialized padding for {k}.') + else: + log.info(f'Zero-initialized padding for {k}.') + src_dict[k] = torch.cat([src_dict[k], pads], 1) + elif self.single_object: + """ + If the model is multiple-object and we are training in single-object, + we strip the last channel of conv1. + This is not supposed to happen in standard training except when users are trying to + finetune a trained model with single object datasets. + """ + if src_dict['mask_encoder.conv1.weight'].shape[1] == 5: + log.warning(f'Converting {k} from multiple objects to single object.' + 'This is not supposed to happen in standard training.') + src_dict[k] = src_dict[k][:, :-1] + + for k in src_dict: + if k not in self.state_dict(): + log.info(f'Key {k} found in src_dict but not in self.state_dict()!!!') + for k in self.state_dict(): + if k not in src_dict: + log.info(f'Key {k} found in self.state_dict() but not in src_dict!!!') + + self.load_state_dict(src_dict, strict=False) + + @property + def device(self) -> torch.device: + return self.pixel_mean.device \ No newline at end of file diff --git a/web-demos/hugging_face/tracker/model/group_modules.py b/web-demos/hugging_face/tracker/model/group_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..a9bf64c51613be619705bee2da9b7508378cbb46 --- /dev/null +++ b/web-demos/hugging_face/tracker/model/group_modules.py @@ -0,0 +1,127 @@ +from typing import Optional +import torch +import torch.nn as nn +import torch.nn.functional as F +from tracker.model.channel_attn import CAResBlock + + +def interpolate_groups(g: torch.Tensor, ratio: float, mode: str, + align_corners: bool) -> torch.Tensor: + batch_size, num_objects = g.shape[:2] + g = F.interpolate(g.flatten(start_dim=0, end_dim=1), + scale_factor=ratio, + mode=mode, + align_corners=align_corners) + g = g.view(batch_size, num_objects, *g.shape[1:]) + return g + + +def upsample_groups(g: torch.Tensor, + ratio: float = 2, + mode: str = 'bilinear', + align_corners: bool = False) -> torch.Tensor: + return interpolate_groups(g, ratio, mode, align_corners) + + +def downsample_groups(g: torch.Tensor, + ratio: float = 1 / 2, + mode: str = 'area', + align_corners: bool = None) -> torch.Tensor: + return interpolate_groups(g, ratio, mode, align_corners) + + +class GConv2d(nn.Conv2d): + def forward(self, g: torch.Tensor) -> torch.Tensor: + batch_size, num_objects = g.shape[:2] + g = super().forward(g.flatten(start_dim=0, end_dim=1)) + return g.view(batch_size, num_objects, *g.shape[1:]) + + +class GroupResBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + + if in_dim == out_dim: + self.downsample = nn.Identity() + else: + self.downsample = GConv2d(in_dim, out_dim, kernel_size=1) + + self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1) + self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1) + + def forward(self, g: torch.Tensor) -> torch.Tensor: + out_g = self.conv1(F.relu(g)) + out_g = self.conv2(F.relu(out_g)) + + g = self.downsample(g) + + return out_g + g + + +class MainToGroupDistributor(nn.Module): + def __init__(self, + x_transform: Optional[nn.Module] = None, + g_transform: Optional[nn.Module] = None, + method: str = 'cat', + reverse_order: bool = False): + super().__init__() + + self.x_transform = x_transform + self.g_transform = g_transform + self.method = method + self.reverse_order = reverse_order + + def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor: + num_objects = g.shape[1] + + if self.x_transform is not None: + x = self.x_transform(x) + + if self.g_transform is not None: + g = self.g_transform(g) + + if not skip_expand: + x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) + if self.method == 'cat': + if self.reverse_order: + g = torch.cat([g, x], 2) + else: + g = torch.cat([x, g], 2) + elif self.method == 'add': + g = x + g + elif self.method == 'mulcat': + g = torch.cat([x * g, g], dim=2) + elif self.method == 'muladd': + g = x * g + g + else: + raise NotImplementedError + + return g + + +class GroupFeatureFusionBlock(nn.Module): + def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int): + super().__init__() + + x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1) + g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1) + + self.distributor = MainToGroupDistributor(x_transform=x_transform, + g_transform=g_transform, + method='add') + self.block1 = CAResBlock(out_dim, out_dim) + self.block2 = CAResBlock(out_dim, out_dim) + + def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor: + batch_size, num_objects = g.shape[:2] + + g = self.distributor(x, g) + + g = g.flatten(start_dim=0, end_dim=1) + + g = self.block1(g) + g = self.block2(g) + + g = g.view(batch_size, num_objects, *g.shape[1:]) + + return g \ No newline at end of file diff --git a/web-demos/hugging_face/tracker/model/losses.py b/web-demos/hugging_face/tracker/model/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..cc9c7c5c3bbd7e7bd909a82aaa8a3b5c8046d7ee --- /dev/null +++ b/web-demos/hugging_face/tracker/model/losses.py @@ -0,0 +1,97 @@ +from typing import List, Dict +from omegaconf import DictConfig +from collections import defaultdict +import torch +import torch.nn.functional as F + +from tracker.utils.point_features import calculate_uncertainty, point_sample, get_uncertain_point_coords_with_randomness +from tracker.utils.tensor_utils import cls_to_one_hot + + +@torch.jit.script +def ce_loss(logits: torch.Tensor, soft_gt: torch.Tensor) -> torch.Tensor: + # logits: T*C*num_points + loss = F.cross_entropy(logits, soft_gt, reduction='none') + # sum over temporal dimension + return loss.sum(0).mean() + + +@torch.jit.script +def dice_loss(mask: torch.Tensor, soft_gt: torch.Tensor) -> torch.Tensor: + # mask: T*C*num_points + # soft_gt: T*C*num_points + # ignores the background + mask = mask[:, 1:].flatten(start_dim=2) + gt = soft_gt[:, 1:].float().flatten(start_dim=2) + numerator = 2 * (mask * gt).sum(-1) + denominator = mask.sum(-1) + gt.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum(0).mean() + + +class LossComputer: + def __init__(self, cfg: DictConfig, stage_cfg: DictConfig): + super().__init__() + self.point_supervision = stage_cfg.point_supervision + self.num_points = stage_cfg.train_num_points + self.oversample_ratio = stage_cfg.oversample_ratio + self.importance_sample_ratio = stage_cfg.importance_sample_ratio + + self.sensory_weight = cfg.model.aux_loss.sensory.weight + self.query_weight = cfg.model.aux_loss.query.weight + + def mask_loss(self, logits: torch.Tensor, + soft_gt: torch.Tensor) -> (torch.Tensor, torch.Tensor): + assert self.point_supervision + + with torch.no_grad(): + # sample point_coords + point_coords = get_uncertain_point_coords_with_randomness( + logits, lambda x: calculate_uncertainty(x), self.num_points, self.oversample_ratio, + self.importance_sample_ratio) + # get gt labels + point_labels = point_sample(soft_gt, point_coords, align_corners=False) + point_logits = point_sample(logits, point_coords, align_corners=False) + # point_labels and point_logits: B*C*num_points + + loss_ce = ce_loss(point_logits, point_labels) + loss_dice = dice_loss(point_logits.softmax(dim=1), point_labels) + + return loss_ce, loss_dice + + def compute(self, data: Dict[str, torch.Tensor], + num_objects: List[int]) -> Dict[str, torch.Tensor]: + batch_size, num_frames = data['rgb'].shape[:2] + losses = defaultdict(float) + t_range = range(1, num_frames) + + for bi in range(batch_size): + logits = torch.stack([data[f'logits_{ti}'][bi, :num_objects[bi] + 1] for ti in t_range], + dim=0) + cls_gt = data['cls_gt'][bi, 1:] # remove gt for the first frame + soft_gt = cls_to_one_hot(cls_gt, num_objects[bi]) + + loss_ce, loss_dice = self.mask_loss(logits, soft_gt) + losses['loss_ce'] += loss_ce / batch_size + losses['loss_dice'] += loss_dice / batch_size + + aux = [data[f'aux_{ti}'] for ti in t_range] + if 'sensory_logits' in aux[0]: + sensory_log = torch.stack( + [a['sensory_logits'][bi, :num_objects[bi] + 1] for a in aux], dim=0) + loss_ce, loss_dice = self.mask_loss(sensory_log, soft_gt) + losses['aux_sensory_ce'] += loss_ce / batch_size * self.sensory_weight + losses['aux_sensory_dice'] += loss_dice / batch_size * self.sensory_weight + if 'q_logits' in aux[0]: + num_levels = aux[0]['q_logits'].shape[2] + + for l in range(num_levels): + query_log = torch.stack( + [a['q_logits'][bi, :num_objects[bi] + 1, l] for a in aux], dim=0) + loss_ce, loss_dice = self.mask_loss(query_log, soft_gt) + losses[f'aux_query_ce_l{l}'] += loss_ce / batch_size * self.query_weight + losses[f'aux_query_dice_l{l}'] += loss_dice / batch_size * self.query_weight + + losses['total_loss'] = sum(losses.values()) + + return losses diff --git a/web-demos/hugging_face/tracker/model/modules.py b/web-demos/hugging_face/tracker/model/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..59c4170af5abfecf7b85ed7804fc390285e0194d --- /dev/null +++ b/web-demos/hugging_face/tracker/model/modules.py @@ -0,0 +1,85 @@ +from typing import List, Iterable +import torch +import torch.nn as nn + +from tracker.model.group_modules import * + + +class MaskUpsampleBlock(nn.Module): + def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2): + super().__init__() + self.distributor = MainToGroupDistributor(method='add') + self.out_conv = GroupResBlock(in_dim, out_dim) + self.scale_factor = scale_factor + + def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor: + g = upsample_groups(in_g, ratio=self.scale_factor) + g = self.distributor(skip_f, g) + g = self.out_conv(g) + return g + + +class DecoderFeatureProcessor(nn.Module): + def __init__(self, decoder_dims: List[int], out_dims: List[int]): + super().__init__() + self.transforms = nn.ModuleList([ + nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims) + ]) + + def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]: + outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)] + return outputs + + +# @torch.jit.script +def _recurrent_update(h: torch.Tensor, values: torch.Tensor) -> torch.Tensor: + # h: batch_size * num_objects * hidden_dim * h * w + # values: batch_size * num_objects * (hidden_dim*3) * h * w + dim = values.shape[2] // 3 + forget_gate = torch.sigmoid(values[:, :, :dim]) + update_gate = torch.sigmoid(values[:, :, dim:dim * 2]) + new_value = torch.tanh(values[:, :, dim * 2:]) + new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value + return new_h + + +class SensoryUpdater(nn.Module): + # Used in the decoder, multi-scale feature + GRU + def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int): + super().__init__() + self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1) + self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1) + self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1) + + self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1) + + nn.init.xavier_normal_(self.transform.weight) + + def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor: + g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \ + self.g4_conv(downsample_groups(g[2], ratio=1/4)) + + with torch.cuda.amp.autocast(enabled=False): + g = g.float() + h = h.float() + values = self.transform(torch.cat([g, h], dim=2)) + new_h = _recurrent_update(h, values) + + return new_h + + +class SensoryDeepUpdater(nn.Module): + def __init__(self, f_dim: int, sensory_dim: int): + super().__init__() + self.transform = GConv2d(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1) + + nn.init.xavier_normal_(self.transform.weight) + + def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor: + with torch.cuda.amp.autocast(enabled=False): + g = g.float() + h = h.float() + values = self.transform(torch.cat([g, h], dim=2)) + new_h = _recurrent_update(h, values) + + return new_h diff --git a/web-demos/hugging_face/tracker/model/transformer/__init__.py b/web-demos/hugging_face/tracker/model/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/web-demos/hugging_face/tracker/model/transformer/object_summarizer.py b/web-demos/hugging_face/tracker/model/transformer/object_summarizer.py new file mode 100644 index 0000000000000000000000000000000000000000..42ee1b5385d607f34145e25b0362678f196064a2 --- /dev/null +++ b/web-demos/hugging_face/tracker/model/transformer/object_summarizer.py @@ -0,0 +1,89 @@ +from typing import List, Dict, Optional +from omegaconf import DictConfig + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tracker.model.transformer.positional_encoding import PositionalEncoding + + +# @torch.jit.script +def _weighted_pooling(masks: torch.Tensor, value: torch.Tensor, + logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): + # value: B*num_objects*H*W*value_dim + # logits: B*num_objects*H*W*num_summaries + # masks: B*num_objects*H*W*num_summaries: 1 if allowed + weights = logits.sigmoid() * masks + # B*num_objects*num_summaries*value_dim + sums = torch.einsum('bkhwq,bkhwc->bkqc', weights, value) + # B*num_objects*H*W*num_summaries -> B*num_objects*num_summaries*1 + area = weights.flatten(start_dim=2, end_dim=3).sum(2).unsqueeze(-1) + + # B*num_objects*num_summaries*value_dim + return sums, area + + +class ObjectSummarizer(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + + this_cfg = model_cfg.object_summarizer + self.value_dim = model_cfg.value_dim + self.embed_dim = this_cfg.embed_dim + self.num_summaries = this_cfg.num_summaries + self.add_pe = this_cfg.add_pe + self.pixel_pe_scale = model_cfg.pixel_pe_scale + self.pixel_pe_temperature = model_cfg.pixel_pe_temperature + + if self.add_pe: + self.pos_enc = PositionalEncoding(self.embed_dim, + scale=self.pixel_pe_scale, + temperature=self.pixel_pe_temperature) + + self.input_proj = nn.Linear(self.value_dim, self.embed_dim) + self.feature_pred = nn.Sequential( + nn.Linear(self.embed_dim, self.embed_dim), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dim, self.embed_dim), + ) + self.weights_pred = nn.Sequential( + nn.Linear(self.embed_dim, self.embed_dim), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dim, self.num_summaries), + ) + + def forward(self, + masks: torch.Tensor, + value: torch.Tensor, + need_weights: bool = False) -> (torch.Tensor, Optional[torch.Tensor]): + # masks: B*num_objects*(H0)*(W0) + # value: B*num_objects*value_dim*H*W + # -> B*num_objects*H*W*value_dim + h, w = value.shape[-2:] + masks = F.interpolate(masks, size=(h, w), mode='area') + masks = masks.unsqueeze(-1) + inv_masks = 1 - masks + repeated_masks = torch.cat([ + masks.expand(-1, -1, -1, -1, self.num_summaries // 2), + inv_masks.expand(-1, -1, -1, -1, self.num_summaries // 2), + ], + dim=-1) + + value = value.permute(0, 1, 3, 4, 2) + value = self.input_proj(value) + if self.add_pe: + pe = self.pos_enc(value) + value = value + pe + + with torch.cuda.amp.autocast(enabled=False): + value = value.float() + feature = self.feature_pred(value) + logits = self.weights_pred(value) + sums, area = _weighted_pooling(repeated_masks, feature, logits) + + summaries = torch.cat([sums, area], dim=-1) + + if need_weights: + return summaries, logits + else: + return summaries, None \ No newline at end of file diff --git a/web-demos/hugging_face/tracker/model/transformer/object_transformer.py b/web-demos/hugging_face/tracker/model/transformer/object_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..71f0830247495df161f6126dd40ea9ff7f30b9f2 --- /dev/null +++ b/web-demos/hugging_face/tracker/model/transformer/object_transformer.py @@ -0,0 +1,205 @@ +from typing import Dict, Optional +from omegaconf import DictConfig + +import torch +import torch.nn as nn +from tracker.model.group_modules import GConv2d +from tracker.utils.tensor_utils import aggregate +from tracker.model.transformer.positional_encoding import PositionalEncoding +from tracker.model.transformer.transformer_layers import * + + +class QueryTransformerBlock(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + + this_cfg = model_cfg.object_transformer + self.embed_dim = this_cfg.embed_dim + self.num_heads = this_cfg.num_heads + self.num_queries = this_cfg.num_queries + self.ff_dim = this_cfg.ff_dim + + self.read_from_pixel = CrossAttention(self.embed_dim, + self.num_heads, + add_pe_to_qkv=this_cfg.read_from_pixel.add_pe_to_qkv) + self.self_attn = SelfAttention(self.embed_dim, + self.num_heads, + add_pe_to_qkv=this_cfg.query_self_attention.add_pe_to_qkv) + self.ffn = FFN(self.embed_dim, self.ff_dim) + self.read_from_query = CrossAttention(self.embed_dim, + self.num_heads, + add_pe_to_qkv=this_cfg.read_from_query.add_pe_to_qkv, + norm=this_cfg.read_from_query.output_norm) + self.pixel_ffn = PixelFFN(self.embed_dim) + + def forward( + self, + x: torch.Tensor, + pixel: torch.Tensor, + query_pe: torch.Tensor, + pixel_pe: torch.Tensor, + attn_mask: torch.Tensor, + need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): + # x: (bs*num_objects)*num_queries*embed_dim + # pixel: bs*num_objects*C*H*W + # query_pe: (bs*num_objects)*num_queries*embed_dim + # pixel_pe: (bs*num_objects)*(H*W)*C + # attn_mask: (bs*num_objects*num_heads)*num_queries*(H*W) + + # bs*num_objects*C*H*W -> (bs*num_objects)*(H*W)*C + pixel_flat = pixel.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous() + x, q_weights = self.read_from_pixel(x, + pixel_flat, + query_pe, + pixel_pe, + attn_mask=attn_mask, + need_weights=need_weights) + x = self.self_attn(x, query_pe) + x = self.ffn(x) + + pixel_flat, p_weights = self.read_from_query(pixel_flat, + x, + pixel_pe, + query_pe, + need_weights=need_weights) + pixel = self.pixel_ffn(pixel, pixel_flat) + + if need_weights: + bs, num_objects, _, h, w = pixel.shape + q_weights = q_weights.view(bs, num_objects, self.num_heads, self.num_queries, h, w) + p_weights = p_weights.transpose(2, 3).view(bs, num_objects, self.num_heads, + self.num_queries, h, w) + + return x, pixel, q_weights, p_weights + + +class QueryTransformer(nn.Module): + def __init__(self, model_cfg: DictConfig): + super().__init__() + + this_cfg = model_cfg.object_transformer + self.value_dim = model_cfg.value_dim + self.embed_dim = this_cfg.embed_dim + self.num_heads = this_cfg.num_heads + self.num_queries = this_cfg.num_queries + + # query initialization and embedding + self.query_init = nn.Embedding(self.num_queries, self.embed_dim) + self.query_emb = nn.Embedding(self.num_queries, self.embed_dim) + + # projection from object summaries to query initialization and embedding + self.summary_to_query_init = nn.Linear(self.embed_dim, self.embed_dim) + self.summary_to_query_emb = nn.Linear(self.embed_dim, self.embed_dim) + + self.pixel_pe_scale = model_cfg.pixel_pe_scale + self.pixel_pe_temperature = model_cfg.pixel_pe_temperature + self.pixel_init_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1) + self.pixel_emb_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1) + self.spatial_pe = PositionalEncoding(self.embed_dim, + scale=self.pixel_pe_scale, + temperature=self.pixel_pe_temperature, + channel_last=False, + transpose_output=True) + + # transformer blocks + self.num_blocks = this_cfg.num_blocks + self.blocks = nn.ModuleList( + QueryTransformerBlock(model_cfg) for _ in range(self.num_blocks)) + self.mask_pred = nn.ModuleList( + nn.Sequential(nn.ReLU(), GConv2d(self.embed_dim, 1, kernel_size=1)) + for _ in range(self.num_blocks + 1)) + + self.act = nn.ReLU(inplace=True) + + def forward(self, + pixel: torch.Tensor, + obj_summaries: torch.Tensor, + selector: Optional[torch.Tensor] = None, + need_weights: bool = False) -> (torch.Tensor, Dict[str, torch.Tensor]): + # pixel: B*num_objects*embed_dim*H*W + # obj_summaries: B*num_objects*T*num_queries*embed_dim + T = obj_summaries.shape[2] + bs, num_objects, _, H, W = pixel.shape + + # normalize object values + # the last channel is the cumulative area of the object + obj_summaries = obj_summaries.view(bs * num_objects, T, self.num_queries, + self.embed_dim + 1) + # sum over time + # during inference, T=1 as we already did streaming average in memory_manager + obj_sums = obj_summaries[:, :, :, :-1].sum(dim=1) + obj_area = obj_summaries[:, :, :, -1:].sum(dim=1) + obj_values = obj_sums / (obj_area + 1e-4) + obj_init = self.summary_to_query_init(obj_values) + obj_emb = self.summary_to_query_emb(obj_values) + + # positional embeddings for object queries + query = self.query_init.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_init + query_emb = self.query_emb.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_emb + + # positional embeddings for pixel features + pixel_init = self.pixel_init_proj(pixel) + pixel_emb = self.pixel_emb_proj(pixel) + pixel_pe = self.spatial_pe(pixel.flatten(0, 1)) + pixel_emb = pixel_emb.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous() + pixel_pe = pixel_pe.flatten(1, 2) + pixel_emb + + pixel = pixel_init + + # run the transformer + aux_features = {'logits': []} + + # first aux output + aux_logits = self.mask_pred[0](pixel).squeeze(2) + attn_mask = self._get_aux_mask(aux_logits, selector) + aux_features['logits'].append(aux_logits) + for i in range(self.num_blocks): + query, pixel, q_weights, p_weights = self.blocks[i](query, + pixel, + query_emb, + pixel_pe, + attn_mask, + need_weights=need_weights) + + if self.training or i <= self.num_blocks - 1 or need_weights: + aux_logits = self.mask_pred[i + 1](pixel).squeeze(2) + attn_mask = self._get_aux_mask(aux_logits, selector) + aux_features['logits'].append(aux_logits) + + aux_features['q_weights'] = q_weights # last layer only + aux_features['p_weights'] = p_weights # last layer only + + if self.training: + # no need to save all heads + aux_features['attn_mask'] = attn_mask.view(bs, num_objects, self.num_heads, + self.num_queries, H, W)[:, :, 0] + + return pixel, aux_features + + def _get_aux_mask(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor: + # logits: batch_size*num_objects*H*W + # selector: batch_size*num_objects*1*1 + # returns a mask of shape (batch_size*num_objects*num_heads)*num_queries*(H*W) + # where True means the attention is blocked + + if selector is None: + prob = logits.sigmoid() + else: + prob = logits.sigmoid() * selector + logits = aggregate(prob, dim=1) + + is_foreground = (logits[:, 1:] >= logits.max(dim=1, keepdim=True)[0]) + foreground_mask = is_foreground.bool().flatten(start_dim=2) + inv_foreground_mask = ~foreground_mask + inv_background_mask = foreground_mask + + aux_foreground_mask = inv_foreground_mask.unsqueeze(2).unsqueeze(2).repeat( + 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2) + aux_background_mask = inv_background_mask.unsqueeze(2).unsqueeze(2).repeat( + 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2) + + aux_mask = torch.cat([aux_foreground_mask, aux_background_mask], dim=1) + + aux_mask[torch.where(aux_mask.sum(-1) == aux_mask.shape[-1])] = False + + return aux_mask \ No newline at end of file diff --git a/web-demos/hugging_face/tracker/model/transformer/positional_encoding.py b/web-demos/hugging_face/tracker/model/transformer/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..6c15bb73784d3e5fcb1a5d2f9713069e7a933f34 --- /dev/null +++ b/web-demos/hugging_face/tracker/model/transformer/positional_encoding.py @@ -0,0 +1,108 @@ +# Reference: +# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/position_encoding.py +# https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py + +import math + +import numpy as np +import torch +from torch import nn + + +def get_emb(sin_inp: torch.Tensor) -> torch.Tensor: + """ + Gets a base embedding for one dimension with sin and cos intertwined + """ + emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) + return torch.flatten(emb, -2, -1) + + +class PositionalEncoding(nn.Module): + def __init__(self, + dim: int, + scale: float = math.pi * 2, + temperature: float = 10000, + normalize: bool = True, + channel_last: bool = True, + transpose_output: bool = False): + super().__init__() + dim = int(np.ceil(dim / 4) * 2) + self.dim = dim + inv_freq = 1.0 / (temperature**(torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.normalize = normalize + self.scale = scale + self.eps = 1e-6 + self.channel_last = channel_last + self.transpose_output = transpose_output + + self.cached_penc = None # the cache is irrespective of the number of objects + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + """ + :param tensor: A 4/5d tensor of size + channel_last=True: (batch_size, h, w, c) or (batch_size, k, h, w, c) + channel_last=False: (batch_size, c, h, w) or (batch_size, k, c, h, w) + :return: positional encoding tensor that has the same shape as the input if the input is 4d + if the input is 5d, the output is broadcastable along the k-dimension + """ + if len(tensor.shape) != 4 and len(tensor.shape) != 5: + raise RuntimeError(f'The input tensor has to be 4/5d, got {tensor.shape}!') + + if len(tensor.shape) == 5: + # take a sample from the k dimension + num_objects = tensor.shape[1] + tensor = tensor[:, 0] + else: + num_objects = None + + if self.channel_last: + batch_size, h, w, c = tensor.shape + else: + batch_size, c, h, w = tensor.shape + + if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: + if num_objects is None: + return self.cached_penc + else: + return self.cached_penc.unsqueeze(1) + + self.cached_penc = None + + pos_y = torch.arange(h, device=tensor.device, dtype=self.inv_freq.dtype) + pos_x = torch.arange(w, device=tensor.device, dtype=self.inv_freq.dtype) + if self.normalize: + pos_y = pos_y / (pos_y[-1] + self.eps) * self.scale + pos_x = pos_x / (pos_x[-1] + self.eps) * self.scale + + sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) + sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) + emb_y = get_emb(sin_inp_y).unsqueeze(1) + emb_x = get_emb(sin_inp_x) + + emb = torch.zeros((h, w, self.dim * 2), device=tensor.device, dtype=tensor.dtype) + emb[:, :, :self.dim] = emb_x + emb[:, :, self.dim:] = emb_y + + if not self.channel_last and self.transpose_output: + # cancelled out + pass + elif (not self.channel_last) or (self.transpose_output): + emb = emb.permute(2, 0, 1) + + self.cached_penc = emb.unsqueeze(0).repeat(batch_size, 1, 1, 1) + if num_objects is None: + return self.cached_penc + else: + return self.cached_penc.unsqueeze(1) + + +if __name__ == '__main__': + pe = PositionalEncoding(8).cuda() + input = torch.ones((1, 8, 8, 8)).cuda() + output = pe(input) + # print(output) + print(output[0, :, 0, 0]) + print(output[0, :, 0, 5]) + print(output[0, 0, :, 0]) + print(output[0, 0, 0, :]) diff --git a/web-demos/hugging_face/tracker/model/transformer/transformer_layers.py b/web-demos/hugging_face/tracker/model/transformer/transformer_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..80cf522dad092e8282f43e7f0f0dc05cfd15aa9b --- /dev/null +++ b/web-demos/hugging_face/tracker/model/transformer/transformer_layers.py @@ -0,0 +1,161 @@ +# Modified from PyTorch nn.Transformer + +from typing import List, Callable + +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F +from tracker.model.channel_attn import CAResBlock + + +class SelfAttention(nn.Module): + def __init__(self, + dim: int, + nhead: int, + dropout: float = 0.0, + batch_first: bool = True, + add_pe_to_qkv: List[bool] = [True, True, False]): + super().__init__() + self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first) + self.norm = nn.LayerNorm(dim) + self.dropout = nn.Dropout(dropout) + self.add_pe_to_qkv = add_pe_to_qkv + + def forward(self, + x: torch.Tensor, + pe: torch.Tensor, + attn_mask: bool = None, + key_padding_mask: bool = None) -> torch.Tensor: + x = self.norm(x) + if any(self.add_pe_to_qkv): + x_with_pe = x + pe + q = x_with_pe if self.add_pe_to_qkv[0] else x + k = x_with_pe if self.add_pe_to_qkv[1] else x + v = x_with_pe if self.add_pe_to_qkv[2] else x + else: + q = k = v = x + + r = x + x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] + return r + self.dropout(x) + + +# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention +class CrossAttention(nn.Module): + def __init__(self, + dim: int, + nhead: int, + dropout: float = 0.0, + batch_first: bool = True, + add_pe_to_qkv: List[bool] = [True, True, False], + residual: bool = True, + norm: bool = True): + super().__init__() + self.cross_attn = nn.MultiheadAttention(dim, + nhead, + dropout=dropout, + batch_first=batch_first) + if norm: + self.norm = nn.LayerNorm(dim) + else: + self.norm = nn.Identity() + self.dropout = nn.Dropout(dropout) + self.add_pe_to_qkv = add_pe_to_qkv + self.residual = residual + + def forward(self, + x: torch.Tensor, + mem: torch.Tensor, + x_pe: torch.Tensor, + mem_pe: torch.Tensor, + attn_mask: bool = None, + *, + need_weights: bool = False) -> (torch.Tensor, torch.Tensor): + x = self.norm(x) + if self.add_pe_to_qkv[0]: + q = x + x_pe + else: + q = x + + if any(self.add_pe_to_qkv[1:]): + mem_with_pe = mem + mem_pe + k = mem_with_pe if self.add_pe_to_qkv[1] else mem + v = mem_with_pe if self.add_pe_to_qkv[2] else mem + else: + k = v = mem + r = x + x, weights = self.cross_attn(q, + k, + v, + attn_mask=attn_mask, + need_weights=need_weights, + average_attn_weights=False) + + if self.residual: + return r + self.dropout(x), weights + else: + return self.dropout(x), weights + + +class FFN(nn.Module): + def __init__(self, dim_in: int, dim_ff: int, activation=F.relu): + super().__init__() + self.linear1 = nn.Linear(dim_in, dim_ff) + self.linear2 = nn.Linear(dim_ff, dim_in) + self.norm = nn.LayerNorm(dim_in) + + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r = x + x = self.norm(x) + x = self.linear2(self.activation(self.linear1(x))) + x = r + x + return x + + +class PixelFFN(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + self.conv = CAResBlock(dim, dim) + + def forward(self, pixel: torch.Tensor, pixel_flat: torch.Tensor) -> torch.Tensor: + # pixel: batch_size * num_objects * dim * H * W + # pixel_flat: (batch_size*num_objects) * (H*W) * dim + bs, num_objects, _, h, w = pixel.shape + pixel_flat = pixel_flat.view(bs * num_objects, h, w, self.dim) + pixel_flat = pixel_flat.permute(0, 3, 1, 2).contiguous() + + x = self.conv(pixel_flat) + x = x.view(bs, num_objects, self.dim, h, w) + return x + + +class OutputFFN(nn.Module): + def __init__(self, dim_in: int, dim_out: int, activation=F.relu): + super().__init__() + self.linear1 = nn.Linear(dim_in, dim_out) + self.linear2 = nn.Linear(dim_out, dim_out) + + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear2(self.activation(self.linear1(x))) + return x + + +def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) diff --git a/web-demos/hugging_face/tracker/model/utils/__init__.py b/web-demos/hugging_face/tracker/model/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/web-demos/hugging_face/tracker/model/utils/memory_utils.py b/web-demos/hugging_face/tracker/model/utils/memory_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7b8726ce6bca51cbd76fff814db9cf03544397cd --- /dev/null +++ b/web-demos/hugging_face/tracker/model/utils/memory_utils.py @@ -0,0 +1,95 @@ +import math +import torch +from typing import Optional, Union, Tuple + + +# @torch.jit.script +def get_similarity(mk: torch.Tensor, + ms: torch.Tensor, + qk: torch.Tensor, + qe: torch.Tensor, + add_batch_dim: bool = False) -> torch.Tensor: + # used for training/inference and memory reading/memory potentiation + # mk: B x CK x [N] - Memory keys + # ms: B x 1 x [N] - Memory shrinkage + # qk: B x CK x [HW/P] - Query keys + # qe: B x CK x [HW/P] - Query selection + # Dimensions in [] are flattened + if add_batch_dim: + mk, ms = mk.unsqueeze(0), ms.unsqueeze(0) + qk, qe = qk.unsqueeze(0), qe.unsqueeze(0) + + CK = mk.shape[1] + mk = mk.flatten(start_dim=2) + ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None + qk = qk.flatten(start_dim=2) + qe = qe.flatten(start_dim=2) if qe is not None else None + + if qe is not None: + # See XMem's appendix for derivation + mk = mk.transpose(1, 2) + a_sq = (mk.pow(2) @ qe) + two_ab = 2 * (mk @ (qk * qe)) + b_sq = (qe * qk.pow(2)).sum(1, keepdim=True) + similarity = (-a_sq + two_ab - b_sq) + else: + # similar to STCN if we don't have the selection term + a_sq = mk.pow(2).sum(1).unsqueeze(2) + two_ab = 2 * (mk.transpose(1, 2) @ qk) + similarity = (-a_sq + two_ab) + + if ms is not None: + similarity = similarity * ms / math.sqrt(CK) # B*N*HW + else: + similarity = similarity / math.sqrt(CK) # B*N*HW + + return similarity + + +def do_softmax( + similarity: torch.Tensor, + top_k: Optional[int] = None, + inplace: bool = False, + return_usage: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # normalize similarity with top-k softmax + # similarity: B x N x [HW/P] + # use inplace with care + if top_k is not None: + values, indices = torch.topk(similarity, k=top_k, dim=1) + + x_exp = values.exp_() + x_exp /= torch.sum(x_exp, dim=1, keepdim=True) + if inplace: + similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW + affinity = similarity + else: + affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW + else: + maxes = torch.max(similarity, dim=1, keepdim=True)[0] + x_exp = torch.exp(similarity - maxes) + x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) + affinity = x_exp / x_exp_sum + indices = None + + if return_usage: + return affinity, affinity.sum(dim=2) + + return affinity + + +def get_affinity(mk: torch.Tensor, ms: torch.Tensor, qk: torch.Tensor, + qe: torch.Tensor) -> torch.Tensor: + # shorthand used in training with no top-k + similarity = get_similarity(mk, ms, qk, qe) + affinity = do_softmax(similarity) + return affinity + + +def readout(affinity: torch.Tensor, mv: torch.Tensor) -> torch.Tensor: + B, CV, T, H, W = mv.shape + + mo = mv.view(B, CV, T * H * W) + mem = torch.bmm(mo, affinity) + mem = mem.view(B, CV, H, W) + + return mem diff --git a/web-demos/hugging_face/tracker/model/utils/parameter_groups.py b/web-demos/hugging_face/tracker/model/utils/parameter_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..177866af48de5e6d8795bdf6734b0dccb5a1947b --- /dev/null +++ b/web-demos/hugging_face/tracker/model/utils/parameter_groups.py @@ -0,0 +1,72 @@ +import logging + +log = logging.getLogger() + + +def get_parameter_groups(model, stage_cfg, print_log=False): + """ + Assign different weight decays and learning rates to different parameters. + Returns a parameter group which can be passed to the optimizer. + """ + weight_decay = stage_cfg.weight_decay + embed_weight_decay = stage_cfg.embed_weight_decay + backbone_lr_ratio = stage_cfg.backbone_lr_ratio + base_lr = stage_cfg.learning_rate + + backbone_params = [] + embed_params = [] + other_params = [] + + embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe'] + embedding_names = [e + '.weight' for e in embedding_names] + + # inspired by detectron2 + memo = set() + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + # Avoid duplicating parameters + if param in memo: + continue + memo.add(param) + + if name.startswith('module'): + name = name[7:] + + inserted = False + if name.startswith('pixel_encoder.'): + backbone_params.append(param) + inserted = True + if print_log: + log.info(f'{name} counted as a backbone parameter.') + else: + for e in embedding_names: + if name.endswith(e): + embed_params.append(param) + inserted = True + if print_log: + log.info(f'{name} counted as an embedding parameter.') + break + + if not inserted: + other_params.append(param) + + parameter_groups = [ + { + 'params': backbone_params, + 'lr': base_lr * backbone_lr_ratio, + 'weight_decay': weight_decay + }, + { + 'params': embed_params, + 'lr': base_lr, + 'weight_decay': embed_weight_decay + }, + { + 'params': other_params, + 'lr': base_lr, + 'weight_decay': weight_decay + }, + ] + + return parameter_groups \ No newline at end of file diff --git a/web-demos/hugging_face/tracker/model/utils/resnet.py b/web-demos/hugging_face/tracker/model/utils/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..3a07d2fd12da0b951ed6c724f97aa2f203877e7e --- /dev/null +++ b/web-demos/hugging_face/tracker/model/utils/resnet.py @@ -0,0 +1,179 @@ +""" +resnet.py - A modified ResNet structure +We append extra channels to the first conv by some network surgery +""" + +from collections import OrderedDict +import math + +import torch +import torch.nn as nn +from torch.utils import model_zoo + + +def load_weights_add_extra_dim(target, source_state, extra_dim=1): + new_dict = OrderedDict() + + for k1, v1 in target.state_dict().items(): + if not 'num_batches_tracked' in k1: + if k1 in source_state: + tar_v = source_state[k1] + + if v1.shape != tar_v.shape: + # Init the new segmentation channel with zeros + # print(v1.shape, tar_v.shape) + c, _, w, h = v1.shape + pads = torch.zeros((c, extra_dim, w, h), device=tar_v.device) + nn.init.orthogonal_(pads) + tar_v = torch.cat([tar_v, pads], 1) + + new_dict[k1] = tar_v + + target.load_state_dict(new_dict) + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, dilation=1): + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=stride, + dilation=dilation, + padding=dilation, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3 + extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, stride, downsample)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, dilation=dilation)) + + return nn.Sequential(*layers) + + +def resnet18(pretrained=True, extra_dim=0): + model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim) + if pretrained: + load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim) + return model + + +def resnet50(pretrained=True, extra_dim=0): + model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim) + if pretrained: + load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim) + return model diff --git a/web-demos/hugging_face/tracker/utils/__init__.py b/web-demos/hugging_face/tracker/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/web-demos/hugging_face/tracker/utils/image_saver.py b/web-demos/hugging_face/tracker/utils/image_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..c3edfa96e60fea0e5ec8fd087da85d2efaa6444c --- /dev/null +++ b/web-demos/hugging_face/tracker/utils/image_saver.py @@ -0,0 +1,230 @@ +import cv2 +import numpy as np + +import torch +from collections import defaultdict + + +def tensor_to_numpy(image): + image_np = (image.numpy() * 255).astype('uint8') + return image_np + + +def tensor_to_np_float(image): + image_np = image.numpy().astype('float32') + return image_np + + +def detach_to_cpu(x): + return x.detach().cpu() + + +def transpose_np(x): + return np.transpose(x, [1, 2, 0]) + + +def tensor_to_gray_im(x): + x = detach_to_cpu(x) + x = tensor_to_numpy(x) + x = transpose_np(x) + return x + + +def tensor_to_im(x): + x = detach_to_cpu(x).clamp(0, 1) + x = tensor_to_numpy(x) + x = transpose_np(x) + return x + + +# Predefined key <-> caption dict +key_captions = { + 'im': 'Image', + 'gt': 'GT', +} +""" +Return an image array with captions +keys in dictionary will be used as caption if not provided +values should contain lists of cv2 images +""" + + +def get_image_array(images, grid_shape, captions={}): + h, w = grid_shape + cate_counts = len(images) + rows_counts = len(next(iter(images.values()))) + + font = cv2.FONT_HERSHEY_SIMPLEX + + output_image = np.zeros([w * cate_counts, h * (rows_counts + 1), 3], dtype=np.uint8) + col_cnt = 0 + for k, v in images.items(): + + # Default as key value itself + caption = captions.get(k, k) + + # Handles new line character + dy = 40 + for i, line in enumerate(caption.split('\n')): + cv2.putText(output_image, line, (10, col_cnt * w + 100 + i * dy), font, 0.8, + (255, 255, 255), 2, cv2.LINE_AA) + + # Put images + for row_cnt, img in enumerate(v): + im_shape = img.shape + if len(im_shape) == 2: + img = img[..., np.newaxis] + + img = (img * 255).astype('uint8') + + output_image[(col_cnt + 0) * w:(col_cnt + 1) * w, + (row_cnt + 1) * h:(row_cnt + 2) * h, :] = img + + col_cnt += 1 + + return output_image + + +def base_transform(im, size): + im = tensor_to_np_float(im) + if len(im.shape) == 3: + im = im.transpose((1, 2, 0)) + else: + im = im[:, :, None] + + # Resize + if im.shape[1] != size: + im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST) + + return im.clip(0, 1) + + +def im_transform(im, size): + return base_transform(detach_to_cpu(im), size=size) + + +def mask_transform(mask, size): + return base_transform(detach_to_cpu(mask), size=size) + + +def logits_transform(mask, size): + return base_transform(detach_to_cpu(torch.sigmoid(mask)), size=size) + + +def add_attention(mask, pos): + mask = mask[:, :, None].repeat(3, axis=2) + pos = (pos + 1) / 2 + for i in range(pos.shape[0]): + y = int(pos[i][0] * mask.shape[0]) + x = int(pos[i][1] * mask.shape[1]) + y = max(min(y, mask.shape[0] - 1), 0) + x = max(min(x, mask.shape[1] - 1), 0) + # mask[y, x, :] = (255, 0, 0) + cv2.circle(mask, (x, y), 5, (1, 0, 0), -1) + return mask + + +def vis(images, size, num_objects): + req_images = defaultdict(list) + + b, t = images['rgb'].shape[:2] + + # limit the number of images saved + b = min(2, b) + + # find max num objects + max_num_objects = max(num_objects[:b]) + + GT_suffix = '' + for bi in range(b): + GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4] + + for bi in range(b): + for ti in range(t): + req_images['RGB'].append(im_transform(images['rgb'][bi, ti], size)) + aux = images[f'aux_{max(ti, 1)}'] # no aux_0, use aux_1 for shape + if 'sensory_logits' in aux: + sensory_aux = aux['sensory_logits'][bi].softmax(dim=0) + # batch_size * num_objects * num_levels * H * W + q_mask_aux = aux['q_logits'][bi].softmax(dim=0) + num_levels = q_mask_aux.shape[1] + + for oi in range(max_num_objects): + if ti == 0 or oi >= num_objects[bi]: + req_images[f'Mask_{oi}'].append( + mask_transform(images['first_frame_gt'][bi][0, oi], size)) + req_images[f'S-Aux_{oi}'].append( + mask_transform(images['first_frame_gt'][bi][0, oi], size)) + for l in range(num_levels): + req_images[f'Q-Aux-L{l}_{oi}'].append( + mask_transform(images['first_frame_gt'][bi][0, oi], size)) + else: + mask = mask_transform(images[f'masks_{ti}'][bi][oi], size) + req_images[f'Mask_{oi}'].append(mask) + if 'sensory_logits' in aux: + req_images[f'S-Aux_{oi}'].append(mask_transform(sensory_aux[oi + 1], size)) + + for l in range(num_levels): + mask = mask_transform(q_mask_aux[oi + 1, l], size) + req_images[f'Q-Aux-L{l}_{oi}'].append(mask) + + req_images[f'GT_{oi}_{GT_suffix}'].append( + mask_transform(images['cls_gt'][bi, ti, 0] == (oi + 1), size)) + + return get_image_array(req_images, size, key_captions) + + +def vis_debug(images, size, num_objects): + req_images = defaultdict(list) + + b, t = images['rgb'].shape[:2] + + # limit the number of images saved + b = min(2, b) + + # find max num objects + max_num_objects = max(num_objects[:b]) + + GT_suffix = '' + for bi in range(b): + GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4] + + for bi in range(b): + for ti in range(t): + req_images['RGB'].append(im_transform(images['rgb'][bi, ti], size)) + aux = images[f'aux_{max(ti, 1)}'] # no aux_0, use aux_1 for shape + sensory_aux = aux['sensory_logits'][bi].softmax(dim=0) + # batch_size * num_objects * num_levels * H * W + q_mask_aux = aux['q_logits'][bi].softmax(dim=0) + attn_mask = aux['attn_mask'][bi] + num_levels = q_mask_aux.shape[1] + num_queries = attn_mask.shape[1] + + for oi in range(max_num_objects): + if ti == 0 or oi >= num_objects[bi]: + req_images[f'Mask_{oi}'].append( + mask_transform(images['first_frame_gt'][bi][0, oi], size)) + req_images[f'S-Aux_{oi}'].append( + mask_transform(images['first_frame_gt'][bi][0, oi], size)) + for l in range(num_levels): + req_images[f'Q-Aux-L{l}_{oi}'].append( + mask_transform(images['first_frame_gt'][bi][0, oi], size)) + for q in range(num_queries): + req_images[f'Attn-Mask-Q{q}_{oi}'].append( + mask_transform(images['first_frame_gt'][bi][0, oi], size)) + else: + mask = mask_transform(images[f'masks_{ti}'][bi][oi], size) + req_images[f'Mask_{oi}'].append(mask) + req_images[f'S-Aux_{oi}'].append(mask_transform(sensory_aux[oi + 1], size)) + + for l in range(num_levels): + mask = mask_transform(q_mask_aux[oi + 1, l], size) + req_images[f'Q-Aux-L{l}_{oi}'].append(mask) + for q in range(num_queries): + mask = mask_transform(1 - attn_mask[oi, q].float(), size) + req_images[f'Attn-Mask-Q{q}_{oi}'].append(mask) + + req_images[f'GT_{oi}_{GT_suffix}'].append( + mask_transform(images['cls_gt'][bi, ti, 0] == (oi + 1), size)) + + return get_image_array(req_images, size, key_captions) \ No newline at end of file diff --git a/web-demos/hugging_face/tracker/utils/load_subset.py b/web-demos/hugging_face/tracker/utils/load_subset.py new file mode 100644 index 0000000000000000000000000000000000000000..c16ed0391ae745a736290bb7b956c98539e087ca --- /dev/null +++ b/web-demos/hugging_face/tracker/utils/load_subset.py @@ -0,0 +1,13 @@ +import json + + +def load_subset(path): + with open(path, mode='r') as f: + subset = set(f.read().splitlines()) + return subset + + +def load_empty_masks(path): + with open(path, mode='r') as f: + empty_masks = json.load(f) + return empty_masks diff --git a/web-demos/hugging_face/tracker/utils/log_integrator.py b/web-demos/hugging_face/tracker/utils/log_integrator.py new file mode 100644 index 0000000000000000000000000000000000000000..f2e12ebb393cfd6b734859fad3c243850ba60ea7 --- /dev/null +++ b/web-demos/hugging_face/tracker/utils/log_integrator.py @@ -0,0 +1,84 @@ +""" +Integrate numerical values for some iterations +Typically used for loss computation / logging to tensorboard +Call finalize and create a new Integrator when you want to display/log +""" +from typing import Dict, Callable, Tuple +import torch +from tracker.utils.logger import TensorboardLogger + + +class Integrator: + def __init__(self, logger: TensorboardLogger, distributed: bool = True): + self.values = {} + self.counts = {} + self.hooks = [] # List is used here to maintain insertion order + + self.logger = logger + + self.distributed = distributed + self.local_rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + + def add_tensor(self, key: str, tensor: torch.Tensor): + if key not in self.values: + self.counts[key] = 1 + if type(tensor) == float or type(tensor) == int: + self.values[key] = tensor + else: + self.values[key] = tensor.mean().item() + else: + self.counts[key] += 1 + if type(tensor) == float or type(tensor) == int: + self.values[key] += tensor + else: + self.values[key] += tensor.mean().item() + + def add_dict(self, tensor_dict: Dict[str, torch.Tensor]): + for k, v in tensor_dict.items(): + self.add_tensor(k, v) + + def add_hook(self, hook: Callable[[torch.Tensor], Tuple[str, torch.Tensor]]): + """ + Adds a custom hook, i.e. compute new metrics using values in the dict + The hook takes the dict as argument, and returns a (k, v) tuple + e.g. for computing IoU + """ + if type(hook) == list: + self.hooks.extend(hook) + else: + self.hooks.append(hook) + + def reset_except_hooks(self): + self.values = {} + self.counts = {} + + # Average and output the metrics + def finalize(self, exp_id: str, prefix: str, it: int) -> None: + + for hook in self.hooks: + k, v = hook(self.values) + self.add_tensor(k, v) + + outputs = {} + for k, v in self.values.items(): + + if k[:4] == 'hide': + continue + + avg = v / self.counts[k] + + if self.distributed: + # Inplace operation + avg = torch.tensor(avg).cuda() + torch.distributed.reduce(avg, dst=0) + + if self.local_rank == 0: + avg = (avg / self.world_size).cpu().item() + outputs[k] = avg + else: + # Simple does it + outputs[k] = avg + + if (not self.distributed) or (self.local_rank == 0): + self.logger.log_metrics(exp_id, prefix, outputs, it) diff --git a/web-demos/hugging_face/tracker/utils/logger.py b/web-demos/hugging_face/tracker/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..399f76d685d2f0fa8a68b69222c1400815a0a2e3 --- /dev/null +++ b/web-demos/hugging_face/tracker/utils/logger.py @@ -0,0 +1,107 @@ +""" +Dumps things to tensorboard and console +""" + +import os +import logging +import datetime +from typing import Dict +import numpy as np +from PIL import Image + +from torch.utils.tensorboard import SummaryWriter +from tracker.utils.time_estimator import TimeEstimator + + +def tensor_to_numpy(image): + image_np = (image.numpy() * 255).astype('uint8') + return image_np + + +def detach_to_cpu(x): + return x.detach().cpu() + + +def fix_width_trunc(x): + return ('{:.9s}'.format('{:0.9f}'.format(x))) + + +class TensorboardLogger: + def __init__(self, run_dir, py_logger: logging.Logger, *, enabled_tb): + self.run_dir = run_dir + self.py_log = py_logger + if enabled_tb: + self.tb_log = SummaryWriter(run_dir) + else: + self.tb_log = None + + # Get current git info for logging + try: + import git + repo = git.Repo(".") + git_info = str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha) + except (ImportError, RuntimeError): + print('Failed to fetch git info. Defaulting to None') + git_info = 'None' + + self.log_string('git', git_info) + + # used when logging metrics + self.time_estimator: TimeEstimator = None + + def log_scalar(self, tag, x, it): + if self.tb_log is None: + return + self.tb_log.add_scalar(tag, x, it) + + def log_metrics(self, exp_id, prefix, metrics: Dict, it): + msg = f'{exp_id}-{prefix} - it {it:6d}: ' + metrics_msg = '' + for k, v in sorted(metrics.items()): + self.log_scalar(f'{prefix}/{k}', v, it) + metrics_msg += f'{k: >10}:{v:.7f},\t' + + if self.time_estimator is not None: + self.time_estimator.update() + avg_time = self.time_estimator.get_and_reset_avg_time() + est = self.time_estimator.get_est_remaining(it) + est = datetime.timedelta(seconds=est) + if est.days > 0: + remaining_str = f'{est.days}d {est.seconds // 3600}h' + else: + remaining_str = f'{est.seconds // 3600}h {(est.seconds%3600) // 60}m' + eta = datetime.datetime.now() + est + eta_str = eta.strftime('%Y-%m-%d %H:%M:%S') + time_msg = f'avg_time:{avg_time:.3f},remaining:{remaining_str},eta:{eta_str},\t' + msg = f'{msg} {time_msg}' + + msg = f'{msg} {metrics_msg}' + self.py_log.info(msg) + + def log_image(self, stage_name, tag, image, it): + image_dir = os.path.join(self.run_dir, f'{stage_name}_images') + os.makedirs(image_dir, exist_ok=True) + + image = Image.fromarray(image) + image.save(os.path.join(image_dir, f'{tag}_{it}.png')) + + def log_string(self, tag, x): + self.py_log.info(f'{tag} - {x}') + if self.tb_log is None: + return + self.tb_log.add_text(tag, x) + + def debug(self, x): + self.py_log.debug(x) + + def info(self, x): + self.py_log.info(x) + + def warning(self, x): + self.py_log.warning(x) + + def error(self, x): + self.py_log.error(x) + + def critical(self, x): + self.py_log.critical(x) diff --git a/web-demos/hugging_face/tracker/utils/mask_mapper.py b/web-demos/hugging_face/tracker/utils/mask_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..617af0c712d20f429a05274438b77a5afc88d2db --- /dev/null +++ b/web-demos/hugging_face/tracker/utils/mask_mapper.py @@ -0,0 +1,78 @@ +import numpy as np +import torch + +def all_to_onehot(masks, labels): + if len(masks.shape) == 3: + Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) + else: + Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8) + + for ni, l in enumerate(labels): + Ms[ni] = (masks == l).astype(np.uint8) + + return Ms + +class MaskMapper: + """ + This class is used to convert a indexed-mask to a one-hot representation. + It also takes care of remapping non-continuous indices + It has two modes: + 1. Default. Only masks with new indices are supposed to go into the remapper. + This is also the case for YouTubeVOS. + i.e., regions with index 0 are not "background", but "don't care". + + 2. Exhaustive. Regions with index 0 are considered "background". + Every single pixel is considered to be "labeled". + """ + def __init__(self): + self.labels = [] + self.remappings = {} + + # if coherent, no mapping is required + self.coherent = True + + def clear_labels(self): + self.labels = [] + self.remappings = {} + # if coherent, no mapping is required + self.coherent = True + + def convert_mask(self, mask, exhaustive=False): + # mask is in index representation, H*W numpy array + labels = np.unique(mask).astype(np.uint8) + labels = labels[labels!=0].tolist() + + new_labels = list(set(labels) - set(self.labels)) + if not exhaustive: + assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode' + + # add new remappings + for i, l in enumerate(new_labels): + self.remappings[l] = i+len(self.labels)+1 + if self.coherent and i+len(self.labels)+1 != l: + self.coherent = False + + if exhaustive: + new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1) + else: + if self.coherent: + new_mapped_labels = new_labels + else: + new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1) + + self.labels.extend(new_labels) + # mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float() + mask = torch.from_numpy(mask).float() + # mask num_objects*H*W + return mask, new_mapped_labels + + + def remap_index_mask(self, mask): + # mask is in index representation, H*W numpy array + if self.coherent: + return mask + + new_mask = np.zeros_like(mask) + for l, i in self.remappings.items(): + new_mask[mask==i] = l + return new_mask \ No newline at end of file diff --git a/web-demos/hugging_face/tracker/utils/palette.py b/web-demos/hugging_face/tracker/utils/palette.py new file mode 100644 index 0000000000000000000000000000000000000000..26a773c88bdd15fdb372fa9f552602a751625fc4 --- /dev/null +++ b/web-demos/hugging_face/tracker/utils/palette.py @@ -0,0 +1,9 @@ +import numpy as np + +davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0' + +youtube_palette = b'\x00\x00\x00\xec_g\xf9\x91W\xfa\xc8c\x99\xc7\x94b\xb3\xb2f\x99\xcc\xc5\x94\xc5\xabyg\xff\xff\xffes~\x0b\x0b\x0b\x0c\x0c\x0c\r\r\r\x0e\x0e\x0e\x0f\x0f\x0f' + +davis_palette_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3) + +youtube_palette_np = np.frombuffer(youtube_palette, dtype=np.uint8).reshape(-1, 3) \ No newline at end of file diff --git a/web-demos/hugging_face/tracker/utils/pano_utils.py b/web-demos/hugging_face/tracker/utils/pano_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5d0e2fe9448ec977f0f177c2b5dc9aaccf38250d --- /dev/null +++ b/web-demos/hugging_face/tracker/utils/pano_utils.py @@ -0,0 +1,30 @@ +import numpy as np +from threading import Lock + + +class ID2RGBConverter: + def __init__(self): + self.all_id = [] + self.obj_to_id = {} + self.lock = Lock() + + def _id_to_rgb(self, id: int): + rgb = np.zeros((3, ), dtype=np.uint8) + for i in range(3): + rgb[i] = id % 256 + id = id // 256 + return rgb + + def convert(self, obj: int): + with self.lock: + if obj in self.obj_to_id: + id = self.obj_to_id[obj] + else: + while True: + id = np.random.randint(255, 256**3) + if id not in self.all_id: + break + self.obj_to_id[obj] = id + self.all_id.append(id) + + return id, self._id_to_rgb(id) diff --git a/web-demos/hugging_face/tracker/utils/point_features.py b/web-demos/hugging_face/tracker/utils/point_features.py new file mode 100644 index 0000000000000000000000000000000000000000..87b794ef23c856bba022215c84581ba38e6d030b --- /dev/null +++ b/web-demos/hugging_face/tracker/utils/point_features.py @@ -0,0 +1,111 @@ +# This file is copied from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py +# such that users do not need to install detectron2 just for these two functions +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import List +import torch +from torch.nn import functional as F + + +def cat(tensors: List[torch.Tensor], dim: int = 0): + """ + Efficient version of torch.cat that avoids a copy if there is only a single element in a list + """ + assert isinstance(tensors, (list, tuple)) + if len(tensors) == 1: + return tensors[0] + return torch.cat(tensors, dim) + + +def calculate_uncertainty(sem_seg_logits): + """ + For each location of the prediction `sem_seg_logits` we estimate uncerainty as the + difference between top first and top second predicted logits. + Args: + mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and + C is the number of foreground classes. The values are logits. + Returns: + scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with + the most uncertain locations having the highest uncertainty score. + """ + if sem_seg_logits.shape[1] == 2: + # binary segmentation + return -(torch.abs(sem_seg_logits[:, 1:2])) + top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0] + return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) + + +def point_sample(input, point_coords, **kwargs): + """ + A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. + Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside + [0, 1] x [0, 1] square. + Args: + input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. + point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains + [0, 1] x [0, 1] normalized point coordinates. + Returns: + output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains + features for points in `point_coords`. The features are obtained via bilinear + interpolation from `input` the same way as :function:`torch.nn.functional.grid_sample`. + """ + add_dim = False + if point_coords.dim() == 3: + add_dim = True + point_coords = point_coords.unsqueeze(2) + output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) + if add_dim: + output = output.squeeze(3) + return output + + +def get_uncertain_point_coords_with_randomness(coarse_logits, uncertainty_func, num_points, + oversample_ratio, importance_sample_ratio): + """ + Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The uncertainties + are calculated for each point using 'uncertainty_func' function that takes point's logit + prediction as input. + See PointRend paper for details. + Args: + coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for + class-specific or class-agnostic prediction. + uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that + contains logit predictions for P points and returns their uncertainties as a Tensor of + shape (N, 1, P). + num_points (int): The number of points P to sample. + oversample_ratio (int): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. + Returns: + point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P + sampled points. + """ + assert oversample_ratio >= 1 + assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 + num_boxes = coarse_logits.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) + point_logits = point_sample(coarse_logits, point_coords, align_corners=False) + # It is crucial to calculate uncertainty based on the sampled prediction value for the points. + # Calculating uncertainties of the coarse predictions first and sampling them for points leads + # to incorrect results. + # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between + # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. + # However, if we calculate uncertainties for the coarse predictions first, + # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. + point_uncertainties = uncertainty_func(point_logits) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, + 2) + if num_random_points > 0: + point_coords = cat( + [ + point_coords, + torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), + ], + dim=1, + ) + return point_coords \ No newline at end of file diff --git a/web-demos/hugging_face/tracker/utils/range_transform.py b/web-demos/hugging_face/tracker/utils/range_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1b0b3b2a01a061b9b2220a93cdf7f7a6357bfb --- /dev/null +++ b/web-demos/hugging_face/tracker/utils/range_transform.py @@ -0,0 +1,12 @@ +import torchvision.transforms as transforms + +im_mean = (124, 116, 104) + +im_normalization = transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) + +inv_im_trans = transforms.Normalize( + mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], + std=[1/0.229, 1/0.224, 1/0.225]) diff --git a/web-demos/hugging_face/tracker/utils/tensor_utils.py b/web-demos/hugging_face/tracker/utils/tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c8e33c8936134ec4efdcc412945469853edc6498 --- /dev/null +++ b/web-demos/hugging_face/tracker/utils/tensor_utils.py @@ -0,0 +1,62 @@ +from typing import List, Iterable +import torch +import torch.nn.functional as F + + +# STM +def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]): + h, w = in_img.shape[-2:] + + if h % d > 0: + new_h = h + d - h % d + else: + new_h = h + if w % d > 0: + new_w = w + d - w % d + else: + new_w = w + lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2) + lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2) + pad_array = (int(lw), int(uw), int(lh), int(uh)) + out = F.pad(in_img, pad_array) + return out, pad_array + + +def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor: + if len(img.shape) == 4: + if pad[2] + pad[3] > 0: + img = img[:, :, pad[2]:-pad[3], :] + if pad[0] + pad[1] > 0: + img = img[:, :, :, pad[0]:-pad[1]] + elif len(img.shape) == 3: + if pad[2] + pad[3] > 0: + img = img[:, pad[2]:-pad[3], :] + if pad[0] + pad[1] > 0: + img = img[:, :, pad[0]:-pad[1]] + elif len(img.shape) == 5: + if pad[2] + pad[3] > 0: + img = img[:, :, :, pad[2]:-pad[3], :] + if pad[0] + pad[1] > 0: + img = img[:, :, :, :, pad[0]:-pad[1]] + else: + raise NotImplementedError + return img + + +# @torch.jit.script +def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor: + with torch.cuda.amp.autocast(enabled=False): + prob = prob.float() + new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob], + dim).clamp(1e-7, 1 - 1e-7) + logits = torch.log((new_prob / (1 - new_prob))) + + return logits + + +# @torch.jit.script +def cls_to_one_hot(cls_gt: torch.Tensor, num_objects: int) -> torch.Tensor: + # cls_gt: B*1*H*W + B, _, H, W = cls_gt.shape + one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1) + return one_hot \ No newline at end of file diff --git a/web-demos/hugging_face/tracker/utils/time_estimator.py b/web-demos/hugging_face/tracker/utils/time_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..8d429b0404b641472ff84701305c570cc26280b7 --- /dev/null +++ b/web-demos/hugging_face/tracker/utils/time_estimator.py @@ -0,0 +1,43 @@ +import time + + +class TimeEstimator: + def __init__(self, total_iter, step_size): + self.avg_time_window = [] # window-based average + self.exp_avg_time = None # exponential moving average + self.alpha = 0.7 # for exponential moving average + + self.last_time = time.time() # would not be accurate for the first iteration but well + self.total_iter = total_iter + self.step_size = step_size + + self.buffering_exp = True + + # call this at a fixed interval + # does not have to be every step + def update(self): + curr_time = time.time() + time_per_iter = curr_time - self.last_time + self.last_time = curr_time + + self.avg_time_window.append(time_per_iter) + + if self.buffering_exp: + if self.exp_avg_time is not None: + # discard the first iteration call to not pollute the ema + self.buffering_exp = False + self.exp_avg_time = time_per_iter + else: + self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter + + def get_est_remaining(self, it): + if self.exp_avg_time is None: + return 0 + + remaining_iter = self.total_iter - it + return remaining_iter * self.exp_avg_time / self.step_size + + def get_and_reset_avg_time(self): + avg = sum(self.avg_time_window) / len(self.avg_time_window) / self.step_size + self.avg_time_window = [] + return avg diff --git a/weights/README.md b/weights/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5380157a9e3fe0f4a05f93006d7edcd83c0df8bb --- /dev/null +++ b/weights/README.md @@ -0,0 +1,13 @@ +# Weights + +Put the downloaded pre-trained models to this folder. + +The directory structure will be arranged as: +``` +weights + |- raft-things.pth + |- recurrent_flow_completion.pth + |- ProPainter.pth + |- i3d_rgb_imagenet.pt (for evaluating VFID metric) + |- README.md +``` \ No newline at end of file