makeavid-sd-jax / makeavid_sd /flax_impl /flax_attention_pseudo3d.py
lopho's picture
forgot about the nested package structure
b2f876f
from typing import Optional
import jax
import jax.numpy as jnp
import flax.linen as nn
import einops
#from flax_memory_efficient_attention import jax_memory_efficient_attention
#from flax_attention import FlaxAttention
from diffusers.models.attention_flax import FlaxAttention
class TransformerPseudo3DModel(nn.Module):
in_channels: int
num_attention_heads: int
attention_head_dim: int
num_layers: int = 1
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
inner_dim = self.num_attention_heads * self.attention_head_dim
self.norm = nn.GroupNorm(
num_groups = 32,
epsilon = 1e-5
)
self.proj_in = nn.Conv(
inner_dim,
kernel_size = (1, 1),
strides = (1, 1),
padding = 'VALID',
dtype = self.dtype
)
transformer_blocks = []
#CheckpointTransformerBlock = nn.checkpoint(
# BasicTransformerBlockPseudo3D,
# static_argnums = (2,3,4)
# #prevent_cse = False
#)
CheckpointTransformerBlock = BasicTransformerBlockPseudo3D
for _ in range(self.num_layers):
transformer_blocks.append(CheckpointTransformerBlock(
dim = inner_dim,
num_attention_heads = self.num_attention_heads,
attention_head_dim = self.attention_head_dim,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
))
self.transformer_blocks = transformer_blocks
self.proj_out = nn.Conv(
inner_dim,
kernel_size = (1, 1),
strides = (1, 1),
padding = 'VALID',
dtype = self.dtype
)
def __call__(self,
hidden_states: jax.Array,
encoder_hidden_states: Optional[jax.Array] = None
) -> jax.Array:
is_video = hidden_states.ndim == 5
f: Optional[int] = None
if is_video:
# jax is channels last
# b,c,f,h,w WRONG
# b,f,h,w,c CORRECT
# b, c, f, h, w = hidden_states.shape
#hidden_states = einops.rearrange(hidden_states, 'b c f h w -> (b f) c h w')
b, f, h, w, c = hidden_states.shape
hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
batch, height, width, channels = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.reshape(batch, height * width, channels)
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states,
f,
height,
width
)
hidden_states = hidden_states.reshape(batch, height, width, channels)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
if is_video:
hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
return hidden_states
class BasicTransformerBlockPseudo3D(nn.Module):
dim: int
num_attention_heads: int
attention_head_dim: int
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.attn1 = FlaxAttention(
query_dim = self.dim,
heads = self.num_attention_heads,
dim_head = self.attention_head_dim,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
self.ff = FeedForward(dim = self.dim, dtype = self.dtype)
self.attn2 = FlaxAttention(
query_dim = self.dim,
heads = self.num_attention_heads,
dim_head = self.attention_head_dim,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
self.attn_temporal = FlaxAttention(
query_dim = self.dim,
heads = self.num_attention_heads,
dim_head = self.attention_head_dim,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
self.norm1 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
self.norm2 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
self.norm_temporal = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
self.norm3 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
def __call__(self,
hidden_states: jax.Array,
context: Optional[jax.Array] = None,
frames_length: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None
) -> jax.Array:
if context is not None and frames_length is not None:
context = context.repeat(frames_length, axis = 0)
# self attention
norm_hidden_states = self.norm1(hidden_states)
hidden_states = self.attn1(norm_hidden_states) + hidden_states
# cross attention
norm_hidden_states = self.norm2(hidden_states)
hidden_states = self.attn2(
norm_hidden_states,
context = context
) + hidden_states
# temporal attention
if frames_length is not None:
#bf, hw, c = hidden_states.shape
# (b f) (h w) c -> b f (h w) c
#hidden_states = hidden_states.reshape(bf // frames_length, frames_length, hw, c)
#b, f, hw, c = hidden_states.shape
# b f (h w) c -> b (h w) f c
#hidden_states = hidden_states.transpose(0, 2, 1, 3)
# b (h w) f c -> (b h w) f c
#hidden_states = hidden_states.reshape(b * hw, frames_length, c)
hidden_states = einops.rearrange(
hidden_states,
'(b f) (h w) c -> (b h w) f c',
f = frames_length,
h = height,
w = width
)
norm_hidden_states = self.norm_temporal(hidden_states)
hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states
# (b h w) f c -> b (h w) f c
#hidden_states = hidden_states.reshape(b, hw, f, c)
# b (h w) f c -> b f (h w) c
#hidden_states = hidden_states.transpose(0, 2, 1, 3)
# b f h w c -> (b f) (h w) c
#hidden_states = hidden_states.reshape(bf, hw, c)
hidden_states = einops.rearrange(
hidden_states,
'(b h w) f c -> (b f) (h w) c',
f = frames_length,
h = height,
w = width
)
norm_hidden_states = self.norm3(hidden_states)
hidden_states = self.ff(norm_hidden_states) + hidden_states
return hidden_states
class FeedForward(nn.Module):
dim: int
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.net_0 = GEGLU(self.dim, self.dtype)
self.net_2 = nn.Dense(self.dim, dtype = self.dtype)
def __call__(self, hidden_states: jax.Array) -> jax.Array:
hidden_states = self.net_0(hidden_states)
hidden_states = self.net_2(hidden_states)
return hidden_states
class GEGLU(nn.Module):
dim: int
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
inner_dim = self.dim * 4
self.proj = nn.Dense(inner_dim * 2, dtype = self.dtype)
def __call__(self, hidden_states: jax.Array) -> jax.Array:
hidden_states = self.proj(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis = 2)
return hidden_linear * nn.gelu(hidden_gelu)