from typing import Any, Dict, Optional, Union import torch import torch.nn as nn import numpy as np import math from diffusers.models.activations import get_activation from einops import rearrange def get_1d_sincos_pos_embed( embed_dim, num_frames, cls_token=False, extra_tokens=0, ): t = np.arange(num_frames, dtype=np.float32) pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, t) # (T, D) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed( embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 ): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ if isinstance(grid_size, int): grid_size = (grid_size, grid_size) grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000, ): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange( start=0, end=half_dim, dtype=torch.float32, device=timesteps.device ) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb class Timesteps(nn.Module): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift def forward(self, timesteps): t_emb = get_timestep_embedding( timesteps, self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, ) return t_emb class TimestepEmbedding(nn.Module): def __init__( self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, post_act_fn: Optional[str] = None, sample_proj_bias=True, ): super().__init__() self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) self.act = get_activation(act_fn) self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, sample_proj_bias) def forward(self, sample): sample = self.linear_1(sample) sample = self.act(sample) sample = self.linear_2(sample) return sample class TextProjection(nn.Module): def __init__(self, in_features, hidden_size, act_fn="silu"): super().__init__() self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) self.act_1 = get_activation(act_fn) self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True) def forward(self, caption): hidden_states = self.linear_1(caption) hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states class CombinedTimestepConditionEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.text_embedder = TextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") def forward(self, timestep, pooled_projection): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) pooled_projections = self.text_embedder(pooled_projection) conditioning = timesteps_emb + pooled_projections return conditioning class CombinedTimestepEmbeddings(nn.Module): def __init__(self, embedding_dim): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) def forward(self, timestep): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D) return timesteps_emb class PatchEmbed3D(nn.Module): """Support the 3D Tensor input""" def __init__( self, height=128, width=128, patch_size=2, in_channels=16, embed_dim=1536, layer_norm=False, bias=True, interpolation_scale=1, pos_embed_type="sincos", temp_pos_embed_type='rope', pos_embed_max_size=192, # For SD3 cropping max_num_frames=64, add_temp_pos_embed=False, interp_condition_pos=False, ): super().__init__() num_patches = (height // patch_size) * (width // patch_size) self.layer_norm = layer_norm self.pos_embed_max_size = pos_embed_max_size self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) if layer_norm: self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) else: self.norm = None self.patch_size = patch_size self.height, self.width = height // patch_size, width // patch_size self.base_size = height // patch_size self.interpolation_scale = interpolation_scale self.add_temp_pos_embed = add_temp_pos_embed # Calculate positional embeddings based on max size or default if pos_embed_max_size: grid_size = pos_embed_max_size else: grid_size = int(num_patches**0.5) if pos_embed_type is None: self.pos_embed = None elif pos_embed_type == "sincos": pos_embed = get_2d_sincos_pos_embed( embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale ) persistent = True if pos_embed_max_size else False self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) if add_temp_pos_embed and temp_pos_embed_type == 'sincos': time_pos_embed = get_1d_sincos_pos_embed(embed_dim, max_num_frames) self.register_buffer("temp_pos_embed", torch.from_numpy(time_pos_embed).float().unsqueeze(0), persistent=True) elif pos_embed_type == "rope": print("Using the rotary position embedding") else: raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") self.pos_embed_type = pos_embed_type self.temp_pos_embed_type = temp_pos_embed_type self.interp_condition_pos = interp_condition_pos def cropped_pos_embed(self, height, width, ori_height, ori_width): """Crops positional embeddings for SD3 compatibility.""" if self.pos_embed_max_size is None: raise ValueError("`pos_embed_max_size` must be set for cropping.") height = height // self.patch_size width = width // self.patch_size ori_height = ori_height // self.patch_size ori_width = ori_width // self.patch_size assert ori_height >= height, "The ori_height needs >= height" assert ori_width >= width, "The ori_width needs >= width" if height > self.pos_embed_max_size: raise ValueError( f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." ) if width > self.pos_embed_max_size: raise ValueError( f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." ) if self.interp_condition_pos: top = (self.pos_embed_max_size - ori_height) // 2 left = (self.pos_embed_max_size - ori_width) // 2 spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) spatial_pos_embed = spatial_pos_embed[:, top : top + ori_height, left : left + ori_width, :] # [b h w c] if ori_height != height or ori_width != width: spatial_pos_embed = spatial_pos_embed.permute(0, 3, 1, 2) spatial_pos_embed = torch.nn.functional.interpolate(spatial_pos_embed, size=(height, width), mode='bilinear') spatial_pos_embed = spatial_pos_embed.permute(0, 2, 3, 1) else: top = (self.pos_embed_max_size - height) // 2 left = (self.pos_embed_max_size - width) // 2 spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) return spatial_pos_embed def forward_func(self, latent, time_index=0, ori_height=None, ori_width=None): if self.pos_embed_max_size is not None: height, width = latent.shape[-2:] else: height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size bs = latent.shape[0] temp = latent.shape[2] latent = rearrange(latent, 'b c t h w -> (b t) c h w') latent = self.proj(latent) latent = latent.flatten(2).transpose(1, 2) # (BT)CHW -> (BT)NC if self.layer_norm: latent = self.norm(latent) if self.pos_embed_type == 'sincos': # Spatial position embedding, Interpolate or crop positional embeddings as needed if self.pos_embed_max_size: pos_embed = self.cropped_pos_embed(height, width, ori_height, ori_width) else: raise NotImplementedError("Not implemented sincos pos embed without sd3 max pos crop") if self.height != height or self.width != width: pos_embed = get_2d_sincos_pos_embed( embed_dim=self.pos_embed.shape[-1], grid_size=(height, width), base_size=self.base_size, interpolation_scale=self.interpolation_scale, ) pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) else: pos_embed = self.pos_embed if self.add_temp_pos_embed and self.temp_pos_embed_type == 'sincos': latent_dtype = latent.dtype latent = latent + pos_embed latent = rearrange(latent, '(b t) n c -> (b n) t c', t=temp) latent = latent + self.temp_pos_embed[:, time_index:time_index + temp, :] latent = latent.to(latent_dtype) latent = rearrange(latent, '(b n) t c -> b t n c', b=bs) else: latent = (latent + pos_embed).to(latent.dtype) latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp) else: assert self.pos_embed_type == "rope", "Only supporting the sincos and rope embedding" latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp) return latent def forward(self, latent): """ Arguments: past_condition_latents (Torch.FloatTensor): The past latent during the generation flatten_input (bool): True indicate flatten the latent into 1D sequence """ if isinstance(latent, list): output_list = [] for latent_ in latent: if not isinstance(latent_, list): latent_ = [latent_] output_latent = [] time_index = 0 ori_height, ori_width = latent_[-1].shape[-2:] for each_latent in latent_: hidden_state = self.forward_func(each_latent, time_index=time_index, ori_height=ori_height, ori_width=ori_width) time_index += each_latent.shape[2] hidden_state = rearrange(hidden_state, "b t n c -> b (t n) c") output_latent.append(hidden_state) output_latent = torch.cat(output_latent, dim=1) output_list.append(output_latent) return output_list else: hidden_states = self.forward_func(latent) hidden_states = rearrange(hidden_states, "b t n c -> b (t n) c") return hidden_states