File size: 9,505 Bytes
149cc2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255

from typing import Tuple

import jax
import jax.numpy as jnp
import flax.linen as nn

from .flax_attention_pseudo3d import TransformerPseudo3DModel
from .flax_resnet_pseudo3d import ResnetBlockPseudo3D, DownsamplePseudo3D, UpsamplePseudo3D


class UNetMidBlockPseudo3DCrossAttn(nn.Module):
    in_channels: int
    num_layers: int = 1
    attn_num_head_channels: int = 1
    use_memory_efficient_attention: bool = False
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        resnets = [
                ResnetBlockPseudo3D(
                        in_channels = self.in_channels,
                        out_channels = self.in_channels,
                        dtype = self.dtype
                )
        ]
        attentions = []
        for _ in range(self.num_layers):
            attn_block = TransformerPseudo3DModel(
                    in_channels = self.in_channels,
                    num_attention_heads = self.attn_num_head_channels,
                    attention_head_dim = self.in_channels // self.attn_num_head_channels,
                    num_layers = 1,
                    use_memory_efficient_attention = self.use_memory_efficient_attention,
                    dtype = self.dtype
            )
            attentions.append(attn_block)
            res_block = ResnetBlockPseudo3D(
                    in_channels = self.in_channels,
                    out_channels = self.in_channels,
                    dtype = self.dtype
            )
            resnets.append(res_block)
        self.attentions = attentions
        self.resnets = resnets

    def __call__(self,
            hidden_states: jax.Array,
            temb: jax.Array,
            encoder_hidden_states = jax.Array
    ) -> jax.Array:
        hidden_states = self.resnets[0](hidden_states, temb)
        for attn, resnet in zip(self.attentions, self.resnets[1:]):
            hidden_states = attn(hidden_states, encoder_hidden_states)
            hidden_states = resnet(hidden_states, temb)
        return hidden_states


class CrossAttnDownBlockPseudo3D(nn.Module):
    in_channels: int
    out_channels: int
    num_layers: int = 1
    attn_num_head_channels: int = 1
    add_downsample: bool = True
    use_memory_efficient_attention: bool = False
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        attentions = []
        resnets = []
        for i in range(self.num_layers):
            in_channels = self.in_channels if i == 0 else self.out_channels
            res_block = ResnetBlockPseudo3D(
                    in_channels = in_channels,
                    out_channels = self.out_channels,
                    dtype = self.dtype
            )
            resnets.append(res_block)
            attn_block = TransformerPseudo3DModel(
                    in_channels = self.out_channels,
                    num_attention_heads = self.attn_num_head_channels,
                    attention_head_dim = self.out_channels // self.attn_num_head_channels,
                    num_layers = 1,
                    use_memory_efficient_attention = self.use_memory_efficient_attention,
                    dtype = self.dtype
            )
            attentions.append(attn_block)
        self.resnets = resnets
        self.attentions = attentions

        if self.add_downsample:
            self.downsamplers_0 = DownsamplePseudo3D(
                    out_channels = self.out_channels,
                    dtype = self.dtype
            )
        else:
            self.downsamplers_0 = None

    def __call__(self,
            hidden_states: jax.Array,
            temb: jax.Array,
            encoder_hidden_states: jax.Array
    ) -> Tuple[jax.Array, jax.Array]:
        output_states = ()
        for resnet, attn in zip(self.resnets, self.attentions):
            hidden_states = resnet(hidden_states, temb)
            hidden_states = attn(hidden_states, encoder_hidden_states)
            output_states += (hidden_states, )
        if self.add_downsample:
            hidden_states = self.downsamplers_0(hidden_states)
            output_states += (hidden_states, )
        return hidden_states, output_states


