NPRC24 / SCBC /CPNet_model.py
Artyom
scbc
f8d6c27 verified
raw
history blame
22.1 kB
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.utils.model_zoo as model_zoo
from torchvision import models
from torchvision import transforms
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import math
import time
import tqdm
import os
import argparse
import copy
import sys
import networks as N
from model_module import *
sys.path.insert(0, '.')
# from .common import *
sys.path.insert(0, '../utils/')
class LiteISPNet(nn.Module):
def __init__(self,):
super(LiteISPNet, self).__init__()
ch_1 = 64
ch_2 = 128
ch_3 = 128
n_blocks = 4
self.head = N.seq(
N.conv(3, ch_1, mode='C')
) # shape: (N, ch_1, H/2, W/2)
self.down1 = N.seq(
N.conv(ch_1, ch_1, mode='C'),
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
N.conv(ch_1, ch_1, mode='C'),
N.DWTForward(ch_1)
) # shape: (N, ch_1*4, H/4, W/4)
self.down2 = N.seq(
N.conv(ch_1*4, ch_1, mode='C'),
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
N.DWTForward(ch_1)
) # shape: (N, ch_1*4, H/8, W/8)
self.down3 = N.seq(
N.conv(ch_1*4, ch_2, mode='C'),
N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks),
N.DWTForward(ch_2)
) # shape: (N, ch_2*4, H/16, W/16)
self.middle = N.seq(
N.conv(ch_2*4, ch_3, mode='C'),
N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks),
N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks),
N.conv(ch_3, ch_2*4, mode='C')
) # shape: (N, ch_2*4, H/16, W/16)
self.up3 = N.seq(
N.DWTInverse(ch_2*4),
N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks),
N.conv(ch_2, ch_1*4, mode='C')
) # shape: (N, ch_1*4, H/8, W/8)
self.up2 = N.seq(
N.DWTInverse(ch_1*4),
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
N.conv(ch_1, ch_1*4, mode='C')
) # shape: (N, ch_1*4, H/4, W/4)
self.up1 = N.seq(
N.DWTInverse(ch_1*4),
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
N.conv(ch_1, ch_1, mode='C')
) # shape: (N, ch_1, H/2, W/2)
self.tail = N.seq(
#N.conv(ch_1, ch_1*4, mode='C'),
#nn.PixelShuffle(upscale_factor=2),
N.conv(ch_1, 3, mode='C')
) # shape: (N, 3, H, W)
def forward(self, raw):
# input = raw
input = torch.pow(raw, 1/2.2)
h = self.head(input)
h_coord = h
d1 = self.down1(h_coord)
d2 = self.down2(d1)
d3 = self.down3(d2)
m = self.middle(d3) + d3
u3 = self.up3(m) + d2
u2 = self.up2(u3) + d1
u1 = self.up1(u2) + h
out = self.tail(u1)
return out
class LiteAWBISPNet(nn.Module):
def __init__(self,):
super(LiteAWBISPNet, self).__init__()
ch_1 = 64
ch_2 = 128
ch_3 = 128
n_blocks = 4
self.head = N.seq(
N.conv(3, ch_1, mode='C')
) # shape: (N, ch_1, H/2, W/2)
self.down1 = N.seq(
N.conv(ch_1, ch_1, mode='C'),
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
N.conv(ch_1, ch_1, mode='C'),
N.DWTForward(ch_1)
) # shape: (N, ch_1*4, H/4, W/4)
self.down2 = N.seq(
N.conv(ch_1*4, ch_1, mode='C'),
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
N.DWTForward(ch_1)
) # shape: (N, ch_1*4, H/8, W/8)
self.down3 = N.seq(
N.conv(ch_1*4, ch_2, mode='C'),
N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks),
N.DWTForward(ch_2)
) # shape: (N, ch_2*4, H/16, W/16)
self.middle = N.seq(
N.conv(ch_2*4, ch_3, mode='C'),
N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks),
N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks),
N.conv(ch_3, ch_2*4, mode='C')
) # shape: (N, ch_2*4, H/16, W/16)
self.up3 = N.seq(
N.DWTInverse(ch_2*4),
N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks),
N.conv(ch_2, ch_1*4, mode='C')
) # shape: (N, ch_1*4, H/8, W/8)
self.up2 = N.seq(
N.DWTInverse(ch_1*4),
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
N.conv(ch_1, ch_1*4, mode='C')
) # shape: (N, ch_1*4, H/4, W/4)
self.up1 = N.seq(
N.DWTInverse(ch_1*4),
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks),
N.conv(ch_1, ch_1, mode='C')
) # shape: (N, ch_1, H/2, W/2)
self.tail = N.seq(
#N.conv(ch_1, ch_1*4, mode='C'),
#nn.PixelShuffle(upscale_factor=2),
N.conv(ch_1, 3, mode='C')
) # shape: (N, 3, H, W)
def forward(self, raw):
# input = raw
input = raw
h = self.head(input)
h_coord = h
d1 = self.down1(h_coord)
d2 = self.down2(d1)
d3 = self.down3(d2)
m = self.middle(d3) + d3
u3 = self.up3(m) + d2
u2 = self.up2(u3) + d1
u1 = self.up1(u2) + h
out = self.tail(u1)
return out
# Alignment Encoder
class A_Encoder(nn.Module):
def __init__(self):
super(A_Encoder, self).__init__()
self.conv12 = Conv2d(3, 64, kernel_size=5, stride=2, padding=2, activation=nn.ReLU()) # 2
self.conv2 = Conv2d(64, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 2
self.conv23 = Conv2d(64, 128, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) # 4
self.conv3 = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 4
self.conv34 = Conv2d(128, 256, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) # 8
self.conv4a = Conv2d(256, 256, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 8
self.conv4b = Conv2d(256, 256, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 8
init_He(self)
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1))
def forward(self, in_f):
f = (in_f - self.mean) / self.std
x = f
x = F.upsample(x, size=(224, 224), mode='bilinear', align_corners=False)
x = self.conv12(x)
x = self.conv2(x)
x = self.conv23(x)
x = self.conv3(x)
x = self.conv34(x)
x = self.conv4a(x)
x = self.conv4b(x)
return x
# Alignment Regressor
class A_Regressor(nn.Module):
def __init__(self):
super(A_Regressor, self).__init__()
self.conv45 = Conv2d(512, 512, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) # 16
self.conv5a = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 16
self.conv5b = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 16
self.conv56 = Conv2d(512, 512, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) # 32
self.conv6a = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 32
self.conv6b = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 32
init_He(self)
self.fc = nn.Linear(512, 6)
self.fc.weight.data.zero_()
self.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float32))
def forward(self, feat1, feat2):
x = torch.cat([feat1, feat2], dim=1)
x = self.conv45(x)
x = self.conv5a(x)
x = self.conv5b(x)
x = self.conv56(x)
x = self.conv5a(x)
x = self.conv5b(x)
x = F.avg_pool2d(x, x.shape[2])
x = x.view(-1, x.shape[1])
theta = self.fc(x)
theta = theta.view(-1, 2, 3)
return theta
# Encoder (Copy network)
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.conv12 = Conv2d(4, 64, kernel_size=5, stride=2, padding=2, activation=nn.ReLU()) # 2
self.conv2 = Conv2d(64, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 2
self.conv23 = Conv2d(64, 128, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) # 4
self.conv3 = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 4
self.value3 = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=None) # 4
init_He(self)
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1))
def forward(self, in_f, in_v):
f = (in_f - self.mean) / self.std
x = torch.cat([f, in_v], dim=1)
x = self.conv12(x)
x = self.conv2(x)
x = self.conv23(x)
x = self.conv3(x)
v = self.value3(x)
return v
# Decoder (Paste network)
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.conv4 = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU())
self.conv5_1 = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU())
self.conv5_2 = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU())
# dilated convolution blocks
self.convA4_1 = Conv2d(257, 257, kernel_size=3, stride=1, padding=2, D=2, activation=nn.ReLU())
self.convA4_2 = Conv2d(257, 257, kernel_size=3, stride=1, padding=4, D=4, activation=nn.ReLU())
self.convA4_3 = Conv2d(257, 257, kernel_size=3, stride=1, padding=8, D=8, activation=nn.ReLU())
self.convA4_4 = Conv2d(257, 257, kernel_size=3, stride=1, padding=16, D=16,activation=nn.ReLU())
self.conv3c = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 4
self.conv3b = Conv2d(257, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 4
self.conv3a = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 4
self.conv32 = Conv2d(128, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 2
self.conv2 = Conv2d(64, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) # 2
self.conv21 = Conv2d(64, 3, kernel_size=5, stride=1, padding=2, activation=None) # 1
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1))
def forward(self, x):
x = self.conv4(x)
x = self.conv5_1(x)
x = self.conv5_2(x)
x = self.convA4_1(x)
x = self.convA4_2(x)
x = self.convA4_3(x)
x = self.convA4_4(x)
x = self.conv3c(x)
x = self.conv3b(x)
x = self.conv3a(x)
x = F.upsample(x, scale_factor=2, mode='nearest') # 2
x = self.conv32(x)
x = self.conv2(x)
x = F.upsample(x, scale_factor=2, mode='nearest') # 2
x = self.conv21(x)
p = (x *self.std) + self.mean
return p
# Context Matching Module
class CM_Module(nn.Module):
def __init__(self):
super(CM_Module, self).__init__()
def masked_softmax(self, vec, mask, dim):
masked_vec = vec * mask.float()
max_vec = torch.max(masked_vec, dim=dim, keepdim=True)[0]
exps = torch.exp(masked_vec-max_vec)
masked_exps = exps * mask.float()
masked_sums = masked_exps.sum(dim, keepdim=True)
zeros = (masked_sums <1e-4)
masked_sums += zeros.float()
return masked_exps/masked_sums
def forward(self, values, tvmap, rvmaps):
B, C, T, H, W = values.size()
# t_feat: target feature
t_feat = values[:, :, 0]
# r_feats: refetence features
r_feats = values[:, :, 1:]
B, Cv, T, H, W = r_feats.size()
# vmap: visibility map
# tvmap: target visibility map
# rvmap: reference visibility map
# gs: cosine similarity
# c_m: c_match
gs_,vmap_ = [], []
tvmap_t = (F.upsample(tvmap, size=(H, W), mode='bilinear', align_corners=False)>0.5).float()
for r in range(T):
rvmap_t = (F.upsample(rvmaps[:,:,r], size=(H, W), mode='bilinear', align_corners=False)>0.5).float()
# vmap: visibility map
vmap = tvmap_t * rvmap_t
gs = (vmap * t_feat * r_feats[:,:,r]).sum(-1).sum(-1).sum(-1)
#valid sum
v_sum = vmap[:,0].sum(-1).sum(-1)
zeros = (v_sum <1e-4)
gs[zeros] = 0
v_sum += zeros.float()
gs = gs / v_sum / C
gs = torch.ones(t_feat.shape).float().cuda() * gs.view(B,1,1,1)
gs_.append(gs)
vmap_.append(rvmap_t)
gss = torch.stack(gs_, dim=2)
vmaps = torch.stack(vmap_, dim=2)
#weighted pixelwise masked softmax
c_match = self.masked_softmax(gss, vmaps, dim=2)
c_out = torch.sum(r_feats * c_match, dim=2)
# c_mask
c_mask = (c_match * vmaps)
c_mask = torch.sum(c_mask,2)
c_mask = 1. - (torch.mean(c_mask, 1, keepdim=True))
return torch.cat([t_feat, c_out, c_mask], dim=1), c_mask
class GCMModel(nn.Module):
def __init__(self):
super(GCMModel, self).__init__()
self.ch_1 = 16
self.ch_2 = 32
guide_input_channels = 3
align_input_channels = 3
self.gcm_coord = None
if not self.gcm_coord:
guide_input_channels = 3
align_input_channels = 3
self.guide_net = N.seq(
N.conv(guide_input_channels, self.ch_1, 7, stride=2, padding=0, mode='CR'),
N.conv(self.ch_1, self.ch_1, kernel_size=3, stride=1, padding=1, mode='CRC'),
nn.AdaptiveAvgPool2d(1),
N.conv(self.ch_1, self.ch_2, 1, stride=1, padding=0, mode='C')
)
self.align_head = N.conv(align_input_channels, self.ch_2, 1, padding=0, mode='CR')
self.align_base = N.seq(
N.conv(self.ch_2, self.ch_2, kernel_size=1, stride=1, padding=0, mode='CRCRCRCRCR')
)
self.align_tail = N.seq(
N.conv(self.ch_2, 3, 1, padding=0, mode='C')
)
def forward(self, demosaic_raw):
demosaic_raw = torch.pow(demosaic_raw, 1 / 2.2)
guide_input = demosaic_raw
base_input =demosaic_raw
guide = self.guide_net(guide_input)
out = self.align_head(base_input)
out = guide * out + out
out = self.align_base(out)
out = self.align_tail(out)+demosaic_raw
return out
class Fusion(nn.Module):
def __init__(self):
super(Fusion, self).__init__()
self.ch_1 = 16
self.ch_2 = 32
guide_input_channels = 9
align_input_channels = 9
self.gcm_coord = None
if not self.gcm_coord:
guide_input_channels = 9
align_input_channels = 9
self.guide_net = N.seq(
N.conv(guide_input_channels, self.ch_1, 7, stride=2, padding=0, mode='CR'),
N.conv(self.ch_1, self.ch_1, kernel_size=3, stride=1, padding=1, mode='CRC'),
nn.AdaptiveAvgPool2d(1),
N.conv(self.ch_1, self.ch_2, 1, stride=1, padding=0, mode='C')
)
self.align_head = N.conv(align_input_channels, self.ch_2, 1, padding=0, mode='CR')
self.align_base = N.seq(
N.conv(self.ch_2, self.ch_2, kernel_size=1, stride=1, padding=0, mode='CRCRCR')
)
self.align_tail = N.seq(
N.conv(self.ch_2, 3, 1, padding=0, mode='C')
)
def forward(self, demosaic_raw):
#demosaic_raw = torch.pow(demosaic_raw, 1 / 2.2)
guide_input = demosaic_raw
base_input =demosaic_raw
guide = self.guide_net(guide_input)
out = self.align_head(base_input)
out = guide * out + out
out = self.align_base(out)
out = self.align_tail(out)
return out
class CPNet(nn.Module):
def __init__(self, mode='Train'):
super(CPNet, self).__init__()
self.A_Encoder = A_Encoder() # Align
self.A_Regressor = A_Regressor() # output: alignment network
self.GCMModel = GCMModel()
self.Encoder = Encoder() # Merge
self.CM_Module = CM_Module()
self.Decoder = Decoder()
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))
self.register_buffer('mean3d', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1,1))
def encoding(self, frames, holes):
batch_size, _, num_frames, height, width = frames.size()
# padding
(frames, holes), pad = pad_divide_by([frames, holes], 8, (frames.size()[3], frames.size()[4]))
feat_ = []
for t in range(num_frames):
feat = self.A_Encoder(frames[:,:,t], holes[:,:,t])
feat_.append(feat)
feats = torch.stack(feat_, dim=2)
return feats
def inpainting(self, rfeats, rframes, rholes, frame, hole, gt):
batch_size, _, height, width = frame.size() # B C H W
num_r = rfeats.size()[2] # # of reference frames
# padding
(rframes, rholes, frame, hole, gt), pad = pad_divide_by([rframes, rholes, frame, hole, gt], 8, (height, width))
# Target embedding
tfeat = self.A_Encoder(frame, hole)
# c_feat: Encoder(Copy Network) features
c_feat_ = [self.Encoder(frame, hole)]
L_align = torch.zeros_like(frame)
# aligned_r: aligned reference frames
aligned_r_ = []
# rvmap: aligned reference frames valid maps
rvmap_ = []
for r in range(num_r):
theta_rt = self.A_Regressor(tfeat, rfeats[:,:,r])
grid_rt = F.affine_grid(theta_rt, frame.size())
# aligned_r: aligned reference frame
# reference frame affine transformation
aligned_r = F.grid_sample(rframes[:,:,r], grid_rt)
# aligned_v: aligned reference visiblity map
# reference mask affine transformation
aligned_v = F.grid_sample(1-rholes[:,:,r], grid_rt)
aligned_v = (aligned_v>0.5).float()
aligned_r_.append(aligned_r)
#intersection of target and reference valid map
trvmap = (1-hole) * aligned_v
# compare the aligned frame - target frame
c_feat_.append(self.Encoder(aligned_r, aligned_v))
rvmap_.append(aligned_v)
aligned_rs = torch.stack(aligned_r_, 2)
c_feats =torch.stack(c_feat_, dim=2)
rvmaps = torch.stack(rvmap_, dim=2)
# p_in: paste network input(target features + c_out + c_mask)
p_in, c_mask = self.CM_Module(c_feats, 1-hole, rvmaps)
pred = self.Decoder(p_in)
_, _, _, H, W = aligned_rs.shape
c_mask = (F.upsample(c_mask, size=(H, W), mode='bilinear', align_corners=False)).detach()
comp = pred * (hole) + gt * (1.-hole)
if pad[2]+pad[3] > 0:
comp = comp[:,:,pad[2]:-pad[3],:]
if pad[0]+pad[1] > 0:
comp = comp[:,:,:,pad[0]:-pad[1]]
comp = torch.clamp(comp, 0, 1)
return comp
def forward(self, Source, Target):
feat_target =self.A_Encoder(Target)
feat_source = self.A_Encoder(Source)
theta = self.A_Regressor(feat_target,feat_source)
grid_rt = F.affine_grid(theta, Target.size())
aligned = F.grid_sample(Source, grid_rt)
mask = torch.ones_like(Source)
mask = F.grid_sample(mask,grid_rt)
return aligned,mask
class AC(nn.Module):
def __init__(self):
super(AC, self).__init__()
self.ch_1 = 32
self.ch_2 = 64
guide_input_channels = 8
align_input_channels = 5
self.gcm_coord = None
if not self.gcm_coord:
guide_input_channels = 6
align_input_channels = 3
self.guide_net = N.seq(
N.conv(guide_input_channels, self.ch_1, 7, stride=2, padding=0, mode='CR'),
N.conv(self.ch_1, self.ch_1, kernel_size=3, stride=1, padding=1, mode='CRC'),
nn.AdaptiveAvgPool2d(1),
N.conv(self.ch_1, self.ch_2, 1, stride=1, padding=0, mode='C')
)
self.align_head = N.conv(align_input_channels, self.ch_2, 1, padding=0, mode='CR')
self.align_base = N.seq(
N.conv(self.ch_2, self.ch_2, kernel_size=1, stride=1, padding=0, mode='CRCRCR')
)
self.align_tail = N.seq(
N.conv(self.ch_2, 3, 1, padding=0, mode='C')
)
def forward(self, demosaic_raw, dslr, coord=None):
demosaic_raw = demosaic_raw+0.01*torch.ones_like(demosaic_raw )
demosaic_raw = torch.pow(demosaic_raw, 1 / 2.2)
demosaic_raw = demosaic_raw/2
if self.gcm_coord:
guide_input = torch.cat((demosaic_raw, dslr, coord), 1)
base_input = torch.cat((demosaic_raw, coord), 1)
else:
guide_input = torch.cat((demosaic_raw, dslr), 1)
base_input = demosaic_raw
guide = self.guide_net(guide_input)
out = self.align_head(base_input)
out = guide * out + out
out = self.align_base(out)
out = self.align_tail(out) +demosaic_raw
return out