|
import functools
|
|
|
|
import torch.nn as nn
|
|
|
|
from ..util import ActNorm
|
|
|
|
|
|
def weights_init(m):
|
|
classname = m.__class__.__name__
|
|
if classname.find("Conv") != -1:
|
|
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
|
elif classname.find("BatchNorm") != -1:
|
|
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
|
nn.init.constant_(m.bias.data, 0)
|
|
|
|
|
|
class NLayerDiscriminator(nn.Module):
|
|
"""Defines a PatchGAN discriminator as in Pix2Pix
|
|
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
|
"""
|
|
|
|
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
|
"""Construct a PatchGAN discriminator
|
|
Parameters:
|
|
input_nc (int) -- the number of channels in input images
|
|
ndf (int) -- the number of filters in the last conv layer
|
|
n_layers (int) -- the number of conv layers in the discriminator
|
|
norm_layer -- normalization layer
|
|
"""
|
|
super(NLayerDiscriminator, self).__init__()
|
|
if not use_actnorm:
|
|
norm_layer = nn.BatchNorm2d
|
|
else:
|
|
norm_layer = ActNorm
|
|
if (
|
|
type(norm_layer) == functools.partial
|
|
):
|
|
use_bias = norm_layer.func != nn.BatchNorm2d
|
|
else:
|
|
use_bias = norm_layer != nn.BatchNorm2d
|
|
|
|
kw = 4
|
|
padw = 1
|
|
sequence = [
|
|
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
|
nn.LeakyReLU(0.2, True),
|
|
]
|
|
nf_mult = 1
|
|
nf_mult_prev = 1
|
|
for n in range(1, n_layers):
|
|
nf_mult_prev = nf_mult
|
|
nf_mult = min(2**n, 8)
|
|
sequence += [
|
|
nn.Conv2d(
|
|
ndf * nf_mult_prev,
|
|
ndf * nf_mult,
|
|
kernel_size=kw,
|
|
stride=2,
|
|
padding=padw,
|
|
bias=use_bias,
|
|
),
|
|
norm_layer(ndf * nf_mult),
|
|
nn.LeakyReLU(0.2, True),
|
|
]
|
|
|
|
nf_mult_prev = nf_mult
|
|
nf_mult = min(2**n_layers, 8)
|
|
sequence += [
|
|
nn.Conv2d(
|
|
ndf * nf_mult_prev,
|
|
ndf * nf_mult,
|
|
kernel_size=kw,
|
|
stride=1,
|
|
padding=padw,
|
|
bias=use_bias,
|
|
),
|
|
norm_layer(ndf * nf_mult),
|
|
nn.LeakyReLU(0.2, True),
|
|
]
|
|
|
|
sequence += [
|
|
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
|
|
]
|
|
self.main = nn.Sequential(*sequence)
|
|
|
|
def forward(self, input):
|
|
"""Standard forward."""
|
|
return self.main(input)
|
|
|