class DownBlockPseudo3D(nn.Module):
    in_channels: int
    out_channels: int
    num_layers: int = 1
    add_downsample: bool = True
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        resnets = []
        for i in range(self.num_layers):
            in_channels = self.in_channels if i == 0 else self.out_channels
            res_block = ResnetBlockPseudo3D(
                    in_channels = in_channels,
                    out_channels = self.out_channels,
                    dtype = self.dtype
            )
            resnets.append(res_block)
        self.resnets = resnets
        if self.add_downsample:
            self.downsamplers_0 = DownsamplePseudo3D(
                    out_channels = self.out_channels,
                    dtype = self.dtype
            )
        else:
            self.downsamplers_0 = None

    def __call__(self,
            hidden_states: jax.Array,
            temb: jax.Array
    ) -> Tuple[jax.Array, jax.Array]:
        output_states = ()
        for resnet in self.resnets:
            hidden_states = resnet(hidden_states, temb)
            output_states += (hidden_states, )
        if self.add_downsample:
            hidden_states = self.downsamplers_0(hidden_states)
            output_states += (hidden_states, )
        return hidden_states, output_states


class CrossAttnUpBlockPseudo3D(nn.Module):
    in_channels: int
    out_channels: int
    prev_output_channels: int
    num_layers: int = 1
    attn_num_head_channels: int = 1
    add_upsample: bool = True
    use_memory_efficient_attention: bool = False
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        resnets = []
        attentions = []
        for i in range(self.num_layers):
            res_skip_channels = self.in_channels if (i == self.num_layers -1) else self.out_channels
            resnet_in_channels = self.prev_output_channels if i == 0 else self.out_channels
            res_block = ResnetBlockPseudo3D(
                    in_channels = resnet_in_channels + res_skip_channels,
                    out_channels = self.out_channels,
                    dtype = self.dtype
            )
            resnets.append(res_block)
            attn_block = TransformerPseudo3DModel(
                    in_channels = self.out_channels,
                    num_attention_heads = self.attn_num_head_channels,
                    attention_head_dim = self.out_channels // self.attn_num_head_channels,
                    num_layers = 1,
                    use_memory_efficient_attention = self.use_memory_efficient_attention,
                    dtype = self.dtype
            )
            attentions.append(attn_block)
        self.resnets = resnets
        self.attentions = attentions
        if self.add_upsample:
            self.upsamplers_0 = UpsamplePseudo3D(
                    out_channels = self.out_channels,
                    dtype = self.dtype
            )
        else:
            self.upsamplers_0 = None

    def __call__(self,
            hidden_states: jax.Array,
            res_hidden_states_tuple: Tuple[jax.Array, ...],
            temb: jax.Array,
            encoder_hidden_states: jax.Array
    ) -> jax.Array:
        for resnet, attn in zip(self.resnets, self.attentions):
            res_hidden_states = res_hidden_states_tuple[-1]
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
            hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis = -1)
            hidden_states = resnet(hidden_states, temb)
            hidden_states = attn(hidden_states, encoder_hidden_states)
        if self.add_upsample:
            hidden_states = self.upsamplers_0(hidden_states)
        return hidden_states


class UpBlockPseudo3D(nn.Module):
    in_channels: int
    out_channels: int
    prev_output_channels: int
    num_layers: int = 1
    add_upsample: bool = True
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        resnets = []
        for i in range(self.num_layers):
            res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
            resnet_in_channels = self.prev_output_channels if i == 0 else self.out_channels
            res_block = ResnetBlockPseudo3D(
                    in_channels = resnet_in_channels + res_skip_channels,
                    out_channels = self.out_channels,
                    dtype = self.dtype
            )
            resnets.append(res_block)
        self.resnets = resnets
        if self.add_upsample:
            self.upsamplers_0 = UpsamplePseudo3D(
                    out_channels = self.out_channels,
                    dtype = self.dtype
            )
        else:
            self.upsamplers_0 = None

    def __call__(self,
            hidden_states: jax.Array,
            res_hidden_states_tuple: Tuple[jax.Array, ...],
            temb: jax.Array
    ) -> jax.Array:
        for resnet in self.resnets:
            res_hidden_states = res_hidden_states_tuple[-1]
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
            hidden_states = jnp.concatenate([hidden_states, res_hidden_states], axis = -1)
            hidden_states = resnet(hidden_states, temb)
        if self.add_upsample:
            hidden_states = self.upsamplers_0(hidden_states)
        return hidden_states