Spaces:
Paused
Paused
CausalVideoAutoencoder: made neater load_ckpt.
Browse files
xora/examples/image_to_video.py
CHANGED
@@ -19,12 +19,12 @@ vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
|
|
19 |
vae_config_path = vae_dir / "config.json"
|
20 |
with open(vae_config_path, 'r') as f:
|
21 |
vae_config = json.load(f)
|
|
|
22 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
23 |
-
vae
|
24 |
-
config=vae_config,
|
25 |
state_dict=vae_state_dict,
|
26 |
-
|
27 |
-
|
28 |
|
29 |
# Load UNet (Transformer) from separate mode
|
30 |
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
|
|
|
19 |
vae_config_path = vae_dir / "config.json"
|
20 |
with open(vae_config_path, 'r') as f:
|
21 |
vae_config = json.load(f)
|
22 |
+
vae = CausalVideoAutoencoder.from_config(vae_config)
|
23 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
24 |
+
vae.load_state_dict(
|
|
|
25 |
state_dict=vae_state_dict,
|
26 |
+
)
|
27 |
+
vae = vae.cuda().to(torch.bfloat16)
|
28 |
|
29 |
# Load UNet (Transformer) from separate mode
|
30 |
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
|
xora/examples/text_to_video.py
CHANGED
@@ -10,7 +10,7 @@ import safetensors.torch
|
|
10 |
import json
|
11 |
|
12 |
# Paths for the separate mode directories
|
13 |
-
separate_dir = Path("/opt/models/xora-
|
14 |
unet_dir = separate_dir / 'unet'
|
15 |
vae_dir = separate_dir / 'vae'
|
16 |
scheduler_dir = separate_dir / 'scheduler'
|
@@ -20,12 +20,12 @@ vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
|
|
20 |
vae_config_path = vae_dir / "config.json"
|
21 |
with open(vae_config_path, 'r') as f:
|
22 |
vae_config = json.load(f)
|
|
|
23 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
24 |
-
vae
|
25 |
-
config=vae_config,
|
26 |
state_dict=vae_state_dict,
|
27 |
-
|
28 |
-
|
29 |
|
30 |
# Load UNet (Transformer) from separate mode
|
31 |
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
|
|
|
10 |
import json
|
11 |
|
12 |
# Paths for the separate mode directories
|
13 |
+
separate_dir = Path("/opt/models/xora-img2video")
|
14 |
unet_dir = separate_dir / 'unet'
|
15 |
vae_dir = separate_dir / 'vae'
|
16 |
scheduler_dir = separate_dir / 'scheduler'
|
|
|
20 |
vae_config_path = vae_dir / "config.json"
|
21 |
with open(vae_config_path, 'r') as f:
|
22 |
vae_config = json.load(f)
|
23 |
+
vae = CausalVideoAutoencoder.from_config(vae_config)
|
24 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
25 |
+
vae.load_state_dict(
|
|
|
26 |
state_dict=vae_state_dict,
|
27 |
+
)
|
28 |
+
vae = vae.cuda().to(torch.bfloat16)
|
29 |
|
30 |
# Load UNet (Transformer) from separate mode
|
31 |
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
|
xora/models/autoencoders/causal_video_autoencoder.py
CHANGED
@@ -41,35 +41,6 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
41 |
|
42 |
return video_vae
|
43 |
|
44 |
-
@classmethod
|
45 |
-
def from_pretrained_conf(cls, config, state_dict, torch_dtype=torch.float32):
|
46 |
-
video_vae = cls.from_config(config)
|
47 |
-
video_vae.to(torch_dtype)
|
48 |
-
|
49 |
-
per_channel_statistics_prefix = "per_channel_statistics."
|
50 |
-
ckpt_state_dict = {
|
51 |
-
key: value
|
52 |
-
for key, value in state_dict.items()
|
53 |
-
if not key.startswith(per_channel_statistics_prefix)
|
54 |
-
}
|
55 |
-
video_vae.load_state_dict(ckpt_state_dict)
|
56 |
-
|
57 |
-
data_dict = {
|
58 |
-
key.removeprefix(per_channel_statistics_prefix): value
|
59 |
-
for key, value in state_dict.items()
|
60 |
-
if key.startswith(per_channel_statistics_prefix)
|
61 |
-
}
|
62 |
-
if len(data_dict) > 0:
|
63 |
-
video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
|
64 |
-
video_vae.register_buffer(
|
65 |
-
"mean_of_means",
|
66 |
-
data_dict.get(
|
67 |
-
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
|
68 |
-
),
|
69 |
-
)
|
70 |
-
|
71 |
-
return video_vae
|
72 |
-
|
73 |
@staticmethod
|
74 |
def from_config(config):
|
75 |
assert config["_class_name"] == "CausalVideoAutoencoder", "config must have _class_name=CausalVideoAutoencoder"
|
@@ -155,6 +126,13 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
155 |
return json.dumps(self.config.__dict__)
|
156 |
|
157 |
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
model_keys = set(name for name, _ in self.named_parameters())
|
159 |
|
160 |
key_mapping = {
|
@@ -162,9 +140,8 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
162 |
"downsamplers.0": "downsample",
|
163 |
"upsamplers.0": "upsample",
|
164 |
}
|
165 |
-
|
166 |
converted_state_dict = {}
|
167 |
-
for key, value in
|
168 |
for k, v in key_mapping.items():
|
169 |
key = key.replace(k, v)
|
170 |
|
@@ -176,6 +153,20 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
176 |
|
177 |
super().load_state_dict(converted_state_dict, strict=strict)
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
def last_layer(self):
|
180 |
if hasattr(self.decoder, "conv_out"):
|
181 |
if isinstance(self.decoder.conv_out, nn.Sequential):
|
|
|
41 |
|
42 |
return video_vae
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
@staticmethod
|
45 |
def from_config(config):
|
46 |
assert config["_class_name"] == "CausalVideoAutoencoder", "config must have _class_name=CausalVideoAutoencoder"
|
|
|
126 |
return json.dumps(self.config.__dict__)
|
127 |
|
128 |
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
129 |
+
per_channel_statistics_prefix = "per_channel_statistics."
|
130 |
+
ckpt_state_dict = {
|
131 |
+
key: value
|
132 |
+
for key, value in state_dict.items()
|
133 |
+
if not key.startswith(per_channel_statistics_prefix)
|
134 |
+
}
|
135 |
+
|
136 |
model_keys = set(name for name, _ in self.named_parameters())
|
137 |
|
138 |
key_mapping = {
|
|
|
140 |
"downsamplers.0": "downsample",
|
141 |
"upsamplers.0": "upsample",
|
142 |
}
|
|
|
143 |
converted_state_dict = {}
|
144 |
+
for key, value in ckpt_state_dict.items():
|
145 |
for k, v in key_mapping.items():
|
146 |
key = key.replace(k, v)
|
147 |
|
|
|
153 |
|
154 |
super().load_state_dict(converted_state_dict, strict=strict)
|
155 |
|
156 |
+
data_dict = {
|
157 |
+
key.removeprefix(per_channel_statistics_prefix): value
|
158 |
+
for key, value in state_dict.items()
|
159 |
+
if key.startswith(per_channel_statistics_prefix)
|
160 |
+
}
|
161 |
+
if len(data_dict) > 0:
|
162 |
+
self.register_buffer("std_of_means", data_dict["std-of-means"])
|
163 |
+
self.register_buffer(
|
164 |
+
"mean_of_means",
|
165 |
+
data_dict.get(
|
166 |
+
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
|
167 |
+
),
|
168 |
+
)
|
169 |
+
|
170 |
def last_layer(self):
|
171 |
if hasattr(self.decoder, "conv_out"):
|
172 |
if isinstance(self.decoder.conv_out, nn.Sequential):
|