File size: 17,216 Bytes
e574f5a
 
 
 
 
 
 
 
471b97b
e574f5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471b97b
e574f5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471b97b
e574f5a
471b97b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
import torch
import gradio as gr
from gradio.themes.utils import colors, fonts, sizes
import argparse
from omegaconf import OmegaConf
import os
from models import get_models
from diffusers.utils.import_utils import is_xformers_available
from vlogger.STEB.model_transform import tca_transform_model, ip_scale_set, ip_transform_model
from diffusers.models import AutoencoderKL
from models.clip import TextEmbedder
from datasets import video_transforms
from torchvision import transforms
from utils import mask_generation_before
from backend import auto_inpainting
from einops import rearrange
import torchvision
import sys
from PIL import Image
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from transformers.image_transforms import convert_to_rgb
try:
    import utils

    from diffusion import create_diffusion
    from download import find_model
except:
    sys.path.append(os.path.split(sys.path[0])[0])
    
    import utils

    from diffusion import create_diffusion
    from download import find_model
    
    
def auto_inpainting(video_input, masked_video, mask, prompt, image, vae, text_encoder, image_encoder, diffusion, model, device, cfg_scale, img_cfg_scale, negative_prompt=""):
    global use_fp16
    image_prompt_embeds = None
    if prompt is None:
        prompt = ""
    if image is not None:
        clip_image = clip_image_processor(images=image, return_tensors="pt").pixel_values
        clip_image_embeds = image_encoder(clip_image.to(device)).image_embeds
        uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds).to(device)
        image_prompt_embeds = torch.cat([clip_image_embeds, uncond_clip_image_embeds], dim=0)
        image_prompt_embeds = rearrange(image_prompt_embeds, '(b n) c -> b n c', b=2).contiguous()
        model = ip_scale_set(model, img_cfg_scale)
        if use_fp16:
            image_prompt_embeds = image_prompt_embeds.to(dtype=torch.float16)
    b, f, c, h, w = video_input.shape
    latent_h = video_input.shape[-2] // 8
    latent_w = video_input.shape[-1] // 8

    if use_fp16:
        z = torch.randn(1, 4, 16, latent_h, latent_w, dtype=torch.float16, device=device) # b,c,f,h,w
        masked_video = masked_video.to(dtype=torch.float16)
        mask = mask.to(dtype=torch.float16)
    else:
        z = torch.randn(1, 4, 16, latent_h, latent_w, device=device) # b,c,f,h,w

    masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
    masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
    masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
    mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
    masked_video = torch.cat([masked_video] * 2)
    mask = torch.cat([mask] * 2)
    z = torch.cat([z] * 2)
    prompt_all = [prompt] + [negative_prompt]

    text_prompt = text_encoder(text_prompts=prompt_all, train=False)
    model_kwargs = dict(encoder_hidden_states=text_prompt, 
                        class_labels=None, 
                        cfg_scale=cfg_scale,
                        use_fp16=use_fp16,
                        ip_hidden_states=image_prompt_embeds)
    
    # Sample images:
    samples = diffusion.ddim_sample_loop(
        model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \
        mask=mask, x_start=masked_video, use_concat=True
    )
    samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
    if use_fp16:
        samples = samples.to(dtype=torch.float16)

    video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
    video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
    return video_clip


def auto_inpainting_temp_split(video_input, masked_video, mask, prompt, image, vae, text_encoder, image_encoder, diffusion, model, device, scfg_scale, tcfg_scale, img_cfg_scale, negative_prompt=""):
    global use_fp16
    image_prompt_embeds = None
    if prompt is None:
        prompt = ""
    if image is not None:
        clip_image = clip_image_processor(images=image, return_tensors="pt").pixel_values
        clip_image_embeds = image_encoder(clip_image.to(device)).image_embeds
        uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds).to(device)
        image_prompt_embeds = torch.cat([clip_image_embeds, clip_image_embeds, uncond_clip_image_embeds], dim=0)
        image_prompt_embeds = rearrange(image_prompt_embeds, '(b n) c -> b n c', b=3).contiguous()
        model = ip_scale_set(model, img_cfg_scale)
        if use_fp16:
            image_prompt_embeds = image_prompt_embeds.to(dtype=torch.float16)
    b, f, c, h, w = video_input.shape
    latent_h = video_input.shape[-2] // 8
    latent_w = video_input.shape[-1] // 8

    if use_fp16:
        z = torch.randn(1, 4, 16, latent_h, latent_w, dtype=torch.float16, device=device) # b,c,f,h,w
        masked_video = masked_video.to(dtype=torch.float16)
        mask = mask.to(dtype=torch.float16)
    else:
        z = torch.randn(1, 4, 16, latent_h, latent_w, device=device) # b,c,f,h,w

    masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
    masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
    masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
    mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
    masked_video = torch.cat([masked_video] * 3)
    mask = torch.cat([mask] * 3)
    z = torch.cat([z] * 3)
    prompt_all = [prompt] + [prompt] + [negative_prompt]
    prompt_temp = [prompt] + [""] + [""]

    text_prompt = text_encoder(text_prompts=prompt_all, train=False)
    temporal_text_prompt = text_encoder(text_prompts=prompt_temp, train=False)
    model_kwargs = dict(encoder_hidden_states=text_prompt, 
                        class_labels=None, 
                        scfg_scale=scfg_scale,
                        tcfg_scale=tcfg_scale,
                        use_fp16=use_fp16,
                        ip_hidden_states=image_prompt_embeds,
                        encoder_temporal_hidden_states=temporal_text_prompt)
    
    # Sample images:
    samples = diffusion.ddim_sample_loop(
        model.forward_with_cfg_temp_split, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, \
        mask=mask, x_start=masked_video, use_concat=True
    )
    samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
    if use_fp16:
        samples = samples.to(dtype=torch.float16)

    video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
    video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
    return video_clip


