Koke_Cacao
commited on
Commit
•
c4ac4f3
1
Parent(s):
ad07f98
:bug: fix direct load from pipeline, thanks @reynoldscem
Browse files- scripts/attention.py +2 -2
- scripts/convert_mvdream_to_diffusers.py +48 -23
- scripts/models.py +65 -2
- scripts/pipeline_mvdream.py +8 -5
- unet/config.json +24 -1
scripts/attention.py
CHANGED
@@ -12,8 +12,8 @@ from typing import Optional, Any
|
|
12 |
from util import checkpoint
|
13 |
|
14 |
try:
|
15 |
-
import xformers
|
16 |
-
import xformers.ops
|
17 |
XFORMERS_IS_AVAILBLE = True
|
18 |
except:
|
19 |
XFORMERS_IS_AVAILBLE = False
|
|
|
12 |
from util import checkpoint
|
13 |
|
14 |
try:
|
15 |
+
import xformers # type: ignore
|
16 |
+
import xformers.ops # type: ignore
|
17 |
XFORMERS_IS_AVAILBLE = True
|
18 |
except:
|
19 |
XFORMERS_IS_AVAILBLE = False
|
scripts/convert_mvdream_to_diffusers.py
CHANGED
@@ -9,13 +9,13 @@ sys.path.insert(0, '../')
|
|
9 |
from diffusers.models import (
|
10 |
AutoencoderKL,
|
11 |
)
|
|
|
12 |
from diffusers.schedulers import DDIMScheduler
|
13 |
from diffusers.utils import logging
|
14 |
-
|
15 |
from accelerate import init_empty_weights
|
16 |
from accelerate.utils import set_module_tensor_to_device
|
17 |
-
from
|
18 |
-
from models import MultiViewUNetModel, MultiViewUNetWrapperModel
|
19 |
from pipeline_mvdream import MVDreamStableDiffusionPipeline
|
20 |
from transformers import CLIPTokenizer, CLIPTextModel
|
21 |
|
@@ -259,14 +259,14 @@ def conv_attn_to_linear(checkpoint):
|
|
259 |
if checkpoint[key].ndim > 2:
|
260 |
checkpoint[key] = checkpoint[key][:, :, 0]
|
261 |
|
|
|
|
|
262 |
|
263 |
def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, device):
|
264 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
265 |
# print(f"Checkpoint: {checkpoint.keys()}")
|
266 |
torch.cuda.empty_cache()
|
267 |
|
268 |
-
from omegaconf import OmegaConf
|
269 |
-
|
270 |
original_config = OmegaConf.load(original_config_file)
|
271 |
# print(f"Original Config: {original_config}")
|
272 |
prediction_type = "epsilon"
|
@@ -296,11 +296,13 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de
|
|
296 |
# checkpoint, unet_config, path=None, extract_ema=extract_ema
|
297 |
# )
|
298 |
# print(f"Unet Config: {original_config.model.params.unet_config.params}")
|
299 |
-
|
|
|
|
|
300 |
# print(f"Unet State Dict: {unet.state_dict().keys()}")
|
301 |
unet.load_state_dict({key.replace("model.diffusion_model.", "unet."): value for key, value in checkpoint.items() if key.replace("model.diffusion_model.", "unet.") in unet.state_dict()})
|
302 |
for param_name, param in unet.state_dict().items():
|
303 |
-
set_module_tensor_to_device(unet, param_name,
|
304 |
|
305 |
# Convert the VAE model.
|
306 |
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
@@ -316,18 +318,18 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de
|
|
316 |
with init_empty_weights():
|
317 |
vae = AutoencoderKL(**vae_config)
|
318 |
|
|
|
|
|
|
|
319 |
if original_config.model.params.unet_config.params.context_dim == 768:
|
320 |
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
321 |
-
text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device=
|
322 |
elif original_config.model.params.unet_config.params.context_dim == 1024:
|
323 |
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="tokenizer")
|
324 |
-
text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="text_encoder").to(device=
|
325 |
else:
|
326 |
raise ValueError(f"Unknown context_dim: {original_config.model.paams.unet_config.params.context_dim}")
|
327 |
|
328 |
-
for param_name, param in converted_vae_checkpoint.items():
|
329 |
-
set_module_tensor_to_device(vae, param_name, "cuda:0", value=param)
|
330 |
-
|
331 |
pipe = MVDreamStableDiffusionPipeline(
|
332 |
vae=vae,
|
333 |
unet=unet,
|
@@ -359,6 +361,8 @@ if __name__ == "__main__":
|
|
359 |
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
360 |
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
361 |
args = parser.parse_args()
|
|
|
|
|
362 |
|
363 |
pipe = convert_from_original_mvdream_ckpt(
|
364 |
checkpoint_path=args.checkpoint_path,
|
@@ -369,15 +373,36 @@ if __name__ == "__main__":
|
|
369 |
if args.half:
|
370 |
pipe.to(torch_dtype=torch.float16)
|
371 |
|
372 |
-
|
373 |
-
images = pipe(
|
374 |
-
prompt="Head of Hatsune Miku",
|
375 |
-
negative_prompt="painting, bad quality, flat",
|
376 |
-
output_type="pil",
|
377 |
-
guidance_scale=7.5,
|
378 |
-
num_inference_steps=50,
|
379 |
-
)
|
380 |
-
for i, image in enumerate(images):
|
381 |
-
image.save(f"image_{i}.png") # type: ignore
|
382 |
-
|
383 |
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from diffusers.models import (
|
10 |
AutoencoderKL,
|
11 |
)
|
12 |
+
from omegaconf import OmegaConf
|
13 |
from diffusers.schedulers import DDIMScheduler
|
14 |
from diffusers.utils import logging
|
15 |
+
from typing import Any
|
16 |
from accelerate import init_empty_weights
|
17 |
from accelerate.utils import set_module_tensor_to_device
|
18 |
+
from models import MultiViewUNetWrapperModel
|
|
|
19 |
from pipeline_mvdream import MVDreamStableDiffusionPipeline
|
20 |
from transformers import CLIPTokenizer, CLIPTextModel
|
21 |
|
|
|
259 |
if checkpoint[key].ndim > 2:
|
260 |
checkpoint[key] = checkpoint[key][:, :, 0]
|
261 |
|
262 |
+
def create_unet_config(original_config) -> Any:
|
263 |
+
return OmegaConf.to_container(original_config.model.params.unet_config.params, resolve=True)
|
264 |
|
265 |
def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, device):
|
266 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
267 |
# print(f"Checkpoint: {checkpoint.keys()}")
|
268 |
torch.cuda.empty_cache()
|
269 |
|
|
|
|
|
270 |
original_config = OmegaConf.load(original_config_file)
|
271 |
# print(f"Original Config: {original_config}")
|
272 |
prediction_type = "epsilon"
|
|
|
296 |
# checkpoint, unet_config, path=None, extract_ema=extract_ema
|
297 |
# )
|
298 |
# print(f"Unet Config: {original_config.model.params.unet_config.params}")
|
299 |
+
unet_config = create_unet_config(original_config)
|
300 |
+
unet: MultiViewUNetWrapperModel = MultiViewUNetWrapperModel(**unet_config)
|
301 |
+
unet.register_to_config(**unet_config)
|
302 |
# print(f"Unet State Dict: {unet.state_dict().keys()}")
|
303 |
unet.load_state_dict({key.replace("model.diffusion_model.", "unet."): value for key, value in checkpoint.items() if key.replace("model.diffusion_model.", "unet.") in unet.state_dict()})
|
304 |
for param_name, param in unet.state_dict().items():
|
305 |
+
set_module_tensor_to_device(unet, param_name, device=device, value=param)
|
306 |
|
307 |
# Convert the VAE model.
|
308 |
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
|
|
318 |
with init_empty_weights():
|
319 |
vae = AutoencoderKL(**vae_config)
|
320 |
|
321 |
+
for param_name, param in converted_vae_checkpoint.items():
|
322 |
+
set_module_tensor_to_device(vae, param_name, device=device, value=param)
|
323 |
+
|
324 |
if original_config.model.params.unet_config.params.context_dim == 768:
|
325 |
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
326 |
+
text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device=device) # type: ignore
|
327 |
elif original_config.model.params.unet_config.params.context_dim == 1024:
|
328 |
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="tokenizer")
|
329 |
+
text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="text_encoder").to(device=device) # type: ignore
|
330 |
else:
|
331 |
raise ValueError(f"Unknown context_dim: {original_config.model.paams.unet_config.params.context_dim}")
|
332 |
|
|
|
|
|
|
|
333 |
pipe = MVDreamStableDiffusionPipeline(
|
334 |
vae=vae,
|
335 |
unet=unet,
|
|
|
361 |
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
362 |
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
363 |
args = parser.parse_args()
|
364 |
+
|
365 |
+
args.device = torch.device(args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu")
|
366 |
|
367 |
pipe = convert_from_original_mvdream_ckpt(
|
368 |
checkpoint_path=args.checkpoint_path,
|
|
|
373 |
if args.half:
|
374 |
pipe.to(torch_dtype=torch.float16)
|
375 |
|
376 |
+
print(f"Saving pipeline to {args.dump_path}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
378 |
+
|
379 |
+
if args.test:
|
380 |
+
try:
|
381 |
+
print(f"Testing each subcomponent of the pipeline...")
|
382 |
+
images = pipe(
|
383 |
+
prompt="Head of Hatsune Miku",
|
384 |
+
negative_prompt="painting, bad quality, flat",
|
385 |
+
output_type="pil",
|
386 |
+
guidance_scale=7.5,
|
387 |
+
num_inference_steps=50,
|
388 |
+
device=args.device,
|
389 |
+
)
|
390 |
+
for i, image in enumerate(images):
|
391 |
+
image.save(f"image_{i}.png") # type: ignore
|
392 |
+
|
393 |
+
print(f"Testing entire pipeline...")
|
394 |
+
loaded_pipe: MVDreamStableDiffusionPipeline = MVDreamStableDiffusionPipeline.from_pretrained(args.dump_path, safe_serialization=args.to_safetensors) # type: ignore
|
395 |
+
images = loaded_pipe(
|
396 |
+
prompt="Head of Hatsune Miku",
|
397 |
+
negative_prompt="painting, bad quality, flat",
|
398 |
+
output_type="pil",
|
399 |
+
guidance_scale=7.5,
|
400 |
+
num_inference_steps=50,
|
401 |
+
device=args.device,
|
402 |
+
)
|
403 |
+
for i, image in enumerate(images):
|
404 |
+
image.save(f"image_{i}.png") # type: ignore
|
405 |
+
except Exception as e:
|
406 |
+
print(f"Failed to test inference: {e}")
|
407 |
+
raise e from e
|
408 |
+
print("Inference test passed!")
|
scripts/models.py
CHANGED
@@ -25,9 +25,72 @@ from torch import Tensor
|
|
25 |
|
26 |
class MultiViewUNetWrapperModel(ModelMixin, ConfigMixin):
|
27 |
|
28 |
-
def __init__(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
super().__init__()
|
30 |
-
self.unet: MultiViewUNetModel = MultiViewUNetModel(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def forward(self, *args, **kwargs):
|
33 |
return self.unet(*args, **kwargs)
|
|
|
25 |
|
26 |
class MultiViewUNetWrapperModel(ModelMixin, ConfigMixin):
|
27 |
|
28 |
+
def __init__(self,
|
29 |
+
image_size,
|
30 |
+
in_channels,
|
31 |
+
model_channels,
|
32 |
+
out_channels,
|
33 |
+
num_res_blocks,
|
34 |
+
attention_resolutions,
|
35 |
+
dropout=0,
|
36 |
+
channel_mult=(1, 2, 4, 8),
|
37 |
+
conv_resample=True,
|
38 |
+
dims=2,
|
39 |
+
num_classes=None,
|
40 |
+
use_checkpoint=False,
|
41 |
+
use_fp16=False,
|
42 |
+
use_bf16=False,
|
43 |
+
num_heads=-1,
|
44 |
+
num_head_channels=-1,
|
45 |
+
num_heads_upsample=-1,
|
46 |
+
use_scale_shift_norm=False,
|
47 |
+
resblock_updown=False,
|
48 |
+
use_new_attention_order=False,
|
49 |
+
use_spatial_transformer=False, # custom transformer support
|
50 |
+
transformer_depth=1, # custom transformer support
|
51 |
+
context_dim=None, # custom transformer support
|
52 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
53 |
+
legacy=True,
|
54 |
+
disable_self_attentions=None,
|
55 |
+
num_attention_blocks=None,
|
56 |
+
disable_middle_self_attn=False,
|
57 |
+
use_linear_in_transformer=False,
|
58 |
+
adm_in_channels=None,
|
59 |
+
camera_dim=None,):
|
60 |
super().__init__()
|
61 |
+
self.unet: MultiViewUNetModel = MultiViewUNetModel(
|
62 |
+
image_size=image_size,
|
63 |
+
in_channels=in_channels,
|
64 |
+
model_channels=model_channels,
|
65 |
+
out_channels=out_channels,
|
66 |
+
num_res_blocks=num_res_blocks,
|
67 |
+
attention_resolutions=attention_resolutions,
|
68 |
+
dropout=dropout,
|
69 |
+
channel_mult=channel_mult,
|
70 |
+
conv_resample=conv_resample,
|
71 |
+
dims=dims,
|
72 |
+
num_classes=num_classes,
|
73 |
+
use_checkpoint=use_checkpoint,
|
74 |
+
use_fp16=use_fp16,
|
75 |
+
use_bf16=use_bf16,
|
76 |
+
num_heads=num_heads,
|
77 |
+
num_head_channels=num_head_channels,
|
78 |
+
num_heads_upsample=num_heads_upsample,
|
79 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
80 |
+
resblock_updown=resblock_updown,
|
81 |
+
use_new_attention_order=use_new_attention_order,
|
82 |
+
use_spatial_transformer=use_spatial_transformer,
|
83 |
+
transformer_depth=transformer_depth,
|
84 |
+
context_dim=context_dim,
|
85 |
+
n_embed=n_embed,
|
86 |
+
legacy=legacy,
|
87 |
+
disable_self_attentions=disable_self_attentions,
|
88 |
+
num_attention_blocks=num_attention_blocks,
|
89 |
+
disable_middle_self_attn=disable_middle_self_attn,
|
90 |
+
use_linear_in_transformer=use_linear_in_transformer,
|
91 |
+
adm_in_channels=adm_in_channels,
|
92 |
+
camera_dim=camera_dim,
|
93 |
+
)
|
94 |
|
95 |
def forward(self, *args, **kwargs):
|
96 |
return self.unet(*args, **kwargs)
|
scripts/pipeline_mvdream.py
CHANGED
@@ -1,16 +1,14 @@
|
|
1 |
import torch
|
2 |
import numpy as np
|
3 |
import inspect
|
4 |
-
from typing import
|
5 |
from transformers import CLIPTextModel, CLIPTokenizer
|
6 |
from diffusers import AutoencoderKL, DiffusionPipeline
|
7 |
-
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
8 |
from diffusers.utils import (
|
9 |
deprecate,
|
10 |
is_accelerate_available,
|
11 |
is_accelerate_version,
|
12 |
logging,
|
13 |
-
replace_example_docstring,
|
14 |
)
|
15 |
from diffusers.configuration_utils import FrozenDict
|
16 |
from diffusers.schedulers import DDIMScheduler
|
@@ -20,6 +18,7 @@ except ImportError:
|
|
20 |
from diffusers.utils.torch_utils import randn_tensor # new import # type: ignore
|
21 |
|
22 |
from models import MultiViewUNetWrapperModel
|
|
|
23 |
|
24 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
25 |
|
@@ -391,9 +390,13 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
|
391 |
output_type: Optional[str] = "pil",
|
392 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
393 |
callback_steps: int = 1,
|
|
|
|
|
394 |
):
|
395 |
-
|
396 |
-
|
|
|
|
|
397 |
|
398 |
camera = get_camera(batch_size).to(device=device)
|
399 |
|
|
|
1 |
import torch
|
2 |
import numpy as np
|
3 |
import inspect
|
4 |
+
from typing import Callable, List, Optional, Union
|
5 |
from transformers import CLIPTextModel, CLIPTokenizer
|
6 |
from diffusers import AutoencoderKL, DiffusionPipeline
|
|
|
7 |
from diffusers.utils import (
|
8 |
deprecate,
|
9 |
is_accelerate_available,
|
10 |
is_accelerate_version,
|
11 |
logging,
|
|
|
12 |
)
|
13 |
from diffusers.configuration_utils import FrozenDict
|
14 |
from diffusers.schedulers import DDIMScheduler
|
|
|
18 |
from diffusers.utils.torch_utils import randn_tensor # new import # type: ignore
|
19 |
|
20 |
from models import MultiViewUNetWrapperModel
|
21 |
+
from accelerate.utils import set_module_tensor_to_device
|
22 |
|
23 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
24 |
|
|
|
390 |
output_type: Optional[str] = "pil",
|
391 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
392 |
callback_steps: int = 1,
|
393 |
+
batch_size: int = 4,
|
394 |
+
device = torch.device("cuda:0"),
|
395 |
):
|
396 |
+
self.unet = self.unet.to(device=device)
|
397 |
+
self.vae = self.vae.to(device=device)
|
398 |
+
|
399 |
+
self.text_encoder = self.text_encoder.to(device=device)
|
400 |
|
401 |
camera = get_camera(batch_size).to(device=device)
|
402 |
|
unet/config.json
CHANGED
@@ -1,4 +1,27 @@
|
|
1 |
{
|
2 |
"_class_name": "MultiViewUNetWrapperModel",
|
3 |
-
"_diffusers_version": "0.21.4"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
}
|
|
|
1 |
{
|
2 |
"_class_name": "MultiViewUNetWrapperModel",
|
3 |
+
"_diffusers_version": "0.21.4",
|
4 |
+
"attention_resolutions": [
|
5 |
+
4,
|
6 |
+
2,
|
7 |
+
1
|
8 |
+
],
|
9 |
+
"camera_dim": 16,
|
10 |
+
"channel_mult": [
|
11 |
+
1,
|
12 |
+
2,
|
13 |
+
4,
|
14 |
+
4
|
15 |
+
],
|
16 |
+
"context_dim": 768,
|
17 |
+
"image_size": 32,
|
18 |
+
"in_channels": 4,
|
19 |
+
"legacy": false,
|
20 |
+
"model_channels": 320,
|
21 |
+
"num_heads": 8,
|
22 |
+
"num_res_blocks": 2,
|
23 |
+
"out_channels": 4,
|
24 |
+
"transformer_depth": 1,
|
25 |
+
"use_checkpoint": false,
|
26 |
+
"use_spatial_transformer": true
|
27 |
}
|