Koke_Cacao commited on
Commit
c4ac4f3
1 Parent(s): ad07f98

:bug: fix direct load from pipeline, thanks @reynoldscem

Browse files
scripts/attention.py CHANGED
@@ -12,8 +12,8 @@ from typing import Optional, Any
12
  from util import checkpoint
13
 
14
  try:
15
- import xformers
16
- import xformers.ops
17
  XFORMERS_IS_AVAILBLE = True
18
  except:
19
  XFORMERS_IS_AVAILBLE = False
 
12
  from util import checkpoint
13
 
14
  try:
15
+ import xformers # type: ignore
16
+ import xformers.ops # type: ignore
17
  XFORMERS_IS_AVAILBLE = True
18
  except:
19
  XFORMERS_IS_AVAILBLE = False
scripts/convert_mvdream_to_diffusers.py CHANGED
@@ -9,13 +9,13 @@ sys.path.insert(0, '../')
9
  from diffusers.models import (
10
  AutoencoderKL,
11
  )
 
12
  from diffusers.schedulers import DDIMScheduler
13
  from diffusers.utils import logging
14
-
15
  from accelerate import init_empty_weights
16
  from accelerate.utils import set_module_tensor_to_device
17
- from rich import print, print_json
18
- from models import MultiViewUNetModel, MultiViewUNetWrapperModel
19
  from pipeline_mvdream import MVDreamStableDiffusionPipeline
20
  from transformers import CLIPTokenizer, CLIPTextModel
21
 
@@ -259,14 +259,14 @@ def conv_attn_to_linear(checkpoint):
259
  if checkpoint[key].ndim > 2:
260
  checkpoint[key] = checkpoint[key][:, :, 0]
261
 
 
 
262
 
263
  def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, device):
264
  checkpoint = torch.load(checkpoint_path, map_location=device)
265
  # print(f"Checkpoint: {checkpoint.keys()}")
266
  torch.cuda.empty_cache()
267
 
268
- from omegaconf import OmegaConf
269
-
270
  original_config = OmegaConf.load(original_config_file)
271
  # print(f"Original Config: {original_config}")
272
  prediction_type = "epsilon"
@@ -296,11 +296,13 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de
296
  # checkpoint, unet_config, path=None, extract_ema=extract_ema
297
  # )
298
  # print(f"Unet Config: {original_config.model.params.unet_config.params}")
299
- unet: MultiViewUNetWrapperModel = MultiViewUNetWrapperModel(**original_config.model.params.unet_config.params)
 
 
300
  # print(f"Unet State Dict: {unet.state_dict().keys()}")
301
  unet.load_state_dict({key.replace("model.diffusion_model.", "unet."): value for key, value in checkpoint.items() if key.replace("model.diffusion_model.", "unet.") in unet.state_dict()})
302
  for param_name, param in unet.state_dict().items():
303
- set_module_tensor_to_device(unet, param_name, "cuda:0", value=param)
304
 
305
  # Convert the VAE model.
306
  vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
@@ -316,18 +318,18 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de
316
  with init_empty_weights():
317
  vae = AutoencoderKL(**vae_config)
318
 
 
 
 
319
  if original_config.model.params.unet_config.params.context_dim == 768:
320
  tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
321
- text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device=torch.device("cuda:0")) # type: ignore
322
  elif original_config.model.params.unet_config.params.context_dim == 1024:
323
  tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="tokenizer")
324
- text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="text_encoder").to(device=torch.device("cuda:0")) # type: ignore
325
  else:
326
  raise ValueError(f"Unknown context_dim: {original_config.model.paams.unet_config.params.context_dim}")
327
 
