dikdimon commited on
Commit
702b253
1 Parent(s): 0593fe4

Upload 142 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ldm_patched/contrib/external.py +1977 -0
  2. ldm_patched/contrib/external_canny.py +303 -0
  3. ldm_patched/contrib/external_clip_sdxl.py +60 -0
  4. ldm_patched/contrib/external_compositing.py +206 -0
  5. ldm_patched/contrib/external_custom_sampler.py +299 -0
  6. ldm_patched/contrib/external_freelunch.py +115 -0
  7. ldm_patched/contrib/external_hypernetwork.py +123 -0
  8. ldm_patched/contrib/external_hypertile.py +85 -0
  9. ldm_patched/contrib/external_images.py +179 -0
  10. ldm_patched/contrib/external_latent.py +159 -0
  11. ldm_patched/contrib/external_mask.py +367 -0
  12. ldm_patched/contrib/external_model_advanced.py +179 -0
  13. ldm_patched/contrib/external_model_downscale.py +57 -0
  14. ldm_patched/contrib/external_model_merging.py +288 -0
  15. ldm_patched/contrib/external_perpneg.py +59 -0
  16. ldm_patched/contrib/external_photomaker.py +189 -0
  17. ldm_patched/contrib/external_post_processing.py +280 -0
  18. ldm_patched/contrib/external_rebatch.py +142 -0
  19. ldm_patched/contrib/external_sag.py +172 -0
  20. ldm_patched/contrib/external_sdupscale.py +51 -0
  21. ldm_patched/contrib/external_stable3d.py +104 -0
  22. ldm_patched/contrib/external_tomesd.py +164 -0
  23. ldm_patched/contrib/external_upscale_model.py +70 -0
  24. ldm_patched/contrib/external_video_model.py +110 -0
  25. ldm_patched/controlnet/cldm.py +312 -0
  26. ldm_patched/k_diffusion/sampling.py +814 -0
  27. ldm_patched/k_diffusion/utils.py +317 -0
  28. ldm_patched/ldm/models/autoencoder.py +235 -0
  29. ldm_patched/ldm/modules/attention.py +788 -0
  30. ldm_patched/ldm/modules/diffusionmodules/__init__.py +0 -0
  31. ldm_patched/ldm/modules/diffusionmodules/model.py +657 -0
  32. ldm_patched/ldm/modules/diffusionmodules/openaimodel.py +933 -0
  33. ldm_patched/ldm/modules/diffusionmodules/upscaling.py +93 -0
  34. ldm_patched/ldm/modules/diffusionmodules/util.py +303 -0
  35. ldm_patched/ldm/modules/distributions/__init__.py +0 -0
  36. ldm_patched/ldm/modules/distributions/distributions.py +96 -0
  37. ldm_patched/ldm/modules/ema.py +89 -0
  38. ldm_patched/ldm/modules/encoders/__init__.py +0 -0
  39. ldm_patched/ldm/modules/encoders/noise_aug_modules.py +39 -0
  40. ldm_patched/ldm/modules/sub_quadratic_attention.py +273 -0
  41. ldm_patched/ldm/modules/temporal_ae.py +252 -0
  42. ldm_patched/ldm/util.py +201 -0
  43. ldm_patched/licenses-3rd/chainer +20 -0
  44. ldm_patched/licenses-3rd/comfyui +674 -0
  45. ldm_patched/licenses-3rd/diffusers +201 -0
  46. ldm_patched/licenses-3rd/kdiffusion +19 -0
  47. ldm_patched/licenses-3rd/ldm +21 -0
  48. ldm_patched/licenses-3rd/taesd +21 -0
  49. ldm_patched/licenses-3rd/transformers +203 -0
  50. ldm_patched/modules/args_parser.py +131 -0
ldm_patched/contrib/external.py ADDED
@@ -0,0 +1,1977 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import torch
6
+
7
+ import os
8
+ import sys
9
+ import json
10
+ import hashlib
11
+ import traceback
12
+ import math
13
+ import time
14
+ import random
15
+
16
+ from PIL import Image, ImageOps, ImageSequence
17
+ from PIL.PngImagePlugin import PngInfo
18
+ import numpy as np
19
+ import safetensors.torch
20
+
21
+ pass # sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "ldm_patched"))
22
+
23
+
24
+ import ldm_patched.modules.diffusers_load
25
+ import ldm_patched.modules.samplers
26
+ import ldm_patched.modules.sample
27
+ import ldm_patched.modules.sd
28
+ import ldm_patched.modules.utils
29
+ import ldm_patched.modules.controlnet
30
+
31
+ import ldm_patched.modules.clip_vision
32
+
33
+ import ldm_patched.modules.model_management
34
+ from ldm_patched.modules.args_parser import args
35
+
36
+ import importlib
37
+
38
+ import ldm_patched.utils.path_utils
39
+ import ldm_patched.utils.latent_visualization
40
+
41
+ def before_node_execution():
42
+ ldm_patched.modules.model_management.throw_exception_if_processing_interrupted()
43
+
44
+ def interrupt_processing(value=True):
45
+ ldm_patched.modules.model_management.interrupt_current_processing(value)
46
+
47
+ MAX_RESOLUTION=8192
48
+
49
+ class CLIPTextEncode:
50
+ @classmethod
51
+ def INPUT_TYPES(s):
52
+ return {"required": {"text": ("STRING", {"multiline": True}), "clip": ("CLIP", )}}
53
+ RETURN_TYPES = ("CONDITIONING",)
54
+ FUNCTION = "encode"
55
+
56
+ CATEGORY = "conditioning"
57
+
58
+ def encode(self, clip, text):
59
+ tokens = clip.tokenize(text)
60
+ cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
61
+ return ([[cond, {"pooled_output": pooled}]], )
62
+
63
+ class ConditioningCombine:
64
+ @classmethod
65
+ def INPUT_TYPES(s):
66
+ return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
67
+ RETURN_TYPES = ("CONDITIONING",)
68
+ FUNCTION = "combine"
69
+
70
+ CATEGORY = "conditioning"
71
+
72
+ def combine(self, conditioning_1, conditioning_2):
73
+ return (conditioning_1 + conditioning_2, )
74
+
75
+ class ConditioningAverage :
76
+ @classmethod
77
+ def INPUT_TYPES(s):
78
+ return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ),
79
+ "conditioning_to_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
80
+ }}
81
+ RETURN_TYPES = ("CONDITIONING",)
82
+ FUNCTION = "addWeighted"
83
+
84
+ CATEGORY = "conditioning"
85
+
86
+ def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_strength):
87
+ out = []
88
+
89
+ if len(conditioning_from) > 1:
90
+ print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
91
+
92
+ cond_from = conditioning_from[0][0]
93
+ pooled_output_from = conditioning_from[0][1].get("pooled_output", None)
94
+
95
+ for i in range(len(conditioning_to)):
96
+ t1 = conditioning_to[i][0]
97
+ pooled_output_to = conditioning_to[i][1].get("pooled_output", pooled_output_from)
98
+ t0 = cond_from[:,:t1.shape[1]]
99
+ if t0.shape[1] < t1.shape[1]:
100
+ t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)
101
+
102
+ tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength))
103
+ t_to = conditioning_to[i][1].copy()
104
+ if pooled_output_from is not None and pooled_output_to is not None:
105
+ t_to["pooled_output"] = torch.mul(pooled_output_to, conditioning_to_strength) + torch.mul(pooled_output_from, (1.0 - conditioning_to_strength))
106
+ elif pooled_output_from is not None:
107
+ t_to["pooled_output"] = pooled_output_from
108
+
109
+ n = [tw, t_to]
110
+ out.append(n)
111
+ return (out, )
112
+
113
+ class ConditioningConcat:
114
+ @classmethod
115
+ def INPUT_TYPES(s):
116
+ return {"required": {
117
+ "conditioning_to": ("CONDITIONING",),
118
+ "conditioning_from": ("CONDITIONING",),
119
+ }}
120
+ RETURN_TYPES = ("CONDITIONING",)
121
+ FUNCTION = "concat"
122
+
123
+ CATEGORY = "conditioning"
124
+
125
+ def concat(self, conditioning_to, conditioning_from):
126
+ out = []
127
+
128
+ if len(conditioning_from) > 1:
129
+ print("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
130
+
131
+ cond_from = conditioning_from[0][0]
132
+
133
+ for i in range(len(conditioning_to)):
134
+ t1 = conditioning_to[i][0]
135
+ tw = torch.cat((t1, cond_from),1)
136
+ n = [tw, conditioning_to[i][1].copy()]
137
+ out.append(n)
138
+
139
+ return (out, )
140
+
141
+ class ConditioningSetArea:
142
+ @classmethod
143
+ def INPUT_TYPES(s):
144
+ return {"required": {"conditioning": ("CONDITIONING", ),
145
+ "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
146
+ "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
147
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
148
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
149
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
150
+ }}
151
+ RETURN_TYPES = ("CONDITIONING",)
152
+ FUNCTION = "append"
153
+
154
+ CATEGORY = "conditioning"
155
+
156
+ def append(self, conditioning, width, height, x, y, strength):
157
+ c = []
158
+ for t in conditioning:
159
+ n = [t[0], t[1].copy()]
160
+ n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
161
+ n[1]['strength'] = strength
162
+ n[1]['set_area_to_bounds'] = False
163
+ c.append(n)
164
+ return (c, )
165
+
166
+ class ConditioningSetAreaPercentage:
167
+ @classmethod
168
+ def INPUT_TYPES(s):
169
+ return {"required": {"conditioning": ("CONDITIONING", ),
170
+ "width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
171
+ "height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
172
+ "x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
173
+ "y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
174
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
175
+ }}
176
+ RETURN_TYPES = ("CONDITIONING",)
177
+ FUNCTION = "append"
178
+
179
+ CATEGORY = "conditioning"
180
+
181
+ def append(self, conditioning, width, height, x, y, strength):
182
+ c = []
183
+ for t in conditioning:
184
+ n = [t[0], t[1].copy()]
185
+ n[1]['area'] = ("percentage", height, width, y, x)
186
+ n[1]['strength'] = strength
187
+ n[1]['set_area_to_bounds'] = False
188
+ c.append(n)
189
+ return (c, )
190
+
191
+ class ConditioningSetAreaStrength:
192
+ @classmethod
193
+ def INPUT_TYPES(s):
194
+ return {"required": {"conditioning": ("CONDITIONING", ),
195
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
196
+ }}
197
+ RETURN_TYPES = ("CONDITIONING",)
198
+ FUNCTION = "append"
199
+
200
+ CATEGORY = "conditioning"
201
+
202
+ def append(self, conditioning, strength):
203
+ c = []
204
+ for t in conditioning:
205
+ n = [t[0], t[1].copy()]
206
+ n[1]['strength'] = strength
207
+ c.append(n)
208
+ return (c, )
209
+
210
+
211
+ class ConditioningSetMask:
212
+ @classmethod
213
+ def INPUT_TYPES(s):
214
+ return {"required": {"conditioning": ("CONDITIONING", ),
215
+ "mask": ("MASK", ),
216
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
217
+ "set_cond_area": (["default", "mask bounds"],),
218
+ }}
219
+ RETURN_TYPES = ("CONDITIONING",)
220
+ FUNCTION = "append"
221
+
222
+ CATEGORY = "conditioning"
223
+
224
+ def append(self, conditioning, mask, set_cond_area, strength):
225
+ c = []
226
+ set_area_to_bounds = False
227
+ if set_cond_area != "default":
228
+ set_area_to_bounds = True
229
+ if len(mask.shape) < 3:
230
+ mask = mask.unsqueeze(0)
231
+ for t in conditioning:
232
+ n = [t[0], t[1].copy()]
233
+ _, h, w = mask.shape
234
+ n[1]['mask'] = mask
235
+ n[1]['set_area_to_bounds'] = set_area_to_bounds
236
+ n[1]['mask_strength'] = strength
237
+ c.append(n)
238
+ return (c, )
239
+
240
+ class ConditioningZeroOut:
241
+ @classmethod
242
+ def INPUT_TYPES(s):
243
+ return {"required": {"conditioning": ("CONDITIONING", )}}
244
+ RETURN_TYPES = ("CONDITIONING",)
245
+ FUNCTION = "zero_out"
246
+
247
+ CATEGORY = "advanced/conditioning"
248
+
249
+ def zero_out(self, conditioning):
250
+ c = []
251
+ for t in conditioning:
252
+ d = t[1].copy()
253
+ if "pooled_output" in d:
254
+ d["pooled_output"] = torch.zeros_like(d["pooled_output"])
255
+ n = [torch.zeros_like(t[0]), d]
256
+ c.append(n)
257
+ return (c, )
258
+
259
+ class ConditioningSetTimestepRange:
260
+ @classmethod
261
+ def INPUT_TYPES(s):
262
+ return {"required": {"conditioning": ("CONDITIONING", ),
263
+ "start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
264
+ "end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
265
+ }}
266
+ RETURN_TYPES = ("CONDITIONING",)
267
+ FUNCTION = "set_range"
268
+
269
+ CATEGORY = "advanced/conditioning"
270
+
271
+ def set_range(self, conditioning, start, end):
272
+ c = []
273
+ for t in conditioning:
274
+ d = t[1].copy()
275
+ d['start_percent'] = start
276
+ d['end_percent'] = end
277
+ n = [t[0], d]
278
+ c.append(n)
279
+ return (c, )
280
+
281
+ class VAEDecode:
282
+ @classmethod
283
+ def INPUT_TYPES(s):
284
+ return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
285
+ RETURN_TYPES = ("IMAGE",)
286
+ FUNCTION = "decode"
287
+
288
+ CATEGORY = "latent"
289
+
290
+ def decode(self, vae, samples):
291
+ return (vae.decode(samples["samples"]), )
292
+
293
+ class VAEDecodeTiled:
294
+ @classmethod
295
+ def INPUT_TYPES(s):
296
+ return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
297
+ "tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
298
+ }}
299
+ RETURN_TYPES = ("IMAGE",)
300
+ FUNCTION = "decode"
301
+
302
+ CATEGORY = "_for_testing"
303
+
304
+ def decode(self, vae, samples, tile_size):
305
+ return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), )
306
+
307
+ class VAEEncode:
308
+ @classmethod
309
+ def INPUT_TYPES(s):
310
+ return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
311
+ RETURN_TYPES = ("LATENT",)
312
+ FUNCTION = "encode"
313
+
314
+ CATEGORY = "latent"
315
+
316
+ @staticmethod
317
+ def vae_encode_crop_pixels(pixels):
318
+ x = (pixels.shape[1] // 8) * 8
319
+ y = (pixels.shape[2] // 8) * 8
320
+ if pixels.shape[1] != x or pixels.shape[2] != y:
321
+ x_offset = (pixels.shape[1] % 8) // 2
322
+ y_offset = (pixels.shape[2] % 8) // 2
323
+ pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
324
+ return pixels
325
+
326
+ def encode(self, vae, pixels):
327
+ pixels = self.vae_encode_crop_pixels(pixels)
328
+ t = vae.encode(pixels[:,:,:,:3])
329
+ return ({"samples":t}, )
330
+
331
+ class VAEEncodeTiled:
332
+ @classmethod
333
+ def INPUT_TYPES(s):
334
+ return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
335
+ "tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
336
+ }}
337
+ RETURN_TYPES = ("LATENT",)
338
+ FUNCTION = "encode"
339
+
340
+ CATEGORY = "_for_testing"
341
+
342
+ def encode(self, vae, pixels, tile_size):
343
+ pixels = VAEEncode.vae_encode_crop_pixels(pixels)
344
+ t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
345
+ return ({"samples":t}, )
346
+
347
+ class VAEEncodeForInpaint:
348
+ @classmethod
349
+ def INPUT_TYPES(s):
350
+ return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
351
+ RETURN_TYPES = ("LATENT",)
352
+ FUNCTION = "encode"
353
+
354
+ CATEGORY = "latent/inpaint"
355
+
356
+ def encode(self, vae, pixels, mask, grow_mask_by=6):
357
+ x = (pixels.shape[1] // 8) * 8
358
+ y = (pixels.shape[2] // 8) * 8
359
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
360
+
361
+ pixels = pixels.clone()
362
+ if pixels.shape[1] != x or pixels.shape[2] != y:
363
+ x_offset = (pixels.shape[1] % 8) // 2
364
+ y_offset = (pixels.shape[2] % 8) // 2
365
+ pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
366
+ mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
367
+
368
+ #grow mask by a few pixels to keep things seamless in latent space
369
+ if grow_mask_by == 0:
370
+ mask_erosion = mask
371
+ else:
372
+ kernel_tensor = torch.ones((1, 1, grow_mask_by, grow_mask_by))
373
+ padding = math.ceil((grow_mask_by - 1) / 2)
374
+
375
+ mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=padding), 0, 1)
376
+
377
+ m = (1.0 - mask.round()).squeeze(1)
378
+ for i in range(3):
379
+ pixels[:,:,:,i] -= 0.5
380
+ pixels[:,:,:,i] *= m
381
+ pixels[:,:,:,i] += 0.5
382
+ t = vae.encode(pixels)
383
+
384
+ return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
385
+
386
+
387
+ class InpaintModelConditioning:
388
+ @classmethod
389
+ def INPUT_TYPES(s):
390
+ return {"required": {"positive": ("CONDITIONING", ),
391
+ "negative": ("CONDITIONING", ),
392
+ "vae": ("VAE", ),
393
+ "pixels": ("IMAGE", ),
394
+ "mask": ("MASK", ),
395
+ }}
396
+
397
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
398
+ RETURN_NAMES = ("positive", "negative", "latent")
399
+ FUNCTION = "encode"
400
+
401
+ CATEGORY = "conditioning/inpaint"
402
+
403
+ def encode(self, positive, negative, pixels, vae, mask):
404
+ x = (pixels.shape[1] // 8) * 8
405
+ y = (pixels.shape[2] // 8) * 8
406
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
407
+
408
+ orig_pixels = pixels
409
+ pixels = orig_pixels.clone()
410
+ if pixels.shape[1] != x or pixels.shape[2] != y:
411
+ x_offset = (pixels.shape[1] % 8) // 2
412
+ y_offset = (pixels.shape[2] % 8) // 2
413
+ pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
414
+ mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
415
+
416
+ m = (1.0 - mask.round()).squeeze(1)
417
+ for i in range(3):
418
+ pixels[:,:,:,i] -= 0.5
419
+ pixels[:,:,:,i] *= m
420
+ pixels[:,:,:,i] += 0.5
421
+ concat_latent = vae.encode(pixels)
422
+ orig_latent = vae.encode(orig_pixels)
423
+
424
+ out_latent = {}
425
+
426
+ out_latent["samples"] = orig_latent
427
+ out_latent["noise_mask"] = mask
428
+
429
+ out = []
430
+ for conditioning in [positive, negative]:
431
+ c = []
432
+ for t in conditioning:
433
+ d = t[1].copy()
434
+ d["concat_latent_image"] = concat_latent
435
+ d["concat_mask"] = mask
436
+ n = [t[0], d]
437
+ c.append(n)
438
+ out.append(c)
439
+ return (out[0], out[1], out_latent)
440
+
441
+
442
+ class SaveLatent:
443
+ def __init__(self):
444
+ self.output_dir = ldm_patched.utils.path_utils.get_output_directory()
445
+
446
+ @classmethod
447
+ def INPUT_TYPES(s):
448
+ return {"required": { "samples": ("LATENT", ),
449
+ "filename_prefix": ("STRING", {"default": "latents/ldm_patched"})},
450
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
451
+ }
452
+ RETURN_TYPES = ()
453
+ FUNCTION = "save"
454
+
455
+ OUTPUT_NODE = True
456
+
457
+ CATEGORY = "_for_testing"
458
+
459
+ def save(self, samples, filename_prefix="ldm_patched", prompt=None, extra_pnginfo=None):
460
+ full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, self.output_dir)
461
+
462
+ # support save metadata for latent sharing
463
+ prompt_info = ""
464
+ if prompt is not None:
465
+ prompt_info = json.dumps(prompt)
466
+
467
+ metadata = None
468
+ if not args.disable_server_info:
469
+ metadata = {"prompt": prompt_info}
470
+ if extra_pnginfo is not None:
471
+ for x in extra_pnginfo:
472
+ metadata[x] = json.dumps(extra_pnginfo[x])
473
+
474
+ file = f"{filename}_{counter:05}_.latent"
475
+
476
+ results = list()
477
+ results.append({
478
+ "filename": file,
479
+ "subfolder": subfolder,
480
+ "type": "output"
481
+ })
482
+
483
+ file = os.path.join(full_output_folder, file)
484
+
485
+ output = {}
486
+ output["latent_tensor"] = samples["samples"]
487
+ output["latent_format_version_0"] = torch.tensor([])
488
+
489
+ ldm_patched.modules.utils.save_torch_file(output, file, metadata=metadata)
490
+ return { "ui": { "latents": results } }
491
+
492
+
493
+ class LoadLatent:
494
+ @classmethod
495
+ def INPUT_TYPES(s):
496
+ input_dir = ldm_patched.utils.path_utils.get_input_directory()
497
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
498
+ return {"required": {"latent": [sorted(files), ]}, }
499
+
500
+ CATEGORY = "_for_testing"
501
+
502
+ RETURN_TYPES = ("LATENT", )
503
+ FUNCTION = "load"
504
+
505
+ def load(self, latent):
506
+ latent_path = ldm_patched.utils.path_utils.get_annotated_filepath(latent)
507
+ latent = safetensors.torch.load_file(latent_path, device="cpu")
508
+ multiplier = 1.0
509
+ if "latent_format_version_0" not in latent:
510
+ multiplier = 1.0 / 0.18215
511
+ samples = {"samples": latent["latent_tensor"].float() * multiplier}
512
+ return (samples, )
513
+
514
+ @classmethod
515
+ def IS_CHANGED(s, latent):
516
+ image_path = ldm_patched.utils.path_utils.get_annotated_filepath(latent)
517
+ m = hashlib.sha256()
518
+ with open(image_path, 'rb') as f:
519
+ m.update(f.read())
520
+ return m.digest().hex()
521
+
522
+ @classmethod
523
+ def VALIDATE_INPUTS(s, latent):
524
+ if not ldm_patched.utils.path_utils.exists_annotated_filepath(latent):
525
+ return "Invalid latent file: {}".format(latent)
526
+ return True
527
+
528
+
529
+ class CheckpointLoader:
530
+ @classmethod
531
+ def INPUT_TYPES(s):
532
+ return {"required": { "config_name": (ldm_patched.utils.path_utils.get_filename_list("configs"), ),
533
+ "ckpt_name": (ldm_patched.utils.path_utils.get_filename_list("checkpoints"), )}}
534
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
535
+ FUNCTION = "load_checkpoint"
536
+
537
+ CATEGORY = "advanced/loaders"
538
+
539
+ def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
540
+ config_path = ldm_patched.utils.path_utils.get_full_path("configs", config_name)
541
+ ckpt_path = ldm_patched.utils.path_utils.get_full_path("checkpoints", ckpt_name)
542
+ return ldm_patched.modules.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings"))
543
+
544
+ class CheckpointLoaderSimple:
545
+ @classmethod
546
+ def INPUT_TYPES(s):
547
+ return {"required": { "ckpt_name": (ldm_patched.utils.path_utils.get_filename_list("checkpoints"), ),
548
+ }}
549
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
550
+ FUNCTION = "load_checkpoint"
551
+
552
+ CATEGORY = "loaders"
553
+
554
+ def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
555
+ ckpt_path = ldm_patched.utils.path_utils.get_full_path("checkpoints", ckpt_name)
556
+ out = ldm_patched.modules.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings"))
557
+ return out[:3]
558
+
559
+ class DiffusersLoader:
560
+ @classmethod
561
+ def INPUT_TYPES(cls):
562
+ paths = []
563
+ for search_path in ldm_patched.utils.path_utils.get_folder_paths("diffusers"):
564
+ if os.path.exists(search_path):
565
+ for root, subdir, files in os.walk(search_path, followlinks=True):
566
+ if "model_index.json" in files:
567
+ paths.append(os.path.relpath(root, start=search_path))
568
+
569
+ return {"required": {"model_path": (paths,), }}
570
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
571
+ FUNCTION = "load_checkpoint"
572
+
573
+ CATEGORY = "advanced/loaders/deprecated"
574
+
575
+ def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
576
+ for search_path in ldm_patched.utils.path_utils.get_folder_paths("diffusers"):
577
+ if os.path.exists(search_path):
578
+ path = os.path.join(search_path, model_path)
579
+ if os.path.exists(path):
580
+ model_path = path
581
+ break
582
+
583
+ return ldm_patched.modules.diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings"))
584
+
585
+
586
+ class unCLIPCheckpointLoader:
587
+ @classmethod
588
+ def INPUT_TYPES(s):
589
+ return {"required": { "ckpt_name": (ldm_patched.utils.path_utils.get_filename_list("checkpoints"), ),
590
+ }}
591
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
592
+ FUNCTION = "load_checkpoint"
593
+
594
+ CATEGORY = "loaders"
595
+
596
+ def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
597
+ ckpt_path = ldm_patched.utils.path_utils.get_full_path("checkpoints", ckpt_name)
598
+ out = ldm_patched.modules.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings"))
599
+ return out
600
+
601
+ class CLIPSetLastLayer:
602
+ @classmethod
603
+ def INPUT_TYPES(s):
604
+ return {"required": { "clip": ("CLIP", ),
605
+ "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}),
606
+ }}
607
+ RETURN_TYPES = ("CLIP",)
608
+ FUNCTION = "set_last_layer"
609
+
610
+ CATEGORY = "conditioning"
611
+
612
+ def set_last_layer(self, clip, stop_at_clip_layer):
613
+ clip = clip.clone()
614
+ clip.clip_layer(stop_at_clip_layer)
615
+ return (clip,)
616
+
617
+ class LoraLoader:
618
+ def __init__(self):
619
+ self.loaded_lora = None
620
+
621
+ @classmethod
622
+ def INPUT_TYPES(s):
623
+ return {"required": { "model": ("MODEL",),
624
+ "clip": ("CLIP", ),
625
+ "lora_name": (ldm_patched.utils.path_utils.get_filename_list("loras"), ),
626
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
627
+ "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
628
+ }}
629
+ RETURN_TYPES = ("MODEL", "CLIP")
630
+ FUNCTION = "load_lora"
631
+
632
+ CATEGORY = "loaders"
633
+
634
+ def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
635
+ if strength_model == 0 and strength_clip == 0:
636
+ return (model, clip)
637
+
638
+ lora_path = ldm_patched.utils.path_utils.get_full_path("loras", lora_name)
639
+ lora = None
640
+ if self.loaded_lora is not None:
641
+ if self.loaded_lora[0] == lora_path:
642
+ lora = self.loaded_lora[1]
643
+ else:
644
+ temp = self.loaded_lora
645
+ self.loaded_lora = None
646
+ del temp
647
+
648
+ if lora is None:
649
+ lora = ldm_patched.modules.utils.load_torch_file(lora_path, safe_load=True)
650
+ self.loaded_lora = (lora_path, lora)
651
+
652
+ model_lora, clip_lora = ldm_patched.modules.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
653
+ return (model_lora, clip_lora)
654
+
655
+ class LoraLoaderModelOnly(LoraLoader):
656
+ @classmethod
657
+ def INPUT_TYPES(s):
658
+ return {"required": { "model": ("MODEL",),
659
+ "lora_name": (ldm_patched.utils.path_utils.get_filename_list("loras"), ),
660
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
661
+ }}
662
+ RETURN_TYPES = ("MODEL",)
663
+ FUNCTION = "load_lora_model_only"
664
+
665
+ def load_lora_model_only(self, model, lora_name, strength_model):
666
+ return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
667
+
668
+ class VAELoader:
669
+ @staticmethod
670
+ def vae_list():
671
+ vaes = ldm_patched.utils.path_utils.get_filename_list("vae")
672
+ approx_vaes = ldm_patched.utils.path_utils.get_filename_list("vae_approx")
673
+ sdxl_taesd_enc = False
674
+ sdxl_taesd_dec = False
675
+ sd1_taesd_enc = False
676
+ sd1_taesd_dec = False
677
+
678
+ for v in approx_vaes:
679
+ if v.startswith("taesd_decoder."):
680
+ sd1_taesd_dec = True
681
+ elif v.startswith("taesd_encoder."):
682
+ sd1_taesd_enc = True
683
+ elif v.startswith("taesdxl_decoder."):
684
+ sdxl_taesd_dec = True
685
+ elif v.startswith("taesdxl_encoder."):
686
+ sdxl_taesd_enc = True
687
+ if sd1_taesd_dec and sd1_taesd_enc:
688
+ vaes.append("taesd")
689
+ if sdxl_taesd_dec and sdxl_taesd_enc:
690
+ vaes.append("taesdxl")
691
+ return vaes
692
+
693
+ @staticmethod
694
+ def load_taesd(name):
695
+ sd = {}
696
+ approx_vaes = ldm_patched.utils.path_utils.get_filename_list("vae_approx")
697
+
698
+ encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
699
+ decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
700
+
701
+ enc = ldm_patched.modules.utils.load_torch_file(ldm_patched.utils.path_utils.get_full_path("vae_approx", encoder))
702
+ for k in enc:
703
+ sd["taesd_encoder.{}".format(k)] = enc[k]
704
+
705
+ dec = ldm_patched.modules.utils.load_torch_file(ldm_patched.utils.path_utils.get_full_path("vae_approx", decoder))
706
+ for k in dec:
707
+ sd["taesd_decoder.{}".format(k)] = dec[k]
708
+
709
+ if name == "taesd":
710
+ sd["vae_scale"] = torch.tensor(0.18215)
711
+ elif name == "taesdxl":
712
+ sd["vae_scale"] = torch.tensor(0.13025)
713
+ return sd
714
+
715
+ @classmethod
716
+ def INPUT_TYPES(s):
717
+ return {"required": { "vae_name": (s.vae_list(), )}}
718
+ RETURN_TYPES = ("VAE",)
719
+ FUNCTION = "load_vae"
720
+
721
+ CATEGORY = "loaders"
722
+
723
+ #TODO: scale factor?
724
+ def load_vae(self, vae_name):
725
+ if vae_name in ["taesd", "taesdxl"]:
726
+ sd = self.load_taesd(vae_name)
727
+ else:
728
+ vae_path = ldm_patched.utils.path_utils.get_full_path("vae", vae_name)
729
+ sd = ldm_patched.modules.utils.load_torch_file(vae_path)
730
+ vae = ldm_patched.modules.sd.VAE(sd=sd)
731
+ return (vae,)
732
+
733
+ class ControlNetLoader:
734
+ @classmethod
735
+ def INPUT_TYPES(s):
736
+ return {"required": { "control_net_name": (ldm_patched.utils.path_utils.get_filename_list("controlnet"), )}}
737
+
738
+ RETURN_TYPES = ("CONTROL_NET",)
739
+ FUNCTION = "load_controlnet"
740
+
741
+ CATEGORY = "loaders"
742
+
743
+ def load_controlnet(self, control_net_name):
744
+ controlnet_path = ldm_patched.utils.path_utils.get_full_path("controlnet", control_net_name)
745
+ controlnet = ldm_patched.modules.controlnet.load_controlnet(controlnet_path)
746
+ return (controlnet,)
747
+
748
+ class DiffControlNetLoader:
749
+ @classmethod
750
+ def INPUT_TYPES(s):
751
+ return {"required": { "model": ("MODEL",),
752
+ "control_net_name": (ldm_patched.utils.path_utils.get_filename_list("controlnet"), )}}
753
+
754
+ RETURN_TYPES = ("CONTROL_NET",)
755
+ FUNCTION = "load_controlnet"
756
+
757
+ CATEGORY = "loaders"
758
+
759
+ def load_controlnet(self, model, control_net_name):
760
+ controlnet_path = ldm_patched.utils.path_utils.get_full_path("controlnet", control_net_name)
761
+ controlnet = ldm_patched.modules.controlnet.load_controlnet(controlnet_path, model)
762
+ return (controlnet,)
763
+
764
+
765
+ class ControlNetApply:
766
+ @classmethod
767
+ def INPUT_TYPES(s):
768
+ return {"required": {"conditioning": ("CONDITIONING", ),
769
+ "control_net": ("CONTROL_NET", ),
770
+ "image": ("IMAGE", ),
771
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01})
772
+ }}
773
+ RETURN_TYPES = ("CONDITIONING",)
774
+ FUNCTION = "apply_controlnet"
775
+
776
+ CATEGORY = "conditioning"
777
+
778
+ def apply_controlnet(self, conditioning, control_net, image, strength):
779
+ if strength == 0:
780
+ return (conditioning, )
781
+
782
+ c = []
783
+ control_hint = image.movedim(-1,1)
784
+ for t in conditioning:
785
+ n = [t[0], t[1].copy()]
786
+ c_net = control_net.copy().set_cond_hint(control_hint, strength)
787
+ if 'control' in t[1]:
788
+ c_net.set_previous_controlnet(t[1]['control'])
789
+ n[1]['control'] = c_net
790
+ n[1]['control_apply_to_uncond'] = True
791
+ c.append(n)
792
+ return (c, )
793
+
794
+
795
+ class ControlNetApplyAdvanced:
796
+ @classmethod
797
+ def INPUT_TYPES(s):
798
+ return {"required": {"positive": ("CONDITIONING", ),
799
+ "negative": ("CONDITIONING", ),
800
+ "control_net": ("CONTROL_NET", ),
801
+ "image": ("IMAGE", ),
802
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
803
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
804
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
805
+ }}
806
+
807
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING")
808
+ RETURN_NAMES = ("positive", "negative")
809
+ FUNCTION = "apply_controlnet"
810
+
811
+ CATEGORY = "conditioning"
812
+
813
+ def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent):
814
+ if strength == 0:
815
+ return (positive, negative)
816
+
817
+ control_hint = image.movedim(-1,1)
818
+ cnets = {}
819
+
820
+ out = []
821
+ for conditioning in [positive, negative]:
822
+ c = []
823
+ for t in conditioning:
824
+ d = t[1].copy()
825
+
826
+ prev_cnet = d.get('control', None)
827
+ if prev_cnet in cnets:
828
+ c_net = cnets[prev_cnet]
829
+ else:
830
+ c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
831
+ c_net.set_previous_controlnet(prev_cnet)
832
+ cnets[prev_cnet] = c_net
833
+
834
+ d['control'] = c_net
835
+ d['control_apply_to_uncond'] = False
836
+ n = [t[0], d]
837
+ c.append(n)
838
+ out.append(c)
839
+ return (out[0], out[1])
840
+
841
+
842
+ class UNETLoader:
843
+ @classmethod
844
+ def INPUT_TYPES(s):
845
+ return {"required": { "unet_name": (ldm_patched.utils.path_utils.get_filename_list("unet"), ),
846
+ }}
847
+ RETURN_TYPES = ("MODEL",)
848
+ FUNCTION = "load_unet"
849
+
850
+ CATEGORY = "advanced/loaders"
851
+
852
+ def load_unet(self, unet_name):
853
+ unet_path = ldm_patched.utils.path_utils.get_full_path("unet", unet_name)
854
+ model = ldm_patched.modules.sd.load_unet(unet_path)
855
+ return (model,)
856
+
857
+ class CLIPLoader:
858
+ @classmethod
859
+ def INPUT_TYPES(s):
860
+ return {"required": { "clip_name": (ldm_patched.utils.path_utils.get_filename_list("clip"), ),
861
+ }}
862
+ RETURN_TYPES = ("CLIP",)
863
+ FUNCTION = "load_clip"
864
+
865
+ CATEGORY = "advanced/loaders"
866
+
867
+ def load_clip(self, clip_name):
868
+ clip_path = ldm_patched.utils.path_utils.get_full_path("clip", clip_name)
869
+ clip = ldm_patched.modules.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings"))
870
+ return (clip,)
871
+
872
+ class DualCLIPLoader:
873
+ @classmethod
874
+ def INPUT_TYPES(s):
875
+ return {"required": { "clip_name1": (ldm_patched.utils.path_utils.get_filename_list("clip"), ), "clip_name2": (ldm_patched.utils.path_utils.get_filename_list("clip"), ),
876
+ }}
877
+ RETURN_TYPES = ("CLIP",)
878
+ FUNCTION = "load_clip"
879
+
880
+ CATEGORY = "advanced/loaders"
881
+
882
+ def load_clip(self, clip_name1, clip_name2):
883
+ clip_path1 = ldm_patched.utils.path_utils.get_full_path("clip", clip_name1)
884
+ clip_path2 = ldm_patched.utils.path_utils.get_full_path("clip", clip_name2)
885
+ clip = ldm_patched.modules.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings"))
886
+ return (clip,)
887
+
888
+ class CLIPVisionLoader:
889
+ @classmethod
890
+ def INPUT_TYPES(s):
891
+ return {"required": { "clip_name": (ldm_patched.utils.path_utils.get_filename_list("clip_vision"), ),
892
+ }}
893
+ RETURN_TYPES = ("CLIP_VISION",)
894
+ FUNCTION = "load_clip"
895
+
896
+ CATEGORY = "loaders"
897
+
898
+ def load_clip(self, clip_name):
899
+ clip_path = ldm_patched.utils.path_utils.get_full_path("clip_vision", clip_name)
900
+ clip_vision = ldm_patched.modules.clip_vision.load(clip_path)
901
+ return (clip_vision,)
902
+
903
+ class CLIPVisionEncode:
904
+ @classmethod
905
+ def INPUT_TYPES(s):
906
+ return {"required": { "clip_vision": ("CLIP_VISION",),
907
+ "image": ("IMAGE",)
908
+ }}
909
+ RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
910
+ FUNCTION = "encode"
911
+
912
+ CATEGORY = "conditioning"
913
+
914
+ def encode(self, clip_vision, image):
915
+ output = clip_vision.encode_image(image)
916
+ return (output,)
917
+
918
+ class StyleModelLoader:
919
+ @classmethod
920
+ def INPUT_TYPES(s):
921
+ return {"required": { "style_model_name": (ldm_patched.utils.path_utils.get_filename_list("style_models"), )}}
922
+
923
+ RETURN_TYPES = ("STYLE_MODEL",)
924
+ FUNCTION = "load_style_model"
925
+
926
+ CATEGORY = "loaders"
927
+
928
+ def load_style_model(self, style_model_name):
929
+ style_model_path = ldm_patched.utils.path_utils.get_full_path("style_models", style_model_name)
930
+ style_model = ldm_patched.modules.sd.load_style_model(style_model_path)
931
+ return (style_model,)
932
+
933
+
934
+ class StyleModelApply:
935
+ @classmethod
936
+ def INPUT_TYPES(s):
937
+ return {"required": {"conditioning": ("CONDITIONING", ),
938
+ "style_model": ("STYLE_MODEL", ),
939
+ "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
940
+ }}
941
+ RETURN_TYPES = ("CONDITIONING",)
942
+ FUNCTION = "apply_stylemodel"
943
+
944
+ CATEGORY = "conditioning/style_model"
945
+
946
+ def apply_stylemodel(self, clip_vision_output, style_model, conditioning):
947
+ cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
948
+ c = []
949
+ for t in conditioning:
950
+ n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
951
+ c.append(n)
952
+ return (c, )
953
+
954
+ class unCLIPConditioning:
955
+ @classmethod
956
+ def INPUT_TYPES(s):
957
+ return {"required": {"conditioning": ("CONDITIONING", ),
958
+ "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
959
+ "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
960
+ "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
961
+ }}
962
+ RETURN_TYPES = ("CONDITIONING",)
963
+ FUNCTION = "apply_adm"
964
+
965
+ CATEGORY = "conditioning"
966
+
967
+ def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
968
+ if strength == 0:
969
+ return (conditioning, )
970
+
971
+ c = []
972
+ for t in conditioning:
973
+ o = t[1].copy()
974
+ x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}
975
+ if "unclip_conditioning" in o:
976
+ o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x]
977
+ else:
978
+ o["unclip_conditioning"] = [x]
979
+ n = [t[0], o]
980
+ c.append(n)
981
+ return (c, )
982
+
983
+ class GLIGENLoader:
984
+ @classmethod
985
+ def INPUT_TYPES(s):
986
+ return {"required": { "gligen_name": (ldm_patched.utils.path_utils.get_filename_list("gligen"), )}}
987
+
988
+ RETURN_TYPES = ("GLIGEN",)
989
+ FUNCTION = "load_gligen"
990
+
991
+ CATEGORY = "loaders"
992
+
993
+ def load_gligen(self, gligen_name):
994
+ gligen_path = ldm_patched.utils.path_utils.get_full_path("gligen", gligen_name)
995
+ gligen = ldm_patched.modules.sd.load_gligen(gligen_path)
996
+ return (gligen,)
997
+
998
+ class GLIGENTextBoxApply:
999
+ @classmethod
1000
+ def INPUT_TYPES(s):
1001
+ return {"required": {"conditioning_to": ("CONDITIONING", ),
1002
+ "clip": ("CLIP", ),
1003
+ "gligen_textbox_model": ("GLIGEN", ),
1004
+ "text": ("STRING", {"multiline": True}),
1005
+ "width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
1006
+ "height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
1007
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1008
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1009
+ }}
1010
+ RETURN_TYPES = ("CONDITIONING",)
1011
+ FUNCTION = "append"
1012
+
1013
+ CATEGORY = "conditioning/gligen"
1014
+
1015
+ def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
1016
+ c = []
1017
+ cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)
1018
+ for t in conditioning_to:
1019
+ n = [t[0], t[1].copy()]
1020
+ position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
1021
+ prev = []
1022
+ if "gligen" in n[1]:
1023
+ prev = n[1]['gligen'][2]
1024
+
1025
+ n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params)
1026
+ c.append(n)
1027
+ return (c, )
1028
+
1029
+ class EmptyLatentImage:
1030
+ def __init__(self):
1031
+ self.device = ldm_patched.modules.model_management.intermediate_device()
1032
+
1033
+ @classmethod
1034
+ def INPUT_TYPES(s):
1035
+ return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
1036
+ "height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
1037
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
1038
+ RETURN_TYPES = ("LATENT",)
1039
+ FUNCTION = "generate"
1040
+
1041
+ CATEGORY = "latent"
1042
+
1043
+ def generate(self, width, height, batch_size=1):
1044
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
1045
+ return ({"samples":latent}, )
1046
+
1047
+
1048
+ class LatentFromBatch:
1049
+ @classmethod
1050
+ def INPUT_TYPES(s):
1051
+ return {"required": { "samples": ("LATENT",),
1052
+ "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
1053
+ "length": ("INT", {"default": 1, "min": 1, "max": 64}),
1054
+ }}
1055
+ RETURN_TYPES = ("LATENT",)
1056
+ FUNCTION = "frombatch"
1057
+
1058
+ CATEGORY = "latent/batch"
1059
+
1060
+ def frombatch(self, samples, batch_index, length):
1061
+ s = samples.copy()
1062
+ s_in = samples["samples"]
1063
+ batch_index = min(s_in.shape[0] - 1, batch_index)
1064
+ length = min(s_in.shape[0] - batch_index, length)
1065
+ s["samples"] = s_in[batch_index:batch_index + length].clone()
1066
+ if "noise_mask" in samples:
1067
+ masks = samples["noise_mask"]
1068
+ if masks.shape[0] == 1:
1069
+ s["noise_mask"] = masks.clone()
1070
+ else:
1071
+ if masks.shape[0] < s_in.shape[0]:
1072
+ masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
1073
+ s["noise_mask"] = masks[batch_index:batch_index + length].clone()
1074
+ if "batch_index" not in s:
1075
+ s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
1076
+ else:
1077
+ s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
1078
+ return (s,)
1079
+
1080
+ class RepeatLatentBatch:
1081
+ @classmethod
1082
+ def INPUT_TYPES(s):
1083
+ return {"required": { "samples": ("LATENT",),
1084
+ "amount": ("INT", {"default": 1, "min": 1, "max": 64}),
1085
+ }}
1086
+ RETURN_TYPES = ("LATENT",)
1087
+ FUNCTION = "repeat"
1088
+
1089
+ CATEGORY = "latent/batch"
1090
+
1091
+ def repeat(self, samples, amount):
1092
+ s = samples.copy()
1093
+ s_in = samples["samples"]
1094
+
1095
+ s["samples"] = s_in.repeat((amount, 1,1,1))
1096
+ if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
1097
+ masks = samples["noise_mask"]
1098
+ if masks.shape[0] < s_in.shape[0]:
1099
+ masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
1100
+ s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
1101
+ if "batch_index" in s:
1102
+ offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
1103
+ s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
1104
+ return (s,)
1105
+
1106
+ class LatentUpscale:
1107
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
1108
+ crop_methods = ["disabled", "center"]
1109
+
1110
+ @classmethod
1111
+ def INPUT_TYPES(s):
1112
+ return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
1113
+ "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1114
+ "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1115
+ "crop": (s.crop_methods,)}}
1116
+ RETURN_TYPES = ("LATENT",)
1117
+ FUNCTION = "upscale"
1118
+
1119
+ CATEGORY = "latent"
1120
+
1121
+ def upscale(self, samples, upscale_method, width, height, crop):
1122
+ if width == 0 and height == 0:
1123
+ s = samples
1124
+ else:
1125
+ s = samples.copy()
1126
+
1127
+ if width == 0:
1128
+ height = max(64, height)
1129
+ width = max(64, round(samples["samples"].shape[3] * height / samples["samples"].shape[2]))
1130
+ elif height == 0:
1131
+ width = max(64, width)
1132
+ height = max(64, round(samples["samples"].shape[2] * width / samples["samples"].shape[3]))
1133
+ else:
1134
+ width = max(64, width)
1135
+ height = max(64, height)
1136
+
1137
+ s["samples"] = ldm_patched.modules.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
1138
+ return (s,)
1139
+
1140
+ class LatentUpscaleBy:
1141
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
1142
+
1143
+ @classmethod
1144
+ def INPUT_TYPES(s):
1145
+ return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
1146
+ "scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}}
1147
+ RETURN_TYPES = ("LATENT",)
1148
+ FUNCTION = "upscale"
1149
+
1150
+ CATEGORY = "latent"
1151
+
1152
+ def upscale(self, samples, upscale_method, scale_by):
1153
+ s = samples.copy()
1154
+ width = round(samples["samples"].shape[3] * scale_by)
1155
+ height = round(samples["samples"].shape[2] * scale_by)
1156
+ s["samples"] = ldm_patched.modules.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
1157
+ return (s,)
1158
+
1159
+ class LatentRotate:
1160
+ @classmethod
1161
+ def INPUT_TYPES(s):
1162
+ return {"required": { "samples": ("LATENT",),
1163
+ "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],),
1164
+ }}
1165
+ RETURN_TYPES = ("LATENT",)
1166
+ FUNCTION = "rotate"
1167
+
1168
+ CATEGORY = "latent/transform"
1169
+
1170
+ def rotate(self, samples, rotation):
1171
+ s = samples.copy()
1172
+ rotate_by = 0
1173
+ if rotation.startswith("90"):
1174
+ rotate_by = 1
1175
+ elif rotation.startswith("180"):
1176
+ rotate_by = 2
1177
+ elif rotation.startswith("270"):
1178
+ rotate_by = 3
1179
+
1180
+ s["samples"] = torch.rot90(samples["samples"], k=rotate_by, dims=[3, 2])
1181
+ return (s,)
1182
+
1183
+ class LatentFlip:
1184
+ @classmethod
1185
+ def INPUT_TYPES(s):
1186
+ return {"required": { "samples": ("LATENT",),
1187
+ "flip_method": (["x-axis: vertically", "y-axis: horizontally"],),
1188
+ }}
1189
+ RETURN_TYPES = ("LATENT",)
1190
+ FUNCTION = "flip"
1191
+
1192
+ CATEGORY = "latent/transform"
1193
+
1194
+ def flip(self, samples, flip_method):
1195
+ s = samples.copy()
1196
+ if flip_method.startswith("x"):
1197
+ s["samples"] = torch.flip(samples["samples"], dims=[2])
1198
+ elif flip_method.startswith("y"):
1199
+ s["samples"] = torch.flip(samples["samples"], dims=[3])
1200
+
1201
+ return (s,)
1202
+
1203
+ class LatentComposite:
1204
+ @classmethod
1205
+ def INPUT_TYPES(s):
1206
+ return {"required": { "samples_to": ("LATENT",),
1207
+ "samples_from": ("LATENT",),
1208
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1209
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1210
+ "feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1211
+ }}
1212
+ RETURN_TYPES = ("LATENT",)
1213
+ FUNCTION = "composite"
1214
+
1215
+ CATEGORY = "latent"
1216
+
1217
+ def composite(self, samples_to, samples_from, x, y, composite_method="normal", feather=0):
1218
+ x = x // 8
1219
+ y = y // 8
1220
+ feather = feather // 8
1221
+ samples_out = samples_to.copy()
1222
+ s = samples_to["samples"].clone()
1223
+ samples_to = samples_to["samples"]
1224
+ samples_from = samples_from["samples"]
1225
+ if feather == 0:
1226
+ s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
1227
+ else:
1228
+ samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
1229
+ mask = torch.ones_like(samples_from)
1230
+ for t in range(feather):
1231
+ if y != 0:
1232
+ mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
1233
+
1234
+ if y + samples_from.shape[2] < samples_to.shape[2]:
1235
+ mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
1236
+ if x != 0:
1237
+ mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
1238
+ if x + samples_from.shape[3] < samples_to.shape[3]:
1239
+ mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
1240
+ rev_mask = torch.ones_like(mask) - mask
1241
+ s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask
1242
+ samples_out["samples"] = s
1243
+ return (samples_out,)
1244
+
1245
+ class LatentBlend:
1246
+ @classmethod
1247
+ def INPUT_TYPES(s):
1248
+ return {"required": {
1249
+ "samples1": ("LATENT",),
1250
+ "samples2": ("LATENT",),
1251
+ "blend_factor": ("FLOAT", {
1252
+ "default": 0.5,
1253
+ "min": 0,
1254
+ "max": 1,
1255
+ "step": 0.01
1256
+ }),
1257
+ }}
1258
+
1259
+ RETURN_TYPES = ("LATENT",)
1260
+ FUNCTION = "blend"
1261
+
1262
+ CATEGORY = "_for_testing"
1263
+
1264
+ def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
1265
+
1266
+ samples_out = samples1.copy()
1267
+ samples1 = samples1["samples"]
1268
+ samples2 = samples2["samples"]
1269
+
1270
+ if samples1.shape != samples2.shape:
1271
+ samples2.permute(0, 3, 1, 2)
1272
+ samples2 = ldm_patched.modules.utils.common_upscale(samples2, samples1.shape[3], samples1.shape[2], 'bicubic', crop='center')
1273
+ samples2.permute(0, 2, 3, 1)
1274
+
1275
+ samples_blended = self.blend_mode(samples1, samples2, blend_mode)
1276
+ samples_blended = samples1 * blend_factor + samples_blended * (1 - blend_factor)
1277
+ samples_out["samples"] = samples_blended
1278
+ return (samples_out,)
1279
+
1280
+ def blend_mode(self, img1, img2, mode):
1281
+ if mode == "normal":
1282
+ return img2
1283
+ else:
1284
+ raise ValueError(f"Unsupported blend mode: {mode}")
1285
+
1286
+ class LatentCrop:
1287
+ @classmethod
1288
+ def INPUT_TYPES(s):
1289
+ return {"required": { "samples": ("LATENT",),
1290
+ "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
1291
+ "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
1292
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1293
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1294
+ }}
1295
+ RETURN_TYPES = ("LATENT",)
1296
+ FUNCTION = "crop"
1297
+
1298
+ CATEGORY = "latent/transform"
1299
+
1300
+ def crop(self, samples, width, height, x, y):
1301
+ s = samples.copy()
1302
+ samples = samples['samples']
1303
+ x = x // 8
1304
+ y = y // 8
1305
+
1306
+ #enfonce minimum size of 64
1307
+ if x > (samples.shape[3] - 8):
1308
+ x = samples.shape[3] - 8
1309
+ if y > (samples.shape[2] - 8):
1310
+ y = samples.shape[2] - 8
1311
+
1312
+ new_height = height // 8
1313
+ new_width = width // 8
1314
+ to_x = new_width + x
1315
+ to_y = new_height + y
1316
+ s['samples'] = samples[:,:,y:to_y, x:to_x]
1317
+ return (s,)
1318
+
1319
+ class SetLatentNoiseMask:
1320
+ @classmethod
1321
+ def INPUT_TYPES(s):
1322
+ return {"required": { "samples": ("LATENT",),
1323
+ "mask": ("MASK",),
1324
+ }}
1325
+ RETURN_TYPES = ("LATENT",)
1326
+ FUNCTION = "set_mask"
1327
+
1328
+ CATEGORY = "latent/inpaint"
1329
+
1330
+ def set_mask(self, samples, mask):
1331
+ s = samples.copy()
1332
+ s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
1333
+ return (s,)
1334
+
1335
+ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
1336
+ latent_image = latent["samples"]
1337
+ if disable_noise:
1338
+ noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
1339
+ else:
1340
+ batch_inds = latent["batch_index"] if "batch_index" in latent else None
1341
+ noise = ldm_patched.modules.sample.prepare_noise(latent_image, seed, batch_inds)
1342
+
1343
+ noise_mask = None
1344
+ if "noise_mask" in latent:
1345
+ noise_mask = latent["noise_mask"]
1346
+
1347
+ callback = ldm_patched.utils.latent_visualization.prepare_callback(model, steps)
1348
+ disable_pbar = not ldm_patched.modules.utils.PROGRESS_BAR_ENABLED
1349
+ samples = ldm_patched.modules.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
1350
+ denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
1351
+ force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
1352
+ out = latent.copy()
1353
+ out["samples"] = samples
1354
+ return (out, )
1355
+
1356
+ class KSampler:
1357
+ @classmethod
1358
+ def INPUT_TYPES(s):
1359
+ return {"required":
1360
+ {"model": ("MODEL",),
1361
+ "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
1362
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
1363
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
1364
+ "sampler_name": (ldm_patched.modules.samplers.KSampler.SAMPLERS, ),
1365
+ "scheduler": (ldm_patched.modules.samplers.KSampler.SCHEDULERS, ),
1366
+ "positive": ("CONDITIONING", ),
1367
+ "negative": ("CONDITIONING", ),
1368
+ "latent_image": ("LATENT", ),
1369
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
1370
+ }
1371
+ }
1372
+
1373
+ RETURN_TYPES = ("LATENT",)
1374
+ FUNCTION = "sample"
1375
+
1376
+ CATEGORY = "sampling"
1377
+
1378
+ def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
1379
+ return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
1380
+
1381
+ class KSamplerAdvanced:
1382
+ @classmethod
1383
+ def INPUT_TYPES(s):
1384
+ return {"required":
1385
+ {"model": ("MODEL",),
1386
+ "add_noise": (["enable", "disable"], ),
1387
+ "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
1388
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
1389
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
1390
+ "sampler_name": (ldm_patched.modules.samplers.KSampler.SAMPLERS, ),
1391
+ "scheduler": (ldm_patched.modules.samplers.KSampler.SCHEDULERS, ),
1392
+ "positive": ("CONDITIONING", ),
1393
+ "negative": ("CONDITIONING", ),
1394
+ "latent_image": ("LATENT", ),
1395
+ "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
1396
+ "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
1397
+ "return_with_leftover_noise": (["disable", "enable"], ),
1398
+ }
1399
+ }
1400
+
1401
+ RETURN_TYPES = ("LATENT",)
1402
+ FUNCTION = "sample"
1403
+
1404
+ CATEGORY = "sampling"
1405
+
1406
+ def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0):
1407
+ force_full_denoise = True
1408
+ if return_with_leftover_noise == "enable":
1409
+ force_full_denoise = False
1410
+ disable_noise = False
1411
+ if add_noise == "disable":
1412
+ disable_noise = True
1413
+ return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
1414
+
1415
+ class SaveImage:
1416
+ def __init__(self):
1417
+ self.output_dir = ldm_patched.utils.path_utils.get_output_directory()
1418
+ self.type = "output"
1419
+ self.prefix_append = ""
1420
+ self.compress_level = 4
1421
+
1422
+ @classmethod
1423
+ def INPUT_TYPES(s):
1424
+ return {"required":
1425
+ {"images": ("IMAGE", ),
1426
+ "filename_prefix": ("STRING", {"default": "ldm_patched"})},
1427
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
1428
+ }
1429
+
1430
+ RETURN_TYPES = ()
1431
+ FUNCTION = "save_images"
1432
+
1433
+ OUTPUT_NODE = True
1434
+
1435
+ CATEGORY = "image"
1436
+
1437
+ def save_images(self, images, filename_prefix="ldm_patched", prompt=None, extra_pnginfo=None):
1438
+ filename_prefix += self.prefix_append
1439
+ full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
1440
+ results = list()
1441
+ for image in images:
1442
+ i = 255. * image.cpu().numpy()
1443
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
1444
+ metadata = None
1445
+ if not args.disable_server_info:
1446
+ metadata = PngInfo()
1447
+ if prompt is not None:
1448
+ metadata.add_text("prompt", json.dumps(prompt))
1449
+ if extra_pnginfo is not None:
1450
+ for x in extra_pnginfo:
1451
+ metadata.add_text(x, json.dumps(extra_pnginfo[x]))
1452
+
1453
+ file = f"{filename}_{counter:05}_.png"
1454
+ img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
1455
+ results.append({
1456
+ "filename": file,
1457
+ "subfolder": subfolder,
1458
+ "type": self.type
1459
+ })
1460
+ counter += 1
1461
+
1462
+ return { "ui": { "images": results } }
1463
+
1464
+ class PreviewImage(SaveImage):
1465
+ def __init__(self):
1466
+ self.output_dir = ldm_patched.utils.path_utils.get_temp_directory()
1467
+ self.type = "temp"
1468
+ self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
1469
+ self.compress_level = 1
1470
+
1471
+ @classmethod
1472
+ def INPUT_TYPES(s):
1473
+ return {"required":
1474
+ {"images": ("IMAGE", ), },
1475
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
1476
+ }
1477
+
1478
+ class LoadImage:
1479
+ @classmethod
1480
+ def INPUT_TYPES(s):
1481
+ input_dir = ldm_patched.utils.path_utils.get_input_directory()
1482
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
1483
+ return {"required":
1484
+ {"image": (sorted(files), {"image_upload": True})},
1485
+ }
1486
+
1487
+ CATEGORY = "image"
1488
+
1489
+ RETURN_TYPES = ("IMAGE", "MASK")
1490
+ FUNCTION = "load_image"
1491
+ def load_image(self, image):
1492
+ image_path = ldm_patched.utils.path_utils.get_annotated_filepath(image)
1493
+ img = Image.open(image_path)
1494
+ output_images = []
1495
+ output_masks = []
1496
+ for i in ImageSequence.Iterator(img):
1497
+ i = ImageOps.exif_transpose(i)
1498
+ if i.mode == 'I':
1499
+ i = i.point(lambda i: i * (1 / 255))
1500
+ image = i.convert("RGB")
1501
+ image = np.array(image).astype(np.float32) / 255.0
1502
+ image = torch.from_numpy(image)[None,]
1503
+ if 'A' in i.getbands():
1504
+ mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
1505
+ mask = 1. - torch.from_numpy(mask)
1506
+ else:
1507
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
1508
+ output_images.append(image)
1509
+ output_masks.append(mask.unsqueeze(0))
1510
+
1511
+ if len(output_images) > 1:
1512
+ output_image = torch.cat(output_images, dim=0)
1513
+ output_mask = torch.cat(output_masks, dim=0)
1514
+ else:
1515
+ output_image = output_images[0]
1516
+ output_mask = output_masks[0]
1517
+
1518
+ return (output_image, output_mask)
1519
+
1520
+ @classmethod
1521
+ def IS_CHANGED(s, image):
1522
+ image_path = ldm_patched.utils.path_utils.get_annotated_filepath(image)
1523
+ m = hashlib.sha256()
1524
+ with open(image_path, 'rb') as f:
1525
+ m.update(f.read())
1526
+ return m.digest().hex()
1527
+
1528
+ @classmethod
1529
+ def VALIDATE_INPUTS(s, image):
1530
+ if not ldm_patched.utils.path_utils.exists_annotated_filepath(image):
1531
+ return "Invalid image file: {}".format(image)
1532
+
1533
+ return True
1534
+
1535
+ class LoadImageMask:
1536
+ _color_channels = ["alpha", "red", "green", "blue"]
1537
+ @classmethod
1538
+ def INPUT_TYPES(s):
1539
+ input_dir = ldm_patched.utils.path_utils.get_input_directory()
1540
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
1541
+ return {"required":
1542
+ {"image": (sorted(files), {"image_upload": True}),
1543
+ "channel": (s._color_channels, ), }
1544
+ }
1545
+
1546
+ CATEGORY = "mask"
1547
+
1548
+ RETURN_TYPES = ("MASK",)
1549
+ FUNCTION = "load_image"
1550
+ def load_image(self, image, channel):
1551
+ image_path = ldm_patched.utils.path_utils.get_annotated_filepath(image)
1552
+ i = Image.open(image_path)
1553
+ i = ImageOps.exif_transpose(i)
1554
+ if i.getbands() != ("R", "G", "B", "A"):
1555
+ if i.mode == 'I':
1556
+ i = i.point(lambda i: i * (1 / 255))
1557
+ i = i.convert("RGBA")
1558
+ mask = None
1559
+ c = channel[0].upper()
1560
+ if c in i.getbands():
1561
+ mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
1562
+ mask = torch.from_numpy(mask)
1563
+ if c == 'A':
1564
+ mask = 1. - mask
1565
+ else:
1566
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
1567
+ return (mask.unsqueeze(0),)
1568
+
1569
+ @classmethod
1570
+ def IS_CHANGED(s, image, channel):
1571
+ image_path = ldm_patched.utils.path_utils.get_annotated_filepath(image)
1572
+ m = hashlib.sha256()
1573
+ with open(image_path, 'rb') as f:
1574
+ m.update(f.read())
1575
+ return m.digest().hex()
1576
+
1577
+ @classmethod
1578
+ def VALIDATE_INPUTS(s, image):
1579
+ if not ldm_patched.utils.path_utils.exists_annotated_filepath(image):
1580
+ return "Invalid image file: {}".format(image)
1581
+
1582
+ return True
1583
+
1584
+ class ImageScale:
1585
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
1586
+ crop_methods = ["disabled", "center"]
1587
+
1588
+ @classmethod
1589
+ def INPUT_TYPES(s):
1590
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
1591
+ "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1592
+ "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1593
+ "crop": (s.crop_methods,)}}
1594
+ RETURN_TYPES = ("IMAGE",)
1595
+ FUNCTION = "upscale"
1596
+
1597
+ CATEGORY = "image/upscaling"
1598
+
1599
+ def upscale(self, image, upscale_method, width, height, crop):
1600
+ if width == 0 and height == 0:
1601
+ s = image
1602
+ else:
1603
+ samples = image.movedim(-1,1)
1604
+
1605
+ if width == 0:
1606
+ width = max(1, round(samples.shape[3] * height / samples.shape[2]))
1607
+ elif height == 0:
1608
+ height = max(1, round(samples.shape[2] * width / samples.shape[3]))
1609
+
1610
+ s = ldm_patched.modules.utils.common_upscale(samples, width, height, upscale_method, crop)
1611
+ s = s.movedim(1,-1)
1612
+ return (s,)
1613
+
1614
+ class ImageScaleBy:
1615
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
1616
+
1617
+ @classmethod
1618
+ def INPUT_TYPES(s):
1619
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
1620
+ "scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 8.0, "step": 0.01}),}}
1621
+ RETURN_TYPES = ("IMAGE",)
1622
+ FUNCTION = "upscale"
1623
+
1624
+ CATEGORY = "image/upscaling"
1625
+
1626
+ def upscale(self, image, upscale_method, scale_by):
1627
+ samples = image.movedim(-1,1)
1628
+ width = round(samples.shape[3] * scale_by)
1629
+ height = round(samples.shape[2] * scale_by)
1630
+ s = ldm_patched.modules.utils.common_upscale(samples, width, height, upscale_method, "disabled")
1631
+ s = s.movedim(1,-1)
1632
+ return (s,)
1633
+
1634
+ class ImageInvert:
1635
+
1636
+ @classmethod
1637
+ def INPUT_TYPES(s):
1638
+ return {"required": { "image": ("IMAGE",)}}
1639
+
1640
+ RETURN_TYPES = ("IMAGE",)
1641
+ FUNCTION = "invert"
1642
+
1643
+ CATEGORY = "image"
1644
+
1645
+ def invert(self, image):
1646
+ s = 1.0 - image
1647
+ return (s,)
1648
+
1649
+ class ImageBatch:
1650
+
1651
+ @classmethod
1652
+ def INPUT_TYPES(s):
1653
+ return {"required": { "image1": ("IMAGE",), "image2": ("IMAGE",)}}
1654
+
1655
+ RETURN_TYPES = ("IMAGE",)
1656
+ FUNCTION = "batch"
1657
+
1658
+ CATEGORY = "image"
1659
+
1660
+ def batch(self, image1, image2):
1661
+ if image1.shape[1:] != image2.shape[1:]:
1662
+ image2 = ldm_patched.modules.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1)
1663
+ s = torch.cat((image1, image2), dim=0)
1664
+ return (s,)
1665
+
1666
+ class EmptyImage:
1667
+ def __init__(self, device="cpu"):
1668
+ self.device = device
1669
+
1670
+ @classmethod
1671
+ def INPUT_TYPES(s):
1672
+ return {"required": { "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
1673
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
1674
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
1675
+ "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
1676
+ }}
1677
+ RETURN_TYPES = ("IMAGE",)
1678
+ FUNCTION = "generate"
1679
+
1680
+ CATEGORY = "image"
1681
+
1682
+ def generate(self, width, height, batch_size=1, color=0):
1683
+ r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF)
1684
+ g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF)
1685
+ b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF)
1686
+ return (torch.cat((r, g, b), dim=-1), )
1687
+
1688
+ class ImagePadForOutpaint:
1689
+
1690
+ @classmethod
1691
+ def INPUT_TYPES(s):
1692
+ return {
1693
+ "required": {
1694
+ "image": ("IMAGE",),
1695
+ "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1696
+ "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1697
+ "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1698
+ "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1699
+ "feathering": ("INT", {"default": 40, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1700
+ }
1701
+ }
1702
+
1703
+ RETURN_TYPES = ("IMAGE", "MASK")
1704
+ FUNCTION = "expand_image"
1705
+
1706
+ CATEGORY = "image"
1707
+
1708
+ def expand_image(self, image, left, top, right, bottom, feathering):
1709
+ d1, d2, d3, d4 = image.size()
1710
+
1711
+ new_image = torch.ones(
1712
+ (d1, d2 + top + bottom, d3 + left + right, d4),
1713
+ dtype=torch.float32,
1714
+ ) * 0.5
1715
+
1716
+ new_image[:, top:top + d2, left:left + d3, :] = image
1717
+
1718
+ mask = torch.ones(
1719
+ (d2 + top + bottom, d3 + left + right),
1720
+ dtype=torch.float32,
1721
+ )
1722
+
1723
+ t = torch.zeros(
1724
+ (d2, d3),
1725
+ dtype=torch.float32
1726
+ )
1727
+
1728
+ if feathering > 0 and feathering * 2 < d2 and feathering * 2 < d3:
1729
+
1730
+ for i in range(d2):
1731
+ for j in range(d3):
1732
+ dt = i if top != 0 else d2
1733
+ db = d2 - i if bottom != 0 else d2
1734
+
1735
+ dl = j if left != 0 else d3
1736
+ dr = d3 - j if right != 0 else d3
1737
+
1738
+ d = min(dt, db, dl, dr)
1739
+
1740
+ if d >= feathering:
1741
+ continue
1742
+
1743
+ v = (feathering - d) / feathering
1744
+
1745
+ t[i, j] = v * v
1746
+
1747
+ mask[top:top + d2, left:left + d3] = t
1748
+
1749
+ return (new_image, mask)
1750
+
1751
+
1752
+ NODE_CLASS_MAPPINGS = {
1753
+ "KSampler": KSampler,
1754
+ "CheckpointLoaderSimple": CheckpointLoaderSimple,
1755
+ "CLIPTextEncode": CLIPTextEncode,
1756
+ "CLIPSetLastLayer": CLIPSetLastLayer,
1757
+ "VAEDecode": VAEDecode,
1758
+ "VAEEncode": VAEEncode,
1759
+ "VAEEncodeForInpaint": VAEEncodeForInpaint,
1760
+ "VAELoader": VAELoader,
1761
+ "EmptyLatentImage": EmptyLatentImage,
1762
+ "LatentUpscale": LatentUpscale,
1763
+ "LatentUpscaleBy": LatentUpscaleBy,
1764
+ "LatentFromBatch": LatentFromBatch,
1765
+ "RepeatLatentBatch": RepeatLatentBatch,
1766
+ "SaveImage": SaveImage,
1767
+ "PreviewImage": PreviewImage,
1768
+ "LoadImage": LoadImage,
1769
+ "LoadImageMask": LoadImageMask,
1770
+ "ImageScale": ImageScale,
1771
+ "ImageScaleBy": ImageScaleBy,
1772
+ "ImageInvert": ImageInvert,
1773
+ "ImageBatch": ImageBatch,
1774
+ "ImagePadForOutpaint": ImagePadForOutpaint,
1775
+ "EmptyImage": EmptyImage,
1776
+ "ConditioningAverage": ConditioningAverage ,
1777
+ "ConditioningCombine": ConditioningCombine,
1778
+ "ConditioningConcat": ConditioningConcat,
1779
+ "ConditioningSetArea": ConditioningSetArea,
1780
+ "ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
1781
+ "ConditioningSetAreaStrength": ConditioningSetAreaStrength,
1782
+ "ConditioningSetMask": ConditioningSetMask,
1783
+ "KSamplerAdvanced": KSamplerAdvanced,
1784
+ "SetLatentNoiseMask": SetLatentNoiseMask,
1785
+ "LatentComposite": LatentComposite,
1786
+ "LatentBlend": LatentBlend,
1787
+ "LatentRotate": LatentRotate,
1788
+ "LatentFlip": LatentFlip,
1789
+ "LatentCrop": LatentCrop,
1790
+ "LoraLoader": LoraLoader,
1791
+ "CLIPLoader": CLIPLoader,
1792
+ "UNETLoader": UNETLoader,
1793
+ "DualCLIPLoader": DualCLIPLoader,
1794
+ "CLIPVisionEncode": CLIPVisionEncode,
1795
+ "StyleModelApply": StyleModelApply,
1796
+ "unCLIPConditioning": unCLIPConditioning,
1797
+ "ControlNetApply": ControlNetApply,
1798
+ "ControlNetApplyAdvanced": ControlNetApplyAdvanced,
1799
+ "ControlNetLoader": ControlNetLoader,
1800
+ "DiffControlNetLoader": DiffControlNetLoader,
1801
+ "StyleModelLoader": StyleModelLoader,
1802
+ "CLIPVisionLoader": CLIPVisionLoader,
1803
+ "VAEDecodeTiled": VAEDecodeTiled,
1804
+ "VAEEncodeTiled": VAEEncodeTiled,
1805
+ "unCLIPCheckpointLoader": unCLIPCheckpointLoader,
1806
+ "GLIGENLoader": GLIGENLoader,
1807
+ "GLIGENTextBoxApply": GLIGENTextBoxApply,
1808
+ "InpaintModelConditioning": InpaintModelConditioning,
1809
+
1810
+ "CheckpointLoader": CheckpointLoader,
1811
+ "DiffusersLoader": DiffusersLoader,
1812
+
1813
+ "LoadLatent": LoadLatent,
1814
+ "SaveLatent": SaveLatent,
1815
+
1816
+ "ConditioningZeroOut": ConditioningZeroOut,
1817
+ "ConditioningSetTimestepRange": ConditioningSetTimestepRange,
1818
+ "LoraLoaderModelOnly": LoraLoaderModelOnly,
1819
+ }
1820
+
1821
+ NODE_DISPLAY_NAME_MAPPINGS = {
1822
+ # Sampling
1823
+ "KSampler": "KSampler",
1824
+ "KSamplerAdvanced": "KSampler (Advanced)",
1825
+ # Loaders
1826
+ "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
1827
+ "CheckpointLoaderSimple": "Load Checkpoint",
1828
+ "VAELoader": "Load VAE",
1829
+ "LoraLoader": "Load LoRA",
1830
+ "CLIPLoader": "Load CLIP",
1831
+ "ControlNetLoader": "Load ControlNet Model",
1832
+ "DiffControlNetLoader": "Load ControlNet Model (diff)",
1833
+ "StyleModelLoader": "Load Style Model",
1834
+ "CLIPVisionLoader": "Load CLIP Vision",
1835
+ "UpscaleModelLoader": "Load Upscale Model",
1836
+ # Conditioning
1837
+ "CLIPVisionEncode": "CLIP Vision Encode",
1838
+ "StyleModelApply": "Apply Style Model",
1839
+ "CLIPTextEncode": "CLIP Text Encode (Prompt)",
1840
+ "CLIPSetLastLayer": "CLIP Set Last Layer",
1841
+ "ConditioningCombine": "Conditioning (Combine)",
1842
+ "ConditioningAverage ": "Conditioning (Average)",
1843
+ "ConditioningConcat": "Conditioning (Concat)",
1844
+ "ConditioningSetArea": "Conditioning (Set Area)",
1845
+ "ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
1846
+ "ConditioningSetMask": "Conditioning (Set Mask)",
1847
+ "ControlNetApply": "Apply ControlNet",
1848
+ "ControlNetApplyAdvanced": "Apply ControlNet (Advanced)",
1849
+ # Latent
1850
+ "VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
1851
+ "SetLatentNoiseMask": "Set Latent Noise Mask",
1852
+ "VAEDecode": "VAE Decode",
1853
+ "VAEEncode": "VAE Encode",
1854
+ "LatentRotate": "Rotate Latent",
1855
+ "LatentFlip": "Flip Latent",
1856
+ "LatentCrop": "Crop Latent",
1857
+ "EmptyLatentImage": "Empty Latent Image",
1858
+ "LatentUpscale": "Upscale Latent",
1859
+ "LatentUpscaleBy": "Upscale Latent By",
1860
+ "LatentComposite": "Latent Composite",
1861
+ "LatentBlend": "Latent Blend",
1862
+ "LatentFromBatch" : "Latent From Batch",
1863
+ "RepeatLatentBatch": "Repeat Latent Batch",
1864
+ # Image
1865
+ "SaveImage": "Save Image",
1866
+ "PreviewImage": "Preview Image",
1867
+ "LoadImage": "Load Image",
1868
+ "LoadImageMask": "Load Image (as Mask)",
1869
+ "ImageScale": "Upscale Image",
1870
+ "ImageScaleBy": "Upscale Image By",
1871
+ "ImageUpscaleWithModel": "Upscale Image (using Model)",
1872
+ "ImageInvert": "Invert Image",
1873
+ "ImagePadForOutpaint": "Pad Image for Outpainting",
1874
+ "ImageBatch": "Batch Images",
1875
+ # _for_testing
1876
+ "VAEDecodeTiled": "VAE Decode (Tiled)",
1877
+ "VAEEncodeTiled": "VAE Encode (Tiled)",
1878
+ }
1879
+
1880
+ EXTENSION_WEB_DIRS = {}
1881
+
1882
+ def load_custom_node(module_path, ignore=set()):
1883
+ module_name = os.path.basename(module_path)
1884
+ if os.path.isfile(module_path):
1885
+ sp = os.path.splitext(module_path)
1886
+ module_name = sp[0]
1887
+ try:
1888
+ if os.path.isfile(module_path):
1889
+ module_spec = importlib.util.spec_from_file_location(module_name, module_path)
1890
+ module_dir = os.path.split(module_path)[0]
1891
+ else:
1892
+ module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
1893
+ module_dir = module_path
1894
+
1895
+ module = importlib.util.module_from_spec(module_spec)
1896
+ sys.modules[module_name] = module
1897
+ module_spec.loader.exec_module(module)
1898
+
1899
+ if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None:
1900
+ web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY")))
1901
+ if os.path.isdir(web_dir):
1902
+ EXTENSION_WEB_DIRS[module_name] = web_dir
1903
+
1904
+ if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
1905
+ for name in module.NODE_CLASS_MAPPINGS:
1906
+ if name not in ignore:
1907
+ NODE_CLASS_MAPPINGS[name] = module.NODE_CLASS_MAPPINGS[name]
1908
+ if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
1909
+ NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
1910
+ return True
1911
+ else:
1912
+ print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
1913
+ return False
1914
+ except Exception as e:
1915
+ print(traceback.format_exc())
1916
+ print(f"Cannot import {module_path} module for custom nodes:", e)
1917
+ return False
1918
+
1919
+ def load_custom_nodes():
1920
+ base_node_names = set(NODE_CLASS_MAPPINGS.keys())
1921
+ node_paths = ldm_patched.utils.path_utils.get_folder_paths("custom_nodes")
1922
+ node_import_times = []
1923
+ for custom_node_path in node_paths:
1924
+ possible_modules = os.listdir(os.path.realpath(custom_node_path))
1925
+ if "__pycache__" in possible_modules:
1926
+ possible_modules.remove("__pycache__")
1927
+
1928
+ for possible_module in possible_modules:
1929
+ module_path = os.path.join(custom_node_path, possible_module)
1930
+ if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
1931
+ if module_path.endswith(".disabled"): continue
1932
+ time_before = time.perf_counter()
1933
+ success = load_custom_node(module_path, base_node_names)
1934
+ node_import_times.append((time.perf_counter() - time_before, module_path, success))
1935
+
1936
+ if len(node_import_times) > 0:
1937
+ print("\nImport times for custom nodes:")
1938
+ for n in sorted(node_import_times):
1939
+ if n[2]:
1940
+ import_message = ""
1941
+ else:
1942
+ import_message = " (IMPORT FAILED)"
1943
+ print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
1944
+ print()
1945
+
1946
+ def init_custom_nodes():
1947
+ extras_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ldm_patched_extras")
1948
+ extras_files = [
1949
+ "nodes_latent.py",
1950
+ "nodes_hypernetwork.py",
1951
+ "nodes_upscale_model.py",
1952
+ "nodes_post_processing.py",
1953
+ "nodes_mask.py",
1954
+ "nodes_compositing.py",
1955
+ "nodes_rebatch.py",
1956
+ "nodes_model_merging.py",
1957
+ "nodes_tomesd.py",
1958
+ "nodes_clip_sdxl.py",
1959
+ "nodes_canny.py",
1960
+ "nodes_freelunch.py",
1961
+ "nodes_custom_sampler.py",
1962
+ "nodes_hypertile.py",
1963
+ "nodes_model_advanced.py",
1964
+ "nodes_model_downscale.py",
1965
+ "nodes_images.py",
1966
+ "nodes_video_model.py",
1967
+ "nodes_sag.py",
1968
+ "nodes_perpneg.py",
1969
+ "nodes_stable3d.py",
1970
+ "nodes_sdupscale.py",
1971
+ "nodes_photomaker.py",
1972
+ ]
1973
+
1974
+ for node_file in extras_files:
1975
+ load_custom_node(os.path.join(extras_dir, node_file))
1976
+
1977
+ load_custom_nodes()
ldm_patched/contrib/external_canny.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ #From https://github.com/kornia/kornia
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import ldm_patched.modules.model_management
11
+
12
+ def get_canny_nms_kernel(device=None, dtype=None):
13
+ """Utility function that returns 3x3 kernels for the Canny Non-maximal suppression."""
14
+ return torch.tensor(
15
+ [
16
+ [[[0.0, 0.0, 0.0], [0.0, 1.0, -1.0], [0.0, 0.0, 0.0]]],
17
+ [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]]],
18
+ [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0]]],
19
+ [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]]],
20
+ [[[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 0.0]]],
21
+ [[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]],
22
+ [[[0.0, -1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]],
23
+ [[[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]],
24
+ ],
25
+ device=device,
26
+ dtype=dtype,
27
+ )
28
+
29
+
30
+ def get_hysteresis_kernel(device=None, dtype=None):
31
+ """Utility function that returns the 3x3 kernels for the Canny hysteresis."""
32
+ return torch.tensor(
33
+ [
34
+ [[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0]]],
35
+ [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]],
36
+ [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]],
37
+ [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]],
38
+ [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
39
+ [[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
40
+ [[[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
41
+ [[[0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
42
+ ],
43
+ device=device,
44
+ dtype=dtype,
45
+ )
46
+
47
+ def gaussian_blur_2d(img, kernel_size, sigma):
48
+ ksize_half = (kernel_size - 1) * 0.5
49
+
50
+ x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
51
+
52
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
53
+
54
+ x_kernel = pdf / pdf.sum()
55
+ x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
56
+
57
+ kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
58
+ kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
59
+
60
+ padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
61
+
62
+ img = torch.nn.functional.pad(img, padding, mode="reflect")
63
+ img = torch.nn.functional.conv2d(img, kernel2d, groups=img.shape[-3])
64
+
65
+ return img
66
+
67
+ def get_sobel_kernel2d(device=None, dtype=None):
68
+ kernel_x = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], device=device, dtype=dtype)
69
+ kernel_y = kernel_x.transpose(0, 1)
70
+ return torch.stack([kernel_x, kernel_y])
71
+
72
+ def spatial_gradient(input, normalized: bool = True):
73
+ r"""Compute the first order image derivative in both x and y using a Sobel operator.
74
+ .. image:: _static/img/spatial_gradient.png
75
+ Args:
76
+ input: input image tensor with shape :math:`(B, C, H, W)`.
77
+ mode: derivatives modality, can be: `sobel` or `diff`.
78
+ order: the order of the derivatives.
79
+ normalized: whether the output is normalized.
80
+ Return:
81
+ the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`.
82
+ .. note::
83
+ See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
84
+ filtering_edges.html>`__.
85
+ Examples:
86
+ >>> input = torch.rand(1, 3, 4, 4)
87
+ >>> output = spatial_gradient(input) # 1x3x2x4x4
88
+ >>> output.shape
89
+ torch.Size([1, 3, 2, 4, 4])
90
+ """
91
+ # KORNIA_CHECK_IS_TENSOR(input)
92
+ # KORNIA_CHECK_SHAPE(input, ['B', 'C', 'H', 'W'])
93
+
94
+ # allocate kernel
95
+ kernel = get_sobel_kernel2d(device=input.device, dtype=input.dtype)
96
+ if normalized:
97
+ kernel = normalize_kernel2d(kernel)
98
+
99
+ # prepare kernel
100
+ b, c, h, w = input.shape
101
+ tmp_kernel = kernel[:, None, ...]
102
+
103
+ # Pad with "replicate for spatial dims, but with zeros for channel
104
+ spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2]
105
+ out_channels: int = 2
106
+ padded_inp = torch.nn.functional.pad(input.reshape(b * c, 1, h, w), spatial_pad, 'replicate')
107
+ out = F.conv2d(padded_inp, tmp_kernel, groups=1, padding=0, stride=1)
108
+ return out.reshape(b, c, out_channels, h, w)
109
+
110
+ def rgb_to_grayscale(image, rgb_weights = None):
111
+ r"""Convert a RGB image to grayscale version of image.
112
+
113
+ .. image:: _static/img/rgb_to_grayscale.png
114
+
115
+ The image data is assumed to be in the range of (0, 1).
116
+
117
+ Args:
118
+ image: RGB image to be converted to grayscale with shape :math:`(*,3,H,W)`.
119
+ rgb_weights: Weights that will be applied on each channel (RGB).
120
+ The sum of the weights should add up to one.
121
+ Returns:
122
+ grayscale version of the image with shape :math:`(*,1,H,W)`.
123
+
124
+ .. note::
125
+ See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
126
+ color_conversions.html>`__.
127
+
128
+ Example:
129
+ >>> input = torch.rand(2, 3, 4, 5)
130
+ >>> gray = rgb_to_grayscale(input) # 2x1x4x5
131
+ """
132
+
133
+ if len(image.shape) < 3 or image.shape[-3] != 3:
134
+ raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
135
+
136
+ if rgb_weights is None:
137
+ # 8 bit images
138
+ if image.dtype == torch.uint8:
139
+ rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8)
140
+ # floating point images
141
+ elif image.dtype in (torch.float16, torch.float32, torch.float64):
142
+ rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype)
143
+ else:
144
+ raise TypeError(f"Unknown data type: {image.dtype}")
145
+ else:
146
+ # is tensor that we make sure is in the same device/dtype
147
+ rgb_weights = rgb_weights.to(image)
148
+
149
+ # unpack the color image channels with RGB order
150
+ r: Tensor = image[..., 0:1, :, :]
151
+ g: Tensor = image[..., 1:2, :, :]
152
+ b: Tensor = image[..., 2:3, :, :]
153
+
154
+ w_r, w_g, w_b = rgb_weights.unbind()
155
+ return w_r * r + w_g * g + w_b * b
156
+
157
+ def canny(
158
+ input,
159
+ low_threshold = 0.1,
160
+ high_threshold = 0.2,
161
+ kernel_size = 5,
162
+ sigma = 1,
163
+ hysteresis = True,
164
+ eps = 1e-6,
165
+ ):
166
+ r"""Find edges of the input image and filters them using the Canny algorithm.
167
+ .. image:: _static/img/canny.png
168
+ Args:
169
+ input: input image tensor with shape :math:`(B,C,H,W)`.
170
+ low_threshold: lower threshold for the hysteresis procedure.
171
+ high_threshold: upper threshold for the hysteresis procedure.
172
+ kernel_size: the size of the kernel for the gaussian blur.
173
+ sigma: the standard deviation of the kernel for the gaussian blur.
174
+ hysteresis: if True, applies the hysteresis edge tracking.
175
+ Otherwise, the edges are divided between weak (0.5) and strong (1) edges.
176
+ eps: regularization number to avoid NaN during backprop.
177
+ Returns:
178
+ - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`.
179
+ - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`.
180
+ .. note::
181
+ See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
182
+ canny.html>`__.
183
+ Example:
184
+ >>> input = torch.rand(5, 3, 4, 4)
185
+ >>> magnitude, edges = canny(input) # 5x3x4x4
186
+ >>> magnitude.shape
187
+ torch.Size([5, 1, 4, 4])
188
+ >>> edges.shape
189
+ torch.Size([5, 1, 4, 4])
190
+ """
191
+ # KORNIA_CHECK_IS_TENSOR(input)
192
+ # KORNIA_CHECK_SHAPE(input, ['B', 'C', 'H', 'W'])
193
+ # KORNIA_CHECK(
194
+ # low_threshold <= high_threshold,
195
+ # "Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: "
196
+ # f"{low_threshold}>{high_threshold}",
197
+ # )
198
+ # KORNIA_CHECK(0 < low_threshold < 1, f'Invalid low threshold. Should be in range (0, 1). Got: {low_threshold}')
199
+ # KORNIA_CHECK(0 < high_threshold < 1, f'Invalid high threshold. Should be in range (0, 1). Got: {high_threshold}')
200
+
201
+ device = input.device
202
+ dtype = input.dtype
203
+
204
+ # To Grayscale
205
+ if input.shape[1] == 3:
206
+ input = rgb_to_grayscale(input)
207
+
208
+ # Gaussian filter
209
+ blurred: Tensor = gaussian_blur_2d(input, kernel_size, sigma)
210
+
211
+ # Compute the gradients
212
+ gradients: Tensor = spatial_gradient(blurred, normalized=False)
213
+
214
+ # Unpack the edges
215
+ gx: Tensor = gradients[:, :, 0]
216
+ gy: Tensor = gradients[:, :, 1]
217
+
218
+ # Compute gradient magnitude and angle
219
+ magnitude: Tensor = torch.sqrt(gx * gx + gy * gy + eps)
220
+ angle: Tensor = torch.atan2(gy, gx)
221
+
222
+ # Radians to Degrees
223
+ angle = 180.0 * angle / math.pi
224
+
225
+ # Round angle to the nearest 45 degree
226
+ angle = torch.round(angle / 45) * 45
227
+
228
+ # Non-maximal suppression
229
+ nms_kernels: Tensor = get_canny_nms_kernel(device, dtype)
230
+ nms_magnitude: Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2)
231
+
232
+ # Get the indices for both directions
233
+ positive_idx: Tensor = (angle / 45) % 8
234
+ positive_idx = positive_idx.long()
235
+
236
+ negative_idx: Tensor = ((angle / 45) + 4) % 8
237
+ negative_idx = negative_idx.long()
238
+
239
+ # Apply the non-maximum suppression to the different directions
240
+ channel_select_filtered_positive: Tensor = torch.gather(nms_magnitude, 1, positive_idx)
241
+ channel_select_filtered_negative: Tensor = torch.gather(nms_magnitude, 1, negative_idx)
242
+
243
+ channel_select_filtered: Tensor = torch.stack(
244
+ [channel_select_filtered_positive, channel_select_filtered_negative], 1
245
+ )
246
+
247
+ is_max: Tensor = channel_select_filtered.min(dim=1)[0] > 0.0
248
+
249
+ magnitude = magnitude * is_max
250
+
251
+ # Threshold
252
+ edges: Tensor = F.threshold(magnitude, low_threshold, 0.0)
253
+
254
+ low: Tensor = magnitude > low_threshold
255
+ high: Tensor = magnitude > high_threshold
256
+
257
+ edges = low * 0.5 + high * 0.5
258
+ edges = edges.to(dtype)
259
+
260
+ # Hysteresis
261
+ if hysteresis:
262
+ edges_old: Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype)
263
+ hysteresis_kernels: Tensor = get_hysteresis_kernel(device, dtype)
264
+
265
+ while ((edges_old - edges).abs() != 0).any():
266
+ weak: Tensor = (edges == 0.5).float()
267
+ strong: Tensor = (edges == 1).float()
268
+
269
+ hysteresis_magnitude: Tensor = F.conv2d(
270
+ edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2
271
+ )
272
+ hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype)
273
+ hysteresis_magnitude = hysteresis_magnitude * weak + strong
274
+
275
+ edges_old = edges.clone()
276
+ edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5
277
+
278
+ edges = hysteresis_magnitude
279
+
280
+ return magnitude, edges
281
+
282
+
283
+ class Canny:
284
+ @classmethod
285
+ def INPUT_TYPES(s):
286
+ return {"required": {"image": ("IMAGE",),
287
+ "low_threshold": ("FLOAT", {"default": 0.4, "min": 0.01, "max": 0.99, "step": 0.01}),
288
+ "high_threshold": ("FLOAT", {"default": 0.8, "min": 0.01, "max": 0.99, "step": 0.01})
289
+ }}
290
+
291
+ RETURN_TYPES = ("IMAGE",)
292
+ FUNCTION = "detect_edge"
293
+
294
+ CATEGORY = "image/preprocessors"
295
+
296
+ def detect_edge(self, image, low_threshold, high_threshold):
297
+ output = canny(image.to(ldm_patched.modules.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
298
+ img_out = output[1].to(ldm_patched.modules.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
299
+ return (img_out,)
300
+
301
+ NODE_CLASS_MAPPINGS = {
302
+ "Canny": Canny,
303
+ }
ldm_patched/contrib/external_clip_sdxl.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import torch
6
+ from ldm_patched.contrib.external import MAX_RESOLUTION
7
+
8
+ class CLIPTextEncodeSDXLRefiner:
9
+ @classmethod
10
+ def INPUT_TYPES(s):
11
+ return {"required": {
12
+ "ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
13
+ "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
14
+ "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
15
+ "text": ("STRING", {"multiline": True}), "clip": ("CLIP", ),
16
+ }}
17
+ RETURN_TYPES = ("CONDITIONING",)
18
+ FUNCTION = "encode"
19
+
20
+ CATEGORY = "advanced/conditioning"
21
+
22
+ def encode(self, clip, ascore, width, height, text):
23
+ tokens = clip.tokenize(text)
24
+ cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
25
+ return ([[cond, {"pooled_output": pooled, "aesthetic_score": ascore, "width": width,"height": height}]], )
26
+
27
+ class CLIPTextEncodeSDXL:
28
+ @classmethod
29
+ def INPUT_TYPES(s):
30
+ return {"required": {
31
+ "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
32
+ "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
33
+ "crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
34
+ "crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
35
+ "target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
36
+ "target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
37
+ "text_g": ("STRING", {"multiline": True, "default": "CLIP_G"}), "clip": ("CLIP", ),
38
+ "text_l": ("STRING", {"multiline": True, "default": "CLIP_L"}), "clip": ("CLIP", ),
39
+ }}
40
+ RETURN_TYPES = ("CONDITIONING",)
41
+ FUNCTION = "encode"
42
+
43
+ CATEGORY = "advanced/conditioning"
44
+
45
+ def encode(self, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l):
46
+ tokens = clip.tokenize(text_g)
47
+ tokens["l"] = clip.tokenize(text_l)["l"]
48
+ if len(tokens["l"]) != len(tokens["g"]):
49
+ empty = clip.tokenize("")
50
+ while len(tokens["l"]) < len(tokens["g"]):
51
+ tokens["l"] += empty["l"]
52
+ while len(tokens["l"]) > len(tokens["g"]):
53
+ tokens["g"] += empty["g"]
54
+ cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
55
+ return ([[cond, {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]], )
56
+
57
+ NODE_CLASS_MAPPINGS = {
58
+ "CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner,
59
+ "CLIPTextEncodeSDXL": CLIPTextEncodeSDXL,
60
+ }
ldm_patched/contrib/external_compositing.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import numpy as np
6
+ import torch
7
+ import ldm_patched.modules.utils
8
+ from enum import Enum
9
+
10
+ def resize_mask(mask, shape):
11
+ return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
12
+
13
+ class PorterDuffMode(Enum):
14
+ ADD = 0
15
+ CLEAR = 1
16
+ DARKEN = 2
17
+ DST = 3
18
+ DST_ATOP = 4
19
+ DST_IN = 5
20
+ DST_OUT = 6
21
+ DST_OVER = 7
22
+ LIGHTEN = 8
23
+ MULTIPLY = 9
24
+ OVERLAY = 10
25
+ SCREEN = 11
26
+ SRC = 12
27
+ SRC_ATOP = 13
28
+ SRC_IN = 14
29
+ SRC_OUT = 15
30
+ SRC_OVER = 16
31
+ XOR = 17
32
+
33
+
34
+ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode):
35
+ if mode == PorterDuffMode.ADD:
36
+ out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1)
37
+ out_image = torch.clamp(src_image + dst_image, 0, 1)
38
+ elif mode == PorterDuffMode.CLEAR:
39
+ out_alpha = torch.zeros_like(dst_alpha)
40
+ out_image = torch.zeros_like(dst_image)
41
+ elif mode == PorterDuffMode.DARKEN:
42
+ out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
43
+ out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image)
44
+ elif mode == PorterDuffMode.DST:
45
+ out_alpha = dst_alpha
46
+ out_image = dst_image
47
+ elif mode == PorterDuffMode.DST_ATOP:
48
+ out_alpha = src_alpha
49
+ out_image = src_alpha * dst_image + (1 - dst_alpha) * src_image
50
+ elif mode == PorterDuffMode.DST_IN:
51
+ out_alpha = src_alpha * dst_alpha
52
+ out_image = dst_image * src_alpha
53
+ elif mode == PorterDuffMode.DST_OUT:
54
+ out_alpha = (1 - src_alpha) * dst_alpha
55
+ out_image = (1 - src_alpha) * dst_image
56
+ elif mode == PorterDuffMode.DST_OVER:
57
+ out_alpha = dst_alpha + (1 - dst_alpha) * src_alpha
58
+ out_image = dst_image + (1 - dst_alpha) * src_image
59
+ elif mode == PorterDuffMode.LIGHTEN:
60
+ out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
61
+ out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.max(src_image, dst_image)
62
+ elif mode == PorterDuffMode.MULTIPLY:
63
+ out_alpha = src_alpha * dst_alpha
64
+ out_image = src_image * dst_image
65
+ elif mode == PorterDuffMode.OVERLAY:
66
+ out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
67
+ out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image,
68
+ src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
69
+ elif mode == PorterDuffMode.SCREEN:
70
+ out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
71
+ out_image = src_image + dst_image - src_image * dst_image
72
+ elif mode == PorterDuffMode.SRC:
73
+ out_alpha = src_alpha
74
+ out_image = src_image
75
+ elif mode == PorterDuffMode.SRC_ATOP:
76
+ out_alpha = dst_alpha
77
+ out_image = dst_alpha * src_image + (1 - src_alpha) * dst_image
78
+ elif mode == PorterDuffMode.SRC_IN:
79
+ out_alpha = src_alpha * dst_alpha
80
+ out_image = src_image * dst_alpha
81
+ elif mode == PorterDuffMode.SRC_OUT:
82
+ out_alpha = (1 - dst_alpha) * src_alpha
83
+ out_image = (1 - dst_alpha) * src_image
84
+ elif mode == PorterDuffMode.SRC_OVER:
85
+ out_alpha = src_alpha + (1 - src_alpha) * dst_alpha
86
+ out_image = src_image + (1 - src_alpha) * dst_image
87
+ elif mode == PorterDuffMode.XOR:
88
+ out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha
89
+ out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image
90
+ else:
91
+ out_alpha = None
92
+ out_image = None
93
+ return out_image, out_alpha
94
+
95
+
96
+ class PorterDuffImageComposite:
97
+ @classmethod
98
+ def INPUT_TYPES(s):
99
+ return {
100
+ "required": {
101
+ "source": ("IMAGE",),
102
+ "source_alpha": ("MASK",),
103
+ "destination": ("IMAGE",),
104
+ "destination_alpha": ("MASK",),
105
+ "mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
106
+ },
107
+ }
108
+
109
+ RETURN_TYPES = ("IMAGE", "MASK")
110
+ FUNCTION = "composite"
111
+ CATEGORY = "mask/compositing"
112
+
113
+ def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
114
+ batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
115
+ out_images = []
116
+ out_alphas = []
117
+
118
+ for i in range(batch_size):
119
+ src_image = source[i]
120
+ dst_image = destination[i]
121
+
122
+ assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
123
+
124
+ src_alpha = source_alpha[i].unsqueeze(2)
125
+ dst_alpha = destination_alpha[i].unsqueeze(2)
126
+
127
+ if dst_alpha.shape[:2] != dst_image.shape[:2]:
128
+ upscale_input = dst_alpha.unsqueeze(0).permute(0, 3, 1, 2)
129
+ upscale_output = ldm_patched.modules.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
130
+ dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
131
+ if src_image.shape != dst_image.shape:
132
+ upscale_input = src_image.unsqueeze(0).permute(0, 3, 1, 2)
133
+ upscale_output = ldm_patched.modules.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
134
+ src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0)
135
+ if src_alpha.shape != dst_alpha.shape:
136
+ upscale_input = src_alpha.unsqueeze(0).permute(0, 3, 1, 2)
137
+ upscale_output = ldm_patched.modules.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center')
138
+ src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
139
+
140
+ out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode])
141
+
142
+ out_images.append(out_image)
143
+ out_alphas.append(out_alpha.squeeze(2))
144
+
145
+ result = (torch.stack(out_images), torch.stack(out_alphas))
146
+ return result
147
+
148
+
149
+ class SplitImageWithAlpha:
150
+ @classmethod
151
+ def INPUT_TYPES(s):
152
+ return {
153
+ "required": {
154
+ "image": ("IMAGE",),
155
+ }
156
+ }
157
+
158
+ CATEGORY = "mask/compositing"
159
+ RETURN_TYPES = ("IMAGE", "MASK")
160
+ FUNCTION = "split_image_with_alpha"
161
+
162
+ def split_image_with_alpha(self, image: torch.Tensor):
163
+ out_images = [i[:,:,:3] for i in image]
164
+ out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
165
+ result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
166
+ return result
167
+
168
+
169
+ class JoinImageWithAlpha:
170
+ @classmethod
171
+ def INPUT_TYPES(s):
172
+ return {
173
+ "required": {
174
+ "image": ("IMAGE",),
175
+ "alpha": ("MASK",),
176
+ }
177
+ }
178
+
179
+ CATEGORY = "mask/compositing"
180
+ RETURN_TYPES = ("IMAGE",)
181
+ FUNCTION = "join_image_with_alpha"
182
+
183
+ def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
184
+ batch_size = min(len(image), len(alpha))
185
+ out_images = []
186
+
187
+ alpha = 1.0 - resize_mask(alpha, image.shape[1:])
188
+ for i in range(batch_size):
189
+ out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
190
+
191
+ result = (torch.stack(out_images),)
192
+ return result
193
+
194
+
195
+ NODE_CLASS_MAPPINGS = {
196
+ "PorterDuffImageComposite": PorterDuffImageComposite,
197
+ "SplitImageWithAlpha": SplitImageWithAlpha,
198
+ "JoinImageWithAlpha": JoinImageWithAlpha,
199
+ }
200
+
201
+
202
+ NODE_DISPLAY_NAME_MAPPINGS = {
203
+ "PorterDuffImageComposite": "Porter-Duff Image Composite",
204
+ "SplitImageWithAlpha": "Split Image with Alpha",
205
+ "JoinImageWithAlpha": "Join Image with Alpha",
206
+ }
ldm_patched/contrib/external_custom_sampler.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import ldm_patched.modules.samplers
6
+ import ldm_patched.modules.sample
7
+ from ldm_patched.k_diffusion import sampling as k_diffusion_sampling
8
+ import ldm_patched.utils.latent_visualization
9
+ import torch
10
+ import ldm_patched.modules.utils
11
+
12
+
13
+ class BasicScheduler:
14
+ @classmethod
15
+ def INPUT_TYPES(s):
16
+ return {"required":
17
+ {"model": ("MODEL",),
18
+ "scheduler": (ldm_patched.modules.samplers.SCHEDULER_NAMES, ),
19
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
20
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
21
+ }
22
+ }
23
+ RETURN_TYPES = ("SIGMAS",)
24
+ CATEGORY = "sampling/custom_sampling/schedulers"
25
+
26
+ FUNCTION = "get_sigmas"
27
+
28
+ def get_sigmas(self, model, scheduler, steps, denoise):
29
+ total_steps = steps
30
+ if denoise < 1.0:
31
+ total_steps = int(steps/denoise)
32
+
33
+ ldm_patched.modules.model_management.load_models_gpu([model])
34
+ sigmas = ldm_patched.modules.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
35
+ sigmas = sigmas[-(steps + 1):]
36
+ return (sigmas, )
37
+
38
+
39
+ class KarrasScheduler:
40
+ @classmethod
41
+ def INPUT_TYPES(s):
42
+ return {"required":
43
+ {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
44
+ "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
45
+ "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
46
+ "rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
47
+ }
48
+ }
49
+ RETURN_TYPES = ("SIGMAS",)
50
+ CATEGORY = "sampling/custom_sampling/schedulers"
51
+
52
+ FUNCTION = "get_sigmas"
53
+
54
+ def get_sigmas(self, steps, sigma_max, sigma_min, rho):
55
+ sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
56
+ return (sigmas, )
57
+
58
+ class ExponentialScheduler:
59
+ @classmethod
60
+ def INPUT_TYPES(s):
61
+ return {"required":
62
+ {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
63
+ "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
64
+ "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
65
+ }
66
+ }
67
+ RETURN_TYPES = ("SIGMAS",)
68
+ CATEGORY = "sampling/custom_sampling/schedulers"
69
+
70
+ FUNCTION = "get_sigmas"
71
+
72
+ def get_sigmas(self, steps, sigma_max, sigma_min):
73
+ sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max)
74
+ return (sigmas, )
75
+
76
+ class PolyexponentialScheduler:
77
+ @classmethod
78
+ def INPUT_TYPES(s):
79
+ return {"required":
80
+ {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
81
+ "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
82
+ "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
83
+ "rho": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
84
+ }
85
+ }
86
+ RETURN_TYPES = ("SIGMAS",)
87
+ CATEGORY = "sampling/custom_sampling/schedulers"
88
+
89
+ FUNCTION = "get_sigmas"
90
+
91
+ def get_sigmas(self, steps, sigma_max, sigma_min, rho):
92
+ sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
93
+ return (sigmas, )
94
+
95
+ class SDTurboScheduler:
96
+ @classmethod
97
+ def INPUT_TYPES(s):
98
+ return {"required":
99
+ {"model": ("MODEL",),
100
+ "steps": ("INT", {"default": 1, "min": 1, "max": 10}),
101
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
102
+ }
103
+ }
104
+ RETURN_TYPES = ("SIGMAS",)
105
+ CATEGORY = "sampling/custom_sampling/schedulers"
106
+
107
+ FUNCTION = "get_sigmas"
108
+
109
+ def get_sigmas(self, model, steps, denoise):
110
+ start_step = 10 - int(10 * denoise)
111
+ timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
112
+ ldm_patched.modules.model_management.load_models_gpu([model])
113
+ sigmas = model.model.model_sampling.sigma(timesteps)
114
+ sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
115
+ return (sigmas, )
116
+
117
+ class VPScheduler:
118
+ @classmethod
119
+ def INPUT_TYPES(s):
120
+ return {"required":
121
+ {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
122
+ "beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), #TODO: fix default values
123
+ "beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
124
+ "eps_s": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step":0.0001, "round": False}),
125
+ }
126
+ }
127
+ RETURN_TYPES = ("SIGMAS",)
128
+ CATEGORY = "sampling/custom_sampling/schedulers"
129
+
130
+ FUNCTION = "get_sigmas"
131
+
132
+ def get_sigmas(self, steps, beta_d, beta_min, eps_s):
133
+ sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s)
134
+ return (sigmas, )
135
+
136
+ class SplitSigmas:
137
+ @classmethod
138
+ def INPUT_TYPES(s):
139
+ return {"required":
140
+ {"sigmas": ("SIGMAS", ),
141
+ "step": ("INT", {"default": 0, "min": 0, "max": 10000}),
142
+ }
143
+ }
144
+ RETURN_TYPES = ("SIGMAS","SIGMAS")
145
+ CATEGORY = "sampling/custom_sampling/sigmas"
146
+
147
+ FUNCTION = "get_sigmas"
148
+
149
+ def get_sigmas(self, sigmas, step):
150
+ sigmas1 = sigmas[:step + 1]
151
+ sigmas2 = sigmas[step:]
152
+ return (sigmas1, sigmas2)
153
+
154
+ class FlipSigmas:
155
+ @classmethod
156
+ def INPUT_TYPES(s):
157
+ return {"required":
158
+ {"sigmas": ("SIGMAS", ),
159
+ }
160
+ }
161
+ RETURN_TYPES = ("SIGMAS",)
162
+ CATEGORY = "sampling/custom_sampling/sigmas"
163
+
164
+ FUNCTION = "get_sigmas"
165
+
166
+ def get_sigmas(self, sigmas):
167
+ sigmas = sigmas.flip(0)
168
+ if sigmas[0] == 0:
169
+ sigmas[0] = 0.0001
170
+ return (sigmas,)
171
+
172
+ class KSamplerSelect:
173
+ @classmethod
174
+ def INPUT_TYPES(s):
175
+ return {"required":
176
+ {"sampler_name": (ldm_patched.modules.samplers.SAMPLER_NAMES, ),
177
+ }
178
+ }
179
+ RETURN_TYPES = ("SAMPLER",)
180
+ CATEGORY = "sampling/custom_sampling/samplers"
181
+
182
+ FUNCTION = "get_sampler"
183
+
184
+ def get_sampler(self, sampler_name):
185
+ sampler = ldm_patched.modules.samplers.sampler_object(sampler_name)
186
+ return (sampler, )
187
+
188
+ class SamplerDPMPP_2M_SDE:
189
+ @classmethod
190
+ def INPUT_TYPES(s):
191
+ return {"required":
192
+ {"solver_type": (['midpoint', 'heun'], ),
193
+ "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
194
+ "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
195
+ "noise_device": (['gpu', 'cpu'], ),
196
+ }
197
+ }
198
+ RETURN_TYPES = ("SAMPLER",)
199
+ CATEGORY = "sampling/custom_sampling/samplers"
200
+
201
+ FUNCTION = "get_sampler"
202
+
203
+ def get_sampler(self, solver_type, eta, s_noise, noise_device):
204
+ if noise_device == 'cpu':
205
+ sampler_name = "dpmpp_2m_sde"
206
+ else:
207
+ sampler_name = "dpmpp_2m_sde_gpu"
208
+ sampler = ldm_patched.modules.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
209
+ return (sampler, )
210
+
211
+
212
+ class SamplerDPMPP_SDE:
213
+ @classmethod
214
+ def INPUT_TYPES(s):
215
+ return {"required":
216
+ {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
217
+ "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
218
+ "r": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
219
+ "noise_device": (['gpu', 'cpu'], ),
220
+ }
221
+ }
222
+ RETURN_TYPES = ("SAMPLER",)
223
+ CATEGORY = "sampling/custom_sampling/samplers"
224
+
225
+ FUNCTION = "get_sampler"
226
+
227
+ def get_sampler(self, eta, s_noise, r, noise_device):
228
+ if noise_device == 'cpu':
229
+ sampler_name = "dpmpp_sde"
230
+ else:
231
+ sampler_name = "dpmpp_sde_gpu"
232
+ sampler = ldm_patched.modules.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
233
+ return (sampler, )
234
+
235
+ class SamplerCustom:
236
+ @classmethod
237
+ def INPUT_TYPES(s):
238
+ return {"required":
239
+ {"model": ("MODEL",),
240
+ "add_noise": ("BOOLEAN", {"default": True}),
241
+ "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
242
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
243
+ "positive": ("CONDITIONING", ),
244
+ "negative": ("CONDITIONING", ),
245
+ "sampler": ("SAMPLER", ),
246
+ "sigmas": ("SIGMAS", ),
247
+ "latent_image": ("LATENT", ),
248
+ }
249
+ }
250
+
251
+ RETURN_TYPES = ("LATENT","LATENT")
252
+ RETURN_NAMES = ("output", "denoised_output")
253
+
254
+ FUNCTION = "sample"
255
+
256
+ CATEGORY = "sampling/custom_sampling"
257
+
258
+ def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image):
259
+ latent = latent_image
260
+ latent_image = latent["samples"]
261
+ if not add_noise:
262
+ noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
263
+ else:
264
+ batch_inds = latent["batch_index"] if "batch_index" in latent else None
265
+ noise = ldm_patched.modules.sample.prepare_noise(latent_image, noise_seed, batch_inds)
266
+
267
+ noise_mask = None
268
+ if "noise_mask" in latent:
269
+ noise_mask = latent["noise_mask"]
270
+
271
+ x0_output = {}
272
+ callback = ldm_patched.utils.latent_visualization.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
273
+
274
+ disable_pbar = not ldm_patched.modules.utils.PROGRESS_BAR_ENABLED
275
+ samples = ldm_patched.modules.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
276
+
277
+ out = latent.copy()
278
+ out["samples"] = samples
279
+ if "x0" in x0_output:
280
+ out_denoised = latent.copy()
281
+ out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu())
282
+ else:
283
+ out_denoised = out
284
+ return (out, out_denoised)
285
+
286
+ NODE_CLASS_MAPPINGS = {
287
+ "SamplerCustom": SamplerCustom,
288
+ "BasicScheduler": BasicScheduler,
289
+ "KarrasScheduler": KarrasScheduler,
290
+ "ExponentialScheduler": ExponentialScheduler,
291
+ "PolyexponentialScheduler": PolyexponentialScheduler,
292
+ "VPScheduler": VPScheduler,
293
+ "SDTurboScheduler": SDTurboScheduler,
294
+ "KSamplerSelect": KSamplerSelect,
295
+ "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
296
+ "SamplerDPMPP_SDE": SamplerDPMPP_SDE,
297
+ "SplitSigmas": SplitSigmas,
298
+ "FlipSigmas": FlipSigmas,
299
+ }
ldm_patched/contrib/external_freelunch.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ #code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License)
4
+
5
+ import torch
6
+
7
+
8
+ def Fourier_filter(x, threshold, scale):
9
+ # FFT
10
+ x_freq = torch.fft.fftn(x.float(), dim=(-2, -1))
11
+ x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
12
+
13
+ B, C, H, W = x_freq.shape
14
+ mask = torch.ones((B, C, H, W), device=x.device)
15
+
16
+ crow, ccol = H // 2, W //2
17
+ mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
18
+ x_freq = x_freq * mask
19
+
20
+ # IFFT
21
+ x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
22
+ x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real
23
+
24
+ return x_filtered.to(x.dtype)
25
+
26
+
27
+ class FreeU:
28
+ @classmethod
29
+ def INPUT_TYPES(s):
30
+ return {"required": { "model": ("MODEL",),
31
+ "b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}),
32
+ "b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}),
33
+ "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
34
+ "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
35
+ }}
36
+ RETURN_TYPES = ("MODEL",)
37
+ FUNCTION = "patch"
38
+
39
+ CATEGORY = "model_patches"
40
+
41
+ def patch(self, model, b1, b2, s1, s2):
42
+ model_channels = model.model.model_config.unet_config["model_channels"]
43
+ scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
44
+ on_cpu_devices = {}
45
+
46
+ def output_block_patch(h, hsp, transformer_options):
47
+ scale = scale_dict.get(h.shape[1], None)
48
+ if scale is not None:
49
+ h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
50
+ if hsp.device not in on_cpu_devices:
51
+ try:
52
+ hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
53
+ except:
54
+ print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.")
55
+ on_cpu_devices[hsp.device] = True
56
+ hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
57
+ else:
58
+ hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
59
+
60
+ return h, hsp
61
+
62
+ m = model.clone()
63
+ m.set_model_output_block_patch(output_block_patch)
64
+ return (m, )
65
+
66
+ class FreeU_V2:
67
+ @classmethod
68
+ def INPUT_TYPES(s):
69
+ return {"required": { "model": ("MODEL",),
70
+ "b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}),
71
+ "b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}),
72
+ "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
73
+ "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
74
+ }}
75
+ RETURN_TYPES = ("MODEL",)
76
+ FUNCTION = "patch"
77
+
78
+ CATEGORY = "model_patches"
79
+
80
+ def patch(self, model, b1, b2, s1, s2):
81
+ model_channels = model.model.model_config.unet_config["model_channels"]
82
+ scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
83
+ on_cpu_devices = {}
84
+
85
+ def output_block_patch(h, hsp, transformer_options):
86
+ scale = scale_dict.get(h.shape[1], None)
87
+ if scale is not None:
88
+ hidden_mean = h.mean(1).unsqueeze(1)
89
+ B = hidden_mean.shape[0]
90
+ hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
91
+ hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
92
+ hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
93
+
94
+ h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * ((scale[0] - 1 ) * hidden_mean + 1)
95
+
96
+ if hsp.device not in on_cpu_devices:
97
+ try:
98
+ hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
99
+ except:
100
+ print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.")
101
+ on_cpu_devices[hsp.device] = True
102
+ hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
103
+ else:
104
+ hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
105
+
106
+ return h, hsp
107
+
108
+ m = model.clone()
109
+ m.set_model_output_block_patch(output_block_patch)
110
+ return (m, )
111
+
112
+ NODE_CLASS_MAPPINGS = {
113
+ "FreeU": FreeU,
114
+ "FreeU_V2": FreeU_V2,
115
+ }
ldm_patched/contrib/external_hypernetwork.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import ldm_patched.modules.utils
6
+ import ldm_patched.utils.path_utils
7
+ import torch
8
+
9
+ def load_hypernetwork_patch(path, strength):
10
+ sd = ldm_patched.modules.utils.load_torch_file(path, safe_load=True)
11
+ activation_func = sd.get('activation_func', 'linear')
12
+ is_layer_norm = sd.get('is_layer_norm', False)
13
+ use_dropout = sd.get('use_dropout', False)
14
+ activate_output = sd.get('activate_output', False)
15
+ last_layer_dropout = sd.get('last_layer_dropout', False)
16
+
17
+ valid_activation = {
18
+ "linear": torch.nn.Identity,
19
+ "relu": torch.nn.ReLU,
20
+ "leakyrelu": torch.nn.LeakyReLU,
21
+ "elu": torch.nn.ELU,
22
+ "swish": torch.nn.Hardswish,
23
+ "tanh": torch.nn.Tanh,
24
+ "sigmoid": torch.nn.Sigmoid,
25
+ "softsign": torch.nn.Softsign,
26
+ "mish": torch.nn.Mish,
27
+ }
28
+
29
+ if activation_func not in valid_activation:
30
+ print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
31
+ return None
32
+
33
+ out = {}
34
+
35
+ for d in sd:
36
+ try:
37
+ dim = int(d)
38
+ except:
39
+ continue
40
+
41
+ output = []
42
+ for index in [0, 1]:
43
+ attn_weights = sd[dim][index]
44
+ keys = attn_weights.keys()
45
+
46
+ linears = filter(lambda a: a.endswith(".weight"), keys)
47
+ linears = list(map(lambda a: a[:-len(".weight")], linears))
48
+ layers = []
49
+
50
+ i = 0
51
+ while i < len(linears):
52
+ lin_name = linears[i]
53
+ last_layer = (i == (len(linears) - 1))
54
+ penultimate_layer = (i == (len(linears) - 2))
55
+
56
+ lin_weight = attn_weights['{}.weight'.format(lin_name)]
57
+ lin_bias = attn_weights['{}.bias'.format(lin_name)]
58
+ layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
59
+ layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
60
+ layers.append(layer)
61
+ if activation_func != "linear":
62
+ if (not last_layer) or (activate_output):
63
+ layers.append(valid_activation[activation_func]())
64
+ if is_layer_norm:
65
+ i += 1
66
+ ln_name = linears[i]
67
+ ln_weight = attn_weights['{}.weight'.format(ln_name)]
68
+ ln_bias = attn_weights['{}.bias'.format(ln_name)]
69
+ ln = torch.nn.LayerNorm(ln_weight.shape[0])
70
+ ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
71
+ layers.append(ln)
72
+ if use_dropout:
73
+ if (not last_layer) and (not penultimate_layer or last_layer_dropout):
74
+ layers.append(torch.nn.Dropout(p=0.3))
75
+ i += 1
76
+
77
+ output.append(torch.nn.Sequential(*layers))
78
+ out[dim] = torch.nn.ModuleList(output)
79
+
80
+ class hypernetwork_patch:
81
+ def __init__(self, hypernet, strength):
82
+ self.hypernet = hypernet
83
+ self.strength = strength
84
+ def __call__(self, q, k, v, extra_options):
85
+ dim = k.shape[-1]
86
+ if dim in self.hypernet:
87
+ hn = self.hypernet[dim]
88
+ k = k + hn[0](k) * self.strength
89
+ v = v + hn[1](v) * self.strength
90
+
91
+ return q, k, v
92
+
93
+ def to(self, device):
94
+ for d in self.hypernet.keys():
95
+ self.hypernet[d] = self.hypernet[d].to(device)
96
+ return self
97
+
98
+ return hypernetwork_patch(out, strength)
99
+
100
+ class HypernetworkLoader:
101
+ @classmethod
102
+ def INPUT_TYPES(s):
103
+ return {"required": { "model": ("MODEL",),
104
+ "hypernetwork_name": (ldm_patched.utils.path_utils.get_filename_list("hypernetworks"), ),
105
+ "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
106
+ }}
107
+ RETURN_TYPES = ("MODEL",)
108
+ FUNCTION = "load_hypernetwork"
109
+
110
+ CATEGORY = "loaders"
111
+
112
+ def load_hypernetwork(self, model, hypernetwork_name, strength):
113
+ hypernetwork_path = ldm_patched.utils.path_utils.get_full_path("hypernetworks", hypernetwork_name)
114
+ model_hypernetwork = model.clone()
115
+ patch = load_hypernetwork_patch(hypernetwork_path, strength)
116
+ if patch is not None:
117
+ model_hypernetwork.set_model_attn1_patch(patch)
118
+ model_hypernetwork.set_model_attn2_patch(patch)
119
+ return (model_hypernetwork,)
120
+
121
+ NODE_CLASS_MAPPINGS = {
122
+ "HypernetworkLoader": HypernetworkLoader
123
+ }
ldm_patched/contrib/external_hypertile.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ #Taken from: https://github.com/tfernd/HyperTile/
4
+
5
+ import math
6
+ from einops import rearrange
7
+ # Use torch rng for consistency across generations
8
+ from torch import randint
9
+
10
+ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
11
+ min_value = min(min_value, value)
12
+
13
+ # All big divisors of value (inclusive)
14
+ divisors = [i for i in range(min_value, value + 1) if value % i == 0]
15
+
16
+ ns = [value // i for i in divisors[:max_options]] # has at least 1 element
17
+
18
+ if len(ns) - 1 > 0:
19
+ idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
20
+ else:
21
+ idx = 0
22
+
23
+ return ns[idx]
24
+
25
+ class HyperTile:
26
+ @classmethod
27
+ def INPUT_TYPES(s):
28
+ return {"required": { "model": ("MODEL",),
29
+ "tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}),
30
+ "swap_size": ("INT", {"default": 2, "min": 1, "max": 128}),
31
+ "max_depth": ("INT", {"default": 0, "min": 0, "max": 10}),
32
+ "scale_depth": ("BOOLEAN", {"default": False}),
33
+ }}
34
+ RETURN_TYPES = ("MODEL",)
35
+ FUNCTION = "patch"
36
+
37
+ CATEGORY = "model_patches"
38
+
39
+ def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
40
+ model_channels = model.model.model_config.unet_config["model_channels"]
41
+
42
+ latent_tile_size = max(32, tile_size) // 8
43
+ self.temp = None
44
+
45
+ def hypertile_in(q, k, v, extra_options):
46
+ model_chans = q.shape[-2]
47
+ orig_shape = extra_options['original_shape']
48
+ apply_to = []
49
+ for i in range(max_depth + 1):
50
+ apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i)))
51
+
52
+ if model_chans in apply_to:
53
+ shape = extra_options["original_shape"]
54
+ aspect_ratio = shape[-1] / shape[-2]
55
+
56
+ hw = q.size(1)
57
+ h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
58
+
59
+ factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1
60
+ nh = random_divisor(h, latent_tile_size * factor, swap_size)
61
+ nw = random_divisor(w, latent_tile_size * factor, swap_size)
62
+
63
+ if nh * nw > 1:
64
+ q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
65
+ self.temp = (nh, nw, h, w)
66
+ return q, k, v
67
+
68
+ return q, k, v
69
+ def hypertile_out(out, extra_options):
70
+ if self.temp is not None:
71
+ nh, nw, h, w = self.temp
72
+ self.temp = None
73
+ out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
74
+ out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
75
+ return out
76
+
77
+
78
+ m = model.clone()
79
+ m.set_model_attn1_patch(hypertile_in)
80
+ m.set_model_attn1_output_patch(hypertile_out)
81
+ return (m, )
82
+
83
+ NODE_CLASS_MAPPINGS = {
84
+ "HyperTile": HyperTile,
85
+ }
ldm_patched/contrib/external_images.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import ldm_patched.contrib.external
6
+ import ldm_patched.utils.path_utils
7
+ from ldm_patched.modules.args_parser import args
8
+
9
+ from PIL import Image
10
+ from PIL.PngImagePlugin import PngInfo
11
+
12
+ import numpy as np
13
+ import json
14
+ import os
15
+
16
+ MAX_RESOLUTION = ldm_patched.contrib.external.MAX_RESOLUTION
17
+
18
+ class ImageCrop:
19
+ @classmethod
20
+ def INPUT_TYPES(s):
21
+ return {"required": { "image": ("IMAGE",),
22
+ "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
23
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
24
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
25
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
26
+ }}
27
+ RETURN_TYPES = ("IMAGE",)
28
+ FUNCTION = "crop"
29
+
30
+ CATEGORY = "image/transform"
31
+
32
+ def crop(self, image, width, height, x, y):
33
+ x = min(x, image.shape[2] - 1)
34
+ y = min(y, image.shape[1] - 1)
35
+ to_x = width + x
36
+ to_y = height + y
37
+ img = image[:,y:to_y, x:to_x, :]
38
+ return (img,)
39
+
40
+ class RepeatImageBatch:
41
+ @classmethod
42
+ def INPUT_TYPES(s):
43
+ return {"required": { "image": ("IMAGE",),
44
+ "amount": ("INT", {"default": 1, "min": 1, "max": 64}),
45
+ }}
46
+ RETURN_TYPES = ("IMAGE",)
47
+ FUNCTION = "repeat"
48
+
49
+ CATEGORY = "image/batch"
50
+
51
+ def repeat(self, image, amount):
52
+ s = image.repeat((amount, 1,1,1))
53
+ return (s,)
54
+
55
+ class SaveAnimatedWEBP:
56
+ def __init__(self):
57
+ self.output_dir = ldm_patched.utils.path_utils.get_output_directory()
58
+ self.type = "output"
59
+ self.prefix_append = ""
60
+
61
+ methods = {"default": 4, "fastest": 0, "slowest": 6}
62
+ @classmethod
63
+ def INPUT_TYPES(s):
64
+ return {"required":
65
+ {"images": ("IMAGE", ),
66
+ "filename_prefix": ("STRING", {"default": "ldm_patched"}),
67
+ "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
68
+ "lossless": ("BOOLEAN", {"default": True}),
69
+ "quality": ("INT", {"default": 80, "min": 0, "max": 100}),
70
+ "method": (list(s.methods.keys()),),
71
+ # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
72
+ },
73
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
74
+ }
75
+
76
+ RETURN_TYPES = ()
77
+ FUNCTION = "save_images"
78
+
79
+ OUTPUT_NODE = True
80
+
81
+ CATEGORY = "image/animation"
82
+
83
+ def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
84
+ method = self.methods.get(method)
85
+ filename_prefix += self.prefix_append
86
+ full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
87
+ results = list()
88
+ pil_images = []
89
+ for image in images:
90
+ i = 255. * image.cpu().numpy()
91
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
92
+ pil_images.append(img)
93
+
94
+ metadata = pil_images[0].getexif()
95
+ if not args.disable_server_info:
96
+ if prompt is not None:
97
+ metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
98
+ if extra_pnginfo is not None:
99
+ inital_exif = 0x010f
100
+ for x in extra_pnginfo:
101
+ metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
102
+ inital_exif -= 1
103
+
104
+ if num_frames == 0:
105
+ num_frames = len(pil_images)
106
+
107
+ c = len(pil_images)
108
+ for i in range(0, c, num_frames):
109
+ file = f"{filename}_{counter:05}_.webp"
110
+ pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
111
+ results.append({
112
+ "filename": file,
113
+ "subfolder": subfolder,
114
+ "type": self.type
115
+ })
116
+ counter += 1
117
+
118
+ animated = num_frames != 1
119
+ return { "ui": { "images": results, "animated": (animated,) } }
120
+
121
+ class SaveAnimatedPNG:
122
+ def __init__(self):
123
+ self.output_dir = ldm_patched.utils.path_utils.get_output_directory()
124
+ self.type = "output"
125
+ self.prefix_append = ""
126
+
127
+ @classmethod
128
+ def INPUT_TYPES(s):
129
+ return {"required":
130
+ {"images": ("IMAGE", ),
131
+ "filename_prefix": ("STRING", {"default": "ldm_patched"}),
132
+ "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
133
+ "compress_level": ("INT", {"default": 4, "min": 0, "max": 9})
134
+ },
135
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
136
+ }
137
+
138
+ RETURN_TYPES = ()
139
+ FUNCTION = "save_images"
140
+
141
+ OUTPUT_NODE = True
142
+
143
+ CATEGORY = "image/animation"
144
+
145
+ def save_images(self, images, fps, compress_level, filename_prefix="ldm_patched", prompt=None, extra_pnginfo=None):
146
+ filename_prefix += self.prefix_append
147
+ full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
148
+ results = list()
149
+ pil_images = []
150
+ for image in images:
151
+ i = 255. * image.cpu().numpy()
152
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
153
+ pil_images.append(img)
154
+
155
+ metadata = None
156
+ if not args.disable_server_info:
157
+ metadata = PngInfo()
158
+ if prompt is not None:
159
+ metadata.add(b"ldm_patched", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
160
+ if extra_pnginfo is not None:
161
+ for x in extra_pnginfo:
162
+ metadata.add(b"ldm_patched", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
163
+
164
+ file = f"{filename}_{counter:05}_.png"
165
+ pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
166
+ results.append({
167
+ "filename": file,
168
+ "subfolder": subfolder,
169
+ "type": self.type
170
+ })
171
+
172
+ return { "ui": { "images": results, "animated": (True,)} }
173
+
174
+ NODE_CLASS_MAPPINGS = {
175
+ "ImageCrop": ImageCrop,
176
+ "RepeatImageBatch": RepeatImageBatch,
177
+ "SaveAnimatedWEBP": SaveAnimatedWEBP,
178
+ "SaveAnimatedPNG": SaveAnimatedPNG,
179
+ }
ldm_patched/contrib/external_latent.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import ldm_patched.modules.utils
6
+ import torch
7
+
8
+ def reshape_latent_to(target_shape, latent):
9
+ if latent.shape[1:] != target_shape[1:]:
10
+ latent = ldm_patched.modules.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
11
+ return ldm_patched.modules.utils.repeat_to_batch_size(latent, target_shape[0])
12
+
13
+
14
+ class LatentAdd:
15
+ @classmethod
16
+ def INPUT_TYPES(s):
17
+ return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
18
+
19
+ RETURN_TYPES = ("LATENT",)
20
+ FUNCTION = "op"
21
+
22
+ CATEGORY = "latent/advanced"
23
+
24
+ def op(self, samples1, samples2):
25
+ samples_out = samples1.copy()
26
+
27
+ s1 = samples1["samples"]
28
+ s2 = samples2["samples"]
29
+
30
+ s2 = reshape_latent_to(s1.shape, s2)
31
+ samples_out["samples"] = s1 + s2
32
+ return (samples_out,)
33
+
34
+ class LatentSubtract:
35
+ @classmethod
36
+ def INPUT_TYPES(s):
37
+ return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
38
+
39
+ RETURN_TYPES = ("LATENT",)
40
+ FUNCTION = "op"
41
+
42
+ CATEGORY = "latent/advanced"
43
+
44
+ def op(self, samples1, samples2):
45
+ samples_out = samples1.copy()
46
+
47
+ s1 = samples1["samples"]
48
+ s2 = samples2["samples"]
49
+
50
+ s2 = reshape_latent_to(s1.shape, s2)
51
+ samples_out["samples"] = s1 - s2
52
+ return (samples_out,)
53
+
54
+ class LatentMultiply:
55
+ @classmethod
56
+ def INPUT_TYPES(s):
57
+ return {"required": { "samples": ("LATENT",),
58
+ "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
59
+ }}
60
+
61
+ RETURN_TYPES = ("LATENT",)
62
+ FUNCTION = "op"
63
+
64
+ CATEGORY = "latent/advanced"
65
+
66
+ def op(self, samples, multiplier):
67
+ samples_out = samples.copy()
68
+
69
+ s1 = samples["samples"]
70
+ samples_out["samples"] = s1 * multiplier
71
+ return (samples_out,)
72
+
73
+ class LatentInterpolate:
74
+ @classmethod
75
+ def INPUT_TYPES(s):
76
+ return {"required": { "samples1": ("LATENT",),
77
+ "samples2": ("LATENT",),
78
+ "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
79
+ }}
80
+
81
+ RETURN_TYPES = ("LATENT",)
82
+ FUNCTION = "op"
83
+
84
+ CATEGORY = "latent/advanced"
85
+
86
+ def op(self, samples1, samples2, ratio):
87
+ samples_out = samples1.copy()
88
+
89
+ s1 = samples1["samples"]
90
+ s2 = samples2["samples"]
91
+
92
+ s2 = reshape_latent_to(s1.shape, s2)
93
+
94
+ m1 = torch.linalg.vector_norm(s1, dim=(1))
95
+ m2 = torch.linalg.vector_norm(s2, dim=(1))
96
+
97
+ s1 = torch.nan_to_num(s1 / m1)
98
+ s2 = torch.nan_to_num(s2 / m2)
99
+
100
+ t = (s1 * ratio + s2 * (1.0 - ratio))
101
+ mt = torch.linalg.vector_norm(t, dim=(1))
102
+ st = torch.nan_to_num(t / mt)
103
+
104
+ samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
105
+ return (samples_out,)
106
+
107
+ class LatentBatch:
108
+ @classmethod
109
+ def INPUT_TYPES(s):
110
+ return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
111
+
112
+ RETURN_TYPES = ("LATENT",)
113
+ FUNCTION = "batch"
114
+
115
+ CATEGORY = "latent/batch"
116
+
117
+ def batch(self, samples1, samples2):
118
+ samples_out = samples1.copy()
119
+ s1 = samples1["samples"]
120
+ s2 = samples2["samples"]
121
+
122
+ if s1.shape[1:] != s2.shape[1:]:
123
+ s2 = ldm_patched.modules.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center")
124
+ s = torch.cat((s1, s2), dim=0)
125
+ samples_out["samples"] = s
126
+ samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
127
+ return (samples_out,)
128
+
129
+ class LatentBatchSeedBehavior:
130
+ @classmethod
131
+ def INPUT_TYPES(s):
132
+ return {"required": { "samples": ("LATENT",),
133
+ "seed_behavior": (["random", "fixed"],{"default": "fixed"}),}}
134
+
135
+ RETURN_TYPES = ("LATENT",)
136
+ FUNCTION = "op"
137
+
138
+ CATEGORY = "latent/advanced"
139
+
140
+ def op(self, samples, seed_behavior):
141
+ samples_out = samples.copy()
142
+ latent = samples["samples"]
143
+ if seed_behavior == "random":
144
+ if 'batch_index' in samples_out:
145
+ samples_out.pop('batch_index')
146
+ elif seed_behavior == "fixed":
147
+ batch_number = samples_out.get("batch_index", [0])[0]
148
+ samples_out["batch_index"] = [batch_number] * latent.shape[0]
149
+
150
+ return (samples_out,)
151
+
152
+ NODE_CLASS_MAPPINGS = {
153
+ "LatentAdd": LatentAdd,
154
+ "LatentSubtract": LatentSubtract,
155
+ "LatentMultiply": LatentMultiply,
156
+ "LatentInterpolate": LatentInterpolate,
157
+ "LatentBatch": LatentBatch,
158
+ "LatentBatchSeedBehavior": LatentBatchSeedBehavior,
159
+ }
ldm_patched/contrib/external_mask.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import numpy as np
6
+ import scipy.ndimage
7
+ import torch
8
+ import ldm_patched.modules.utils
9
+
10
+ from ldm_patched.contrib.external import MAX_RESOLUTION
11
+
12
+ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
13
+ source = source.to(destination.device)
14
+ if resize_source:
15
+ source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
16
+
17
+ source = ldm_patched.modules.utils.repeat_to_batch_size(source, destination.shape[0])
18
+
19
+ x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
20
+ y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
21
+
22
+ left, top = (x // multiplier, y // multiplier)
23
+ right, bottom = (left + source.shape[3], top + source.shape[2],)
24
+
25
+ if mask is None:
26
+ mask = torch.ones_like(source)
27
+ else:
28
+ mask = mask.to(destination.device, copy=True)
29
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
30
+ mask = ldm_patched.modules.utils.repeat_to_batch_size(mask, source.shape[0])
31
+
32
+ # calculate the bounds of the source that will be overlapping the destination
33
+ # this prevents the source trying to overwrite latent pixels that are out of bounds
34
+ # of the destination
35
+ visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),)
36
+
37
+ mask = mask[:, :, :visible_height, :visible_width]
38
+ inverse_mask = torch.ones_like(mask) - mask
39
+
40
+ source_portion = mask * source[:, :, :visible_height, :visible_width]
41
+ destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
42
+
43
+ destination[:, :, top:bottom, left:right] = source_portion + destination_portion
44
+ return destination
45
+
46
+ class LatentCompositeMasked:
47
+ @classmethod
48
+ def INPUT_TYPES(s):
49
+ return {
50
+ "required": {
51
+ "destination": ("LATENT",),
52
+ "source": ("LATENT",),
53
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
54
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
55
+ "resize_source": ("BOOLEAN", {"default": False}),
56
+ },
57
+ "optional": {
58
+ "mask": ("MASK",),
59
+ }
60
+ }
61
+ RETURN_TYPES = ("LATENT",)
62
+ FUNCTION = "composite"
63
+
64
+ CATEGORY = "latent"
65
+
66
+ def composite(self, destination, source, x, y, resize_source, mask = None):
67
+ output = destination.copy()
68
+ destination = destination["samples"].clone()
69
+ source = source["samples"]
70
+ output["samples"] = composite(destination, source, x, y, mask, 8, resize_source)
71
+ return (output,)
72
+
73
+ class ImageCompositeMasked:
74
+ @classmethod
75
+ def INPUT_TYPES(s):
76
+ return {
77
+ "required": {
78
+ "destination": ("IMAGE",),
79
+ "source": ("IMAGE",),
80
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
81
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
82
+ "resize_source": ("BOOLEAN", {"default": False}),
83
+ },
84
+ "optional": {
85
+ "mask": ("MASK",),
86
+ }
87
+ }
88
+ RETURN_TYPES = ("IMAGE",)
89
+ FUNCTION = "composite"
90
+
91
+ CATEGORY = "image"
92
+
93
+ def composite(self, destination, source, x, y, resize_source, mask = None):
94
+ destination = destination.clone().movedim(-1, 1)
95
+ output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
96
+ return (output,)
97
+
98
+ class MaskToImage:
99
+ @classmethod
100
+ def INPUT_TYPES(s):
101
+ return {
102
+ "required": {
103
+ "mask": ("MASK",),
104
+ }
105
+ }
106
+
107
+ CATEGORY = "mask"
108
+
109
+ RETURN_TYPES = ("IMAGE",)
110
+ FUNCTION = "mask_to_image"
111
+
112
+ def mask_to_image(self, mask):
113
+ result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
114
+ return (result,)
115
+
116
+ class ImageToMask:
117
+ @classmethod
118
+ def INPUT_TYPES(s):
119
+ return {
120
+ "required": {
121
+ "image": ("IMAGE",),
122
+ "channel": (["red", "green", "blue", "alpha"],),
123
+ }
124
+ }
125
+
126
+ CATEGORY = "mask"
127
+
128
+ RETURN_TYPES = ("MASK",)
129
+ FUNCTION = "image_to_mask"
130
+
131
+ def image_to_mask(self, image, channel):
132
+ channels = ["red", "green", "blue", "alpha"]
133
+ mask = image[:, :, :, channels.index(channel)]
134
+ return (mask,)
135
+
136
+ class ImageColorToMask:
137
+ @classmethod
138
+ def INPUT_TYPES(s):
139
+ return {
140
+ "required": {
141
+ "image": ("IMAGE",),
142
+ "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
143
+ }
144
+ }
145
+
146
+ CATEGORY = "mask"
147
+
148
+ RETURN_TYPES = ("MASK",)
149
+ FUNCTION = "image_to_mask"
150
+
151
+ def image_to_mask(self, image, color):
152
+ temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
153
+ temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
154
+ mask = torch.where(temp == color, 255, 0).float()
155
+ return (mask,)
156
+
157
+ class SolidMask:
158
+ @classmethod
159
+ def INPUT_TYPES(cls):
160
+ return {
161
+ "required": {
162
+ "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
163
+ "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
164
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
165
+ }
166
+ }
167
+
168
+ CATEGORY = "mask"
169
+
170
+ RETURN_TYPES = ("MASK",)
171
+
172
+ FUNCTION = "solid"
173
+
174
+ def solid(self, value, width, height):
175
+ out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu")
176
+ return (out,)
177
+
178
+ class InvertMask:
179
+ @classmethod
180
+ def INPUT_TYPES(cls):
181
+ return {
182
+ "required": {
183
+ "mask": ("MASK",),
184
+ }
185
+ }
186
+
187
+ CATEGORY = "mask"
188
+
189
+ RETURN_TYPES = ("MASK",)
190
+
191
+ FUNCTION = "invert"
192
+
193
+ def invert(self, mask):
194
+ out = 1.0 - mask
195
+ return (out,)
196
+
197
+ class CropMask:
198
+ @classmethod
199
+ def INPUT_TYPES(cls):
200
+ return {
201
+ "required": {
202
+ "mask": ("MASK",),
203
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
204
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
205
+ "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
206
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
207
+ }
208
+ }
209
+
210
+ CATEGORY = "mask"
211
+
212
+ RETURN_TYPES = ("MASK",)
213
+
214
+ FUNCTION = "crop"
215
+
216
+ def crop(self, mask, x, y, width, height):
217
+ mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
218
+ out = mask[:, y:y + height, x:x + width]
219
+ return (out,)
220
+
221
+ class MaskComposite:
222
+ @classmethod
223
+ def INPUT_TYPES(cls):
224
+ return {
225
+ "required": {
226
+ "destination": ("MASK",),
227
+ "source": ("MASK",),
228
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
229
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
230
+ "operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
231
+ }
232
+ }
233
+
234
+ CATEGORY = "mask"
235
+
236
+ RETURN_TYPES = ("MASK",)
237
+
238
+ FUNCTION = "combine"
239
+
240
+ def combine(self, destination, source, x, y, operation):
241
+ output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
242
+ source = source.reshape((-1, source.shape[-2], source.shape[-1]))
243
+
244
+ left, top = (x, y,)
245
+ right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2]))
246
+ visible_width, visible_height = (right - left, bottom - top,)
247
+
248
+ source_portion = source[:, :visible_height, :visible_width]
249
+ destination_portion = destination[:, top:bottom, left:right]
250
+
251
+ if operation == "multiply":
252
+ output[:, top:bottom, left:right] = destination_portion * source_portion
253
+ elif operation == "add":
254
+ output[:, top:bottom, left:right] = destination_portion + source_portion
255
+ elif operation == "subtract":
256
+ output[:, top:bottom, left:right] = destination_portion - source_portion
257
+ elif operation == "and":
258
+ output[:, top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float()
259
+ elif operation == "or":
260
+ output[:, top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float()
261
+ elif operation == "xor":
262
+ output[:, top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float()
263
+
264
+ output = torch.clamp(output, 0.0, 1.0)
265
+
266
+ return (output,)
267
+
268
+ class FeatherMask:
269
+ @classmethod
270
+ def INPUT_TYPES(cls):
271
+ return {
272
+ "required": {
273
+ "mask": ("MASK",),
274
+ "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
275
+ "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
276
+ "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
277
+ "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
278
+ }
279
+ }
280
+
281
+ CATEGORY = "mask"
282
+
283
+ RETURN_TYPES = ("MASK",)
284
+
285
+ FUNCTION = "feather"
286
+
287
+ def feather(self, mask, left, top, right, bottom):
288
+ output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
289
+
290
+ left = min(left, output.shape[-1])
291
+ right = min(right, output.shape[-1])
292
+ top = min(top, output.shape[-2])
293
+ bottom = min(bottom, output.shape[-2])
294
+
295
+ for x in range(left):
296
+ feather_rate = (x + 1.0) / left
297
+ output[:, :, x] *= feather_rate
298
+
299
+ for x in range(right):
300
+ feather_rate = (x + 1) / right
301
+ output[:, :, -x] *= feather_rate
302
+
303
+ for y in range(top):
304
+ feather_rate = (y + 1) / top
305
+ output[:, y, :] *= feather_rate
306
+
307
+ for y in range(bottom):
308
+ feather_rate = (y + 1) / bottom
309
+ output[:, -y, :] *= feather_rate
310
+
311
+ return (output,)
312
+
313
+ class GrowMask:
314
+ @classmethod
315
+ def INPUT_TYPES(cls):
316
+ return {
317
+ "required": {
318
+ "mask": ("MASK",),
319
+ "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}),
320
+ "tapered_corners": ("BOOLEAN", {"default": True}),
321
+ },
322
+ }
323
+
324
+ CATEGORY = "mask"
325
+
326
+ RETURN_TYPES = ("MASK",)
327
+
328
+ FUNCTION = "expand_mask"
329
+
330
+ def expand_mask(self, mask, expand, tapered_corners):
331
+ c = 0 if tapered_corners else 1
332
+ kernel = np.array([[c, 1, c],
333
+ [1, 1, 1],
334
+ [c, 1, c]])
335
+ mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
336
+ out = []
337
+ for m in mask:
338
+ output = m.numpy()
339
+ for _ in range(abs(expand)):
340
+ if expand < 0:
341
+ output = scipy.ndimage.grey_erosion(output, footprint=kernel)
342
+ else:
343
+ output = scipy.ndimage.grey_dilation(output, footprint=kernel)
344
+ output = torch.from_numpy(output)
345
+ out.append(output)
346
+ return (torch.stack(out, dim=0),)
347
+
348
+
349
+
350
+ NODE_CLASS_MAPPINGS = {
351
+ "LatentCompositeMasked": LatentCompositeMasked,
352
+ "ImageCompositeMasked": ImageCompositeMasked,
353
+ "MaskToImage": MaskToImage,
354
+ "ImageToMask": ImageToMask,
355
+ "ImageColorToMask": ImageColorToMask,
356
+ "SolidMask": SolidMask,
357
+ "InvertMask": InvertMask,
358
+ "CropMask": CropMask,
359
+ "MaskComposite": MaskComposite,
360
+ "FeatherMask": FeatherMask,
361
+ "GrowMask": GrowMask,
362
+ }
363
+
364
+ NODE_DISPLAY_NAME_MAPPINGS = {
365
+ "ImageToMask": "Convert Image to Mask",
366
+ "MaskToImage": "Convert Mask to Image",
367
+ }
ldm_patched/contrib/external_model_advanced.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import ldm_patched.utils.path_utils
6
+ import ldm_patched.modules.sd
7
+ import ldm_patched.modules.model_sampling
8
+ import torch
9
+
10
+ class LCM(ldm_patched.modules.model_sampling.EPS):
11
+ def calculate_denoised(self, sigma, model_output, model_input):
12
+ timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
13
+ sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
14
+ x0 = model_input - model_output * sigma
15
+
16
+ sigma_data = 0.5
17
+ scaled_timestep = timestep * 10.0 #timestep_scaling
18
+
19
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
20
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
21
+
22
+ return c_out * x0 + c_skip * model_input
23
+
24
+ class ModelSamplingDiscreteDistilled(ldm_patched.modules.model_sampling.ModelSamplingDiscrete):
25
+ original_timesteps = 50
26
+
27
+ def __init__(self, model_config=None):
28
+ super().__init__(model_config)
29
+
30
+ self.skip_steps = self.num_timesteps // self.original_timesteps
31
+
32
+ sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
33
+ for x in range(self.original_timesteps):
34
+ sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps]
35
+
36
+ self.set_sigmas(sigmas_valid)
37
+
38
+ def timestep(self, sigma):
39
+ log_sigma = sigma.log()
40
+ dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
41
+ return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device)
42
+
43
+ def sigma(self, timestep):
44
+ t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
45
+ low_idx = t.floor().long()
46
+ high_idx = t.ceil().long()
47
+ w = t.frac()
48
+ log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
49
+ return log_sigma.exp().to(timestep.device)
50
+
51
+
52
+ def rescale_zero_terminal_snr_sigmas(sigmas):
53
+ alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
54
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
55
+
56
+ # Store old values.
57
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
58
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
59
+
60
+ # Shift so the last timestep is zero.
61
+ alphas_bar_sqrt -= (alphas_bar_sqrt_T)
62
+
63
+ # Scale so the first timestep is back to the old value.
64
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
65
+
66
+ # Convert alphas_bar_sqrt to betas
67
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
68
+ alphas_bar[-1] = 4.8973451890853435e-08
69
+ return ((1 - alphas_bar) / alphas_bar) ** 0.5
70
+
71
+ class ModelSamplingDiscrete:
72
+ @classmethod
73
+ def INPUT_TYPES(s):
74
+ return {"required": { "model": ("MODEL",),
75
+ "sampling": (["eps", "v_prediction", "lcm"],),
76
+ "zsnr": ("BOOLEAN", {"default": False}),
77
+ }}
78
+
79
+ RETURN_TYPES = ("MODEL",)
80
+ FUNCTION = "patch"
81
+
82
+ CATEGORY = "advanced/model"
83
+
84
+ def patch(self, model, sampling, zsnr):
85
+ m = model.clone()
86
+
87
+ sampling_base = ldm_patched.modules.model_sampling.ModelSamplingDiscrete
88
+ if sampling == "eps":
89
+ sampling_type = ldm_patched.modules.model_sampling.EPS
90
+ elif sampling == "v_prediction":
91
+ sampling_type = ldm_patched.modules.model_sampling.V_PREDICTION
92
+ elif sampling == "lcm":
93
+ sampling_type = LCM
94
+ sampling_base = ModelSamplingDiscreteDistilled
95
+
96
+ class ModelSamplingAdvanced(sampling_base, sampling_type):
97
+ pass
98
+
99
+ model_sampling = ModelSamplingAdvanced(model.model.model_config)
100
+ if zsnr:
101
+ model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
102
+
103
+ m.add_object_patch("model_sampling", model_sampling)
104
+ return (m, )
105
+
106
+ class ModelSamplingContinuousEDM:
107
+ @classmethod
108
+ def INPUT_TYPES(s):
109
+ return {"required": { "model": ("MODEL",),
110
+ "sampling": (["v_prediction", "eps"],),
111
+ "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
112
+ "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
113
+ }}
114
+
115
+ RETURN_TYPES = ("MODEL",)
116
+ FUNCTION = "patch"
117
+
118
+ CATEGORY = "advanced/model"
119
+
120
+ def patch(self, model, sampling, sigma_max, sigma_min):
121
+ m = model.clone()
122
+
123
+ if sampling == "eps":
124
+ sampling_type = ldm_patched.modules.model_sampling.EPS
125
+ elif sampling == "v_prediction":
126
+ sampling_type = ldm_patched.modules.model_sampling.V_PREDICTION
127
+
128
+ class ModelSamplingAdvanced(ldm_patched.modules.model_sampling.ModelSamplingContinuousEDM, sampling_type):
129
+ pass
130
+
131
+ model_sampling = ModelSamplingAdvanced(model.model.model_config)
132
+ model_sampling.set_sigma_range(sigma_min, sigma_max)
133
+ m.add_object_patch("model_sampling", model_sampling)
134
+ return (m, )
135
+
136
+ class RescaleCFG:
137
+ @classmethod
138
+ def INPUT_TYPES(s):
139
+ return {"required": { "model": ("MODEL",),
140
+ "multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}),
141
+ }}
142
+ RETURN_TYPES = ("MODEL",)
143
+ FUNCTION = "patch"
144
+
145
+ CATEGORY = "advanced/model"
146
+
147
+ def patch(self, model, multiplier):
148
+ def rescale_cfg(args):
149
+ cond = args["cond"]
150
+ uncond = args["uncond"]
151
+ cond_scale = args["cond_scale"]
152
+ sigma = args["sigma"]
153
+ sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
154
+ x_orig = args["input"]
155
+
156
+ #rescale cfg has to be done on v-pred model output
157
+ x = x_orig / (sigma * sigma + 1.0)
158
+ cond = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
159
+ uncond = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
160
+
161
+ #rescalecfg
162
+ x_cfg = uncond + cond_scale * (cond - uncond)
163
+ ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True)
164
+ ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True)
165
+
166
+ x_rescaled = x_cfg * (ro_pos / ro_cfg)
167
+ x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg
168
+
169
+ return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5)
170
+
171
+ m = model.clone()
172
+ m.set_model_sampler_cfg_function(rescale_cfg)
173
+ return (m, )
174
+
175
+ NODE_CLASS_MAPPINGS = {
176
+ "ModelSamplingDiscrete": ModelSamplingDiscrete,
177
+ "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
178
+ "RescaleCFG": RescaleCFG,
179
+ }
ldm_patched/contrib/external_model_downscale.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import torch
6
+ import ldm_patched.modules.utils
7
+
8
+ class PatchModelAddDownscale:
9
+ upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
10
+ @classmethod
11
+ def INPUT_TYPES(s):
12
+ return {"required": { "model": ("MODEL",),
13
+ "block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
14
+ "downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
15
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
16
+ "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
17
+ "downscale_after_skip": ("BOOLEAN", {"default": True}),
18
+ "downscale_method": (s.upscale_methods,),
19
+ "upscale_method": (s.upscale_methods,),
20
+ }}
21
+ RETURN_TYPES = ("MODEL",)
22
+ FUNCTION = "patch"
23
+
24
+ CATEGORY = "_for_testing"
25
+
26
+ def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
27
+ sigma_start = model.model.model_sampling.percent_to_sigma(start_percent)
28
+ sigma_end = model.model.model_sampling.percent_to_sigma(end_percent)
29
+
30
+ def input_block_patch(h, transformer_options):
31
+ if transformer_options["block"][1] == block_number:
32
+ sigma = transformer_options["sigmas"][0].item()
33
+ if sigma <= sigma_start and sigma >= sigma_end:
34
+ h = ldm_patched.modules.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
35
+ return h
36
+
37
+ def output_block_patch(h, hsp, transformer_options):
38
+ if h.shape[2] != hsp.shape[2]:
39
+ h = ldm_patched.modules.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
40
+ return h, hsp
41
+
42
+ m = model.clone()
43
+ if downscale_after_skip:
44
+ m.set_model_input_block_patch_after_skip(input_block_patch)
45
+ else:
46
+ m.set_model_input_block_patch(input_block_patch)
47
+ m.set_model_output_block_patch(output_block_patch)
48
+ return (m, )
49
+
50
+ NODE_CLASS_MAPPINGS = {
51
+ "PatchModelAddDownscale": PatchModelAddDownscale,
52
+ }
53
+
54
+ NODE_DISPLAY_NAME_MAPPINGS = {
55
+ # Sampling
56
+ "PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
57
+ }
ldm_patched/contrib/external_model_merging.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import ldm_patched.modules.sd
6
+ import ldm_patched.modules.utils
7
+ import ldm_patched.modules.model_base
8
+ import ldm_patched.modules.model_management
9
+
10
+ import ldm_patched.utils.path_utils
11
+ import json
12
+ import os
13
+
14
+ from ldm_patched.modules.args_parser import args
15
+
16
+ class ModelMergeSimple:
17
+ @classmethod
18
+ def INPUT_TYPES(s):
19
+ return {"required": { "model1": ("MODEL",),
20
+ "model2": ("MODEL",),
21
+ "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
22
+ }}
23
+ RETURN_TYPES = ("MODEL",)
24
+ FUNCTION = "merge"
25
+
26
+ CATEGORY = "advanced/model_merging"
27
+
28
+ def merge(self, model1, model2, ratio):
29
+ m = model1.clone()
30
+ kp = model2.get_key_patches("diffusion_model.")
31
+ for k in kp:
32
+ m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
33
+ return (m, )
34
+
35
+ class ModelSubtract:
36
+ @classmethod
37
+ def INPUT_TYPES(s):
38
+ return {"required": { "model1": ("MODEL",),
39
+ "model2": ("MODEL",),
40
+ "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
41
+ }}
42
+ RETURN_TYPES = ("MODEL",)
43
+ FUNCTION = "merge"
44
+
45
+ CATEGORY = "advanced/model_merging"
46
+
47
+ def merge(self, model1, model2, multiplier):
48
+ m = model1.clone()
49
+ kp = model2.get_key_patches("diffusion_model.")
50
+ for k in kp:
51
+ m.add_patches({k: kp[k]}, - multiplier, multiplier)
52
+ return (m, )
53
+
54
+ class ModelAdd:
55
+ @classmethod
56
+ def INPUT_TYPES(s):
57
+ return {"required": { "model1": ("MODEL",),
58
+ "model2": ("MODEL",),
59
+ }}
60
+ RETURN_TYPES = ("MODEL",)
61
+ FUNCTION = "merge"
62
+
63
+ CATEGORY = "advanced/model_merging"
64
+
65
+ def merge(self, model1, model2):
66
+ m = model1.clone()
67
+ kp = model2.get_key_patches("diffusion_model.")
68
+ for k in kp:
69
+ m.add_patches({k: kp[k]}, 1.0, 1.0)
70
+ return (m, )
71
+
72
+
73
+ class CLIPMergeSimple:
74
+ @classmethod
75
+ def INPUT_TYPES(s):
76
+ return {"required": { "clip1": ("CLIP",),
77
+ "clip2": ("CLIP",),
78
+ "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
79
+ }}
80
+ RETURN_TYPES = ("CLIP",)
81
+ FUNCTION = "merge"
82
+
83
+ CATEGORY = "advanced/model_merging"
84
+
85
+ def merge(self, clip1, clip2, ratio):
86
+ m = clip1.clone()
87
+ kp = clip2.get_key_patches()
88
+ for k in kp:
89
+ if k.endswith(".position_ids") or k.endswith(".logit_scale"):
90
+ continue
91
+ m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
92
+ return (m, )
93
+
94
+ class ModelMergeBlocks:
95
+ @classmethod
96
+ def INPUT_TYPES(s):
97
+ return {"required": { "model1": ("MODEL",),
98
+ "model2": ("MODEL",),
99
+ "input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
100
+ "middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
101
+ "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
102
+ }}
103
+ RETURN_TYPES = ("MODEL",)
104
+ FUNCTION = "merge"
105
+
106
+ CATEGORY = "advanced/model_merging"
107
+
108
+ def merge(self, model1, model2, **kwargs):
109
+ m = model1.clone()
110
+ kp = model2.get_key_patches("diffusion_model.")
111
+ default_ratio = next(iter(kwargs.values()))
112
+
113
+ for k in kp:
114
+ ratio = default_ratio
115
+ k_unet = k[len("diffusion_model."):]
116
+
117
+ last_arg_size = 0
118
+ for arg in kwargs:
119
+ if k_unet.startswith(arg) and last_arg_size < len(arg):
120
+ ratio = kwargs[arg]
121
+ last_arg_size = len(arg)
122
+
123
+ m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
124
+ return (m, )
125
+
126
+ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
127
+ full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, output_dir)
128
+ prompt_info = ""
129
+ if prompt is not None:
130
+ prompt_info = json.dumps(prompt)
131
+
132
+ metadata = {}
133
+
134
+ enable_modelspec = True
135
+ if isinstance(model.model, ldm_patched.modules.model_base.SDXL):
136
+ metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
137
+ elif isinstance(model.model, ldm_patched.modules.model_base.SDXLRefiner):
138
+ metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
139
+ else:
140
+ enable_modelspec = False
141
+
142
+ if enable_modelspec:
143
+ metadata["modelspec.sai_model_spec"] = "1.0.0"
144
+ metadata["modelspec.implementation"] = "sgm"
145
+ metadata["modelspec.title"] = "{} {}".format(filename, counter)
146
+
147
+ #TODO:
148
+ # "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
149
+ # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
150
+ # "v2-inpainting"
151
+
152
+ if model.model.model_type == ldm_patched.modules.model_base.ModelType.EPS:
153
+ metadata["modelspec.predict_key"] = "epsilon"
154
+ elif model.model.model_type == ldm_patched.modules.model_base.ModelType.V_PREDICTION:
155
+ metadata["modelspec.predict_key"] = "v"
156
+
157
+ if not args.disable_server_info:
158
+ metadata["prompt"] = prompt_info
159
+ if extra_pnginfo is not None:
160
+ for x in extra_pnginfo:
161
+ metadata[x] = json.dumps(extra_pnginfo[x])
162
+
163
+ output_checkpoint = f"{filename}_{counter:05}_.safetensors"
164
+ output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
165
+
166
+ ldm_patched.modules.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata)
167
+
168
+ class CheckpointSave:
169
+ def __init__(self):
170
+ self.output_dir = ldm_patched.utils.path_utils.get_output_directory()
171
+
172
+ @classmethod
173
+ def INPUT_TYPES(s):
174
+ return {"required": { "model": ("MODEL",),
175
+ "clip": ("CLIP",),
176
+ "vae": ("VAE",),
177
+ "filename_prefix": ("STRING", {"default": "checkpoints/ldm_patched"}),},
178
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
179
+ RETURN_TYPES = ()
180
+ FUNCTION = "save"
181
+ OUTPUT_NODE = True
182
+
183
+ CATEGORY = "advanced/model_merging"
184
+
185
+ def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
186
+ save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
187
+ return {}
188
+
189
+ class CLIPSave:
190
+ def __init__(self):
191
+ self.output_dir = ldm_patched.utils.path_utils.get_output_directory()
192
+
193
+ @classmethod
194
+ def INPUT_TYPES(s):
195
+ return {"required": { "clip": ("CLIP",),
196
+ "filename_prefix": ("STRING", {"default": "clip/ldm_patched"}),},
197
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
198
+ RETURN_TYPES = ()
199
+ FUNCTION = "save"
200
+ OUTPUT_NODE = True
201
+
202
+ CATEGORY = "advanced/model_merging"
203
+
204
+ def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
205
+ prompt_info = ""
206
+ if prompt is not None:
207
+ prompt_info = json.dumps(prompt)
208
+
209
+ metadata = {}
210
+ if not args.disable_server_info:
211
+ metadata["prompt"] = prompt_info
212
+ if extra_pnginfo is not None:
213
+ for x in extra_pnginfo:
214
+ metadata[x] = json.dumps(extra_pnginfo[x])
215
+
216
+ ldm_patched.modules.model_management.load_models_gpu([clip.load_model()])
217
+ clip_sd = clip.get_sd()
218
+
219
+ for prefix in ["clip_l.", "clip_g.", ""]:
220
+ k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
221
+ current_clip_sd = {}
222
+ for x in k:
223
+ current_clip_sd[x] = clip_sd.pop(x)
224
+ if len(current_clip_sd) == 0:
225
+ continue
226
+
227
+ p = prefix[:-1]
228
+ replace_prefix = {}
229
+ filename_prefix_ = filename_prefix
230
+ if len(p) > 0:
231
+ filename_prefix_ = "{}_{}".format(filename_prefix_, p)
232
+ replace_prefix[prefix] = ""
233
+ replace_prefix["transformer."] = ""
234
+
235
+ full_output_folder, filename, counter, subfolder, filename_prefix_ = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix_, self.output_dir)
236
+
237
+ output_checkpoint = f"{filename}_{counter:05}_.safetensors"
238
+ output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
239
+
240
+ current_clip_sd = ldm_patched.modules.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
241
+
242
+ ldm_patched.modules.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
243
+ return {}
244
+
245
+ class VAESave:
246
+ def __init__(self):
247
+ self.output_dir = ldm_patched.utils.path_utils.get_output_directory()
248
+
249
+ @classmethod
250
+ def INPUT_TYPES(s):
251
+ return {"required": { "vae": ("VAE",),
252
+ "filename_prefix": ("STRING", {"default": "vae/ldm_patched_vae"}),},
253
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
254
+ RETURN_TYPES = ()
255
+ FUNCTION = "save"
256
+ OUTPUT_NODE = True
257
+
258
+ CATEGORY = "advanced/model_merging"
259
+
260
+ def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None):
261
+ full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, self.output_dir)
262
+ prompt_info = ""
263
+ if prompt is not None:
264
+ prompt_info = json.dumps(prompt)
265
+
266
+ metadata = {}
267
+ if not args.disable_server_info:
268
+ metadata["prompt"] = prompt_info
269
+ if extra_pnginfo is not None:
270
+ for x in extra_pnginfo:
271
+ metadata[x] = json.dumps(extra_pnginfo[x])
272
+
273
+ output_checkpoint = f"{filename}_{counter:05}_.safetensors"
274
+ output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
275
+
276
+ ldm_patched.modules.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
277
+ return {}
278
+
279
+ NODE_CLASS_MAPPINGS = {
280
+ "ModelMergeSimple": ModelMergeSimple,
281
+ "ModelMergeBlocks": ModelMergeBlocks,
282
+ "ModelMergeSubtract": ModelSubtract,
283
+ "ModelMergeAdd": ModelAdd,
284
+ "CheckpointSave": CheckpointSave,
285
+ "CLIPMergeSimple": CLIPMergeSimple,
286
+ "CLIPSave": CLIPSave,
287
+ "VAESave": VAESave,
288
+ }
ldm_patched/contrib/external_perpneg.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import torch
6
+ import ldm_patched.modules.model_management
7
+ import ldm_patched.modules.sample
8
+ import ldm_patched.modules.samplers
9
+ import ldm_patched.modules.utils
10
+
11
+
12
+ class PerpNeg:
13
+ @classmethod
14
+ def INPUT_TYPES(s):
15
+ return {"required": {"model": ("MODEL", ),
16
+ "empty_conditioning": ("CONDITIONING", ),
17
+ "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}),
18
+ }}
19
+ RETURN_TYPES = ("MODEL",)
20
+ FUNCTION = "patch"
21
+
22
+ CATEGORY = "_for_testing"
23
+
24
+ def patch(self, model, empty_conditioning, neg_scale):
25
+ m = model.clone()
26
+ nocond = ldm_patched.modules.sample.convert_cond(empty_conditioning)
27
+
28
+ def cfg_function(args):
29
+ model = args["model"]
30
+ noise_pred_pos = args["cond_denoised"]
31
+ noise_pred_neg = args["uncond_denoised"]
32
+ cond_scale = args["cond_scale"]
33
+ x = args["input"]
34
+ sigma = args["sigma"]
35
+ model_options = args["model_options"]
36
+ nocond_processed = ldm_patched.modules.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
37
+
38
+ (noise_pred_nocond, _) = ldm_patched.modules.samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options)
39
+
40
+ pos = noise_pred_pos - noise_pred_nocond
41
+ neg = noise_pred_neg - noise_pred_nocond
42
+ perp = ((torch.mul(pos, neg).sum())/(torch.norm(neg)**2)) * neg
43
+ perp_neg = perp * neg_scale
44
+ cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg)
45
+ cfg_result = x - cfg_result
46
+ return cfg_result
47
+
48
+ m.set_model_sampler_cfg_function(cfg_function)
49
+
50
+ return (m, )
51
+
52
+
53
+ NODE_CLASS_MAPPINGS = {
54
+ "PerpNeg": PerpNeg,
55
+ }
56
+
57
+ NODE_DISPLAY_NAME_MAPPINGS = {
58
+ "PerpNeg": "Perp-Neg",
59
+ }
ldm_patched/contrib/external_photomaker.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import ldm_patched.utils.path_utils
6
+ import ldm_patched.modules.clip_model
7
+ import ldm_patched.modules.clip_vision
8
+ import ldm_patched.modules.ops
9
+
10
+ # code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0
11
+ VISION_CONFIG_DICT = {
12
+ "hidden_size": 1024,
13
+ "image_size": 224,
14
+ "intermediate_size": 4096,
15
+ "num_attention_heads": 16,
16
+ "num_channels": 3,
17
+ "num_hidden_layers": 24,
18
+ "patch_size": 14,
19
+ "projection_dim": 768,
20
+ "hidden_act": "quick_gelu",
21
+ }
22
+
23
+ class MLP(nn.Module):
24
+ def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True, operations=ldm_patched.modules.ops):
25
+ super().__init__()
26
+ if use_residual:
27
+ assert in_dim == out_dim
28
+ self.layernorm = operations.LayerNorm(in_dim)
29
+ self.fc1 = operations.Linear(in_dim, hidden_dim)
30
+ self.fc2 = operations.Linear(hidden_dim, out_dim)
31
+ self.use_residual = use_residual
32
+ self.act_fn = nn.GELU()
33
+
34
+ def forward(self, x):
35
+ residual = x
36
+ x = self.layernorm(x)
37
+ x = self.fc1(x)
38
+ x = self.act_fn(x)
39
+ x = self.fc2(x)
40
+ if self.use_residual:
41
+ x = x + residual
42
+ return x
43
+
44
+
45
+ class FuseModule(nn.Module):
46
+ def __init__(self, embed_dim, operations):
47
+ super().__init__()
48
+ self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False, operations=operations)
49
+ self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True, operations=operations)
50
+ self.layer_norm = operations.LayerNorm(embed_dim)
51
+
52
+ def fuse_fn(self, prompt_embeds, id_embeds):
53
+ stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
54
+ stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
55
+ stacked_id_embeds = self.mlp2(stacked_id_embeds)
56
+ stacked_id_embeds = self.layer_norm(stacked_id_embeds)
57
+ return stacked_id_embeds
58
+
59
+ def forward(
60
+ self,
61
+ prompt_embeds,
62
+ id_embeds,
63
+ class_tokens_mask,
64
+ ) -> torch.Tensor:
65
+ # id_embeds shape: [b, max_num_inputs, 1, 2048]
66
+ id_embeds = id_embeds.to(prompt_embeds.dtype)
67
+ num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case
68
+ batch_size, max_num_inputs = id_embeds.shape[:2]
69
+ # seq_length: 77
70
+ seq_length = prompt_embeds.shape[1]
71
+ # flat_id_embeds shape: [b*max_num_inputs, 1, 2048]
72
+ flat_id_embeds = id_embeds.view(
73
+ -1, id_embeds.shape[-2], id_embeds.shape[-1]
74
+ )
75
+ # valid_id_mask [b*max_num_inputs]
76
+ valid_id_mask = (
77
+ torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :]
78
+ < num_inputs[:, None]
79
+ )
80
+ valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
81
+
82
+ prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1])
83
+ class_tokens_mask = class_tokens_mask.view(-1)
84
+ valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1])
85
+ # slice out the image token embeddings
86
+ image_token_embeds = prompt_embeds[class_tokens_mask]
87
+ stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
88
+ assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
89
+ prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
90
+ updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
91
+ return updated_prompt_embeds
92
+
93
+ class PhotoMakerIDEncoder(ldm_patched.modules.clip_model.CLIPVisionModelProjection):
94
+ def __init__(self):
95
+ self.load_device = ldm_patched.modules.model_management.text_encoder_device()
96
+ offload_device = ldm_patched.modules.model_management.text_encoder_offload_device()
97
+ dtype = ldm_patched.modules.model_management.text_encoder_dtype(self.load_device)
98
+
99
+ super().__init__(VISION_CONFIG_DICT, dtype, offload_device, ldm_patched.modules.ops.manual_cast)
100
+ self.visual_projection_2 = ldm_patched.modules.ops.manual_cast.Linear(1024, 1280, bias=False)
101
+ self.fuse_module = FuseModule(2048, ldm_patched.modules.ops.manual_cast)
102
+
103
+ def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
104
+ b, num_inputs, c, h, w = id_pixel_values.shape
105
+ id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
106
+
107
+ shared_id_embeds = self.vision_model(id_pixel_values)[2]
108
+ id_embeds = self.visual_projection(shared_id_embeds)
109
+ id_embeds_2 = self.visual_projection_2(shared_id_embeds)
110
+
111
+ id_embeds = id_embeds.view(b, num_inputs, 1, -1)
112
+ id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
113
+
114
+ id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
115
+ updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
116
+
117
+ return updated_prompt_embeds
118
+
119
+
120
+ class PhotoMakerLoader:
121
+ @classmethod
122
+ def INPUT_TYPES(s):
123
+ return {"required": { "photomaker_model_name": (ldm_patched.utils.path_utils.get_filename_list("photomaker"), )}}
124
+
125
+ RETURN_TYPES = ("PHOTOMAKER",)
126
+ FUNCTION = "load_photomaker_model"
127
+
128
+ CATEGORY = "_for_testing/photomaker"
129
+
130
+ def load_photomaker_model(self, photomaker_model_name):
131
+ photomaker_model_path = ldm_patched.utils.path_utils.get_full_path("photomaker", photomaker_model_name)
132
+ photomaker_model = PhotoMakerIDEncoder()
133
+ data = ldm_patched.modules.utils.load_torch_file(photomaker_model_path, safe_load=True)
134
+ if "id_encoder" in data:
135
+ data = data["id_encoder"]
136
+ photomaker_model.load_state_dict(data)
137
+ return (photomaker_model,)
138
+
139
+
140
+ class PhotoMakerEncode:
141
+ @classmethod
142
+ def INPUT_TYPES(s):
143
+ return {"required": { "photomaker": ("PHOTOMAKER",),
144
+ "image": ("IMAGE",),
145
+ "clip": ("CLIP", ),
146
+ "text": ("STRING", {"multiline": True, "default": "photograph of photomaker"}),
147
+ }}
148
+
149
+ RETURN_TYPES = ("CONDITIONING",)
150
+ FUNCTION = "apply_photomaker"
151
+
152
+ CATEGORY = "_for_testing/photomaker"
153
+
154
+ def apply_photomaker(self, photomaker, image, clip, text):
155
+ special_token = "photomaker"
156
+ pixel_values = ldm_patched.modules.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float()
157
+ try:
158
+ index = text.split(" ").index(special_token) + 1
159
+ except ValueError:
160
+ index = -1
161
+ tokens = clip.tokenize(text, return_word_ids=True)
162
+ out_tokens = {}
163
+ for k in tokens:
164
+ out_tokens[k] = []
165
+ for t in tokens[k]:
166
+ f = list(filter(lambda x: x[2] != index, t))
167
+ while len(f) < len(t):
168
+ f.append(t[-1])
169
+ out_tokens[k].append(f)
170
+
171
+ cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True)
172
+
173
+ if index > 0:
174
+ token_index = index - 1
175
+ num_id_images = 1
176
+ class_tokens_mask = [True if token_index <= i < token_index+num_id_images else False for i in range(77)]
177
+ out = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device),
178
+ class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0))
179
+ else:
180
+ out = cond
181
+
182
+ return ([[out, {"pooled_output": pooled}]], )
183
+
184
+
185
+ NODE_CLASS_MAPPINGS = {
186
+ "PhotoMakerLoader": PhotoMakerLoader,
187
+ "PhotoMakerEncode": PhotoMakerEncode,
188
+ }
189
+
ldm_patched/contrib/external_post_processing.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from PIL import Image
9
+ import math
10
+
11
+ import ldm_patched.modules.utils
12
+
13
+
14
+ class Blend:
15
+ def __init__(self):
16
+ pass
17
+
18
+ @classmethod
19
+ def INPUT_TYPES(s):
20
+ return {
21
+ "required": {
22
+ "image1": ("IMAGE",),
23
+ "image2": ("IMAGE",),
24
+ "blend_factor": ("FLOAT", {
25
+ "default": 0.5,
26
+ "min": 0.0,
27
+ "max": 1.0,
28
+ "step": 0.01
29
+ }),
30
+ "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],),
31
+ },
32
+ }
33
+
34
+ RETURN_TYPES = ("IMAGE",)
35
+ FUNCTION = "blend_images"
36
+
37
+ CATEGORY = "image/postprocessing"
38
+
39
+ def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
40
+ image2 = image2.to(image1.device)
41
+ if image1.shape != image2.shape:
42
+ image2 = image2.permute(0, 3, 1, 2)
43
+ image2 = ldm_patched.modules.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
44
+ image2 = image2.permute(0, 2, 3, 1)
45
+
46
+ blended_image = self.blend_mode(image1, image2, blend_mode)
47
+ blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
48
+ blended_image = torch.clamp(blended_image, 0, 1)
49
+ return (blended_image,)
50
+
51
+ def blend_mode(self, img1, img2, mode):
52
+ if mode == "normal":
53
+ return img2
54
+ elif mode == "multiply":
55
+ return img1 * img2
56
+ elif mode == "screen":
57
+ return 1 - (1 - img1) * (1 - img2)
58
+ elif mode == "overlay":
59
+ return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
60
+ elif mode == "soft_light":
61
+ return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
62
+ elif mode == "difference":
63
+ return img1 - img2
64
+ else:
65
+ raise ValueError(f"Unsupported blend mode: {mode}")
66
+
67
+ def g(self, x):
68
+ return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
69
+
70
+ def gaussian_kernel(kernel_size: int, sigma: float, device=None):
71
+ x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
72
+ d = torch.sqrt(x * x + y * y)
73
+ g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
74
+ return g / g.sum()
75
+
76
+ class Blur:
77
+ def __init__(self):
78
+ pass
79
+
80
+ @classmethod
81
+ def INPUT_TYPES(s):
82
+ return {
83
+ "required": {
84
+ "image": ("IMAGE",),
85
+ "blur_radius": ("INT", {
86
+ "default": 1,
87
+ "min": 1,
88
+ "max": 31,
89
+ "step": 1
90
+ }),
91
+ "sigma": ("FLOAT", {
92
+ "default": 1.0,
93
+ "min": 0.1,
94
+ "max": 10.0,
95
+ "step": 0.1
96
+ }),
97
+ },
98
+ }
99
+
100
+ RETURN_TYPES = ("IMAGE",)
101
+ FUNCTION = "blur"
102
+
103
+ CATEGORY = "image/postprocessing"
104
+
105
+ def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
106
+ if blur_radius == 0:
107
+ return (image,)
108
+
109
+ batch_size, height, width, channels = image.shape
110
+
111
+ kernel_size = blur_radius * 2 + 1
112
+ kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1)
113
+
114
+ image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
115
+ padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
116
+ blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
117
+ blurred = blurred.permute(0, 2, 3, 1)
118
+
119
+ return (blurred,)
120
+
121
+ class Quantize:
122
+ def __init__(self):
123
+ pass
124
+
125
+ @classmethod
126
+ def INPUT_TYPES(s):
127
+ return {
128
+ "required": {
129
+ "image": ("IMAGE",),
130
+ "colors": ("INT", {
131
+ "default": 256,
132
+ "min": 1,
133
+ "max": 256,
134
+ "step": 1
135
+ }),
136
+ "dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],),
137
+ },
138
+ }
139
+
140
+ RETURN_TYPES = ("IMAGE",)
141
+ FUNCTION = "quantize"
142
+
143
+ CATEGORY = "image/postprocessing"
144
+
145
+ def bayer(im, pal_im, order):
146
+ def normalized_bayer_matrix(n):
147
+ if n == 0:
148
+ return np.zeros((1,1), "float32")
149
+ else:
150
+ q = 4 ** n
151
+ m = q * normalized_bayer_matrix(n - 1)
152
+ return np.bmat(((m-1.5, m+0.5), (m+1.5, m-0.5))) / q
153
+
154
+ num_colors = len(pal_im.getpalette()) // 3
155
+ spread = 2 * 256 / num_colors
156
+ bayer_n = int(math.log2(order))
157
+ bayer_matrix = torch.from_numpy(spread * normalized_bayer_matrix(bayer_n) + 0.5)
158
+
159
+ result = torch.from_numpy(np.array(im).astype(np.float32))
160
+ tw = math.ceil(result.shape[0] / bayer_matrix.shape[0])
161
+ th = math.ceil(result.shape[1] / bayer_matrix.shape[1])
162
+ tiled_matrix = bayer_matrix.tile(tw, th).unsqueeze(-1)
163
+ result.add_(tiled_matrix[:result.shape[0],:result.shape[1]]).clamp_(0, 255)
164
+ result = result.to(dtype=torch.uint8)
165
+
166
+ im = Image.fromarray(result.cpu().numpy())
167
+ im = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
168
+ return im
169
+
170
+ def quantize(self, image: torch.Tensor, colors: int, dither: str):
171
+ batch_size, height, width, _ = image.shape
172
+ result = torch.zeros_like(image)
173
+
174
+ for b in range(batch_size):
175
+ im = Image.fromarray((image[b] * 255).to(torch.uint8).numpy(), mode='RGB')
176
+
177
+ pal_im = im.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
178
+
179
+ if dither == "none":
180
+ quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
181
+ elif dither == "floyd-steinberg":
182
+ quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.FLOYDSTEINBERG)
183
+ elif dither.startswith("bayer"):
184
+ order = int(dither.split('-')[-1])
185
+ quantized_image = Quantize.bayer(im, pal_im, order)
186
+
187
+ quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
188
+ result[b] = quantized_array
189
+
190
+ return (result,)
191
+
192
+ class Sharpen:
193
+ def __init__(self):
194
+ pass
195
+
196
+ @classmethod
197
+ def INPUT_TYPES(s):
198
+ return {
199
+ "required": {
200
+ "image": ("IMAGE",),
201
+ "sharpen_radius": ("INT", {
202
+ "default": 1,
203
+ "min": 1,
204
+ "max": 31,
205
+ "step": 1
206
+ }),
207
+ "sigma": ("FLOAT", {
208
+ "default": 1.0,
209
+ "min": 0.1,
210
+ "max": 10.0,
211
+ "step": 0.1
212
+ }),
213
+ "alpha": ("FLOAT", {
214
+ "default": 1.0,
215
+ "min": 0.0,
216
+ "max": 5.0,
217
+ "step": 0.1
218
+ }),
219
+ },
220
+ }
221
+
222
+ RETURN_TYPES = ("IMAGE",)
223
+ FUNCTION = "sharpen"
224
+
225
+ CATEGORY = "image/postprocessing"
226
+
227
+ def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
228
+ if sharpen_radius == 0:
229
+ return (image,)
230
+
231
+ batch_size, height, width, channels = image.shape
232
+
233
+ kernel_size = sharpen_radius * 2 + 1
234
+ kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
235
+ center = kernel_size // 2
236
+ kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
237
+ kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
238
+
239
+ tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
240
+ tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
241
+ sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
242
+ sharpened = sharpened.permute(0, 2, 3, 1)
243
+
244
+ result = torch.clamp(sharpened, 0, 1)
245
+
246
+ return (result,)
247
+
248
+ class ImageScaleToTotalPixels:
249
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
250
+ crop_methods = ["disabled", "center"]
251
+
252
+ @classmethod
253
+ def INPUT_TYPES(s):
254
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
255
+ "megapixels": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 16.0, "step": 0.01}),
256
+ }}
257
+ RETURN_TYPES = ("IMAGE",)
258
+ FUNCTION = "upscale"
259
+
260
+ CATEGORY = "image/upscaling"
261
+
262
+ def upscale(self, image, upscale_method, megapixels):
263
+ samples = image.movedim(-1,1)
264
+ total = int(megapixels * 1024 * 1024)
265
+
266
+ scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
267
+ width = round(samples.shape[3] * scale_by)
268
+ height = round(samples.shape[2] * scale_by)
269
+
270
+ s = ldm_patched.modules.utils.common_upscale(samples, width, height, upscale_method, "disabled")
271
+ s = s.movedim(1,-1)
272
+ return (s,)
273
+
274
+ NODE_CLASS_MAPPINGS = {
275
+ "ImageBlend": Blend,
276
+ "ImageBlur": Blur,
277
+ "ImageQuantize": Quantize,
278
+ "ImageSharpen": Sharpen,
279
+ "ImageScaleToTotalPixels": ImageScaleToTotalPixels,
280
+ }
ldm_patched/contrib/external_rebatch.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import torch
6
+
7
+ class LatentRebatch:
8
+ @classmethod
9
+ def INPUT_TYPES(s):
10
+ return {"required": { "latents": ("LATENT",),
11
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
12
+ }}
13
+ RETURN_TYPES = ("LATENT",)
14
+ INPUT_IS_LIST = True
15
+ OUTPUT_IS_LIST = (True, )
16
+
17
+ FUNCTION = "rebatch"
18
+
19
+ CATEGORY = "latent/batch"
20
+
21
+ @staticmethod
22
+ def get_batch(latents, list_ind, offset):
23
+ '''prepare a batch out of the list of latents'''
24
+ samples = latents[list_ind]['samples']
25
+ shape = samples.shape
26
+ mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
27
+ if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
28
+ torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
29
+ if mask.shape[0] < samples.shape[0]:
30
+ mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
31
+ if 'batch_index' in latents[list_ind]:
32
+ batch_inds = latents[list_ind]['batch_index']
33
+ else:
34
+ batch_inds = [x+offset for x in range(shape[0])]
35
+ return samples, mask, batch_inds
36
+
37
+ @staticmethod
38
+ def get_slices(indexable, num, batch_size):
39
+ '''divides an indexable object into num slices of length batch_size, and a remainder'''
40
+ slices = []
41
+ for i in range(num):
42
+ slices.append(indexable[i*batch_size:(i+1)*batch_size])
43
+ if num * batch_size < len(indexable):
44
+ return slices, indexable[num * batch_size:]
45
+ else:
46
+ return slices, None
47
+
48
+ @staticmethod
49
+ def slice_batch(batch, num, batch_size):
50
+ result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
51
+ return list(zip(*result))
52
+
53
+ @staticmethod
54
+ def cat_batch(batch1, batch2):
55
+ if batch1[0] is None:
56
+ return batch2
57
+ result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
58
+ return result
59
+
60
+ def rebatch(self, latents, batch_size):
61
+ batch_size = batch_size[0]
62
+
63
+ output_list = []
64
+ current_batch = (None, None, None)
65
+ processed = 0
66
+
67
+ for i in range(len(latents)):
68
+ # fetch new entry of list
69
+ #samples, masks, indices = self.get_batch(latents, i)
70
+ next_batch = self.get_batch(latents, i, processed)
71
+ processed += len(next_batch[2])
72
+ # set to current if current is None
73
+ if current_batch[0] is None:
74
+ current_batch = next_batch
75
+ # add previous to list if dimensions do not match
76
+ elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
77
+ sliced, _ = self.slice_batch(current_batch, 1, batch_size)
78
+ output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
79
+ current_batch = next_batch
80
+ # cat if everything checks out
81
+ else:
82
+ current_batch = self.cat_batch(current_batch, next_batch)
83
+
84
+ # add to list if dimensions gone above target batch size
85
+ if current_batch[0].shape[0] > batch_size:
86
+ num = current_batch[0].shape[0] // batch_size
87
+ sliced, remainder = self.slice_batch(current_batch, num, batch_size)
88
+
89
+ for i in range(num):
90
+ output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
91
+
92
+ current_batch = remainder
93
+
94
+ #add remainder
95
+ if current_batch[0] is not None:
96
+ sliced, _ = self.slice_batch(current_batch, 1, batch_size)
97
+ output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
98
+
99
+ #get rid of empty masks
100
+ for s in output_list:
101
+ if s['noise_mask'].mean() == 1.0:
102
+ del s['noise_mask']
103
+
104
+ return (output_list,)
105
+
106
+ class ImageRebatch:
107
+ @classmethod
108
+ def INPUT_TYPES(s):
109
+ return {"required": { "images": ("IMAGE",),
110
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
111
+ }}
112
+ RETURN_TYPES = ("IMAGE",)
113
+ INPUT_IS_LIST = True
114
+ OUTPUT_IS_LIST = (True, )
115
+
116
+ FUNCTION = "rebatch"
117
+
118
+ CATEGORY = "image/batch"
119
+
120
+ def rebatch(self, images, batch_size):
121
+ batch_size = batch_size[0]
122
+
123
+ output_list = []
124
+ all_images = []
125
+ for img in images:
126
+ for i in range(img.shape[0]):
127
+ all_images.append(img[i:i+1])
128
+
129
+ for i in range(0, len(all_images), batch_size):
130
+ output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
131
+
132
+ return (output_list,)
133
+
134
+ NODE_CLASS_MAPPINGS = {
135
+ "RebatchLatents": LatentRebatch,
136
+ "RebatchImages": ImageRebatch,
137
+ }
138
+
139
+ NODE_DISPLAY_NAME_MAPPINGS = {
140
+ "RebatchLatents": "Rebatch Latents",
141
+ "RebatchImages": "Rebatch Images",
142
+ }
ldm_patched/contrib/external_sag.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import torch
4
+ from torch import einsum
5
+ import torch.nn.functional as F
6
+ import math
7
+
8
+ from einops import rearrange, repeat
9
+ import os
10
+ from ldm_patched.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
11
+ import ldm_patched.modules.samplers
12
+
13
+ # from ldm_patched.modules/ldm/modules/attention.py
14
+ # but modified to return attention scores as well as output
15
+ def attention_basic_with_sim(q, k, v, heads, mask=None):
16
+ b, _, dim_head = q.shape
17
+ dim_head //= heads
18
+ scale = dim_head ** -0.5
19
+
20
+ h = heads
21
+ q, k, v = map(
22
+ lambda t: t.unsqueeze(3)
23
+ .reshape(b, -1, heads, dim_head)
24
+ .permute(0, 2, 1, 3)
25
+ .reshape(b * heads, -1, dim_head)
26
+ .contiguous(),
27
+ (q, k, v),
28
+ )
29
+
30
+ # force cast to fp32 to avoid overflowing
31
+ if _ATTN_PRECISION =="fp32":
32
+ sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
33
+ else:
34
+ sim = einsum('b i d, b j d -> b i j', q, k) * scale
35
+
36
+ del q, k
37
+
38
+ if mask is not None:
39
+ mask = rearrange(mask, 'b ... -> b (...)')
40
+ max_neg_value = -torch.finfo(sim.dtype).max
41
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
42
+ sim.masked_fill_(~mask, max_neg_value)
43
+
44
+ # attention, what we cannot get enough of
45
+ sim = sim.softmax(dim=-1)
46
+
47
+ out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
48
+ out = (
49
+ out.unsqueeze(0)
50
+ .reshape(b, heads, -1, dim_head)
51
+ .permute(0, 2, 1, 3)
52
+ .reshape(b, -1, heads * dim_head)
53
+ )
54
+ return (out, sim)
55
+
56
+ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
57
+ # reshape and GAP the attention map
58
+ _, hw1, hw2 = attn.shape
59
+ b, _, lh, lw = x0.shape
60
+ attn = attn.reshape(b, -1, hw1, hw2)
61
+ # Global Average Pool
62
+ mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
63
+ ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
64
+ mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
65
+
66
+ # Reshape
67
+ mask = (
68
+ mask.reshape(b, *mid_shape)
69
+ .unsqueeze(1)
70
+ .type(attn.dtype)
71
+ )
72
+ # Upsample
73
+ mask = F.interpolate(mask, (lh, lw))
74
+
75
+ blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
76
+ blurred = blurred * mask + x0 * (1 - mask)
77
+ return blurred
78
+
79
+ def gaussian_blur_2d(img, kernel_size, sigma):
80
+ ksize_half = (kernel_size - 1) * 0.5
81
+
82
+ x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
83
+
84
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
85
+
86
+ x_kernel = pdf / pdf.sum()
87
+ x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
88
+
89
+ kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
90
+ kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
91
+
92
+ padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
93
+
94
+ img = F.pad(img, padding, mode="reflect")
95
+ img = F.conv2d(img, kernel2d, groups=img.shape[-3])
96
+ return img
97
+
98
+ class SelfAttentionGuidance:
99
+ @classmethod
100
+ def INPUT_TYPES(s):
101
+ return {"required": { "model": ("MODEL",),
102
+ "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}),
103
+ "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
104
+ }}
105
+ RETURN_TYPES = ("MODEL",)
106
+ FUNCTION = "patch"
107
+
108
+ CATEGORY = "_for_testing"
109
+
110
+ def patch(self, model, scale, blur_sigma):
111
+ m = model.clone()
112
+
113
+ attn_scores = None
114
+
115
+ # TODO: make this work properly with chunked batches
116
+ # currently, we can only save the attn from one UNet call
117
+ def attn_and_record(q, k, v, extra_options):
118
+ nonlocal attn_scores
119
+ # if uncond, save the attention scores
120
+ heads = extra_options["n_heads"]
121
+ cond_or_uncond = extra_options["cond_or_uncond"]
122
+ b = q.shape[0] // len(cond_or_uncond)
123
+ if 1 in cond_or_uncond:
124
+ uncond_index = cond_or_uncond.index(1)
125
+ # do the entire attention operation, but save the attention scores to attn_scores
126
+ (out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
127
+ # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
128
+ n_slices = heads * b
129
+ attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
130
+ return out
131
+ else:
132
+ return optimized_attention(q, k, v, heads=heads)
133
+
134
+ def post_cfg_function(args):
135
+ nonlocal attn_scores
136
+ uncond_attn = attn_scores
137
+
138
+ sag_scale = scale
139
+ sag_sigma = blur_sigma
140
+ sag_threshold = 1.0
141
+ model = args["model"]
142
+ uncond_pred = args["uncond_denoised"]
143
+ uncond = args["uncond"]
144
+ cfg_result = args["denoised"]
145
+ sigma = args["sigma"]
146
+ model_options = args["model_options"]
147
+ x = args["input"]
148
+ if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding
149
+ return cfg_result
150
+
151
+ # create the adversarially blurred image
152
+ degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
153
+ degraded_noised = degraded + x - uncond_pred
154
+ # call into the UNet
155
+ (sag, _) = ldm_patched.modules.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
156
+ return cfg_result + (degraded - sag) * sag_scale
157
+
158
+ m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
159
+
160
+ # from diffusers:
161
+ # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
162
+ m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
163
+
164
+ return (m, )
165
+
166
+ NODE_CLASS_MAPPINGS = {
167
+ "SelfAttentionGuidance": SelfAttentionGuidance,
168
+ }
169
+
170
+ NODE_DISPLAY_NAME_MAPPINGS = {
171
+ "SelfAttentionGuidance": "Self-Attention Guidance",
172
+ }
ldm_patched/contrib/external_sdupscale.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import torch
6
+ import ldm_patched.contrib.external
7
+ import ldm_patched.modules.utils
8
+
9
+ class SD_4XUpscale_Conditioning:
10
+ @classmethod
11
+ def INPUT_TYPES(s):
12
+ return {"required": { "images": ("IMAGE",),
13
+ "positive": ("CONDITIONING",),
14
+ "negative": ("CONDITIONING",),
15
+ "scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}),
16
+ "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
17
+ }}
18
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
19
+ RETURN_NAMES = ("positive", "negative", "latent")
20
+
21
+ FUNCTION = "encode"
22
+
23
+ CATEGORY = "conditioning/upscale_diffusion"
24
+
25
+ def encode(self, images, positive, negative, scale_ratio, noise_augmentation):
26
+ width = max(1, round(images.shape[-2] * scale_ratio))
27
+ height = max(1, round(images.shape[-3] * scale_ratio))
28
+
29
+ pixels = ldm_patched.modules.utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center")
30
+
31
+ out_cp = []
32
+ out_cn = []
33
+
34
+ for t in positive:
35
+ n = [t[0], t[1].copy()]
36
+ n[1]['concat_image'] = pixels
37
+ n[1]['noise_augmentation'] = noise_augmentation
38
+ out_cp.append(n)
39
+
40
+ for t in negative:
41
+ n = [t[0], t[1].copy()]
42
+ n[1]['concat_image'] = pixels
43
+ n[1]['noise_augmentation'] = noise_augmentation
44
+ out_cn.append(n)
45
+
46
+ latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
47
+ return (out_cp, out_cn, {"samples":latent})
48
+
49
+ NODE_CLASS_MAPPINGS = {
50
+ "SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning,
51
+ }
ldm_patched/contrib/external_stable3d.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import torch
4
+ import ldm_patched.contrib.external
5
+ import ldm_patched.modules.utils
6
+
7
+ def camera_embeddings(elevation, azimuth):
8
+ elevation = torch.as_tensor([elevation])
9
+ azimuth = torch.as_tensor([azimuth])
10
+ embeddings = torch.stack(
11
+ [
12
+ torch.deg2rad(
13
+ (90 - elevation) - (90)
14
+ ), # Zero123 polar is 90-elevation
15
+ torch.sin(torch.deg2rad(azimuth)),
16
+ torch.cos(torch.deg2rad(azimuth)),
17
+ torch.deg2rad(
18
+ 90 - torch.full_like(elevation, 0)
19
+ ),
20
+ ], dim=-1).unsqueeze(1)
21
+
22
+ return embeddings
23
+
24
+
25
+ class StableZero123_Conditioning:
26
+ @classmethod
27
+ def INPUT_TYPES(s):
28
+ return {"required": { "clip_vision": ("CLIP_VISION",),
29
+ "init_image": ("IMAGE",),
30
+ "vae": ("VAE",),
31
+ "width": ("INT", {"default": 256, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}),
32
+ "height": ("INT", {"default": 256, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}),
33
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
34
+ "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
35
+ "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
36
+ }}
37
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
38
+ RETURN_NAMES = ("positive", "negative", "latent")
39
+
40
+ FUNCTION = "encode"
41
+
42
+ CATEGORY = "conditioning/3d_models"
43
+
44
+ def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth):
45
+ output = clip_vision.encode_image(init_image)
46
+ pooled = output.image_embeds.unsqueeze(0)
47
+ pixels = ldm_patched.modules.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
48
+ encode_pixels = pixels[:,:,:,:3]
49
+ t = vae.encode(encode_pixels)
50
+ cam_embeds = camera_embeddings(elevation, azimuth)
51
+ cond = torch.cat([pooled, cam_embeds.to(pooled.device).repeat((pooled.shape[0], 1, 1))], dim=-1)
52
+
53
+ positive = [[cond, {"concat_latent_image": t}]]
54
+ negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
55
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8])
56
+ return (positive, negative, {"samples":latent})
57
+
58
+ class StableZero123_Conditioning_Batched:
59
+ @classmethod
60
+ def INPUT_TYPES(s):
61
+ return {"required": { "clip_vision": ("CLIP_VISION",),
62
+ "init_image": ("IMAGE",),
63
+ "vae": ("VAE",),
64
+ "width": ("INT", {"default": 256, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}),
65
+ "height": ("INT", {"default": 256, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}),
66
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
67
+ "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
68
+ "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
69
+ "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
70
+ "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
71
+ }}
72
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
73
+ RETURN_NAMES = ("positive", "negative", "latent")
74
+
75
+ FUNCTION = "encode"
76
+
77
+ CATEGORY = "conditioning/3d_models"
78
+
79
+ def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment):
80
+ output = clip_vision.encode_image(init_image)
81
+ pooled = output.image_embeds.unsqueeze(0)
82
+ pixels = ldm_patched.modules.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
83
+ encode_pixels = pixels[:,:,:,:3]
84
+ t = vae.encode(encode_pixels)
85
+
86
+ cam_embeds = []
87
+ for i in range(batch_size):
88
+ cam_embeds.append(camera_embeddings(elevation, azimuth))
89
+ elevation += elevation_batch_increment
90
+ azimuth += azimuth_batch_increment
91
+
92
+ cam_embeds = torch.cat(cam_embeds, dim=0)
93
+ cond = torch.cat([ldm_patched.modules.utils.repeat_to_batch_size(pooled, batch_size), cam_embeds], dim=-1)
94
+
95
+ positive = [[cond, {"concat_latent_image": t}]]
96
+ negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
97
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8])
98
+ return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
99
+
100
+
101
+ NODE_CLASS_MAPPINGS = {
102
+ "StableZero123_Conditioning": StableZero123_Conditioning,
103
+ "StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched,
104
+ }
ldm_patched/contrib/external_tomesd.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1st edit: https://github.com/dbolya/tomesd
2
+ # 2nd edit: https://github.com/comfyanonymous/ComfyUI
3
+ # 3rd edit: Forge official
4
+
5
+ import torch
6
+ from typing import Tuple, Callable
7
+ import math
8
+
9
+ def do_nothing(x: torch.Tensor, mode:str=None):
10
+ return x
11
+
12
+
13
+ def mps_gather_workaround(input, dim, index):
14
+ if input.shape[-1] == 1:
15
+ return torch.gather(
16
+ input.unsqueeze(-1),
17
+ dim - 1 if dim < 0 else dim,
18
+ index.unsqueeze(-1)
19
+ ).squeeze(-1)
20
+ else:
21
+ return torch.gather(input, dim, index)
22
+
23
+
24
+ def bipartite_soft_matching_random2d(metric: torch.Tensor,
25
+ w: int, h: int, sx: int, sy: int, r: int,
26
+ no_rand: bool = False) -> Tuple[Callable, Callable]:
27
+ """
28
+ Partitions the tokens into src and dst and merges r tokens from src to dst.
29
+ Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
30
+ Args:
31
+ - metric [B, N, C]: metric to use for similarity
32
+ - w: image width in tokens
33
+ - h: image height in tokens
34
+ - sx: stride in the x dimension for dst, must divide w
35
+ - sy: stride in the y dimension for dst, must divide h
36
+ - r: number of tokens to remove (by merging)
37
+ - no_rand: if true, disable randomness (use top left corner only)
38
+ """
39
+ B, N, _ = metric.shape
40
+
41
+ if r <= 0 or w == 1 or h == 1:
42
+ return do_nothing, do_nothing
43
+
44
+ gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
45
+
46
+ with torch.no_grad():
47
+
48
+ hsy, wsx = h // sy, w // sx
49
+
50
+ # For each sy by sx kernel, randomly assign one token to be dst and the rest src
51
+ if no_rand:
52
+ rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
53
+ else:
54
+ rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
55
+
56
+ # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
57
+ idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
58
+ idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
59
+ idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
60
+
61
+ # Image is not divisible by sx or sy so we need to move it into a new buffer
62
+ if (hsy * sy) < h or (wsx * sx) < w:
63
+ idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
64
+ idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
65
+ else:
66
+ idx_buffer = idx_buffer_view
67
+
68
+ # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
69
+ rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
70
+
71
+ # We're finished with these
72
+ del idx_buffer, idx_buffer_view
73
+
74
+ # rand_idx is currently dst|src, so split them
75
+ num_dst = hsy * wsx
76
+ a_idx = rand_idx[:, num_dst:, :] # src
77
+ b_idx = rand_idx[:, :num_dst, :] # dst
78
+
79
+ def split(x):
80
+ C = x.shape[-1]
81
+ src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
82
+ dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
83
+ return src, dst
84
+
85
+ # Cosine similarity between A and B
86
+ metric = metric / metric.norm(dim=-1, keepdim=True)
87
+ a, b = split(metric)
88
+ scores = a @ b.transpose(-1, -2)
89
+
90
+ # Can't reduce more than the # tokens in src
91
+ r = min(a.shape[1], r)
92
+
93
+ # Find the most similar greedily
94
+ node_max, node_idx = scores.max(dim=-1)
95
+ edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
96
+
97
+ unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
98
+ src_idx = edge_idx[..., :r, :] # Merged Tokens
99
+ dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
100
+
101
+ def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
102
+ src, dst = split(x)
103
+ n, t1, c = src.shape
104
+
105
+ unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
106
+ src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
107
+ dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
108
+
109
+ return torch.cat([unm, dst], dim=1)
110
+
111
+ def unmerge(x: torch.Tensor) -> torch.Tensor:
112
+ unm_len = unm_idx.shape[1]
113
+ unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
114
+ _, _, c = unm.shape
115
+
116
+ src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
117
+
118
+ # Combine back to the original shape
119
+ out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
120
+ out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
121
+ out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
122
+ out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
123
+
124
+ return out
125
+
126
+ return merge, unmerge
127
+
128
+
129
+ def get_functions(x, ratio, original_shape):
130
+ b, c, original_h, original_w = original_shape
131
+ original_tokens = original_h * original_w
132
+ downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
133
+ stride_x = 2
134
+ stride_y = 2
135
+ max_downsample = 1
136
+
137
+ if downsample <= max_downsample:
138
+ w = int(math.ceil(original_w / downsample))
139
+ h = int(math.ceil(original_h / downsample))
140
+ r = int(x.shape[1] * ratio)
141
+ no_rand = False
142
+ m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
143
+ return m, u
144
+
145
+ nothing = lambda y: y
146
+ return nothing, nothing
147
+
148
+
149
+ class TomePatcher:
150
+ def __init__(self):
151
+ self.u = None
152
+
153
+ def patch(self, model, ratio):
154
+ def tomesd_m(q, k, v, extra_options):
155
+ m, self.u = get_functions(q, ratio, extra_options["original_shape"])
156
+ return m(q), k, v
157
+
158
+ def tomesd_u(n, extra_options):
159
+ return self.u(n)
160
+
161
+ m = model.clone()
162
+ m.set_model_attn1_patch(tomesd_m)
163
+ m.set_model_attn1_output_patch(tomesd_u)
164
+ return m
ldm_patched/contrib/external_upscale_model.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import os
6
+ from ldm_patched.pfn import model_loading
7
+ from ldm_patched.modules import model_management
8
+ import torch
9
+ import ldm_patched.modules.utils
10
+ import ldm_patched.utils.path_utils
11
+
12
+ class UpscaleModelLoader:
13
+ @classmethod
14
+ def INPUT_TYPES(s):
15
+ return {"required": { "model_name": (ldm_patched.utils.path_utils.get_filename_list("upscale_models"), ),
16
+ }}
17
+ RETURN_TYPES = ("UPSCALE_MODEL",)
18
+ FUNCTION = "load_model"
19
+
20
+ CATEGORY = "loaders"
21
+
22
+ def load_model(self, model_name):
23
+ model_path = ldm_patched.utils.path_utils.get_full_path("upscale_models", model_name)
24
+ sd = ldm_patched.modules.utils.load_torch_file(model_path, safe_load=True)
25
+ if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
26
+ sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"module.":""})
27
+ out = model_loading.load_state_dict(sd).eval()
28
+ return (out, )
29
+
30
+
31
+ class ImageUpscaleWithModel:
32
+ @classmethod
33
+ def INPUT_TYPES(s):
34
+ return {"required": { "upscale_model": ("UPSCALE_MODEL",),
35
+ "image": ("IMAGE",),
36
+ }}
37
+ RETURN_TYPES = ("IMAGE",)
38
+ FUNCTION = "upscale"
39
+
40
+ CATEGORY = "image/upscaling"
41
+
42
+ def upscale(self, upscale_model, image):
43
+ device = model_management.get_torch_device()
44
+ upscale_model.to(device)
45
+ in_img = image.movedim(-1,-3).to(device)
46
+ free_memory = model_management.get_free_memory(device)
47
+
48
+ tile = 512
49
+ overlap = 32
50
+
51
+ oom = True
52
+ while oom:
53
+ try:
54
+ steps = in_img.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
55
+ pbar = ldm_patched.modules.utils.ProgressBar(steps)
56
+ s = ldm_patched.modules.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
57
+ oom = False
58
+ except model_management.OOM_EXCEPTION as e:
59
+ tile //= 2
60
+ if tile < 128:
61
+ raise e
62
+
63
+ upscale_model.cpu()
64
+ s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
65
+ return (s,)
66
+
67
+ NODE_CLASS_MAPPINGS = {
68
+ "UpscaleModelLoader": UpscaleModelLoader,
69
+ "ImageUpscaleWithModel": ImageUpscaleWithModel
70
+ }
ldm_patched/contrib/external_video_model.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import ldm_patched.contrib.external
6
+ import torch
7
+ import ldm_patched.modules.utils
8
+ import ldm_patched.modules.sd
9
+ import ldm_patched.utils.path_utils
10
+ import ldm_patched.contrib.external_model_merging
11
+
12
+
13
+ class ImageOnlyCheckpointLoader:
14
+ @classmethod
15
+ def INPUT_TYPES(s):
16
+ return {"required": { "ckpt_name": (ldm_patched.utils.path_utils.get_filename_list("checkpoints"), ),
17
+ }}
18
+ RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
19
+ FUNCTION = "load_checkpoint"
20
+
21
+ CATEGORY = "loaders/video_models"
22
+
23
+ def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
24
+ ckpt_path = ldm_patched.utils.path_utils.get_full_path("checkpoints", ckpt_name)
25
+ out = ldm_patched.modules.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings"))
26
+ return (out[0], out[3], out[2])
27
+
28
+
29
+ class SVD_img2vid_Conditioning:
30
+ @classmethod
31
+ def INPUT_TYPES(s):
32
+ return {"required": { "clip_vision": ("CLIP_VISION",),
33
+ "init_image": ("IMAGE",),
34
+ "vae": ("VAE",),
35
+ "width": ("INT", {"default": 1024, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}),
36
+ "height": ("INT", {"default": 576, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}),
37
+ "video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
38
+ "motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}),
39
+ "fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
40
+ "augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01})
41
+ }}
42
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
43
+ RETURN_NAMES = ("positive", "negative", "latent")
44
+
45
+ FUNCTION = "encode"
46
+
47
+ CATEGORY = "conditioning/video_models"
48
+
49
+ def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
50
+ output = clip_vision.encode_image(init_image)
51
+ pooled = output.image_embeds.unsqueeze(0)
52
+ pixels = ldm_patched.modules.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
53
+ encode_pixels = pixels[:,:,:,:3]
54
+ if augmentation_level > 0:
55
+ encode_pixels += torch.randn_like(pixels) * augmentation_level
56
+ t = vae.encode(encode_pixels)
57
+ positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
58
+ negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
59
+ latent = torch.zeros([video_frames, 4, height // 8, width // 8])
60
+ return (positive, negative, {"samples":latent})
61
+
62
+ class VideoLinearCFGGuidance:
63
+ @classmethod
64
+ def INPUT_TYPES(s):
65
+ return {"required": { "model": ("MODEL",),
66
+ "min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
67
+ }}
68
+ RETURN_TYPES = ("MODEL",)
69
+ FUNCTION = "patch"
70
+
71
+ CATEGORY = "sampling/video_models"
72
+
73
+ def patch(self, model, min_cfg):
74
+ def linear_cfg(args):
75
+ cond = args["cond"]
76
+ uncond = args["uncond"]
77
+ cond_scale = args["cond_scale"]
78
+
79
+ scale = torch.linspace(min_cfg, cond_scale, cond.shape[0], device=cond.device).reshape((cond.shape[0], 1, 1, 1))
80
+ return uncond + scale * (cond - uncond)
81
+
82
+ m = model.clone()
83
+ m.set_model_sampler_cfg_function(linear_cfg)
84
+ return (m, )
85
+
86
+ class ImageOnlyCheckpointSave(ldm_patched.contrib.external_model_merging.CheckpointSave):
87
+ CATEGORY = "_for_testing"
88
+
89
+ @classmethod
90
+ def INPUT_TYPES(s):
91
+ return {"required": { "model": ("MODEL",),
92
+ "clip_vision": ("CLIP_VISION",),
93
+ "vae": ("VAE",),
94
+ "filename_prefix": ("STRING", {"default": "checkpoints/ldm_patched"}),},
95
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
96
+
97
+ def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None):
98
+ ldm_patched.contrib.external_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
99
+ return {}
100
+
101
+ NODE_CLASS_MAPPINGS = {
102
+ "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
103
+ "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
104
+ "VideoLinearCFGGuidance": VideoLinearCFGGuidance,
105
+ "ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
106
+ }
107
+
108
+ NODE_DISPLAY_NAME_MAPPINGS = {
109
+ "ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
110
+ }
ldm_patched/controlnet/cldm.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #taken from: https://github.com/lllyasviel/ControlNet
2
+ #and modified
3
+
4
+ import torch
5
+ import torch as th
6
+ import torch.nn as nn
7
+
8
+ from ldm_patched.ldm.modules.diffusionmodules.util import (
9
+ zero_module,
10
+ timestep_embedding,
11
+ )
12
+
13
+ from ldm_patched.ldm.modules.attention import SpatialTransformer
14
+ from ldm_patched.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
15
+ from ldm_patched.ldm.util import exists
16
+ import ldm_patched.modules.ops
17
+
18
+ class ControlledUnetModel(UNetModel):
19
+ #implemented in the ldm unet
20
+ pass
21
+
22
+ class ControlNet(nn.Module):
23
+ def __init__(
24
+ self,
25
+ image_size,
26
+ in_channels,
27
+ model_channels,
28
+ hint_channels,
29
+ num_res_blocks,
30
+ dropout=0,
31
+ channel_mult=(1, 2, 4, 8),
32
+ conv_resample=True,
33
+ dims=2,
34
+ num_classes=None,
35
+ use_checkpoint=False,
36
+ dtype=torch.float32,
37
+ num_heads=-1,
38
+ num_head_channels=-1,
39
+ num_heads_upsample=-1,
40
+ use_scale_shift_norm=False,
41
+ resblock_updown=False,
42
+ use_new_attention_order=False,
43
+ use_spatial_transformer=False, # custom transformer support
44
+ transformer_depth=1, # custom transformer support
45
+ context_dim=None, # custom transformer support
46
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
47
+ legacy=True,
48
+ disable_self_attentions=None,
49
+ num_attention_blocks=None,
50
+ disable_middle_self_attn=False,
51
+ use_linear_in_transformer=False,
52
+ adm_in_channels=None,
53
+ transformer_depth_middle=None,
54
+ transformer_depth_output=None,
55
+ device=None,
56
+ operations=ldm_patched.modules.ops.disable_weight_init,
57
+ **kwargs,
58
+ ):
59
+ super().__init__()
60
+ assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
61
+ if use_spatial_transformer:
62
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
63
+
64
+ if context_dim is not None:
65
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
66
+ # from omegaconf.listconfig import ListConfig
67
+ # if type(context_dim) == ListConfig:
68
+ # context_dim = list(context_dim)
69
+
70
+ if num_heads_upsample == -1:
71
+ num_heads_upsample = num_heads
72
+
73
+ if num_heads == -1:
74
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
75
+
76
+ if num_head_channels == -1:
77
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
78
+
79
+ self.dims = dims
80
+ self.image_size = image_size
81
+ self.in_channels = in_channels
82
+ self.model_channels = model_channels
83
+
84
+ if isinstance(num_res_blocks, int):
85
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
86
+ else:
87
+ if len(num_res_blocks) != len(channel_mult):
88
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
89
+ "as a list/tuple (per-level) with the same length as channel_mult")
90
+ self.num_res_blocks = num_res_blocks
91
+
92
+ if disable_self_attentions is not None:
93
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
94
+ assert len(disable_self_attentions) == len(channel_mult)
95
+ if num_attention_blocks is not None:
96
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
97
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
98
+
99
+ transformer_depth = transformer_depth[:]
100
+
101
+ self.dropout = dropout
102
+ self.channel_mult = channel_mult
103
+ self.conv_resample = conv_resample
104
+ self.num_classes = num_classes
105
+ self.use_checkpoint = use_checkpoint
106
+ self.dtype = dtype
107
+ self.num_heads = num_heads
108
+ self.num_head_channels = num_head_channels
109
+ self.num_heads_upsample = num_heads_upsample
110
+ self.predict_codebook_ids = n_embed is not None
111
+
112
+ time_embed_dim = model_channels * 4
113
+ self.time_embed = nn.Sequential(
114
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
115
+ nn.SiLU(),
116
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
117
+ )
118
+
119
+ if self.num_classes is not None:
120
+ if isinstance(self.num_classes, int):
121
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
122
+ elif self.num_classes == "continuous":
123
+ print("setting up linear c_adm embedding layer")
124
+ self.label_emb = nn.Linear(1, time_embed_dim)
125
+ elif self.num_classes == "sequential":
126
+ assert adm_in_channels is not None
127
+ self.label_emb = nn.Sequential(
128
+ nn.Sequential(
129
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
130
+ nn.SiLU(),
131
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
132
+ )
133
+ )
134
+ else:
135
+ raise ValueError()
136
+
137
+ self.input_blocks = nn.ModuleList(
138
+ [
139
+ TimestepEmbedSequential(
140
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
141
+ )
142
+ ]
143
+ )
144
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
145
+
146
+ self.input_hint_block = TimestepEmbedSequential(
147
+ operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
148
+ nn.SiLU(),
149
+ operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
150
+ nn.SiLU(),
151
+ operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
152
+ nn.SiLU(),
153
+ operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
154
+ nn.SiLU(),
155
+ operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
156
+ nn.SiLU(),
157
+ operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
158
+ nn.SiLU(),
159
+ operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
160
+ nn.SiLU(),
161
+ operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
162
+ )
163
+
164
+ self._feature_size = model_channels
165
+ input_block_chans = [model_channels]
166
+ ch = model_channels
167
+ ds = 1
168
+ for level, mult in enumerate(channel_mult):
169
+ for nr in range(self.num_res_blocks[level]):
170
+ layers = [
171
+ ResBlock(
172
+ ch,
173
+ time_embed_dim,
174
+ dropout,
175
+ out_channels=mult * model_channels,
176
+ dims=dims,
177
+ use_checkpoint=use_checkpoint,
178
+ use_scale_shift_norm=use_scale_shift_norm,
179
+ dtype=self.dtype,
180
+ device=device,
181
+ operations=operations,
182
+ )
183
+ ]
184
+ ch = mult * model_channels
185
+ num_transformers = transformer_depth.pop(0)
186
+ if num_transformers > 0:
187
+ if num_head_channels == -1:
188
+ dim_head = ch // num_heads
189
+ else:
190
+ num_heads = ch // num_head_channels
191
+ dim_head = num_head_channels
192
+ if legacy:
193
+ #num_heads = 1
194
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
195
+ if exists(disable_self_attentions):
196
+ disabled_sa = disable_self_attentions[level]
197
+ else:
198
+ disabled_sa = False
199
+
200
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
201
+ layers.append(
202
+ SpatialTransformer(
203
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
204
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
205
+ use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
206
+ )
207
+ )
208
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
209
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
210
+ self._feature_size += ch
211
+ input_block_chans.append(ch)
212
+ if level != len(channel_mult) - 1:
213
+ out_ch = ch
214
+ self.input_blocks.append(
215
+ TimestepEmbedSequential(
216
+ ResBlock(
217
+ ch,
218
+ time_embed_dim,
219
+ dropout,
220
+ out_channels=out_ch,
221
+ dims=dims,
222
+ use_checkpoint=use_checkpoint,
223
+ use_scale_shift_norm=use_scale_shift_norm,
224
+ down=True,
225
+ dtype=self.dtype,
226
+ device=device,
227
+ operations=operations
228
+ )
229
+ if resblock_updown
230
+ else Downsample(
231
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
232
+ )
233
+ )
234
+ )
235
+ ch = out_ch
236
+ input_block_chans.append(ch)
237
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
238
+ ds *= 2
239
+ self._feature_size += ch
240
+
241
+ if num_head_channels == -1:
242
+ dim_head = ch // num_heads
243
+ else:
244
+ num_heads = ch // num_head_channels
245
+ dim_head = num_head_channels
246
+ if legacy:
247
+ #num_heads = 1
248
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
249
+ mid_block = [
250
+ ResBlock(
251
+ ch,
252
+ time_embed_dim,
253
+ dropout,
254
+ dims=dims,
255
+ use_checkpoint=use_checkpoint,
256
+ use_scale_shift_norm=use_scale_shift_norm,
257
+ dtype=self.dtype,
258
+ device=device,
259
+ operations=operations
260
+ )]
261
+ if transformer_depth_middle >= 0:
262
+ mid_block += [SpatialTransformer( # always uses a self-attn
263
+ ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
264
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
265
+ use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
266
+ ),
267
+ ResBlock(
268
+ ch,
269
+ time_embed_dim,
270
+ dropout,
271
+ dims=dims,
272
+ use_checkpoint=use_checkpoint,
273
+ use_scale_shift_norm=use_scale_shift_norm,
274
+ dtype=self.dtype,
275
+ device=device,
276
+ operations=operations
277
+ )]
278
+ self.middle_block = TimestepEmbedSequential(*mid_block)
279
+ self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
280
+ self._feature_size += ch
281
+
282
+ def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
283
+ return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
284
+
285
+ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
286
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
287
+ emb = self.time_embed(t_emb)
288
+
289
+ guided_hint = self.input_hint_block(hint, emb, context)
290
+
291
+ outs = []
292
+
293
+ hs = []
294
+ if self.num_classes is not None:
295
+ assert y.shape[0] == x.shape[0]
296
+ emb = emb + self.label_emb(y)
297
+
298
+ h = x
299
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
300
+ if guided_hint is not None:
301
+ h = module(h, emb, context)
302
+ h += guided_hint
303
+ guided_hint = None
304
+ else:
305
+ h = module(h, emb, context)
306
+ outs.append(zero_conv(h, emb, context))
307
+
308
+ h = self.middle_block(h, emb, context)
309
+ outs.append(self.middle_block_out(h, emb, context))
310
+
311
+ return outs
312
+
ldm_patched/k_diffusion/sampling.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import math
6
+
7
+ from scipy import integrate
8
+ import torch
9
+ from torch import nn
10
+ import torchsde
11
+ from tqdm.auto import trange, tqdm
12
+
13
+ from . import utils
14
+
15
+
16
+ def append_zero(x):
17
+ return torch.cat([x, x.new_zeros([1])])
18
+
19
+
20
+ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
21
+ """Constructs the noise schedule of Karras et al. (2022)."""
22
+ ramp = torch.linspace(0, 1, n, device=device)
23
+ min_inv_rho = sigma_min ** (1 / rho)
24
+ max_inv_rho = sigma_max ** (1 / rho)
25
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
26
+ return append_zero(sigmas).to(device)
27
+
28
+
29
+ def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
30
+ """Constructs an exponential noise schedule."""
31
+ sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
32
+ return append_zero(sigmas)
33
+
34
+
35
+ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
36
+ """Constructs an polynomial in log sigma noise schedule."""
37
+ ramp = torch.linspace(1, 0, n, device=device) ** rho
38
+ sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
39
+ return append_zero(sigmas)
40
+
41
+
42
+ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
43
+ """Constructs a continuous VP noise schedule."""
44
+ t = torch.linspace(1, eps_s, n, device=device)
45
+ sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
46
+ return append_zero(sigmas)
47
+
48
+
49
+ def to_d(x, sigma, denoised):
50
+ """Converts a denoiser output to a Karras ODE derivative."""
51
+ return (x - denoised) / utils.append_dims(sigma, x.ndim)
52
+
53
+
54
+ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
55
+ """Calculates the noise level (sigma_down) to step down to and the amount
56
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
57
+ if not eta:
58
+ return sigma_to, 0.
59
+ sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
60
+ sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
61
+ return sigma_down, sigma_up
62
+
63
+
64
+ def default_noise_sampler(x):
65
+ return lambda sigma, sigma_next: torch.randn_like(x)
66
+
67
+
68
+ class BatchedBrownianTree:
69
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
70
+
71
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
72
+ self.cpu_tree = True
73
+ if "cpu" in kwargs:
74
+ self.cpu_tree = kwargs.pop("cpu")
75
+ t0, t1, self.sign = self.sort(t0, t1)
76
+ w0 = kwargs.get('w0', torch.zeros_like(x))
77
+ if seed is None:
78
+ seed = torch.randint(0, 2 ** 63 - 1, []).item()
79
+ self.batched = True
80
+ try:
81
+ assert len(seed) == x.shape[0]
82
+ w0 = w0[0]
83
+ except TypeError:
84
+ seed = [seed]
85
+ self.batched = False
86
+ if self.cpu_tree:
87
+ self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
88
+ else:
89
+ self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
90
+
91
+ @staticmethod
92
+ def sort(a, b):
93
+ return (a, b, 1) if a < b else (b, a, -1)
94
+
95
+ def __call__(self, t0, t1):
96
+ t0, t1, sign = self.sort(t0, t1)
97
+ if self.cpu_tree:
98
+ w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
99
+ else:
100
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
101
+
102
+ return w if self.batched else w[0]
103
+
104
+
105
+ class BrownianTreeNoiseSampler:
106
+ """A noise sampler backed by a torchsde.BrownianTree.
107
+
108
+ Args:
109
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
110
+ random samples.
111
+ sigma_min (float): The low end of the valid interval.
112
+ sigma_max (float): The high end of the valid interval.
113
+ seed (int or List[int]): The random seed. If a list of seeds is
114
+ supplied instead of a single integer, then the noise sampler will
115
+ use one BrownianTree per batch item, each with its own seed.
116
+ transform (callable): A function that maps sigma to the sampler's
117
+ internal timestep.
118
+ """
119
+
120
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
121
+ self.transform = transform
122
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
123
+ self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
124
+
125
+ def __call__(self, sigma, sigma_next):
126
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
127
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
128
+
129
+
130
+ @torch.no_grad()
131
+ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
132
+ """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
133
+ extra_args = {} if extra_args is None else extra_args
134
+ s_in = x.new_ones([x.shape[0]])
135
+ for i in trange(len(sigmas) - 1, disable=disable):
136
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
137
+ sigma_hat = sigmas[i] * (gamma + 1)
138
+ if gamma > 0:
139
+ eps = torch.randn_like(x) * s_noise
140
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
141
+ denoised = model(x, sigma_hat * s_in, **extra_args)
142
+ d = to_d(x, sigma_hat, denoised)
143
+ if callback is not None:
144
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
145
+ dt = sigmas[i + 1] - sigma_hat
146
+ # Euler method
147
+ x = x + d * dt
148
+ return x
149
+
150
+
151
+ @torch.no_grad()
152
+ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
153
+ """Ancestral sampling with Euler method steps."""
154
+ extra_args = {} if extra_args is None else extra_args
155
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
156
+ s_in = x.new_ones([x.shape[0]])
157
+ for i in trange(len(sigmas) - 1, disable=disable):
158
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
159
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
160
+ if callback is not None:
161
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
162
+ d = to_d(x, sigmas[i], denoised)
163
+ # Euler method
164
+ dt = sigma_down - sigmas[i]
165
+ x = x + d * dt
166
+ if sigmas[i + 1] > 0:
167
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
168
+ return x
169
+
170
+
171
+ @torch.no_grad()
172
+ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
173
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
174
+ extra_args = {} if extra_args is None else extra_args
175
+ s_in = x.new_ones([x.shape[0]])
176
+ for i in trange(len(sigmas) - 1, disable=disable):
177
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
178
+ sigma_hat = sigmas[i] * (gamma + 1)
179
+ if gamma > 0:
180
+ eps = torch.randn_like(x) * s_noise
181
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
182
+ denoised = model(x, sigma_hat * s_in, **extra_args)
183
+ d = to_d(x, sigma_hat, denoised)
184
+ if callback is not None:
185
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
186
+ dt = sigmas[i + 1] - sigma_hat
187
+ if sigmas[i + 1] == 0:
188
+ # Euler method
189
+ x = x + d * dt
190
+ else:
191
+ # Heun's method
192
+ x_2 = x + d * dt
193
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
194
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
195
+ d_prime = (d + d_2) / 2
196
+ x = x + d_prime * dt
197
+ return x
198
+
199
+
200
+ @torch.no_grad()
201
+ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
202
+ """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
203
+ extra_args = {} if extra_args is None else extra_args
204
+ s_in = x.new_ones([x.shape[0]])
205
+ for i in trange(len(sigmas) - 1, disable=disable):
206
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
207
+ sigma_hat = sigmas[i] * (gamma + 1)
208
+ if gamma > 0:
209
+ eps = torch.randn_like(x) * s_noise
210
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
211
+ denoised = model(x, sigma_hat * s_in, **extra_args)
212
+ d = to_d(x, sigma_hat, denoised)
213
+ if callback is not None:
214
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
215
+ if sigmas[i + 1] == 0:
216
+ # Euler method
217
+ dt = sigmas[i + 1] - sigma_hat
218
+ x = x + d * dt
219
+ else:
220
+ # DPM-Solver-2
221
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
222
+ dt_1 = sigma_mid - sigma_hat
223
+ dt_2 = sigmas[i + 1] - sigma_hat
224
+ x_2 = x + d * dt_1
225
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
226
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
227
+ x = x + d_2 * dt_2
228
+ return x
229
+
230
+
231
+ @torch.no_grad()
232
+ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
233
+ """Ancestral sampling with DPM-Solver second-order steps."""
234
+ extra_args = {} if extra_args is None else extra_args
235
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
236
+ s_in = x.new_ones([x.shape[0]])
237
+ for i in trange(len(sigmas) - 1, disable=disable):
238
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
239
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
240
+ if callback is not None:
241
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
242
+ d = to_d(x, sigmas[i], denoised)
243
+ if sigma_down == 0:
244
+ # Euler method
245
+ dt = sigma_down - sigmas[i]
246
+ x = x + d * dt
247
+ else:
248
+ # DPM-Solver-2
249
+ sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
250
+ dt_1 = sigma_mid - sigmas[i]
251
+ dt_2 = sigma_down - sigmas[i]
252
+ x_2 = x + d * dt_1
253
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
254
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
255
+ x = x + d_2 * dt_2
256
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
257
+ return x
258
+
259
+
260
+ def linear_multistep_coeff(order, t, i, j):
261
+ if order - 1 > i:
262
+ raise ValueError(f'Order {order} too high for step {i}')
263
+ def fn(tau):
264
+ prod = 1.
265
+ for k in range(order):
266
+ if j == k:
267
+ continue
268
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
269
+ return prod
270
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
271
+
272
+
273
+ @torch.no_grad()
274
+ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
275
+ extra_args = {} if extra_args is None else extra_args
276
+ s_in = x.new_ones([x.shape[0]])
277
+ sigmas_cpu = sigmas.detach().cpu().numpy()
278
+ ds = []
279
+ for i in trange(len(sigmas) - 1, disable=disable):
280
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
281
+ d = to_d(x, sigmas[i], denoised)
282
+ ds.append(d)
283
+ if len(ds) > order:
284
+ ds.pop(0)
285
+ if callback is not None:
286
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
287
+ cur_order = min(i + 1, order)
288
+ coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
289
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
290
+ return x
291
+
292
+
293
+ class PIDStepSizeController:
294
+ """A PID controller for ODE adaptive step size control."""
295
+ def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
296
+ self.h = h
297
+ self.b1 = (pcoeff + icoeff + dcoeff) / order
298
+ self.b2 = -(pcoeff + 2 * dcoeff) / order
299
+ self.b3 = dcoeff / order
300
+ self.accept_safety = accept_safety
301
+ self.eps = eps
302
+ self.errs = []
303
+
304
+ def limiter(self, x):
305
+ return 1 + math.atan(x - 1)
306
+
307
+ def propose_step(self, error):
308
+ inv_error = 1 / (float(error) + self.eps)
309
+ if not self.errs:
310
+ self.errs = [inv_error, inv_error, inv_error]
311
+ self.errs[0] = inv_error
312
+ factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
313
+ factor = self.limiter(factor)
314
+ accept = factor >= self.accept_safety
315
+ if accept:
316
+ self.errs[2] = self.errs[1]
317
+ self.errs[1] = self.errs[0]
318
+ self.h *= factor
319
+ return accept
320
+
321
+
322
+ class DPMSolver(nn.Module):
323
+ """DPM-Solver. See https://arxiv.org/abs/2206.00927."""
324
+
325
+ def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
326
+ super().__init__()
327
+ self.model = model
328
+ self.extra_args = {} if extra_args is None else extra_args
329
+ self.eps_callback = eps_callback
330
+ self.info_callback = info_callback
331
+
332
+ def t(self, sigma):
333
+ return -sigma.log()
334
+
335
+ def sigma(self, t):
336
+ return t.neg().exp()
337
+
338
+ def eps(self, eps_cache, key, x, t, *args, **kwargs):
339
+ if key in eps_cache:
340
+ return eps_cache[key], eps_cache
341
+ sigma = self.sigma(t) * x.new_ones([x.shape[0]])
342
+ eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
343
+ if self.eps_callback is not None:
344
+ self.eps_callback()
345
+ return eps, {key: eps, **eps_cache}
346
+
347
+ def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
348
+ eps_cache = {} if eps_cache is None else eps_cache
349
+ h = t_next - t
350
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
351
+ x_1 = x - self.sigma(t_next) * h.expm1() * eps
352
+ return x_1, eps_cache
353
+
354
+ def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
355
+ eps_cache = {} if eps_cache is None else eps_cache
356
+ h = t_next - t
357
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
358
+ s1 = t + r1 * h
359
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
360
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
361
+ x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
362
+ return x_2, eps_cache
363
+
364
+ def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
365
+ eps_cache = {} if eps_cache is None else eps_cache
366
+ h = t_next - t
367
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
368
+ s1 = t + r1 * h
369
+ s2 = t + r2 * h
370
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
371
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
372
+ u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
373
+ eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
374
+ x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
375
+ return x_3, eps_cache
376
+
377
+ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
378
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
379
+ if not t_end > t_start and eta:
380
+ raise ValueError('eta must be 0 for reverse sampling')
381
+
382
+ m = math.floor(nfe / 3) + 1
383
+ ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
384
+
385
+ if nfe % 3 == 0:
386
+ orders = [3] * (m - 2) + [2, 1]
387
+ else:
388
+ orders = [3] * (m - 1) + [nfe % 3]
389
+
390
+ for i in range(len(orders)):
391
+ eps_cache = {}
392
+ t, t_next = ts[i], ts[i + 1]
393
+ if eta:
394
+ sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
395
+ t_next_ = torch.minimum(t_end, self.t(sd))
396
+ su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
397
+ else:
398
+ t_next_, su = t_next, 0.
399
+
400
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
401
+ denoised = x - self.sigma(t) * eps
402
+ if self.info_callback is not None:
403
+ self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
404
+
405
+ if orders[i] == 1:
406
+ x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
407
+ elif orders[i] == 2:
408
+ x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
409
+ else:
410
+ x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
411
+
412
+ x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
413
+
414
+ return x
415
+
416
+ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
417
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
418
+ if order not in {2, 3}:
419
+ raise ValueError('order should be 2 or 3')
420
+ forward = t_end > t_start
421
+ if not forward and eta:
422
+ raise ValueError('eta must be 0 for reverse sampling')
423
+ h_init = abs(h_init) * (1 if forward else -1)
424
+ atol = torch.tensor(atol)
425
+ rtol = torch.tensor(rtol)
426
+ s = t_start
427
+ x_prev = x
428
+ accept = True
429
+ pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
430
+ info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
431
+
432
+ while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
433
+ eps_cache = {}
434
+ t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
435
+ if eta:
436
+ sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
437
+ t_ = torch.minimum(t_end, self.t(sd))
438
+ su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
439
+ else:
440
+ t_, su = t, 0.
441
+
442
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
443
+ denoised = x - self.sigma(s) * eps
444
+
445
+ if order == 2:
446
+ x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
447
+ x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
448
+ else:
449
+ x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
450
+ x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
451
+ delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
452
+ error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
453
+ accept = pid.propose_step(error)
454
+ if accept:
455
+ x_prev = x_low
456
+ x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
457
+ s = t
458
+ info['n_accept'] += 1
459
+ else:
460
+ info['n_reject'] += 1
461
+ info['nfe'] += order
462
+ info['steps'] += 1
463
+
464
+ if self.info_callback is not None:
465
+ self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
466
+
467
+ return x, info
468
+
469
+
470
+ @torch.no_grad()
471
+ def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
472
+ """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
473
+ if sigma_min <= 0 or sigma_max <= 0:
474
+ raise ValueError('sigma_min and sigma_max must not be 0')
475
+ with tqdm(total=n, disable=disable) as pbar:
476
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
477
+ if callback is not None:
478
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
479
+ return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
480
+
481
+
482
+ @torch.no_grad()
483
+ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
484
+ """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
485
+ if sigma_min <= 0 or sigma_max <= 0:
486
+ raise ValueError('sigma_min and sigma_max must not be 0')
487
+ with tqdm(disable=disable) as pbar:
488
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
489
+ if callback is not None:
490
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
491
+ x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
492
+ if return_info:
493
+ return x, info
494
+ return x
495
+
496
+
497
+ @torch.no_grad()
498
+ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
499
+ """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
500
+ extra_args = {} if extra_args is None else extra_args
501
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
502
+ s_in = x.new_ones([x.shape[0]])
503
+ sigma_fn = lambda t: t.neg().exp()
504
+ t_fn = lambda sigma: sigma.log().neg()
505
+
506
+ for i in trange(len(sigmas) - 1, disable=disable):
507
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
508
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
509
+ if callback is not None:
510
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
511
+ if sigma_down == 0:
512
+ # Euler method
513
+ d = to_d(x, sigmas[i], denoised)
514
+ dt = sigma_down - sigmas[i]
515
+ x = x + d * dt
516
+ else:
517
+ # DPM-Solver++(2S)
518
+ t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
519
+ r = 1 / 2
520
+ h = t_next - t
521
+ s = t + r * h
522
+ x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
523
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
524
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
525
+ # Noise addition
526
+ if sigmas[i + 1] > 0:
527
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
528
+ return x
529
+
530
+
531
+ @torch.no_grad()
532
+ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
533
+ """DPM-Solver++ (stochastic)."""
534
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
535
+ seed = extra_args.get("seed", None)
536
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
537
+ extra_args = {} if extra_args is None else extra_args
538
+ s_in = x.new_ones([x.shape[0]])
539
+ sigma_fn = lambda t: t.neg().exp()
540
+ t_fn = lambda sigma: sigma.log().neg()
541
+
542
+ for i in trange(len(sigmas) - 1, disable=disable):
543
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
544
+ if callback is not None:
545
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
546
+ if sigmas[i + 1] == 0:
547
+ # Euler method
548
+ d = to_d(x, sigmas[i], denoised)
549
+ dt = sigmas[i + 1] - sigmas[i]
550
+ x = x + d * dt
551
+ else:
552
+ # DPM-Solver++
553
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
554
+ h = t_next - t
555
+ s = t + h * r
556
+ fac = 1 / (2 * r)
557
+
558
+ # Step 1
559
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
560
+ s_ = t_fn(sd)
561
+ x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
562
+ x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
563
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
564
+
565
+ # Step 2
566
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
567
+ t_next_ = t_fn(sd)
568
+ denoised_d = (1 - fac) * denoised + fac * denoised_2
569
+ x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
570
+ x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
571
+ return x
572
+
573
+
574
+ @torch.no_grad()
575
+ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
576
+ """DPM-Solver++(2M)."""
577
+ extra_args = {} if extra_args is None else extra_args
578
+ s_in = x.new_ones([x.shape[0]])
579
+ sigma_fn = lambda t: t.neg().exp()
580
+ t_fn = lambda sigma: sigma.log().neg()
581
+ old_denoised = None
582
+
583
+ for i in trange(len(sigmas) - 1, disable=disable):
584
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
585
+ if callback is not None:
586
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
587
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
588
+ h = t_next - t
589
+ if old_denoised is None or sigmas[i + 1] == 0:
590
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
591
+ else:
592
+ h_last = t - t_fn(sigmas[i - 1])
593
+ r = h_last / h
594
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
595
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
596
+ old_denoised = denoised
597
+ return x
598
+
599
+ @torch.no_grad()
600
+ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
601
+ """DPM-Solver++(2M) SDE."""
602
+
603
+ if solver_type not in {'heun', 'midpoint'}:
604
+ raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
605
+
606
+ seed = extra_args.get("seed", None)
607
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
608
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
609
+ extra_args = {} if extra_args is None else extra_args
610
+ s_in = x.new_ones([x.shape[0]])
611
+
612
+ old_denoised = None
613
+ h_last = None
614
+ h = None
615
+
616
+ for i in trange(len(sigmas) - 1, disable=disable):
617
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
618
+ if callback is not None:
619
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
620
+ if sigmas[i + 1] == 0:
621
+ # Denoising step
622
+ x = denoised
623
+ else:
624
+ # DPM-Solver++(2M) SDE
625
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
626
+ h = s - t
627
+ eta_h = eta * h
628
+
629
+ x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
630
+
631
+ if old_denoised is not None:
632
+ r = h_last / h
633
+ if solver_type == 'heun':
634
+ x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
635
+ elif solver_type == 'midpoint':
636
+ x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
637
+
638
+ if eta:
639
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
640
+
641
+ old_denoised = denoised
642
+ h_last = h
643
+ return x
644
+
645
+ @torch.no_grad()
646
+ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
647
+ """DPM-Solver++(3M) SDE."""
648
+
649
+ seed = extra_args.get("seed", None)
650
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
651
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
652
+ extra_args = {} if extra_args is None else extra_args
653
+ s_in = x.new_ones([x.shape[0]])
654
+
655
+ denoised_1, denoised_2 = None, None
656
+ h, h_1, h_2 = None, None, None
657
+
658
+ for i in trange(len(sigmas) - 1, disable=disable):
659
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
660
+ if callback is not None:
661
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
662
+ if sigmas[i + 1] == 0:
663
+ # Denoising step
664
+ x = denoised
665
+ else:
666
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
667
+ h = s - t
668
+ h_eta = h * (eta + 1)
669
+
670
+ x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
671
+
672
+ if h_2 is not None:
673
+ r0 = h_1 / h
674
+ r1 = h_2 / h
675
+ d1_0 = (denoised - denoised_1) / r0
676
+ d1_1 = (denoised_1 - denoised_2) / r1
677
+ d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
678
+ d2 = (d1_0 - d1_1) / (r0 + r1)
679
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
680
+ phi_3 = phi_2 / h_eta - 0.5
681
+ x = x + phi_2 * d1 - phi_3 * d2
682
+ elif h_1 is not None:
683
+ r = h_1 / h
684
+ d = (denoised - denoised_1) / r
685
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
686
+ x = x + phi_2 * d
687
+
688
+ if eta:
689
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
690
+
691
+ denoised_1, denoised_2 = denoised, denoised_1
692
+ h_1, h_2 = h, h_1
693
+ return x
694
+
695
+ @torch.no_grad()
696
+ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
697
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
698
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
699
+ return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
700
+
701
+ @torch.no_grad()
702
+ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
703
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
704
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
705
+ return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
706
+
707
+ @torch.no_grad()
708
+ def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
709
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
710
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
711
+ return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
712
+
713
+
714
+ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
715
+ alpha_cumprod = 1 / ((sigma * sigma) + 1)
716
+ alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1)
717
+ alpha = (alpha_cumprod / alpha_cumprod_prev)
718
+
719
+ mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt())
720
+ if sigma_prev > 0:
721
+ mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
722
+ return mu
723
+
724
+ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
725
+ extra_args = {} if extra_args is None else extra_args
726
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
727
+ s_in = x.new_ones([x.shape[0]])
728
+
729
+ for i in trange(len(sigmas) - 1, disable=disable):
730
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
731
+ if callback is not None:
732
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
733
+ x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler)
734
+ if sigmas[i + 1] != 0:
735
+ x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0)
736
+ return x
737
+
738
+
739
+ @torch.no_grad()
740
+ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
741
+ return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
742
+
743
+ @torch.no_grad()
744
+ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
745
+ extra_args = {} if extra_args is None else extra_args
746
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
747
+ s_in = x.new_ones([x.shape[0]])
748
+ for i in trange(len(sigmas) - 1, disable=disable):
749
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
750
+ if callback is not None:
751
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
752
+
753
+ x = denoised
754
+ if sigmas[i + 1] > 0:
755
+ x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
756
+ return x
757
+
758
+
759
+
760
+ @torch.no_grad()
761
+ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
762
+ # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
763
+ extra_args = {} if extra_args is None else extra_args
764
+ s_in = x.new_ones([x.shape[0]])
765
+ s_end = sigmas[-1]
766
+ for i in trange(len(sigmas) - 1, disable=disable):
767
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
768
+ eps = torch.randn_like(x) * s_noise
769
+ sigma_hat = sigmas[i] * (gamma + 1)
770
+ if gamma > 0:
771
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
772
+ denoised = model(x, sigma_hat * s_in, **extra_args)
773
+ d = to_d(x, sigma_hat, denoised)
774
+ if callback is not None:
775
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
776
+ dt = sigmas[i + 1] - sigma_hat
777
+ if sigmas[i + 1] == s_end:
778
+ # Euler method
779
+ x = x + d * dt
780
+ elif sigmas[i + 2] == s_end:
781
+
782
+ # Heun's method
783
+ x_2 = x + d * dt
784
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
785
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
786
+
787
+ w = 2 * sigmas[0]
788
+ w2 = sigmas[i+1]/w
789
+ w1 = 1 - w2
790
+
791
+ d_prime = d * w1 + d_2 * w2
792
+
793
+
794
+ x = x + d_prime * dt
795
+
796
+ else:
797
+ # Heun++
798
+ x_2 = x + d * dt
799
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
800
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
801
+ dt_2 = sigmas[i + 2] - sigmas[i + 1]
802
+
803
+ x_3 = x_2 + d_2 * dt_2
804
+ denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
805
+ d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
806
+
807
+ w = 3 * sigmas[0]
808
+ w2 = sigmas[i + 1] / w
809
+ w3 = sigmas[i + 2] / w
810
+ w1 = 1 - w2 - w3
811
+
812
+ d_prime = w1 * d + w2 * d_2 + w3 * d_3
813
+ x = x + d_prime * dt
814
+ return x
ldm_patched/k_diffusion/utils.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ from contextlib import contextmanager
6
+ import hashlib
7
+ import math
8
+ from pathlib import Path
9
+ import shutil
10
+ import urllib
11
+ import warnings
12
+
13
+ from PIL import Image
14
+ import torch
15
+ from torch import nn, optim
16
+ from torch.utils import data
17
+
18
+
19
+ def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
20
+ """Apply passed in transforms for HuggingFace Datasets."""
21
+ images = [transform(image.convert(mode)) for image in examples[image_key]]
22
+ return {image_key: images}
23
+
24
+
25
+ def append_dims(x, target_dims):
26
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
27
+ dims_to_append = target_dims - x.ndim
28
+ if dims_to_append < 0:
29
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
30
+ expanded = x[(...,) + (None,) * dims_to_append]
31
+ # MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
32
+ # https://github.com/pytorch/pytorch/issues/84364
33
+ return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
34
+
35
+
36
+ def n_params(module):
37
+ """Returns the number of trainable parameters in a module."""
38
+ return sum(p.numel() for p in module.parameters())
39
+
40
+
41
+ def download_file(path, url, digest=None):
42
+ """Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
43
+ path = Path(path)
44
+ path.parent.mkdir(parents=True, exist_ok=True)
45
+ if not path.exists():
46
+ with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
47
+ shutil.copyfileobj(response, f)
48
+ if digest is not None:
49
+ file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
50
+ if digest != file_digest:
51
+ raise OSError(f'hash of {path} (url: {url}) failed to validate')
52
+ return path
53
+
54
+
55
+ @contextmanager
56
+ def train_mode(model, mode=True):
57
+ """A context manager that places a model into training mode and restores
58
+ the previous mode on exit."""
59
+ modes = [module.training for module in model.modules()]
60
+ try:
61
+ yield model.train(mode)
62
+ finally:
63
+ for i, module in enumerate(model.modules()):
64
+ module.training = modes[i]
65
+
66
+
67
+ def eval_mode(model):
68
+ """A context manager that places a model into evaluation mode and restores
69
+ the previous mode on exit."""
70
+ return train_mode(model, False)
71
+
72
+
73
+ @torch.no_grad()
74
+ def ema_update(model, averaged_model, decay):
75
+ """Incorporates updated model parameters into an exponential moving averaged
76
+ version of a model. It should be called after each optimizer step."""
77
+ model_params = dict(model.named_parameters())
78
+ averaged_params = dict(averaged_model.named_parameters())
79
+ assert model_params.keys() == averaged_params.keys()
80
+
81
+ for name, param in model_params.items():
82
+ averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
83
+
84
+ model_buffers = dict(model.named_buffers())
85
+ averaged_buffers = dict(averaged_model.named_buffers())
86
+ assert model_buffers.keys() == averaged_buffers.keys()
87
+
88
+ for name, buf in model_buffers.items():
89
+ averaged_buffers[name].copy_(buf)
90
+
91
+
92
+ class EMAWarmup:
93
+ """Implements an EMA warmup using an inverse decay schedule.
94
+ If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
95
+ good values for models you plan to train for a million or more steps (reaches decay
96
+ factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
97
+ you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
98
+ 215.4k steps).
99
+ Args:
100
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
101
+ power (float): Exponential factor of EMA warmup. Default: 1.
102
+ min_value (float): The minimum EMA decay rate. Default: 0.
103
+ max_value (float): The maximum EMA decay rate. Default: 1.
104
+ start_at (int): The epoch to start averaging at. Default: 0.
105
+ last_epoch (int): The index of last epoch. Default: 0.
106
+ """
107
+
108
+ def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
109
+ last_epoch=0):
110
+ self.inv_gamma = inv_gamma
111
+ self.power = power
112
+ self.min_value = min_value
113
+ self.max_value = max_value
114
+ self.start_at = start_at
115
+ self.last_epoch = last_epoch
116
+
117
+ def state_dict(self):
118
+ """Returns the state of the class as a :class:`dict`."""
119
+ return dict(self.__dict__.items())
120
+
121
+ def load_state_dict(self, state_dict):
122
+ """Loads the class's state.
123
+ Args:
124
+ state_dict (dict): scaler state. Should be an object returned
125
+ from a call to :meth:`state_dict`.
126
+ """
127
+ self.__dict__.update(state_dict)
128
+
129
+ def get_value(self):
130
+ """Gets the current EMA decay rate."""
131
+ epoch = max(0, self.last_epoch - self.start_at)
132
+ value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
133
+ return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
134
+
135
+ def step(self):
136
+ """Updates the step count."""
137
+ self.last_epoch += 1
138
+
139
+
140
+ class InverseLR(optim.lr_scheduler._LRScheduler):
141
+ """Implements an inverse decay learning rate schedule with an optional exponential
142
+ warmup. When last_epoch=-1, sets initial lr as lr.
143
+ inv_gamma is the number of steps/epochs required for the learning rate to decay to
144
+ (1 / 2)**power of its original value.
145
+ Args:
146
+ optimizer (Optimizer): Wrapped optimizer.
147
+ inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
148
+ power (float): Exponential factor of learning rate decay. Default: 1.
149
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
150
+ Default: 0.
151
+ min_lr (float): The minimum learning rate. Default: 0.
152
+ last_epoch (int): The index of last epoch. Default: -1.
153
+ verbose (bool): If ``True``, prints a message to stdout for
154
+ each update. Default: ``False``.
155
+ """
156
+
157
+ def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
158
+ last_epoch=-1, verbose=False):
159
+ self.inv_gamma = inv_gamma
160
+ self.power = power
161
+ if not 0. <= warmup < 1:
162
+ raise ValueError('Invalid value for warmup')
163
+ self.warmup = warmup
164
+ self.min_lr = min_lr
165
+ super().__init__(optimizer, last_epoch, verbose)
166
+
167
+ def get_lr(self):
168
+ if not self._get_lr_called_within_step:
169
+ warnings.warn("To get the last learning rate computed by the scheduler, "
170
+ "please use `get_last_lr()`.")
171
+
172
+ return self._get_closed_form_lr()
173
+
174
+ def _get_closed_form_lr(self):
175
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
176
+ lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
177
+ return [warmup * max(self.min_lr, base_lr * lr_mult)
178
+ for base_lr in self.base_lrs]
179
+
180
+
181
+ class ExponentialLR(optim.lr_scheduler._LRScheduler):
182
+ """Implements an exponential learning rate schedule with an optional exponential
183
+ warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
184
+ continuously by decay (default 0.5) every num_steps steps.
185
+ Args:
186
+ optimizer (Optimizer): Wrapped optimizer.
187
+ num_steps (float): The number of steps to decay the learning rate by decay in.
188
+ decay (float): The factor by which to decay the learning rate every num_steps
189
+ steps. Default: 0.5.
190
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
191
+ Default: 0.
192
+ min_lr (float): The minimum learning rate. Default: 0.
193
+ last_epoch (int): The index of last epoch. Default: -1.
194
+ verbose (bool): If ``True``, prints a message to stdout for
195
+ each update. Default: ``False``.
196
+ """
197
+
198
+ def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
199
+ last_epoch=-1, verbose=False):
200
+ self.num_steps = num_steps
201
+ self.decay = decay
202
+ if not 0. <= warmup < 1:
203
+ raise ValueError('Invalid value for warmup')
204
+ self.warmup = warmup
205
+ self.min_lr = min_lr
206
+ super().__init__(optimizer, last_epoch, verbose)
207
+
208
+ def get_lr(self):
209
+ if not self._get_lr_called_within_step:
210
+ warnings.warn("To get the last learning rate computed by the scheduler, "
211
+ "please use `get_last_lr()`.")
212
+
213
+ return self._get_closed_form_lr()
214
+
215
+ def _get_closed_form_lr(self):
216
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
217
+ lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
218
+ return [warmup * max(self.min_lr, base_lr * lr_mult)
219
+ for base_lr in self.base_lrs]
220
+
221
+
222
+ def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
223
+ """Draws samples from an lognormal distribution."""
224
+ return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
225
+
226
+
227
+ def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
228
+ """Draws samples from an optionally truncated log-logistic distribution."""
229
+ min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
230
+ max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
231
+ min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
232
+ max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
233
+ u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
234
+ return u.logit().mul(scale).add(loc).exp().to(dtype)
235
+
236
+
237
+ def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
238
+ """Draws samples from an log-uniform distribution."""
239
+ min_value = math.log(min_value)
240
+ max_value = math.log(max_value)
241
+ return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
242
+
243
+
244
+ def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
245
+ """Draws samples from a truncated v-diffusion training timestep distribution."""
246
+ min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
247
+ max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
248
+ u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
249
+ return torch.tan(u * math.pi / 2) * sigma_data
250
+
251
+
252
+ def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
253
+ """Draws samples from a split lognormal distribution."""
254
+ n = torch.randn(shape, device=device, dtype=dtype).abs()
255
+ u = torch.rand(shape, device=device, dtype=dtype)
256
+ n_left = n * -scale_1 + loc
257
+ n_right = n * scale_2 + loc
258
+ ratio = scale_1 / (scale_1 + scale_2)
259
+ return torch.where(u < ratio, n_left, n_right).exp()
260
+
261
+
262
+ class FolderOfImages(data.Dataset):
263
+ """Recursively finds all images in a directory. It does not support
264
+ classes/targets."""
265
+
266
+ IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
267
+
268
+ def __init__(self, root, transform=None):
269
+ super().__init__()
270
+ self.root = Path(root)
271
+ self.transform = nn.Identity() if transform is None else transform
272
+ self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
273
+
274
+ def __repr__(self):
275
+ return f'FolderOfImages(root="{self.root}", len: {len(self)})'
276
+
277
+ def __len__(self):
278
+ return len(self.paths)
279
+
280
+ def __getitem__(self, key):
281
+ path = self.paths[key]
282
+ with open(path, 'rb') as f:
283
+ image = Image.open(f).convert('RGB')
284
+ image = self.transform(image)
285
+ return image,
286
+
287
+
288
+ class CSVLogger:
289
+ def __init__(self, filename, columns):
290
+ self.filename = Path(filename)
291
+ self.columns = columns
292
+ if self.filename.exists():
293
+ self.file = open(self.filename, 'a')
294
+ else:
295
+ self.file = open(self.filename, 'w')
296
+ self.write(*self.columns)
297
+
298
+ def write(self, *args):
299
+ print(*args, sep=',', file=self.file, flush=True)
300
+
301
+
302
+ @contextmanager
303
+ def tf32_mode(cudnn=None, matmul=None):
304
+ """A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
305
+ cudnn_old = torch.backends.cudnn.allow_tf32
306
+ matmul_old = torch.backends.cuda.matmul.allow_tf32
307
+ try:
308
+ if cudnn is not None:
309
+ torch.backends.cudnn.allow_tf32 = cudnn
310
+ if matmul is not None:
311
+ torch.backends.cuda.matmul.allow_tf32 = matmul
312
+ yield
313
+ finally:
314
+ if cudnn is not None:
315
+ torch.backends.cudnn.allow_tf32 = cudnn_old
316
+ if matmul is not None:
317
+ torch.backends.cuda.matmul.allow_tf32 = matmul_old
ldm_patched/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1st edit by https://github.com/CompVis/latent-diffusion
2
+ # 2nd edit by https://github.com/Stability-AI/stablediffusion
3
+ # 3rd edit by https://github.com/Stability-AI/generative-models
4
+ # 4th edit by https://github.com/comfyanonymous/ComfyUI
5
+ # 5th edit by Forge
6
+
7
+
8
+ import torch
9
+ # import pytorch_lightning as pl
10
+ import torch.nn.functional as F
11
+ from contextlib import contextmanager
12
+ from typing import Any, Dict, List, Optional, Tuple, Union
13
+
14
+ from ldm_patched.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
15
+
16
+ from ldm_patched.ldm.util import instantiate_from_config
17
+ from ldm_patched.ldm.modules.ema import LitEma
18
+ import ldm_patched.modules.ops
19
+
20
+ class DiagonalGaussianRegularizer(torch.nn.Module):
21
+ def __init__(self, sample: bool = True):
22
+ super().__init__()
23
+ self.sample = sample
24
+
25
+ def get_trainable_parameters(self) -> Any:
26
+ yield from ()
27
+
28
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
29
+ log = dict()
30
+ posterior = DiagonalGaussianDistribution(z)
31
+ if self.sample:
32
+ z = posterior.sample()
33
+ else:
34
+ z = posterior.mode()
35
+ kl_loss = posterior.kl()
36
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
37
+ log["kl_loss"] = kl_loss
38
+ return z, log
39
+
40
+
41
+ class AbstractAutoencoder(torch.nn.Module):
42
+ """
43
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
44
+ unCLIP models, etc. Hence, it is fairly general, and specific features
45
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ ema_decay: Union[None, float] = None,
51
+ monitor: Union[None, str] = None,
52
+ input_key: str = "jpg",
53
+ **kwargs,
54
+ ):
55
+ super().__init__()
56
+
57
+ self.input_key = input_key
58
+ self.use_ema = ema_decay is not None
59
+ if monitor is not None:
60
+ self.monitor = monitor
61
+
62
+ if self.use_ema:
63
+ self.model_ema = LitEma(self, decay=ema_decay)
64
+ logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
65
+
66
+ def get_input(self, batch) -> Any:
67
+ raise NotImplementedError()
68
+
69
+ def on_train_batch_end(self, *args, **kwargs):
70
+ # for EMA computation
71
+ if self.use_ema:
72
+ self.model_ema(self)
73
+
74
+ @contextmanager
75
+ def ema_scope(self, context=None):
76
+ if self.use_ema:
77
+ self.model_ema.store(self.parameters())
78
+ self.model_ema.copy_to(self)
79
+ if context is not None:
80
+ logpy.info(f"{context}: Switched to EMA weights")
81
+ try:
82
+ yield None
83
+ finally:
84
+ if self.use_ema:
85
+ self.model_ema.restore(self.parameters())
86
+ if context is not None:
87
+ logpy.info(f"{context}: Restored training weights")
88
+
89
+ def encode(self, *args, **kwargs) -> torch.Tensor:
90
+ raise NotImplementedError("encode()-method of abstract base class called")
91
+
92
+ def decode(self, *args, **kwargs) -> torch.Tensor:
93
+ raise NotImplementedError("decode()-method of abstract base class called")
94
+
95
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
96
+ logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
97
+ return get_obj_from_str(cfg["target"])(
98
+ params, lr=lr, **cfg.get("params", dict())
99
+ )
100
+
101
+ def configure_optimizers(self) -> Any:
102
+ raise NotImplementedError()
103
+
104
+
105
+ class AutoencodingEngine(AbstractAutoencoder):
106
+ """
107
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
108
+ (we also restore them explicitly as special cases for legacy reasons).
109
+ Regularizations such as KL or VQ are moved to the regularizer class.
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ *args,
115
+ encoder_config: Dict,
116
+ decoder_config: Dict,
117
+ regularizer_config: Dict,
118
+ **kwargs,
119
+ ):
120
+ super().__init__(*args, **kwargs)
121
+
122
+ self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
123
+ self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
124
+ self.regularization: AbstractRegularizer = instantiate_from_config(
125
+ regularizer_config
126
+ )
127
+
128
+ def get_last_layer(self):
129
+ return self.decoder.get_last_layer()
130
+
131
+ def encode(
132
+ self,
133
+ x: torch.Tensor,
134
+ return_reg_log: bool = False,
135
+ unregularized: bool = False,
136
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
137
+ z = self.encoder(x)
138
+ if unregularized:
139
+ return z, dict()
140
+ z, reg_log = self.regularization(z)
141
+ if return_reg_log:
142
+ return z, reg_log
143
+ return z
144
+
145
+ def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
146
+ x = self.decoder(z, **kwargs)
147
+ return x
148
+
149
+ def forward(
150
+ self, x: torch.Tensor, **additional_decode_kwargs
151
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
152
+ z, reg_log = self.encode(x, return_reg_log=True)
153
+ dec = self.decode(z, **additional_decode_kwargs)
154
+ return z, dec, reg_log
155
+
156
+
157
+ class AutoencodingEngineLegacy(AutoencodingEngine):
158
+ def __init__(self, embed_dim: int, **kwargs):
159
+ self.max_batch_size = kwargs.pop("max_batch_size", None)
160
+ ddconfig = kwargs.pop("ddconfig")
161
+ super().__init__(
162
+ encoder_config={
163
+ "target": "ldm_patched.ldm.modules.diffusionmodules.model.Encoder",
164
+ "params": ddconfig,
165
+ },
166
+ decoder_config={
167
+ "target": "ldm_patched.ldm.modules.diffusionmodules.model.Decoder",
168
+ "params": ddconfig,
169
+ },
170
+ **kwargs,
171
+ )
172
+ self.quant_conv = ldm_patched.modules.ops.disable_weight_init.Conv2d(
173
+ (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
174
+ (1 + ddconfig["double_z"]) * embed_dim,
175
+ 1,
176
+ )
177
+ self.post_quant_conv = ldm_patched.modules.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
178
+ self.embed_dim = embed_dim
179
+
180
+ def get_autoencoder_params(self) -> list:
181
+ params = super().get_autoencoder_params()
182
+ return params
183
+
184
+ def encode(
185
+ self, x: torch.Tensor, return_reg_log: bool = False
186
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
187
+ if self.max_batch_size is None:
188
+ z = self.encoder(x)
189
+ z = self.quant_conv(z)
190
+ else:
191
+ N = x.shape[0]
192
+ bs = self.max_batch_size
193
+ n_batches = int(math.ceil(N / bs))
194
+ z = list()
195
+ for i_batch in range(n_batches):
196
+ z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
197
+ z_batch = self.quant_conv(z_batch)
198
+ z.append(z_batch)
199
+ z = torch.cat(z, 0)
200
+
201
+ z, reg_log = self.regularization(z)
202
+ if return_reg_log:
203
+ return z, reg_log
204
+ return z
205
+
206
+ def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
207
+ if self.max_batch_size is None:
208
+ dec = self.post_quant_conv(z)
209
+ dec = self.decoder(dec, **decoder_kwargs)
210
+ else:
211
+ N = z.shape[0]
212
+ bs = self.max_batch_size
213
+ n_batches = int(math.ceil(N / bs))
214
+ dec = list()
215
+ for i_batch in range(n_batches):
216
+ dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
217
+ dec_batch = self.decoder(dec_batch, **decoder_kwargs)
218
+ dec.append(dec_batch)
219
+ dec = torch.cat(dec, 0)
220
+
221
+ return dec
222
+
223
+
224
+ class AutoencoderKL(AutoencodingEngineLegacy):
225
+ def __init__(self, **kwargs):
226
+ if "lossconfig" in kwargs:
227
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
228
+ super().__init__(
229
+ regularizer_config={
230
+ "target": (
231
+ "ldm_patched.ldm.models.autoencoder.DiagonalGaussianRegularizer"
232
+ )
233
+ },
234
+ **kwargs,
235
+ )
ldm_patched/ldm/modules/attention.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1st edit by https://github.com/CompVis/latent-diffusion
2
+ # 2nd edit by https://github.com/Stability-AI/stablediffusion
3
+ # 3rd edit by https://github.com/Stability-AI/generative-models
4
+ # 4th edit by https://github.com/comfyanonymous/ComfyUI
5
+ # 5th edit by Forge
6
+
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn, einsum
12
+ from einops import rearrange, repeat
13
+ from typing import Optional, Any
14
+
15
+ from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
16
+ from .sub_quadratic_attention import efficient_dot_product_attention
17
+
18
+ from ldm_patched.modules import model_management
19
+
20
+ if model_management.xformers_enabled():
21
+ import xformers
22
+ import xformers.ops
23
+
24
+ from ldm_patched.modules.args_parser import args
25
+ import ldm_patched.modules.ops
26
+ ops = ldm_patched.modules.ops.disable_weight_init
27
+
28
+ # CrossAttn precision handling
29
+ if args.disable_attention_upcast:
30
+ print("disabling upcasting of attention")
31
+ _ATTN_PRECISION = "fp16"
32
+ else:
33
+ _ATTN_PRECISION = "fp32"
34
+
35
+
36
+ def exists(val):
37
+ return val is not None
38
+
39
+
40
+ def uniq(arr):
41
+ return{el: True for el in arr}.keys()
42
+
43
+
44
+ def default(val, d):
45
+ if exists(val):
46
+ return val
47
+ return d
48
+
49
+
50
+ def max_neg_value(t):
51
+ return -torch.finfo(t.dtype).max
52
+
53
+
54
+ def init_(tensor):
55
+ dim = tensor.shape[-1]
56
+ std = 1 / math.sqrt(dim)
57
+ tensor.uniform_(-std, std)
58
+ return tensor
59
+
60
+
61
+ # feedforward
62
+ class GEGLU(nn.Module):
63
+ def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
64
+ super().__init__()
65
+ self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
66
+
67
+ def forward(self, x):
68
+ x, gate = self.proj(x).chunk(2, dim=-1)
69
+ return x * F.gelu(gate)
70
+
71
+
72
+ class FeedForward(nn.Module):
73
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=ops):
74
+ super().__init__()
75
+ inner_dim = int(dim * mult)
76
+ dim_out = default(dim_out, dim)
77
+ project_in = nn.Sequential(
78
+ operations.Linear(dim, inner_dim, dtype=dtype, device=device),
79
+ nn.GELU()
80
+ ) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device, operations=operations)
81
+
82
+ self.net = nn.Sequential(
83
+ project_in,
84
+ nn.Dropout(dropout),
85
+ operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
86
+ )
87
+
88
+ def forward(self, x):
89
+ return self.net(x)
90
+
91
+ def Normalize(in_channels, dtype=None, device=None):
92
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
93
+
94
+ def attention_basic(q, k, v, heads, mask=None):
95
+ b, _, dim_head = q.shape
96
+ dim_head //= heads
97
+ scale = dim_head ** -0.5
98
+
99
+ h = heads
100
+ q, k, v = map(
101
+ lambda t: t.unsqueeze(3)
102
+ .reshape(b, -1, heads, dim_head)
103
+ .permute(0, 2, 1, 3)
104
+ .reshape(b * heads, -1, dim_head)
105
+ .contiguous(),
106
+ (q, k, v),
107
+ )
108
+
109
+ # force cast to fp32 to avoid overflowing
110
+ if _ATTN_PRECISION =="fp32":
111
+ sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
112
+ else:
113
+ sim = einsum('b i d, b j d -> b i j', q, k) * scale
114
+
115
+ del q, k
116
+
117
+ if exists(mask):
118
+ if mask.dtype == torch.bool:
119
+ mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
120
+ max_neg_value = -torch.finfo(sim.dtype).max
121
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
122
+ sim.masked_fill_(~mask, max_neg_value)
123
+ else:
124
+ sim += mask
125
+
126
+ # attention, what we cannot get enough of
127
+ sim = sim.softmax(dim=-1)
128
+
129
+ out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
130
+ out = (
131
+ out.unsqueeze(0)
132
+ .reshape(b, heads, -1, dim_head)
133
+ .permute(0, 2, 1, 3)
134
+ .reshape(b, -1, heads * dim_head)
135
+ )
136
+ return out
137
+
138
+
139
+ def attention_sub_quad(query, key, value, heads, mask=None):
140
+ b, _, dim_head = query.shape
141
+ dim_head //= heads
142
+
143
+ scale = dim_head ** -0.5
144
+ query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
145
+ value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
146
+
147
+ key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
148
+
149
+ dtype = query.dtype
150
+ upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
151
+ if upcast_attention:
152
+ bytes_per_token = torch.finfo(torch.float32).bits//8
153
+ else:
154
+ bytes_per_token = torch.finfo(query.dtype).bits//8
155
+ batch_x_heads, q_tokens, _ = query.shape
156
+ _, _, k_tokens = key.shape
157
+ qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
158
+
159
+ mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
160
+
161
+ kv_chunk_size_min = None
162
+ kv_chunk_size = None
163
+ query_chunk_size = None
164
+
165
+ for x in [4096, 2048, 1024, 512, 256]:
166
+ count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0)
167
+ if count >= k_tokens:
168
+ kv_chunk_size = k_tokens
169
+ query_chunk_size = x
170
+ break
171
+
172
+ if query_chunk_size is None:
173
+ query_chunk_size = 512
174
+
175
+ hidden_states = efficient_dot_product_attention(
176
+ query,
177
+ key,
178
+ value,
179
+ query_chunk_size=query_chunk_size,
180
+ kv_chunk_size=kv_chunk_size,
181
+ kv_chunk_size_min=kv_chunk_size_min,
182
+ use_checkpoint=False,
183
+ upcast_attention=upcast_attention,
184
+ mask=mask,
185
+ )
186
+
187
+ hidden_states = hidden_states.to(dtype)
188
+
189
+ hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
190
+ return hidden_states
191
+
192
+ def attention_split(q, k, v, heads, mask=None):
193
+ b, _, dim_head = q.shape
194
+ dim_head //= heads
195
+ scale = dim_head ** -0.5
196
+
197
+ h = heads
198
+ q, k, v = map(
199
+ lambda t: t.unsqueeze(3)
200
+ .reshape(b, -1, heads, dim_head)
201
+ .permute(0, 2, 1, 3)
202
+ .reshape(b * heads, -1, dim_head)
203
+ .contiguous(),
204
+ (q, k, v),
205
+ )
206
+
207
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
208
+
209
+ mem_free_total = model_management.get_free_memory(q.device)
210
+
211
+ if _ATTN_PRECISION =="fp32":
212
+ element_size = 4
213
+ else:
214
+ element_size = q.element_size()
215
+
216
+ gb = 1024 ** 3
217
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
218
+ modifier = 3
219
+ mem_required = tensor_size * modifier
220
+ steps = 1
221
+
222
+
223
+ if mem_required > mem_free_total:
224
+ steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
225
+ # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
226
+ # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
227
+
228
+ if steps > 64:
229
+ max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
230
+ raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
231
+ f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
232
+
233
+ # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
234
+ first_op_done = False
235
+ cleared_cache = False
236
+ while True:
237
+ try:
238
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
239
+ for i in range(0, q.shape[1], slice_size):
240
+ end = i + slice_size
241
+ if _ATTN_PRECISION =="fp32":
242
+ with torch.autocast(enabled=False, device_type = 'cuda'):
243
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
244
+ else:
245
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
246
+
247
+ if mask is not None:
248
+ if len(mask.shape) == 2:
249
+ s1 += mask[i:end]
250
+ else:
251
+ s1 += mask[:, i:end]
252
+
253
+ s2 = s1.softmax(dim=-1).to(v.dtype)
254
+ del s1
255
+ first_op_done = True
256
+
257
+ r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
258
+ del s2
259
+ break
260
+ except model_management.OOM_EXCEPTION as e:
261
+ if first_op_done == False:
262
+ model_management.soft_empty_cache(True)
263
+ if cleared_cache == False:
264
+ cleared_cache = True
265
+ print("out of memory error, emptying cache and trying again")
266
+ continue
267
+ steps *= 2
268
+ if steps > 64:
269
+ raise e
270
+ print("out of memory error, increasing steps and trying again", steps)
271
+ else:
272
+ raise e
273
+
274
+ del q, k, v
275
+
276
+ r1 = (
277
+ r1.unsqueeze(0)
278
+ .reshape(b, heads, -1, dim_head)
279
+ .permute(0, 2, 1, 3)
280
+ .reshape(b, -1, heads * dim_head)
281
+ )
282
+ return r1
283
+
284
+ BROKEN_XFORMERS = False
285
+ try:
286
+ x_vers = xformers.__version__
287
+ #I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error)
288
+ BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23")
289
+ except:
290
+ pass
291
+
292
+ def attention_xformers(q, k, v, heads, mask=None):
293
+ b, _, dim_head = q.shape
294
+ dim_head //= heads
295
+ if BROKEN_XFORMERS:
296
+ if b * heads > 65535:
297
+ return attention_pytorch(q, k, v, heads, mask)
298
+
299
+ q, k, v = map(
300
+ lambda t: t.unsqueeze(3)
301
+ .reshape(b, -1, heads, dim_head)
302
+ .permute(0, 2, 1, 3)
303
+ .reshape(b * heads, -1, dim_head)
304
+ .contiguous(),
305
+ (q, k, v),
306
+ )
307
+
308
+ if mask is not None:
309
+ pad = 8 - q.shape[1] % 8
310
+ mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
311
+ mask_out[:, :, :mask.shape[-1]] = mask
312
+ mask = mask_out[:, :, :mask.shape[-1]]
313
+
314
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
315
+
316
+ out = (
317
+ out.unsqueeze(0)
318
+ .reshape(b, heads, -1, dim_head)
319
+ .permute(0, 2, 1, 3)
320
+ .reshape(b, -1, heads * dim_head)
321
+ )
322
+ return out
323
+
324
+ def attention_pytorch(q, k, v, heads, mask=None):
325
+ b, _, dim_head = q.shape
326
+ dim_head //= heads
327
+ q, k, v = map(
328
+ lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
329
+ (q, k, v),
330
+ )
331
+
332
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
333
+ out = (
334
+ out.transpose(1, 2).reshape(b, -1, heads * dim_head)
335
+ )
336
+ return out
337
+
338
+
339
+ optimized_attention = attention_basic
340
+
341
+ if model_management.xformers_enabled():
342
+ print("Using xformers cross attention")
343
+ optimized_attention = attention_xformers
344
+ elif model_management.pytorch_attention_enabled():
345
+ print("Using pytorch cross attention")
346
+ optimized_attention = attention_pytorch
347
+ else:
348
+ if args.attention_split:
349
+ print("Using split optimization for cross attention")
350
+ optimized_attention = attention_split
351
+ else:
352
+ print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --attention-split")
353
+ optimized_attention = attention_sub_quad
354
+
355
+ optimized_attention_masked = optimized_attention
356
+
357
+ def optimized_attention_for_device(device, mask=False, small_input=False):
358
+ if small_input:
359
+ if model_management.pytorch_attention_enabled():
360
+ return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
361
+ else:
362
+ return attention_basic
363
+
364
+ if device == torch.device("cpu"):
365
+ return attention_sub_quad
366
+
367
+ if mask:
368
+ return optimized_attention_masked
369
+
370
+ return optimized_attention
371
+
372
+
373
+ class CrossAttention(nn.Module):
374
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
375
+ super().__init__()
376
+ inner_dim = dim_head * heads
377
+ context_dim = default(context_dim, query_dim)
378
+
379
+ self.heads = heads
380
+ self.dim_head = dim_head
381
+
382
+ self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
383
+ self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
384
+ self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
385
+
386
+ self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
387
+
388
+ def forward(self, x, context=None, value=None, mask=None, transformer_options=None):
389
+ q = self.to_q(x)
390
+ context = default(context, x)
391
+ k = self.to_k(context)
392
+ if value is not None:
393
+ v = self.to_v(value)
394
+ del value
395
+ else:
396
+ v = self.to_v(context)
397
+
398
+ if mask is None:
399
+ out = optimized_attention(q, k, v, self.heads)
400
+ else:
401
+ out = optimized_attention_masked(q, k, v, self.heads, mask)
402
+ return self.to_out(out)
403
+
404
+
405
+ class BasicTransformerBlock(nn.Module):
406
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
407
+ disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops):
408
+ super().__init__()
409
+
410
+ self.ff_in = ff_in or inner_dim is not None
411
+ if inner_dim is None:
412
+ inner_dim = dim
413
+
414
+ self.is_res = inner_dim == dim
415
+
416
+ if self.ff_in:
417
+ self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
418
+ self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
419
+
420
+ self.disable_self_attn = disable_self_attn
421
+ self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
422
+ context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
423
+ self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
424
+
425
+ if disable_temporal_crossattention:
426
+ if switch_temporal_ca_to_sa:
427
+ raise ValueError
428
+ else:
429
+ self.attn2 = None
430
+ else:
431
+ context_dim_attn2 = None
432
+ if not switch_temporal_ca_to_sa:
433
+ context_dim_attn2 = context_dim
434
+
435
+ self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
436
+ heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
437
+ self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
438
+
439
+ self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
440
+ self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
441
+ self.checkpoint = checkpoint
442
+ self.n_heads = n_heads
443
+ self.d_head = d_head
444
+ self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
445
+
446
+ def forward(self, x, context=None, transformer_options={}):
447
+ return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
448
+
449
+ def _forward(self, x, context=None, transformer_options={}):
450
+ extra_options = {}
451
+ block = transformer_options.get("block", None)
452
+ block_index = transformer_options.get("block_index", 0)
453
+ transformer_patches = {}
454
+ transformer_patches_replace = {}
455
+
456
+ for k in transformer_options:
457
+ if k == "patches":
458
+ transformer_patches = transformer_options[k]
459
+ elif k == "patches_replace":
460
+ transformer_patches_replace = transformer_options[k]
461
+ else:
462
+ extra_options[k] = transformer_options[k]
463
+
464
+ extra_options["n_heads"] = self.n_heads
465
+ extra_options["dim_head"] = self.d_head
466
+
467
+ if self.ff_in:
468
+ x_skip = x
469
+ x = self.ff_in(self.norm_in(x))
470
+ if self.is_res:
471
+ x += x_skip
472
+
473
+ n = self.norm1(x)
474
+ if self.disable_self_attn:
475
+ context_attn1 = context
476
+ else:
477
+ context_attn1 = None
478
+ value_attn1 = None
479
+
480
+ if "attn1_patch" in transformer_patches:
481
+ patch = transformer_patches["attn1_patch"]
482
+ if context_attn1 is None:
483
+ context_attn1 = n
484
+ value_attn1 = context_attn1
485
+ for p in patch:
486
+ n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
487
+
488
+ if block is not None:
489
+ transformer_block = (block[0], block[1], block_index)
490
+ else:
491
+ transformer_block = None
492
+ attn1_replace_patch = transformer_patches_replace.get("attn1", {})
493
+ block_attn1 = transformer_block
494
+ if block_attn1 not in attn1_replace_patch:
495
+ block_attn1 = block
496
+
497
+ if block_attn1 in attn1_replace_patch:
498
+ if context_attn1 is None:
499
+ context_attn1 = n
500
+ value_attn1 = n
501
+ n = self.attn1.to_q(n)
502
+ context_attn1 = self.attn1.to_k(context_attn1)
503
+ value_attn1 = self.attn1.to_v(value_attn1)
504
+ n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
505
+ n = self.attn1.to_out(n)
506
+ else:
507
+ n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=extra_options)
508
+
509
+ if "attn1_output_patch" in transformer_patches:
510
+ patch = transformer_patches["attn1_output_patch"]
511
+ for p in patch:
512
+ n = p(n, extra_options)
513
+
514
+ x += n
515
+ if "middle_patch" in transformer_patches:
516
+ patch = transformer_patches["middle_patch"]
517
+ for p in patch:
518
+ x = p(x, extra_options)
519
+
520
+ if self.attn2 is not None:
521
+ n = self.norm2(x)
522
+ if self.switch_temporal_ca_to_sa:
523
+ context_attn2 = n
524
+ else:
525
+ context_attn2 = context
526
+ value_attn2 = None
527
+ if "attn2_patch" in transformer_patches:
528
+ patch = transformer_patches["attn2_patch"]
529
+ value_attn2 = context_attn2
530
+ for p in patch:
531
+ n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
532
+
533
+ attn2_replace_patch = transformer_patches_replace.get("attn2", {})
534
+ block_attn2 = transformer_block
535
+ if block_attn2 not in attn2_replace_patch:
536
+ block_attn2 = block
537
+
538
+ if block_attn2 in attn2_replace_patch:
539
+ if value_attn2 is None:
540
+ value_attn2 = context_attn2
541
+ n = self.attn2.to_q(n)
542
+ context_attn2 = self.attn2.to_k(context_attn2)
543
+ value_attn2 = self.attn2.to_v(value_attn2)
544
+ n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
545
+ n = self.attn2.to_out(n)
546
+ else:
547
+ n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=extra_options)
548
+
549
+ if "attn2_output_patch" in transformer_patches:
550
+ patch = transformer_patches["attn2_output_patch"]
551
+ for p in patch:
552
+ n = p(n, extra_options)
553
+
554
+ x += n
555
+ if self.is_res:
556
+ x_skip = x
557
+ x = self.ff(self.norm3(x))
558
+ if self.is_res:
559
+ x += x_skip
560
+
561
+ return x
562
+
563
+
564
+ class SpatialTransformer(nn.Module):
565
+ """
566
+ Transformer block for image-like data.
567
+ First, project the input (aka embedding)
568
+ and reshape to b, t, d.
569
+ Then apply standard transformer action.
570
+ Finally, reshape to image
571
+ NEW: use_linear for more efficiency instead of the 1x1 convs
572
+ """
573
+ def __init__(self, in_channels, n_heads, d_head,
574
+ depth=1, dropout=0., context_dim=None,
575
+ disable_self_attn=False, use_linear=False,
576
+ use_checkpoint=True, dtype=None, device=None, operations=ops):
577
+ super().__init__()
578
+ if exists(context_dim) and not isinstance(context_dim, list):
579
+ context_dim = [context_dim] * depth
580
+ self.in_channels = in_channels
581
+ inner_dim = n_heads * d_head
582
+ self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
583
+ if not use_linear:
584
+ self.proj_in = operations.Conv2d(in_channels,
585
+ inner_dim,
586
+ kernel_size=1,
587
+ stride=1,
588
+ padding=0, dtype=dtype, device=device)
589
+ else:
590
+ self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
591
+
592
+ self.transformer_blocks = nn.ModuleList(
593
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
594
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations)
595
+ for d in range(depth)]
596
+ )
597
+ if not use_linear:
598
+ self.proj_out = operations.Conv2d(inner_dim,in_channels,
599
+ kernel_size=1,
600
+ stride=1,
601
+ padding=0, dtype=dtype, device=device)
602
+ else:
603
+ self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
604
+ self.use_linear = use_linear
605
+
606
+ def forward(self, x, context=None, transformer_options={}):
607
+ # note: if no context is given, cross-attention defaults to self-attention
608
+ if not isinstance(context, list):
609
+ context = [context] * len(self.transformer_blocks)
610
+ b, c, h, w = x.shape
611
+ x_in = x
612
+ x = self.norm(x)
613
+ if not self.use_linear:
614
+ x = self.proj_in(x)
615
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
616
+ if self.use_linear:
617
+ x = self.proj_in(x)
618
+ for i, block in enumerate(self.transformer_blocks):
619
+ transformer_options["block_index"] = i
620
+ x = block(x, context=context[i], transformer_options=transformer_options)
621
+ if self.use_linear:
622
+ x = self.proj_out(x)
623
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
624
+ if not self.use_linear:
625
+ x = self.proj_out(x)
626
+ return x + x_in
627
+
628
+
629
+ class SpatialVideoTransformer(SpatialTransformer):
630
+ def __init__(
631
+ self,
632
+ in_channels,
633
+ n_heads,
634
+ d_head,
635
+ depth=1,
636
+ dropout=0.0,
637
+ use_linear=False,
638
+ context_dim=None,
639
+ use_spatial_context=False,
640
+ timesteps=None,
641
+ merge_strategy: str = "fixed",
642
+ merge_factor: float = 0.5,
643
+ time_context_dim=None,
644
+ ff_in=False,
645
+ checkpoint=False,
646
+ time_depth=1,
647
+ disable_self_attn=False,
648
+ disable_temporal_crossattention=False,
649
+ max_time_embed_period: int = 10000,
650
+ dtype=None, device=None, operations=ops
651
+ ):
652
+ super().__init__(
653
+ in_channels,
654
+ n_heads,
655
+ d_head,
656
+ depth=depth,
657
+ dropout=dropout,
658
+ use_checkpoint=checkpoint,
659
+ context_dim=context_dim,
660
+ use_linear=use_linear,
661
+ disable_self_attn=disable_self_attn,
662
+ dtype=dtype, device=device, operations=operations
663
+ )
664
+ self.time_depth = time_depth
665
+ self.depth = depth
666
+ self.max_time_embed_period = max_time_embed_period
667
+
668
+ time_mix_d_head = d_head
669
+ n_time_mix_heads = n_heads
670
+
671
+ time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
672
+
673
+ inner_dim = n_heads * d_head
674
+ if use_spatial_context:
675
+ time_context_dim = context_dim
676
+
677
+ self.time_stack = nn.ModuleList(
678
+ [
679
+ BasicTransformerBlock(
680
+ inner_dim,
681
+ n_time_mix_heads,
682
+ time_mix_d_head,
683
+ dropout=dropout,
684
+ context_dim=time_context_dim,
685
+ # timesteps=timesteps,
686
+ checkpoint=checkpoint,
687
+ ff_in=ff_in,
688
+ inner_dim=time_mix_inner_dim,
689
+ disable_self_attn=disable_self_attn,
690
+ disable_temporal_crossattention=disable_temporal_crossattention,
691
+ dtype=dtype, device=device, operations=operations
692
+ )
693
+ for _ in range(self.depth)
694
+ ]
695
+ )
696
+
697
+ assert len(self.time_stack) == len(self.transformer_blocks)
698
+
699
+ self.use_spatial_context = use_spatial_context
700
+ self.in_channels = in_channels
701
+
702
+ time_embed_dim = self.in_channels * 4
703
+ self.time_pos_embed = nn.Sequential(
704
+ operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device),
705
+ nn.SiLU(),
706
+ operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device),
707
+ )
708
+
709
+ self.time_mixer = AlphaBlender(
710
+ alpha=merge_factor, merge_strategy=merge_strategy
711
+ )
712
+
713
+ def forward(
714
+ self,
715
+ x: torch.Tensor,
716
+ context: Optional[torch.Tensor] = None,
717
+ time_context: Optional[torch.Tensor] = None,
718
+ timesteps: Optional[int] = None,
719
+ image_only_indicator: Optional[torch.Tensor] = None,
720
+ transformer_options={}
721
+ ) -> torch.Tensor:
722
+ _, _, h, w = x.shape
723
+ x_in = x
724
+ spatial_context = None
725
+ if exists(context):
726
+ spatial_context = context
727
+
728
+ if self.use_spatial_context:
729
+ assert (
730
+ context.ndim == 3
731
+ ), f"n dims of spatial context should be 3 but are {context.ndim}"
732
+
733
+ if time_context is None:
734
+ time_context = context
735
+ time_context_first_timestep = time_context[::timesteps]
736
+ time_context = repeat(
737
+ time_context_first_timestep, "b ... -> (b n) ...", n=h * w
738
+ )
739
+ elif time_context is not None and not self.use_spatial_context:
740
+ time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
741
+ if time_context.ndim == 2:
742
+ time_context = rearrange(time_context, "b c -> b 1 c")
743
+
744
+ x = self.norm(x)
745
+ if not self.use_linear:
746
+ x = self.proj_in(x)
747
+ x = rearrange(x, "b c h w -> b (h w) c")
748
+ if self.use_linear:
749
+ x = self.proj_in(x)
750
+
751
+ num_frames = torch.arange(timesteps, device=x.device)
752
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
753
+ num_frames = rearrange(num_frames, "b t -> (b t)")
754
+ t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype)
755
+ emb = self.time_pos_embed(t_emb)
756
+ emb = emb[:, None, :]
757
+
758
+ for it_, (block, mix_block) in enumerate(
759
+ zip(self.transformer_blocks, self.time_stack)
760
+ ):
761
+ transformer_options["block_index"] = it_
762
+ x = block(
763
+ x,
764
+ context=spatial_context,
765
+ transformer_options=transformer_options,
766
+ )
767
+
768
+ x_mix = x
769
+ x_mix = x_mix + emb
770
+
771
+ B, S, C = x_mix.shape
772
+ x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
773
+ x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
774
+ x_mix = rearrange(
775
+ x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
776
+ )
777
+
778
+ x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator)
779
+
780
+ if self.use_linear:
781
+ x = self.proj_out(x)
782
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
783
+ if not self.use_linear:
784
+ x = self.proj_out(x)
785
+ out = x + x_in
786
+ return out
787
+
788
+
ldm_patched/ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
ldm_patched/ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1st edit by https://github.com/CompVis/latent-diffusion
2
+ # 2nd edit by https://github.com/Stability-AI/stablediffusion
3
+ # 3rd edit by https://github.com/Stability-AI/generative-models
4
+ # 4th edit by https://github.com/comfyanonymous/ComfyUI
5
+ # 5th edit by Forge
6
+
7
+
8
+ # pytorch_diffusion + derived encoder decoder
9
+ import math
10
+ import torch
11
+ import torch.nn as nn
12
+ import numpy as np
13
+ from einops import rearrange
14
+ from typing import Optional, Any
15
+
16
+ from ldm_patched.modules import model_management
17
+ import ldm_patched.modules.ops
18
+ ops = ldm_patched.modules.ops.disable_weight_init
19
+
20
+ if model_management.xformers_enabled_vae():
21
+ import xformers
22
+ import xformers.ops
23
+
24
+ def get_timestep_embedding(timesteps, embedding_dim):
25
+ """
26
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
27
+ From Fairseq.
28
+ Build sinusoidal embeddings.
29
+ This matches the implementation in tensor2tensor, but differs slightly
30
+ from the description in Section 3.5 of "Attention Is All You Need".
31
+ """
32
+ assert len(timesteps.shape) == 1
33
+
34
+ half_dim = embedding_dim // 2
35
+ emb = math.log(10000) / (half_dim - 1)
36
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
37
+ emb = emb.to(device=timesteps.device)
38
+ emb = timesteps.float()[:, None] * emb[None, :]
39
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
40
+ if embedding_dim % 2 == 1: # zero pad
41
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
42
+ return emb
43
+
44
+
45
+ def nonlinearity(x):
46
+ # swish
47
+ return x*torch.sigmoid(x)
48
+
49
+
50
+ def Normalize(in_channels, num_groups=32):
51
+ return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
52
+
53
+
54
+ class Upsample(nn.Module):
55
+ def __init__(self, in_channels, with_conv):
56
+ super().__init__()
57
+ self.with_conv = with_conv
58
+ if self.with_conv:
59
+ self.conv = ops.Conv2d(in_channels,
60
+ in_channels,
61
+ kernel_size=3,
62
+ stride=1,
63
+ padding=1)
64
+
65
+ def forward(self, x):
66
+ try:
67
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
68
+ except: #operation not implemented for bf16
69
+ b, c, h, w = x.shape
70
+ out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
71
+ split = 8
72
+ l = out.shape[1] // split
73
+ for i in range(0, out.shape[1], l):
74
+ out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
75
+ del x
76
+ x = out
77
+
78
+ if self.with_conv:
79
+ x = self.conv(x)
80
+ return x
81
+
82
+
83
+ class Downsample(nn.Module):
84
+ def __init__(self, in_channels, with_conv):
85
+ super().__init__()
86
+ self.with_conv = with_conv
87
+ if self.with_conv:
88
+ # no asymmetric padding in torch conv, must do it ourselves
89
+ self.conv = ops.Conv2d(in_channels,
90
+ in_channels,
91
+ kernel_size=3,
92
+ stride=2,
93
+ padding=0)
94
+
95
+ def forward(self, x):
96
+ if self.with_conv:
97
+ pad = (0,1,0,1)
98
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
99
+ x = self.conv(x)
100
+ else:
101
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
102
+ return x
103
+
104
+
105
+ class ResnetBlock(nn.Module):
106
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
107
+ dropout, temb_channels=512):
108
+ super().__init__()
109
+ self.in_channels = in_channels
110
+ out_channels = in_channels if out_channels is None else out_channels
111
+ self.out_channels = out_channels
112
+ self.use_conv_shortcut = conv_shortcut
113
+
114
+ self.swish = torch.nn.SiLU(inplace=True)
115
+ self.norm1 = Normalize(in_channels)
116
+ self.conv1 = ops.Conv2d(in_channels,
117
+ out_channels,
118
+ kernel_size=3,
119
+ stride=1,
120
+ padding=1)
121
+ if temb_channels > 0:
122
+ self.temb_proj = ops.Linear(temb_channels,
123
+ out_channels)
124
+ self.norm2 = Normalize(out_channels)
125
+ self.dropout = torch.nn.Dropout(dropout, inplace=True)
126
+ self.conv2 = ops.Conv2d(out_channels,
127
+ out_channels,
128
+ kernel_size=3,
129
+ stride=1,
130
+ padding=1)
131
+ if self.in_channels != self.out_channels:
132
+ if self.use_conv_shortcut:
133
+ self.conv_shortcut = ops.Conv2d(in_channels,
134
+ out_channels,
135
+ kernel_size=3,
136
+ stride=1,
137
+ padding=1)
138
+ else:
139
+ self.nin_shortcut = ops.Conv2d(in_channels,
140
+ out_channels,
141
+ kernel_size=1,
142
+ stride=1,
143
+ padding=0)
144
+
145
+ def forward(self, x, temb):
146
+ h = x
147
+ h = self.norm1(h)
148
+ h = self.swish(h)
149
+ h = self.conv1(h)
150
+
151
+ if temb is not None:
152
+ h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
153
+
154
+ h = self.norm2(h)
155
+ h = self.swish(h)
156
+ h = self.dropout(h)
157
+ h = self.conv2(h)
158
+
159
+ if self.in_channels != self.out_channels:
160
+ if self.use_conv_shortcut:
161
+ x = self.conv_shortcut(x)
162
+ else:
163
+ x = self.nin_shortcut(x)
164
+
165
+ return x+h
166
+
167
+ def slice_attention(q, k, v):
168
+ r1 = torch.zeros_like(k, device=q.device)
169
+ scale = (int(q.shape[-1])**(-0.5))
170
+
171
+ mem_free_total = model_management.get_free_memory(q.device)
172
+
173
+ gb = 1024 ** 3
174
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
175
+ modifier = 3 if q.element_size() == 2 else 2.5
176
+ mem_required = tensor_size * modifier
177
+ steps = 1
178
+
179
+ if mem_required > mem_free_total:
180
+ steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
181
+
182
+ while True:
183
+ try:
184
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
185
+ for i in range(0, q.shape[1], slice_size):
186
+ end = i + slice_size
187
+ s1 = torch.bmm(q[:, i:end], k) * scale
188
+
189
+ s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
190
+ del s1
191
+
192
+ r1[:, :, i:end] = torch.bmm(v, s2)
193
+ del s2
194
+ break
195
+ except model_management.OOM_EXCEPTION as e:
196
+ model_management.soft_empty_cache(True)
197
+ steps *= 2
198
+ if steps > 128:
199
+ raise e
200
+ print("out of memory error, increasing steps and trying again", steps)
201
+
202
+ return r1
203
+
204
+ def normal_attention(q, k, v):
205
+ # compute attention
206
+ b,c,h,w = q.shape
207
+
208
+ q = q.reshape(b,c,h*w)
209
+ q = q.permute(0,2,1) # b,hw,c
210
+ k = k.reshape(b,c,h*w) # b,c,hw
211
+ v = v.reshape(b,c,h*w)
212
+
213
+ r1 = slice_attention(q, k, v)
214
+ h_ = r1.reshape(b,c,h,w)
215
+ del r1
216
+ return h_
217
+
218
+ def xformers_attention(q, k, v):
219
+ # compute attention
220
+ B, C, H, W = q.shape
221
+ q, k, v = map(
222
+ lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
223
+ (q, k, v),
224
+ )
225
+
226
+ try:
227
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
228
+ out = out.transpose(1, 2).reshape(B, C, H, W)
229
+ except NotImplementedError as e:
230
+ out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
231
+ return out
232
+
233
+ def pytorch_attention(q, k, v):
234
+ # compute attention
235
+ B, C, H, W = q.shape
236
+ q, k, v = map(
237
+ lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
238
+ (q, k, v),
239
+ )
240
+
241
+ try:
242
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
243
+ out = out.transpose(2, 3).reshape(B, C, H, W)
244
+ except model_management.OOM_EXCEPTION as e:
245
+ print("scaled_dot_product_attention OOMed: switched to slice attention")
246
+ out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
247
+ return out
248
+
249
+
250
+ class AttnBlock(nn.Module):
251
+ def __init__(self, in_channels):
252
+ super().__init__()
253
+ self.in_channels = in_channels
254
+
255
+ self.norm = Normalize(in_channels)
256
+ self.q = ops.Conv2d(in_channels,
257
+ in_channels,
258
+ kernel_size=1,
259
+ stride=1,
260
+ padding=0)
261
+ self.k = ops.Conv2d(in_channels,
262
+ in_channels,
263
+ kernel_size=1,
264
+ stride=1,
265
+ padding=0)
266
+ self.v = ops.Conv2d(in_channels,
267
+ in_channels,
268
+ kernel_size=1,
269
+ stride=1,
270
+ padding=0)
271
+ self.proj_out = ops.Conv2d(in_channels,
272
+ in_channels,
273
+ kernel_size=1,
274
+ stride=1,
275
+ padding=0)
276
+
277
+ if model_management.xformers_enabled_vae():
278
+ print("Using xformers attention in VAE")
279
+ self.optimized_attention = xformers_attention
280
+ elif model_management.pytorch_attention_enabled():
281
+ print("Using pytorch attention in VAE")
282
+ self.optimized_attention = pytorch_attention
283
+ else:
284
+ print("Using split attention in VAE")
285
+ self.optimized_attention = normal_attention
286
+
287
+ def forward(self, x):
288
+ h_ = x
289
+ h_ = self.norm(h_)
290
+ q = self.q(h_)
291
+ k = self.k(h_)
292
+ v = self.v(h_)
293
+
294
+ h_ = self.optimized_attention(q, k, v)
295
+
296
+ h_ = self.proj_out(h_)
297
+
298
+ return x+h_
299
+
300
+
301
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
302
+ return AttnBlock(in_channels)
303
+
304
+
305
+ class Model(nn.Module):
306
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
307
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
308
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
309
+ super().__init__()
310
+ if use_linear_attn: attn_type = "linear"
311
+ self.ch = ch
312
+ self.temb_ch = self.ch*4
313
+ self.num_resolutions = len(ch_mult)
314
+ self.num_res_blocks = num_res_blocks
315
+ self.resolution = resolution
316
+ self.in_channels = in_channels
317
+
318
+ self.use_timestep = use_timestep
319
+ if self.use_timestep:
320
+ # timestep embedding
321
+ self.temb = nn.Module()
322
+ self.temb.dense = nn.ModuleList([
323
+ ops.Linear(self.ch,
324
+ self.temb_ch),
325
+ ops.Linear(self.temb_ch,
326
+ self.temb_ch),
327
+ ])
328
+
329
+ # downsampling
330
+ self.conv_in = ops.Conv2d(in_channels,
331
+ self.ch,
332
+ kernel_size=3,
333
+ stride=1,
334
+ padding=1)
335
+
336
+ curr_res = resolution
337
+ in_ch_mult = (1,)+tuple(ch_mult)
338
+ self.down = nn.ModuleList()
339
+ for i_level in range(self.num_resolutions):
340
+ block = nn.ModuleList()
341
+ attn = nn.ModuleList()
342
+ block_in = ch*in_ch_mult[i_level]
343
+ block_out = ch*ch_mult[i_level]
344
+ for i_block in range(self.num_res_blocks):
345
+ block.append(ResnetBlock(in_channels=block_in,
346
+ out_channels=block_out,
347
+ temb_channels=self.temb_ch,
348
+ dropout=dropout))
349
+ block_in = block_out
350
+ if curr_res in attn_resolutions:
351
+ attn.append(make_attn(block_in, attn_type=attn_type))
352
+ down = nn.Module()
353
+ down.block = block
354
+ down.attn = attn
355
+ if i_level != self.num_resolutions-1:
356
+ down.downsample = Downsample(block_in, resamp_with_conv)
357
+ curr_res = curr_res // 2
358
+ self.down.append(down)
359
+
360
+ # middle
361
+ self.mid = nn.Module()
362
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
363
+ out_channels=block_in,
364
+ temb_channels=self.temb_ch,
365
+ dropout=dropout)
366
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
367
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
368
+ out_channels=block_in,
369
+ temb_channels=self.temb_ch,
370
+ dropout=dropout)
371
+
372
+ # upsampling
373
+ self.up = nn.ModuleList()
374
+ for i_level in reversed(range(self.num_resolutions)):
375
+ block = nn.ModuleList()
376
+ attn = nn.ModuleList()
377
+ block_out = ch*ch_mult[i_level]
378
+ skip_in = ch*ch_mult[i_level]
379
+ for i_block in range(self.num_res_blocks+1):
380
+ if i_block == self.num_res_blocks:
381
+ skip_in = ch*in_ch_mult[i_level]
382
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
383
+ out_channels=block_out,
384
+ temb_channels=self.temb_ch,
385
+ dropout=dropout))
386
+ block_in = block_out
387
+ if curr_res in attn_resolutions:
388
+ attn.append(make_attn(block_in, attn_type=attn_type))
389
+ up = nn.Module()
390
+ up.block = block
391
+ up.attn = attn
392
+ if i_level != 0:
393
+ up.upsample = Upsample(block_in, resamp_with_conv)
394
+ curr_res = curr_res * 2
395
+ self.up.insert(0, up) # prepend to get consistent order
396
+
397
+ # end
398
+ self.norm_out = Normalize(block_in)
399
+ self.conv_out = ops.Conv2d(block_in,
400
+ out_ch,
401
+ kernel_size=3,
402
+ stride=1,
403
+ padding=1)
404
+
405
+ def forward(self, x, t=None, context=None):
406
+ #assert x.shape[2] == x.shape[3] == self.resolution
407
+ if context is not None:
408
+ # assume aligned context, cat along channel axis
409
+ x = torch.cat((x, context), dim=1)
410
+ if self.use_timestep:
411
+ # timestep embedding
412
+ assert t is not None
413
+ temb = get_timestep_embedding(t, self.ch)
414
+ temb = self.temb.dense[0](temb)
415
+ temb = nonlinearity(temb)
416
+ temb = self.temb.dense[1](temb)
417
+ else:
418
+ temb = None
419
+
420
+ # downsampling
421
+ hs = [self.conv_in(x)]
422
+ for i_level in range(self.num_resolutions):
423
+ for i_block in range(self.num_res_blocks):
424
+ h = self.down[i_level].block[i_block](hs[-1], temb)
425
+ if len(self.down[i_level].attn) > 0:
426
+ h = self.down[i_level].attn[i_block](h)
427
+ hs.append(h)
428
+ if i_level != self.num_resolutions-1:
429
+ hs.append(self.down[i_level].downsample(hs[-1]))
430
+
431
+ # middle
432
+ h = hs[-1]
433
+ h = self.mid.block_1(h, temb)
434
+ h = self.mid.attn_1(h)
435
+ h = self.mid.block_2(h, temb)
436
+
437
+ # upsampling
438
+ for i_level in reversed(range(self.num_resolutions)):
439
+ for i_block in range(self.num_res_blocks+1):
440
+ h = self.up[i_level].block[i_block](
441
+ torch.cat([h, hs.pop()], dim=1), temb)
442
+ if len(self.up[i_level].attn) > 0:
443
+ h = self.up[i_level].attn[i_block](h)
444
+ if i_level != 0:
445
+ h = self.up[i_level].upsample(h)
446
+
447
+ # end
448
+ h = self.norm_out(h)
449
+ h = nonlinearity(h)
450
+ h = self.conv_out(h)
451
+ return h
452
+
453
+ def get_last_layer(self):
454
+ return self.conv_out.weight
455
+
456
+
457
+ class Encoder(nn.Module):
458
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
459
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
460
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
461
+ **ignore_kwargs):
462
+ super().__init__()
463
+ if use_linear_attn: attn_type = "linear"
464
+ self.ch = ch
465
+ self.temb_ch = 0
466
+ self.num_resolutions = len(ch_mult)
467
+ self.num_res_blocks = num_res_blocks
468
+ self.resolution = resolution
469
+ self.in_channels = in_channels
470
+
471
+ # downsampling
472
+ self.conv_in = ops.Conv2d(in_channels,
473
+ self.ch,
474
+ kernel_size=3,
475
+ stride=1,
476
+ padding=1)
477
+
478
+ curr_res = resolution
479
+ in_ch_mult = (1,)+tuple(ch_mult)
480
+ self.in_ch_mult = in_ch_mult
481
+ self.down = nn.ModuleList()
482
+ for i_level in range(self.num_resolutions):
483
+ block = nn.ModuleList()
484
+ attn = nn.ModuleList()
485
+ block_in = ch*in_ch_mult[i_level]
486
+ block_out = ch*ch_mult[i_level]
487
+ for i_block in range(self.num_res_blocks):
488
+ block.append(ResnetBlock(in_channels=block_in,
489
+ out_channels=block_out,
490
+ temb_channels=self.temb_ch,
491
+ dropout=dropout))
492
+ block_in = block_out
493
+ if curr_res in attn_resolutions:
494
+ attn.append(make_attn(block_in, attn_type=attn_type))
495
+ down = nn.Module()
496
+ down.block = block
497
+ down.attn = attn
498
+ if i_level != self.num_resolutions-1:
499
+ down.downsample = Downsample(block_in, resamp_with_conv)
500
+ curr_res = curr_res // 2
501
+ self.down.append(down)
502
+
503
+ # middle
504
+ self.mid = nn.Module()
505
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
506
+ out_channels=block_in,
507
+ temb_channels=self.temb_ch,
508
+ dropout=dropout)
509
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
510
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
511
+ out_channels=block_in,
512
+ temb_channels=self.temb_ch,
513
+ dropout=dropout)
514
+
515
+ # end
516
+ self.norm_out = Normalize(block_in)
517
+ self.conv_out = ops.Conv2d(block_in,
518
+ 2*z_channels if double_z else z_channels,
519
+ kernel_size=3,
520
+ stride=1,
521
+ padding=1)
522
+
523
+ def forward(self, x):
524
+ # timestep embedding
525
+ temb = None
526
+ # downsampling
527
+ h = self.conv_in(x)
528
+ for i_level in range(self.num_resolutions):
529
+ for i_block in range(self.num_res_blocks):
530
+ h = self.down[i_level].block[i_block](h, temb)
531
+ if len(self.down[i_level].attn) > 0:
532
+ h = self.down[i_level].attn[i_block](h)
533
+ if i_level != self.num_resolutions-1:
534
+ h = self.down[i_level].downsample(h)
535
+
536
+ # middle
537
+ h = self.mid.block_1(h, temb)
538
+ h = self.mid.attn_1(h)
539
+ h = self.mid.block_2(h, temb)
540
+
541
+ # end
542
+ h = self.norm_out(h)
543
+ h = nonlinearity(h)
544
+ h = self.conv_out(h)
545
+ return h
546
+
547
+
548
+ class Decoder(nn.Module):
549
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
550
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
551
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
552
+ conv_out_op=ops.Conv2d,
553
+ resnet_op=ResnetBlock,
554
+ attn_op=AttnBlock,
555
+ **ignorekwargs):
556
+ super().__init__()
557
+ if use_linear_attn: attn_type = "linear"
558
+ self.ch = ch
559
+ self.temb_ch = 0
560
+ self.num_resolutions = len(ch_mult)
561
+ self.num_res_blocks = num_res_blocks
562
+ self.resolution = resolution
563
+ self.in_channels = in_channels
564
+ self.give_pre_end = give_pre_end
565
+ self.tanh_out = tanh_out
566
+
567
+ # compute in_ch_mult, block_in and curr_res at lowest res
568
+ in_ch_mult = (1,)+tuple(ch_mult)
569
+ block_in = ch*ch_mult[self.num_resolutions-1]
570
+ curr_res = resolution // 2**(self.num_resolutions-1)
571
+ self.z_shape = (1,z_channels,curr_res,curr_res)
572
+ print("Working with z of shape {} = {} dimensions.".format(
573
+ self.z_shape, np.prod(self.z_shape)))
574
+
575
+ # z to block_in
576
+ self.conv_in = ops.Conv2d(z_channels,
577
+ block_in,
578
+ kernel_size=3,
579
+ stride=1,
580
+ padding=1)
581
+
582
+ # middle
583
+ self.mid = nn.Module()
584
+ self.mid.block_1 = resnet_op(in_channels=block_in,
585
+ out_channels=block_in,
586
+ temb_channels=self.temb_ch,
587
+ dropout=dropout)
588
+ self.mid.attn_1 = attn_op(block_in)
589
+ self.mid.block_2 = resnet_op(in_channels=block_in,
590
+ out_channels=block_in,
591
+ temb_channels=self.temb_ch,
592
+ dropout=dropout)
593
+
594
+ # upsampling
595
+ self.up = nn.ModuleList()
596
+ for i_level in reversed(range(self.num_resolutions)):
597
+ block = nn.ModuleList()
598
+ attn = nn.ModuleList()
599
+ block_out = ch*ch_mult[i_level]
600
+ for i_block in range(self.num_res_blocks+1):
601
+ block.append(resnet_op(in_channels=block_in,
602
+ out_channels=block_out,
603
+ temb_channels=self.temb_ch,
604
+ dropout=dropout))
605
+ block_in = block_out
606
+ if curr_res in attn_resolutions:
607
+ attn.append(attn_op(block_in))
608
+ up = nn.Module()
609
+ up.block = block
610
+ up.attn = attn
611
+ if i_level != 0:
612
+ up.upsample = Upsample(block_in, resamp_with_conv)
613
+ curr_res = curr_res * 2
614
+ self.up.insert(0, up) # prepend to get consistent order
615
+
616
+ # end
617
+ self.norm_out = Normalize(block_in)
618
+ self.conv_out = conv_out_op(block_in,
619
+ out_ch,
620
+ kernel_size=3,
621
+ stride=1,
622
+ padding=1)
623
+
624
+ def forward(self, z, **kwargs):
625
+ #assert z.shape[1:] == self.z_shape[1:]
626
+ self.last_z_shape = z.shape
627
+
628
+ # timestep embedding
629
+ temb = None
630
+
631
+ # z to block_in
632
+ h = self.conv_in(z)
633
+
634
+ # middle
635
+ h = self.mid.block_1(h, temb, **kwargs)
636
+ h = self.mid.attn_1(h, **kwargs)
637
+ h = self.mid.block_2(h, temb, **kwargs)
638
+
639
+ # upsampling
640
+ for i_level in reversed(range(self.num_resolutions)):
641
+ for i_block in range(self.num_res_blocks+1):
642
+ h = self.up[i_level].block[i_block](h, temb, **kwargs)
643
+ if len(self.up[i_level].attn) > 0:
644
+ h = self.up[i_level].attn[i_block](h, **kwargs)
645
+ if i_level != 0:
646
+ h = self.up[i_level].upsample(h)
647
+
648
+ # end
649
+ if self.give_pre_end:
650
+ return h
651
+
652
+ h = self.norm_out(h)
653
+ h = nonlinearity(h)
654
+ h = self.conv_out(h, **kwargs)
655
+ if self.tanh_out:
656
+ h = torch.tanh(h)
657
+ return h
ldm_patched/ldm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,933 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1st edit by https://github.com/CompVis/latent-diffusion
2
+ # 2nd edit by https://github.com/Stability-AI/stablediffusion
3
+ # 3rd edit by https://github.com/Stability-AI/generative-models
4
+ # 4th edit by https://github.com/comfyanonymous/ComfyUI
5
+ # 5th edit by Forge
6
+
7
+
8
+ from abc import abstractmethod
9
+
10
+ import torch as th
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+
15
+ from .util import (
16
+ checkpoint,
17
+ avg_pool_nd,
18
+ zero_module,
19
+ timestep_embedding,
20
+ AlphaBlender,
21
+ )
22
+ from ..attention import SpatialTransformer, SpatialVideoTransformer, default
23
+ from ldm_patched.ldm.util import exists
24
+ import ldm_patched.modules.ops
25
+ ops = ldm_patched.modules.ops.disable_weight_init
26
+
27
+ class TimestepBlock(nn.Module):
28
+ """
29
+ Any module where forward() takes timestep embeddings as a second argument.
30
+ """
31
+
32
+ @abstractmethod
33
+ def forward(self, x, emb):
34
+ """
35
+ Apply the module to `x` given `emb` timestep embeddings.
36
+ """
37
+
38
+ #This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
39
+ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
40
+ block_inner_modifiers = transformer_options.get("block_inner_modifiers", [])
41
+
42
+ for layer_index, layer in enumerate(ts):
43
+ for modifier in block_inner_modifiers:
44
+ x = modifier(x, 'before', layer, layer_index, ts, transformer_options)
45
+
46
+ if isinstance(layer, VideoResBlock):
47
+ x = layer(x, emb, num_video_frames, image_only_indicator)
48
+ elif isinstance(layer, TimestepBlock):
49
+ x = layer(x, emb)
50
+ elif isinstance(layer, SpatialVideoTransformer):
51
+ x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
52
+ if "transformer_index" in transformer_options:
53
+ transformer_options["transformer_index"] += 1
54
+ elif isinstance(layer, SpatialTransformer):
55
+ x = layer(x, context, transformer_options)
56
+ if "transformer_index" in transformer_options:
57
+ transformer_options["transformer_index"] += 1
58
+ elif isinstance(layer, Upsample):
59
+ x = layer(x, output_shape=output_shape)
60
+ else:
61
+ x = layer(x)
62
+
63
+ for modifier in block_inner_modifiers:
64
+ x = modifier(x, 'after', layer, layer_index, ts, transformer_options)
65
+ return x
66
+
67
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
68
+ """
69
+ A sequential module that passes timestep embeddings to the children that
70
+ support it as an extra input.
71
+ """
72
+
73
+ def forward(self, *args, **kwargs):
74
+ return forward_timestep_embed(self, *args, **kwargs)
75
+
76
+ class Upsample(nn.Module):
77
+ """
78
+ An upsampling layer with an optional convolution.
79
+ :param channels: channels in the inputs and outputs.
80
+ :param use_conv: a bool determining if a convolution is applied.
81
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
82
+ upsampling occurs in the inner-two dimensions.
83
+ """
84
+
85
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
86
+ super().__init__()
87
+ self.channels = channels
88
+ self.out_channels = out_channels or channels
89
+ self.use_conv = use_conv
90
+ self.dims = dims
91
+ if use_conv:
92
+ self.conv = operations.conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device)
93
+
94
+ def forward(self, x, output_shape=None):
95
+ assert x.shape[1] == self.channels
96
+ if self.dims == 3:
97
+ shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
98
+ if output_shape is not None:
99
+ shape[1] = output_shape[3]
100
+ shape[2] = output_shape[4]
101
+ else:
102
+ shape = [x.shape[2] * 2, x.shape[3] * 2]
103
+ if output_shape is not None:
104
+ shape[0] = output_shape[2]
105
+ shape[1] = output_shape[3]
106
+
107
+ x = F.interpolate(x, size=shape, mode="nearest")
108
+ if self.use_conv:
109
+ x = self.conv(x)
110
+ return x
111
+
112
+ class Downsample(nn.Module):
113
+ """
114
+ A downsampling layer with an optional convolution.
115
+ :param channels: channels in the inputs and outputs.
116
+ :param use_conv: a bool determining if a convolution is applied.
117
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
118
+ downsampling occurs in the inner-two dimensions.
119
+ """
120
+
121
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
122
+ super().__init__()
123
+ self.channels = channels
124
+ self.out_channels = out_channels or channels
125
+ self.use_conv = use_conv
126
+ self.dims = dims
127
+ stride = 2 if dims != 3 else (1, 2, 2)
128
+ if use_conv:
129
+ self.op = operations.conv_nd(
130
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device
131
+ )
132
+ else:
133
+ assert self.channels == self.out_channels
134
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
135
+
136
+ def forward(self, x):
137
+ assert x.shape[1] == self.channels
138
+ return self.op(x)
139
+
140
+
141
+ class ResBlock(TimestepBlock):
142
+ """
143
+ A residual block that can optionally change the number of channels.
144
+ :param channels: the number of input channels.
145
+ :param emb_channels: the number of timestep embedding channels.
146
+ :param dropout: the rate of dropout.
147
+ :param out_channels: if specified, the number of out channels.
148
+ :param use_conv: if True and out_channels is specified, use a spatial
149
+ convolution instead of a smaller 1x1 convolution to change the
150
+ channels in the skip connection.
151
+ :param dims: determines if the signal is 1D, 2D, or 3D.
152
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
153
+ :param up: if True, use this block for upsampling.
154
+ :param down: if True, use this block for downsampling.
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ channels,
160
+ emb_channels,
161
+ dropout,
162
+ out_channels=None,
163
+ use_conv=False,
164
+ use_scale_shift_norm=False,
165
+ dims=2,
166
+ use_checkpoint=False,
167
+ up=False,
168
+ down=False,
169
+ kernel_size=3,
170
+ exchange_temb_dims=False,
171
+ skip_t_emb=False,
172
+ dtype=None,
173
+ device=None,
174
+ operations=ops
175
+ ):
176
+ super().__init__()
177
+ self.channels = channels
178
+ self.emb_channels = emb_channels
179
+ self.dropout = dropout
180
+ self.out_channels = out_channels or channels
181
+ self.use_conv = use_conv
182
+ self.use_checkpoint = use_checkpoint
183
+ self.use_scale_shift_norm = use_scale_shift_norm
184
+ self.exchange_temb_dims = exchange_temb_dims
185
+
186
+ if isinstance(kernel_size, list):
187
+ padding = [k // 2 for k in kernel_size]
188
+ else:
189
+ padding = kernel_size // 2
190
+
191
+ self.in_layers = nn.Sequential(
192
+ operations.GroupNorm(32, channels, dtype=dtype, device=device),
193
+ nn.SiLU(),
194
+ operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
195
+ )
196
+
197
+ self.updown = up or down
198
+
199
+ if up:
200
+ self.h_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
201
+ self.x_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
202
+ elif down:
203
+ self.h_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
204
+ self.x_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
205
+ else:
206
+ self.h_upd = self.x_upd = nn.Identity()
207
+
208
+ self.skip_t_emb = skip_t_emb
209
+ if self.skip_t_emb:
210
+ self.emb_layers = None
211
+ self.exchange_temb_dims = False
212
+ else:
213
+ self.emb_layers = nn.Sequential(
214
+ nn.SiLU(),
215
+ operations.Linear(
216
+ emb_channels,
217
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
218
+ ),
219
+ )
220
+ self.out_layers = nn.Sequential(
221
+ operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
222
+ nn.SiLU(),
223
+ nn.Dropout(p=dropout),
224
+ operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
225
+ ,
226
+ )
227
+
228
+ if self.out_channels == channels:
229
+ self.skip_connection = nn.Identity()
230
+ elif use_conv:
231
+ self.skip_connection = operations.conv_nd(
232
+ dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device
233
+ )
234
+ else:
235
+ self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
236
+
237
+ def forward(self, x, emb):
238
+ """
239
+ Apply the block to a Tensor, conditioned on a timestep embedding.
240
+ :param x: an [N x C x ...] Tensor of features.
241
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
242
+ :return: an [N x C x ...] Tensor of outputs.
243
+ """
244
+ return checkpoint(
245
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
246
+ )
247
+
248
+
249
+ def _forward(self, x, emb):
250
+ if self.updown:
251
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
252
+ h = in_rest(x)
253
+ h = self.h_upd(h)
254
+ x = self.x_upd(x)
255
+ h = in_conv(h)
256
+ else:
257
+ h = self.in_layers(x)
258
+
259
+ emb_out = None
260
+ if not self.skip_t_emb:
261
+ emb_out = self.emb_layers(emb).type(h.dtype)
262
+ while len(emb_out.shape) < len(h.shape):
263
+ emb_out = emb_out[..., None]
264
+ if self.use_scale_shift_norm:
265
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
266
+ h = out_norm(h)
267
+ if emb_out is not None:
268
+ scale, shift = th.chunk(emb_out, 2, dim=1)
269
+ h *= (1 + scale)
270
+ h += shift
271
+ h = out_rest(h)
272
+ else:
273
+ if emb_out is not None:
274
+ if self.exchange_temb_dims:
275
+ emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
276
+ h = h + emb_out
277
+ h = self.out_layers(h)
278
+ return self.skip_connection(x) + h
279
+
280
+
281
+ class VideoResBlock(ResBlock):
282
+ def __init__(
283
+ self,
284
+ channels: int,
285
+ emb_channels: int,
286
+ dropout: float,
287
+ video_kernel_size=3,
288
+ merge_strategy: str = "fixed",
289
+ merge_factor: float = 0.5,
290
+ out_channels=None,
291
+ use_conv: bool = False,
292
+ use_scale_shift_norm: bool = False,
293
+ dims: int = 2,
294
+ use_checkpoint: bool = False,
295
+ up: bool = False,
296
+ down: bool = False,
297
+ dtype=None,
298
+ device=None,
299
+ operations=ops
300
+ ):
301
+ super().__init__(
302
+ channels,
303
+ emb_channels,
304
+ dropout,
305
+ out_channels=out_channels,
306
+ use_conv=use_conv,
307
+ use_scale_shift_norm=use_scale_shift_norm,
308
+ dims=dims,
309
+ use_checkpoint=use_checkpoint,
310
+ up=up,
311
+ down=down,
312
+ dtype=dtype,
313
+ device=device,
314
+ operations=operations
315
+ )
316
+
317
+ self.time_stack = ResBlock(
318
+ default(out_channels, channels),
319
+ emb_channels,
320
+ dropout=dropout,
321
+ dims=3,
322
+ out_channels=default(out_channels, channels),
323
+ use_scale_shift_norm=False,
324
+ use_conv=False,
325
+ up=False,
326
+ down=False,
327
+ kernel_size=video_kernel_size,
328
+ use_checkpoint=use_checkpoint,
329
+ exchange_temb_dims=True,
330
+ dtype=dtype,
331
+ device=device,
332
+ operations=operations
333
+ )
334
+ self.time_mixer = AlphaBlender(
335
+ alpha=merge_factor,
336
+ merge_strategy=merge_strategy,
337
+ rearrange_pattern="b t -> b 1 t 1 1",
338
+ )
339
+
340
+ def forward(
341
+ self,
342
+ x: th.Tensor,
343
+ emb: th.Tensor,
344
+ num_video_frames: int,
345
+ image_only_indicator = None,
346
+ ) -> th.Tensor:
347
+ x = super().forward(x, emb)
348
+
349
+ x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
350
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
351
+
352
+ x = self.time_stack(
353
+ x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
354
+ )
355
+ x = self.time_mixer(
356
+ x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
357
+ )
358
+ x = rearrange(x, "b c t h w -> (b t) c h w")
359
+ return x
360
+
361
+
362
+ class Timestep(nn.Module):
363
+ def __init__(self, dim):
364
+ super().__init__()
365
+ self.dim = dim
366
+
367
+ def forward(self, t):
368
+ return timestep_embedding(t, self.dim)
369
+
370
+ def apply_control(h, control, name):
371
+ if control is not None and name in control and len(control[name]) > 0:
372
+ ctrl = control[name].pop()
373
+ if ctrl is not None:
374
+ try:
375
+ h += ctrl
376
+ except:
377
+ print("warning control could not be applied", h.shape, ctrl.shape)
378
+ return h
379
+
380
+ class UNetModel(nn.Module):
381
+ """
382
+ The full UNet model with attention and timestep embedding.
383
+ :param in_channels: channels in the input Tensor.
384
+ :param model_channels: base channel count for the model.
385
+ :param out_channels: channels in the output Tensor.
386
+ :param num_res_blocks: number of residual blocks per downsample.
387
+ :param dropout: the dropout probability.
388
+ :param channel_mult: channel multiplier for each level of the UNet.
389
+ :param conv_resample: if True, use learned convolutions for upsampling and
390
+ downsampling.
391
+ :param dims: determines if the signal is 1D, 2D, or 3D.
392
+ :param num_classes: if specified (as an int), then this model will be
393
+ class-conditional with `num_classes` classes.
394
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
395
+ :param num_heads: the number of attention heads in each attention layer.
396
+ :param num_heads_channels: if specified, ignore num_heads and instead use
397
+ a fixed channel width per attention head.
398
+ :param num_heads_upsample: works with num_heads to set a different number
399
+ of heads for upsampling. Deprecated.
400
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
401
+ :param resblock_updown: use residual blocks for up/downsampling.
402
+ :param use_new_attention_order: use a different attention pattern for potentially
403
+ increased efficiency.
404
+ """
405
+
406
+ def __init__(
407
+ self,
408
+ image_size,
409
+ in_channels,
410
+ model_channels,
411
+ out_channels,
412
+ num_res_blocks,
413
+ dropout=0,
414
+ channel_mult=(1, 2, 4, 8),
415
+ conv_resample=True,
416
+ dims=2,
417
+ num_classes=None,
418
+ use_checkpoint=False,
419
+ dtype=th.float32,
420
+ num_heads=-1,
421
+ num_head_channels=-1,
422
+ num_heads_upsample=-1,
423
+ use_scale_shift_norm=False,
424
+ resblock_updown=False,
425
+ use_new_attention_order=False,
426
+ use_spatial_transformer=False, # custom transformer support
427
+ transformer_depth=1, # custom transformer support
428
+ context_dim=None, # custom transformer support
429
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
430
+ legacy=True,
431
+ disable_self_attentions=None,
432
+ num_attention_blocks=None,
433
+ disable_middle_self_attn=False,
434
+ use_linear_in_transformer=False,
435
+ adm_in_channels=None,
436
+ transformer_depth_middle=None,
437
+ transformer_depth_output=None,
438
+ use_temporal_resblock=False,
439
+ use_temporal_attention=False,
440
+ time_context_dim=None,
441
+ extra_ff_mix_layer=False,
442
+ use_spatial_context=False,
443
+ merge_strategy=None,
444
+ merge_factor=0.0,
445
+ video_kernel_size=None,
446
+ disable_temporal_crossattention=False,
447
+ max_ddpm_temb_period=10000,
448
+ device=None,
449
+ operations=ops,
450
+ ):
451
+ super().__init__()
452
+
453
+ if context_dim is not None:
454
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
455
+ # from omegaconf.listconfig import ListConfig
456
+ # if type(context_dim) == ListConfig:
457
+ # context_dim = list(context_dim)
458
+
459
+ if num_heads_upsample == -1:
460
+ num_heads_upsample = num_heads
461
+
462
+ if num_heads == -1:
463
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
464
+
465
+ if num_head_channels == -1:
466
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
467
+
468
+ self.in_channels = in_channels
469
+ self.model_channels = model_channels
470
+ self.out_channels = out_channels
471
+
472
+ if isinstance(num_res_blocks, int):
473
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
474
+ else:
475
+ if len(num_res_blocks) != len(channel_mult):
476
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
477
+ "as a list/tuple (per-level) with the same length as channel_mult")
478
+ self.num_res_blocks = num_res_blocks
479
+
480
+ if disable_self_attentions is not None:
481
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
482
+ assert len(disable_self_attentions) == len(channel_mult)
483
+ if num_attention_blocks is not None:
484
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
485
+
486
+ transformer_depth = transformer_depth[:]
487
+ transformer_depth_output = transformer_depth_output[:]
488
+
489
+ self.dropout = dropout
490
+ self.channel_mult = channel_mult
491
+ self.conv_resample = conv_resample
492
+ self.num_classes = num_classes
493
+ self.use_checkpoint = use_checkpoint
494
+ self.dtype = dtype
495
+ self.num_heads = num_heads
496
+ self.num_head_channels = num_head_channels
497
+ self.num_heads_upsample = num_heads_upsample
498
+ self.use_temporal_resblocks = use_temporal_resblock
499
+ self.predict_codebook_ids = n_embed is not None
500
+
501
+ self.default_num_video_frames = None
502
+ self.default_image_only_indicator = None
503
+
504
+ time_embed_dim = model_channels * 4
505
+ self.time_embed = nn.Sequential(
506
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
507
+ nn.SiLU(),
508
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
509
+ )
510
+
511
+ if self.num_classes is not None:
512
+ if isinstance(self.num_classes, int):
513
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
514
+ elif self.num_classes == "continuous":
515
+ print("setting up linear c_adm embedding layer")
516
+ self.label_emb = nn.Linear(1, time_embed_dim)
517
+ elif self.num_classes == "sequential":
518
+ assert adm_in_channels is not None
519
+ self.label_emb = nn.Sequential(
520
+ nn.Sequential(
521
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
522
+ nn.SiLU(),
523
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
524
+ )
525
+ )
526
+ else:
527
+ raise ValueError()
528
+
529
+ self.input_blocks = nn.ModuleList(
530
+ [
531
+ TimestepEmbedSequential(
532
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
533
+ )
534
+ ]
535
+ )
536
+ self._feature_size = model_channels
537
+ input_block_chans = [model_channels]
538
+ ch = model_channels
539
+ ds = 1
540
+
541
+ def get_attention_layer(
542
+ ch,
543
+ num_heads,
544
+ dim_head,
545
+ depth=1,
546
+ context_dim=None,
547
+ use_checkpoint=False,
548
+ disable_self_attn=False,
549
+ ):
550
+ if use_temporal_attention:
551
+ return SpatialVideoTransformer(
552
+ ch,
553
+ num_heads,
554
+ dim_head,
555
+ depth=depth,
556
+ context_dim=context_dim,
557
+ time_context_dim=time_context_dim,
558
+ dropout=dropout,
559
+ ff_in=extra_ff_mix_layer,
560
+ use_spatial_context=use_spatial_context,
561
+ merge_strategy=merge_strategy,
562
+ merge_factor=merge_factor,
563
+ checkpoint=use_checkpoint,
564
+ use_linear=use_linear_in_transformer,
565
+ disable_self_attn=disable_self_attn,
566
+ disable_temporal_crossattention=disable_temporal_crossattention,
567
+ max_time_embed_period=max_ddpm_temb_period,
568
+ dtype=self.dtype, device=device, operations=operations
569
+ )
570
+ else:
571
+ return SpatialTransformer(
572
+ ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
573
+ disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
574
+ use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
575
+ )
576
+
577
+ def get_resblock(
578
+ merge_factor,
579
+ merge_strategy,
580
+ video_kernel_size,
581
+ ch,
582
+ time_embed_dim,
583
+ dropout,
584
+ out_channels,
585
+ dims,
586
+ use_checkpoint,
587
+ use_scale_shift_norm,
588
+ down=False,
589
+ up=False,
590
+ dtype=None,
591
+ device=None,
592
+ operations=ops
593
+ ):
594
+ if self.use_temporal_resblocks:
595
+ return VideoResBlock(
596
+ merge_factor=merge_factor,
597
+ merge_strategy=merge_strategy,
598
+ video_kernel_size=video_kernel_size,
599
+ channels=ch,
600
+ emb_channels=time_embed_dim,
601
+ dropout=dropout,
602
+ out_channels=out_channels,
603
+ dims=dims,
604
+ use_checkpoint=use_checkpoint,
605
+ use_scale_shift_norm=use_scale_shift_norm,
606
+ down=down,
607
+ up=up,
608
+ dtype=dtype,
609
+ device=device,
610
+ operations=operations
611
+ )
612
+ else:
613
+ return ResBlock(
614
+ channels=ch,
615
+ emb_channels=time_embed_dim,
616
+ dropout=dropout,
617
+ out_channels=out_channels,
618
+ use_checkpoint=use_checkpoint,
619
+ dims=dims,
620
+ use_scale_shift_norm=use_scale_shift_norm,
621
+ down=down,
622
+ up=up,
623
+ dtype=dtype,
624
+ device=device,
625
+ operations=operations
626
+ )
627
+
628
+ for level, mult in enumerate(channel_mult):
629
+ for nr in range(self.num_res_blocks[level]):
630
+ layers = [
631
+ get_resblock(
632
+ merge_factor=merge_factor,
633
+ merge_strategy=merge_strategy,
634
+ video_kernel_size=video_kernel_size,
635
+ ch=ch,
636
+ time_embed_dim=time_embed_dim,
637
+ dropout=dropout,
638
+ out_channels=mult * model_channels,
639
+ dims=dims,
640
+ use_checkpoint=use_checkpoint,
641
+ use_scale_shift_norm=use_scale_shift_norm,
642
+ dtype=self.dtype,
643
+ device=device,
644
+ operations=operations,
645
+ )
646
+ ]
647
+ ch = mult * model_channels
648
+ num_transformers = transformer_depth.pop(0)
649
+ if num_transformers > 0:
650
+ if num_head_channels == -1:
651
+ dim_head = ch // num_heads
652
+ else:
653
+ num_heads = ch // num_head_channels
654
+ dim_head = num_head_channels
655
+ if legacy:
656
+ #num_heads = 1
657
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
658
+ if exists(disable_self_attentions):
659
+ disabled_sa = disable_self_attentions[level]
660
+ else:
661
+ disabled_sa = False
662
+
663
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
664
+ layers.append(get_attention_layer(
665
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
666
+ disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint)
667
+ )
668
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
669
+ self._feature_size += ch
670
+ input_block_chans.append(ch)
671
+ if level != len(channel_mult) - 1:
672
+ out_ch = ch
673
+ self.input_blocks.append(
674
+ TimestepEmbedSequential(
675
+ get_resblock(
676
+ merge_factor=merge_factor,
677
+ merge_strategy=merge_strategy,
678
+ video_kernel_size=video_kernel_size,
679
+ ch=ch,
680
+ time_embed_dim=time_embed_dim,
681
+ dropout=dropout,
682
+ out_channels=out_ch,
683
+ dims=dims,
684
+ use_checkpoint=use_checkpoint,
685
+ use_scale_shift_norm=use_scale_shift_norm,
686
+ down=True,
687
+ dtype=self.dtype,
688
+ device=device,
689
+ operations=operations
690
+ )
691
+ if resblock_updown
692
+ else Downsample(
693
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
694
+ )
695
+ )
696
+ )
697
+ ch = out_ch
698
+ input_block_chans.append(ch)
699
+ ds *= 2
700
+ self._feature_size += ch
701
+
702
+ if num_head_channels == -1:
703
+ dim_head = ch // num_heads
704
+ else:
705
+ num_heads = ch // num_head_channels
706
+ dim_head = num_head_channels
707
+ if legacy:
708
+ #num_heads = 1
709
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
710
+ mid_block = [
711
+ get_resblock(
712
+ merge_factor=merge_factor,
713
+ merge_strategy=merge_strategy,
714
+ video_kernel_size=video_kernel_size,
715
+ ch=ch,
716
+ time_embed_dim=time_embed_dim,
717
+ dropout=dropout,
718
+ out_channels=None,
719
+ dims=dims,
720
+ use_checkpoint=use_checkpoint,
721
+ use_scale_shift_norm=use_scale_shift_norm,
722
+ dtype=self.dtype,
723
+ device=device,
724
+ operations=operations
725
+ )]
726
+ if transformer_depth_middle >= 0:
727
+ mid_block += [get_attention_layer( # always uses a self-attn
728
+ ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
729
+ disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
730
+ ),
731
+ get_resblock(
732
+ merge_factor=merge_factor,
733
+ merge_strategy=merge_strategy,
734
+ video_kernel_size=video_kernel_size,
735
+ ch=ch,
736
+ time_embed_dim=time_embed_dim,
737
+ dropout=dropout,
738
+ out_channels=None,
739
+ dims=dims,
740
+ use_checkpoint=use_checkpoint,
741
+ use_scale_shift_norm=use_scale_shift_norm,
742
+ dtype=self.dtype,
743
+ device=device,
744
+ operations=operations
745
+ )]
746
+ self.middle_block = TimestepEmbedSequential(*mid_block)
747
+ self._feature_size += ch
748
+
749
+ self.output_blocks = nn.ModuleList([])
750
+ for level, mult in list(enumerate(channel_mult))[::-1]:
751
+ for i in range(self.num_res_blocks[level] + 1):
752
+ ich = input_block_chans.pop()
753
+ layers = [
754
+ get_resblock(
755
+ merge_factor=merge_factor,
756
+ merge_strategy=merge_strategy,
757
+ video_kernel_size=video_kernel_size,
758
+ ch=ch + ich,
759
+ time_embed_dim=time_embed_dim,
760
+ dropout=dropout,
761
+ out_channels=model_channels * mult,
762
+ dims=dims,
763
+ use_checkpoint=use_checkpoint,
764
+ use_scale_shift_norm=use_scale_shift_norm,
765
+ dtype=self.dtype,
766
+ device=device,
767
+ operations=operations
768
+ )
769
+ ]
770
+ ch = model_channels * mult
771
+ num_transformers = transformer_depth_output.pop()
772
+ if num_transformers > 0:
773
+ if num_head_channels == -1:
774
+ dim_head = ch // num_heads
775
+ else:
776
+ num_heads = ch // num_head_channels
777
+ dim_head = num_head_channels
778
+ if legacy:
779
+ #num_heads = 1
780
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
781
+ if exists(disable_self_attentions):
782
+ disabled_sa = disable_self_attentions[level]
783
+ else:
784
+ disabled_sa = False
785
+
786
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
787
+ layers.append(
788
+ get_attention_layer(
789
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
790
+ disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint
791
+ )
792
+ )
793
+ if level and i == self.num_res_blocks[level]:
794
+ out_ch = ch
795
+ layers.append(
796
+ get_resblock(
797
+ merge_factor=merge_factor,
798
+ merge_strategy=merge_strategy,
799
+ video_kernel_size=video_kernel_size,
800
+ ch=ch,
801
+ time_embed_dim=time_embed_dim,
802
+ dropout=dropout,
803
+ out_channels=out_ch,
804
+ dims=dims,
805
+ use_checkpoint=use_checkpoint,
806
+ use_scale_shift_norm=use_scale_shift_norm,
807
+ up=True,
808
+ dtype=self.dtype,
809
+ device=device,
810
+ operations=operations
811
+ )
812
+ if resblock_updown
813
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations)
814
+ )
815
+ ds //= 2
816
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
817
+ self._feature_size += ch
818
+
819
+ self.out = nn.Sequential(
820
+ operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
821
+ nn.SiLU(),
822
+ zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
823
+ )
824
+ if self.predict_codebook_ids:
825
+ self.id_predictor = nn.Sequential(
826
+ operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
827
+ operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
828
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
829
+ )
830
+
831
+ def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
832
+ """
833
+ Apply the model to an input batch.
834
+ :param x: an [N x C x ...] Tensor of inputs.
835
+ :param timesteps: a 1-D batch of timesteps.
836
+ :param context: conditioning plugged in via crossattn
837
+ :param y: an [N] Tensor of labels, if class-conditional.
838
+ :return: an [N x C x ...] Tensor of outputs.
839
+ """
840
+ transformer_options["original_shape"] = list(x.shape)
841
+ transformer_options["transformer_index"] = 0
842
+ transformer_patches = transformer_options.get("patches", {})
843
+ block_modifiers = transformer_options.get("block_modifiers", [])
844
+
845
+ num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
846
+ image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
847
+ time_context = kwargs.get("time_context", None)
848
+
849
+ assert (y is not None) == (
850
+ self.num_classes is not None
851
+ ), "must specify y if and only if the model is class-conditional"
852
+ hs = []
853
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
854
+ emb = self.time_embed(t_emb)
855
+
856
+ if self.num_classes is not None:
857
+ assert y.shape[0] == x.shape[0]
858
+ emb = emb + self.label_emb(y)
859
+
860
+ h = x
861
+ for id, module in enumerate(self.input_blocks):
862
+ transformer_options["block"] = ("input", id)
863
+
864
+ for block_modifier in block_modifiers:
865
+ h = block_modifier(h, 'before', transformer_options)
866
+
867
+ h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
868
+ h = apply_control(h, control, 'input')
869
+
870
+ for block_modifier in block_modifiers:
871
+ h = block_modifier(h, 'after', transformer_options)
872
+
873
+ if "input_block_patch" in transformer_patches:
874
+ patch = transformer_patches["input_block_patch"]
875
+ for p in patch:
876
+ h = p(h, transformer_options)
877
+
878
+ hs.append(h)
879
+ if "input_block_patch_after_skip" in transformer_patches:
880
+ patch = transformer_patches["input_block_patch_after_skip"]
881
+ for p in patch:
882
+ h = p(h, transformer_options)
883
+
884
+ transformer_options["block"] = ("middle", 0)
885
+
886
+ for block_modifier in block_modifiers:
887
+ h = block_modifier(h, 'before', transformer_options)
888
+
889
+ h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
890
+ h = apply_control(h, control, 'middle')
891
+
892
+ for block_modifier in block_modifiers:
893
+ h = block_modifier(h, 'after', transformer_options)
894
+
895
+ for id, module in enumerate(self.output_blocks):
896
+ transformer_options["block"] = ("output", id)
897
+ hsp = hs.pop()
898
+ hsp = apply_control(hsp, control, 'output')
899
+
900
+ if "output_block_patch" in transformer_patches:
901
+ patch = transformer_patches["output_block_patch"]
902
+ for p in patch:
903
+ h, hsp = p(h, hsp, transformer_options)
904
+
905
+ h = th.cat([h, hsp], dim=1)
906
+ del hsp
907
+ if len(hs) > 0:
908
+ output_shape = hs[-1].shape
909
+ else:
910
+ output_shape = None
911
+
912
+ for block_modifier in block_modifiers:
913
+ h = block_modifier(h, 'before', transformer_options)
914
+
915
+ h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
916
+
917
+ for block_modifier in block_modifiers:
918
+ h = block_modifier(h, 'after', transformer_options)
919
+
920
+ transformer_options["block"] = ("last", 0)
921
+
922
+ for block_modifier in block_modifiers:
923
+ h = block_modifier(h, 'before', transformer_options)
924
+
925
+ if self.predict_codebook_ids:
926
+ h = self.id_predictor(h)
927
+ else:
928
+ h = self.out(h)
929
+
930
+ for block_modifier in block_modifiers:
931
+ h = block_modifier(h, 'after', transformer_options)
932
+
933
+ return h.type(x.dtype)
ldm_patched/ldm/modules/diffusionmodules/upscaling.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1st edit by https://github.com/CompVis/latent-diffusion
2
+ # 2nd edit by https://github.com/Stability-AI/stablediffusion
3
+ # 3rd edit by https://github.com/Stability-AI/generative-models
4
+ # 4th edit by https://github.com/comfyanonymous/ComfyUI
5
+
6
+ # This file is only for reference, and not used in the backend or runtime.
7
+
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ from functools import partial
13
+
14
+ from .util import extract_into_tensor, make_beta_schedule
15
+ from ldm_patched.ldm.util import default
16
+
17
+
18
+ class AbstractLowScaleModel(nn.Module):
19
+ # for concatenating a downsampled image to the latent representation
20
+ def __init__(self, noise_schedule_config=None):
21
+ super(AbstractLowScaleModel, self).__init__()
22
+ if noise_schedule_config is not None:
23
+ self.register_schedule(**noise_schedule_config)
24
+
25
+ def register_schedule(self, beta_schedule="linear", timesteps=1000,
26
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
27
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
28
+ cosine_s=cosine_s)
29
+ alphas = 1. - betas
30
+ alphas_cumprod = np.cumprod(alphas, axis=0)
31
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
32
+
33
+ timesteps, = betas.shape
34
+ self.num_timesteps = int(timesteps)
35
+ self.linear_start = linear_start
36
+ self.linear_end = linear_end
37
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
38
+
39
+ to_torch = partial(torch.tensor, dtype=torch.float32)
40
+
41
+ self.register_buffer('betas', to_torch(betas))
42
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
43
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
44
+
45
+ # calculations for diffusion q(x_t | x_{t-1}) and others
46
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
47
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
48
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
49
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
50
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
51
+
52
+ def q_sample(self, x_start, t, noise=None, seed=None):
53
+ if noise is None:
54
+ if seed is None:
55
+ noise = torch.randn_like(x_start)
56
+ else:
57
+ noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device)
58
+ return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
59
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise)
60
+
61
+ def forward(self, x):
62
+ return x, None
63
+
64
+ def decode(self, x):
65
+ return x
66
+
67
+
68
+ class SimpleImageConcat(AbstractLowScaleModel):
69
+ # no noise level conditioning
70
+ def __init__(self):
71
+ super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
72
+ self.max_noise_level = 0
73
+
74
+ def forward(self, x):
75
+ # fix to constant noise level
76
+ return x, torch.zeros(x.shape[0], device=x.device).long()
77
+
78
+
79
+ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
80
+ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
81
+ super().__init__(noise_schedule_config=noise_schedule_config)
82
+ self.max_noise_level = max_noise_level
83
+
84
+ def forward(self, x, noise_level=None, seed=None):
85
+ if noise_level is None:
86
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
87
+ else:
88
+ assert isinstance(noise_level, torch.Tensor)
89
+ z = self.q_sample(x, noise_level, seed=seed)
90
+ return z, noise_level
91
+
92
+
93
+
ldm_patched/ldm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1st edit by https://github.com/CompVis/latent-diffusion
2
+ # 2nd edit by https://github.com/Stability-AI/stablediffusion
3
+ # 3rd edit by https://github.com/Stability-AI/generative-models
4
+
5
+
6
+ # adopted from
7
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
8
+ # and
9
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
10
+ # and
11
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
12
+ #
13
+ # thanks!
14
+
15
+
16
+ import os
17
+ import math
18
+ import torch
19
+ import torch.nn as nn
20
+ import numpy as np
21
+ from einops import repeat, rearrange
22
+
23
+ from ldm_patched.ldm.util import instantiate_from_config
24
+
25
+ class AlphaBlender(nn.Module):
26
+ strategies = ["learned", "fixed", "learned_with_images"]
27
+
28
+ def __init__(
29
+ self,
30
+ alpha: float,
31
+ merge_strategy: str = "learned_with_images",
32
+ rearrange_pattern: str = "b t -> (b t) 1 1",
33
+ ):
34
+ super().__init__()
35
+ self.merge_strategy = merge_strategy
36
+ self.rearrange_pattern = rearrange_pattern
37
+
38
+ assert (
39
+ merge_strategy in self.strategies
40
+ ), f"merge_strategy needs to be in {self.strategies}"
41
+
42
+ if self.merge_strategy == "fixed":
43
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
44
+ elif (
45
+ self.merge_strategy == "learned"
46
+ or self.merge_strategy == "learned_with_images"
47
+ ):
48
+ self.register_parameter(
49
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
50
+ )
51
+ else:
52
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
53
+
54
+ def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
55
+ # skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
56
+ if self.merge_strategy == "fixed":
57
+ # make shape compatible
58
+ # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
59
+ alpha = self.mix_factor.to(image_only_indicator.device)
60
+ elif self.merge_strategy == "learned":
61
+ alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device))
62
+ # make shape compatible
63
+ # alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
64
+ elif self.merge_strategy == "learned_with_images":
65
+ assert image_only_indicator is not None, "need image_only_indicator ..."
66
+ alpha = torch.where(
67
+ image_only_indicator.bool(),
68
+ torch.ones(1, 1, device=image_only_indicator.device),
69
+ rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"),
70
+ )
71
+ alpha = rearrange(alpha, self.rearrange_pattern)
72
+ # make shape compatible
73
+ # alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
74
+ else:
75
+ raise NotImplementedError()
76
+ return alpha
77
+
78
+ def forward(
79
+ self,
80
+ x_spatial,
81
+ x_temporal,
82
+ image_only_indicator=None,
83
+ ) -> torch.Tensor:
84
+ alpha = self.get_alpha(image_only_indicator)
85
+ x = (
86
+ alpha.to(x_spatial.dtype) * x_spatial
87
+ + (1.0 - alpha).to(x_spatial.dtype) * x_temporal
88
+ )
89
+ return x
90
+
91
+
92
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
93
+ if schedule == "linear":
94
+ betas = (
95
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
96
+ )
97
+
98
+ elif schedule == "cosine":
99
+ timesteps = (
100
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
101
+ )
102
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
103
+ alphas = torch.cos(alphas).pow(2)
104
+ alphas = alphas / alphas[0]
105
+ betas = 1 - alphas[1:] / alphas[:-1]
106
+ betas = np.clip(betas, a_min=0, a_max=0.999)
107
+
108
+ elif schedule == "squaredcos_cap_v2": # used for karlo prior
109
+ # return early
110
+ return betas_for_alpha_bar(
111
+ n_timestep,
112
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
113
+ )
114
+
115
+ elif schedule == "sqrt_linear":
116
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
117
+ elif schedule == "sqrt":
118
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
119
+ else:
120
+ raise ValueError(f"schedule '{schedule}' unknown.")
121
+ return betas.numpy()
122
+
123
+
124
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
125
+ if ddim_discr_method == 'uniform':
126
+ c = num_ddpm_timesteps // num_ddim_timesteps
127
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
128
+ elif ddim_discr_method == 'quad':
129
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
130
+ else:
131
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
132
+
133
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
134
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
135
+ steps_out = ddim_timesteps + 1
136
+ if verbose:
137
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
138
+ return steps_out
139
+
140
+
141
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
142
+ # select alphas for computing the variance schedule
143
+ alphas = alphacums[ddim_timesteps]
144
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
145
+
146
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
147
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
148
+ if verbose:
149
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
150
+ print(f'For the chosen value of eta, which is {eta}, '
151
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
152
+ return sigmas, alphas, alphas_prev
153
+
154
+
155
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
156
+ """
157
+ Create a beta schedule that discretizes the given alpha_t_bar function,
158
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
159
+ :param num_diffusion_timesteps: the number of betas to produce.
160
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
161
+ produces the cumulative product of (1-beta) up to that
162
+ part of the diffusion process.
163
+ :param max_beta: the maximum beta to use; use values lower than 1 to
164
+ prevent singularities.
165
+ """
166
+ betas = []
167
+ for i in range(num_diffusion_timesteps):
168
+ t1 = i / num_diffusion_timesteps
169
+ t2 = (i + 1) / num_diffusion_timesteps
170
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
171
+ return np.array(betas)
172
+
173
+
174
+ def extract_into_tensor(a, t, x_shape):
175
+ b, *_ = t.shape
176
+ out = a.gather(-1, t)
177
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
178
+
179
+
180
+ def checkpoint(func, inputs, params, flag):
181
+ """
182
+ Evaluate a function without caching intermediate activations, allowing for
183
+ reduced memory at the expense of extra compute in the backward pass.
184
+ :param func: the function to evaluate.
185
+ :param inputs: the argument sequence to pass to `func`.
186
+ :param params: a sequence of parameters `func` depends on but does not
187
+ explicitly take as arguments.
188
+ :param flag: if False, disable gradient checkpointing.
189
+ """
190
+ if flag:
191
+ args = tuple(inputs) + tuple(params)
192
+ return CheckpointFunction.apply(func, len(inputs), *args)
193
+ else:
194
+ return func(*inputs)
195
+
196
+
197
+ class CheckpointFunction(torch.autograd.Function):
198
+ @staticmethod
199
+ def forward(ctx, run_function, length, *args):
200
+ ctx.run_function = run_function
201
+ ctx.input_tensors = list(args[:length])
202
+ ctx.input_params = list(args[length:])
203
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
204
+ "dtype": torch.get_autocast_gpu_dtype(),
205
+ "cache_enabled": torch.is_autocast_cache_enabled()}
206
+ with torch.no_grad():
207
+ output_tensors = ctx.run_function(*ctx.input_tensors)
208
+ return output_tensors
209
+
210
+ @staticmethod
211
+ def backward(ctx, *output_grads):
212
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
213
+ with torch.enable_grad(), \
214
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
215
+ # Fixes a bug where the first op in run_function modifies the
216
+ # Tensor storage in place, which is not allowed for detach()'d
217
+ # Tensors.
218
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
219
+ output_tensors = ctx.run_function(*shallow_copies)
220
+ input_grads = torch.autograd.grad(
221
+ output_tensors,
222
+ ctx.input_tensors + ctx.input_params,
223
+ output_grads,
224
+ allow_unused=True,
225
+ )
226
+ del ctx.input_tensors
227
+ del ctx.input_params
228
+ del output_tensors
229
+ return (None, None) + input_grads
230
+
231
+
232
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
233
+ # Consistent with Kohya to reduce differences between model training and inference.
234
+
235
+ if not repeat_only:
236
+ half = dim // 2
237
+ freqs = torch.exp(
238
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
239
+ ).to(device=timesteps.device)
240
+ args = timesteps[:, None].float() * freqs[None]
241
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
242
+ if dim % 2:
243
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
244
+ else:
245
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
246
+ return embedding
247
+
248
+
249
+ def zero_module(module):
250
+ """
251
+ Zero out the parameters of a module and return it.
252
+ """
253
+ for p in module.parameters():
254
+ p.detach().zero_()
255
+ return module
256
+
257
+
258
+ def scale_module(module, scale):
259
+ """
260
+ Scale the parameters of a module and return it.
261
+ """
262
+ for p in module.parameters():
263
+ p.detach().mul_(scale)
264
+ return module
265
+
266
+
267
+ def mean_flat(tensor):
268
+ """
269
+ Take the mean over all non-batch dimensions.
270
+ """
271
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
272
+
273
+
274
+ def avg_pool_nd(dims, *args, **kwargs):
275
+ """
276
+ Create a 1D, 2D, or 3D average pooling module.
277
+ """
278
+ if dims == 1:
279
+ return nn.AvgPool1d(*args, **kwargs)
280
+ elif dims == 2:
281
+ return nn.AvgPool2d(*args, **kwargs)
282
+ elif dims == 3:
283
+ return nn.AvgPool3d(*args, **kwargs)
284
+ raise ValueError(f"unsupported dimensions: {dims}")
285
+
286
+
287
+ class HybridConditioner(nn.Module):
288
+
289
+ def __init__(self, c_concat_config, c_crossattn_config):
290
+ super().__init__()
291
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
292
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
293
+
294
+ def forward(self, c_concat, c_crossattn):
295
+ c_concat = self.concat_conditioner(c_concat)
296
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
297
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
298
+
299
+
300
+ def noise_like(shape, device, repeat=False):
301
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
302
+ noise = lambda: torch.randn(shape, device=device)
303
+ return repeat_noise() if repeat else noise()
ldm_patched/ldm/modules/distributions/__init__.py ADDED
File without changes
ldm_patched/ldm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1st edit by https://github.com/CompVis/latent-diffusion
2
+ # 2nd edit by https://github.com/Stability-AI/stablediffusion
3
+
4
+
5
+ import torch
6
+ import numpy as np
7
+
8
+
9
+ class AbstractDistribution:
10
+ def sample(self):
11
+ raise NotImplementedError()
12
+
13
+ def mode(self):
14
+ raise NotImplementedError()
15
+
16
+
17
+ class DiracDistribution(AbstractDistribution):
18
+ def __init__(self, value):
19
+ self.value = value
20
+
21
+ def sample(self):
22
+ return self.value
23
+
24
+ def mode(self):
25
+ return self.value
26
+
27
+
28
+ class DiagonalGaussianDistribution(object):
29
+ def __init__(self, parameters, deterministic=False):
30
+ self.parameters = parameters
31
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
32
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
33
+ self.deterministic = deterministic
34
+ self.std = torch.exp(0.5 * self.logvar)
35
+ self.var = torch.exp(self.logvar)
36
+ if self.deterministic:
37
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
38
+
39
+ def sample(self):
40
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
41
+ return x
42
+
43
+ def kl(self, other=None):
44
+ if self.deterministic:
45
+ return torch.Tensor([0.])
46
+ else:
47
+ if other is None:
48
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
49
+ + self.var - 1.0 - self.logvar,
50
+ dim=[1, 2, 3])
51
+ else:
52
+ return 0.5 * torch.sum(
53
+ torch.pow(self.mean - other.mean, 2) / other.var
54
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
55
+ dim=[1, 2, 3])
56
+
57
+ def nll(self, sample, dims=[1,2,3]):
58
+ if self.deterministic:
59
+ return torch.Tensor([0.])
60
+ logtwopi = np.log(2.0 * np.pi)
61
+ return 0.5 * torch.sum(
62
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
63
+ dim=dims)
64
+
65
+ def mode(self):
66
+ return self.mean
67
+
68
+
69
+ def normal_kl(mean1, logvar1, mean2, logvar2):
70
+ """
71
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
72
+ Compute the KL divergence between two gaussians.
73
+ Shapes are automatically broadcasted, so batches can be compared to
74
+ scalars, among other use cases.
75
+ """
76
+ tensor = None
77
+ for obj in (mean1, logvar1, mean2, logvar2):
78
+ if isinstance(obj, torch.Tensor):
79
+ tensor = obj
80
+ break
81
+ assert tensor is not None, "at least one argument must be a Tensor"
82
+
83
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
84
+ # Tensors, but it does not work for torch.exp().
85
+ logvar1, logvar2 = [
86
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
87
+ for x in (logvar1, logvar2)
88
+ ]
89
+
90
+ return 0.5 * (
91
+ -1.0
92
+ + logvar2
93
+ - logvar1
94
+ + torch.exp(logvar1 - logvar2)
95
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
96
+ )
ldm_patched/ldm/modules/ema.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1st edit by https://github.com/CompVis/latent-diffusion
2
+ # 2nd edit by https://github.com/Stability-AI/stablediffusion
3
+ # 3rd edit by https://github.com/Stability-AI/generative-models
4
+ # 4th edit by https://github.com/comfyanonymous/ComfyUI
5
+
6
+
7
+ # This file is not used in image diffusion backend.
8
+
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+
14
+ class LitEma(nn.Module):
15
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
16
+ super().__init__()
17
+ if decay < 0.0 or decay > 1.0:
18
+ raise ValueError('Decay must be between 0 and 1')
19
+
20
+ self.m_name2s_name = {}
21
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
22
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
23
+ else torch.tensor(-1, dtype=torch.int))
24
+
25
+ for name, p in model.named_parameters():
26
+ if p.requires_grad:
27
+ # remove as '.'-character is not allowed in buffers
28
+ s_name = name.replace('.', '')
29
+ self.m_name2s_name.update({name: s_name})
30
+ self.register_buffer(s_name, p.clone().detach().data)
31
+
32
+ self.collected_params = []
33
+
34
+ def reset_num_updates(self):
35
+ del self.num_updates
36
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
37
+
38
+ def forward(self, model):
39
+ decay = self.decay
40
+
41
+ if self.num_updates >= 0:
42
+ self.num_updates += 1
43
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
44
+
45
+ one_minus_decay = 1.0 - decay
46
+
47
+ with torch.no_grad():
48
+ m_param = dict(model.named_parameters())
49
+ shadow_params = dict(self.named_buffers())
50
+
51
+ for key in m_param:
52
+ if m_param[key].requires_grad:
53
+ sname = self.m_name2s_name[key]
54
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
55
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
56
+ else:
57
+ assert not key in self.m_name2s_name
58
+
59
+ def copy_to(self, model):
60
+ m_param = dict(model.named_parameters())
61
+ shadow_params = dict(self.named_buffers())
62
+ for key in m_param:
63
+ if m_param[key].requires_grad:
64
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
65
+ else:
66
+ assert not key in self.m_name2s_name
67
+
68
+ def store(self, parameters):
69
+ """
70
+ Save the current parameters for restoring later.
71
+ Args:
72
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73
+ temporarily stored.
74
+ """
75
+ self.collected_params = [param.clone() for param in parameters]
76
+
77
+ def restore(self, parameters):
78
+ """
79
+ Restore the parameters stored with the `store` method.
80
+ Useful to validate the model with EMA parameters without affecting the
81
+ original optimization process. Store the parameters before the
82
+ `copy_to` method. After validation (or model saving), use this to
83
+ restore the former parameters.
84
+ Args:
85
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
86
+ updated with the stored parameters.
87
+ """
88
+ for c_param, param in zip(self.collected_params, parameters):
89
+ param.data.copy_(c_param.data)
ldm_patched/ldm/modules/encoders/__init__.py ADDED
File without changes
ldm_patched/ldm/modules/encoders/noise_aug_modules.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ from ..diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
6
+ from ..diffusionmodules.openaimodel import Timestep
7
+ import torch
8
+
9
+ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
10
+ def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs):
11
+ super().__init__(*args, **kwargs)
12
+ if clip_stats_path is None:
13
+ clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
14
+ else:
15
+ clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
16
+ self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
17
+ self.register_buffer("data_std", clip_std[None, :], persistent=False)
18
+ self.time_embed = Timestep(timestep_dim)
19
+
20
+ def scale(self, x):
21
+ # re-normalize to centered mean and unit variance
22
+ x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device)
23
+ return x
24
+
25
+ def unscale(self, x):
26
+ # back to original data stats
27
+ x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
28
+ return x
29
+
30
+ def forward(self, x, noise_level=None, seed=None):
31
+ if noise_level is None:
32
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
33
+ else:
34
+ assert isinstance(noise_level, torch.Tensor)
35
+ x = self.scale(x)
36
+ z = self.q_sample(x, noise_level, seed=seed)
37
+ z = self.unscale(z)
38
+ noise_level = self.time_embed(noise_level)
39
+ return z, noise_level
ldm_patched/ldm/modules/sub_quadratic_attention.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original source:
2
+ # https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
3
+ # license:
4
+ # MIT
5
+ # credit:
6
+ # Amin Rezaei (original author)
7
+ # Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
8
+ # implementation of:
9
+ # Self-attention Does Not Need O(n2) Memory":
10
+ # https://arxiv.org/abs/2112.05682v2
11
+
12
+ from functools import partial
13
+ import torch
14
+ from torch import Tensor
15
+ from torch.utils.checkpoint import checkpoint
16
+ import math
17
+
18
+ try:
19
+ from typing import Optional, NamedTuple, List, Protocol
20
+ except ImportError:
21
+ from typing import Optional, NamedTuple, List
22
+ from typing_extensions import Protocol
23
+
24
+ from torch import Tensor
25
+ from typing import List
26
+
27
+ from ldm_patched.modules import model_management
28
+
29
+ def dynamic_slice(
30
+ x: Tensor,
31
+ starts: List[int],
32
+ sizes: List[int],
33
+ ) -> Tensor:
34
+ slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
35
+ return x[slicing]
36
+
37
+ class AttnChunk(NamedTuple):
38
+ exp_values: Tensor
39
+ exp_weights_sum: Tensor
40
+ max_score: Tensor
41
+
42
+ class SummarizeChunk(Protocol):
43
+ @staticmethod
44
+ def __call__(
45
+ query: Tensor,
46
+ key_t: Tensor,
47
+ value: Tensor,
48
+ ) -> AttnChunk: ...
49
+
50
+ class ComputeQueryChunkAttn(Protocol):
51
+ @staticmethod
52
+ def __call__(
53
+ query: Tensor,
54
+ key_t: Tensor,
55
+ value: Tensor,
56
+ ) -> Tensor: ...
57
+
58
+ def _summarize_chunk(
59
+ query: Tensor,
60
+ key_t: Tensor,
61
+ value: Tensor,
62
+ scale: float,
63
+ upcast_attention: bool,
64
+ mask,
65
+ ) -> AttnChunk:
66
+ if upcast_attention:
67
+ with torch.autocast(enabled=False, device_type = 'cuda'):
68
+ query = query.float()
69
+ key_t = key_t.float()
70
+ attn_weights = torch.baddbmm(
71
+ torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
72
+ query,
73
+ key_t,
74
+ alpha=scale,
75
+ beta=0,
76
+ )
77
+ else:
78
+ attn_weights = torch.baddbmm(
79
+ torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
80
+ query,
81
+ key_t,
82
+ alpha=scale,
83
+ beta=0,
84
+ )
85
+ max_score, _ = torch.max(attn_weights, -1, keepdim=True)
86
+ max_score = max_score.detach()
87
+ attn_weights -= max_score
88
+ if mask is not None:
89
+ attn_weights += mask
90
+ torch.exp(attn_weights, out=attn_weights)
91
+ exp_weights = attn_weights.to(value.dtype)
92
+ exp_values = torch.bmm(exp_weights, value)
93
+ max_score = max_score.squeeze(-1)
94
+ return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
95
+
96
+ def _query_chunk_attention(
97
+ query: Tensor,
98
+ key_t: Tensor,
99
+ value: Tensor,
100
+ summarize_chunk: SummarizeChunk,
101
+ kv_chunk_size: int,
102
+ mask,
103
+ ) -> Tensor:
104
+ batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
105
+ _, _, v_channels_per_head = value.shape
106
+
107
+ def chunk_scanner(chunk_idx: int, mask) -> AttnChunk:
108
+ key_chunk = dynamic_slice(
109
+ key_t,
110
+ (0, 0, chunk_idx),
111
+ (batch_x_heads, k_channels_per_head, kv_chunk_size)
112
+ )
113
+ value_chunk = dynamic_slice(
114
+ value,
115
+ (0, chunk_idx, 0),
116
+ (batch_x_heads, kv_chunk_size, v_channels_per_head)
117
+ )
118
+ if mask is not None:
119
+ mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
120
+
121
+ return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
122
+
123
+ chunks: List[AttnChunk] = [
124
+ chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
125
+ ]
126
+ acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
127
+ chunk_values, chunk_weights, chunk_max = acc_chunk
128
+
129
+ global_max, _ = torch.max(chunk_max, 0, keepdim=True)
130
+ max_diffs = torch.exp(chunk_max - global_max)
131
+ chunk_values *= torch.unsqueeze(max_diffs, -1)
132
+ chunk_weights *= max_diffs
133
+
134
+ all_values = chunk_values.sum(dim=0)
135
+ all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
136
+ return all_values / all_weights
137
+
138
+ # TODO: refactor CrossAttention#get_attention_scores to share code with this
139
+ def _get_attention_scores_no_kv_chunking(
140
+ query: Tensor,
141
+ key_t: Tensor,
142
+ value: Tensor,
143
+ scale: float,
144
+ upcast_attention: bool,
145
+ mask,
146
+ ) -> Tensor:
147
+ if upcast_attention:
148
+ with torch.autocast(enabled=False, device_type = 'cuda'):
149
+ query = query.float()
150
+ key_t = key_t.float()
151
+ attn_scores = torch.baddbmm(
152
+ torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
153
+ query,
154
+ key_t,
155
+ alpha=scale,
156
+ beta=0,
157
+ )
158
+ else:
159
+ attn_scores = torch.baddbmm(
160
+ torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
161
+ query,
162
+ key_t,
163
+ alpha=scale,
164
+ beta=0,
165
+ )
166
+
167
+ if mask is not None:
168
+ attn_scores += mask
169
+ try:
170
+ attn_probs = attn_scores.softmax(dim=-1)
171
+ del attn_scores
172
+ except model_management.OOM_EXCEPTION:
173
+ print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
174
+ attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
175
+ torch.exp(attn_scores, out=attn_scores)
176
+ summed = torch.sum(attn_scores, dim=-1, keepdim=True)
177
+ attn_scores /= summed
178
+ attn_probs = attn_scores
179
+
180
+ hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
181
+ return hidden_states_slice
182
+
183
+ class ScannedChunk(NamedTuple):
184
+ chunk_idx: int
185
+ attn_chunk: AttnChunk
186
+
187
+ def efficient_dot_product_attention(
188
+ query: Tensor,
189
+ key_t: Tensor,
190
+ value: Tensor,
191
+ query_chunk_size=1024,
192
+ kv_chunk_size: Optional[int] = None,
193
+ kv_chunk_size_min: Optional[int] = None,
194
+ use_checkpoint=True,
195
+ upcast_attention=False,
196
+ mask = None,
197
+ ):
198
+ """Computes efficient dot-product attention given query, transposed key, and value.
199
+ This is efficient version of attention presented in
200
+ https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
201
+ Args:
202
+ query: queries for calculating attention with shape of
203
+ `[batch * num_heads, tokens, channels_per_head]`.
204
+ key_t: keys for calculating attention with shape of
205
+ `[batch * num_heads, channels_per_head, tokens]`.
206
+ value: values to be used in attention with shape of
207
+ `[batch * num_heads, tokens, channels_per_head]`.
208
+ query_chunk_size: int: query chunks size
209
+ kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
210
+ kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
211
+ use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
212
+ Returns:
213
+ Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
214
+ """
215
+ batch_x_heads, q_tokens, q_channels_per_head = query.shape
216
+ _, _, k_tokens = key_t.shape
217
+ scale = q_channels_per_head ** -0.5
218
+
219
+ kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
220
+ if kv_chunk_size_min is not None:
221
+ kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
222
+
223
+ if mask is not None and len(mask.shape) == 2:
224
+ mask = mask.unsqueeze(0)
225
+
226
+ def get_query_chunk(chunk_idx: int) -> Tensor:
227
+ return dynamic_slice(
228
+ query,
229
+ (0, chunk_idx, 0),
230
+ (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
231
+ )
232
+
233
+ def get_mask_chunk(chunk_idx: int) -> Tensor:
234
+ if mask is None:
235
+ return None
236
+ chunk = min(query_chunk_size, q_tokens)
237
+ return mask[:,chunk_idx:chunk_idx + chunk]
238
+
239
+ summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
240
+ summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
241
+ compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
242
+ _get_attention_scores_no_kv_chunking,
243
+ scale=scale,
244
+ upcast_attention=upcast_attention
245
+ ) if k_tokens <= kv_chunk_size else (
246
+ # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
247
+ partial(
248
+ _query_chunk_attention,
249
+ kv_chunk_size=kv_chunk_size,
250
+ summarize_chunk=summarize_chunk,
251
+ )
252
+ )
253
+
254
+ if q_tokens <= query_chunk_size:
255
+ # fast-path for when there's just 1 query chunk
256
+ return compute_query_chunk_attn(
257
+ query=query,
258
+ key_t=key_t,
259
+ value=value,
260
+ mask=mask,
261
+ )
262
+
263
+ # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
264
+ # and pass slices to be mutated, instead of torch.cat()ing the returned slices
265
+ res = torch.cat([
266
+ compute_query_chunk_attn(
267
+ query=get_query_chunk(i * query_chunk_size),
268
+ key_t=key_t,
269
+ value=value,
270
+ mask=get_mask_chunk(i * query_chunk_size)
271
+ ) for i in range(math.ceil(q_tokens / query_chunk_size))
272
+ ], dim=1)
273
+ return res
ldm_patched/ldm/modules/temporal_ae.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1st edit by https://github.com/Stability-AI/generative-models
2
+ # 2nd edit by https://github.com/comfyanonymous/ComfyUI
3
+ # 3rd edit by Forge
4
+
5
+ # This file is not used in image diffusion backend. (but used in SVD.)
6
+
7
+
8
+ import functools
9
+ from typing import Callable, Iterable, Union
10
+
11
+ import torch
12
+ from einops import rearrange, repeat
13
+
14
+ import ldm_patched.modules.ops
15
+ ops = ldm_patched.modules.ops.disable_weight_init
16
+
17
+ from .diffusionmodules.model import (
18
+ AttnBlock,
19
+ Decoder,
20
+ ResnetBlock,
21
+ )
22
+ from .diffusionmodules.openaimodel import ResBlock, timestep_embedding
23
+ from .attention import BasicTransformerBlock
24
+
25
+ def partialclass(cls, *args, **kwargs):
26
+ class NewCls(cls):
27
+ __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
28
+
29
+ return NewCls
30
+
31
+
32
+ class VideoResBlock(ResnetBlock):
33
+ def __init__(
34
+ self,
35
+ out_channels,
36
+ *args,
37
+ dropout=0.0,
38
+ video_kernel_size=3,
39
+ alpha=0.0,
40
+ merge_strategy="learned",
41
+ **kwargs,
42
+ ):
43
+ super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
44
+ if video_kernel_size is None:
45
+ video_kernel_size = [3, 1, 1]
46
+ self.time_stack = ResBlock(
47
+ channels=out_channels,
48
+ emb_channels=0,
49
+ dropout=dropout,
50
+ dims=3,
51
+ use_scale_shift_norm=False,
52
+ use_conv=False,
53
+ up=False,
54
+ down=False,
55
+ kernel_size=video_kernel_size,
56
+ use_checkpoint=False,
57
+ skip_t_emb=True,
58
+ )
59
+
60
+ self.merge_strategy = merge_strategy
61
+ if self.merge_strategy == "fixed":
62
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
63
+ elif self.merge_strategy == "learned":
64
+ self.register_parameter(
65
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
66
+ )
67
+ else:
68
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
69
+
70
+ def get_alpha(self, bs):
71
+ if self.merge_strategy == "fixed":
72
+ return self.mix_factor
73
+ elif self.merge_strategy == "learned":
74
+ return torch.sigmoid(self.mix_factor)
75
+ else:
76
+ raise NotImplementedError()
77
+
78
+ def forward(self, x, temb, skip_video=False, timesteps=None):
79
+ b, c, h, w = x.shape
80
+ if timesteps is None:
81
+ timesteps = b
82
+
83
+ x = super().forward(x, temb)
84
+
85
+ if not skip_video:
86
+ x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
87
+
88
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
89
+
90
+ x = self.time_stack(x, temb)
91
+
92
+ alpha = self.get_alpha(bs=b // timesteps).to(x.device)
93
+ x = alpha * x + (1.0 - alpha) * x_mix
94
+
95
+ x = rearrange(x, "b c t h w -> (b t) c h w")
96
+ return x
97
+
98
+
99
+ class AE3DConv(ops.Conv2d):
100
+ def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
101
+ super().__init__(in_channels, out_channels, *args, **kwargs)
102
+ if isinstance(video_kernel_size, Iterable):
103
+ padding = [int(k // 2) for k in video_kernel_size]
104
+ else:
105
+ padding = int(video_kernel_size // 2)
106
+
107
+ self.time_mix_conv = ops.Conv3d(
108
+ in_channels=out_channels,
109
+ out_channels=out_channels,
110
+ kernel_size=video_kernel_size,
111
+ padding=padding,
112
+ )
113
+
114
+ def forward(self, input, timesteps=None, skip_video=False):
115
+ if timesteps is None:
116
+ timesteps = input.shape[0]
117
+ x = super().forward(input)
118
+ if skip_video:
119
+ return x
120
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
121
+ x = self.time_mix_conv(x)
122
+ return rearrange(x, "b c t h w -> (b t) c h w")
123
+
124
+
125
+ class AttnVideoBlock(AttnBlock):
126
+ def __init__(
127
+ self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
128
+ ):
129
+ super().__init__(in_channels)
130
+ # no context, single headed, as in base class
131
+ self.time_mix_block = BasicTransformerBlock(
132
+ dim=in_channels,
133
+ n_heads=1,
134
+ d_head=in_channels,
135
+ checkpoint=False,
136
+ ff_in=True,
137
+ )
138
+
139
+ time_embed_dim = self.in_channels * 4
140
+ self.video_time_embed = torch.nn.Sequential(
141
+ ops.Linear(self.in_channels, time_embed_dim),
142
+ torch.nn.SiLU(),
143
+ ops.Linear(time_embed_dim, self.in_channels),
144
+ )
145
+
146
+ self.merge_strategy = merge_strategy
147
+ if self.merge_strategy == "fixed":
148
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
149
+ elif self.merge_strategy == "learned":
150
+ self.register_parameter(
151
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
152
+ )
153
+ else:
154
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
155
+
156
+ def forward(self, x, timesteps=None, skip_time_block=False):
157
+ if skip_time_block:
158
+ return super().forward(x)
159
+
160
+ if timesteps is None:
161
+ timesteps = x.shape[0]
162
+
163
+ x_in = x
164
+ x = self.attention(x)
165
+ h, w = x.shape[2:]
166
+ x = rearrange(x, "b c h w -> b (h w) c")
167
+
168
+ x_mix = x
169
+ num_frames = torch.arange(timesteps, device=x.device)
170
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
171
+ num_frames = rearrange(num_frames, "b t -> (b t)")
172
+ t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
173
+ emb = self.video_time_embed(t_emb) # b, n_channels
174
+ emb = emb[:, None, :]
175
+ x_mix = x_mix + emb
176
+
177
+ alpha = self.get_alpha().to(x.device)
178
+ x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
179
+ x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
180
+
181
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
182
+ x = self.proj_out(x)
183
+
184
+ return x_in + x
185
+
186
+ def get_alpha(
187
+ self,
188
+ ):
189
+ if self.merge_strategy == "fixed":
190
+ return self.mix_factor
191
+ elif self.merge_strategy == "learned":
192
+ return torch.sigmoid(self.mix_factor)
193
+ else:
194
+ raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
195
+
196
+
197
+
198
+ def make_time_attn(
199
+ in_channels,
200
+ attn_type="vanilla",
201
+ attn_kwargs=None,
202
+ alpha: float = 0,
203
+ merge_strategy: str = "learned",
204
+ ):
205
+ return partialclass(
206
+ AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
207
+ )
208
+
209
+
210
+ class Conv2DWrapper(torch.nn.Conv2d):
211
+ def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
212
+ return super().forward(input)
213
+
214
+
215
+ class VideoDecoder(Decoder):
216
+ available_time_modes = ["all", "conv-only", "attn-only"]
217
+
218
+ def __init__(
219
+ self,
220
+ *args,
221
+ video_kernel_size: Union[int, list] = 3,
222
+ alpha: float = 0.0,
223
+ merge_strategy: str = "learned",
224
+ time_mode: str = "conv-only",
225
+ **kwargs,
226
+ ):
227
+ self.video_kernel_size = video_kernel_size
228
+ self.alpha = alpha
229
+ self.merge_strategy = merge_strategy
230
+ self.time_mode = time_mode
231
+ assert (
232
+ self.time_mode in self.available_time_modes
233
+ ), f"time_mode parameter has to be in {self.available_time_modes}"
234
+
235
+ if self.time_mode != "attn-only":
236
+ kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
237
+ if self.time_mode not in ["conv-only", "only-last-conv"]:
238
+ kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy)
239
+ if self.time_mode not in ["attn-only", "only-last-conv"]:
240
+ kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy)
241
+
242
+ super().__init__(*args, **kwargs)
243
+
244
+ def get_last_layer(self, skip_time_mix=False, **kwargs):
245
+ if self.time_mode == "attn-only":
246
+ raise NotImplementedError("TODO")
247
+ else:
248
+ return (
249
+ self.conv_out.time_mix_conv.weight
250
+ if not skip_time_mix
251
+ else self.conv_out.weight
252
+ )
ldm_patched/ldm/util.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import importlib
6
+
7
+ import torch
8
+ from torch import optim
9
+ import numpy as np
10
+
11
+ from inspect import isfunction
12
+ from PIL import Image, ImageDraw, ImageFont
13
+
14
+
15
+ def log_txt_as_img(wh, xc, size=10):
16
+ # wh a tuple of (width, height)
17
+ # xc a list of captions to plot
18
+ b = len(xc)
19
+ txts = list()
20
+ for bi in range(b):
21
+ txt = Image.new("RGB", wh, color="white")
22
+ draw = ImageDraw.Draw(txt)
23
+ font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
24
+ nc = int(40 * (wh[0] / 256))
25
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
26
+
27
+ try:
28
+ draw.text((0, 0), lines, fill="black", font=font)
29
+ except UnicodeEncodeError:
30
+ print("Cant encode string for logging. Skipping.")
31
+
32
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
33
+ txts.append(txt)
34
+ txts = np.stack(txts)
35
+ txts = torch.tensor(txts)
36
+ return txts
37
+
38
+
39
+ def ismap(x):
40
+ if not isinstance(x, torch.Tensor):
41
+ return False
42
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
43
+
44
+
45
+ def isimage(x):
46
+ if not isinstance(x,torch.Tensor):
47
+ return False
48
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
49
+
50
+
51
+ def exists(x):
52
+ return x is not None
53
+
54
+
55
+ def default(val, d):
56
+ if exists(val):
57
+ return val
58
+ return d() if isfunction(d) else d
59
+
60
+
61
+ def mean_flat(tensor):
62
+ """
63
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
64
+ Take the mean over all non-batch dimensions.
65
+ """
66
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
67
+
68
+
69
+ def count_params(model, verbose=False):
70
+ total_params = sum(p.numel() for p in model.parameters())
71
+ if verbose:
72
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
73
+ return total_params
74
+
75
+
76
+ def instantiate_from_config(config):
77
+ if not "target" in config:
78
+ if config == '__is_first_stage__':
79
+ return None
80
+ elif config == "__is_unconditional__":
81
+ return None
82
+ raise KeyError("Expected key `target` to instantiate.")
83
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
84
+
85
+
86
+ def get_obj_from_str(string, reload=False):
87
+ module, cls = string.rsplit(".", 1)
88
+ if reload:
89
+ module_imp = importlib.import_module(module)
90
+ importlib.reload(module_imp)
91
+ return getattr(importlib.import_module(module, package=None), cls)
92
+
93
+
94
+ class AdamWwithEMAandWings(optim.Optimizer):
95
+ # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
96
+ def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
97
+ weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
98
+ ema_power=1., param_names=()):
99
+ """AdamW that saves EMA versions of the parameters."""
100
+ if not 0.0 <= lr:
101
+ raise ValueError("Invalid learning rate: {}".format(lr))
102
+ if not 0.0 <= eps:
103
+ raise ValueError("Invalid epsilon value: {}".format(eps))
104
+ if not 0.0 <= betas[0] < 1.0:
105
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
106
+ if not 0.0 <= betas[1] < 1.0:
107
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
108
+ if not 0.0 <= weight_decay:
109
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
110
+ if not 0.0 <= ema_decay <= 1.0:
111
+ raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
112
+ defaults = dict(lr=lr, betas=betas, eps=eps,
113
+ weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
114
+ ema_power=ema_power, param_names=param_names)
115
+ super().__init__(params, defaults)
116
+
117
+ def __setstate__(self, state):
118
+ super().__setstate__(state)
119
+ for group in self.param_groups:
120
+ group.setdefault('amsgrad', False)
121
+
122
+ @torch.no_grad()
123
+ def step(self, closure=None):
124
+ """Performs a single optimization step.
125
+ Args:
126
+ closure (callable, optional): A closure that reevaluates the model
127
+ and returns the loss.
128
+ """
129
+ loss = None
130
+ if closure is not None:
131
+ with torch.enable_grad():
132
+ loss = closure()
133
+
134
+ for group in self.param_groups:
135
+ params_with_grad = []
136
+ grads = []
137
+ exp_avgs = []
138
+ exp_avg_sqs = []
139
+ ema_params_with_grad = []
140
+ state_sums = []
141
+ max_exp_avg_sqs = []
142
+ state_steps = []
143
+ amsgrad = group['amsgrad']
144
+ beta1, beta2 = group['betas']
145
+ ema_decay = group['ema_decay']
146
+ ema_power = group['ema_power']
147
+
148
+ for p in group['params']:
149
+ if p.grad is None:
150
+ continue
151
+ params_with_grad.append(p)
152
+ if p.grad.is_sparse:
153
+ raise RuntimeError('AdamW does not support sparse gradients')
154
+ grads.append(p.grad)
155
+
156
+ state = self.state[p]
157
+
158
+ # State initialization
159
+ if len(state) == 0:
160
+ state['step'] = 0
161
+ # Exponential moving average of gradient values
162
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
163
+ # Exponential moving average of squared gradient values
164
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
165
+ if amsgrad:
166
+ # Maintains max of all exp. moving avg. of sq. grad. values
167
+ state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
168
+ # Exponential moving average of parameter values
169
+ state['param_exp_avg'] = p.detach().float().clone()
170
+
171
+ exp_avgs.append(state['exp_avg'])
172
+ exp_avg_sqs.append(state['exp_avg_sq'])
173
+ ema_params_with_grad.append(state['param_exp_avg'])
174
+
175
+ if amsgrad:
176
+ max_exp_avg_sqs.append(state['max_exp_avg_sq'])
177
+
178
+ # update the steps for each param group update
179
+ state['step'] += 1
180
+ # record the step after step update
181
+ state_steps.append(state['step'])
182
+
183
+ optim._functional.adamw(params_with_grad,
184
+ grads,
185
+ exp_avgs,
186
+ exp_avg_sqs,
187
+ max_exp_avg_sqs,
188
+ state_steps,
189
+ amsgrad=amsgrad,
190
+ beta1=beta1,
191
+ beta2=beta2,
192
+ lr=group['lr'],
193
+ weight_decay=group['weight_decay'],
194
+ eps=group['eps'],
195
+ maximize=False)
196
+
197
+ cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
198
+ for param, ema_param in zip(params_with_grad, ema_params_with_grad):
199
+ ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
200
+
201
+ return loss
ldm_patched/licenses-3rd/chainer ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2015 Preferred Infrastructure, Inc.
2
+ Copyright (c) 2015 Preferred Networks, Inc.
3
+
4
+ Permission is hereby granted, free of charge, to any person obtaining a copy
5
+ of this software and associated documentation files (the "Software"), to deal
6
+ in the Software without restriction, including without limitation the rights
7
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8
+ copies of the Software, and to permit persons to whom the Software is
9
+ furnished to do so, subject to the following conditions:
10
+
11
+ The above copyright notice and this permission notice shall be included in
12
+ all copies or substantial portions of the Software.
13
+
14
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20
+ THE SOFTWARE.
ldm_patched/licenses-3rd/comfyui ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU General Public License is a free, copyleft license for
11
+ software and other kinds of works.
12
+
13
+ The licenses for most software and other practical works are designed
14
+ to take away your freedom to share and change the works. By contrast,
15
+ the GNU General Public License is intended to guarantee your freedom to
16
+ share and change all versions of a program--to make sure it remains free
17
+ software for all its users. We, the Free Software Foundation, use the
18
+ GNU General Public License for most of our software; it applies also to
19
+ any other work released this way by its authors. You can apply it to
20
+ your programs, too.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ To protect your rights, we need to prevent others from denying you
30
+ these rights or asking you to surrender the rights. Therefore, you have
31
+ certain responsibilities if you distribute copies of the software, or if
32
+ you modify it: responsibilities to respect the freedom of others.
33
+
34
+ For example, if you distribute copies of such a program, whether
35
+ gratis or for a fee, you must pass on to the recipients the same
36
+ freedoms that you received. You must make sure that they, too, receive
37
+ or can get the source code. And you must show them these terms so they
38
+ know their rights.
39
+
40
+ Developers that use the GNU GPL protect your rights with two steps:
41
+ (1) assert copyright on the software, and (2) offer you this License
42
+ giving you legal permission to copy, distribute and/or modify it.
43
+
44
+ For the developers' and authors' protection, the GPL clearly explains
45
+ that there is no warranty for this free software. For both users' and
46
+ authors' sake, the GPL requires that modified versions be marked as
47
+ changed, so that their problems will not be attributed erroneously to
48
+ authors of previous versions.
49
+
50
+ Some devices are designed to deny users access to install or run
51
+ modified versions of the software inside them, although the manufacturer
52
+ can do so. This is fundamentally incompatible with the aim of
53
+ protecting users' freedom to change the software. The systematic
54
+ pattern of such abuse occurs in the area of products for individuals to
55
+ use, which is precisely where it is most unacceptable. Therefore, we
56
+ have designed this version of the GPL to prohibit the practice for those
57
+ products. If such problems arise substantially in other domains, we
58
+ stand ready to extend this provision to those domains in future versions
59
+ of the GPL, as needed to protect the freedom of users.
60
+
61
+ Finally, every program is threatened constantly by software patents.
62
+ States should not allow patents to restrict development and use of
63
+ software on general-purpose computers, but in those that do, we wish to
64
+ avoid the special danger that patents applied to a free program could
65
+ make it effectively proprietary. To prevent this, the GPL assures that
66
+ patents cannot be used to render the program non-free.
67
+
68
+ The precise terms and conditions for copying, distribution and
69
+ modification follow.
70
+
71
+ TERMS AND CONDITIONS
72
+
73
+ 0. Definitions.
74
+
75
+ "This License" refers to version 3 of the GNU General Public License.
76
+
77
+ "Copyright" also means copyright-like laws that apply to other kinds of
78
+ works, such as semiconductor masks.
79
+
80
+ "The Program" refers to any copyrightable work licensed under this
81
+ License. Each licensee is addressed as "you". "Licensees" and
82
+ "recipients" may be individuals or organizations.
83
+
84
+ To "modify" a work means to copy from or adapt all or part of the work
85
+ in a fashion requiring copyright permission, other than the making of an
86
+ exact copy. The resulting work is called a "modified version" of the
87
+ earlier work or a work "based on" the earlier work.
88
+
89
+ A "covered work" means either the unmodified Program or a work based
90
+ on the Program.
91
+
92
+ To "propagate" a work means to do anything with it that, without
93
+ permission, would make you directly or secondarily liable for
94
+ infringement under applicable copyright law, except executing it on a
95
+ computer or modifying a private copy. Propagation includes copying,
96
+ distribution (with or without modification), making available to the
97
+ public, and in some countries other activities as well.
98
+
99
+ To "convey" a work means any kind of propagation that enables other
100
+ parties to make or receive copies. Mere interaction with a user through
101
+ a computer network, with no transfer of a copy, is not conveying.
102
+
103
+ An interactive user interface displays "Appropriate Legal Notices"
104
+ to the extent that it includes a convenient and prominently visible
105
+ feature that (1) displays an appropriate copyright notice, and (2)
106
+ tells the user that there is no warranty for the work (except to the
107
+ extent that warranties are provided), that licensees may convey the
108
+ work under this License, and how to view a copy of this License. If
109
+ the interface presents a list of user commands or options, such as a
110
+ menu, a prominent item in the list meets this criterion.
111
+
112
+ 1. Source Code.
113
+
114
+ The "source code" for a work means the preferred form of the work
115
+ for making modifications to it. "Object code" means any non-source
116
+ form of a work.
117
+
118
+ A "Standard Interface" means an interface that either is an official
119
+ standard defined by a recognized standards body, or, in the case of
120
+ interfaces specified for a particular programming language, one that
121
+ is widely used among developers working in that language.
122
+
123
+ The "System Libraries" of an executable work include anything, other
124
+ than the work as a whole, that (a) is included in the normal form of
125
+ packaging a Major Component, but which is not part of that Major
126
+ Component, and (b) serves only to enable use of the work with that
127
+ Major Component, or to implement a Standard Interface for which an
128
+ implementation is available to the public in source code form. A
129
+ "Major Component", in this context, means a major essential component
130
+ (kernel, window system, and so on) of the specific operating system
131
+ (if any) on which the executable work runs, or a compiler used to
132
+ produce the work, or an object code interpreter used to run it.
133
+
134
+ The "Corresponding Source" for a work in object code form means all
135
+ the source code needed to generate, install, and (for an executable
136
+ work) run the object code and to modify the work, including scripts to
137
+ control those activities. However, it does not include the work's
138
+ System Libraries, or general-purpose tools or generally available free
139
+ programs which are used unmodified in performing those activities but
140
+ which are not part of the work. For example, Corresponding Source
141
+ includes interface definition files associated with source files for
142
+ the work, and the source code for shared libraries and dynamically
143
+ linked subprograms that the work is specifically designed to require,
144
+ such as by intimate data communication or control flow between those
145
+ subprograms and other parts of the work.
146
+
147
+ The Corresponding Source need not include anything that users
148
+ can regenerate automatically from other parts of the Corresponding
149
+ Source.
150
+
151
+ The Corresponding Source for a work in source code form is that
152
+ same work.
153
+
154
+ 2. Basic Permissions.
155
+
156
+ All rights granted under this License are granted for the term of
157
+ copyright on the Program, and are irrevocable provided the stated
158
+ conditions are met. This License explicitly affirms your unlimited
159
+ permission to run the unmodified Program. The output from running a
160
+ covered work is covered by this License only if the output, given its
161
+ content, constitutes a covered work. This License acknowledges your
162
+ rights of fair use or other equivalent, as provided by copyright law.
163
+
164
+ You may make, run and propagate covered works that you do not
165
+ convey, without conditions so long as your license otherwise remains
166
+ in force. You may convey covered works to others for the sole purpose
167
+ of having them make modifications exclusively for you, or provide you
168
+ with facilities for running those works, provided that you comply with
169
+ the terms of this License in conveying all material for which you do
170
+ not control copyright. Those thus making or running the covered works
171
+ for you must do so exclusively on your behalf, under your direction
172
+ and control, on terms that prohibit them from making any copies of
173
+ your copyrighted material outside their relationship with you.
174
+
175
+ Conveying under any other circumstances is permitted solely under
176
+ the conditions stated below. Sublicensing is not allowed; section 10
177
+ makes it unnecessary.
178
+
179
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
+
181
+ No covered work shall be deemed part of an effective technological
182
+ measure under any applicable law fulfilling obligations under article
183
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
+ similar laws prohibiting or restricting circumvention of such
185
+ measures.
186
+
187
+ When you convey a covered work, you waive any legal power to forbid
188
+ circumvention of technological measures to the extent such circumvention
189
+ is effected by exercising rights under this License with respect to
190
+ the covered work, and you disclaim any intention to limit operation or
191
+ modification of the work as a means of enforcing, against the work's
192
+ users, your or third parties' legal rights to forbid circumvention of
193
+ technological measures.
194
+
195
+ 4. Conveying Verbatim Copies.
196
+
197
+ You may convey verbatim copies of the Program's source code as you
198
+ receive it, in any medium, provided that you conspicuously and
199
+ appropriately publish on each copy an appropriate copyright notice;
200
+ keep intact all notices stating that this License and any
201
+ non-permissive terms added in accord with section 7 apply to the code;
202
+ keep intact all notices of the absence of any warranty; and give all
203
+ recipients a copy of this License along with the Program.
204
+
205
+ You may charge any price or no price for each copy that you convey,
206
+ and you may offer support or warranty protection for a fee.
207
+
208
+ 5. Conveying Modified Source Versions.
209
+
210
+ You may convey a work based on the Program, or the modifications to
211
+ produce it from the Program, in the form of source code under the
212
+ terms of section 4, provided that you also meet all of these conditions:
213
+
214
+ a) The work must carry prominent notices stating that you modified
215
+ it, and giving a relevant date.
216
+
217
+ b) The work must carry prominent notices stating that it is
218
+ released under this License and any conditions added under section
219
+ 7. This requirement modifies the requirement in section 4 to
220
+ "keep intact all notices".
221
+
222
+ c) You must license the entire work, as a whole, under this
223
+ License to anyone who comes into possession of a copy. This
224
+ License will therefore apply, along with any applicable section 7
225
+ additional terms, to the whole of the work, and all its parts,
226
+ regardless of how they are packaged. This License gives no
227
+ permission to license the work in any other way, but it does not
228
+ invalidate such permission if you have separately received it.
229
+
230
+ d) If the work has interactive user interfaces, each must display
231
+ Appropriate Legal Notices; however, if the Program has interactive
232
+ interfaces that do not display Appropriate Legal Notices, your
233
+ work need not make them do so.
234
+
235
+ A compilation of a covered work with other separate and independent
236
+ works, which are not by their nature extensions of the covered work,
237
+ and which are not combined with it such as to form a larger program,
238
+ in or on a volume of a storage or distribution medium, is called an
239
+ "aggregate" if the compilation and its resulting copyright are not
240
+ used to limit the access or legal rights of the compilation's users
241
+ beyond what the individual works permit. Inclusion of a covered work
242
+ in an aggregate does not cause this License to apply to the other
243
+ parts of the aggregate.
244
+
245
+ 6. Conveying Non-Source Forms.
246
+
247
+ You may convey a covered work in object code form under the terms
248
+ of sections 4 and 5, provided that you also convey the
249
+ machine-readable Corresponding Source under the terms of this License,
250
+ in one of these ways:
251
+
252
+ a) Convey the object code in, or embodied in, a physical product
253
+ (including a physical distribution medium), accompanied by the
254
+ Corresponding Source fixed on a durable physical medium
255
+ customarily used for software interchange.
256
+
257
+ b) Convey the object code in, or embodied in, a physical product
258
+ (including a physical distribution medium), accompanied by a
259
+ written offer, valid for at least three years and valid for as
260
+ long as you offer spare parts or customer support for that product
261
+ model, to give anyone who possesses the object code either (1) a
262
+ copy of the Corresponding Source for all the software in the
263
+ product that is covered by this License, on a durable physical
264
+ medium customarily used for software interchange, for a price no
265
+ more than your reasonable cost of physically performing this
266
+ conveying of source, or (2) access to copy the
267
+ Corresponding Source from a network server at no charge.
268
+
269
+ c) Convey individual copies of the object code with a copy of the
270
+ written offer to provide the Corresponding Source. This
271
+ alternative is allowed only occasionally and noncommercially, and
272
+ only if you received the object code with such an offer, in accord
273
+ with subsection 6b.
274
+
275
+ d) Convey the object code by offering access from a designated
276
+ place (gratis or for a charge), and offer equivalent access to the
277
+ Corresponding Source in the same way through the same place at no
278
+ further charge. You need not require recipients to copy the
279
+ Corresponding Source along with the object code. If the place to
280
+ copy the object code is a network server, the Corresponding Source
281
+ may be on a different server (operated by you or a third party)
282
+ that supports equivalent copying facilities, provided you maintain
283
+ clear directions next to the object code saying where to find the
284
+ Corresponding Source. Regardless of what server hosts the
285
+ Corresponding Source, you remain obligated to ensure that it is
286
+ available for as long as needed to satisfy these requirements.
287
+
288
+ e) Convey the object code using peer-to-peer transmission, provided
289
+ you inform other peers where the object code and Corresponding
290
+ Source of the work are being offered to the general public at no
291
+ charge under subsection 6d.
292
+
293
+ A separable portion of the object code, whose source code is excluded
294
+ from the Corresponding Source as a System Library, need not be
295
+ included in conveying the object code work.
296
+
297
+ A "User Product" is either (1) a "consumer product", which means any
298
+ tangible personal property which is normally used for personal, family,
299
+ or household purposes, or (2) anything designed or sold for incorporation
300
+ into a dwelling. In determining whether a product is a consumer product,
301
+ doubtful cases shall be resolved in favor of coverage. For a particular
302
+ product received by a particular user, "normally used" refers to a
303
+ typical or common use of that class of product, regardless of the status
304
+ of the particular user or of the way in which the particular user
305
+ actually uses, or expects or is expected to use, the product. A product
306
+ is a consumer product regardless of whether the product has substantial
307
+ commercial, industrial or non-consumer uses, unless such uses represent
308
+ the only significant mode of use of the product.
309
+
310
+ "Installation Information" for a User Product means any methods,
311
+ procedures, authorization keys, or other information required to install
312
+ and execute modified versions of a covered work in that User Product from
313
+ a modified version of its Corresponding Source. The information must
314
+ suffice to ensure that the continued functioning of the modified object
315
+ code is in no case prevented or interfered with solely because
316
+ modification has been made.
317
+
318
+ If you convey an object code work under this section in, or with, or
319
+ specifically for use in, a User Product, and the conveying occurs as
320
+ part of a transaction in which the right of possession and use of the
321
+ User Product is transferred to the recipient in perpetuity or for a
322
+ fixed term (regardless of how the transaction is characterized), the
323
+ Corresponding Source conveyed under this section must be accompanied
324
+ by the Installation Information. But this requirement does not apply
325
+ if neither you nor any third party retains the ability to install
326
+ modified object code on the User Product (for example, the work has
327
+ been installed in ROM).
328
+
329
+ The requirement to provide Installation Information does not include a
330
+ requirement to continue to provide support service, warranty, or updates
331
+ for a work that has been modified or installed by the recipient, or for
332
+ the User Product in which it has been modified or installed. Access to a
333
+ network may be denied when the modification itself materially and
334
+ adversely affects the operation of the network or violates the rules and
335
+ protocols for communication across the network.
336
+
337
+ Corresponding Source conveyed, and Installation Information provided,
338
+ in accord with this section must be in a format that is publicly
339
+ documented (and with an implementation available to the public in
340
+ source code form), and must require no special password or key for
341
+ unpacking, reading or copying.
342
+
343
+ 7. Additional Terms.
344
+
345
+ "Additional permissions" are terms that supplement the terms of this
346
+ License by making exceptions from one or more of its conditions.
347
+ Additional permissions that are applicable to the entire Program shall
348
+ be treated as though they were included in this License, to the extent
349
+ that they are valid under applicable law. If additional permissions
350
+ apply only to part of the Program, that part may be used separately
351
+ under those permissions, but the entire Program remains governed by
352
+ this License without regard to the additional permissions.
353
+
354
+ When you convey a copy of a covered work, you may at your option
355
+ remove any additional permissions from that copy, or from any part of
356
+ it. (Additional permissions may be written to require their own
357
+ removal in certain cases when you modify the work.) You may place
358
+ additional permissions on material, added by you to a covered work,
359
+ for which you have or can give appropriate copyright permission.
360
+
361
+ Notwithstanding any other provision of this License, for material you
362
+ add to a covered work, you may (if authorized by the copyright holders of
363
+ that material) supplement the terms of this License with terms:
364
+
365
+ a) Disclaiming warranty or limiting liability differently from the
366
+ terms of sections 15 and 16 of this License; or
367
+
368
+ b) Requiring preservation of specified reasonable legal notices or
369
+ author attributions in that material or in the Appropriate Legal
370
+ Notices displayed by works containing it; or
371
+
372
+ c) Prohibiting misrepresentation of the origin of that material, or
373
+ requiring that modified versions of such material be marked in
374
+ reasonable ways as different from the original version; or
375
+
376
+ d) Limiting the use for publicity purposes of names of licensors or
377
+ authors of the material; or
378
+
379
+ e) Declining to grant rights under trademark law for use of some
380
+ trade names, trademarks, or service marks; or
381
+
382
+ f) Requiring indemnification of licensors and authors of that
383
+ material by anyone who conveys the material (or modified versions of
384
+ it) with contractual assumptions of liability to the recipient, for
385
+ any liability that these contractual assumptions directly impose on
386
+ those licensors and authors.
387
+
388
+ All other non-permissive additional terms are considered "further
389
+ restrictions" within the meaning of section 10. If the Program as you
390
+ received it, or any part of it, contains a notice stating that it is
391
+ governed by this License along with a term that is a further
392
+ restriction, you may remove that term. If a license document contains
393
+ a further restriction but permits relicensing or conveying under this
394
+ License, you may add to a covered work material governed by the terms
395
+ of that license document, provided that the further restriction does
396
+ not survive such relicensing or conveying.
397
+
398
+ If you add terms to a covered work in accord with this section, you
399
+ must place, in the relevant source files, a statement of the
400
+ additional terms that apply to those files, or a notice indicating
401
+ where to find the applicable terms.
402
+
403
+ Additional terms, permissive or non-permissive, may be stated in the
404
+ form of a separately written license, or stated as exceptions;
405
+ the above requirements apply either way.
406
+
407
+ 8. Termination.
408
+
409
+ You may not propagate or modify a covered work except as expressly
410
+ provided under this License. Any attempt otherwise to propagate or
411
+ modify it is void, and will automatically terminate your rights under
412
+ this License (including any patent licenses granted under the third
413
+ paragraph of section 11).
414
+
415
+ However, if you cease all violation of this License, then your
416
+ license from a particular copyright holder is reinstated (a)
417
+ provisionally, unless and until the copyright holder explicitly and
418
+ finally terminates your license, and (b) permanently, if the copyright
419
+ holder fails to notify you of the violation by some reasonable means
420
+ prior to 60 days after the cessation.
421
+
422
+ Moreover, your license from a particular copyright holder is
423
+ reinstated permanently if the copyright holder notifies you of the
424
+ violation by some reasonable means, this is the first time you have
425
+ received notice of violation of this License (for any work) from that
426
+ copyright holder, and you cure the violation prior to 30 days after
427
+ your receipt of the notice.
428
+
429
+ Termination of your rights under this section does not terminate the
430
+ licenses of parties who have received copies or rights from you under
431
+ this License. If your rights have been terminated and not permanently
432
+ reinstated, you do not qualify to receive new licenses for the same
433
+ material under section 10.
434
+
435
+ 9. Acceptance Not Required for Having Copies.
436
+
437
+ You are not required to accept this License in order to receive or
438
+ run a copy of the Program. Ancillary propagation of a covered work
439
+ occurring solely as a consequence of using peer-to-peer transmission
440
+ to receive a copy likewise does not require acceptance. However,
441
+ nothing other than this License grants you permission to propagate or
442
+ modify any covered work. These actions infringe copyright if you do
443
+ not accept this License. Therefore, by modifying or propagating a
444
+ covered work, you indicate your acceptance of this License to do so.
445
+
446
+ 10. Automatic Licensing of Downstream Recipients.
447
+
448
+ Each time you convey a covered work, the recipient automatically
449
+ receives a license from the original licensors, to run, modify and
450
+ propagate that work, subject to this License. You are not responsible
451
+ for enforcing compliance by third parties with this License.
452
+
453
+ An "entity transaction" is a transaction transferring control of an
454
+ organization, or substantially all assets of one, or subdividing an
455
+ organization, or merging organizations. If propagation of a covered
456
+ work results from an entity transaction, each party to that
457
+ transaction who receives a copy of the work also receives whatever
458
+ licenses to the work the party's predecessor in interest had or could
459
+ give under the previous paragraph, plus a right to possession of the
460
+ Corresponding Source of the work from the predecessor in interest, if
461
+ the predecessor has it or can get it with reasonable efforts.
462
+
463
+ You may not impose any further restrictions on the exercise of the
464
+ rights granted or affirmed under this License. For example, you may
465
+ not impose a license fee, royalty, or other charge for exercise of
466
+ rights granted under this License, and you may not initiate litigation
467
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
468
+ any patent claim is infringed by making, using, selling, offering for
469
+ sale, or importing the Program or any portion of it.
470
+
471
+ 11. Patents.
472
+
473
+ A "contributor" is a copyright holder who authorizes use under this
474
+ License of the Program or a work on which the Program is based. The
475
+ work thus licensed is called the contributor's "contributor version".
476
+
477
+ A contributor's "essential patent claims" are all patent claims
478
+ owned or controlled by the contributor, whether already acquired or
479
+ hereafter acquired, that would be infringed by some manner, permitted
480
+ by this License, of making, using, or selling its contributor version,
481
+ but do not include claims that would be infringed only as a
482
+ consequence of further modification of the contributor version. For
483
+ purposes of this definition, "control" includes the right to grant
484
+ patent sublicenses in a manner consistent with the requirements of
485
+ this License.
486
+
487
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
488
+ patent license under the contributor's essential patent claims, to
489
+ make, use, sell, offer for sale, import and otherwise run, modify and
490
+ propagate the contents of its contributor version.
491
+
492
+ In the following three paragraphs, a "patent license" is any express
493
+ agreement or commitment, however denominated, not to enforce a patent
494
+ (such as an express permission to practice a patent or covenant not to
495
+ sue for patent infringement). To "grant" such a patent license to a
496
+ party means to make such an agreement or commitment not to enforce a
497
+ patent against the party.
498
+
499
+ If you convey a covered work, knowingly relying on a patent license,
500
+ and the Corresponding Source of the work is not available for anyone
501
+ to copy, free of charge and under the terms of this License, through a
502
+ publicly available network server or other readily accessible means,
503
+ then you must either (1) cause the Corresponding Source to be so
504
+ available, or (2) arrange to deprive yourself of the benefit of the
505
+ patent license for this particular work, or (3) arrange, in a manner
506
+ consistent with the requirements of this License, to extend the patent
507
+ license to downstream recipients. "Knowingly relying" means you have
508
+ actual knowledge that, but for the patent license, your conveying the
509
+ covered work in a country, or your recipient's use of the covered work
510
+ in a country, would infringe one or more identifiable patents in that
511
+ country that you have reason to believe are valid.
512
+
513
+ If, pursuant to or in connection with a single transaction or
514
+ arrangement, you convey, or propagate by procuring conveyance of, a
515
+ covered work, and grant a patent license to some of the parties
516
+ receiving the covered work authorizing them to use, propagate, modify
517
+ or convey a specific copy of the covered work, then the patent license
518
+ you grant is automatically extended to all recipients of the covered
519
+ work and works based on it.
520
+
521
+ A patent license is "discriminatory" if it does not include within
522
+ the scope of its coverage, prohibits the exercise of, or is
523
+ conditioned on the non-exercise of one or more of the rights that are
524
+ specifically granted under this License. You may not convey a covered
525
+ work if you are a party to an arrangement with a third party that is
526
+ in the business of distributing software, under which you make payment
527
+ to the third party based on the extent of your activity of conveying
528
+ the work, and under which the third party grants, to any of the
529
+ parties who would receive the covered work from you, a discriminatory
530
+ patent license (a) in connection with copies of the covered work
531
+ conveyed by you (or copies made from those copies), or (b) primarily
532
+ for and in connection with specific products or compilations that
533
+ contain the covered work, unless you entered into that arrangement,
534
+ or that patent license was granted, prior to 28 March 2007.
535
+
536
+ Nothing in this License shall be construed as excluding or limiting
537
+ any implied license or other defenses to infringement that may
538
+ otherwise be available to you under applicable patent law.
539
+
540
+ 12. No Surrender of Others' Freedom.
541
+
542
+ If conditions are imposed on you (whether by court order, agreement or
543
+ otherwise) that contradict the conditions of this License, they do not
544
+ excuse you from the conditions of this License. If you cannot convey a
545
+ covered work so as to satisfy simultaneously your obligations under this
546
+ License and any other pertinent obligations, then as a consequence you may
547
+ not convey it at all. For example, if you agree to terms that obligate you
548
+ to collect a royalty for further conveying from those to whom you convey
549
+ the Program, the only way you could satisfy both those terms and this
550
+ License would be to refrain entirely from conveying the Program.
551
+
552
+ 13. Use with the GNU Affero General Public License.
553
+
554
+ Notwithstanding any other provision of this License, you have
555
+ permission to link or combine any covered work with a work licensed
556
+ under version 3 of the GNU Affero General Public License into a single
557
+ combined work, and to convey the resulting work. The terms of this
558
+ License will continue to apply to the part which is the covered work,
559
+ but the special requirements of the GNU Affero General Public License,
560
+ section 13, concerning interaction through a network will apply to the
561
+ combination as such.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU General Public License from time to time. Such new versions will
567
+ be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU General Public License for more details.
646
+
647
+ You should have received a copy of the GNU General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If the program does terminal interaction, make it output a short
653
+ notice like this when it starts in an interactive mode:
654
+
655
+ <program> Copyright (C) <year> <name of author>
656
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
+ This is free software, and you are welcome to redistribute it
658
+ under certain conditions; type `show c' for details.
659
+
660
+ The hypothetical commands `show w' and `show c' should show the appropriate
661
+ parts of the General Public License. Of course, your program's commands
662
+ might be different; for a GUI interface, you would use an "about box".
663
+
664
+ You should also get your employer (if you work as a programmer) or school,
665
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
666
+ For more information on this, and how to apply and follow the GNU GPL, see
667
+ <https://www.gnu.org/licenses/>.
668
+
669
+ The GNU General Public License does not permit incorporating your program
670
+ into proprietary programs. If your program is a subroutine library, you
671
+ may consider it more useful to permit linking proprietary applications with
672
+ the library. If this is what you want to do, use the GNU Lesser General
673
+ Public License instead of this License. But first, please read
674
+ <https://www.gnu.org/licenses/why-not-lgpl.html>.
ldm_patched/licenses-3rd/diffusers ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
ldm_patched/licenses-3rd/kdiffusion ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022 Katherine Crowson
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in
11
+ all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
+ THE SOFTWARE.
ldm_patched/licenses-3rd/ldm ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
ldm_patched/licenses-3rd/taesd ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Ollin Boer Bohan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
ldm_patched/licenses-3rd/transformers ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2018- The Hugging Face team. All rights reserved.
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction,
12
+ and distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ "Licensor" shall mean the copyright owner or entity authorized by
15
+ the copyright owner that is granting the License.
16
+
17
+ "Legal Entity" shall mean the union of the acting entity and all
18
+ other entities that control, are controlled by, or are under common
19
+ control with that entity. For the purposes of this definition,
20
+ "control" means (i) the power, direct or indirect, to cause the
21
+ direction or management of such entity, whether by contract or
22
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
+ outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity
26
+ exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications,
29
+ including but not limited to software source code, documentation
30
+ source, and configuration files.
31
+
32
+ "Object" form shall mean any form resulting from mechanical
33
+ transformation or translation of a Source form, including but
34
+ not limited to compiled object code, generated documentation,
35
+ and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or
38
+ Object form, made available under the License, as indicated by a
39
+ copyright notice that is included in or attached to the work
40
+ (an example is provided in the Appendix below).
41
+
42
+ "Derivative Works" shall mean any work, whether in Source or Object
43
+ form, that is based on (or derived from) the Work and for which the
44
+ editorial revisions, annotations, elaborations, or other modifications
45
+ represent, as a whole, an original work of authorship. For the purposes
46
+ of this License, Derivative Works shall not include works that remain
47
+ separable from, or merely link (or bind by name) to the interfaces of,
48
+ the Work and Derivative Works thereof.
49
+
50
+ "Contribution" shall mean any work of authorship, including
51
+ the original version of the Work and any modifications or additions
52
+ to that Work or Derivative Works thereof, that is intentionally
53
+ submitted to Licensor for inclusion in the Work by the copyright owner
54
+ or by an individual or Legal Entity authorized to submit on behalf of
55
+ the copyright owner. For the purposes of this definition, "submitted"
56
+ means any form of electronic, verbal, or written communication sent
57
+ to the Licensor or its representatives, including but not limited to
58
+ communication on electronic mailing lists, source code control systems,
59
+ and issue tracking systems that are managed by, or on behalf of, the
60
+ Licensor for the purpose of discussing and improving the Work, but
61
+ excluding communication that is conspicuously marked or otherwise
62
+ designated in writing by the copyright owner as "Not a Contribution."
63
+
64
+ "Contributor" shall mean Licensor and any individual or Legal Entity
65
+ on behalf of whom a Contribution has been received by Licensor and
66
+ subsequently incorporated within the Work.
67
+
68
+ 2. Grant of Copyright License. Subject to the terms and conditions of
69
+ this License, each Contributor hereby grants to You a perpetual,
70
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
+ copyright license to reproduce, prepare Derivative Works of,
72
+ publicly display, publicly perform, sublicense, and distribute the
73
+ Work and such Derivative Works in Source or Object form.
74
+
75
+ 3. Grant of Patent License. Subject to the terms and conditions of
76
+ this License, each Contributor hereby grants to You a perpetual,
77
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
+ (except as stated in this section) patent license to make, have made,
79
+ use, offer to sell, sell, import, and otherwise transfer the Work,
80
+ where such license applies only to those patent claims licensable
81
+ by such Contributor that are necessarily infringed by their
82
+ Contribution(s) alone or by combination of their Contribution(s)
83
+ with the Work to which such Contribution(s) was submitted. If You
84
+ institute patent litigation against any entity (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
86
+ or a Contribution incorporated within the Work constitutes direct
87
+ or contributory patent infringement, then any patent licenses
88
+ granted to You under this License for that Work shall terminate
89
+ as of the date such litigation is filed.
90
+
91
+ 4. Redistribution. You may reproduce and distribute copies of the
92
+ Work or Derivative Works thereof in any medium, with or without
93
+ modifications, and in Source or Object form, provided that You
94
+ meet the following conditions:
95
+
96
+ (a) You must give any other recipients of the Work or
97
+ Derivative Works a copy of this License; and
98
+
99
+ (b) You must cause any modified files to carry prominent notices
100
+ stating that You changed the files; and
101
+
102
+ (c) You must retain, in the Source form of any Derivative Works
103
+ that You distribute, all copyright, patent, trademark, and
104
+ attribution notices from the Source form of the Work,
105
+ excluding those notices that do not pertain to any part of
106
+ the Derivative Works; and
107
+
108
+ (d) If the Work includes a "NOTICE" text file as part of its
109
+ distribution, then any Derivative Works that You distribute must
110
+ include a readable copy of the attribution notices contained
111
+ within such NOTICE file, excluding those notices that do not
112
+ pertain to any part of the Derivative Works, in at least one
113
+ of the following places: within a NOTICE text file distributed
114
+ as part of the Derivative Works; within the Source form or
115
+ documentation, if provided along with the Derivative Works; or,
116
+ within a display generated by the Derivative Works, if and
117
+ wherever such third-party notices normally appear. The contents
118
+ of the NOTICE file are for informational purposes only and
119
+ do not modify the License. You may add Your own attribution
120
+ notices within Derivative Works that You distribute, alongside
121
+ or as an addendum to the NOTICE text from the Work, provided
122
+ that such additional attribution notices cannot be construed
123
+ as modifying the License.
124
+
125
+ You may add Your own copyright statement to Your modifications and
126
+ may provide additional or different license terms and conditions
127
+ for use, reproduction, or distribution of Your modifications, or
128
+ for any such Derivative Works as a whole, provided Your use,
129
+ reproduction, and distribution of the Work otherwise complies with
130
+ the conditions stated in this License.
131
+
132
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
133
+ any Contribution intentionally submitted for inclusion in the Work
134
+ by You to the Licensor shall be under the terms and conditions of
135
+ this License, without any additional terms or conditions.
136
+ Notwithstanding the above, nothing herein shall supersede or modify
137
+ the terms of any separate license agreement you may have executed
138
+ with Licensor regarding such Contributions.
139
+
140
+ 6. Trademarks. This License does not grant permission to use the trade
141
+ names, trademarks, service marks, or product names of the Licensor,
142
+ except as required for reasonable and customary use in describing the
143
+ origin of the Work and reproducing the content of the NOTICE file.
144
+
145
+ 7. Disclaimer of Warranty. Unless required by applicable law or
146
+ agreed to in writing, Licensor provides the Work (and each
147
+ Contributor provides its Contributions) on an "AS IS" BASIS,
148
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
+ implied, including, without limitation, any warranties or conditions
150
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
+ PARTICULAR PURPOSE. You are solely responsible for determining the
152
+ appropriateness of using or redistributing the Work and assume any
153
+ risks associated with Your exercise of permissions under this License.
154
+
155
+ 8. Limitation of Liability. In no event and under no legal theory,
156
+ whether in tort (including negligence), contract, or otherwise,
157
+ unless required by applicable law (such as deliberate and grossly
158
+ negligent acts) or agreed to in writing, shall any Contributor be
159
+ liable to You for damages, including any direct, indirect, special,
160
+ incidental, or consequential damages of any character arising as a
161
+ result of this License or out of the use or inability to use the
162
+ Work (including but not limited to damages for loss of goodwill,
163
+ work stoppage, computer failure or malfunction, or any and all
164
+ other commercial damages or losses), even if such Contributor
165
+ has been advised of the possibility of such damages.
166
+
167
+ 9. Accepting Warranty or Additional Liability. While redistributing
168
+ the Work or Derivative Works thereof, You may choose to offer,
169
+ and charge a fee for, acceptance of support, warranty, indemnity,
170
+ or other liability obligations and/or rights consistent with this
171
+ License. However, in accepting such obligations, You may act only
172
+ on Your own behalf and on Your sole responsibility, not on behalf
173
+ of any other Contributor, and only if You agree to indemnify,
174
+ defend, and hold each Contributor harmless for any liability
175
+ incurred by, or claims asserted against, such Contributor by reason
176
+ of your accepting any such warranty or additional liability.
177
+
178
+ END OF TERMS AND CONDITIONS
179
+
180
+ APPENDIX: How to apply the Apache License to your work.
181
+
182
+ To apply the Apache License to your work, attach the following
183
+ boilerplate notice, with the fields enclosed by brackets "[]"
184
+ replaced with your own identifying information. (Don't include
185
+ the brackets!) The text should be enclosed in the appropriate
186
+ comment syntax for the file format. We also recommend that a
187
+ file or class name and description of purpose be included on the
188
+ same "printed page" as the copyright notice for easier
189
+ identification within third-party archives.
190
+
191
+ Copyright [yyyy] [name of copyright owner]
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
ldm_patched/modules/args_parser.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/comfyanonymous/ComfyUI
2
+ # This file is only for reference, and not used in the backend or runtime.
3
+
4
+
5
+ import argparse
6
+ import enum
7
+ import ldm_patched.modules.options
8
+
9
+ class EnumAction(argparse.Action):
10
+ """
11
+ Argparse action for handling Enums
12
+ """
13
+ def __init__(self, **kwargs):
14
+ # Pop off the type value
15
+ enum_type = kwargs.pop("type", None)
16
+
17
+ # Ensure an Enum subclass is provided
18
+ if enum_type is None:
19
+ raise ValueError("type must be assigned an Enum when using EnumAction")
20
+ if not issubclass(enum_type, enum.Enum):
21
+ raise TypeError("type must be an Enum when using EnumAction")
22
+
23
+ # Generate choices from the Enum
24
+ choices = tuple(e.value for e in enum_type)
25
+ kwargs.setdefault("choices", choices)
26
+ kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
27
+
28
+ super(EnumAction, self).__init__(**kwargs)
29
+
30
+ self._enum = enum_type
31
+
32
+ def __call__(self, parser, namespace, values, option_string=None):
33
+ # Convert value back into an Enum
34
+ value = self._enum(values)
35
+ setattr(namespace, self.dest, value)
36
+
37
+
38
+ parser = argparse.ArgumentParser()
39
+
40
+ #parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0")
41
+ #parser.add_argument("--port", type=int, default=8188)
42
+ parser.add_argument("--disable-header-check", type=str, default=None, metavar="ORIGIN", nargs="?", const="*")
43
+ parser.add_argument("--web-upload-size", type=float, default=100)
44
+
45
+ parser.add_argument("--external-working-path", type=str, default=None, metavar="PATH", nargs='+', action='append')
46
+ parser.add_argument("--output-path", type=str, default=None)
47
+ parser.add_argument("--temp-path", type=str, default=None)
48
+ parser.add_argument("--cache-path", type=str, default=None)
49
+ parser.add_argument("--in-browser", action="store_true")
50
+ parser.add_argument("--disable-in-browser", action="store_true")
51
+ parser.add_argument("--gpu-device-id", type=int, default=None, metavar="DEVICE_ID")
52
+
53
+ parser.add_argument("--disable-attention-upcast", action="store_true")
54
+
55
+ fp_group = parser.add_mutually_exclusive_group()
56
+ fp_group.add_argument("--all-in-fp32", action="store_true")
57
+ fp_group.add_argument("--all-in-fp16", action="store_true")
58
+
59
+ fpunet_group = parser.add_mutually_exclusive_group()
60
+ fpunet_group.add_argument("--unet-in-bf16", action="store_true")
61
+ fpunet_group.add_argument("--unet-in-fp16", action="store_true")
62
+ fpunet_group.add_argument("--unet-in-fp8-e4m3fn", action="store_true")
63
+ fpunet_group.add_argument("--unet-in-fp8-e5m2", action="store_true")
64
+
65
+ fpvae_group = parser.add_mutually_exclusive_group()
66
+ fpvae_group.add_argument("--vae-in-fp16", action="store_true")
67
+ fpvae_group.add_argument("--vae-in-fp32", action="store_true")
68
+ fpvae_group.add_argument("--vae-in-bf16", action="store_true")
69
+
70
+ parser.add_argument("--vae-in-cpu", action="store_true")
71
+
72
+ fpte_group = parser.add_mutually_exclusive_group()
73
+ fpte_group.add_argument("--clip-in-fp8-e4m3fn", action="store_true")
74
+ fpte_group.add_argument("--clip-in-fp8-e5m2", action="store_true")
75
+ fpte_group.add_argument("--clip-in-fp16", action="store_true")
76
+ fpte_group.add_argument("--clip-in-fp32", action="store_true")
77
+
78
+
79
+ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1)
80
+
81
+ parser.add_argument("--disable-ipex-hijack", action="store_true")
82
+
83
+ class LatentPreviewMethod(enum.Enum):
84
+ NoPreviews = "none"
85
+ Auto = "auto"
86
+ Latent2RGB = "fast"
87
+ TAESD = "taesd"
88
+
89
+ parser.add_argument("--preview-option", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, action=EnumAction)
90
+
91
+ attn_group = parser.add_mutually_exclusive_group()
92
+ attn_group.add_argument("--attention-split", action="store_true")
93
+ attn_group.add_argument("--attention-quad", action="store_true")
94
+ attn_group.add_argument("--attention-pytorch", action="store_true")
95
+
96
+ parser.add_argument("--disable-xformers", action="store_true")
97
+
98
+ vram_group = parser.add_mutually_exclusive_group()
99
+ vram_group.add_argument("--always-gpu", action="store_true")
100
+ vram_group.add_argument("--always-high-vram", action="store_true")
101
+ vram_group.add_argument("--always-normal-vram", action="store_true")
102
+ vram_group.add_argument("--always-low-vram", action="store_true")
103
+ vram_group.add_argument("--always-no-vram", action="store_true")
104
+ vram_group.add_argument("--always-cpu", action="store_true")
105
+
106
+
107
+ parser.add_argument("--always-offload-from-vram", action="store_true")
108
+ parser.add_argument("--pytorch-deterministic", action="store_true")
109
+
110
+ parser.add_argument("--disable-server-log", action="store_true")
111
+ parser.add_argument("--debug-mode", action="store_true")
112
+ parser.add_argument("--is-windows-embedded-python", action="store_true")
113
+
114
+ parser.add_argument("--disable-server-info", action="store_true")
115
+
116
+ parser.add_argument("--multi-user", action="store_true")
117
+
118
+ parser.add_argument("--cuda-malloc", action="store_true")
119
+ parser.add_argument("--cuda-stream", action="store_true")
120
+ parser.add_argument("--pin-shared-memory", action="store_true")
121
+
122
+ if ldm_patched.modules.options.args_parsing:
123
+ args = parser.parse_args([])
124
+ else:
125
+ args = parser.parse_args([])
126
+
127
+ if args.is_windows_embedded_python:
128
+ args.in_browser = True
129
+
130
+ if args.disable_in_browser:
131
+ args.in_browser = False