Spaces:
Runtime error
Runtime error
forgot about the nested package structure
Browse files- makeavid_sd/README.md +0 -1
- makeavid_sd/{makeavid_sd/__init__.py β __init__.py} +0 -0
- makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/__init__.py +0 -0
- makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/dataset.py +0 -0
- makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/flax_attention_pseudo3d.py +0 -0
- makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/flax_embeddings.py +0 -0
- makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/flax_resnet_pseudo3d.py +0 -0
- makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/flax_trainer.py +0 -0
- makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/flax_unet_pseudo3d_blocks.py +0 -0
- makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/flax_unet_pseudo3d_condition.py +0 -0
- makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/train.py +0 -0
- makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/train.sh +0 -0
- makeavid_sd/{makeavid_sd/inference.py β inference.py} +0 -0
- makeavid_sd/requirements.txt +0 -2
- makeavid_sd/setup.py +0 -11
- makeavid_sd/{makeavid_sd/torch_impl β torch_impl}/__init__.py +0 -0
- makeavid_sd/{makeavid_sd/torch_impl β torch_impl}/torch_attention_pseudo3d.py +0 -0
- makeavid_sd/{makeavid_sd/torch_impl β torch_impl}/torch_cross_attention.py +0 -0
- makeavid_sd/{makeavid_sd/torch_impl β torch_impl}/torch_embeddings.py +0 -0
- makeavid_sd/{makeavid_sd/torch_impl β torch_impl}/torch_resnet_pseudo3d.py +0 -0
- makeavid_sd/{makeavid_sd/torch_impl β torch_impl}/torch_unet_pseudo3d_blocks.py +0 -0
- makeavid_sd/{makeavid_sd/torch_impl β torch_impl}/torch_unet_pseudo3d_condition.py +0 -0
- makeavid_sd/trainer_xla.py +0 -104
makeavid_sd/README.md
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
# makeavid-sd-tpu
|
|
|
|
makeavid_sd/{makeavid_sd/__init__.py β __init__.py}
RENAMED
File without changes
|
makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/__init__.py
RENAMED
File without changes
|
makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/dataset.py
RENAMED
File without changes
|
makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/flax_attention_pseudo3d.py
RENAMED
File without changes
|
makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/flax_embeddings.py
RENAMED
File without changes
|
makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/flax_resnet_pseudo3d.py
RENAMED
File without changes
|
makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/flax_trainer.py
RENAMED
File without changes
|
makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/flax_unet_pseudo3d_blocks.py
RENAMED
File without changes
|
makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/flax_unet_pseudo3d_condition.py
RENAMED
File without changes
|
makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/train.py
RENAMED
File without changes
|
makeavid_sd/{makeavid_sd/flax_impl β flax_impl}/train.sh
RENAMED
File without changes
|
makeavid_sd/{makeavid_sd/inference.py β inference.py}
RENAMED
File without changes
|
makeavid_sd/requirements.txt
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
torch
|
2 |
-
torch_xla
|
|
|
|
|
|
makeavid_sd/setup.py
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
from setuptools import setup
|
2 |
-
setup(
|
3 |
-
name = 'makeavid_sd',
|
4 |
-
version = '0.1.0',
|
5 |
-
description = 'makeavid sd',
|
6 |
-
author = 'Lopho',
|
7 |
-
author_email = 'contact@lopho.org',
|
8 |
-
platforms = ['any'],
|
9 |
-
license = 'GNU Affero General Public License v3',
|
10 |
-
url = 'http://github.com/lopho/makeavid-sd-tpu'
|
11 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
makeavid_sd/{makeavid_sd/torch_impl β torch_impl}/__init__.py
RENAMED
File without changes
|
makeavid_sd/{makeavid_sd/torch_impl β torch_impl}/torch_attention_pseudo3d.py
RENAMED
File without changes
|
makeavid_sd/{makeavid_sd/torch_impl β torch_impl}/torch_cross_attention.py
RENAMED
File without changes
|
makeavid_sd/{makeavid_sd/torch_impl β torch_impl}/torch_embeddings.py
RENAMED
File without changes
|
makeavid_sd/{makeavid_sd/torch_impl β torch_impl}/torch_resnet_pseudo3d.py
RENAMED
File without changes
|
makeavid_sd/{makeavid_sd/torch_impl β torch_impl}/torch_unet_pseudo3d_blocks.py
RENAMED
File without changes
|
makeavid_sd/{makeavid_sd/torch_impl β torch_impl}/torch_unet_pseudo3d_condition.py
RENAMED
File without changes
|
makeavid_sd/trainer_xla.py
DELETED
@@ -1,104 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
os.environ['PJRT_DEVICE'] = 'TPU'
|
3 |
-
|
4 |
-
from tqdm.auto import tqdm
|
5 |
-
import torch
|
6 |
-
from torch.utils.data import DataLoader
|
7 |
-
from torch_xla.core import xla_model
|
8 |
-
from diffusers import UNetPseudo3DConditionModel
|
9 |
-
from dataset import load_dataset
|
10 |
-
|
11 |
-
|
12 |
-
class TempoTrainerXLA:
|
13 |
-
def __init__(self,
|
14 |
-
pretrained: str = 'lxj616/make-a-stable-diffusion-video-timelapse',
|
15 |
-
lr: float = 1e-4,
|
16 |
-
dtype: torch.dtype = torch.float32,
|
17 |
-
) -> None:
|
18 |
-
self.dtype = dtype
|
19 |
-
self.device: torch.device = xla_model.xla_device(0)
|
20 |
-
unet: UNetPseudo3DConditionModel = UNetPseudo3DConditionModel.from_pretrained(
|
21 |
-
pretrained,
|
22 |
-
subfolder = 'unet'
|
23 |
-
).to(dtype = dtype, memory_format = torch.contiguous_format)
|
24 |
-
unfreeze_all: bool = False
|
25 |
-
unet = unet.train()
|
26 |
-
if not unfreeze_all:
|
27 |
-
unet.requires_grad_(False)
|
28 |
-
for name, param in unet.named_parameters():
|
29 |
-
if 'temporal_conv' in name:
|
30 |
-
param.requires_grad_(True)
|
31 |
-
for block in [*unet.down_blocks, unet.mid_block, *unet.up_blocks]:
|
32 |
-
if hasattr(block, 'attentions') and block.attentions is not None:
|
33 |
-
for attn_block in block.attentions:
|
34 |
-
for transformer_block in attn_block.transformer_blocks:
|
35 |
-
transformer_block.requires_grad_(False)
|
36 |
-
transformer_block.attn_temporal.requires_grad_(True)
|
37 |
-
transformer_block.norm_temporal.requires_grad_(True)
|
38 |
-
else:
|
39 |
-
unet.requires_grad_(True)
|
40 |
-
self.model: UNetPseudo3DConditionModel = unet.to(device = self.device)
|
41 |
-
#self.model = torch.compile(self.model, backend = 'aot_torchxla_trace_once')
|
42 |
-
self.params = lambda: filter(lambda p: p.requires_grad, self.model.parameters())
|
43 |
-
self.optim: torch.optim.Optimizer = torch.optim.AdamW(self.params(), lr = lr)
|
44 |
-
def lr_warmup(warmup_steps: int = 0):
|
45 |
-
def lambda_lr(step: int) -> float:
|
46 |
-
if step < warmup_steps:
|
47 |
-
return step / warmup_steps
|
48 |
-
else:
|
49 |
-
return 1.0
|
50 |
-
return lambda_lr
|
51 |
-
self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optim, lr_lambda = lr_warmup(warmup_steps = 60), last_epoch = -1)
|
52 |
-
|
53 |
-
@torch.no_grad()
|
54 |
-
def train(self, dataloader: DataLoader, epochs: int = 1, log_every: int = 1, save_every: int = 1000) -> None:
|
55 |
-
# 'latent_model_input'
|
56 |
-
# 'encoder_hidden_states'
|
57 |
-
# 'timesteps'
|
58 |
-
# 'noise'
|
59 |
-
global_step: int = 0
|
60 |
-
for epoch in range(epochs):
|
61 |
-
pbar = tqdm(dataloader, dynamic_ncols = True, smoothing = 0.01)
|
62 |
-
for b in pbar:
|
63 |
-
latent_model_input: torch.Tensor = b['latent_model_input'].to(device = self.device)
|
64 |
-
encoder_hidden_states: torch.Tensor = b['encoder_hidden_states'].to(device = self.device)
|
65 |
-
timesteps: torch.Tensor = b['timesteps'].to(device = self.device)
|
66 |
-
noise: torch.Tensor = b['noise'].to(device = self.device)
|
67 |
-
with torch.enable_grad():
|
68 |
-
self.optim.zero_grad(set_to_none = True)
|
69 |
-
y = self.model(latent_model_input, timesteps, encoder_hidden_states).sample
|
70 |
-
loss = torch.nn.functional.mse_loss(noise, y)
|
71 |
-
loss.backward()
|
72 |
-
self.optim.step()
|
73 |
-
self.scheduler.step()
|
74 |
-
xla_model.mark_step()
|
75 |
-
if global_step % log_every == 0:
|
76 |
-
pbar.set_postfix({ 'loss': loss.detach().item(), 'epoch': epoch })
|
77 |
-
|
78 |
-
def main():
|
79 |
-
pretrained: str = 'lxj616/make-a-stable-diffusion-video-timelapse'
|
80 |
-
dataset_path: str = './storage/dataset/tempofunk'
|
81 |
-
dtype: torch.dtype = torch.bfloat16
|
82 |
-
trainer = TempoTrainerXLA(
|
83 |
-
pretrained = pretrained,
|
84 |
-
lr = 1e-5,
|
85 |
-
dtype = dtype
|
86 |
-
)
|
87 |
-
dataloader: DataLoader = load_dataset(
|
88 |
-
dataset_path = dataset_path,
|
89 |
-
pretrained = pretrained,
|
90 |
-
batch_size = 1,
|
91 |
-
num_frames = 10,
|
92 |
-
num_workers = 1,
|
93 |
-
dtype = dtype
|
94 |
-
)
|
95 |
-
trainer.train(
|
96 |
-
dataloader = dataloader,
|
97 |
-
epochs = 1000,
|
98 |
-
log_every = 1,
|
99 |
-
save_every = 1000
|
100 |
-
)
|
101 |
-
|
102 |
-
if __name__ == '__main__':
|
103 |
-
main()
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|