328
- for param_name, param in converted_vae_checkpoint.items():
329
- set_module_tensor_to_device(vae, param_name, "cuda:0", value=param)
330
-
331
  pipe = MVDreamStableDiffusionPipeline(
332
  vae=vae,
333
  unet=unet,
@@ -359,6 +361,8 @@ if __name__ == "__main__":
359
  parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
360
  parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
361
  args = parser.parse_args()
 
 
362
 
363
  pipe = convert_from_original_mvdream_ckpt(
364
  checkpoint_path=args.checkpoint_path,
@@ -369,15 +373,36 @@ if __name__ == "__main__":
369
  if args.half:
370
  pipe.to(torch_dtype=torch.float16)
371
 
372
- if args.test:
373
- images = pipe(
374
- prompt="Head of Hatsune Miku",
375
- negative_prompt="painting, bad quality, flat",
376
- output_type="pil",
377
- guidance_scale=7.5,
378
- num_inference_steps=50,
379
- )
380
- for i, image in enumerate(images):
381
- image.save(f"image_{i}.png") # type: ignore
382
-
383
  pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from diffusers.models import (
10
  AutoencoderKL,
11
  )
12
+ from omegaconf import OmegaConf
13
  from diffusers.schedulers import DDIMScheduler
14
  from diffusers.utils import logging
15
+ from typing import Any
16
  from accelerate import init_empty_weights
17
  from accelerate.utils import set_module_tensor_to_device
18
+ from models import MultiViewUNetWrapperModel
 
19
  from pipeline_mvdream import MVDreamStableDiffusionPipeline
20
  from transformers import CLIPTokenizer, CLIPTextModel
21
 
 
259
  if checkpoint[key].ndim > 2:
260
  checkpoint[key] = checkpoint[key][:, :, 0]
261
 
262
+ def create_unet_config(original_config) -> Any:
263
+ return OmegaConf.to_container(original_config.model.params.unet_config.params, resolve=True)
264
 
265
  def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, device):
266
  checkpoint = torch.load(checkpoint_path, map_location=device)
267
  # print(f"Checkpoint: {checkpoint.keys()}")
268
  torch.cuda.empty_cache()
269
 
 
 
270
  original_config = OmegaConf.load(original_config_file)
271
  # print(f"Original Config: {original_config}")
272
  prediction_type = "epsilon"
 
296
  # checkpoint, unet_config, path=None, extract_ema=extract_ema
297
  # )
298
  # print(f"Unet Config: {original_config.model.params.unet_config.params}")
299
+ unet_config = create_unet_config(original_config)
300
+ unet: MultiViewUNetWrapperModel = MultiViewUNetWrapperModel(**unet_config)
301
+ unet.register_to_config(**unet_config)
302
  # print(f"Unet State Dict: {unet.state_dict().keys()}")
303
  unet.load_state_dict({key.replace("model.diffusion_model.", "unet."): value for key, value in checkpoint.items() if key.replace("model.diffusion_model.", "unet.") in unet.state_dict()})
304
  for param_name, param in unet.state_dict().items():
305
+ set_module_tensor_to_device(unet, param_name, device=device, value=param)
306
 
307
  # Convert the VAE model.
308
  vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
 
318
  with init_empty_weights():
319
  vae = AutoencoderKL(**vae_config)
320
 
321
+ for param_name, param in converted_vae_checkpoint.items():
322
+ set_module_tensor_to_device(vae, param_name, device=device, value=param)
323
+
324
  if original_config.model.params.unet_config.params.context_dim == 768:
325
  tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
326
+ text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device=device) # type: ignore
327
  elif original_config.model.params.unet_config.params.context_dim == 1024:
328
  tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="tokenizer")
329
+ text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="text_encoder").to(device=device) # type: ignore
330
  else:
331
  raise ValueError(f"Unknown context_dim: {original_config.model.paams.unet_config.params.context_dim}")
