Spaces:
Runtime error
Runtime error
import functools | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import init | |
from torch.nn import Parameter as P | |
from diffusers import ModelMixin | |
from diffusers.configuration_utils import (ConfigMixin, | |
register_to_config) | |
def proj(x, y): | |
return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) | |
def gram_schmidt(x, ys): | |
for y in ys: | |
x = x - proj(x, y) | |
return x | |
def power_iteration(W, u_, update=True, eps=1e-12): | |
us, vs, svs = [], [], [] | |
for i, u in enumerate(u_): | |
with torch.no_grad(): | |
v = torch.matmul(u, W) | |
v = F.normalize(gram_schmidt(v, vs), eps=eps) | |
vs += [v] | |
u = torch.matmul(v, W.t()) | |
u = F.normalize(gram_schmidt(u, us), eps=eps) | |
us += [u] | |
if update: | |
u_[i][:] = u | |
svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] | |
return svs, us, vs | |
class LinearBlock(nn.Module): | |
def __init__( | |
self, | |
in_dim, | |
out_dim, | |
norm='none', | |
act='relu', | |
use_sn=False | |
): | |
super(LinearBlock, self).__init__() | |
use_bias = True | |
self.fc = nn.Linear(in_dim, out_dim, bias=use_bias) | |
if use_sn: | |
self.fc = nn.utils.spectral_norm(self.fc) | |
# initialize normalization | |
norm_dim = out_dim | |
if norm == 'bn': | |
self.norm = nn.BatchNorm1d(norm_dim) | |
elif norm == 'in': | |
self.norm = nn.InstanceNorm1d(norm_dim) | |
elif norm == 'none': | |
self.norm = None | |
else: | |
assert 0, "Unsupported normalization: {}".format(norm) | |
# initialize activation | |
if act == 'relu': | |
self.activation = nn.ReLU(inplace=True) | |
elif act == 'lrelu': | |
self.activation = nn.LeakyReLU(0.2, inplace=True) | |
elif act == 'tanh': | |
self.activation = nn.Tanh() | |
elif act == 'none': | |
self.activation = None | |
else: | |
assert 0, "Unsupported activation: {}".format(act) | |
def forward(self, x): | |
out = self.fc(x) | |
if self.norm: | |
out = self.norm(out) | |
if self.activation: | |
out = self.activation(out) | |
return out | |
class MLP(nn.Module): | |
def __init__( | |
self, | |
nf_in, | |
nf_out, | |
nf_mlp, | |
num_blocks, | |
norm, | |
act, | |
use_sn =False | |
): | |
super(MLP,self).__init__() | |
self.model = nn.ModuleList() | |
nf = nf_mlp | |
self.model.append(LinearBlock(nf_in, nf, norm = norm, act = act, use_sn = use_sn)) | |
for _ in range((num_blocks - 2)): | |
self.model.append(LinearBlock(nf, nf, norm=norm, act=act, use_sn=use_sn)) | |
self.model.append(LinearBlock(nf, nf_out, norm='none', act ='none', use_sn = use_sn)) | |
self.model = nn.Sequential(*self.model) | |
def forward(self, x): | |
return self.model(x.view(x.size(0), -1)) | |
class SN(object): | |
def __init__( | |
self, | |
num_svs, | |
num_itrs, | |
num_outputs, | |
transpose=False, | |
eps=1e-12 | |
): | |
self.num_itrs = num_itrs | |
self.num_svs = num_svs | |
self.transpose = transpose | |
self.eps = eps | |
for i in range(self.num_svs): | |
self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) | |
self.register_buffer('sv%d' % i, torch.ones(1)) | |
def u(self): | |
return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] | |
def sv(self): | |
return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] | |
def W_(self): | |
W_mat = self.weight.view(self.weight.size(0), -1) | |
if self.transpose: | |
W_mat = W_mat.t() | |
for _ in range(self.num_itrs): | |
svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) | |
if self.training: | |
with torch.no_grad(): | |
for i, sv in enumerate(svs): | |
self.sv[i][:] = sv | |
return self.weight / svs[0] | |
class SNConv2d(nn.Conv2d, SN): | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, | |
padding=0, dilation=1, groups=1, bias=True, | |
num_svs=1, num_itrs=1, eps=1e-12): | |
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, | |
padding, dilation, groups, bias) | |
SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) | |
def forward(self, x): | |
return F.conv2d(x, self.W_(), self.bias, self.stride, | |
self.padding, self.dilation, self.groups) | |
def forward_wo_sn(self, x): | |
return F.conv2d(x, self.weight, self.bias, self.stride, | |
self.padding, self.dilation, self.groups) | |
class SNLinear(nn.Linear, SN): | |
def __init__(self, in_features, out_features, bias=True, | |
num_svs=1, num_itrs=1, eps=1e-12): | |
nn.Linear.__init__(self, in_features, out_features, bias) | |
SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) | |
def forward(self, x): | |
return F.linear(x, self.W_(), self.bias) | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
ch, | |
which_conv=SNConv2d, | |
name='attention' | |
): | |
super(Attention, self).__init__() | |
self.ch = ch | |
self.which_conv = which_conv | |
self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) | |
self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) | |
self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False) | |
self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False) | |
# Learnable gain parameter | |
self.gamma = P(torch.tensor(0.), requires_grad=True) | |
def forward(self, x, y=None): | |
theta = self.theta(x) | |
phi = F.max_pool2d(self.phi(x), [2,2]) | |
g = F.max_pool2d(self.g(x), [2,2]) | |
theta = theta.view(-1, self. ch // 8, x.shape[2] * x.shape[3]) | |
phi = phi.view(-1, self. ch // 8, x.shape[2] * x.shape[3] // 4) | |
g = g.view(-1, self. ch // 2, x.shape[2] * x.shape[3] // 4) | |
beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) | |
o = self.o(torch.bmm(g, beta.transpose(1,2)).view(-1, self.ch // 2, x.shape[2], x.shape[3])) | |
return self.gamma * o + x | |
class DBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True, | |
preactivation=False, activation=None, downsample=None,): | |
super(DBlock, self).__init__() | |
self.in_channels, self.out_channels = in_channels, out_channels | |
self.hidden_channels = self.out_channels if wide else self.in_channels | |
self.which_conv = which_conv | |
self.preactivation = preactivation | |
self.activation = activation | |
self.downsample = downsample | |
# Conv layers | |
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels) | |
self.conv2 = self.which_conv(self.hidden_channels, self.out_channels) | |
self.learnable_sc = True if (in_channels != out_channels) or downsample else False | |
if self.learnable_sc: | |
self.conv_sc = self.which_conv(in_channels, out_channels, | |
kernel_size=1, padding=0) | |
def shortcut(self, x): | |
if self.preactivation: | |
if self.learnable_sc: | |
x = self.conv_sc(x) | |
if self.downsample: | |
x = self.downsample(x) | |
else: | |
if self.downsample: | |
x = self.downsample(x) | |
if self.learnable_sc: | |
x = self.conv_sc(x) | |
return x | |
def forward(self, x): | |
if self.preactivation: | |
h = F.relu(x) | |
else: | |
h = x | |
h = self.conv1(h) | |
h = self.conv2(self.activation(h)) | |
if self.downsample: | |
h = self.downsample(h) | |
return h + self.shortcut(x) | |
class GBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, | |
which_conv=nn.Conv2d,which_bn= nn.BatchNorm2d, activation=None, | |
upsample=None): | |
super(GBlock, self).__init__() | |
self.in_channels, self.out_channels = in_channels, out_channels | |
self.which_conv,self.which_bn =which_conv, which_bn | |
self.activation = activation | |
self.upsample = upsample | |
# Conv layers | |
self.conv1 = self.which_conv(self.in_channels, self.out_channels) | |
self.conv2 = self.which_conv(self.out_channels, self.out_channels) | |
self.learnable_sc = in_channels != out_channels or upsample | |
if self.learnable_sc: | |
self.conv_sc = self.which_conv(in_channels, out_channels, | |
kernel_size=1, padding=0) | |
# Batchnorm layers | |
self.bn1 = self.which_bn(in_channels) | |
self.bn2 = self.which_bn(out_channels) | |
# upsample layers | |
self.upsample = upsample | |
def forward(self, x): | |
h = self.activation(self.bn1(x)) | |
if self.upsample: | |
h = self.upsample(h) | |
x = self.upsample(x) | |
h = self.conv1(h) | |
h = self.activation(self.bn2(h)) | |
h = self.conv2(h) | |
if self.learnable_sc: | |
x = self.conv_sc(x) | |
return h + x | |
class GBlock2(nn.Module): | |
def __init__(self, in_channels, out_channels, | |
which_conv=nn.Conv2d, activation=None, | |
upsample=None, skip_connection = True): | |
super(GBlock2, self).__init__() | |
self.in_channels, self.out_channels = in_channels, out_channels | |
self.which_conv = which_conv | |
self.activation = activation | |
self.upsample = upsample | |
# Conv layers | |
self.conv1 = self.which_conv(self.in_channels, self.out_channels) | |
self.conv2 = self.which_conv(self.out_channels, self.out_channels) | |
self.learnable_sc = in_channels != out_channels or upsample | |
if self.learnable_sc: | |
self.conv_sc = self.which_conv(in_channels, out_channels, | |
kernel_size=1, padding=0) | |
# upsample layers | |
self.upsample = upsample | |
self.skip_connection = skip_connection | |
def forward(self, x): | |
h = self.activation(x) | |
if self.upsample: | |
h = self.upsample(h) | |
x = self.upsample(x) | |
h = self.conv1(h) | |
h = self.activation(h) | |
h = self.conv2(h) | |
if self.learnable_sc: | |
x = self.conv_sc(x) | |
if self.skip_connection: | |
out = h + x | |
else: | |
out = h | |
return out | |
def content_encoder_arch(ch =64,out_channel_multiplier = 1, input_nc = 3): | |
arch = {} | |
n=2 | |
arch[80] = {'in_channels': [input_nc] + [ch*item for item in [1,2]], | |
'out_channels' : [item * ch for item in [1,2,4]], | |
'resolution': [40,20,10]} | |
arch[96] = {'in_channels': [input_nc] + [ch*item for item in [1,2]], | |
'out_channels' : [item * ch for item in [1,2,4]], | |
'resolution': [48,24,12]} | |
arch[128] = {'in_channels': [input_nc] + [ch*item for item in [1,2,4,8]], | |
'out_channels' : [item * ch for item in [1,2,4,8,16]], | |
'resolution': [64,32,16,8,4]} | |
arch[256] = {'in_channels':[input_nc]+[ch*item for item in [1,2,4,8,8]], | |
'out_channels':[item*ch for item in [1,2,4,8,8,16]], | |
'resolution': [128,64,32,16,8,4]} | |
return arch | |
class ContentEncoder(ModelMixin, ConfigMixin): | |
def __init__(self, G_ch=64, G_wide=True, resolution=128, | |
G_kernel_size=3, G_attn='64_32_16_8', n_classes=1000, | |
num_G_SVs=1, num_G_SV_itrs=1, G_activation=nn.ReLU(inplace=False), | |
SN_eps=1e-12, output_dim=1, G_fp16=False, | |
G_init='N02', G_param='SN', nf_mlp = 512, nEmbedding = 256, input_nc = 3,output_nc = 3): | |
super(ContentEncoder, self).__init__() | |
self.ch = G_ch | |
self.G_wide = G_wide | |
self.resolution = resolution | |
self.kernel_size = G_kernel_size | |
self.attention = G_attn | |
self.n_classes = n_classes | |
self.activation = G_activation | |
self.init = G_init | |
self.G_param = G_param | |
self.SN_eps = SN_eps | |
self.fp16 = G_fp16 | |
if self.resolution == 96: | |
self.save_featrues = [0,1,2,3,4] | |
elif self.resolution == 80: | |
self.save_featrues = [0,1,2,3,4] | |
elif self.resolution == 128: | |
self.save_featrues = [0,1,2,3,4] | |
elif self.resolution == 256: | |
self.save_featrues = [0,1,2,3,4,5] | |
self.out_channel_nultipiler = 1 | |
self.arch = content_encoder_arch(self.ch, self.out_channel_nultipiler,input_nc)[resolution] | |
if self.G_param == 'SN': | |
self.which_conv = functools.partial(SNConv2d, | |
kernel_size=3, padding=1, | |
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, | |
eps=self.SN_eps) | |
self.which_linear = functools.partial(SNLinear, | |
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, | |
eps=self.SN_eps) | |
self.blocks = [] | |
for index in range(len(self.arch['out_channels'])): | |
self.blocks += [[DBlock(in_channels=self.arch['in_channels'][index], | |
out_channels=self.arch['out_channels'][index], | |
which_conv=self.which_conv, | |
wide=self.G_wide, | |
activation=self.activation, | |
preactivation=(index > 0), | |
downsample=nn.AvgPool2d(2))]] | |
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) | |
self.init_weights() | |
def init_weights(self): | |
self.param_count = 0 | |
for module in self.modules(): | |
if (isinstance(module, nn.Conv2d) | |
or isinstance(module, nn.Linear) | |
or isinstance(module, nn.Embedding)): | |
if self.init == 'ortho': | |
init.orthogonal_(module.weight) | |
elif self.init == 'N02': | |
init.normal_(module.weight, 0, 0.02) | |
elif self.init in ['glorot', 'xavier']: | |
init.xavier_uniform_(module.weight) | |
else: | |
print('Init style not recognized...') | |
self.param_count += sum([p.data.nelement() for p in module.parameters()]) | |
print('Param count for D''s initialized parameters: %d' % self.param_count) | |
def forward(self,x): | |
h = x | |
residual_features = [] | |
residual_features.append(h) | |
for index, blocklist in enumerate(self.blocks): | |
for block in blocklist: | |
h = block(h) | |
if index in self.save_featrues[:-1]: | |
residual_features.append(h) | |
return h,residual_features | |