File size: 1,129 Bytes
1d30073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import jax.numpy as jnp
import flax.linen as nn

from t5_vae_flax_alt.src.encoders import VAE_ENCODER_MODELS
from t5_vae_flax_alt.src.decoders import VAE_DECODER_MODELS
from t5_vae_flax_alt.src.config import T5VaeConfig


class VAE(nn.Module):
    # see https://github.com/google/flax#what-does-flax-look-like
    """
        An MMD-VAE used with encoder-decoder models.
        Encodes all token encodings into a single latent & spits them back out.
    """
    config: T5VaeConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        self.encoder = VAE_ENCODER_MODELS[self.config.vae_encoder_model](self.config.latent_token_size, self.config.n_latent_tokens)
        self.decoder = VAE_DECODER_MODELS[self.config.vae_decoder_model](self.config.t5.d_model,  self.config.n_latent_tokens)

    def __call__(self, encoding=None, latent_codes=None):
        latent_codes = self.encode(encoding)
        return self.decode(latent_codes), latent_codes

    def encode(self, encoding):
        return self.encoder(encoding)

    def decode(self, latent):
        return self.decoder(latent)