332
 
 
 
 
333
  pipe = MVDreamStableDiffusionPipeline(
334
  vae=vae,
335
  unet=unet,
 
361
  parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
362
  parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
363
  args = parser.parse_args()
364
+
365
+ args.device = torch.device(args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu")
366
 
367
  pipe = convert_from_original_mvdream_ckpt(
368
  checkpoint_path=args.checkpoint_path,
 
373
  if args.half:
374
  pipe.to(torch_dtype=torch.float16)
375
 
376
+ print(f"Saving pipeline to {args.dump_path}...")
 
 
 
 
 
 
 
 
 
 
377
  pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
378
+
379
+ if args.test:
380
+ try:
381
+ print(f"Testing each subcomponent of the pipeline...")
382
+ images = pipe(
383
+ prompt="Head of Hatsune Miku",
384
+ negative_prompt="painting, bad quality, flat",
385
+ output_type="pil",
386
+ guidance_scale=7.5,
387
+ num_inference_steps=50,
388
+ device=args.device,
389
+ )
390
+ for i, image in enumerate(images):
391
+ image.save(f"image_{i}.png") # type: ignore
392
+
393
+ print(f"Testing entire pipeline...")
394
+ loaded_pipe: MVDreamStableDiffusionPipeline = MVDreamStableDiffusionPipeline.from_pretrained(args.dump_path, safe_serialization=args.to_safetensors) # type: ignore
395
+ images = loaded_pipe(
396
+ prompt="Head of Hatsune Miku",
397
+ negative_prompt="painting, bad quality, flat",
398
+ output_type="pil",
399
+ guidance_scale=7.5,
400
+ num_inference_steps=50,
401
+ device=args.device,
402
+ )
403
+ for i, image in enumerate(images):
404
+ image.save(f"image_{i}.png") # type: ignore
405
+ except Exception as e:
406
+ print(f"Failed to test inference: {e}")
407
+ raise e from e
408
+ print("Inference test passed!")
scripts/models.py CHANGED
@@ -25,9 +25,72 @@ from torch import Tensor
25
 
26
  class MultiViewUNetWrapperModel(ModelMixin, ConfigMixin):
27
 
28
- def __init__(self, *args, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  super().__init__()
30
- self.unet: MultiViewUNetModel = MultiViewUNetModel(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def forward(self, *args, **kwargs):
33
  return self.unet(*args, **kwargs)
 
25
 
26
  class MultiViewUNetWrapperModel(ModelMixin, ConfigMixin):
27
 
28
+ def __init__(self,
29
+ image_size,
30
+ in_channels,
31
+ model_channels,
32
+ out_channels,
33
+ num_res_blocks,
34
+ attention_resolutions,
35
+ dropout=0,
36
+ channel_mult=(1, 2, 4, 8),
37
+ conv_resample=True,
38
+ dims=2,
39
+ num_classes=None,
40
+ use_checkpoint=False,
41
+ use_fp16=False,
42
+ use_bf16=False,
43
+ num_heads=-1,
44
+ num_head_channels=-1,
45
+ num_heads_upsample=-1,
46
+ use_scale_shift_norm=False,
47
+ resblock_updown=False,
48
+ use_new_attention_order=False,
49
+ use_spatial_transformer=False, # custom transformer support
50
+ transformer_depth=1, # custom transformer support
51
+ context_dim=None, # custom transformer support
52
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
53
+ legacy=True,
54
+ disable_self_attentions=None,
55
+ num_attention_blocks=None,
56
+ disable_middle_self_attn=False,
57
+ use_linear_in_transformer=False,
58
+ adm_in_channels=None,
59
+ camera_dim=None,):
60
  super().__init__()
61
+ self.unet: MultiViewUNetModel = MultiViewUNetModel(
62
+ image_size=image_size,
63
+ in_channels=in_channels,
64
+ model_channels=model_channels,
65
+ out_channels=out_channels,
66
+ num_res_blocks=num_res_blocks,
67
+ attention_resolutions=attention_resolutions,
68
+ dropout=dropout,
69
+ channel_mult=channel_mult,
70
+ conv_resample=conv_resample,
71
+ dims=dims,
72
+ num_classes=num_classes,
73
+ use_checkpoint=use_checkpoint,
74
+ use_fp16=use_fp16,
75
+ use_bf16=use_bf16,
76
+ num_heads=num_heads,
77
+ num_head_channels=num_head_channels,
78
+ num_heads_upsample=num_heads_upsample,
79
+ use_scale_shift_norm=use_scale_shift_norm,
80
+ resblock_updown=resblock_updown,
81
+ use_new_attention_order=use_new_attention_order,
82
+ use_spatial_transformer=use_spatial_transformer,
83
+ transformer_depth=transformer_depth,
84
+ context_dim=context_dim,
85
+ n_embed=n_embed,
86
+ legacy=legacy,
87
+ disable_self_attentions=disable_self_attentions,
88
+ num_attention_blocks=num_attention_blocks,
89
+ disable_middle_self_attn=disable_middle_self_attn,
90
+ use_linear_in_transformer=use_linear_in_transformer,
91
+ adm_in_channels=adm_in_channels,
92
+ camera_dim=camera_dim,
93
+ )
94
 
95
  def forward(self, *args, **kwargs):
96
  return self.unet(*args, **kwargs)
scripts/pipeline_mvdream.py CHANGED
@@ -1,16 +1,14 @@
1
  import torch
2
  import numpy as np
3
  import inspect
4
- from typing import Any, Callable, Dict, List, Optional, Union
5
  from transformers import CLIPTextModel, CLIPTokenizer
6
  from diffusers import AutoencoderKL, DiffusionPipeline
7
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
8
  from diffusers.utils import (
9
  deprecate,
10
  is_accelerate_available,
11
  is_accelerate_version,
12
  logging,
13
- replace_example_docstring,
14
  )
15
  from diffusers.configuration_utils import FrozenDict
16
  from diffusers.schedulers import DDIMScheduler
@@ -20,6 +18,7 @@ except ImportError:
20
  from diffusers.utils.torch_utils import randn_tensor # new import # type: ignore
21
 
22
  from models import MultiViewUNetWrapperModel
 
23
 
24
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
 
@@ -391,9 +390,13 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
391
  output_type: Optional[str] = "pil",
392
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
393
  callback_steps: int = 1,
 
 
394
  ):
395
- batch_size = 4
396
- device = torch.device("cuda:0")
 
 
397
 
398
  camera = get_camera(batch_size).to(device=device)
399
 
 
1
  import torch
2
  import numpy as np
3
  import inspect
4
+ from typing import Callable, List, Optional, Union
5
  from transformers import CLIPTextModel, CLIPTokenizer
6
  from diffusers import AutoencoderKL, DiffusionPipeline
 
7
  from diffusers.utils import (
8
  deprecate,
9
  is_accelerate_available,
10
  is_accelerate_version,
11
  logging,
 
12
  )
13
  from diffusers.configuration_utils import FrozenDict
14
  from diffusers.schedulers import DDIMScheduler
 
18
  from diffusers.utils.torch_utils import randn_tensor # new import # type: ignore
19
 
20
  from models import MultiViewUNetWrapperModel
21
+ from accelerate.utils import set_module_tensor_to_device
22
 
23
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
 
 
390
  output_type: Optional[str] = "pil",
391
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
392
  callback_steps: int = 1,
393
+ batch_size: int = 4,
394
+ device = torch.device("cuda:0"),
395
  ):
396
+ self.unet = self.unet.to(device=device)
397
+ self.vae = self.vae.to(device=device)
398
+
399
+ self.text_encoder = self.text_encoder.to(device=device)
400
 
401
  camera = get_camera(batch_size).to(device=device)
402
 
unet/config.json CHANGED
@@ -1,4 +1,27 @@
1
  {
2
  "_class_name": "MultiViewUNetWrapperModel",
3
- "_diffusers_version": "0.21.4"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  }
 
1
  {
2
  "_class_name": "MultiViewUNetWrapperModel",
3
+ "_diffusers_version": "0.21.4",
4
+ "attention_resolutions": [
5
+ 4,
6
+ 2,
7
+ 1
8
+ ],
9
+ "camera_dim": 16,
10
+ "channel_mult": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 4
15
+ ],
16
+ "context_dim": 768,
17
+ "image_size": 32,
18
+ "in_channels": 4,
19
+ "legacy": false,
20
+ "model_channels": 320,
21
+ "num_heads": 8,
22
+ "num_res_blocks": 2,
23
+ "out_channels": 4,
24
+ "transformer_depth": 1,
25
+ "use_checkpoint": false,
26
+ "use_spatial_transformer": true
27
  }