tsqn's picture
chore: ZeroGPU
b9a7927
import gc
import os
import torch
import spaces
import gradio as gr
from diffusers import LattePipeline
from transformers import T5EncoderModel, BitsAndBytesConfig
import imageio
from torchvision.utils import save_image
def flush():
gc.collect()
torch.cuda.empty_cache()
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
def initialize_pipeline():
model_id = "maxin-cn/Latte-1"
text_encoder = T5EncoderModel.from_pretrained(
model_id,
subfolder="text_encoder",
quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16),
device_map="auto",
)
pipe = LattePipeline.from_pretrained(
model_id,
text_encoder=text_encoder,
transformer=None,
device_map="balanced",
)
return pipe, text_encoder
@spaces.GPU(duration=120)
def generate_video(
prompt: str,
negative_prompt: str = "",
video_length: int = 16,
num_inference_steps: int = 50,
progress=gr.Progress()
):
# Set random seed for reproducibility
torch.manual_seed(0)
# Initialize the pipeline
progress(0, desc="Initializing pipeline...")
pipe, text_encoder = initialize_pipeline()
# Generate prompt embeddings
progress(0.2, desc="Encoding prompt...")
with torch.no_grad():
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
prompt,
negative_prompt=negative_prompt
)
# Clean up first pipeline
progress(0.3, desc="Cleaning up...")
del text_encoder
del pipe
flush()
# Initialize the second pipeline
progress(0.4, desc="Initializing generation pipeline...")
pipe = LattePipeline.from_pretrained(
"maxin-cn/Latte-1",
text_encoder=None,
torch_dtype=torch.float16,
).to("cuda")
# Generate video
progress(0.5, desc="Generating video...")
videos = pipe(
video_length=video_length,
num_inference_steps=num_inference_steps,
negative_prompt=None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
output_type="pt",
).frames.cpu()
progress(0.8, desc="Post-processing...")
# Convert to video format
videos = (videos.clamp(0, 1) * 255).to(dtype=torch.uint8)
# Save temporary file
temp_output = "temp_output.mp4"
imageio.mimwrite(
temp_output,
videos[0].permute(0, 2, 3, 1),
fps=8,
quality=5
)
# Clean up
progress(0.9, desc="Cleaning up...")
del pipe
flush()
progress(1.0, desc="Done!")
return temp_output
def create_demo():
with gr.Blocks() as demo:
gr.Markdown("""
# Latte Video Generation
Generate short videos using the Latte-1 model.
""")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
value="a cat wearing sunglasses and working as a lifeguard at pool.",
info="Describe what you want to generate"
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="",
info="What you don't want to see in the generation"
)
video_length = gr.Slider(
minimum=8,
maximum=32,
step=8,
value=16,
label="Video Length (frames)"
)
steps = gr.Slider(
minimum=20,
maximum=100,
step=10,
value=50,
label="Number of Inference Steps"
)
generate_btn = gr.Button("Generate Video")
with gr.Column():
output_video = gr.Video(label="Generated Video")
generate_btn.click(
fn=generate_video,
inputs=[prompt, negative_prompt, video_length, steps],
outputs=output_video
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.queue()
demo.launch(share=False)