lopho's picture
forgot about the nested package structure
b2f876f
import jax
import jax.numpy as jnp
import flax.linen as nn
def get_sinusoidal_embeddings(
timesteps: jax.Array,
embedding_dim: int,
freq_shift: float = 1,
min_timescale: float = 1,
max_timescale: float = 1.0e4,
flip_sin_to_cos: bool = False,
scale: float = 1.0,
dtype: jnp.dtype = jnp.float32
) -> jax.Array:
assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
num_timescales = float(embedding_dim // 2)
log_timescale_increment = jnp.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype = dtype) * -log_timescale_increment)
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
# scale embeddings
scaled_time = scale * emb
if flip_sin_to_cos:
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis = 1)
else:
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis = 1)
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
return signal
class TimestepEmbedding(nn.Module):
time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, temb: jax.Array) -> jax.Array:
temb = nn.Dense(self.time_embed_dim, dtype = self.dtype, name = "linear_1")(temb)
temb = nn.silu(temb)
temb = nn.Dense(self.time_embed_dim, dtype = self.dtype, name = "linear_2")(temb)
return temb
class Timesteps(nn.Module):
dim: int = 32
flip_sin_to_cos: bool = False
freq_shift: float = 1
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, timesteps: jax.Array) -> jax.Array:
return get_sinusoidal_embeddings(
timesteps = timesteps,
embedding_dim = self.dim,
flip_sin_to_cos = self.flip_sin_to_cos,
freq_shift = self.freq_shift,
dtype = self.dtype
)