ansel3911 commited on
Commit
d6439c9
1 Parent(s): d97058c

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +97 -0
  2. requirements.txt +0 -0
  3. utils.py +7 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision
4
+ from diffusers import I2VGenXLPipeline, DiffusionPipeline
5
+ from torchvision.transforms.functional import to_tensor
6
+ from PIL import Image
7
+ from utils import create_progress_updater
8
+
9
+ if gr.NO_RELOAD:
10
+ n_sdxl_steps = 50
11
+ n_i2v_steps = 50
12
+ high_noise_frac = 0.8
13
+ negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
14
+ generator = torch.manual_seed(8888)
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ total_steps = n_sdxl_steps + n_i2v_steps
17
+ print("Device:", device)
18
+
19
+ base = DiffusionPipeline.from_pretrained(
20
+ "stabilityai/stable-diffusion-xl-base-1.0",
21
+ torch_dtype=torch.float16,
22
+ variant="fp16",
23
+ use_safetensors=True,
24
+ )
25
+ # refiner = DiffusionPipeline.from_pretrained(
26
+ # "stabilityai/stable-diffusion-xl-refiner-1.0",
27
+ # text_encoder_2=base.text_encoder_2,
28
+ # vae=base.vae,
29
+ # torch_dtype=torch.float16,
30
+ # use_safetensors=True,
31
+ # variant="fp16",
32
+ # )
33
+ pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
34
+
35
+ # base.to("cuda")
36
+ # refiner.to("cuda")
37
+ # pipeline.to("cuda")
38
+
39
+ # base.unet = torch.compile(base.unet, mode="reduce-overhead", fullgraph=True)
40
+ # refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
41
+ # pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
42
+ base.enable_model_cpu_offload()
43
+ pipeline.enable_model_cpu_offload()
44
+ pipeline.unet.enable_forward_chunking()
45
+
46
+ def generate(prompt: str, progress=gr.Progress()):
47
+ progress((0, 100), desc="Generating first frame...")
48
+ image = base(
49
+ prompt=prompt,
50
+ num_inference_steps=n_sdxl_steps,
51
+ callback_on_step_end=create_progress_updater(
52
+ start=0,
53
+ total=total_steps,
54
+ desc="Generating first frame...",
55
+ progress=progress,
56
+ ),
57
+ ).images[0]
58
+ # progress((n_sdxl_steps * high_noise_frac, total_steps), desc="Refining first frame...")
59
+ # image = refiner(
60
+ # prompt=prompt,
61
+ # num_inference_steps=n_sdxl_steps,
62
+ # denoising_start=high_noise_frac,
63
+ # image=image,
64
+ # callback_on_step_end=create_progress_updater(
65
+ # start=n_sdxl_steps * high_noise_frac,
66
+ # total=total_steps,
67
+ # desc="Refining first frame...",
68
+ # progress=progress,
69
+ # ),
70
+ # ).images[0]
71
+ image = to_tensor(image)
72
+ progress((n_sdxl_steps, total_steps), desc="Generating video...")
73
+ frames: list[Image.Image] = pipeline(
74
+ prompt=prompt,
75
+ image=image,
76
+ num_inference_steps=50,
77
+ negative_prompt=negative_prompt,
78
+ guidance_scale=9.0,
79
+ generator=generator,
80
+ decode_chunk_size=2,
81
+ num_frames=32,
82
+ ).frames[0]
83
+ progress((total_steps - 1, total_steps), desc="Finalizing...")
84
+ frames = [to_tensor(frame.convert("RGB")).mul(255).byte().permute(1, 2, 0) for frame in frames]
85
+ frames = torch.stack(frames)
86
+ torchvision.io.write_video("video.mp4", frames, fps=8)
87
+ return "video.mp4"
88
+
89
+ app = gr.Interface(
90
+ fn=generate,
91
+ inputs=["text"],
92
+ outputs=gr.Video()
93
+ )
94
+
95
+ if __name__ == "__main__":
96
+ app.launch()
97
+
requirements.txt ADDED
Binary file (2.99 kB). View file
 
utils.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from gradio import Progress
2
+
3
+ def create_progress_updater(start: int, total: int, desc: str, progress: Progress):
4
+ def updater(pipe, step, timestep, callback_kwargs):
5
+ progress((step + start, total), desc=desc)
6
+ return callback_kwargs
7
+ return updater