Spaces:
Build error
Build error
import numpy as np | |
import jax | |
from jax import random | |
import jax.numpy as jnp | |
import flax.linen as nn | |
from typing import Any, Tuple, List, Callable | |
import h5py | |
from . import ops | |
from stylegan2 import utils | |
URLS = {'afhqcat': 'https://www.dropbox.com/s/qygbjkefyqyu9k9/stylegan2_discriminator_afhqcat.h5?dl=1', | |
'afhqdog': 'https://www.dropbox.com/s/kmoxbp33qswz64p/stylegan2_discriminator_afhqdog.h5?dl=1', | |
'afhqwild': 'https://www.dropbox.com/s/jz1hpsyt3isj6e7/stylegan2_discriminator_afhqwild.h5?dl=1', | |
'brecahad': 'https://www.dropbox.com/s/h0cb89hruo6pmyj/stylegan2_discriminator_brecahad.h5?dl=1', | |
'car': 'https://www.dropbox.com/s/2ghjrmxih7cic76/stylegan2_discriminator_car.h5?dl=1', | |
'cat': 'https://www.dropbox.com/s/zfhjsvlsny5qixd/stylegan2_discriminator_cat.h5?dl=1', | |
'church': 'https://www.dropbox.com/s/jlno7zeivkjtk8g/stylegan2_discriminator_church.h5?dl=1', | |
'cifar10': 'https://www.dropbox.com/s/eldpubfkl4c6rur/stylegan2_discriminator_cifar10.h5?dl=1', | |
'ffhq': 'https://www.dropbox.com/s/m42qy9951b7lq1s/stylegan2_discriminator_ffhq.h5?dl=1', | |
'horse': 'https://www.dropbox.com/s/19f5pxrcdh2g8cw/stylegan2_discriminator_horse.h5?dl=1', | |
'metfaces': 'https://www.dropbox.com/s/xnokaunql12glkd/stylegan2_discriminator_metfaces.h5?dl=1'} | |
RESOLUTION = {'metfaces': 1024, | |
'ffhq': 1024, | |
'church': 256, | |
'cat': 256, | |
'horse': 256, | |
'car': 512, | |
'brecahad': 512, | |
'afhqwild': 512, | |
'afhqdog': 512, | |
'afhqcat': 512, | |
'cifar10': 32} | |
C_DIM = {'metfaces': 0, | |
'ffhq': 0, | |
'church': 0, | |
'cat': 0, | |
'horse': 0, | |
'car': 0, | |
'brecahad': 0, | |
'afhqwild': 0, | |
'afhqdog': 0, | |
'afhqcat': 0, | |
'cifar10': 10} | |
ARCHITECTURE = {'metfaces': 'resnet', | |
'ffhq': 'resnet', | |
'church': 'resnet', | |
'cat': 'resnet', | |
'horse': 'resnet', | |
'car': 'resnet', | |
'brecahad': 'resnet', | |
'afhqwild': 'resnet', | |
'afhqdog': 'resnet', | |
'afhqcat': 'resnet', | |
'cifar10': 'orig'} | |
MBSTD_GROUP_SIZE = {'metfaces': None, | |
'ffhq': None, | |
'church': None, | |
'cat': None, | |
'horse': None, | |
'car': None, | |
'brecahad': None, | |
'afhqwild': None, | |
'afhqdog': None, | |
'afhqcat': None, | |
'cifar10': 32} | |
class FromRGBLayer(nn.Module): | |
""" | |
From RGB Layer. | |
Attributes: | |
fmaps (int): Number of output channels of the convolution. | |
kernel (int): Kernel size of the convolution. | |
lr_multiplier (float): Learning rate multiplier. | |
activation (str): Activation function: 'relu', 'lrelu', etc. | |
param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored. | |
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping. | |
dtype (str): Data dtype. | |
rng (jax.random.PRNGKey): PRNG for initialization. | |
""" | |
fmaps: int | |
kernel: int=1 | |
lr_multiplier: float=1 | |
activation: str='leaky_relu' | |
param_dict: h5py.Group=None | |
clip_conv: float=None | |
dtype: str='float32' | |
rng: Any=random.PRNGKey(0) | |
def __call__(self, x, y): | |
""" | |
Run From RGB Layer. | |
Args: | |
x (tensor): Input image of shape [N, H, W, num_channels]. | |
y (tensor): Input tensor of shape [N, H, W, out_channels]. | |
Returns: | |
(tensor): Output tensor of shape [N, H, W, out_channels]. | |
""" | |
w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps] | |
w, b = ops.get_weight(w_shape, self.lr_multiplier, True, self.param_dict, 'fromrgb', self.rng) | |
w = self.param(name='weight', init_fn=lambda *_ : w) | |
b = self.param(name='bias', init_fn=lambda *_ : b) | |
w = ops.equalize_lr_weight(w, self.lr_multiplier) | |
b = ops.equalize_lr_bias(b, self.lr_multiplier) | |
x = x.astype(self.dtype) | |
x = ops.conv2d(x, w.astype(x.dtype)) | |
x += b.astype(x.dtype) | |
x = ops.apply_activation(x, activation=self.activation) | |
if self.clip_conv is not None: | |
x = jnp.clip(x, -self.clip_conv, self.clip_conv) | |
if y is not None: | |
x += y | |
return x | |
class DiscriminatorLayer(nn.Module): | |
""" | |
Discriminator Layer. | |
Attributes: | |
fmaps (int): Number of output channels of the convolution. | |
kernel (int): Kernel size of the convolution. | |
use_bias (bool): If True, use bias. | |
down (bool): If True, downsample the spatial resolution. | |
resample_kernel (Tuple): Kernel that is used for FIR filter. | |
activation (str): Activation function: 'relu', 'lrelu', etc. | |
layer_name (str): Layer name. | |
param_dict (h5py.Group): Parameter dict with pretrained parameters. | |
lr_multiplier (float): Learning rate multiplier. | |
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping. | |
dtype (str): Data dtype. | |
rng (jax.random.PRNGKey): PRNG for initialization. | |
""" | |
fmaps: int | |
kernel: int=3 | |
use_bias: bool=True | |
down: bool=False | |
resample_kernel: Tuple=None | |
activation: str='leaky_relu' | |
layer_name: str=None | |
param_dict: h5py.Group=None | |
lr_multiplier: float=1 | |
clip_conv: float=None | |
dtype: str='float32' | |
rng: Any=random.PRNGKey(0) | |
def __call__(self, x): | |
""" | |
Run Discriminator Layer. | |
Args: | |
x (tensor): Input tensor of shape [N, H, W, C]. | |
Returns: | |
(tensor): Output tensor of shape [N, H, W, fmaps]. | |
""" | |
w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps] | |
if self.use_bias: | |
w, b = ops.get_weight(w_shape, self.lr_multiplier, self.use_bias, self.param_dict, self.layer_name, self.rng) | |
else: | |
w = ops.get_weight(w_shape, self.lr_multiplier, self.use_bias, self.param_dict, self.layer_name, self.rng) | |
w = self.param(name='weight', init_fn=lambda *_ : w) | |
w = ops.equalize_lr_weight(w, self.lr_multiplier) | |
if self.use_bias: | |
b = self.param(name='bias', init_fn=lambda *_ : b) | |
b = ops.equalize_lr_bias(b, self.lr_multiplier) | |
x = x.astype(self.dtype) | |
x = ops.conv2d(x, w, down=self.down, resample_kernel=self.resample_kernel) | |
if self.use_bias: x += b.astype(x.dtype) | |
x = ops.apply_activation(x, activation=self.activation) | |
if self.clip_conv is not None: | |
x = jnp.clip(x, -self.clip_conv, self.clip_conv) | |
return x | |
class DiscriminatorBlock(nn.Module): | |
""" | |
Discriminator Block. | |
Attributes: | |
fmaps (int): Number of output channels of the convolution. | |
kernel (int): Kernel size of the convolution. | |
resample_kernel (Tuple): Kernel that is used for FIR filter. | |
activation (str): Activation function: 'relu', 'lrelu', etc. | |
param_dict (h5py.Group): Parameter dict with pretrained parameters. | |
lr_multiplier (float): Learning rate multiplier. | |
architecture (str): Architecture: 'orig', 'resnet'. | |
nf (Callable): Callable that returns the number of feature maps for a given layer. | |
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping. | |
dtype (str): Data dtype. | |
rng (jax.random.PRNGKey): Random seed for initialization. | |
""" | |
res: int | |
kernel: int=3 | |
resample_kernel: Tuple=(1, 3, 3, 1) | |
activation: str='leaky_relu' | |
param_dict: Any=None | |
lr_multiplier: float=1 | |
architecture: str='resnet' | |
nf: Callable=None | |
clip_conv: float=None | |
dtype: str='float32' | |
rng: Any=random.PRNGKey(0) | |
def __call__(self, x): | |
""" | |
Run Discriminator Block. | |
Args: | |
x (tensor): Input tensor of shape [N, H, W, C]. | |
Returns: | |
(tensor): Output tensor of shape [N, H, W, fmaps]. | |
""" | |
init_rng = self.rng | |
x = x.astype(self.dtype) | |
residual = x | |
for i in range(2): | |
init_rng, init_key = random.split(init_rng) | |
x = DiscriminatorLayer(fmaps=self.nf(self.res - (i + 1)), | |
kernel=self.kernel, | |
down=i == 1, | |
resample_kernel=self.resample_kernel if i == 1 else None, | |
activation=self.activation, | |
layer_name=f'conv{i}', | |
param_dict=self.param_dict, | |
lr_multiplier=self.lr_multiplier, | |
clip_conv=self.clip_conv, | |
dtype=self.dtype, | |
rng=init_key)(x) | |
if self.architecture == 'resnet': | |
init_rng, init_key = random.split(init_rng) | |
residual = DiscriminatorLayer(fmaps=self.nf(self.res - 2), | |
kernel=1, | |
use_bias=False, | |
down=True, | |
resample_kernel=self.resample_kernel, | |
activation='linear', | |
layer_name='skip', | |
param_dict=self.param_dict, | |
lr_multiplier=self.lr_multiplier, | |
dtype=self.dtype, | |
rng=init_key)(residual) | |
x = (x + residual) * np.sqrt(0.5, dtype=x.dtype) | |
return x | |
class Discriminator(nn.Module): | |
""" | |
Discriminator. | |
Attributes: | |
resolution (int): Input resolution. Overridden based on dataset. | |
num_channels (int): Number of input color channels. Overridden based on dataset. | |
c_dim (int): Dimensionality of the labels (c), 0 if no labels. Overrttten based on dataset. | |
fmap_base (int): Overall multiplier for the number of feature maps. | |
fmap_decay (int): Log2 feature map reduction when doubling the resolution. | |
fmap_min (int): Minimum number of feature maps in any layer. | |
fmap_max (int): Maximum number of feature maps in any layer. | |
mapping_layers (int): Number of additional mapping layers for the conditioning labels. | |
mapping_fmaps (int): Number of activations in the mapping layers, None = default. | |
mapping_lr_multiplier (float): Learning rate multiplier for the mapping layers. | |
architecture (str): Architecture: 'orig', 'resnet'. | |
activation (int): Activation function: 'relu', 'leaky_relu', etc. | |
mbstd_group_size (int): Group size for the minibatch standard deviation layer, None = entire minibatch. | |
mbstd_num_features (int): Number of features for the minibatch standard deviation layer, 0 = disable. | |
resample_kernel (Tuple): Low-pass filter to apply when resampling activations, None = box filter. | |
num_fp16_res (int): Use float16 for the 'num_fp16_res' highest resolutions. | |
clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping. | |
pretrained (str): Use pretrained model, None for random initialization. | |
ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used. | |
dtype (str): Data type. | |
rng (jax.random.PRNGKey): PRNG for initialization. | |
""" | |
# Input dimensions. | |
resolution: int=1024 | |
num_channels: int=3 | |
c_dim: int=0 | |
# Capacity. | |
fmap_base: int=16384 | |
fmap_decay: int=1 | |
fmap_min: int=1 | |
fmap_max: int=512 | |
# Internal details. | |
mapping_layers: int=0 | |
mapping_fmaps: int=None | |
mapping_lr_multiplier: float=0.1 | |
architecture: str='resnet' | |
activation: str='leaky_relu' | |
mbstd_group_size: int=None | |
mbstd_num_features: int=1 | |
resample_kernel: Tuple=(1, 3, 3, 1) | |
num_fp16_res: int=0 | |
clip_conv: float=None | |
# Pretraining | |
pretrained: str=None | |
ckpt_dir: str=None | |
dtype: str='float32' | |
rng: Any=random.PRNGKey(0) | |
def setup(self): | |
self.resolution_ = self.resolution | |
self.c_dim_ = self.c_dim | |
self.architecture_ = self.architecture | |
self.mbstd_group_size_ = self.mbstd_group_size | |
self.param_dict = None | |
if self.pretrained is not None: | |
assert self.pretrained in URLS.keys(), f'Pretrained model not available: {self.pretrained}' | |
ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained]) | |
self.param_dict = h5py.File(ckpt_file, 'r')['discriminator'] | |
self.resolution_ = RESOLUTION[self.pretrained] | |
self.architecture_ = ARCHITECTURE[self.pretrained] | |
self.mbstd_group_size_ = MBSTD_GROUP_SIZE[self.pretrained] | |
self.c_dim_ = C_DIM[self.pretrained] | |
assert self.architecture in ['orig', 'resnet'] | |
def __call__(self, x, c=None): | |
""" | |
Run Discriminator. | |
Args: | |
x (tensor): Input image of shape [N, H, W, num_channels]. | |
c (tensor): Input labels, shape [N, c_dim]. | |
Returns: | |
(tensor): Output tensor of shape [N, 1]. | |
""" | |
resolution_log2 = int(np.log2(self.resolution_)) | |
assert self.resolution_ == 2**resolution_log2 and self.resolution_ >= 4 | |
def nf(stage): return np.clip(int(self.fmap_base / (2.0 ** (stage * self.fmap_decay))), self.fmap_min, self.fmap_max) | |
if self.mapping_fmaps is None: | |
mapping_fmaps = nf(0) | |
else: | |
mapping_fmaps = self.mapping_fmaps | |
init_rng = self.rng | |
# Label embedding and mapping. | |
if self.c_dim_ > 0: | |
c = ops.LinearLayer(in_features=self.c_dim_, | |
out_features=mapping_fmaps, | |
lr_multiplier=self.mapping_lr_multiplier, | |
param_dict=self.param_dict, | |
layer_name='label_embedding', | |
dtype=self.dtype, | |
rng=init_rng)(c) | |
c = ops.normalize_2nd_moment(c) | |
for i in range(self.mapping_layers): | |
init_rng, init_key = random.split(init_rng) | |
c = ops.LinearLayer(in_features=self.c_dim_, | |
out_features=mapping_fmaps, | |
lr_multiplier=self.mapping_lr_multiplier, | |
param_dict=self.param_dict, | |
layer_name=f'fc{i}', | |
dtype=self.dtype, | |
rng=init_key)(c) | |
# Layers for >=8x8 resolutions. | |
y = None | |
for res in range(resolution_log2, 2, -1): | |
res_str = f'block_{2**res}x{2**res}' | |
if res == resolution_log2: | |
init_rng, init_key = random.split(init_rng) | |
x = FromRGBLayer(fmaps=nf(res - 1), | |
kernel=1, | |
activation=self.activation, | |
param_dict=self.param_dict[res_str] if self.param_dict is not None else None, | |
clip_conv=self.clip_conv, | |
dtype=self.dtype if res >= resolution_log2 + 1 - self.num_fp16_res else 'float32', | |
rng=init_key)(x, y) | |
init_rng, init_key = random.split(init_rng) | |
x = DiscriminatorBlock(res=res, | |
kernel=3, | |
resample_kernel=self.resample_kernel, | |
activation=self.activation, | |
param_dict=self.param_dict[res_str] if self.param_dict is not None else None, | |
architecture=self.architecture_, | |
nf=nf, | |
clip_conv=self.clip_conv, | |
dtype=self.dtype if res >= resolution_log2 + 1 - self.num_fp16_res else 'float32', | |
rng=init_key)(x) | |
# Layers for 4x4 resolution. | |
dtype = jnp.float32 | |
x = x.astype(dtype) | |
if self.mbstd_num_features > 0: | |
x = ops.minibatch_stddev_layer(x, self.mbstd_group_size_, self.mbstd_num_features) | |
init_rng, init_key = random.split(init_rng) | |
x = DiscriminatorLayer(fmaps=nf(1), | |
kernel=3, | |
use_bias=True, | |
activation=self.activation, | |
layer_name='conv0', | |
param_dict=self.param_dict['block_4x4'] if self.param_dict is not None else None, | |
clip_conv=self.clip_conv, | |
dtype=dtype, | |
rng=init_rng)(x) | |
# Switch to NCHW so that the pretrained weights still work after reshaping | |
x = jnp.transpose(x, axes=(0, 3, 1, 2)) | |
x = jnp.reshape(x, newshape=(-1, x.shape[1] * x.shape[2] * x.shape[3])) | |
init_rng, init_key = random.split(init_rng) | |
x = ops.LinearLayer(in_features=x.shape[1], | |
out_features=nf(0), | |
activation=self.activation, | |
param_dict=self.param_dict['block_4x4'] if self.param_dict is not None else None, | |
layer_name='fc0', | |
dtype=dtype, | |
rng=init_key)(x) | |
# Output layer. | |
init_rng, init_key = random.split(init_rng) | |
x = ops.LinearLayer(in_features=x.shape[1], | |
out_features=1 if self.c_dim_ == 0 else mapping_fmaps, | |
param_dict=self.param_dict, | |
layer_name='output', | |
dtype=dtype, | |
rng=init_key)(x) | |
if self.c_dim_ > 0: | |
x = jnp.sum(x * c, axis=1, keepdims=True) / jnp.sqrt(mapping_fmaps) | |
return x | |