import gradio as gr from gradio_toggle import Toggle import torch from huggingface_hub import snapshot_download from transformers import pipeline from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder from xora.models.transformers.transformer3d import Transformer3DModel from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier from xora.schedulers.rf import RectifiedFlowScheduler from xora.pipelines.pipeline_xora_video import XoraVideoPipeline from transformers import T5EncoderModel, T5Tokenizer from xora.utils.conditioning_method import ConditioningMethod from pathlib import Path import safetensors.torch import json import numpy as np import cv2 from PIL import Image import tempfile import os import gc from openai import OpenAI import re # Load system prompts system_prompt_t2v = """당신은 비디오 생성을 위한 프롬프트 전문가입니다. 주어진 프롬프트를 다음 구조에 맞게 개선해주세요: 1. 주요 동작을 명확한 한 문장으로 시작 2. 구체적인 동작과 제스처를 시간 순서대로 설명 3. 캐릭터/객체의 외모를 상세히 묘사 4. 배경과 환경 세부 사항을 구체적으로 포함 5. 카메라 각도와 움직임을 명시 6. 조명과 색상을 자세히 설명 7. 변화나 갑작스러운 사건을 자연스럽게 포함 모든 설명은 하나의 자연스러운 문단으로 작성하고, 촬영 감독이 촬영 목록을 설명하는 것처럼 구체적이고 시각적으로 작성하세요. 200단어를 넘지 않도록 하되, 최대한 상세하게 작성하세요.""" system_prompt_i2v = """당신은 이미지 기반 비디오 생성을 위한 프롬프트 전문가입니다. 주어진 프롬프트를 다음 구조에 맞게 개선해주세요: 1. 주요 동작을 명확한 한 문장으로 시작 2. 구체적인 동작과 제스처를 시간 순서대로 설명 3. 캐릭터/객체의 외모를 상세히 묘사 4. 배경과 환경 세부 사항을 구체적으로 포함 5. 카메라 각도와 움직임을 명시 6. 조명과 색상을 자세히 설명 7. 변화나 갑작스러운 사건을 자연스럽게 포함 모든 설명은 하나의 자연스러운 문단으로 작성하고, 촬영 감독이 촬영 목록을 설명하는 것처럼 구체적이고 시각적으로 작성하세요. 200단어를 넘지 않도록 하되, 최대한 상세하게 작성하세요.""" # Load Hugging Face token if needed hf_token = os.getenv("HF_TOKEN") openai_api_key = os.getenv("OPENAI_API_KEY") client = OpenAI(api_key=openai_api_key) # Initialize translation pipeline translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") # Korean text detection function def contains_korean(text): korean_pattern = re.compile('[ㄱ-ㅎㅏ-ㅣ가-힣]') return bool(korean_pattern.search(text)) def translate_korean_prompt(prompt): """ Translate Korean prompt to English if Korean text is detected """ if contains_korean(prompt): translated = translator(prompt)[0]['translation_text'] print(f"Original Korean prompt: {prompt}") print(f"Translated English prompt: {translated}") return translated return prompt def enhance_prompt(prompt, type="t2v"): system_prompt = system_prompt_t2v if type == "t2v" else system_prompt_i2v messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ] try: response = client.chat.completions.create( model="gpt-4-1106-preview", messages=messages, max_tokens=2000, ) enhanced_prompt = response.choices[0].message.content.strip() print("\n=== 프롬프트 증강 결과 ===") print("Original Prompt:") print(prompt) print("\nEnhanced Prompt:") print(enhanced_prompt) print("========================\n") return enhanced_prompt except Exception as e: print(f"Error during prompt enhancement: {e}") return prompt def update_prompt_t2v(prompt, enhance_toggle): return update_prompt(prompt, enhance_toggle, "t2v") def update_prompt_i2v(prompt, enhance_toggle): return update_prompt(prompt, enhance_toggle, "i2v") def update_prompt(prompt, enhance_toggle, type="t2v"): if enhance_toggle: return enhance_prompt(prompt, type) return prompt # Set model download directory within Hugging Face Spaces model_path = "asset" if not os.path.exists(model_path): snapshot_download( "Lightricks/LTX-Video", local_dir=model_path, repo_type="model", token=hf_token ) # Global variables to load components vae_dir = Path(model_path) / "vae" unet_dir = Path(model_path) / "unet" scheduler_dir = Path(model_path) / "scheduler" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_vae(vae_dir): vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors" vae_config_path = vae_dir / "config.json" with open(vae_config_path, "r") as f: vae_config = json.load(f) vae = CausalVideoAutoencoder.from_config(vae_config) vae_state_dict = safetensors.torch.load_file(vae_ckpt_path) vae.load_state_dict(vae_state_dict) return vae.to(device=device, dtype=torch.bfloat16) def load_unet(unet_dir): unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors" unet_config_path = unet_dir / "config.json" transformer_config = Transformer3DModel.load_config(unet_config_path) transformer = Transformer3DModel.from_config(transformer_config) unet_state_dict = safetensors.torch.load_file(unet_ckpt_path) transformer.load_state_dict(unet_state_dict, strict=True) return transformer.to(device=device, dtype=torch.bfloat16) def load_scheduler(scheduler_dir): scheduler_config_path = scheduler_dir / "scheduler_config.json" scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path) return RectifiedFlowScheduler.from_config(scheduler_config) # Helper function for image processing def center_crop_and_resize(frame, target_height, target_width): h, w, _ = frame.shape aspect_ratio_target = target_width / target_height aspect_ratio_frame = w / h if aspect_ratio_frame > aspect_ratio_target: new_width = int(h * aspect_ratio_target) x_start = (w - new_width) // 2 frame_cropped = frame[:, x_start : x_start + new_width] else: new_height = int(w / aspect_ratio_target) y_start = (h - new_height) // 2 frame_cropped = frame[y_start : y_start + new_height, :] frame_resized = cv2.resize(frame_cropped, (target_width, target_height)) return frame_resized def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768): image = Image.open(image_path).convert("RGB") image_np = np.array(image) frame_resized = center_crop_and_resize(image_np, target_height, target_width) frame_tensor = torch.tensor(frame_resized).permute(2, 0, 1).float() frame_tensor = (frame_tensor / 127.5) - 1.0 return frame_tensor.unsqueeze(0).unsqueeze(2) # Load models vae = load_vae(vae_dir) unet = load_unet(unet_dir) scheduler = load_scheduler(scheduler_dir) patchifier = SymmetricPatchifier(patch_size=1) text_encoder = T5EncoderModel.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder" ).to(device) tokenizer = T5Tokenizer.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer" ) pipeline = XoraVideoPipeline( transformer=unet, patchifier=patchifier, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, vae=vae, ).to(device) # Preset options for resolution and frame configuration # Convert frames to seconds assuming 25 FPS preset_options = [ {"label": "1216x704, 1.6초", "width": 1216, "height": 704, "num_frames": 41}, {"label": "1088x704, 2.0초", "width": 1088, "height": 704, "num_frames": 49}, {"label": "1056x640, 2.3초", "width": 1056, "height": 640, "num_frames": 57}, {"label": "992x608, 2.6초", "width": 992, "height": 608, "num_frames": 65}, {"label": "896x608, 2.9초", "width": 896, "height": 608, "num_frames": 73}, {"label": "896x544, 3.2초", "width": 896, "height": 544, "num_frames": 81}, {"label": "832x544, 3.6초", "width": 832, "height": 544, "num_frames": 89}, {"label": "800x512, 3.9초", "width": 800, "height": 512, "num_frames": 97}, {"label": "768x512, 3.9초", "width": 768, "height": 512, "num_frames": 97}, {"label": "800x480, 4.2초", "width": 800, "height": 480, "num_frames": 105}, {"label": "736x480, 4.5초", "width": 736, "height": 480, "num_frames": 113}, {"label": "704x480, 4.8초", "width": 704, "height": 480, "num_frames": 121}, {"label": "704x448, 5.2초", "width": 704, "height": 448, "num_frames": 129}, {"label": "672x448, 5.5초", "width": 672, "height": 448, "num_frames": 137}, {"label": "640x416, 6.1초", "width": 640, "height": 416, "num_frames": 153}, {"label": "672x384, 6.4초", "width": 672, "height": 384, "num_frames": 161}, {"label": "640x384, 6.8초", "width": 640, "height": 384, "num_frames": 169}, {"label": "608x384, 7.1초", "width": 608, "height": 384, "num_frames": 177}, {"label": "576x384, 7.4초", "width": 576, "height": 384, "num_frames": 185}, {"label": "608x352, 7.7초", "width": 608, "height": 352, "num_frames": 193}, {"label": "576x352, 8.0초", "width": 576, "height": 352, "num_frames": 201}, {"label": "544x352, 8.4초", "width": 544, "height": 352, "num_frames": 209}, {"label": "512x352, 9.3초", "width": 512, "height": 352, "num_frames": 233}, {"label": "544x320, 9.6초", "width": 544, "height": 320, "num_frames": 241}, {"label": "512x320, 10.3초", "width": 512, "height": 320, "num_frames": 257}, ] def preset_changed(preset): if preset != "Custom": selected = next(item for item in preset_options if item["label"] == preset) # height, width, num_frames 값을 global 변수로 업데이트 return ( selected["height"], selected["width"], selected["num_frames"], gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), ) else: return ( None, None, None, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), ) def generate_video_from_text( prompt="", enhance_prompt_toggle=False, negative_prompt="low quality, worst quality, deformed, distorted, warped, motion smear, motion artifacts, fused fingers, incorrect anatomy, strange hands, unattractive", frame_rate=25, seed=171198, num_inference_steps=41, guidance_scale=4, height=512, width=320, num_frames=257, progress=gr.Progress(), ): if len(prompt.strip()) < 50: raise gr.Error( "프롬프트는 최소 50자 이상이어야 합니다. 더 자세한 설명을 제공해주세요.", duration=5, ) # Translate Korean prompts to English prompt = translate_korean_prompt(prompt) negative_prompt = translate_korean_prompt(negative_prompt) sample = { "prompt": prompt, "prompt_attention_mask": None, "negative_prompt": negative_prompt, "negative_prompt_attention_mask": None, "media_items": None, } generator = torch.Generator(device="cpu").manual_seed(seed) def gradio_progress_callback(self, step, timestep, kwargs): progress((step + 1) / num_inference_steps) try: with torch.no_grad(): images = pipeline( num_inference_steps=num_inference_steps, num_images_per_prompt=1, guidance_scale=guidance_scale, generator=generator, output_type="pt", height=height, width=width, num_frames=num_frames, frame_rate=frame_rate, **sample, is_video=True, vae_per_channel_normalize=True, conditioning_method=ConditioningMethod.UNCONDITIONAL, mixed_precision=True, callback_on_step_end=gradio_progress_callback, ).images except Exception as e: raise gr.Error( f"비디오 생성 중 오류가 발생했습니다. 다시 시도해주세요. 오류: {e}", duration=5, ) finally: torch.cuda.empty_cache() gc.collect() output_path = tempfile.mktemp(suffix=".mp4") print(images.shape) video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy() video_np = (video_np * 255).astype(np.uint8) height, width = video_np.shape[1:3] out = cv2.VideoWriter( output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height) ) for frame in video_np[..., ::-1]: out.write(frame) out.release() del images del video_np torch.cuda.empty_cache() return output_path def generate_video_from_image( image_path, prompt="", enhance_prompt_toggle=False, negative_prompt="low quality, worst quality, deformed, distorted, warped, motion smear, motion artifacts, fused fingers, incorrect anatomy, strange hands, unattractive", frame_rate=25, seed=171198, num_inference_steps=50, guidance_scale=4, height=512, width=768, num_frames=121, progress=gr.Progress(), ): print("Height: ", height) print("Width: ", width) print("Num Frames: ", num_frames) if len(prompt.strip()) < 50: raise gr.Error( "프롬프트는 최소 50자 이상이어야 합니다. 더 자세한 설명을 제공해주세요.", duration=5, ) if not image_path: raise gr.Error("입력 이미지를 제공해주세요.", duration=5) # Translate Korean prompts to English prompt = translate_korean_prompt(prompt) negative_prompt = translate_korean_prompt(negative_prompt) media_items = ( load_image_to_tensor_with_resize(image_path, height, width).to(device).detach() ) sample = { "prompt": prompt, "prompt_attention_mask": None, "negative_prompt": negative_prompt, "negative_prompt_attention_mask": None, "media_items": media_items, } generator = torch.Generator(device="cpu").manual_seed(seed) def gradio_progress_callback(self, step, timestep, kwargs): progress((step + 1) / num_inference_steps) try: with torch.no_grad(): images = pipeline( num_inference_steps=num_inference_steps, num_images_per_prompt=1, guidance_scale=guidance_scale, generator=generator, output_type="pt", height=height, width=width, num_frames=num_frames, frame_rate=frame_rate, **sample, is_video=True, vae_per_channel_normalize=True, conditioning_method=ConditioningMethod.FIRST_FRAME, mixed_precision=True, callback_on_step_end=gradio_progress_callback, ).images output_path = tempfile.mktemp(suffix=".mp4") video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy() video_np = (video_np * 255).astype(np.uint8) height, width = video_np.shape[1:3] out = cv2.VideoWriter( output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height) ) for frame in video_np[..., ::-1]: out.write(frame) out.release() except Exception as e: raise gr.Error( f"비디오 생성 중 오류가 발생했습니다. 다시 시도해주세요. 오류: {e}", duration=5, ) finally: torch.cuda.empty_cache() gc.collect() return output_path def create_advanced_options(): with gr.Accordion("Step 4: Advanced Options (Optional)", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=1000000, step=1, value=171198 ) inference_steps = gr.Slider( label="4.2 Inference Steps", minimum=1, maximum=50, step=1, value=50, visible=False ) guidance_scale = gr.Slider( label="4.3 Guidance Scale", minimum=1.0, maximum=5.0, step=0.1, value=4.0, visible=False ) height_slider = gr.Slider( label="4.4 Height", minimum=256, maximum=1024, step=64, value=512, visible=False, ) width_slider = gr.Slider( label="4.5 Width", minimum=256, maximum=1024, step=64, value=768, visible=False, ) num_frames_slider = gr.Slider( label="4.5 Number of Frames", minimum=1, maximum=200, step=1, value=121, visible=False, ) return [ seed, inference_steps, guidance_scale, height_slider, width_slider, num_frames_slider, ] # Gradio Interface Definition with gr.Blocks(theme=gr.themes.Soft()) as iface: with gr.Tabs(): # Text to Video Tab with gr.TabItem("텍스트로 비디오 만들기"): with gr.Row(): with gr.Column(): txt2vid_prompt = gr.Textbox( label="Step 1: 프롬프트 입력", placeholder="생성하고 싶은 비디오를 설명하세요 (최소 50자)...", value="귀여운 고양이", lines=5, ) txt2vid_enhance_toggle = Toggle( label="프롬프트 개선", value=False, interactive=True, ) txt2vid_negative_prompt = gr.Textbox( label="Step 2: 네거티브 프롬프트 입력", placeholder="비디오에서 원하지 않는 요소를 설명하세요...", value="low quality, worst quality, deformed, distorted, warped, motion smear, motion artifacts, fused fingers, incorrect anatomy, strange hands, unattractive", lines=2, visible=False ) # 현재 선택된 값들을 저장할 상태 변수들 txt2vid_current_height = gr.State(value=512) txt2vid_current_width = gr.State(value=320) txt2vid_current_num_frames = gr.State(value=257) txt2vid_preset = gr.Dropdown( choices=[p["label"] for p in preset_options], value="512x320, 10.3초", label="Step 2: 해상도 프리셋 선택", ) txt2vid_frame_rate = gr.Slider( label="Step 3: 프레임 레이트", minimum=21, maximum=30, step=1, value=25, visible=False ) txt2vid_advanced = create_advanced_options() txt2vid_generate = gr.Button( "Step 3: 비디오 생성", variant="primary", size="lg", ) with gr.Column(): txt2vid_output = gr.Video(label="생성된 비디오") # Image to Video Tab with gr.TabItem("이미지로 비디오 만들기"): with gr.Row(): with gr.Column(): img2vid_image = gr.Image( type="filepath", label="Step 1: 입력 이미지 업로드", elem_id="image_upload", ) img2vid_prompt = gr.Textbox( label="Step 2: 프롬프트 입력", placeholder="이미지를 어떻게 애니메이션화할지 설명하세요 (최소 50자)...", value="귀여운 고양이", lines=5, ) img2vid_enhance_toggle = Toggle( label="프롬프트 증강", value=False, interactive=True, ) img2vid_negative_prompt = gr.Textbox( label="Step 3: 네거티브 프롬프트 입력", placeholder="비디오에서 원하지 않는 요소를 설명하세요...", value="low quality, worst quality, deformed, distorted, warped, motion smear, motion artifacts, fused fingers, incorrect anatomy, strange hands, unattractive", lines=2, visible=False ) # 현재 선택된 값들을 저장할 상태 변수들 img2vid_current_height = gr.State(value=512) img2vid_current_width = gr.State(value=768) img2vid_current_num_frames = gr.State(value=97) img2vid_preset = gr.Dropdown( choices=[p["label"] for p in preset_options], value="512x320, 10.3초", label="Step 3: 해상도 프리셋 선택", ) img2vid_frame_rate = gr.Slider( label="Step 4: 프레임 레이트", minimum=21, maximum=30, step=1, value=25, visible=False ) img2vid_advanced = create_advanced_options() img2vid_generate = gr.Button( "Step 4: 비디오 생성", variant="primary", size="lg", ) with gr.Column(): img2vid_output = gr.Video(label="생성된 비디오") # Event handlers txt2vid_preset.change( fn=preset_changed, inputs=[txt2vid_preset], outputs=[ txt2vid_current_height, txt2vid_current_width, txt2vid_current_num_frames, *txt2vid_advanced[3:] ] ) txt2vid_enhance_toggle.change( fn=update_prompt_t2v, inputs=[txt2vid_prompt, txt2vid_enhance_toggle], outputs=txt2vid_prompt ) txt2vid_generate.click( fn=generate_video_from_text, inputs=[ txt2vid_prompt, txt2vid_enhance_toggle, txt2vid_negative_prompt, txt2vid_frame_rate, *txt2vid_advanced[:3], # seed, inference_steps, guidance_scale txt2vid_current_height, txt2vid_current_width, txt2vid_current_num_frames, ], outputs=txt2vid_output, concurrency_limit=1, concurrency_id="generate_video", queue=True, ) img2vid_preset.change( fn=preset_changed, inputs=[img2vid_preset], outputs=[ img2vid_current_height, img2vid_current_width, img2vid_current_num_frames, *img2vid_advanced[3:] ] ) img2vid_enhance_toggle.change( fn=update_prompt_i2v, inputs=[img2vid_prompt, img2vid_enhance_toggle], outputs=img2vid_prompt ) img2vid_generate.click( fn=generate_video_from_image, inputs=[ img2vid_image, img2vid_prompt, img2vid_enhance_toggle, img2vid_negative_prompt, img2vid_frame_rate, *img2vid_advanced[:3], # seed, inference_steps, guidance_scale img2vid_current_height, img2vid_current_width, img2vid_current_num_frames, ], outputs=img2vid_output, concurrency_limit=1, concurrency_id="generate_video", queue=True, ) if __name__ == "__main__": iface.queue(max_size=64, default_concurrency_limit=1, api_open=False).launch( share=True, show_api=False )