import torch from vae.causal_video_autoencoder import CausalVideoAutoencoder from transformer.transformer3d import Trasformer3D from patchify.symmetric import SymmetricPatchifier model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS" vae_path = "/opt/models/checkpoints/vae_training/causal_vvae_32x32x8_420m_cont_32/step_2296000" dtype = torch.float32 vae = CausalVideoAutoencoder.from_pretrained( pretrained_model_name_or_path=vae_local_path, revision=False, torch_dtype=torch.bfloat16, load_in_8bit=False, ) transformer_config_path = "/opt/txt2img/txt2img/config/transformer3d/xora_v1.2-L.json" transformer_config = Transformer3D.load_config(config_local_path) transformer = Transformer3D.from_config(config) transformer_local_path = "/opt/models/logs/v1.2-vae-mf-medHR-mr-cvae-nl/ckpt/01760000/model.p" transformer_ckpt_state_dict = torch.load(transformer_local_path) transformer.load_state_dict(transformer_ckpt_state_dict, True) unet = transformer scheduler_config_path = "/opt/txt2img/txt2img/config/scheduler/RF_SD3_shifted.json" scheduler_config = RectifiedFlowScheduler.load_config(config_local_path) scheduler = RectifiedFlowScheduler.from_config(config) patchifier = SymmetricPatchifier(patch_size=1) pipeline = VideoPixArtAlphaPipeline.from_pretrained(model_name_or_path, safety_checker=None, revision=None, torch_dtype=dtype, **submodel_dict, ) num_inference_steps=20 num_images_per_prompt=2 guidance_scale=3 height=512 width=768 num_frames=57 frame_rate=25 sample = { "prompt_embeds": None, # (B, L, E) 'prompt_attention_mask': None, # (B , L) 'negative_prompt_embeds': None,' # (B, L, E) 'negative_prompt': None, 'negative_prompt_attention_mask': None # (B , L) } images = pipeline( num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, guidance_scale=guidance_scale, generator=None, output_type="pt", callback_on_step_end=None, height=height, width=width, num_frames=num_frames, frame_rate=frame_rate, **sample, is_video=True, vae_per_channel_noramlize=True, ).images