# ========================================
#             Model Initialization
# ========================================
device = None
output_path = None
use_fp16 = False
model = None
vae = None
text_encoder = None
image_encoder = None
clip_image_processor = None
def init_model():
    global device
    global output_path
    global use_fp16
    global model
    global diffusion
    global vae
    global text_encoder
    global image_encoder
    global clip_image_processor
    print('Initializing ShowMaker', flush=True)
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="./configs/with_mask_ref_sample.yaml")
    args = parser.parse_args()
    args = OmegaConf.load(args.config)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    output_path = args.save_img_path
    # Load model:
    latent_h = args.image_size[0] // 8
    latent_w = args.image_size[1] // 8
    args.image_h = args.image_size[0]
    args.image_w = args.image_size[1]
    args.latent_h = latent_h
    args.latent_w = latent_w
    print('loading model')
    model = get_models(True, args).to(device)
    model = tca_transform_model(model).to(device)
    model = ip_transform_model(model).to(device)
    if args.use_compile:
        model = torch.compile(model)
    if args.enable_xformers_memory_efficient_attention:
        if is_xformers_available():
            model.enable_xformers_memory_efficient_attention()
            print("xformer!")
        else:
            raise ValueError("xformers is not available. Make sure it is installed correctly")
    ckpt_path = args.ckpt
    state_dict = find_model(ckpt_path)
    model.load_state_dict(state_dict)
    print('loading succeed')
    model.eval()  # important!
    pretrained_model_path = args.pretrained_model_path
    vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
    text_encoder = TextEmbedder(tokenizer_path=pretrained_model_path + "tokenizer",
                                encoder_path=pretrained_model_path + "text_encoder").to(device)
    image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path).to(device)
    clip_image_processor = CLIPImageProcessor()
    if args.use_fp16:
        print('Warnning: using half percision for inferencing!')
        vae.to(dtype=torch.float16)
        model.to(dtype=torch.float16)
        text_encoder.to(dtype=torch.float16)
        image_encoder.to(dtype=torch.float16)
        use_fp16 = True
    print('Initialization Finished')
init_model()


# ========================================
#             Video Generation
# ========================================
def video_generation(text, image, scfg_scale, tcfg_scale, img_cfg_scale, diffusion):
    with torch.no_grad():
        print("begin generation", flush=True)
        transform_video = transforms.Compose([
            video_transforms.ToTensorVideo(), # TCHW
            video_transforms.WebVideo320512((320, 512)),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
        ])
        video_frames = torch.zeros(16, 3, 320, 512, dtype=torch.uint8)
        video_frames = transform_video(video_frames)
        video_input = video_frames.to(device).unsqueeze(0) # b,f,c,h,w
        mask = mask_generation_before("all", video_input.shape, video_input.dtype, device)
        masked_video = video_input * (mask == 0)
        if image is not None:
            print(image.shape, flush=True)
            # image = Image.open(image)
        if scfg_scale == tcfg_scale:
            video_clip = auto_inpainting(video_input, masked_video, mask, text, image, vae, text_encoder, image_encoder, diffusion, model, device, scfg_scale, img_cfg_scale)
        else:
            video_clip = auto_inpainting_temp_split(video_input, masked_video, mask, text, image, vae, text_encoder, image_encoder, diffusion, model, device, scfg_scale, tcfg_scale, img_cfg_scale)
        video_clip = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
        video_path = os.path.join(output_path, 'video.mp4')
        torchvision.io.write_video(video_path, video_clip, fps=8)
        return video_path
    
    
