from typing import Tuple, Union import jax import jax.numpy as jnp import flax.linen as nn from flax.core.frozen_dict import FrozenDict from diffusers.configuration_utils import ConfigMixin, flax_register_to_config from diffusers.models.modeling_flax_utils import FlaxModelMixin from diffusers.utils import BaseOutput from .flax_unet_pseudo3d_blocks import ( CrossAttnDownBlockPseudo3D, CrossAttnUpBlockPseudo3D, DownBlockPseudo3D, UpBlockPseudo3D, UNetMidBlockPseudo3DCrossAttn ) #from flax_embeddings import ( # TimestepEmbedding, # Timesteps #) from diffusers.models.embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps from .flax_resnet_pseudo3d import ConvPseudo3D class UNetPseudo3DConditionOutput(BaseOutput): sample: jax.Array @flax_register_to_config class UNetPseudo3DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): sample_size: Union[int, Tuple[int, int]] = (64, 64) in_channels: int = 4 out_channels: int = 4 down_block_types: Tuple[str] = ( "CrossAttnDownBlockPseudo3D", "CrossAttnDownBlockPseudo3D", "CrossAttnDownBlockPseudo3D", "DownBlockPseudo3D" ) up_block_types: Tuple[str] = ( "UpBlockPseudo3D", "CrossAttnUpBlockPseudo3D", "CrossAttnUpBlockPseudo3D", "CrossAttnUpBlockPseudo3D" ) block_out_channels: Tuple[int] = ( 320, 640, 1280, 1280 ) layers_per_block: int = 2 attention_head_dim: Union[int, Tuple[int]] = 8 cross_attention_dim: int = 768 flip_sin_to_cos: bool = True freq_shift: int = 0 use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 param_dtype: str = 'float32' def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: if self.param_dtype == 'bfloat16': param_dtype = jnp.bfloat16 elif self.param_dtype == 'float16': param_dtype = jnp.float16 elif self.param_dtype == 'float32': param_dtype = jnp.float32 else: raise ValueError(f'unknown parameter type: {self.param_dtype}') sample_size = self.sample_size if isinstance(sample_size, int): sample_size = (sample_size, sample_size) sample_shape = (1, self.in_channels, 1, *sample_size) sample = jnp.zeros(sample_shape, dtype = param_dtype) timesteps = jnp.ones((1, ), dtype = jnp.int32) encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype = param_dtype) params_rng, dropout_rng = jax.random.split(rng) rngs = { "params": params_rng, "dropout": dropout_rng } return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"] def setup(self) -> None: if isinstance(self.attention_head_dim, int): attention_head_dim = (self.attention_head_dim, ) * len(self.down_block_types) else: attention_head_dim = self.attention_head_dim time_embed_dim = self.block_out_channels[0] * 4 self.conv_in = ConvPseudo3D( features = self.block_out_channels[0], kernel_size = (3, 3), strides = (1, 1), padding = ((1, 1), (1, 1)), dtype = self.dtype ) self.time_proj = FlaxTimesteps( dim = self.block_out_channels[0], flip_sin_to_cos = self.flip_sin_to_cos, freq_shift = self.freq_shift ) self.time_embedding = FlaxTimestepEmbedding( time_embed_dim = time_embed_dim, dtype = self.dtype ) down_blocks = [] output_channels = self.block_out_channels[0] for i, down_block_type in enumerate(self.down_block_types): input_channels = output_channels output_channels = self.block_out_channels[i] is_final_block = i == len(self.block_out_channels) - 1 # allows loading 3d models with old layer type names in their configs # eg. 2D instead of Pseudo3D, like lxj's timelapse model if down_block_type in ['CrossAttnDownBlockPseudo3D', 'CrossAttnDownBlock2D']: down_block = CrossAttnDownBlockPseudo3D( in_channels = input_channels, out_channels = output_channels, num_layers = self.layers_per_block, attn_num_head_channels = attention_head_dim[i], add_downsample = not is_final_block, use_memory_efficient_attention = self.use_memory_efficient_attention, dtype = self.dtype ) elif down_block_type in ['DownBlockPseudo3D', 'DownBlock2D']: down_block = DownBlockPseudo3D( in_channels = input_channels, out_channels = output_channels, num_layers = self.layers_per_block, add_downsample = not is_final_block, dtype = self.dtype ) else: raise NotImplementedError(f'Unimplemented down block type: {down_block_type}') down_blocks.append(down_block) self.down_blocks = down_blocks self.mid_block = UNetMidBlockPseudo3DCrossAttn( in_channels = self.block_out_channels[-1], attn_num_head_channels = attention_head_dim[-1], use_memory_efficient_attention = self.use_memory_efficient_attention, dtype = self.dtype ) up_blocks = [] reversed_block_out_channels = list(reversed(self.block_out_channels)) reversed_attention_head_dim = list(reversed(attention_head_dim)) output_channels = reversed_block_out_channels[0] for i, up_block_type in enumerate(self.up_block_types): prev_output_channels = output_channels output_channels = reversed_block_out_channels[i] input_channels = reversed_block_out_channels[min(i + 1, len(self.block_out_channels) - 1)] is_final_block = i == len(self.block_out_channels) - 1 if up_block_type in ['CrossAttnUpBlockPseudo3D', 'CrossAttnUpBlock2D']: up_block = CrossAttnUpBlockPseudo3D( in_channels = input_channels, out_channels = output_channels, prev_output_channels = prev_output_channels, num_layers = self.layers_per_block + 1, attn_num_head_channels = reversed_attention_head_dim[i], add_upsample = not is_final_block, use_memory_efficient_attention = self.use_memory_efficient_attention, dtype = self.dtype ) elif up_block_type in ['UpBlockPseudo3D', 'UpBlock2D']: up_block = UpBlockPseudo3D( in_channels = input_channels, out_channels = output_channels, prev_output_channels = prev_output_channels, num_layers = self.layers_per_block + 1, add_upsample = not is_final_block, dtype = self.dtype ) else: raise NotImplementedError(f'Unimplemented up block type: {up_block_type}') up_blocks.append(up_block) self.up_blocks = up_blocks self.conv_norm_out = nn.GroupNorm( num_groups = 32, epsilon = 1e-5 ) self.conv_out = ConvPseudo3D( features = self.out_channels, kernel_size = (3, 3), strides = (1, 1), padding = ((1, 1), (1, 1)), dtype = self.dtype ) def __call__(self, sample: jax.Array, timesteps: jax.Array, encoder_hidden_states: jax.Array, return_dict: bool = True ) -> Union[UNetPseudo3DConditionOutput, Tuple[jax.Array]]: if timesteps.dtype != jnp.float32: timesteps = timesteps.astype(dtype = jnp.float32) if len(timesteps.shape) == 0: timesteps = jnp.expand_dims(timesteps, 0) # b,c,f,h,w -> b,f,h,w,c sample = sample.transpose((0, 2, 3, 4, 1)) t_emb = self.time_proj(timesteps) t_emb = self.time_embedding(t_emb) sample = self.conv_in(sample) down_block_res_samples = (sample, ) for down_block in self.down_blocks: if isinstance(down_block, CrossAttnDownBlockPseudo3D): sample, res_samples = down_block( hidden_states = sample, temb = t_emb, encoder_hidden_states = encoder_hidden_states ) elif isinstance(down_block, DownBlockPseudo3D): sample, res_samples = down_block( hidden_states = sample, temb = t_emb ) else: raise NotImplementedError(f'Unimplemented down block type: {down_block.__class__.__name__}') down_block_res_samples += res_samples sample = self.mid_block( hidden_states = sample, temb = t_emb, encoder_hidden_states = encoder_hidden_states ) for up_block in self.up_blocks: res_samples = down_block_res_samples[-(self.layers_per_block + 1):] down_block_res_samples = down_block_res_samples[:-(self.layers_per_block + 1)] if isinstance(up_block, CrossAttnUpBlockPseudo3D): sample = up_block( hidden_states = sample, temb = t_emb, encoder_hidden_states = encoder_hidden_states, res_hidden_states_tuple = res_samples ) elif isinstance(up_block, UpBlockPseudo3D): sample = up_block( hidden_states = sample, temb = t_emb, res_hidden_states_tuple = res_samples ) else: raise NotImplementedError(f'Unimplemented up block type: {up_block.__class__.__name__}') sample = self.conv_norm_out(sample) sample = nn.silu(sample) sample = self.conv_out(sample) # b,f,h,w,c -> b,c,f,h,w sample = sample.transpose((0, 4, 1, 2, 3)) if not return_dict: return (sample, ) return UNetPseudo3DConditionOutput(sample = sample)