maxin-cn commited on
Commit
6cf5463
1 Parent(s): fb85fbd

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. demo.py +8 -3
  2. sample_videos/temp.mp4 +0 -0
  3. temp.py +309 -0
demo.py CHANGED
@@ -232,6 +232,13 @@ with gr.Blocks() as demo:
232
  generate_button = gr.Button(value="Generate", variant='primary')
233
 
234
  with gr.Accordion("Advanced options", open=False):
 
 
 
 
 
 
 
235
  with gr.Column():
236
  with gr.Row():
237
  input_image_path = gr.Textbox(label="Input Image URL", lines=1, scale=10, info="Press Enter or the Preview button to confirm the input image.")
@@ -299,6 +306,4 @@ with gr.Blocks() as demo:
299
  outputs=[result_video]
300
  )
301
 
302
- demo.launch(debug=False, share=True)
303
-
304
- # demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True)
 
232
  generate_button = gr.Button(value="Generate", variant='primary')
233
 
234
  with gr.Accordion("Advanced options", open=False):
235
+ gr.Markdown(
236
+ """
237
+ - Input image can be specified using the "Input Image URL" text box or uploaded by clicking or dragging the image to the "Input Image" box.
238
+ - Input image will be resized and/or center cropped to a given resolution (320 x 512) automatically.
239
+ - After setting the input image path, press the "Preview" button to visualize the resized input image.
240
+ """
241
+ )
242
  with gr.Column():
243
  with gr.Row():
244
  input_image_path = gr.Textbox(label="Input Image URL", lines=1, scale=10, info="Press Enter or the Preview button to confirm the input image.")
 
306
  outputs=[result_video]
307
  )
308
 
309
+ demo.launch(debug=False, share=True)
 
 
sample_videos/temp.mp4 CHANGED
Binary files a/sample_videos/temp.mp4 and b/sample_videos/temp.mp4 differ
 
