Arnaudding001 commited on
Commit
d0ad51e
1 Parent(s): 8ae9e74

Create encoder_encorders_psp_encoders.py

Browse files
Files changed (1) hide show
  1. encoder_encorders_psp_encoders.py +186 -0
encoder_encorders_psp_encoders.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ from torch.nn import Linear, Conv2d, BatchNorm2d, PReLU, Sequential, Module
6
+
7
+ from model.encoder.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE
8
+ from model.stylegan.model import EqualLinear
9
+
10
+
11
+ class GradualStyleBlock(Module):
12
+ def __init__(self, in_c, out_c, spatial):
13
+ super(GradualStyleBlock, self).__init__()
14
+ self.out_c = out_c
15
+ self.spatial = spatial
16
+ num_pools = int(np.log2(spatial))
17
+ modules = []
18
+ modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
19
+ nn.LeakyReLU()]
20
+ for i in range(num_pools - 1):
21
+ modules += [
22
+ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
23
+ nn.LeakyReLU()
24
+ ]
25
+ self.convs = nn.Sequential(*modules)
26
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
27
+
28
+ def forward(self, x):
29
+ x = self.convs(x)
30
+ x = x.view(-1, self.out_c)
31
+ x = self.linear(x)
32
+ return x
33
+
34
+
35
+ class GradualStyleEncoder(Module):
36
+ def __init__(self, num_layers, mode='ir', opts=None):
37
+ super(GradualStyleEncoder, self).__init__()
38
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
39
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
40
+ blocks = get_blocks(num_layers)
41
+ if mode == 'ir':
42
+ unit_module = bottleneck_IR
43
+ elif mode == 'ir_se':
44
+ unit_module = bottleneck_IR_SE
45
+ self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
46
+ BatchNorm2d(64),
47
+ PReLU(64))
48
+ modules = []
49
+ for block in blocks:
50
+ for bottleneck in block:
51
+ modules.append(unit_module(bottleneck.in_channel,
52
+ bottleneck.depth,
53
+ bottleneck.stride))
54
+ self.body = Sequential(*modules)
55
+
56
+ self.styles = nn.ModuleList()
57
+ self.style_count = opts.n_styles
58
+ self.coarse_ind = 3
59
+ self.middle_ind = 7
60
+ for i in range(self.style_count):
61
+ if i < self.coarse_ind:
62
+ style = GradualStyleBlock(512, 512, 16)
63
+ elif i < self.middle_ind:
64
+ style = GradualStyleBlock(512, 512, 32)
65
+ else:
66
+ style = GradualStyleBlock(512, 512, 64)
67
+ self.styles.append(style)
68
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
69
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
70
+
71
+ def _upsample_add(self, x, y):
72
+ '''Upsample and add two feature maps.
73
+ Args:
74
+ x: (Variable) top feature map to be upsampled.
75
+ y: (Variable) lateral feature map.
76
+ Returns:
77
+ (Variable) added feature map.
78
+ Note in PyTorch, when input size is odd, the upsampled feature map
79
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
80
+ maybe not equal to the lateral feature map size.
81
+ e.g.
82
+ original input size: [N,_,15,15] ->
83
+ conv2d feature map size: [N,_,8,8] ->
84
+ upsampled feature map size: [N,_,16,16]
85
+ So we choose bilinear upsample which supports arbitrary output sizes.
86
+ '''
87
+ _, _, H, W = y.size()
88
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
89
+
90
+ def forward(self, x):
91
+ x = self.input_layer(x)
92
+
93
+ latents = []
94
+ modulelist = list(self.body._modules.values())
95
+ for i, l in enumerate(modulelist):
96
+ x = l(x)
97
+ if i == 6:
98
+ c1 = x
99
+ elif i == 20:
100
+ c2 = x
101
+ elif i == 23:
102
+ c3 = x
103
+
104
+ for j in range(self.coarse_ind):
105
+ latents.append(self.styles[j](c3))
106
+
107
+ p2 = self._upsample_add(c3, self.latlayer1(c2))
108
+ for j in range(self.coarse_ind, self.middle_ind):
109
+ latents.append(self.styles[j](p2))
110
+
111
+ p1 = self._upsample_add(p2, self.latlayer2(c1))
112
+ for j in range(self.middle_ind, self.style_count):
113
+ latents.append(self.styles[j](p1))
114
+
115
+ out = torch.stack(latents, dim=1)
116
+ return out
117
+
118
+
119
+ class BackboneEncoderUsingLastLayerIntoW(Module):
120
+ def __init__(self, num_layers, mode='ir', opts=None):
121
+ super(BackboneEncoderUsingLastLayerIntoW, self).__init__()
122
+ print('Using BackboneEncoderUsingLastLayerIntoW')
123
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
124
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
125
+ blocks = get_blocks(num_layers)
126
+ if mode == 'ir':
127
+ unit_module = bottleneck_IR
128
+ elif mode == 'ir_se':
129
+ unit_module = bottleneck_IR_SE
130
+ self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
131
+ BatchNorm2d(64),
132
+ PReLU(64))
133
+ self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
134
+ self.linear = EqualLinear(512, 512, lr_mul=1)
135
+ modules = []
136
+ for block in blocks:
137
+ for bottleneck in block:
138
+ modules.append(unit_module(bottleneck.in_channel,
139
+ bottleneck.depth,
140
+ bottleneck.stride))
141
+ self.body = Sequential(*modules)
142
+
143
+ def forward(self, x):
144
+ x = self.input_layer(x)
145
+ x = self.body(x)
146
+ x = self.output_pool(x)
147
+ x = x.view(-1, 512)
148
+ x = self.linear(x)
149
+ return x
150
+
151
+
152
+ class BackboneEncoderUsingLastLayerIntoWPlus(Module):
153
+ def __init__(self, num_layers, mode='ir', opts=None):
154
+ super(BackboneEncoderUsingLastLayerIntoWPlus, self).__init__()
155
+ print('Using BackboneEncoderUsingLastLayerIntoWPlus')
156
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
157
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
158
+ blocks = get_blocks(num_layers)
159
+ if mode == 'ir':
160
+ unit_module = bottleneck_IR
161
+ elif mode == 'ir_se':
162
+ unit_module = bottleneck_IR_SE
163
+ self.n_styles = opts.n_styles
164
+ self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
165
+ BatchNorm2d(64),
166
+ PReLU(64))
167
+ self.output_layer_2 = Sequential(BatchNorm2d(512),
168
+ torch.nn.AdaptiveAvgPool2d((7, 7)),
169
+ Flatten(),
170
+ Linear(512 * 7 * 7, 512))
171
+ self.linear = EqualLinear(512, 512 * self.n_styles, lr_mul=1)
172
+ modules = []
173
+ for block in blocks:
174
+ for bottleneck in block:
175
+ modules.append(unit_module(bottleneck.in_channel,
176
+ bottleneck.depth,
177
+ bottleneck.stride))
178
+ self.body = Sequential(*modules)
179
+
180
+ def forward(self, x):
181
+ x = self.input_layer(x)
182
+ x = self.body(x)
183
+ x = self.output_layer_2(x)
184
+ x = self.linear(x)
185
+ x = x.view(-1, self.n_styles, 512)
186
+ return x