Spaces:
Runtime error
Runtime error
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler | |
from src import (ContentEncoder, | |
StyleEncoder, | |
UNet) | |
def build_unet(args): | |
unet = UNet( | |
sample_size=args.resolution, | |
in_channels=3, | |
out_channels=3, | |
flip_sin_to_cos=True, | |
freq_shift=0, | |
down_block_types=('DownBlock2D', | |
'MCADownBlock2D', | |
'MCADownBlock2D', | |
'DownBlock2D'), | |
up_block_types=('UpBlock2D', | |
'StyleRSIUpBlock2D', | |
'StyleRSIUpBlock2D', | |
'UpBlock2D'), | |
block_out_channels=args.unet_channels, | |
layers_per_block=2, | |
downsample_padding=1, | |
mid_block_scale_factor=1, | |
act_fn='silu', | |
norm_num_groups=32, | |
norm_eps=1e-05, | |
cross_attention_dim=args.style_start_channel * 16, | |
attention_head_dim=1, | |
channel_attn=args.channel_attn, | |
content_encoder_downsample_size=args.content_encoder_downsample_size, | |
content_start_channel=args.content_start_channel, | |
reduction=32) | |
return unet | |
def build_style_encoder(args): | |
style_image_encoder = StyleEncoder( | |
G_ch=args.style_start_channel, | |
resolution=args.style_image_size[0]) | |
print("Get CG-GAN Style Encoder!") | |
return style_image_encoder | |
def build_content_encoder(args): | |
content_image_encoder = ContentEncoder( | |
G_ch=args.content_start_channel, | |
resolution=args.content_image_size[0]) | |
print("Get CG-GAN Content Encoder!") | |
return content_image_encoder | |
def build_ddpm_scheduler(args): | |
ddpm_scheduler = DDPMScheduler( | |
num_train_timesteps=1000, | |
beta_start=0.0001, | |
beta_end=0.02, | |
beta_schedule=args.beta_scheduler, | |
trained_betas=None, | |
variance_type="fixed_small", | |
clip_sample=True) | |
return ddpm_scheduler |