from abc import ABC, abstractmethod from typing import Tuple import torch from diffusers.configuration_utils import ConfigMixin from einops import rearrange from torch import Tensor from txt2img.common.torch_utils import append_dims from txt2img.config.diffusion_parts import PatchifierConfig, PatchifierName def pixart_alpha_patchify( latents: Tensor, patch_size: int, ) -> Tuple[Tensor, Tensor]: latents = rearrange( latents, "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", p1=patch_size[0], p2=patch_size[1], p3=patch_size[2], ) return latents class SymmetricPatchifier(Patchifier): def patchify( self, latents: Tensor, ) -> Tuple[Tensor, Tensor]: return pixart_alpha_patchify(latents, self._patch_size) def unpatchify( self, latents: Tensor, output_height: int, output_width: int, output_num_frames: int, out_channels: int ) -> Tuple[Tensor, Tensor]: output_height = output_height // self._patch_size[1] output_width = output_width // self._patch_size[2] latents = rearrange( latents, "b (f h w) (c p q) -> b c f (h p) (w q) ", f=output_num_frames, h=output_height, w=output_width, p=self._patch_size[1], q=self._patch_size[2], ) return latents