from typing import Optional import jax import jax.numpy as jnp import flax.linen as nn import einops #from flax_memory_efficient_attention import jax_memory_efficient_attention #from flax_attention import FlaxAttention from diffusers.models.attention_flax import FlaxAttention class TransformerPseudo3DModel(nn.Module): in_channels: int num_attention_heads: int attention_head_dim: int num_layers: int = 1 use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self) -> None: inner_dim = self.num_attention_heads * self.attention_head_dim self.norm = nn.GroupNorm( num_groups = 32, epsilon = 1e-5 ) self.proj_in = nn.Conv( inner_dim, kernel_size = (1, 1), strides = (1, 1), padding = 'VALID', dtype = self.dtype ) transformer_blocks = [] #CheckpointTransformerBlock = nn.checkpoint( # BasicTransformerBlockPseudo3D, # static_argnums = (2,3,4) # #prevent_cse = False #) CheckpointTransformerBlock = BasicTransformerBlockPseudo3D for _ in range(self.num_layers): transformer_blocks.append(CheckpointTransformerBlock( dim = inner_dim, num_attention_heads = self.num_attention_heads, attention_head_dim = self.attention_head_dim, use_memory_efficient_attention = self.use_memory_efficient_attention, dtype = self.dtype )) self.transformer_blocks = transformer_blocks self.proj_out = nn.Conv( inner_dim, kernel_size = (1, 1), strides = (1, 1), padding = 'VALID', dtype = self.dtype ) def __call__(self, hidden_states: jax.Array, encoder_hidden_states: Optional[jax.Array] = None ) -> jax.Array: is_video = hidden_states.ndim == 5 f: Optional[int] = None if is_video: # jax is channels last # b,c,f,h,w WRONG # b,f,h,w,c CORRECT # b, c, f, h, w = hidden_states.shape #hidden_states = einops.rearrange(hidden_states, 'b c f h w -> (b f) c h w') b, f, h, w, c = hidden_states.shape hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c') batch, height, width, channels = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) hidden_states = self.proj_in(hidden_states) hidden_states = hidden_states.reshape(batch, height * width, channels) for block in self.transformer_blocks: hidden_states = block( hidden_states, encoder_hidden_states, f, height, width ) hidden_states = hidden_states.reshape(batch, height, width, channels) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states + residual if is_video: hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b) return hidden_states class BasicTransformerBlockPseudo3D(nn.Module): dim: int num_attention_heads: int attention_head_dim: int use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self) -> None: self.attn1 = FlaxAttention( query_dim = self.dim, heads = self.num_attention_heads, dim_head = self.attention_head_dim, use_memory_efficient_attention = self.use_memory_efficient_attention, dtype = self.dtype ) self.ff = FeedForward(dim = self.dim, dtype = self.dtype) self.attn2 = FlaxAttention( query_dim = self.dim, heads = self.num_attention_heads, dim_head = self.attention_head_dim, use_memory_efficient_attention = self.use_memory_efficient_attention, dtype = self.dtype ) self.attn_temporal = FlaxAttention( query_dim = self.dim, heads = self.num_attention_heads, dim_head = self.attention_head_dim, use_memory_efficient_attention = self.use_memory_efficient_attention, dtype = self.dtype ) self.norm1 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype) self.norm2 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype) self.norm_temporal = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype) self.norm3 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype) def __call__(self, hidden_states: jax.Array, context: Optional[jax.Array] = None, frames_length: Optional[int] = None, height: Optional[int] = None, width: Optional[int] = None ) -> jax.Array: if context is not None and frames_length is not None: context = context.repeat(frames_length, axis = 0) # self attention norm_hidden_states = self.norm1(hidden_states) hidden_states = self.attn1(norm_hidden_states) + hidden_states # cross attention norm_hidden_states = self.norm2(hidden_states) hidden_states = self.attn2( norm_hidden_states, context = context ) + hidden_states # temporal attention if frames_length is not None: #bf, hw, c = hidden_states.shape # (b f) (h w) c -> b f (h w) c #hidden_states = hidden_states.reshape(bf // frames_length, frames_length, hw, c) #b, f, hw, c = hidden_states.shape # b f (h w) c -> b (h w) f c #hidden_states = hidden_states.transpose(0, 2, 1, 3) # b (h w) f c -> (b h w) f c #hidden_states = hidden_states.reshape(b * hw, frames_length, c) hidden_states = einops.rearrange( hidden_states, '(b f) (h w) c -> (b h w) f c', f = frames_length, h = height, w = width ) norm_hidden_states = self.norm_temporal(hidden_states) hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states # (b h w) f c -> b (h w) f c #hidden_states = hidden_states.reshape(b, hw, f, c) # b (h w) f c -> b f (h w) c #hidden_states = hidden_states.transpose(0, 2, 1, 3) # b f h w c -> (b f) (h w) c #hidden_states = hidden_states.reshape(bf, hw, c) hidden_states = einops.rearrange( hidden_states, '(b h w) f c -> (b f) (h w) c', f = frames_length, h = height, w = width ) norm_hidden_states = self.norm3(hidden_states) hidden_states = self.ff(norm_hidden_states) + hidden_states return hidden_states class FeedForward(nn.Module): dim: int dtype: jnp.dtype = jnp.float32 def setup(self) -> None: self.net_0 = GEGLU(self.dim, self.dtype) self.net_2 = nn.Dense(self.dim, dtype = self.dtype) def __call__(self, hidden_states: jax.Array) -> jax.Array: hidden_states = self.net_0(hidden_states) hidden_states = self.net_2(hidden_states) return hidden_states class GEGLU(nn.Module): dim: int dtype: jnp.dtype = jnp.float32 def setup(self) -> None: inner_dim = self.dim * 4 self.proj = nn.Dense(inner_dim * 2, dtype = self.dtype) def __call__(self, hidden_states: jax.Array) -> jax.Array: hidden_states = self.proj(hidden_states) hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis = 2) return hidden_linear * nn.gelu(hidden_gelu)