# ========================================
#             Video Prediction
# ========================================
def video_prediction(text, image, scfg_scale, tcfg_scale, img_cfg_scale, preframe, diffusion):
    with torch.no_grad():
        print("begin generation", flush=True)
        transform_video = transforms.Compose([
            video_transforms.ToTensorVideo(), # TCHW
            # video_transforms.WebVideo320512((320, 512)),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
        ])
        preframe = torch.as_tensor(convert_to_rgb(preframe)).unsqueeze(0)
        zeros = torch.zeros_like(preframe)
        video_frames = torch.cat([preframe] + [zeros] * 15, dim=0).permute(0, 3, 1, 2)
        H_scale = 320 / video_frames.shape[2]
        W_scale = 512 / video_frames.shape[3]
        scale_ = H_scale
        if W_scale < H_scale:
            scale_ = W_scale
        video_frames = torch.nn.functional.interpolate(video_frames, scale_factor=scale_, mode="bilinear", align_corners=False)
        video_frames = transform_video(video_frames)
        video_input = video_frames.to(device).unsqueeze(0) # b,f,c,h,w
        mask = mask_generation_before("first1", video_input.shape, video_input.dtype, device)
        masked_video = video_input * (mask == 0)
        if image is not None:
            print(image.shape, flush=True)
        if scfg_scale == tcfg_scale:
            video_clip = auto_inpainting(video_input, masked_video, mask, text, image, vae, text_encoder, image_encoder, diffusion, model, device, scfg_scale, img_cfg_scale)
        else:
            video_clip = auto_inpainting_temp_split(video_input, masked_video, mask, text, image, vae, text_encoder, image_encoder, diffusion, model, device, scfg_scale, tcfg_scale, img_cfg_scale)
        video_clip = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
        video_path = os.path.join(output_path, 'video.mp4')
        torchvision.io.write_video(video_path, video_clip, fps=8)
        return video_path


# ========================================
#      Judge Generation or Prediction
# ========================================
def gen_or_pre(text_input, image_input, scfg_scale, tcfg_scale, img_cfg_scale, preframe_input, diffusion_step):
    default_step = [25, 40, 50, 100, 125, 200, 250]
    difference = [abs(item - diffusion_step) for item in default_step]
    diffusion_step = default_step[difference.index(min(difference))]
    diffusion = create_diffusion(str(diffusion_step))
    if preframe_input is None:
        return video_generation(text_input, image_input, scfg_scale, tcfg_scale, img_cfg_scale, diffusion)
    else:
        return video_prediction(text_input, image_input, scfg_scale, tcfg_scale, img_cfg_scale, preframe_input, diffusion)


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(visible=True) as input_raws:
            with gr.Row():
                with gr.Column(scale=1.0):
                    text_input = gr.Textbox(show_label=True, interactive=True, label="Text prompt").style(container=False)
            with gr.Row():
                with gr.Column(scale=0.5):
                    image_input = gr.Image(show_label=True, interactive=True, label="Reference image").style(container=False)
                with gr.Column(scale=0.5):
                    preframe_input = gr.Image(show_label=True, interactive=True, label="First frame").style(container=False)
            with gr.Row():
                with gr.Column(scale=1.0):
                    scfg_scale = gr.Slider(
                        minimum=1,
                        maximum=50,
                        value=8,
                        step=0.1,
                        interactive=True,
                        label="Spatial Text Guidence Scale",
                    )
            with gr.Row():
                with gr.Column(scale=1.0):
                    tcfg_scale = gr.Slider(
                        minimum=1,
                        maximum=50,
                        value=6.5,
                        step=0.1,
                        interactive=True,
                        label="Temporal Text Guidence Scale",
                    )
            with gr.Row():
                with gr.Column(scale=1.0):
                    img_cfg_scale = gr.Slider(
                        minimum=0,
                        maximum=1,
                        value=0.3,
                        step=0.005,
                        interactive=True,
                        label="Image Guidence Scale",
                    )
            with gr.Row():
                with gr.Column(scale=1.0):
                    diffusion_step = gr.Slider(
                        minimum=20,
                        maximum=250,
                        value=100,
                        step=1,
                        interactive=True,
                        label="Diffusion Step",
                    )
            with gr.Row():
                with gr.Column(scale=0.5, min_width=0):
                    run = gr.Button("πŸ’­Send")
                with gr.Column(scale=0.5, min_width=0):
                    clear = gr.Button("πŸ”„Clear️")     
        with gr.Column(scale=0.5, visible=True) as video_upload:
            output_video = gr.Video(interactive=False, include_audio=True, elem_id="θΎ“ε‡Ίηš„θ§†ι’‘")#.style(height=360)
            # with gr.Column(elem_id="image", scale=0.5) as img_part:
            #     with gr.Tab("Video", elem_id='video_tab'):
                    
            #     with gr.Tab("Image", elem_id='image_tab'):
            #         up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload").style(height=360)
            # upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
            clear = gr.Button("Restart")
    run.click(gen_or_pre, [text_input, image_input, scfg_scale, tcfg_scale, img_cfg_scale, preframe_input, diffusion_step], [output_video])
    
demo.launch(share=True, enable_queue=True)

# demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True)