Spaces:
Runtime error
Runtime error
File size: 713 Bytes
1d30073 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import logging
import jax.numpy as jnp
import flax.linen as nn
logger = logging.getLogger(__name__)
class Encoder(nn.Module):
'''
Converts N hidden tokens into N seperate latent codes.
'''
latent_token_size: int
n_latent_tokens: int
@nn.compact
def __call__(self, encoding):
latent_tokens = nn.Dense(self.latent_token_size)(encoding)
raw_latent_code = latent_tokens[:, : self.n_latent_tokens, :]
# TODO does this just apply tanh to each latent token? Or across the whole batch
latent_code = jnp.tanh(raw_latent_code)
return latent_code # (batch, latent_tokens_per_sequence, latent_token_dim)
VAE_ENCODER_MODELS = {
'': Encoder,
}
|