Spaces:
Runtime error
Runtime error
from typing import Tuple | |
import jax | |
import jax.numpy as jnp | |
import flax.linen as nn | |
from .flax_attention_pseudo3d import TransformerPseudo3DModel | |
from .flax_resnet_pseudo3d import ResnetBlockPseudo3D, DownsamplePseudo3D, UpsamplePseudo3D | |
class UNetMidBlockPseudo3DCrossAttn(nn.Module): | |
in_channels: int | |
num_layers: int = 1 | |
attn_num_head_channels: int = 1 | |
use_memory_efficient_attention: bool = False | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self) -> None: | |
resnets = [ | |
ResnetBlockPseudo3D( | |
in_channels = self.in_channels, | |
out_channels = self.in_channels, | |
dtype = self.dtype | |
) | |
] | |
attentions = [] | |
for _ in range(self.num_layers): | |
attn_block = TransformerPseudo3DModel( | |
in_channels = self.in_channels, | |
num_attention_heads = self.attn_num_head_channels, | |
attention_head_dim = self.in_channels // self.attn_num_head_channels, | |
num_layers = 1, | |
use_memory_efficient_attention = self.use_memory_efficient_attention, | |
dtype = self.dtype | |
) | |
attentions.append(attn_block) | |
res_block = ResnetBlockPseudo3D( | |
in_channels = self.in_channels, | |
out_channels = self.in_channels, | |
dtype = self.dtype | |
) | |
resnets.append(res_block) | |
self.attentions = attentions | |
self.resnets = resnets | |
def __call__(self, | |
hidden_states: jax.Array, | |
temb: jax.Array, | |
encoder_hidden_states = jax.Array | |
) -> jax.Array: | |
hidden_states = self.resnets[0](hidden_states, temb) | |
for attn, resnet in zip(self.attentions, self.resnets[1:]): | |
hidden_states = attn(hidden_states, encoder_hidden_states) | |
hidden_states = resnet(hidden_states, temb) | |
return hidden_states | |
class CrossAttnDownBlockPseudo3D(nn.Module): | |
in_channels: int | |
out_channels: int | |
num_layers: int = 1 | |
attn_num_head_channels: int = 1 | |
add_downsample: bool = True | |
use_memory_efficient_attention: bool = False | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self) -> None: | |
attentions = [] | |
resnets = [] | |
for i in range(self.num_layers): | |
in_channels = self.in_channels if i == 0 else self.out_channels | |
res_block = ResnetBlockPseudo3D( | |
in_channels = in_channels, | |
out_channels = self.out_channels, | |
dtype = self.dtype | |
) | |
resnets.append(res_block) | |
attn_block = TransformerPseudo3DModel( | |
in_channels = self.out_channels, | |
num_attention_heads = self.attn_num_head_channels, | |
attention_head_dim = self.out_channels // self.attn_num_head_channels, | |
num_layers = 1, | |
use_memory_efficient_attention = self.use_memory_efficient_attention, | |
dtype = self.dtype | |
) | |
attentions.append(attn_block) | |
self.resnets = resnets | |
self.attentions = attentions | |
if self.add_downsample: | |
self.downsamplers_0 = DownsamplePseudo3D( | |
out_channels = self.out_channels, | |
dtype = self.dtype | |
) | |
else: | |
self.downsamplers_0 = None | |
def __call__(self, | |
hidden_states: jax.Array, | |
temb: jax.Array, | |
encoder_hidden_states: jax.Array | |
) -> Tuple[jax.Array, jax.Array]: | |
output_states = () | |
for resnet, attn in zip(self.resnets, self.attentions): | |
hidden_states = resnet(hidden_states, temb) | |
hidden_states = attn(hidden_states, encoder_hidden_states) | |
output_states += (hidden_states, ) | |
if self.add_downsample: | |
hidden_states = self.downsamplers_0(hidden_states) | |
output_states += (hidden_states, ) | |
return hidden_states, output_states | |
class DownBlockPseudo3D(nn.Module): | |
in_channels: int | |
out_channels: int | |
num_layers: int = 1 | |
add_downsample: bool = True | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self) -> None: | |
resnets = [] | |
for i in range(self.num_layers): | |
in_channels = self.in_channels if i == 0 else self.out_channels | |
res_block = ResnetBlockPseudo3D( | |
in_channels = in_channels, | |
out_channels = self.out_channels, | |
dtype = self.dtype | |
) | |
resnets.append(res_block) | |
self.resnets = resnets | |
if self.add_downsample: | |
self.downsamplers_0 = DownsamplePseudo3D( | |
out_channels = self.out_channels, | |
dtype = self.dtype | |
) | |
else: | |
self.downsamplers_0 = None | |
def __call__(self, | |
hidden_states: jax.Array, | |
temb: jax.Array | |
) -> Tuple[jax.Array, jax.Array]: | |
output_states = () | |
for resnet in self.resnets: | |
hidden_states = resnet(hidden_states, temb) | |
output_states += (hidden_states, ) | |
if self.add_downsample: | |
hidden_states = self.downsamplers_0(hidden_states) | |
output_states += (hidden_states, ) | |
return hidden_states, output_states | |
class CrossAttnUpBlockPseudo3D(nn.Module): | |
in_channels: int | |
out_channels: int | |
prev_output_channels: int | |
num_layers: int = 1 | |
attn_num_head_channels: int = 1 | |
add_upsample: bool = True | |
use_memory_efficient_attention: bool = False | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self) -> None: | |
resnets = [] | |
attentions = [] | |
for i in range(self.num_layers): | |
res_skip_channels = self.in_channels if (i == self.num_layers -1) else self.out_channels | |
resnet_in_channels = self.prev_output_channels if i == 0 else self.out_channels | |
res_block = ResnetBlockPseudo3D( | |
in_channels = resnet_in_channels + res_skip_channels, | |
out_channels = self.out_channels, | |
dtype = self.dtype | |
) | |
resnets.append(res_block) | |
attn_block = TransformerPseudo3DModel( | |
in_channels = self.out_channels, | |
num_attention_heads = self.attn_num_head_channels, | |
attention_head_dim = self.out_channels // self.attn_num_head_channels, | |
num_layers = 1, | |
use_memory_efficient_attention = self.use_memory_efficient_attention, | |
dtype = self.dtype | |
) | |
attentions.append(attn_block) | |
self.resnets = resnets | |
self.attentions = attentions | |
if self.add_upsample: | |
self.upsamplers_0 = UpsamplePseudo3D( | |
out_channels = self.out_channels, | |
dtype = self.dtype | |
) | |
else: | |
self.upsamplers_0 = None | |
def __call__(self, | |
hidden_states: jax.Array, | |
res_hidden_states_tuple: Tuple[jax.Array, ...], | |
temb: jax.Array, | |
encoder_hidden_states: jax.Array | |
) -> jax.Array: | |
for resnet, attn in zip(self.resnets, self.attentions): | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis = -1) | |
hidden_states = resnet(hidden_states, temb) | |
hidden_states = attn(hidden_states, encoder_hidden_states) | |
if self.add_upsample: | |
hidden_states = self.upsamplers_0(hidden_states) | |
return hidden_states | |
class UpBlockPseudo3D(nn.Module): | |
in_channels: int | |
out_channels: int | |
prev_output_channels: int | |
num_layers: int = 1 | |
add_upsample: bool = True | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self) -> None: | |
resnets = [] | |
for i in range(self.num_layers): | |
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels | |
resnet_in_channels = self.prev_output_channels if i == 0 else self.out_channels | |
res_block = ResnetBlockPseudo3D( | |
in_channels = resnet_in_channels + res_skip_channels, | |
out_channels = self.out_channels, | |
dtype = self.dtype | |
) | |
resnets.append(res_block) | |
self.resnets = resnets | |
if self.add_upsample: | |
self.upsamplers_0 = UpsamplePseudo3D( | |
out_channels = self.out_channels, | |
dtype = self.dtype | |
) | |
else: | |
self.upsamplers_0 = None | |
def __call__(self, | |
hidden_states: jax.Array, | |
res_hidden_states_tuple: Tuple[jax.Array, ...], | |
temb: jax.Array | |
) -> jax.Array: | |
for resnet in self.resnets: | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
hidden_states = jnp.concatenate([hidden_states, res_hidden_states], axis = -1) | |
hidden_states = resnet(hidden_states, temb) | |
if self.add_upsample: | |
hidden_states = self.upsamplers_0(hidden_states) | |
return hidden_states | |