temp.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import argparse
5
+ import torchvision
6
+
7
+
8
+ from pipelines.pipeline_videogen import VideoGenPipeline
9
+ from diffusers.schedulers import DDIMScheduler
10
+ from diffusers.models import AutoencoderKL
11
+ from diffusers.models import AutoencoderKLTemporalDecoder
12
+ from transformers import CLIPTokenizer, CLIPTextModel
13
+ from omegaconf import OmegaConf
14
+
15
+ import os, sys
16
+ sys.path.append(os.path.split(sys.path[0])[0])
17
+ from models import get_models
18
+ import imageio
19
+ from PIL import Image
20
+ import numpy as np
21
+ from datasets import video_transforms
22
+ from torchvision import transforms
23
+ from einops import rearrange, repeat
24
+ from utils import dct_low_pass_filter, exchanged_mixed_dct_freq
25
+ from copy import deepcopy
26
+ import spaces
27
+ import requests
28
+ from datetime import datetime
29
+ import random
30
+
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument("--config", type=str, default="./configs/sample.yaml")
33
+ args = parser.parse_args()
34
+ args = OmegaConf.load(args.config)
35
+
36
+ torch.set_grad_enabled(False)
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ dtype = torch.float16 # torch.float16
39
+
40
+ unet = get_models(args).to(device, dtype=dtype)
41
+
42
+ if args.enable_vae_temporal_decoder:
43
+ if args.use_dct:
44
+ vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float64).to(device)
45
+ else:
46
+ vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
47
+ vae = deepcopy(vae_for_base_content).to(dtype=dtype)
48
+ else:
49
+ vae_for_base_content = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae",).to(device, dtype=torch.float64)
50
+ vae = deepcopy(vae_for_base_content).to(dtype=dtype)
51
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
52
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device) # huge
53
+
54
+ # set eval mode
55
+ unet.eval()
56
+ vae.eval()
57
+ text_encoder.eval()
58
+
59
+ basedir = os.getcwd()
60
+ savedir = os.path.join(basedir, "samples/Gradio", datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
61
+ savedir_sample = os.path.join(savedir, "sample")
62
+ os.makedirs(savedir, exist_ok=True)
63
+
64
+ def update_and_resize_image(input_image_path, height_slider, width_slider):
65
+ if input_image_path.startswith("http://") or input_image_path.startswith("https://"):
66
+ pil_image = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
67
+ else:
68
+ pil_image = Image.open(input_image_path).convert('RGB')
69
+
70
+ original_width, original_height = pil_image.size
71
+
72
+ if original_height == height_slider and original_width == width_slider:
73
+ return gr.Image(value=np.array(pil_image))
74
+
75
+ ratio1 = height_slider / original_height
76
+ ratio2 = width_slider / original_width
77
+
78
+ if ratio1 > ratio2:
79
+ new_width = int(original_width * ratio1)
80
+ new_height = int(original_height * ratio1)
81
+ else:
82
+ new_width = int(original_width * ratio2)
83
+ new_height = int(original_height * ratio2)
84
+
85
+ pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
86
+
87
+ left = (new_width - width_slider) / 2
88
+ top = (new_height - height_slider) / 2
89
+ right = left + width_slider
90
+ bottom = top + height_slider
91
+
92
+ pil_image = pil_image.crop((left, top, right, bottom))
93
+
94
+ return gr.Image(value=np.array(pil_image))
95
+
96
+
97
+ def update_textbox_and_save_image(input_image, height_slider, width_slider):
98
+ pil_image = Image.fromarray(input_image.astype(np.uint8)).convert("RGB")
99
+
100
+ original_width, original_height = pil_image.size
101
+
102
+ ratio1 = height_slider / original_height
103
+ ratio2 = width_slider / original_width
104
+
105
+ if ratio1 > ratio2:
106
+ new_width = int(original_width * ratio1)
107
+ new_height = int(original_height * ratio1)
108
+ else:
109
+ new_width = int(original_width * ratio2)
110
+ new_height = int(original_height * ratio2)
111
+
112
+ pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
113
+
114
+ left = (new_width - width_slider) / 2
115
+ top = (new_height - height_slider) / 2
116
+ right = left + width_slider
117
+ bottom = top + height_slider
118
+
119
+ pil_image = pil_image.crop((left, top, right, bottom))
120
+
121
+ img_path = os.path.join(savedir, "input_image.png")
122
+ pil_image.save(img_path)
123
+
124
+ return gr.Textbox(value=img_path), gr.Image(value=np.array(pil_image))
125
+
126
+ def prepare_image(image, vae, transform_video, device, dtype=torch.float16):
127
+ image = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0).permute(0, 3, 1, 2)
128
+ image = transform_video(image)
129
+ image = vae.encode(image.to(dtype=dtype, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor)
130
+ image = image.unsqueeze(2)
131
+ return image
132
+
133
+
134
+ @spaces.GPU
135
+ def gen_video(input_image, prompt, negative_prompt, diffusion_step, height, width, scfg_scale, use_dctinit, dct_coefficients, noise_level, motion_bucket_id, seed):
136
+
137
+ torch.manual_seed(seed)
138
+
139
+ scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
140
+ subfolder="scheduler",
141
+ beta_start=args.beta_start,
142
+ beta_end=args.beta_end,
143
+ beta_schedule=args.beta_schedule)
144
+
145
+ videogen_pipeline = VideoGenPipeline(vae=vae,
146
+ text_encoder=text_encoder,
147
+ tokenizer=tokenizer,
148
+ scheduler=scheduler,
149
+ unet=unet).to(device)
150
+ # videogen_pipeline.enable_xformers_memory_efficient_attention()
151
+
152
+ transform_video = transforms.Compose([
153
+ video_transforms.ToTensorVideo(),
154
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
155
+ ])
156
+
157
+ if args.use_dct:
158
+ base_content = prepare_image(input_image, vae_for_base_content, transform_video, device, dtype=torch.float64).to(device)
159
+ else:
160
+ base_content = prepare_image(input_image, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device)
161
+
162
+ if use_dctinit:
163
+ # filter params
164
+ print("Using DCT!")
165
+ base_content_repeat = repeat(base_content, 'b c f h w -> b c (f r) h w', r=15).contiguous()
166
+
167
+ # define filter
168
+ freq_filter = dct_low_pass_filter(dct_coefficients=base_content, percentage=dct_coefficients)
169
+
170
+ noise = torch.randn(1, 4, 15, 40, 64).to(device)
171
+
172
+ # add noise to base_content
173
+ diffuse_timesteps = torch.full((1,),int(noise_level))
174
+ diffuse_timesteps = diffuse_timesteps.long()
175
+
176
+ # 3d content
177
+ base_content_noise = scheduler.add_noise(
178
+ original_samples=base_content_repeat.to(device),
179
+ noise=noise,
180
+ timesteps=diffuse_timesteps.to(device))
181
+
182
+ # 3d content
183
+ latents = exchanged_mixed_dct_freq(noise=noise,
184
+ base_content=base_content_noise,
185
+ LPF_3d=freq_filter).to(dtype=torch.float16)
186
+
187
+ base_content = base_content.to(dtype=torch.float16)
188
+
189
+ videos = videogen_pipeline(prompt,
190
+ negative_prompt=negative_prompt,
191
+ latents=latents if use_dctinit else None,
192
+ base_content=base_content,
193
+ video_length=15,
194
+ height=height,
195
+ width=width,
196
+ num_inference_steps=diffusion_step,
197
+ guidance_scale=scfg_scale,
198
+ motion_bucket_id=100-motion_bucket_id,
199
+ enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
200
+
201
+ save_path = args.save_img_path + 'temp' + '.mp4'
202
+ # torchvision.io.write_video(save_path, videos[0], fps=8, video_codec='h264', options={'crf': '10'})
203
+ imageio.mimwrite(save_path, videos[0], fps=8, quality=7)
204
+ return save_path
205
+
206
+
207
+ if not os.path.exists(args.save_img_path):
208
+ os.makedirs(args.save_img_path)
209
+
210
+
211
+ with gr.Blocks() as demo:
212
+
213
+ gr.Markdown("<font color=red size=6.5><center>Cinemo: Consistent and Controllable Image Animation with Motion Diffusion Models</center></font>")
214
+ gr.Markdown(
215
+ """<div style="display: flex;align-items: center;justify-content: center">
216
+ [<a href="https://arxiv.org/abs/2407.15642">Arxiv Report</a>] | [<a href="https://https://maxin-cn.github.io/cinemo_project/">Project Page</a>] | [<a href="https://github.com/maxin-cn/Cinemo">Github</a>]</div>
217
+ """
218
+ )
219
+
220
+
221
+ with gr.Column(variant="panel"):
222
+ with gr.Row():
223
+ prompt_textbox = gr.Textbox(label="Prompt", lines=1)
224
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=1)
225
+
226
+ with gr.Row(equal_height=False):
227
+ with gr.Column():
228
+ with gr.Row():
229
+ input_image = gr.Image(label="Input Image", interactive=True)
230
+ result_video = gr.Video(label="Generated Animation", interactive=False, autoplay=True)
231
+
232
+ generate_button = gr.Button(value="Generate", variant='primary')
233
+
234
+ with gr.Accordion("Advanced options", open=False):
235
+ gr.Markdown(
236
+ """
237
+ - Input image can be specified using the "Input Image URL" text box or uploaded by clicking or dragging the image to the "Input Image" box.
238
+ - Input image will be resized and/or center cropped to a given resolution (320 x 512) automatically.
239
+ - After setting the input image path, press the "Preview" button to visualize the resized input image.
240
+ """
241
+ )
242
+ with gr.Column():
243
+ with gr.Row():
244
+ input_image_path = gr.Textbox(label="Input Image URL", lines=1, scale=10, info="Press Enter or the Preview button to confirm the input image.")
245
+ preview_button = gr.Button(value="Preview")
246
+
247
+ with gr.Row():
248
+ sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=250, step=1)
249
+
250
+ with gr.Row():
251
+ seed_textbox = gr.Slider(label="Seed", value=100, minimum=1, maximum=int(1e8), step=1, interactive=True)
252
+ # seed_textbox = gr.Textbox(label="Seed", value=100)
253
+ # seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
254
+ # seed_button.click(fn=lambda: gr.Textbox(value=random.randint(1, int(1e8))), inputs=[], outputs=[seed_textbox])
255
+
256
+ with gr.Row():
257
+ height = gr.Slider(label="Height", value=320, minimum=0, maximum=512, step=16, interactive=False)
258
+ width = gr.Slider(label="Width", value=512, minimum=0, maximum=512, step=16, interactive=False)
259
+ with gr.Row():
260
+ txt_cfg_scale = gr.Slider(label="CFG Scale", value=7.5, minimum=1.0, maximum=20.0, step=0.1, interactive=True)
261
+ motion_bucket_id = gr.Slider(label="Motion Intensity", value=10, minimum=1, maximum=20, step=1, interactive=True)
262
+
263
+ with gr.Row():
264
+ use_dctinit = gr.Checkbox(label="Enable DCTInit", value=True)
265
+ dct_coefficients = gr.Slider(label="DCT Coefficients", value=0.23, minimum=0, maximum=1, step=0.01, interactive=True)
266
+ noise_level = gr.Slider(label="Noise Level", value=985, minimum=1, maximum=999, step=1, interactive=True)
267
+
268
+ input_image.upload(fn=update_textbox_and_save_image, inputs=[input_image, height, width], outputs=[input_image_path, input_image])
269
+ preview_button.click(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
270
+ input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
271
+
272
+ EXAMPLES = [
273
+ ["./example/aircrafts_flying/0.jpg", "aircrafts flying" , 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
274
+ ["./example/fireworks/0.jpg", "fireworks" , 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
275
+ ["./example/flowers_swaying/0.jpg", "flowers swaying" , 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
276
+ ["./example/girl_walking_on_the_beach/0.jpg", "girl walking on the beach" , 50, 320, 512, 7.5, True, 0.23, 985, 10, 200],
277
+ ["./example/house_rotating/0.jpg", "house rotating" , 50, 320, 512, 7.5, True, 0.23, 985, 10, 100],
278
+ ["./example/people_runing/0.jpg", "people runing" , 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
279
+ ]
280
+
281
+ examples = gr.Examples(
282
+ examples = EXAMPLES,
283
+ fn = gen_video,
284
+ inputs=[input_image, prompt_textbox, sample_step_slider, height, width, txt_cfg_scale, use_dctinit, dct_coefficients, noise_level, motion_bucket_id, seed_textbox],
285
+ outputs=[result_video],
286
+ # cache_examples=True,
287
+ cache_examples="lazy",
288
+ )
289
+
290
+ generate_button.click(
291
+ fn=gen_video,
292
+ inputs=[
293
+ input_image,
294
+ prompt_textbox,
295
+ negative_prompt_textbox,
296
+ sample_step_slider,
297
+ height,
298
+ width,
299
+ txt_cfg_scale,
300
+ use_dctinit,
301
+ dct_coefficients,
302
+ noise_level,
303
+ motion_bucket_id,
304
+ seed_textbox,
305
+ ],
306
+ outputs=[result_video]
307
+ )
308
+
309
+ demo.launch(debug=False, share=True)