File size: 577 Bytes
1d30073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import logging
import flax.linen as nn

logger = logging.getLogger(__name__)


class Decoder(nn.Module):
    '''
        Converts latent code -> transformer encoding.
    '''
    dim_model: int
    n_latent_tokens: int

    @nn.compact
    def __call__(self, latent_code):  # (batch, latent_tokens_per_sequence, latent_token_dim)
        raw_latent_tokens = nn.Dense(self.dim_model)(latent_code)
        latent_tokens = nn.LayerNorm()(raw_latent_tokens)
        return latent_tokens  # (batch, latent_tokens_per_sequence, dim_model)


VAE_DECODER_MODELS = {
    '': Decoder,
}