LittleNyima commited on
Commit
a651084
1 Parent(s): 7699ada

Initial commit

Browse files
Files changed (3) hide show
  1. README.md +111 -0
  2. config.json +51 -0
  3. diffusion_pytorch_model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Abstract
2
+
3
+ **DDPM** model trained on [huggan/anime-faces](https://huggingface.co/datasets/huggan/anime-faces) dataset.
4
+
5
+ ## Training Arguments
6
+
7
+ | Argument | Value |
8
+ | :-------------------------: | :----: |
9
+ | image_size | 64 |
10
+ | train_batch_size | 16 |
11
+ | eval_batch_size | 16 |
12
+ | num_epochs | 50 |
13
+ | gradient_accumulation_steps | 1 |
14
+ | learning_rate | 1e-4 |
15
+ | lr_warmup_steps | 500 |
16
+ | mixed_precision | "fp16" |
17
+
18
+ For training code, please refer to [this link](https://github.com/LittleNyima/code-snippets/blob/master/ddpm-tutorial/ddpm_training.py).
19
+
20
+ # Inference
21
+
22
+ This project aims to implement DDPM from scratch, so `DDPMScheduler` is not used. Instead, I use only `UNet2DModel` and implement a simple scheduler myself. The inference code is:
23
+
24
+ ```python
25
+ import torch
26
+ from tqdm import tqdm
27
+ from diffusers import UNet2DModel
28
+
29
+ class DDPM:
30
+ def __init__(
31
+ self,
32
+ num_train_timesteps:int = 1000,
33
+ beta_start: float = 0.0001,
34
+ beta_end: float = 0.02,
35
+ ):
36
+ self.num_train_timesteps = num_train_timesteps
37
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
38
+ self.alphas = 1.0 - self.betas
39
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
40
+ self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1)
41
+
42
+ def add_noise(
43
+ self,
44
+ original_samples: torch.Tensor,
45
+ noise: torch.Tensor,
46
+ timesteps: torch.Tensor,
47
+ ):
48
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device ,dtype=original_samples.dtype)
49
+ noise = noise.to(original_samples.device)
50
+ timesteps = timesteps.to(original_samples.device)
51
+
52
+ # \sqrt{\bar\alpha_t}
53
+ sqrt_alpha_prod = alphas_cumprod[timesteps].flatten() ** 0.5
54
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
55
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
56
+
57
+ # \sqrt{1 - \bar\alpha_t}
58
+ sqrt_one_minus_alpha_prod = (1.0 - alphas_cumprod[timesteps]).flatten() ** 0.5
59
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
60
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
61
+
62
+ return sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
63
+
64
+ @torch.no_grad()
65
+ def sample(
66
+ self,
67
+ unet: UNet2DModel,
68
+ batch_size: int,
69
+ in_channels: int,
70
+ sample_size: int,
71
+ ):
72
+ betas = self.betas.to(unet.device)
73
+ alphas = self.alphas.to(unet.device)
74
+ alphas_cumprod = self.alphas_cumprod.to(unet.device)
75
+ timesteps = self.timesteps.to(unet.device)
76
+ images = torch.randn((batch_size, in_channels, sample_size, sample_size), device=unet.device)
77
+ for timestep in tqdm(timesteps, desc='Sampling'):
78
+ pred_noise: torch.Tensor = unet(images, timestep).sample
79
+
80
+ # mean of q(x_{t-1}|x_t)
81
+ alpha_t = alphas[timestep]
82
+ alpha_cumprod_t = alphas_cumprod[timestep]
83
+ sqrt_alpha_t = alpha_t ** 0.5
84
+ one_minus_alpha_t = 1.0 - alpha_t
85
+ sqrt_one_minus_alpha_cumprod_t = (1 - alpha_cumprod_t) ** 0.5
86
+ mean = (images - one_minus_alpha_t / sqrt_one_minus_alpha_cumprod_t * pred_noise) / sqrt_alpha_t
87
+
88
+ # variance of q(x_{t-1}|x_t)
89
+ if timestep > 1:
90
+ beta_t = betas[timestep]
91
+ one_minus_alpha_cumprod_t_minus_one = 1.0 - alphas_cumprod[timestep - 1]
92
+ one_divided_by_sigma_square = alpha_t / beta_t + 1.0 / one_minus_alpha_cumprod_t_minus_one
93
+ variance = (1.0 / one_divided_by_sigma_square) ** 0.5
94
+ else:
95
+ variance = torch.zeros_like(timestep)
96
+
97
+ epsilon = torch.randn_like(images)
98
+ images = mean + variance * epsilon
99
+ images = (images / 2.0 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy()
100
+ return images
101
+
102
+ model = UNet2DModel.from_pretrained('ddpm-animefaces-64').cuda()
103
+ ddpm = DDPM()
104
+ images = ddpm.sample(model, 32, 3, 64)
105
+
106
+ from diffusers.utils import make_image_grid, numpy_to_pil
107
+ image_grid = make_image_grid(numpy_to_pil(images), rows=4, cols=8)
108
+ image_grid.save('ddpm-sample-results.png')
109
+ ```
110
+
111
+ This can also be found in [this link](https://github.com/LittleNyima/code-snippets/blob/master/ddpm-tutorial/ddpm_sampling.py).
config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DModel",
3
+ "_diffusers_version": "0.28.1",
4
+ "act_fn": "silu",
5
+ "add_attention": true,
6
+ "attention_head_dim": 8,
7
+ "attn_norm_num_groups": null,
8
+ "block_out_channels": [
9
+ 128,
10
+ 128,
11
+ 256,
12
+ 256,
13
+ 512,
14
+ 512
15
+ ],
16
+ "center_input_sample": false,
17
+ "class_embed_type": null,
18
+ "down_block_types": [
19
+ "DownBlock2D",
20
+ "DownBlock2D",
21
+ "DownBlock2D",
22
+ "DownBlock2D",
23
+ "AttnDownBlock2D",
24
+ "DownBlock2D"
25
+ ],
26
+ "downsample_padding": 1,
27
+ "downsample_type": "conv",
28
+ "dropout": 0.0,
29
+ "flip_sin_to_cos": true,
30
+ "freq_shift": 0,
31
+ "in_channels": 3,
32
+ "layers_per_block": 2,
33
+ "mid_block_scale_factor": 1,
34
+ "norm_eps": 1e-05,
35
+ "norm_num_groups": 32,
36
+ "num_class_embeds": null,
37
+ "num_train_timesteps": null,
38
+ "out_channels": 3,
39
+ "resnet_time_scale_shift": "default",
40
+ "sample_size": 64,
41
+ "time_embedding_type": "positional",
42
+ "up_block_types": [
43
+ "UpBlock2D",
44
+ "AttnUpBlock2D",
45
+ "UpBlock2D",
46
+ "UpBlock2D",
47
+ "UpBlock2D",
48
+ "UpBlock2D"
49
+ ],
50
+ "upsample_type": "conv"
51
+ }
diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e75355c71e36f8e36d802f61b3c7f04cad455e591278e5e8a9a1ad89c4e2f990
3
+ size 454741108