File size: 23,029 Bytes
901e379 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 |
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
_weights_dict = dict()
def load_weights(weight_file):
if weight_file == None:
return
try:
weights_dict = np.load(weight_file, allow_pickle=True).item()
except:
weights_dict = np.load(weight_file, allow_pickle=True, encoding='bytes').item()
return weights_dict
class irn50_pytorch(nn.Module):
def __init__(self, weight_file):
super(irn50_pytorch, self).__init__()
global _weights_dict
_weights_dict = load_weights(weight_file)
self.Convolution1 = self.__conv(2, name='Convolution1', in_channels=3, out_channels=32, kernel_size=(3, 3), stride=(2, 2), groups=1, bias=False)
self.BatchNorm1 = self.__batch_normalization(2, 'BatchNorm1', num_features=32, eps=9.999999747378752e-06, momentum=0.0)
self.Convolution2 = self.__conv(2, name='Convolution2', in_channels=32, out_channels=32, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.BatchNorm2 = self.__batch_normalization(2, 'BatchNorm2', num_features=32, eps=9.999999747378752e-06, momentum=0.0)
self.Convolution3 = self.__conv(2, name='Convolution3', in_channels=32, out_channels=64, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.BatchNorm3 = self.__batch_normalization(2, 'BatchNorm3', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.Convolution4 = self.__conv(2, name='Convolution4', in_channels=64, out_channels=80, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.BatchNorm4 = self.__batch_normalization(2, 'BatchNorm4', num_features=80, eps=9.999999747378752e-06, momentum=0.0)
self.Convolution5 = self.__conv(2, name='Convolution5', in_channels=80, out_channels=192, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.BatchNorm5 = self.__batch_normalization(2, 'BatchNorm5', num_features=192, eps=9.999999747378752e-06, momentum=0.0)
self.Convolution6 = self.__conv(2, name='Convolution6', in_channels=192, out_channels=256, kernel_size=(3, 3), stride=(2, 2), groups=1, bias=False)
self.BatchNorm6 = self.__batch_normalization(2, 'BatchNorm6', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res1_proj = self.__conv(2, name='conv2_res1_proj', in_channels=256, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv2_res1_conv1 = self.__conv(2, name='conv2_res1_conv1', in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv2_res1_conv1_bn = self.__batch_normalization(2, 'conv2_res1_conv1_bn', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res1_conv2 = self.__conv(2, name='conv2_res1_conv2', in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv2_res1_conv2_bn = self.__batch_normalization(2, 'conv2_res1_conv2_bn', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res1_conv3 = self.__conv(2, name='conv2_res1_conv3', in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv2_res2_pre_bn = self.__batch_normalization(2, 'conv2_res2_pre_bn', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res2_conv1 = self.__conv(2, name='conv2_res2_conv1', in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv2_res2_conv1_bn = self.__batch_normalization(2, 'conv2_res2_conv1_bn', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res2_conv2 = self.__conv(2, name='conv2_res2_conv2', in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv2_res2_conv2_bn = self.__batch_normalization(2, 'conv2_res2_conv2_bn', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res2_conv3 = self.__conv(2, name='conv2_res2_conv3', in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv2_res3_pre_bn = self.__batch_normalization(2, 'conv2_res3_pre_bn', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res3_conv1 = self.__conv(2, name='conv2_res3_conv1', in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv2_res3_conv1_bn = self.__batch_normalization(2, 'conv2_res3_conv1_bn', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res3_conv2 = self.__conv(2, name='conv2_res3_conv2', in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv2_res3_conv2_bn = self.__batch_normalization(2, 'conv2_res3_conv2_bn', num_features=64, eps=9.999999747378752e-06, momentum=0.0)
self.conv2_res3_conv3 = self.__conv(2, name='conv2_res3_conv3', in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv3_res1_pre_bn = self.__batch_normalization(2, 'conv3_res1_pre_bn', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res1_proj = self.__conv(2, name='conv3_res1_proj', in_channels=256, out_channels=512, kernel_size=(1, 1), stride=(2, 2), groups=1, bias=False)
self.conv3_res1_conv1 = self.__conv(2, name='conv3_res1_conv1', in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(2, 2), groups=1, bias=False)
self.conv3_res1_conv1_bn = self.__batch_normalization(2, 'conv3_res1_conv1_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res1_conv2 = self.__conv(2, name='conv3_res1_conv2', in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv3_res1_conv2_bn = self.__batch_normalization(2, 'conv3_res1_conv2_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res1_conv3 = self.__conv(2, name='conv3_res1_conv3', in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv3_res2_pre_bn = self.__batch_normalization(2, 'conv3_res2_pre_bn', num_features=512, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res2_conv1 = self.__conv(2, name='conv3_res2_conv1', in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv3_res2_conv1_bn = self.__batch_normalization(2, 'conv3_res2_conv1_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res2_conv2 = self.__conv(2, name='conv3_res2_conv2', in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv3_res2_conv2_bn = self.__batch_normalization(2, 'conv3_res2_conv2_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res2_conv3 = self.__conv(2, name='conv3_res2_conv3', in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv3_res3_pre_bn = self.__batch_normalization(2, 'conv3_res3_pre_bn', num_features=512, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res3_conv1 = self.__conv(2, name='conv3_res3_conv1', in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv3_res3_conv1_bn = self.__batch_normalization(2, 'conv3_res3_conv1_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res3_conv2 = self.__conv(2, name='conv3_res3_conv2', in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv3_res3_conv2_bn = self.__batch_normalization(2, 'conv3_res3_conv2_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res3_conv3 = self.__conv(2, name='conv3_res3_conv3', in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv3_res4_pre_bn = self.__batch_normalization(2, 'conv3_res4_pre_bn', num_features=512, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res4_conv1 = self.__conv(2, name='conv3_res4_conv1', in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv3_res4_conv1_bn = self.__batch_normalization(2, 'conv3_res4_conv1_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res4_conv2 = self.__conv(2, name='conv3_res4_conv2', in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv3_res4_conv2_bn = self.__batch_normalization(2, 'conv3_res4_conv2_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv3_res4_conv3 = self.__conv(2, name='conv3_res4_conv3', in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv4_res1_pre_bn = self.__batch_normalization(2, 'conv4_res1_pre_bn', num_features=512, eps=9.999999747378752e-06, momentum=0.0)
self.conv4_res1_proj = self.__conv(2, name='conv4_res1_proj', in_channels=512, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv4_res1_conv1 = self.__conv(2, name='conv4_res1_conv1', in_channels=512, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv4_res1_conv1_bn = self.__batch_normalization(2, 'conv4_res1_conv1_bn', num_features=128, eps=9.999999747378752e-06, momentum=0.0)
self.conv4_res1_conv2 = self.__conv(2, name='conv4_res1_conv2', in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv4_res2_pre_bn = self.__batch_normalization(2, 'conv4_res2_pre_bn', num_features=512, eps=9.999999747378752e-06, momentum=0.0)
self.conv4_res2_conv1_proj = self.__conv(2, name='conv4_res2_conv1_proj', in_channels=512, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv4_res2_conv1 = self.__conv(2, name='conv4_res2_conv1', in_channels=512, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv4_res2_conv1_bn = self.__batch_normalization(2, 'conv4_res2_conv1_bn', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv4_res2_conv2 = self.__conv(2, name='conv4_res2_conv2', in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv4_res2_conv2_bn = self.__batch_normalization(2, 'conv4_res2_conv2_bn', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv4_res2_conv3 = self.__conv(2, name='conv4_res2_conv3', in_channels=256, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv4_res3_pre_bn = self.__batch_normalization(2, 'conv4_res3_pre_bn', num_features=1024, eps=9.999999747378752e-06, momentum=0.0)
self.conv4_res3_conv1 = self.__conv(2, name='conv4_res3_conv1', in_channels=1024, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv4_res3_conv1_bn = self.__batch_normalization(2, 'conv4_res3_conv1_bn', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv4_res3_conv2 = self.__conv(2, name='conv4_res3_conv2', in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False)
self.conv4_res3_conv2_bn = self.__batch_normalization(2, 'conv4_res3_conv2_bn', num_features=256, eps=9.999999747378752e-06, momentum=0.0)
self.conv4_res3_conv3 = self.__conv(2, name='conv4_res3_conv3', in_channels=256, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False)
self.conv5_bn = self.__batch_normalization(2, 'conv5_bn', num_features=1024, eps=9.999999747378752e-06, momentum=0.0)
self.fc1_1 = self.__dense(name = 'fc1_1', in_features = 16384, out_features = 512, bias = False)
self.bn_fc1 = self.__batch_normalization(1, 'bn_fc1', num_features=512, eps=9.999999747378752e-06, momentum=0.0)
def forward(self, x):
Convolution1 = self.Convolution1(x)
BatchNorm1 = self.BatchNorm1(Convolution1)
ReLU1 = F.relu(BatchNorm1)
Convolution2 = self.Convolution2(ReLU1)
BatchNorm2 = self.BatchNorm2(Convolution2)
ReLU2 = F.relu(BatchNorm2)
Convolution3_pad = F.pad(ReLU2, (1, 1, 1, 1))
Convolution3 = self.Convolution3(Convolution3_pad)
BatchNorm3 = self.BatchNorm3(Convolution3)
ReLU3 = F.relu(BatchNorm3)
Pooling1_pad = F.pad(ReLU3, (0, 1, 0, 1), value=float('-inf'))
Pooling1, Pooling1_idx = F.max_pool2d(Pooling1_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False, return_indices=True)
Convolution4 = self.Convolution4(Pooling1)
BatchNorm4 = self.BatchNorm4(Convolution4)
ReLU4 = F.relu(BatchNorm4)
Convolution5 = self.Convolution5(ReLU4)
BatchNorm5 = self.BatchNorm5(Convolution5)
ReLU5 = F.relu(BatchNorm5)
Convolution6_pad = F.pad(ReLU5, (1, 1, 1, 1))
Convolution6 = self.Convolution6(Convolution6_pad)
BatchNorm6 = self.BatchNorm6(Convolution6)
ReLU6 = F.relu(BatchNorm6)
conv2_res1_proj = self.conv2_res1_proj(ReLU6)
conv2_res1_conv1 = self.conv2_res1_conv1(ReLU6)
conv2_res1_conv1_bn = self.conv2_res1_conv1_bn(conv2_res1_conv1)
conv2_res1_conv1_relu = F.relu(conv2_res1_conv1_bn)
conv2_res1_conv2_pad = F.pad(conv2_res1_conv1_relu, (1, 1, 1, 1))
conv2_res1_conv2 = self.conv2_res1_conv2(conv2_res1_conv2_pad)
conv2_res1_conv2_bn = self.conv2_res1_conv2_bn(conv2_res1_conv2)
conv2_res1_conv2_relu = F.relu(conv2_res1_conv2_bn)
conv2_res1_conv3 = self.conv2_res1_conv3(conv2_res1_conv2_relu)
conv2_res1 = conv2_res1_proj + conv2_res1_conv3
conv2_res2_pre_bn = self.conv2_res2_pre_bn(conv2_res1)
conv2_res2_pre_relu = F.relu(conv2_res2_pre_bn)
conv2_res2_conv1 = self.conv2_res2_conv1(conv2_res2_pre_relu)
conv2_res2_conv1_bn = self.conv2_res2_conv1_bn(conv2_res2_conv1)
conv2_res2_conv1_relu = F.relu(conv2_res2_conv1_bn)
conv2_res2_conv2_pad = F.pad(conv2_res2_conv1_relu, (1, 1, 1, 1))
conv2_res2_conv2 = self.conv2_res2_conv2(conv2_res2_conv2_pad)
conv2_res2_conv2_bn = self.conv2_res2_conv2_bn(conv2_res2_conv2)
conv2_res2_conv2_relu = F.relu(conv2_res2_conv2_bn)
conv2_res2_conv3 = self.conv2_res2_conv3(conv2_res2_conv2_relu)
conv2_res2 = conv2_res1 + conv2_res2_conv3
conv2_res3_pre_bn = self.conv2_res3_pre_bn(conv2_res2)
conv2_res3_pre_relu = F.relu(conv2_res3_pre_bn)
conv2_res3_conv1 = self.conv2_res3_conv1(conv2_res3_pre_relu)
conv2_res3_conv1_bn = self.conv2_res3_conv1_bn(conv2_res3_conv1)
conv2_res3_conv1_relu = F.relu(conv2_res3_conv1_bn)
conv2_res3_conv2_pad = F.pad(conv2_res3_conv1_relu, (1, 1, 1, 1))
conv2_res3_conv2 = self.conv2_res3_conv2(conv2_res3_conv2_pad)
conv2_res3_conv2_bn = self.conv2_res3_conv2_bn(conv2_res3_conv2)
conv2_res3_conv2_relu = F.relu(conv2_res3_conv2_bn)
conv2_res3_conv3 = self.conv2_res3_conv3(conv2_res3_conv2_relu)
conv2_res3 = conv2_res2 + conv2_res3_conv3
conv3_res1_pre_bn = self.conv3_res1_pre_bn(conv2_res3)
conv3_res1_pre_relu = F.relu(conv3_res1_pre_bn)
conv3_res1_proj = self.conv3_res1_proj(conv3_res1_pre_relu)
conv3_res1_conv1 = self.conv3_res1_conv1(conv3_res1_pre_relu)
conv3_res1_conv1_bn = self.conv3_res1_conv1_bn(conv3_res1_conv1)
conv3_res1_conv1_relu = F.relu(conv3_res1_conv1_bn)
conv3_res1_conv2_pad = F.pad(conv3_res1_conv1_relu, (1, 1, 1, 1))
conv3_res1_conv2 = self.conv3_res1_conv2(conv3_res1_conv2_pad)
conv3_res1_conv2_bn = self.conv3_res1_conv2_bn(conv3_res1_conv2)
conv3_res1_conv2_relu = F.relu(conv3_res1_conv2_bn)
conv3_res1_conv3 = self.conv3_res1_conv3(conv3_res1_conv2_relu)
conv3_res1 = conv3_res1_proj + conv3_res1_conv3
conv3_res2_pre_bn = self.conv3_res2_pre_bn(conv3_res1)
conv3_res2_pre_relu = F.relu(conv3_res2_pre_bn)
conv3_res2_conv1 = self.conv3_res2_conv1(conv3_res2_pre_relu)
conv3_res2_conv1_bn = self.conv3_res2_conv1_bn(conv3_res2_conv1)
conv3_res2_conv1_relu = F.relu(conv3_res2_conv1_bn)
conv3_res2_conv2_pad = F.pad(conv3_res2_conv1_relu, (1, 1, 1, 1))
conv3_res2_conv2 = self.conv3_res2_conv2(conv3_res2_conv2_pad)
conv3_res2_conv2_bn = self.conv3_res2_conv2_bn(conv3_res2_conv2)
conv3_res2_conv2_relu = F.relu(conv3_res2_conv2_bn)
conv3_res2_conv3 = self.conv3_res2_conv3(conv3_res2_conv2_relu)
conv3_res2 = conv3_res1 + conv3_res2_conv3
conv3_res3_pre_bn = self.conv3_res3_pre_bn(conv3_res2)
conv3_res3_pre_relu = F.relu(conv3_res3_pre_bn)
conv3_res3_conv1 = self.conv3_res3_conv1(conv3_res3_pre_relu)
conv3_res3_conv1_bn = self.conv3_res3_conv1_bn(conv3_res3_conv1)
conv3_res3_conv1_relu = F.relu(conv3_res3_conv1_bn)
conv3_res3_conv2_pad = F.pad(conv3_res3_conv1_relu, (1, 1, 1, 1))
conv3_res3_conv2 = self.conv3_res3_conv2(conv3_res3_conv2_pad)
conv3_res3_conv2_bn = self.conv3_res3_conv2_bn(conv3_res3_conv2)
conv3_res3_conv2_relu = F.relu(conv3_res3_conv2_bn)
conv3_res3_conv3 = self.conv3_res3_conv3(conv3_res3_conv2_relu)
conv3_res3 = conv3_res2 + conv3_res3_conv3
conv3_res4_pre_bn = self.conv3_res4_pre_bn(conv3_res3)
conv3_res4_pre_relu = F.relu(conv3_res4_pre_bn)
conv3_res4_conv1 = self.conv3_res4_conv1(conv3_res4_pre_relu)
conv3_res4_conv1_bn = self.conv3_res4_conv1_bn(conv3_res4_conv1)
conv3_res4_conv1_relu = F.relu(conv3_res4_conv1_bn)
conv3_res4_conv2_pad = F.pad(conv3_res4_conv1_relu, (1, 1, 1, 1))
conv3_res4_conv2 = self.conv3_res4_conv2(conv3_res4_conv2_pad)
conv3_res4_conv2_bn = self.conv3_res4_conv2_bn(conv3_res4_conv2)
conv3_res4_conv2_relu = F.relu(conv3_res4_conv2_bn)
conv3_res4_conv3 = self.conv3_res4_conv3(conv3_res4_conv2_relu)
conv3_res4 = conv3_res3 + conv3_res4_conv3
conv4_res1_pre_bn = self.conv4_res1_pre_bn(conv3_res4)
conv4_res1_pre_relu = F.relu(conv4_res1_pre_bn)
conv4_res1_proj = self.conv4_res1_proj(conv4_res1_pre_relu)
conv4_res1_conv1_pad = F.pad(conv4_res1_pre_relu, (1, 1, 1, 1))
conv4_res1_conv1 = self.conv4_res1_conv1(conv4_res1_conv1_pad)
conv4_res1_conv1_bn = self.conv4_res1_conv1_bn(conv4_res1_conv1)
conv4_res1_conv1_relu = F.relu(conv4_res1_conv1_bn)
conv4_res1_conv2 = self.conv4_res1_conv2(conv4_res1_conv1_relu)
conv4_res1 = conv4_res1_proj + conv4_res1_conv2
conv4_res2_pre_bn = self.conv4_res2_pre_bn(conv4_res1)
conv4_res2_pre_relu = F.relu(conv4_res2_pre_bn)
conv4_res2_conv1_proj = self.conv4_res2_conv1_proj(conv4_res2_pre_relu)
conv4_res2_conv1 = self.conv4_res2_conv1(conv4_res2_pre_relu)
conv4_res2_conv1_bn = self.conv4_res2_conv1_bn(conv4_res2_conv1)
conv4_res2_conv1_relu = F.relu(conv4_res2_conv1_bn)
conv4_res2_conv2_pad = F.pad(conv4_res2_conv1_relu, (1, 1, 1, 1))
conv4_res2_conv2 = self.conv4_res2_conv2(conv4_res2_conv2_pad)
conv4_res2_conv2_bn = self.conv4_res2_conv2_bn(conv4_res2_conv2)
conv4_res2_conv2_relu = F.relu(conv4_res2_conv2_bn)
conv4_res2_conv3 = self.conv4_res2_conv3(conv4_res2_conv2_relu)
conv4_res2 = conv4_res2_conv1_proj + conv4_res2_conv3
conv4_res3_pre_bn = self.conv4_res3_pre_bn(conv4_res2)
conv4_res3_pre_relu = F.relu(conv4_res3_pre_bn)
conv4_res3_conv1 = self.conv4_res3_conv1(conv4_res3_pre_relu)
conv4_res3_conv1_bn = self.conv4_res3_conv1_bn(conv4_res3_conv1)
conv4_res3_conv1_relu = F.relu(conv4_res3_conv1_bn)
conv4_res3_conv2_pad = F.pad(conv4_res3_conv1_relu, (1, 1, 1, 1))
conv4_res3_conv2 = self.conv4_res3_conv2(conv4_res3_conv2_pad)
conv4_res3_conv2_bn = self.conv4_res3_conv2_bn(conv4_res3_conv2)
conv4_res3_conv2_relu = F.relu(conv4_res3_conv2_bn)
conv4_res3_conv3 = self.conv4_res3_conv3(conv4_res3_conv2_relu)
conv4_res3 = conv4_res2 + conv4_res3_conv3
conv5_bn = self.conv5_bn(conv4_res3)
conv5_relu = F.relu(conv5_bn)
pool5 = F.avg_pool2d(conv5_relu, kernel_size=(4, 4), stride=(1, 1), padding=(0,), ceil_mode=False, count_include_pad=False)
fc1_0 = pool5.view(pool5.size(0), -1)
fc1_1 = self.fc1_1(fc1_0)
bn_fc1 = self.bn_fc1(fc1_1)
#return bn_fc1
bn_fc1 = bn_fc1.reshape(bn_fc1.size()[0], bn_fc1.size()[1])
slice_fc1, slice_fc2 = bn_fc1[:, :256], bn_fc1[:, 256:]
eltwise_fc1 = torch.max(slice_fc1, slice_fc2)
return eltwise_fc1
@staticmethod
def __conv(dim, name, **kwargs):
if dim == 1: layer = nn.Conv1d(**kwargs)
elif dim == 2: layer = nn.Conv2d(**kwargs)
elif dim == 3: layer = nn.Conv3d(**kwargs)
else: raise NotImplementedError()
layer.state_dict()['weight'].copy_(torch.from_numpy(_weights_dict[name]['weights']))
if 'bias' in _weights_dict[name]:
layer.state_dict()['bias'].copy_(torch.from_numpy(_weights_dict[name]['bias']))
return layer
@staticmethod
def __batch_normalization(dim, name, **kwargs):
if dim == 0 or dim == 1: layer = nn.BatchNorm1d(**kwargs)
elif dim == 2: layer = nn.BatchNorm2d(**kwargs)
elif dim == 3: layer = nn.BatchNorm3d(**kwargs)
else: raise NotImplementedError()
if 'scale' in _weights_dict[name]:
layer.state_dict()['weight'].copy_(torch.from_numpy(_weights_dict[name]['scale']))
else:
layer.weight.data.fill_(1)
if 'bias' in _weights_dict[name]:
layer.state_dict()['bias'].copy_(torch.from_numpy(_weights_dict[name]['bias']))
else:
layer.bias.data.fill_(0)
layer.state_dict()['running_mean'].copy_(torch.from_numpy(_weights_dict[name]['mean']))
layer.state_dict()['running_var'].copy_(torch.from_numpy(_weights_dict[name]['var']))
return layer
@staticmethod
def __dense(name, **kwargs):
layer = nn.Linear(**kwargs)
layer.state_dict()['weight'].copy_(torch.from_numpy(_weights_dict[name]['weights']))
if 'bias' in _weights_dict[name]:
layer.state_dict()['bias'].copy_(torch.from_numpy(_weights_dict[name]['bias']))
return layer
|