faceplugin's picture
initial commit
901e379
raw
history blame
23 kB
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
_weights_dict = dict()
def load_weights(weight_file):
if weight_file == None:
return
try:
weights_dict = np.load(weight_file, allow_pickle=True).item()
except:
weights_dict = np.load(weight_file, allow_pickle=True, encoding='bytes').item()
return weights_dict
class irn50_pytorch(nn.Module):
def __init__(self, weight_file):
super(irn50_pytorch, self).__init__()
global _weights_dict
_weights_dict = load_weights(weight_file)
self.Convolution1 = self.__conv(2, name='Convolution1', in_channels=3, out_channels=32, kernel_size=(3, 3), stride=(2, 2), groups=1, bias=False)
self.BatchNorm1 = self.__batch_normalization(2, 'BatchNorm1', num_features=32, eps=9.999999747378752e-06, momentum=0.0)
self.Convolution2 = self.__conv(2, name='Convolution2', in_channels=32, out_channels=32, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.BatchNorm2 = self.__batch_normalization(2, 'BatchNorm2', num_features=32, eps=9.999999747378752e-06, momentum=0.0)
self.Convolution3 = self.__conv(2, name='Convolution3', in_channels=32, out_channels=64, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.BatchNorm3 = self.__batch_normalization(2, 'BatchNorm3', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.Convolution4 = self.__conv(2, name='Convolution4', in_channels=64, out_channels=80, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.BatchNorm4 = self.__batch_normalization(2, 'BatchNorm4', num_features=80, eps=9.999999747378752e-06, momentum=0.0)
self.Convolution5 = self.__conv(2, name='Convolution5', in_channels=80, out_channels=192, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.BatchNorm5 = self.__batch_normalization(2, 'BatchNorm5', num_features=192, eps=9.999999747378752e-06, momentum=0.0)
self.Convolution6 = self.__conv(2, name='Convolution6', in_channels=192, out_channels=256, kernel_size=(3, 3), stride=(2, 2), groups=1, bias=False)
self.BatchNorm6 = self.__batch_normalization(2, 'BatchNorm6', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res1_proj = self.__conv(2, name='conv2_res1_proj', in_channels=256, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv2_res1_conv1 = self.__conv(2, name='conv2_res1_conv1', in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv2_res1_conv1_bn = self.__batch_normalization(2, 'conv2_res1_conv1_bn', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res1_conv2 = self.__conv(2, name='conv2_res1_conv2', in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv2_res1_conv2_bn = self.__batch_normalization(2, 'conv2_res1_conv2_bn', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res1_conv3 = self.__conv(2, name='conv2_res1_conv3', in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv2_res2_pre_bn = self.__batch_normalization(2, 'conv2_res2_pre_bn', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res2_conv1 = self.__conv(2, name='conv2_res2_conv1', in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv2_res2_conv1_bn = self.__batch_normalization(2, 'conv2_res2_conv1_bn', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res2_conv2 = self.__conv(2, name='conv2_res2_conv2', in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv2_res2_conv2_bn = self.__batch_normalization(2, 'conv2_res2_conv2_bn', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res2_conv3 = self.__conv(2, name='conv2_res2_conv3', in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv2_res3_pre_bn = self.__batch_normalization(2, 'conv2_res3_pre_bn', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res3_conv1 = self.__conv(2, name='conv2_res3_conv1', in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv2_res3_conv1_bn = self.__batch_normalization(2, 'conv2_res3_conv1_bn', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res3_conv2 = self.__conv(2, name='conv2_res3_conv2', in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv2_res3_conv2_bn = self.__batch_normalization(2, 'conv2_res3_conv2_bn', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res3_conv3 = self.__conv(2, name='conv2_res3_conv3', in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv3_res1_pre_bn = self.__batch_normalization(2, 'conv3_res1_pre_bn', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res1_proj = self.__conv(2, name='conv3_res1_proj', in_channels=256, out_channels=512, kernel_size=(1, 1), stride=(2, 2), groups=1, bias=False)
self.conv3_res1_conv1 = self.__conv(2, name='conv3_res1_conv1', in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(2, 2), groups=1, bias=False)
self.conv3_res1_conv1_bn = self.__batch_normalization(2, 'conv3_res1_conv1_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res1_conv2 = self.__conv(2, name='conv3_res1_conv2', in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv3_res1_conv2_bn = self.__batch_normalization(2, 'conv3_res1_conv2_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res1_conv3 = self.__conv(2, name='conv3_res1_conv3', in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv3_res2_pre_bn = self.__batch_normalization(2, 'conv3_res2_pre_bn', num_features=512, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res2_conv1 = self.__conv(2, name='conv3_res2_conv1', in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv3_res2_conv1_bn = self.__batch_normalization(2, 'conv3_res2_conv1_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res2_conv2 = self.__conv(2, name='conv3_res2_conv2', in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv3_res2_conv2_bn = self.__batch_normalization(2, 'conv3_res2_conv2_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res2_conv3 = self.__conv(2, name='conv3_res2_conv3', in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv3_res3_pre_bn = self.__batch_normalization(2, 'conv3_res3_pre_bn', num_features=512, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res3_conv1 = self.__conv(2, name='conv3_res3_conv1', in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv3_res3_conv1_bn = self.__batch_normalization(2, 'conv3_res3_conv1_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res3_conv2 = self.__conv(2, name='conv3_res3_conv2', in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv3_res3_conv2_bn = self.__batch_normalization(2, 'conv3_res3_conv2_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res3_conv3 = self.__conv(2, name='conv3_res3_conv3', in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv3_res4_pre_bn = self.__batch_normalization(2, 'conv3_res4_pre_bn', num_features=512, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res4_conv1 = self.__conv(2, name='conv3_res4_conv1', in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv3_res4_conv1_bn = self.__batch_normalization(2, 'conv3_res4_conv1_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res4_conv2 = self.__conv(2, name='conv3_res4_conv2', in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv3_res4_conv2_bn = self.__batch_normalization(2, 'conv3_res4_conv2_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res4_conv3 = self.__conv(2, name='conv3_res4_conv3', in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv4_res1_pre_bn = self.__batch_normalization(2, 'conv4_res1_pre_bn', num_features=512, eps=9.999999747378752e-06, momentum=0.0)
self.conv4_res1_proj = self.__conv(2, name='conv4_res1_proj', in_channels=512, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv4_res1_conv1 = self.__conv(2, name='conv4_res1_conv1', in_channels=512, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv4_res1_conv1_bn = self.__batch_normalization(2, 'conv4_res1_conv1_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv4_res1_conv2 = self.__conv(2, name='conv4_res1_conv2', in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv4_res2_pre_bn = self.__batch_normalization(2, 'conv4_res2_pre_bn', num_features=512, eps=9.999999747378752e-06, momentum=0.0)
self.conv4_res2_conv1_proj = self.__conv(2, name='conv4_res2_conv1_proj', in_channels=512, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv4_res2_conv1 = self.__conv(2, name='conv4_res2_conv1', in_channels=512, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv4_res2_conv1_bn = self.__batch_normalization(2, 'conv4_res2_conv1_bn', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv4_res2_conv2 = self.__conv(2, name='conv4_res2_conv2', in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv4_res2_conv2_bn = self.__batch_normalization(2, 'conv4_res2_conv2_bn', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv4_res2_conv3 = self.__conv(2, name='conv4_res2_conv3', in_channels=256, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv4_res3_pre_bn = self.__batch_normalization(2, 'conv4_res3_pre_bn', num_features=1024, eps=9.999999747378752e-06, momentum=0.0)
self.conv4_res3_conv1 = self.__conv(2, name='conv4_res3_conv1', in_channels=1024, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv4_res3_conv1_bn = self.__batch_normalization(2, 'conv4_res3_conv1_bn', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv4_res3_conv2 = self.__conv(2, name='conv4_res3_conv2', in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv4_res3_conv2_bn = self.__batch_normalization(2, 'conv4_res3_conv2_bn', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv4_res3_conv3 = self.__conv(2, name='conv4_res3_conv3', in_channels=256, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv5_bn = self.__batch_normalization(2, 'conv5_bn', num_features=1024, eps=9.999999747378752e-06, momentum=0.0)
self.fc1_1 = self.__dense(name = 'fc1_1', in_features = 16384, out_features = 512, bias = False)
self.bn_fc1 = self.__batch_normalization(1, 'bn_fc1', num_features=512, eps=9.999999747378752e-06, momentum=0.0)
def forward(self, x):
Convolution1 = self.Convolution1(x)
BatchNorm1 = self.BatchNorm1(Convolution1)
ReLU1 = F.relu(BatchNorm1)
Convolution2 = self.Convolution2(ReLU1)
BatchNorm2 = self.BatchNorm2(Convolution2)
ReLU2 = F.relu(BatchNorm2)
Convolution3_pad = F.pad(ReLU2, (1, 1, 1, 1))
Convolution3 = self.Convolution3(Convolution3_pad)
BatchNorm3 = self.BatchNorm3(Convolution3)
ReLU3 = F.relu(BatchNorm3)
Pooling1_pad = F.pad(ReLU3, (0, 1, 0, 1), value=float('-inf'))
Pooling1, Pooling1_idx = F.max_pool2d(Pooling1_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False, return_indices=True)
Convolution4 = self.Convolution4(Pooling1)
BatchNorm4 = self.BatchNorm4(Convolution4)
ReLU4 = F.relu(BatchNorm4)
Convolution5 = self.Convolution5(ReLU4)
BatchNorm5 = self.BatchNorm5(Convolution5)
ReLU5 = F.relu(BatchNorm5)
Convolution6_pad = F.pad(ReLU5, (1, 1, 1, 1))
Convolution6 = self.Convolution6(Convolution6_pad)
BatchNorm6 = self.BatchNorm6(Convolution6)
ReLU6 = F.relu(BatchNorm6)
conv2_res1_proj = self.conv2_res1_proj(ReLU6)
conv2_res1_conv1 = self.conv2_res1_conv1(ReLU6)
conv2_res1_conv1_bn = self.conv2_res1_conv1_bn(conv2_res1_conv1)
conv2_res1_conv1_relu = F.relu(conv2_res1_conv1_bn)
conv2_res1_conv2_pad = F.pad(conv2_res1_conv1_relu, (1, 1, 1, 1))
conv2_res1_conv2 = self.conv2_res1_conv2(conv2_res1_conv2_pad)
conv2_res1_conv2_bn = self.conv2_res1_conv2_bn(conv2_res1_conv2)
conv2_res1_conv2_relu = F.relu(conv2_res1_conv2_bn)
conv2_res1_conv3 = self.conv2_res1_conv3(conv2_res1_conv2_relu)
conv2_res1 = conv2_res1_proj + conv2_res1_conv3
conv2_res2_pre_bn = self.conv2_res2_pre_bn(conv2_res1)
conv2_res2_pre_relu = F.relu(conv2_res2_pre_bn)
conv2_res2_conv1 = self.conv2_res2_conv1(conv2_res2_pre_relu)
conv2_res2_conv1_bn = self.conv2_res2_conv1_bn(conv2_res2_conv1)
conv2_res2_conv1_relu = F.relu(conv2_res2_conv1_bn)
conv2_res2_conv2_pad = F.pad(conv2_res2_conv1_relu, (1, 1, 1, 1))
conv2_res2_conv2 = self.conv2_res2_conv2(conv2_res2_conv2_pad)
conv2_res2_conv2_bn = self.conv2_res2_conv2_bn(conv2_res2_conv2)
conv2_res2_conv2_relu = F.relu(conv2_res2_conv2_bn)
conv2_res2_conv3 = self.conv2_res2_conv3(conv2_res2_conv2_relu)
conv2_res2 = conv2_res1 + conv2_res2_conv3
conv2_res3_pre_bn = self.conv2_res3_pre_bn(conv2_res2)
conv2_res3_pre_relu = F.relu(conv2_res3_pre_bn)
conv2_res3_conv1 = self.conv2_res3_conv1(conv2_res3_pre_relu)
conv2_res3_conv1_bn = self.conv2_res3_conv1_bn(conv2_res3_conv1)
conv2_res3_conv1_relu = F.relu(conv2_res3_conv1_bn)
conv2_res3_conv2_pad = F.pad(conv2_res3_conv1_relu, (1, 1, 1, 1))
conv2_res3_conv2 = self.conv2_res3_conv2(conv2_res3_conv2_pad)
conv2_res3_conv2_bn = self.conv2_res3_conv2_bn(conv2_res3_conv2)
conv2_res3_conv2_relu = F.relu(conv2_res3_conv2_bn)
conv2_res3_conv3 = self.conv2_res3_conv3(conv2_res3_conv2_relu)
conv2_res3 = conv2_res2 + conv2_res3_conv3
conv3_res1_pre_bn = self.conv3_res1_pre_bn(conv2_res3)
conv3_res1_pre_relu = F.relu(conv3_res1_pre_bn)
conv3_res1_proj = self.conv3_res1_proj(conv3_res1_pre_relu)
conv3_res1_conv1 = self.conv3_res1_conv1(conv3_res1_pre_relu)
conv3_res1_conv1_bn = self.conv3_res1_conv1_bn(conv3_res1_conv1)
conv3_res1_conv1_relu = F.relu(conv3_res1_conv1_bn)
conv3_res1_conv2_pad = F.pad(conv3_res1_conv1_relu, (1, 1, 1, 1))
conv3_res1_conv2 = self.conv3_res1_conv2(conv3_res1_conv2_pad)
conv3_res1_conv2_bn = self.conv3_res1_conv2_bn(conv3_res1_conv2)
conv3_res1_conv2_relu = F.relu(conv3_res1_conv2_bn)
conv3_res1_conv3 = self.conv3_res1_conv3(conv3_res1_conv2_relu)
conv3_res1 = conv3_res1_proj + conv3_res1_conv3
conv3_res2_pre_bn = self.conv3_res2_pre_bn(conv3_res1)
conv3_res2_pre_relu = F.relu(conv3_res2_pre_bn)
conv3_res2_conv1 = self.conv3_res2_conv1(conv3_res2_pre_relu)
conv3_res2_conv1_bn = self.conv3_res2_conv1_bn(conv3_res2_conv1)
conv3_res2_conv1_relu = F.relu(conv3_res2_conv1_bn)
conv3_res2_conv2_pad = F.pad(conv3_res2_conv1_relu, (1, 1, 1, 1))
conv3_res2_conv2 = self.conv3_res2_conv2(conv3_res2_conv2_pad)
conv3_res2_conv2_bn = self.conv3_res2_conv2_bn(conv3_res2_conv2)
conv3_res2_conv2_relu = F.relu(conv3_res2_conv2_bn)
conv3_res2_conv3 = self.conv3_res2_conv3(conv3_res2_conv2_relu)
conv3_res2 = conv3_res1 + conv3_res2_conv3
conv3_res3_pre_bn = self.conv3_res3_pre_bn(conv3_res2)
conv3_res3_pre_relu = F.relu(conv3_res3_pre_bn)
conv3_res3_conv1 = self.conv3_res3_conv1(conv3_res3_pre_relu)
conv3_res3_conv1_bn = self.conv3_res3_conv1_bn(conv3_res3_conv1)
conv3_res3_conv1_relu = F.relu(conv3_res3_conv1_bn)
conv3_res3_conv2_pad = F.pad(conv3_res3_conv1_relu, (1, 1, 1, 1))
conv3_res3_conv2 = self.conv3_res3_conv2(conv3_res3_conv2_pad)
conv3_res3_conv2_bn = self.conv3_res3_conv2_bn(conv3_res3_conv2)
conv3_res3_conv2_relu = F.relu(conv3_res3_conv2_bn)
conv3_res3_conv3 = self.conv3_res3_conv3(conv3_res3_conv2_relu)
conv3_res3 = conv3_res2 + conv3_res3_conv3
conv3_res4_pre_bn = self.conv3_res4_pre_bn(conv3_res3)
conv3_res4_pre_relu = F.relu(conv3_res4_pre_bn)
conv3_res4_conv1 = self.conv3_res4_conv1(conv3_res4_pre_relu)
conv3_res4_conv1_bn = self.conv3_res4_conv1_bn(conv3_res4_conv1)
conv3_res4_conv1_relu = F.relu(conv3_res4_conv1_bn)
conv3_res4_conv2_pad = F.pad(conv3_res4_conv1_relu, (1, 1, 1, 1))
conv3_res4_conv2 = self.conv3_res4_conv2(conv3_res4_conv2_pad)
conv3_res4_conv2_bn = self.conv3_res4_conv2_bn(conv3_res4_conv2)
conv3_res4_conv2_relu = F.relu(conv3_res4_conv2_bn)
conv3_res4_conv3 = self.conv3_res4_conv3(conv3_res4_conv2_relu)
conv3_res4 = conv3_res3 + conv3_res4_conv3
conv4_res1_pre_bn = self.conv4_res1_pre_bn(conv3_res4)
conv4_res1_pre_relu = F.relu(conv4_res1_pre_bn)
conv4_res1_proj = self.conv4_res1_proj(conv4_res1_pre_relu)
conv4_res1_conv1_pad = F.pad(conv4_res1_pre_relu, (1, 1, 1, 1))
conv4_res1_conv1 = self.conv4_res1_conv1(conv4_res1_conv1_pad)
conv4_res1_conv1_bn = self.conv4_res1_conv1_bn(conv4_res1_conv1)
conv4_res1_conv1_relu = F.relu(conv4_res1_conv1_bn)
conv4_res1_conv2 = self.conv4_res1_conv2(conv4_res1_conv1_relu)
conv4_res1 = conv4_res1_proj + conv4_res1_conv2
conv4_res2_pre_bn = self.conv4_res2_pre_bn(conv4_res1)
conv4_res2_pre_relu = F.relu(conv4_res2_pre_bn)
conv4_res2_conv1_proj = self.conv4_res2_conv1_proj(conv4_res2_pre_relu)
conv4_res2_conv1 = self.conv4_res2_conv1(conv4_res2_pre_relu)
conv4_res2_conv1_bn = self.conv4_res2_conv1_bn(conv4_res2_conv1)
conv4_res2_conv1_relu = F.relu(conv4_res2_conv1_bn)
conv4_res2_conv2_pad = F.pad(conv4_res2_conv1_relu, (1, 1, 1, 1))
conv4_res2_conv2 = self.conv4_res2_conv2(conv4_res2_conv2_pad)
conv4_res2_conv2_bn = self.conv4_res2_conv2_bn(conv4_res2_conv2)
conv4_res2_conv2_relu = F.relu(conv4_res2_conv2_bn)
conv4_res2_conv3 = self.conv4_res2_conv3(conv4_res2_conv2_relu)
conv4_res2 = conv4_res2_conv1_proj + conv4_res2_conv3
conv4_res3_pre_bn = self.conv4_res3_pre_bn(conv4_res2)
conv4_res3_pre_relu = F.relu(conv4_res3_pre_bn)
conv4_res3_conv1 = self.conv4_res3_conv1(conv4_res3_pre_relu)
conv4_res3_conv1_bn = self.conv4_res3_conv1_bn(conv4_res3_conv1)
conv4_res3_conv1_relu = F.relu(conv4_res3_conv1_bn)
conv4_res3_conv2_pad = F.pad(conv4_res3_conv1_relu, (1, 1, 1, 1))
conv4_res3_conv2 = self.conv4_res3_conv2(conv4_res3_conv2_pad)
conv4_res3_conv2_bn = self.conv4_res3_conv2_bn(conv4_res3_conv2)
conv4_res3_conv2_relu = F.relu(conv4_res3_conv2_bn)
conv4_res3_conv3 = self.conv4_res3_conv3(conv4_res3_conv2_relu)
conv4_res3 = conv4_res2 + conv4_res3_conv3
conv5_bn = self.conv5_bn(conv4_res3)
conv5_relu = F.relu(conv5_bn)
pool5 = F.avg_pool2d(conv5_relu, kernel_size=(4, 4), stride=(1, 1), padding=(0,), ceil_mode=False, count_include_pad=False)
fc1_0 = pool5.view(pool5.size(0), -1)
fc1_1 = self.fc1_1(fc1_0)
bn_fc1 = self.bn_fc1(fc1_1)
#return bn_fc1
bn_fc1 = bn_fc1.reshape(bn_fc1.size()[0], bn_fc1.size()[1])
slice_fc1, slice_fc2 = bn_fc1[:, :256], bn_fc1[:, 256:]
eltwise_fc1 = torch.max(slice_fc1, slice_fc2)
return eltwise_fc1
@staticmethod
def __conv(dim, name, **kwargs):
if dim == 1: layer = nn.Conv1d(**kwargs)
elif dim == 2: layer = nn.Conv2d(**kwargs)
elif dim == 3: layer = nn.Conv3d(**kwargs)
else: raise NotImplementedError()
layer.state_dict()['weight'].copy_(torch.from_numpy(_weights_dict[name]['weights']))
if 'bias' in _weights_dict[name]:
layer.state_dict()['bias'].copy_(torch.from_numpy(_weights_dict[name]['bias']))
return layer
@staticmethod
def __batch_normalization(dim, name, **kwargs):
if dim == 0 or dim == 1: layer = nn.BatchNorm1d(**kwargs)
elif dim == 2: layer = nn.BatchNorm2d(**kwargs)
elif dim == 3: layer = nn.BatchNorm3d(**kwargs)
else: raise NotImplementedError()
if 'scale' in _weights_dict[name]:
layer.state_dict()['weight'].copy_(torch.from_numpy(_weights_dict[name]['scale']))
else:
layer.weight.data.fill_(1)
if 'bias' in _weights_dict[name]:
layer.state_dict()['bias'].copy_(torch.from_numpy(_weights_dict[name]['bias']))
else:
layer.bias.data.fill_(0)
layer.state_dict()['running_mean'].copy_(torch.from_numpy(_weights_dict[name]['mean']))
layer.state_dict()['running_var'].copy_(torch.from_numpy(_weights_dict[name]['var']))
return layer
@staticmethod
def __dense(name, **kwargs):
layer = nn.Linear(**kwargs)
layer.state_dict()['weight'].copy_(torch.from_numpy(_weights_dict[name]['weights']))
if 'bias' in _weights_dict[name]:
layer.state_dict()['bias'].copy_(torch.from_numpy(_weights_dict[name]['bias']))
return layer