|
from __future__ import division |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.nn.init as init |
|
import torch.utils.model_zoo as model_zoo |
|
from torchvision import models |
|
from torchvision import transforms |
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import numpy as np |
|
import math |
|
import time |
|
import tqdm |
|
import os |
|
import argparse |
|
import copy |
|
import sys |
|
import networks as N |
|
from model_module import * |
|
sys.path.insert(0, '.') |
|
|
|
sys.path.insert(0, '../utils/') |
|
|
|
|
|
class LiteISPNet(nn.Module): |
|
def __init__(self,): |
|
super(LiteISPNet, self).__init__() |
|
|
|
ch_1 = 64 |
|
ch_2 = 128 |
|
ch_3 = 128 |
|
n_blocks = 4 |
|
|
|
|
|
self.head = N.seq( |
|
N.conv(3, ch_1, mode='C') |
|
) |
|
|
|
self.down1 = N.seq( |
|
N.conv(ch_1, ch_1, mode='C'), |
|
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), |
|
N.conv(ch_1, ch_1, mode='C'), |
|
N.DWTForward(ch_1) |
|
) |
|
|
|
self.down2 = N.seq( |
|
N.conv(ch_1*4, ch_1, mode='C'), |
|
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), |
|
N.DWTForward(ch_1) |
|
) |
|
|
|
self.down3 = N.seq( |
|
N.conv(ch_1*4, ch_2, mode='C'), |
|
N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks), |
|
N.DWTForward(ch_2) |
|
) |
|
|
|
self.middle = N.seq( |
|
N.conv(ch_2*4, ch_3, mode='C'), |
|
N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks), |
|
N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks), |
|
N.conv(ch_3, ch_2*4, mode='C') |
|
) |
|
|
|
self.up3 = N.seq( |
|
N.DWTInverse(ch_2*4), |
|
N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks), |
|
N.conv(ch_2, ch_1*4, mode='C') |
|
) |
|
|
|
self.up2 = N.seq( |
|
N.DWTInverse(ch_1*4), |
|
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), |
|
N.conv(ch_1, ch_1*4, mode='C') |
|
) |
|
|
|
self.up1 = N.seq( |
|
N.DWTInverse(ch_1*4), |
|
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), |
|
N.conv(ch_1, ch_1, mode='C') |
|
) |
|
|
|
self.tail = N.seq( |
|
|
|
|
|
N.conv(ch_1, 3, mode='C') |
|
) |
|
|
|
def forward(self, raw): |
|
|
|
input = torch.pow(raw, 1/2.2) |
|
|
|
h = self.head(input) |
|
h_coord = h |
|
|
|
d1 = self.down1(h_coord) |
|
d2 = self.down2(d1) |
|
d3 = self.down3(d2) |
|
m = self.middle(d3) + d3 |
|
u3 = self.up3(m) + d2 |
|
u2 = self.up2(u3) + d1 |
|
u1 = self.up1(u2) + h |
|
out = self.tail(u1) |
|
|
|
return out |
|
|
|
|
|
class LiteAWBISPNet(nn.Module): |
|
def __init__(self,): |
|
super(LiteAWBISPNet, self).__init__() |
|
|
|
ch_1 = 64 |
|
ch_2 = 128 |
|
ch_3 = 128 |
|
n_blocks = 4 |
|
|
|
|
|
self.head = N.seq( |
|
N.conv(3, ch_1, mode='C') |
|
) |
|
|
|
self.down1 = N.seq( |
|
N.conv(ch_1, ch_1, mode='C'), |
|
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), |
|
N.conv(ch_1, ch_1, mode='C'), |
|
N.DWTForward(ch_1) |
|
) |
|
|
|
self.down2 = N.seq( |
|
N.conv(ch_1*4, ch_1, mode='C'), |
|
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), |
|
N.DWTForward(ch_1) |
|
) |
|
|
|
self.down3 = N.seq( |
|
N.conv(ch_1*4, ch_2, mode='C'), |
|
N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks), |
|
N.DWTForward(ch_2) |
|
) |
|
|
|
self.middle = N.seq( |
|
N.conv(ch_2*4, ch_3, mode='C'), |
|
N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks), |
|
N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks), |
|
N.conv(ch_3, ch_2*4, mode='C') |
|
) |
|
|
|
self.up3 = N.seq( |
|
N.DWTInverse(ch_2*4), |
|
N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks), |
|
N.conv(ch_2, ch_1*4, mode='C') |
|
) |
|
|
|
self.up2 = N.seq( |
|
N.DWTInverse(ch_1*4), |
|
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), |
|
N.conv(ch_1, ch_1*4, mode='C') |
|
) |
|
|
|
self.up1 = N.seq( |
|
N.DWTInverse(ch_1*4), |
|
N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), |
|
N.conv(ch_1, ch_1, mode='C') |
|
) |
|
|
|
self.tail = N.seq( |
|
|
|
|
|
N.conv(ch_1, 3, mode='C') |
|
) |
|
|
|
def forward(self, raw): |
|
|
|
|
|
input = raw |
|
h = self.head(input) |
|
h_coord = h |
|
|
|
d1 = self.down1(h_coord) |
|
d2 = self.down2(d1) |
|
d3 = self.down3(d2) |
|
m = self.middle(d3) + d3 |
|
u3 = self.up3(m) + d2 |
|
u2 = self.up2(u3) + d1 |
|
u1 = self.up1(u2) + h |
|
out = self.tail(u1) |
|
|
|
return out |
|
|
|
|
|
|
|
class A_Encoder(nn.Module): |
|
def __init__(self): |
|
super(A_Encoder, self).__init__() |
|
self.conv12 = Conv2d(3, 64, kernel_size=5, stride=2, padding=2, activation=nn.ReLU()) |
|
self.conv2 = Conv2d(64, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
self.conv23 = Conv2d(64, 128, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) |
|
self.conv3 = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
self.conv34 = Conv2d(128, 256, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) |
|
self.conv4a = Conv2d(256, 256, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
self.conv4b = Conv2d(256, 256, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
init_He(self) |
|
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1)) |
|
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1)) |
|
|
|
def forward(self, in_f): |
|
f = (in_f - self.mean) / self.std |
|
x = f |
|
x = F.upsample(x, size=(224, 224), mode='bilinear', align_corners=False) |
|
x = self.conv12(x) |
|
x = self.conv2(x) |
|
x = self.conv23(x) |
|
x = self.conv3(x) |
|
x = self.conv34(x) |
|
x = self.conv4a(x) |
|
x = self.conv4b(x) |
|
return x |
|
|
|
|
|
class A_Regressor(nn.Module): |
|
def __init__(self): |
|
super(A_Regressor, self).__init__() |
|
self.conv45 = Conv2d(512, 512, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) |
|
self.conv5a = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
self.conv5b = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
self.conv56 = Conv2d(512, 512, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) |
|
self.conv6a = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
self.conv6b = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
init_He(self) |
|
|
|
self.fc = nn.Linear(512, 6) |
|
self.fc.weight.data.zero_() |
|
self.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float32)) |
|
|
|
def forward(self, feat1, feat2): |
|
x = torch.cat([feat1, feat2], dim=1) |
|
x = self.conv45(x) |
|
x = self.conv5a(x) |
|
x = self.conv5b(x) |
|
x = self.conv56(x) |
|
x = self.conv5a(x) |
|
x = self.conv5b(x) |
|
|
|
x = F.avg_pool2d(x, x.shape[2]) |
|
x = x.view(-1, x.shape[1]) |
|
|
|
theta = self.fc(x) |
|
theta = theta.view(-1, 2, 3) |
|
|
|
return theta |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__(self): |
|
super(Encoder, self).__init__() |
|
self.conv12 = Conv2d(4, 64, kernel_size=5, stride=2, padding=2, activation=nn.ReLU()) |
|
self.conv2 = Conv2d(64, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
self.conv23 = Conv2d(64, 128, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) |
|
self.conv3 = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
self.value3 = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=None) |
|
init_He(self) |
|
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1)) |
|
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1)) |
|
|
|
def forward(self, in_f, in_v): |
|
f = (in_f - self.mean) / self.std |
|
x = torch.cat([f, in_v], dim=1) |
|
x = self.conv12(x) |
|
x = self.conv2(x) |
|
x = self.conv23(x) |
|
x = self.conv3(x) |
|
v = self.value3(x) |
|
return v |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__(self): |
|
super(Decoder, self).__init__() |
|
self.conv4 = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
self.conv5_1 = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
self.conv5_2 = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
|
|
|
|
self.convA4_1 = Conv2d(257, 257, kernel_size=3, stride=1, padding=2, D=2, activation=nn.ReLU()) |
|
self.convA4_2 = Conv2d(257, 257, kernel_size=3, stride=1, padding=4, D=4, activation=nn.ReLU()) |
|
self.convA4_3 = Conv2d(257, 257, kernel_size=3, stride=1, padding=8, D=8, activation=nn.ReLU()) |
|
self.convA4_4 = Conv2d(257, 257, kernel_size=3, stride=1, padding=16, D=16,activation=nn.ReLU()) |
|
|
|
self.conv3c = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
self.conv3b = Conv2d(257, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
self.conv3a = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
self.conv32 = Conv2d(128, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
self.conv2 = Conv2d(64, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
|
self.conv21 = Conv2d(64, 3, kernel_size=5, stride=1, padding=2, activation=None) |
|
|
|
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1)) |
|
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1)) |
|
|
|
def forward(self, x): |
|
x = self.conv4(x) |
|
x = self.conv5_1(x) |
|
x = self.conv5_2(x) |
|
|
|
x = self.convA4_1(x) |
|
x = self.convA4_2(x) |
|
x = self.convA4_3(x) |
|
x = self.convA4_4(x) |
|
|
|
x = self.conv3c(x) |
|
x = self.conv3b(x) |
|
x = self.conv3a(x) |
|
x = F.upsample(x, scale_factor=2, mode='nearest') |
|
x = self.conv32(x) |
|
x = self.conv2(x) |
|
x = F.upsample(x, scale_factor=2, mode='nearest') |
|
x = self.conv21(x) |
|
|
|
p = (x *self.std) + self.mean |
|
return p |
|
|
|
|
|
|
|
class CM_Module(nn.Module): |
|
def __init__(self): |
|
super(CM_Module, self).__init__() |
|
|
|
def masked_softmax(self, vec, mask, dim): |
|
masked_vec = vec * mask.float() |
|
max_vec = torch.max(masked_vec, dim=dim, keepdim=True)[0] |
|
exps = torch.exp(masked_vec-max_vec) |
|
masked_exps = exps * mask.float() |
|
masked_sums = masked_exps.sum(dim, keepdim=True) |
|
zeros = (masked_sums <1e-4) |
|
masked_sums += zeros.float() |
|
return masked_exps/masked_sums |
|
|
|
def forward(self, values, tvmap, rvmaps): |
|
|
|
B, C, T, H, W = values.size() |
|
|
|
t_feat = values[:, :, 0] |
|
|
|
r_feats = values[:, :, 1:] |
|
|
|
B, Cv, T, H, W = r_feats.size() |
|
|
|
|
|
|
|
|
|
|
|
gs_,vmap_ = [], [] |
|
tvmap_t = (F.upsample(tvmap, size=(H, W), mode='bilinear', align_corners=False)>0.5).float() |
|
for r in range(T): |
|
rvmap_t = (F.upsample(rvmaps[:,:,r], size=(H, W), mode='bilinear', align_corners=False)>0.5).float() |
|
|
|
vmap = tvmap_t * rvmap_t |
|
gs = (vmap * t_feat * r_feats[:,:,r]).sum(-1).sum(-1).sum(-1) |
|
|
|
v_sum = vmap[:,0].sum(-1).sum(-1) |
|
zeros = (v_sum <1e-4) |
|
gs[zeros] = 0 |
|
v_sum += zeros.float() |
|
gs = gs / v_sum / C |
|
gs = torch.ones(t_feat.shape).float().cuda() * gs.view(B,1,1,1) |
|
gs_.append(gs) |
|
vmap_.append(rvmap_t) |
|
|
|
gss = torch.stack(gs_, dim=2) |
|
vmaps = torch.stack(vmap_, dim=2) |
|
|
|
|
|
c_match = self.masked_softmax(gss, vmaps, dim=2) |
|
c_out = torch.sum(r_feats * c_match, dim=2) |
|
|
|
|
|
c_mask = (c_match * vmaps) |
|
c_mask = torch.sum(c_mask,2) |
|
c_mask = 1. - (torch.mean(c_mask, 1, keepdim=True)) |
|
|
|
return torch.cat([t_feat, c_out, c_mask], dim=1), c_mask |
|
|
|
|
|
class GCMModel(nn.Module): |
|
def __init__(self): |
|
super(GCMModel, self).__init__() |
|
self.ch_1 = 16 |
|
self.ch_2 = 32 |
|
guide_input_channels = 3 |
|
align_input_channels = 3 |
|
self.gcm_coord = None |
|
|
|
if not self.gcm_coord: |
|
guide_input_channels = 3 |
|
align_input_channels = 3 |
|
|
|
self.guide_net = N.seq( |
|
N.conv(guide_input_channels, self.ch_1, 7, stride=2, padding=0, mode='CR'), |
|
N.conv(self.ch_1, self.ch_1, kernel_size=3, stride=1, padding=1, mode='CRC'), |
|
nn.AdaptiveAvgPool2d(1), |
|
N.conv(self.ch_1, self.ch_2, 1, stride=1, padding=0, mode='C') |
|
) |
|
|
|
self.align_head = N.conv(align_input_channels, self.ch_2, 1, padding=0, mode='CR') |
|
|
|
self.align_base = N.seq( |
|
N.conv(self.ch_2, self.ch_2, kernel_size=1, stride=1, padding=0, mode='CRCRCRCRCR') |
|
) |
|
self.align_tail = N.seq( |
|
N.conv(self.ch_2, 3, 1, padding=0, mode='C') |
|
) |
|
|
|
def forward(self, demosaic_raw): |
|
demosaic_raw = torch.pow(demosaic_raw, 1 / 2.2) |
|
guide_input = demosaic_raw |
|
base_input =demosaic_raw |
|
guide = self.guide_net(guide_input) |
|
out = self.align_head(base_input) |
|
out = guide * out + out |
|
out = self.align_base(out) |
|
out = self.align_tail(out)+demosaic_raw |
|
|
|
return out |
|
|
|
class Fusion(nn.Module): |
|
def __init__(self): |
|
super(Fusion, self).__init__() |
|
self.ch_1 = 16 |
|
self.ch_2 = 32 |
|
guide_input_channels = 9 |
|
align_input_channels = 9 |
|
self.gcm_coord = None |
|
|
|
if not self.gcm_coord: |
|
guide_input_channels = 9 |
|
align_input_channels = 9 |
|
|
|
self.guide_net = N.seq( |
|
N.conv(guide_input_channels, self.ch_1, 7, stride=2, padding=0, mode='CR'), |
|
N.conv(self.ch_1, self.ch_1, kernel_size=3, stride=1, padding=1, mode='CRC'), |
|
nn.AdaptiveAvgPool2d(1), |
|
N.conv(self.ch_1, self.ch_2, 1, stride=1, padding=0, mode='C') |
|
) |
|
|
|
self.align_head = N.conv(align_input_channels, self.ch_2, 1, padding=0, mode='CR') |
|
|
|
self.align_base = N.seq( |
|
N.conv(self.ch_2, self.ch_2, kernel_size=1, stride=1, padding=0, mode='CRCRCR') |
|
) |
|
self.align_tail = N.seq( |
|
N.conv(self.ch_2, 3, 1, padding=0, mode='C') |
|
) |
|
|
|
def forward(self, demosaic_raw): |
|
|
|
guide_input = demosaic_raw |
|
base_input =demosaic_raw |
|
guide = self.guide_net(guide_input) |
|
out = self.align_head(base_input) |
|
out = guide * out + out |
|
out = self.align_base(out) |
|
out = self.align_tail(out) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
class CPNet(nn.Module): |
|
def __init__(self, mode='Train'): |
|
super(CPNet, self).__init__() |
|
self.A_Encoder = A_Encoder() |
|
self.A_Regressor = A_Regressor() |
|
self.GCMModel = GCMModel() |
|
self.Encoder = Encoder() |
|
self.CM_Module = CM_Module() |
|
|
|
self.Decoder = Decoder() |
|
|
|
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1)) |
|
self.register_buffer('mean3d', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1,1)) |
|
|
|
|
|
def encoding(self, frames, holes): |
|
|
|
batch_size, _, num_frames, height, width = frames.size() |
|
|
|
(frames, holes), pad = pad_divide_by([frames, holes], 8, (frames.size()[3], frames.size()[4])) |
|
|
|
feat_ = [] |
|
for t in range(num_frames): |
|
feat = self.A_Encoder(frames[:,:,t], holes[:,:,t]) |
|
feat_.append(feat) |
|
feats = torch.stack(feat_, dim=2) |
|
return feats |
|
|
|
def inpainting(self, rfeats, rframes, rholes, frame, hole, gt): |
|
|
|
batch_size, _, height, width = frame.size() |
|
num_r = rfeats.size()[2] |
|
|
|
|
|
(rframes, rholes, frame, hole, gt), pad = pad_divide_by([rframes, rholes, frame, hole, gt], 8, (height, width)) |
|
|
|
|
|
tfeat = self.A_Encoder(frame, hole) |
|
|
|
|
|
c_feat_ = [self.Encoder(frame, hole)] |
|
L_align = torch.zeros_like(frame) |
|
|
|
|
|
aligned_r_ = [] |
|
|
|
|
|
rvmap_ = [] |
|
|
|
for r in range(num_r): |
|
theta_rt = self.A_Regressor(tfeat, rfeats[:,:,r]) |
|
grid_rt = F.affine_grid(theta_rt, frame.size()) |
|
|
|
|
|
|
|
aligned_r = F.grid_sample(rframes[:,:,r], grid_rt) |
|
|
|
|
|
|
|
aligned_v = F.grid_sample(1-rholes[:,:,r], grid_rt) |
|
aligned_v = (aligned_v>0.5).float() |
|
|
|
aligned_r_.append(aligned_r) |
|
|
|
|
|
trvmap = (1-hole) * aligned_v |
|
|
|
|
|
c_feat_.append(self.Encoder(aligned_r, aligned_v)) |
|
|
|
rvmap_.append(aligned_v) |
|
|
|
aligned_rs = torch.stack(aligned_r_, 2) |
|
|
|
c_feats =torch.stack(c_feat_, dim=2) |
|
rvmaps = torch.stack(rvmap_, dim=2) |
|
|
|
|
|
p_in, c_mask = self.CM_Module(c_feats, 1-hole, rvmaps) |
|
|
|
pred = self.Decoder(p_in) |
|
|
|
_, _, _, H, W = aligned_rs.shape |
|
c_mask = (F.upsample(c_mask, size=(H, W), mode='bilinear', align_corners=False)).detach() |
|
|
|
comp = pred * (hole) + gt * (1.-hole) |
|
|
|
|
|
if pad[2]+pad[3] > 0: |
|
comp = comp[:,:,pad[2]:-pad[3],:] |
|
|
|
if pad[0]+pad[1] > 0: |
|
comp = comp[:,:,:,pad[0]:-pad[1]] |
|
|
|
comp = torch.clamp(comp, 0, 1) |
|
|
|
return comp |
|
|
|
def forward(self, Source, Target): |
|
|
|
feat_target =self.A_Encoder(Target) |
|
feat_source = self.A_Encoder(Source) |
|
|
|
theta = self.A_Regressor(feat_target,feat_source) |
|
grid_rt = F.affine_grid(theta, Target.size()) |
|
aligned = F.grid_sample(Source, grid_rt) |
|
mask = torch.ones_like(Source) |
|
mask = F.grid_sample(mask,grid_rt) |
|
|
|
return aligned,mask |
|
|
|
|
|
class AC(nn.Module): |
|
def __init__(self): |
|
super(AC, self).__init__() |
|
self.ch_1 = 32 |
|
self.ch_2 = 64 |
|
guide_input_channels = 8 |
|
align_input_channels = 5 |
|
self.gcm_coord = None |
|
|
|
if not self.gcm_coord: |
|
guide_input_channels = 6 |
|
align_input_channels = 3 |
|
|
|
self.guide_net = N.seq( |
|
N.conv(guide_input_channels, self.ch_1, 7, stride=2, padding=0, mode='CR'), |
|
N.conv(self.ch_1, self.ch_1, kernel_size=3, stride=1, padding=1, mode='CRC'), |
|
nn.AdaptiveAvgPool2d(1), |
|
N.conv(self.ch_1, self.ch_2, 1, stride=1, padding=0, mode='C') |
|
) |
|
|
|
self.align_head = N.conv(align_input_channels, self.ch_2, 1, padding=0, mode='CR') |
|
|
|
self.align_base = N.seq( |
|
N.conv(self.ch_2, self.ch_2, kernel_size=1, stride=1, padding=0, mode='CRCRCR') |
|
) |
|
self.align_tail = N.seq( |
|
N.conv(self.ch_2, 3, 1, padding=0, mode='C') |
|
) |
|
|
|
def forward(self, demosaic_raw, dslr, coord=None): |
|
demosaic_raw = demosaic_raw+0.01*torch.ones_like(demosaic_raw ) |
|
demosaic_raw = torch.pow(demosaic_raw, 1 / 2.2) |
|
demosaic_raw = demosaic_raw/2 |
|
if self.gcm_coord: |
|
guide_input = torch.cat((demosaic_raw, dslr, coord), 1) |
|
base_input = torch.cat((demosaic_raw, coord), 1) |
|
else: |
|
guide_input = torch.cat((demosaic_raw, dslr), 1) |
|
base_input = demosaic_raw |
|
|
|
guide = self.guide_net(guide_input) |
|
|
|
out = self.align_head(base_input) |
|
out = guide * out + out |
|
out = self.align_base(out) |
|
out = self.align_tail(out) +demosaic_raw |
|
|
|
return out |