makeavid-sd-jax / makeavid_sd /flax_impl /flax_unet_pseudo3d_blocks.py
lopho's picture
forgot about the nested package structure
b2f876f
raw
history blame
9.51 kB
from typing import Tuple
import jax
import jax.numpy as jnp
import flax.linen as nn
from .flax_attention_pseudo3d import TransformerPseudo3DModel
from .flax_resnet_pseudo3d import ResnetBlockPseudo3D, DownsamplePseudo3D, UpsamplePseudo3D
class UNetMidBlockPseudo3DCrossAttn(nn.Module):
in_channels: int
num_layers: int = 1
attn_num_head_channels: int = 1
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
resnets = [
ResnetBlockPseudo3D(
in_channels = self.in_channels,
out_channels = self.in_channels,
dtype = self.dtype
)
]
attentions = []
for _ in range(self.num_layers):
attn_block = TransformerPseudo3DModel(
in_channels = self.in_channels,
num_attention_heads = self.attn_num_head_channels,
attention_head_dim = self.in_channels // self.attn_num_head_channels,
num_layers = 1,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
attentions.append(attn_block)
res_block = ResnetBlockPseudo3D(
in_channels = self.in_channels,
out_channels = self.in_channels,
dtype = self.dtype
)
resnets.append(res_block)
self.attentions = attentions
self.resnets = resnets
def __call__(self,
hidden_states: jax.Array,
temb: jax.Array,
encoder_hidden_states = jax.Array
) -> jax.Array:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
class CrossAttnDownBlockPseudo3D(nn.Module):
in_channels: int
out_channels: int
num_layers: int = 1
attn_num_head_channels: int = 1
add_downsample: bool = True
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
attentions = []
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
res_block = ResnetBlockPseudo3D(
in_channels = in_channels,
out_channels = self.out_channels,
dtype = self.dtype
)
resnets.append(res_block)
attn_block = TransformerPseudo3DModel(
in_channels = self.out_channels,
num_attention_heads = self.attn_num_head_channels,
attention_head_dim = self.out_channels // self.attn_num_head_channels,
num_layers = 1,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
attentions.append(attn_block)
self.resnets = resnets
self.attentions = attentions
if self.add_downsample:
self.downsamplers_0 = DownsamplePseudo3D(
out_channels = self.out_channels,
dtype = self.dtype
)
else:
self.downsamplers_0 = None
def __call__(self,
hidden_states: jax.Array,
temb: jax.Array,
encoder_hidden_states: jax.Array
) -> Tuple[jax.Array, jax.Array]:
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states)
output_states += (hidden_states, )
if self.add_downsample:
hidden_states = self.downsamplers_0(hidden_states)
output_states += (hidden_states, )
return hidden_states, output_states
class DownBlockPseudo3D(nn.Module):
in_channels: int
out_channels: int
num_layers: int = 1
add_downsample: bool = True
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
res_block = ResnetBlockPseudo3D(
in_channels = in_channels,
out_channels = self.out_channels,
dtype = self.dtype
)
resnets.append(res_block)
self.resnets = resnets
if self.add_downsample:
self.downsamplers_0 = DownsamplePseudo3D(
out_channels = self.out_channels,
dtype = self.dtype
)
else:
self.downsamplers_0 = None
def __call__(self,
hidden_states: jax.Array,
temb: jax.Array
) -> Tuple[jax.Array, jax.Array]:
output_states = ()
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states, )
if self.add_downsample:
hidden_states = self.downsamplers_0(hidden_states)
output_states += (hidden_states, )
return hidden_states, output_states
class CrossAttnUpBlockPseudo3D(nn.Module):
in_channels: int
out_channels: int
prev_output_channels: int
num_layers: int = 1
attn_num_head_channels: int = 1
add_upsample: bool = True
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
resnets = []
attentions = []
for i in range(self.num_layers):
res_skip_channels = self.in_channels if (i == self.num_layers -1) else self.out_channels
resnet_in_channels = self.prev_output_channels if i == 0 else self.out_channels
res_block = ResnetBlockPseudo3D(
in_channels = resnet_in_channels + res_skip_channels,
out_channels = self.out_channels,
dtype = self.dtype
)
resnets.append(res_block)
attn_block = TransformerPseudo3DModel(
in_channels = self.out_channels,
num_attention_heads = self.attn_num_head_channels,
attention_head_dim = self.out_channels // self.attn_num_head_channels,
num_layers = 1,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
attentions.append(attn_block)
self.resnets = resnets
self.attentions = attentions
if self.add_upsample:
self.upsamplers_0 = UpsamplePseudo3D(
out_channels = self.out_channels,
dtype = self.dtype
)
else:
self.upsamplers_0 = None
def __call__(self,
hidden_states: jax.Array,
res_hidden_states_tuple: Tuple[jax.Array, ...],
temb: jax.Array,
encoder_hidden_states: jax.Array
) -> jax.Array:
for resnet, attn in zip(self.resnets, self.attentions):
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis = -1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states)
if self.add_upsample:
hidden_states = self.upsamplers_0(hidden_states)
return hidden_states
class UpBlockPseudo3D(nn.Module):
in_channels: int
out_channels: int
prev_output_channels: int
num_layers: int = 1
add_upsample: bool = True
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
resnets = []
for i in range(self.num_layers):
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
resnet_in_channels = self.prev_output_channels if i == 0 else self.out_channels
res_block = ResnetBlockPseudo3D(
in_channels = resnet_in_channels + res_skip_channels,
out_channels = self.out_channels,
dtype = self.dtype
)
resnets.append(res_block)
self.resnets = resnets
if self.add_upsample:
self.upsamplers_0 = UpsamplePseudo3D(
out_channels = self.out_channels,
dtype = self.dtype
)
else:
self.upsamplers_0 = None
def __call__(self,
hidden_states: jax.Array,
res_hidden_states_tuple: Tuple[jax.Array, ...],
temb: jax.Array
) -> jax.Array:
for resnet in self.resnets:
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = jnp.concatenate([hidden_states, res_hidden_states], axis = -1)
hidden_states = resnet(hidden_states, temb)
if self.add_upsample:
hidden_states = self.upsamplers_0(hidden_states)
return hidden_states