File size: 496 Bytes
9b2bdf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from torch import nn
class LatentCodesDiscriminator(nn.Module):
def __init__(self, style_dim, n_mlp):
super().__init__()
self.style_dim = style_dim
layers = []
for i in range(n_mlp-1):
layers.append(
nn.Linear(style_dim, style_dim)
)
layers.append(nn.LeakyReLU(0.2))
layers.append(nn.Linear(512, 1))
self.mlp = nn.Sequential(*layers)
def forward(self, w):
return self.mlp(w)
|