Qdssa commited on
Commit
0035f04
1 Parent(s): 80d1e06

Upload 269 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Backend/.pylintrc +3 -0
  2. Backend/0.4.12 +29 -0
  3. Backend/app/__init__.py +0 -0
  4. Backend/app/app_settings.py +54 -0
  5. Backend/app/frontend_management.py +188 -0
  6. Backend/app/user_manager.py +205 -0
  7. Backend/comfy/checkpoint_pickle.py +13 -0
  8. Backend/comfy/cldm/cldm.py +437 -0
  9. Backend/comfy/cldm/control_types.py +10 -0
  10. Backend/comfy/cldm/mmdit.py +77 -0
  11. Backend/comfy/cli_args.py +180 -0
  12. Backend/comfy/clip_config_bigg.json +23 -0
  13. Backend/comfy/clip_model.py +196 -0
  14. Backend/comfy/clip_vision.py +121 -0
  15. Backend/comfy/clip_vision_config_g.json +18 -0
  16. Backend/comfy/clip_vision_config_h.json +18 -0
  17. Backend/comfy/clip_vision_config_vitl.json +18 -0
  18. Backend/comfy/clip_vision_config_vitl_336.json +18 -0
  19. Backend/comfy/conds.py +83 -0
  20. Backend/comfy/controlnet.py +622 -0
  21. Backend/comfy/diffusers_convert.py +281 -0
  22. Backend/comfy/diffusers_load.py +36 -0
  23. Backend/comfy/extra_samplers/uni_pc.py +875 -0
  24. Backend/comfy/gligen.py +343 -0
  25. Backend/comfy/k_diffusion/deis.py +121 -0
  26. Backend/comfy/k_diffusion/sampling.py +1050 -0
  27. Backend/comfy/k_diffusion/utils.py +313 -0
  28. Backend/comfy/latent_formats.py +170 -0
  29. Backend/comfy/ldm/audio/autoencoder.py +282 -0
  30. Backend/comfy/ldm/audio/dit.py +891 -0
  31. Backend/comfy/ldm/audio/embedders.py +108 -0
  32. Backend/comfy/ldm/aura/mmdit.py +478 -0
  33. Backend/comfy/ldm/cascade/common.py +154 -0
  34. Backend/comfy/ldm/cascade/controlnet.py +93 -0
  35. Backend/comfy/ldm/cascade/stage_a.py +255 -0
  36. Backend/comfy/ldm/cascade/stage_b.py +256 -0
  37. Backend/comfy/ldm/cascade/stage_c.py +273 -0
  38. Backend/comfy/ldm/cascade/stage_c_coder.py +95 -0
  39. Backend/comfy/ldm/common_dit.py +8 -0
  40. Backend/comfy/ldm/flux/layers.py +263 -0
  41. Backend/comfy/ldm/flux/math.py +35 -0
  42. Backend/comfy/ldm/flux/model.py +142 -0
  43. Backend/comfy/ldm/hydit/attn_layers.py +219 -0
  44. Backend/comfy/ldm/hydit/models.py +405 -0
  45. Backend/comfy/ldm/hydit/poolers.py +37 -0
  46. Backend/comfy/ldm/hydit/posemb_layers.py +224 -0
  47. Backend/comfy/ldm/models/autoencoder.py +226 -0
  48. Backend/comfy/ldm/modules/attention.py +865 -0
  49. Backend/comfy/ldm/modules/diffusionmodules/__init__.py +0 -0
  50. Backend/comfy/ldm/modules/diffusionmodules/mmdit.py +955 -0
Backend/.pylintrc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [MESSAGES CONTROL]
2
+ disable=all
3
+ enable=eval-used
Backend/0.4.12 ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Collecting timm
2
+ Downloading timm-1.0.9-py3-none-any.whl.metadata (42 kB)
3
+ Requirement already satisfied: torch in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from timm) (2.4.0)
4
+ Requirement already satisfied: torchvision in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from timm) (0.19.0)
5
+ Requirement already satisfied: pyyaml in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from timm) (6.0.2)
6
+ Requirement already satisfied: huggingface_hub in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from timm) (0.24.6)
7
+ Requirement already satisfied: safetensors in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from timm) (0.4.4)
8
+ Requirement already satisfied: filelock in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from huggingface_hub->timm) (3.15.4)
9
+ Requirement already satisfied: fsspec>=2023.5.0 in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from huggingface_hub->timm) (2024.6.1)
10
+ Requirement already satisfied: packaging>=20.9 in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from huggingface_hub->timm) (24.1)
11
+ Requirement already satisfied: requests in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from huggingface_hub->timm) (2.32.3)
12
+ Requirement already satisfied: tqdm>=4.42.1 in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from huggingface_hub->timm) (4.66.5)
13
+ Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from huggingface_hub->timm) (4.12.2)
14
+ Requirement already satisfied: sympy in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from torch->timm) (1.13.2)
15
+ Requirement already satisfied: networkx in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from torch->timm) (3.3)
16
+ Requirement already satisfied: jinja2 in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from torch->timm) (3.1.4)
17
+ Requirement already satisfied: numpy<2 in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from torchvision->timm) (1.26.4)
18
+ Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from torchvision->timm) (9.5.0)
19
+ Requirement already satisfied: colorama in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from tqdm>=4.42.1->huggingface_hub->timm) (0.4.6)
20
+ Requirement already satisfied: MarkupSafe>=2.0 in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from jinja2->torch->timm) (2.1.5)
21
+ Requirement already satisfied: charset-normalizer<4,>=2 in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from requests->huggingface_hub->timm) (3.3.2)
22
+ Requirement already satisfied: idna<4,>=2.5 in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from requests->huggingface_hub->timm) (3.8)
23
+ Requirement already satisfied: urllib3<3,>=1.21.1 in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from requests->huggingface_hub->timm) (2.2.2)
24
+ Requirement already satisfied: certifi>=2017.4.17 in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from requests->huggingface_hub->timm) (2024.7.4)
25
+ Requirement already satisfied: mpmath<1.4,>=1.1.0 in c:\users\ilya9\anaconda3\envs\comf_upscaler\lib\site-packages (from sympy->torch->timm) (1.3.0)
26
+ Downloading timm-1.0.9-py3-none-any.whl (2.3 MB)
27
+ ---------------------------------------- 2.3/2.3 MB 6.0 MB/s eta 0:00:00
28
+ Installing collected packages: timm
29
+ Successfully installed timm-1.0.9
Backend/app/__init__.py ADDED
File without changes
Backend/app/app_settings.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from aiohttp import web
4
+
5
+
6
+ class AppSettings():
7
+ def __init__(self, user_manager):
8
+ self.user_manager = user_manager
9
+
10
+ def get_settings(self, request):
11
+ file = self.user_manager.get_request_user_filepath(
12
+ request, "comfy.settings.json")
13
+ if os.path.isfile(file):
14
+ with open(file) as f:
15
+ return json.load(f)
16
+ else:
17
+ return {}
18
+
19
+ def save_settings(self, request, settings):
20
+ file = self.user_manager.get_request_user_filepath(
21
+ request, "comfy.settings.json")
22
+ with open(file, "w") as f:
23
+ f.write(json.dumps(settings, indent=4))
24
+
25
+ def add_routes(self, routes):
26
+ @routes.get("/settings")
27
+ async def get_settings(request):
28
+ return web.json_response(self.get_settings(request))
29
+
30
+ @routes.get("/settings/{id}")
31
+ async def get_setting(request):
32
+ value = None
33
+ settings = self.get_settings(request)
34
+ setting_id = request.match_info.get("id", None)
35
+ if setting_id and setting_id in settings:
36
+ value = settings[setting_id]
37
+ return web.json_response(value)
38
+
39
+ @routes.post("/settings")
40
+ async def post_settings(request):
41
+ settings = self.get_settings(request)
42
+ new_settings = await request.json()
43
+ self.save_settings(request, {**settings, **new_settings})
44
+ return web.Response(status=200)
45
+
46
+ @routes.post("/settings/{id}")
47
+ async def post_setting(request):
48
+ setting_id = request.match_info.get("id", None)
49
+ if not setting_id:
50
+ return web.Response(status=400)
51
+ settings = self.get_settings(request)
52
+ settings[setting_id] = await request.json()
53
+ self.save_settings(request, settings)
54
+ return web.Response(status=200)
Backend/app/frontend_management.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import argparse
3
+ import logging
4
+ import os
5
+ import re
6
+ import tempfile
7
+ import zipfile
8
+ from dataclasses import dataclass
9
+ from functools import cached_property
10
+ from pathlib import Path
11
+ from typing import TypedDict
12
+
13
+ import requests
14
+ from typing_extensions import NotRequired
15
+ from comfy.cli_args import DEFAULT_VERSION_STRING
16
+
17
+
18
+ REQUEST_TIMEOUT = 10 # seconds
19
+
20
+
21
+ class Asset(TypedDict):
22
+ url: str
23
+
24
+
25
+ class Release(TypedDict):
26
+ id: int
27
+ tag_name: str
28
+ name: str
29
+ prerelease: bool
30
+ created_at: str
31
+ published_at: str
32
+ body: str
33
+ assets: NotRequired[list[Asset]]
34
+
35
+
36
+ @dataclass
37
+ class FrontEndProvider:
38
+ owner: str
39
+ repo: str
40
+
41
+ @property
42
+ def folder_name(self) -> str:
43
+ return f"{self.owner}_{self.repo}"
44
+
45
+ @property
46
+ def release_url(self) -> str:
47
+ return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
48
+
49
+ @cached_property
50
+ def all_releases(self) -> list[Release]:
51
+ releases = []
52
+ api_url = self.release_url
53
+ while api_url:
54
+ response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
55
+ response.raise_for_status() # Raises an HTTPError if the response was an error
56
+ releases.extend(response.json())
57
+ # GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
58
+ if "next" in response.links:
59
+ api_url = response.links["next"]["url"]
60
+ else:
61
+ api_url = None
62
+ return releases
63
+
64
+ @cached_property
65
+ def latest_release(self) -> Release:
66
+ latest_release_url = f"{self.release_url}/latest"
67
+ response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
68
+ response.raise_for_status() # Raises an HTTPError if the response was an error
69
+ return response.json()
70
+
71
+ def get_release(self, version: str) -> Release:
72
+ if version == "latest":
73
+ return self.latest_release
74
+ else:
75
+ for release in self.all_releases:
76
+ if release["tag_name"] in [version, f"v{version}"]:
77
+ return release
78
+ raise ValueError(f"Version {version} not found in releases")
79
+
80
+
81
+ def download_release_asset_zip(release: Release, destination_path: str) -> None:
82
+ """Download dist.zip from github release."""
83
+ asset_url = None
84
+ for asset in release.get("assets", []):
85
+ if asset["name"] == "dist.zip":
86
+ asset_url = asset["url"]
87
+ break
88
+
89
+ if not asset_url:
90
+ raise ValueError("dist.zip not found in the release assets")
91
+
92
+ # Use a temporary file to download the zip content
93
+ with tempfile.TemporaryFile() as tmp_file:
94
+ headers = {"Accept": "application/octet-stream"}
95
+ response = requests.get(
96
+ asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
97
+ )
98
+ response.raise_for_status() # Ensure we got a successful response
99
+
100
+ # Write the content to the temporary file
101
+ tmp_file.write(response.content)
102
+
103
+ # Go back to the beginning of the temporary file
104
+ tmp_file.seek(0)
105
+
106
+ # Extract the zip file content to the destination path
107
+ with zipfile.ZipFile(tmp_file, "r") as zip_ref:
108
+ zip_ref.extractall(destination_path)
109
+
110
+
111
+ class FrontendManager:
112
+ DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
113
+ CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
114
+
115
+ @classmethod
116
+ def parse_version_string(cls, value: str) -> tuple[str, str, str]:
117
+ """
118
+ Args:
119
+ value (str): The version string to parse.
120
+
121
+ Returns:
122
+ tuple[str, str]: A tuple containing provider name and version.
123
+
124
+ Raises:
125
+ argparse.ArgumentTypeError: If the version string is invalid.
126
+ """
127
+ VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
128
+ match_result = re.match(VERSION_PATTERN, value)
129
+ if match_result is None:
130
+ raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
131
+
132
+ return match_result.group(1), match_result.group(2), match_result.group(3)
133
+
134
+ @classmethod
135
+ def init_frontend_unsafe(cls, version_string: str) -> str:
136
+ """
137
+ Initializes the frontend for the specified version.
138
+
139
+ Args:
140
+ version_string (str): The version string.
141
+
142
+ Returns:
143
+ str: The path to the initialized frontend.
144
+
145
+ Raises:
146
+ Exception: If there is an error during the initialization process.
147
+ main error source might be request timeout or invalid URL.
148
+ """
149
+ if version_string == DEFAULT_VERSION_STRING:
150
+ return cls.DEFAULT_FRONTEND_PATH
151
+
152
+ repo_owner, repo_name, version = cls.parse_version_string(version_string)
153
+ provider = FrontEndProvider(repo_owner, repo_name)
154
+ release = provider.get_release(version)
155
+
156
+ semantic_version = release["tag_name"].lstrip("v")
157
+ web_root = str(
158
+ Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
159
+ )
160
+ if not os.path.exists(web_root):
161
+ os.makedirs(web_root, exist_ok=True)
162
+ logging.info(
163
+ "Downloading frontend(%s) version(%s) to (%s)",
164
+ provider.folder_name,
165
+ semantic_version,
166
+ web_root,
167
+ )
168
+ logging.debug(release)
169
+ download_release_asset_zip(release, destination_path=web_root)
170
+ return web_root
171
+
172
+ @classmethod
173
+ def init_frontend(cls, version_string: str) -> str:
174
+ """
175
+ Initializes the frontend with the specified version string.
176
+
177
+ Args:
178
+ version_string (str): The version string to initialize the frontend with.
179
+
180
+ Returns:
181
+ str: The path of the initialized frontend.
182
+ """
183
+ try:
184
+ return cls.init_frontend_unsafe(version_string)
185
+ except Exception as e:
186
+ logging.error("Failed to initialize frontend: %s", e)
187
+ logging.info("Falling back to the default frontend.")
188
+ return cls.DEFAULT_FRONTEND_PATH
Backend/app/user_manager.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import uuid
5
+ import glob
6
+ import shutil
7
+ from aiohttp import web
8
+ from comfy.cli_args import args
9
+ from folder_paths import user_directory
10
+ from .app_settings import AppSettings
11
+
12
+ default_user = "default"
13
+ users_file = os.path.join(user_directory, "users.json")
14
+
15
+
16
+ class UserManager():
17
+ def __init__(self):
18
+ global user_directory
19
+
20
+ self.settings = AppSettings(self)
21
+ if not os.path.exists(user_directory):
22
+ os.mkdir(user_directory)
23
+ if not args.multi_user:
24
+ print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
25
+ print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
26
+
27
+ if args.multi_user:
28
+ if os.path.isfile(users_file):
29
+ with open(users_file) as f:
30
+ self.users = json.load(f)
31
+ else:
32
+ self.users = {}
33
+ else:
34
+ self.users = {"default": "default"}
35
+
36
+ def get_request_user_id(self, request):
37
+ user = "default"
38
+ if args.multi_user and "comfy-user" in request.headers:
39
+ user = request.headers["comfy-user"]
40
+
41
+ if user not in self.users:
42
+ raise KeyError("Unknown user: " + user)
43
+
44
+ return user
45
+
46
+ def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
47
+ global user_directory
48
+
49
+ if type == "userdata":
50
+ root_dir = user_directory
51
+ else:
52
+ raise KeyError("Unknown filepath type:" + type)
53
+
54
+ user = self.get_request_user_id(request)
55
+ path = user_root = os.path.abspath(os.path.join(root_dir, user))
56
+
57
+ # prevent leaving /{type}
58
+ if os.path.commonpath((root_dir, user_root)) != root_dir:
59
+ return None
60
+
61
+ if file is not None:
62
+ # prevent leaving /{type}/{user}
63
+ path = os.path.abspath(os.path.join(user_root, file))
64
+ if os.path.commonpath((user_root, path)) != user_root:
65
+ return None
66
+
67
+ parent = os.path.split(path)[0]
68
+
69
+ if create_dir and not os.path.exists(parent):
70
+ os.makedirs(parent, exist_ok=True)
71
+
72
+ return path
73
+
74
+ def add_user(self, name):
75
+ name = name.strip()
76
+ if not name:
77
+ raise ValueError("username not provided")
78
+ user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
79
+ user_id = user_id + "_" + str(uuid.uuid4())
80
+
81
+ self.users[user_id] = name
82
+
83
+ global users_file
84
+ with open(users_file, "w") as f:
85
+ json.dump(self.users, f)
86
+
87
+ return user_id
88
+
89
+ def add_routes(self, routes):
90
+ self.settings.add_routes(routes)
91
+
92
+ @routes.get("/users")
93
+ async def get_users(request):
94
+ if args.multi_user:
95
+ return web.json_response({"storage": "server", "users": self.users})
96
+ else:
97
+ user_dir = self.get_request_user_filepath(request, None, create_dir=False)
98
+ return web.json_response({
99
+ "storage": "server",
100
+ "migrated": os.path.exists(user_dir)
101
+ })
102
+
103
+ @routes.post("/users")
104
+ async def post_users(request):
105
+ body = await request.json()
106
+ username = body["username"]
107
+ if username in self.users.values():
108
+ return web.json_response({"error": "Duplicate username."}, status=400)
109
+
110
+ user_id = self.add_user(username)
111
+ return web.json_response(user_id)
112
+
113
+ @routes.get("/userdata")
114
+ async def listuserdata(request):
115
+ directory = request.rel_url.query.get('dir', '')
116
+ if not directory:
117
+ return web.Response(status=400)
118
+
119
+ path = self.get_request_user_filepath(request, directory)
120
+ if not path:
121
+ return web.Response(status=403)
122
+
123
+ if not os.path.exists(path):
124
+ return web.Response(status=404)
125
+
126
+ recurse = request.rel_url.query.get('recurse', '').lower() == "true"
127
+ results = glob.glob(os.path.join(
128
+ glob.escape(path), '**/*'), recursive=recurse)
129
+ results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
130
+
131
+ split_path = request.rel_url.query.get('split', '').lower() == "true"
132
+ if split_path:
133
+ results = [[x] + x.split(os.sep) for x in results]
134
+
135
+ return web.json_response(results)
136
+
137
+ def get_user_data_path(request, check_exists = False, param = "file"):
138
+ file = request.match_info.get(param, None)
139
+ if not file:
140
+ return web.Response(status=400)
141
+
142
+ path = self.get_request_user_filepath(request, file)
143
+ if not path:
144
+ return web.Response(status=403)
145
+
146
+ if check_exists and not os.path.exists(path):
147
+ return web.Response(status=404)
148
+
149
+ return path
150
+
151
+ @routes.get("/userdata/{file}")
152
+ async def getuserdata(request):
153
+ path = get_user_data_path(request, check_exists=True)
154
+ if not isinstance(path, str):
155
+ return path
156
+
157
+ return web.FileResponse(path)
158
+
159
+ @routes.post("/userdata/{file}")
160
+ async def post_userdata(request):
161
+ path = get_user_data_path(request)
162
+ if not isinstance(path, str):
163
+ return path
164
+
165
+ overwrite = request.query["overwrite"] != "false"
166
+ if not overwrite and os.path.exists(path):
167
+ return web.Response(status=409)
168
+
169
+ body = await request.read()
170
+
171
+ with open(path, "wb") as f:
172
+ f.write(body)
173
+
174
+ resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
175
+ return web.json_response(resp)
176
+
177
+ @routes.delete("/userdata/{file}")
178
+ async def delete_userdata(request):
179
+ path = get_user_data_path(request, check_exists=True)
180
+ if not isinstance(path, str):
181
+ return path
182
+
183
+ os.remove(path)
184
+
185
+ return web.Response(status=204)
186
+
187
+ @routes.post("/userdata/{file}/move/{dest}")
188
+ async def move_userdata(request):
189
+ source = get_user_data_path(request, check_exists=True)
190
+ if not isinstance(source, str):
191
+ return source
192
+
193
+ dest = get_user_data_path(request, check_exists=False, param="dest")
194
+ if not isinstance(source, str):
195
+ return dest
196
+
197
+ overwrite = request.query["overwrite"] != "false"
198
+ if not overwrite and os.path.exists(dest):
199
+ return web.Response(status=409)
200
+
201
+ print(f"moving '{source}' -> '{dest}'")
202
+ shutil.move(source, dest)
203
+
204
+ resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
205
+ return web.json_response(resp)
Backend/comfy/checkpoint_pickle.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ load = pickle.load
4
+
5
+ class Empty:
6
+ pass
7
+
8
+ class Unpickler(pickle.Unpickler):
9
+ def find_class(self, module, name):
10
+ #TODO: safe unpickle
11
+ if module.startswith("pytorch_lightning"):
12
+ return Empty
13
+ return super().find_class(module, name)
Backend/comfy/cldm/cldm.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #taken from: https://github.com/lllyasviel/ControlNet
2
+ #and modified
3
+
4
+ import torch
5
+ import torch as th
6
+ import torch.nn as nn
7
+
8
+ from ..ldm.modules.diffusionmodules.util import (
9
+ zero_module,
10
+ timestep_embedding,
11
+ )
12
+
13
+ from ..ldm.modules.attention import SpatialTransformer
14
+ from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
15
+ from ..ldm.util import exists
16
+ from .control_types import UNION_CONTROLNET_TYPES
17
+ from collections import OrderedDict
18
+ import comfy.ops
19
+ from comfy.ldm.modules.attention import optimized_attention
20
+
21
+ class OptimizedAttention(nn.Module):
22
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
23
+ super().__init__()
24
+ self.heads = nhead
25
+ self.c = c
26
+
27
+ self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
28
+ self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
29
+
30
+ def forward(self, x):
31
+ x = self.in_proj(x)
32
+ q, k, v = x.split(self.c, dim=2)
33
+ out = optimized_attention(q, k, v, self.heads)
34
+ return self.out_proj(out)
35
+
36
+ class QuickGELU(nn.Module):
37
+ def forward(self, x: torch.Tensor):
38
+ return x * torch.sigmoid(1.702 * x)
39
+
40
+ class ResBlockUnionControlnet(nn.Module):
41
+ def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
42
+ super().__init__()
43
+ self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
44
+ self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
45
+ self.mlp = nn.Sequential(
46
+ OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
47
+ ("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
48
+ self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)
49
+
50
+ def attention(self, x: torch.Tensor):
51
+ return self.attn(x)
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ x = x + self.attention(self.ln_1(x))
55
+ x = x + self.mlp(self.ln_2(x))
56
+ return x
57
+
58
+ class ControlledUnetModel(UNetModel):
59
+ #implemented in the ldm unet
60
+ pass
61
+
62
+ class ControlNet(nn.Module):
63
+ def __init__(
64
+ self,
65
+ image_size,
66
+ in_channels,
67
+ model_channels,
68
+ hint_channels,
69
+ num_res_blocks,
70
+ dropout=0,
71
+ channel_mult=(1, 2, 4, 8),
72
+ conv_resample=True,
73
+ dims=2,
74
+ num_classes=None,
75
+ use_checkpoint=False,
76
+ dtype=torch.float32,
77
+ num_heads=-1,
78
+ num_head_channels=-1,
79
+ num_heads_upsample=-1,
80
+ use_scale_shift_norm=False,
81
+ resblock_updown=False,
82
+ use_new_attention_order=False,
83
+ use_spatial_transformer=False, # custom transformer support
84
+ transformer_depth=1, # custom transformer support
85
+ context_dim=None, # custom transformer support
86
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
87
+ legacy=True,
88
+ disable_self_attentions=None,
89
+ num_attention_blocks=None,
90
+ disable_middle_self_attn=False,
91
+ use_linear_in_transformer=False,
92
+ adm_in_channels=None,
93
+ transformer_depth_middle=None,
94
+ transformer_depth_output=None,
95
+ attn_precision=None,
96
+ union_controlnet_num_control_type=None,
97
+ device=None,
98
+ operations=comfy.ops.disable_weight_init,
99
+ **kwargs,
100
+ ):
101
+ super().__init__()
102
+ assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
103
+ if use_spatial_transformer:
104
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
105
+
106
+ if context_dim is not None:
107
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
108
+ # from omegaconf.listconfig import ListConfig
109
+ # if type(context_dim) == ListConfig:
110
+ # context_dim = list(context_dim)
111
+
112
+ if num_heads_upsample == -1:
113
+ num_heads_upsample = num_heads
114
+
115
+ if num_heads == -1:
116
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
117
+
118
+ if num_head_channels == -1:
119
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
120
+
121
+ self.dims = dims
122
+ self.image_size = image_size
123
+ self.in_channels = in_channels
124
+ self.model_channels = model_channels
125
+
126
+ if isinstance(num_res_blocks, int):
127
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
128
+ else:
129
+ if len(num_res_blocks) != len(channel_mult):
130
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
131
+ "as a list/tuple (per-level) with the same length as channel_mult")
132
+ self.num_res_blocks = num_res_blocks
133
+
134
+ if disable_self_attentions is not None:
135
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
136
+ assert len(disable_self_attentions) == len(channel_mult)
137
+ if num_attention_blocks is not None:
138
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
139
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
140
+
141
+ transformer_depth = transformer_depth[:]
142
+
143
+ self.dropout = dropout
144
+ self.channel_mult = channel_mult
145
+ self.conv_resample = conv_resample
146
+ self.num_classes = num_classes
147
+ self.use_checkpoint = use_checkpoint
148
+ self.dtype = dtype
149
+ self.num_heads = num_heads
150
+ self.num_head_channels = num_head_channels
151
+ self.num_heads_upsample = num_heads_upsample
152
+ self.predict_codebook_ids = n_embed is not None
153
+
154
+ time_embed_dim = model_channels * 4
155
+ self.time_embed = nn.Sequential(
156
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
157
+ nn.SiLU(),
158
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
159
+ )
160
+
161
+ if self.num_classes is not None:
162
+ if isinstance(self.num_classes, int):
163
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
164
+ elif self.num_classes == "continuous":
165
+ print("setting up linear c_adm embedding layer")
166
+ self.label_emb = nn.Linear(1, time_embed_dim)
167
+ elif self.num_classes == "sequential":
168
+ assert adm_in_channels is not None
169
+ self.label_emb = nn.Sequential(
170
+ nn.Sequential(
171
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
172
+ nn.SiLU(),
173
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
174
+ )
175
+ )
176
+ else:
177
+ raise ValueError()
178
+
179
+ self.input_blocks = nn.ModuleList(
180
+ [
181
+ TimestepEmbedSequential(
182
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
183
+ )
184
+ ]
185
+ )
186
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
187
+
188
+ self.input_hint_block = TimestepEmbedSequential(
189
+ operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
190
+ nn.SiLU(),
191
+ operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
192
+ nn.SiLU(),
193
+ operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
194
+ nn.SiLU(),
195
+ operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
196
+ nn.SiLU(),
197
+ operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
198
+ nn.SiLU(),
199
+ operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
200
+ nn.SiLU(),
201
+ operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
202
+ nn.SiLU(),
203
+ operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
204
+ )
205
+
206
+ self._feature_size = model_channels
207
+ input_block_chans = [model_channels]
208
+ ch = model_channels
209
+ ds = 1
210
+ for level, mult in enumerate(channel_mult):
211
+ for nr in range(self.num_res_blocks[level]):
212
+ layers = [
213
+ ResBlock(
214
+ ch,
215
+ time_embed_dim,
216
+ dropout,
217
+ out_channels=mult * model_channels,
218
+ dims=dims,
219
+ use_checkpoint=use_checkpoint,
220
+ use_scale_shift_norm=use_scale_shift_norm,
221
+ dtype=self.dtype,
222
+ device=device,
223
+ operations=operations,
224
+ )
225
+ ]
226
+ ch = mult * model_channels
227
+ num_transformers = transformer_depth.pop(0)
228
+ if num_transformers > 0:
229
+ if num_head_channels == -1:
230
+ dim_head = ch // num_heads
231
+ else:
232
+ num_heads = ch // num_head_channels
233
+ dim_head = num_head_channels
234
+ if legacy:
235
+ #num_heads = 1
236
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
237
+ if exists(disable_self_attentions):
238
+ disabled_sa = disable_self_attentions[level]
239
+ else:
240
+ disabled_sa = False
241
+
242
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
243
+ layers.append(
244
+ SpatialTransformer(
245
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
246
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
247
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
248
+ )
249
+ )
250
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
251
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
252
+ self._feature_size += ch
253
+ input_block_chans.append(ch)
254
+ if level != len(channel_mult) - 1:
255
+ out_ch = ch
256
+ self.input_blocks.append(
257
+ TimestepEmbedSequential(
258
+ ResBlock(
259
+ ch,
260
+ time_embed_dim,
261
+ dropout,
262
+ out_channels=out_ch,
263
+ dims=dims,
264
+ use_checkpoint=use_checkpoint,
265
+ use_scale_shift_norm=use_scale_shift_norm,
266
+ down=True,
267
+ dtype=self.dtype,
268
+ device=device,
269
+ operations=operations
270
+ )
271
+ if resblock_updown
272
+ else Downsample(
273
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
274
+ )
275
+ )
276
+ )
277
+ ch = out_ch
278
+ input_block_chans.append(ch)
279
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
280
+ ds *= 2
281
+ self._feature_size += ch
282
+
283
+ if num_head_channels == -1:
284
+ dim_head = ch // num_heads
285
+ else:
286
+ num_heads = ch // num_head_channels
287
+ dim_head = num_head_channels
288
+ if legacy:
289
+ #num_heads = 1
290
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
291
+ mid_block = [
292
+ ResBlock(
293
+ ch,
294
+ time_embed_dim,
295
+ dropout,
296
+ dims=dims,
297
+ use_checkpoint=use_checkpoint,
298
+ use_scale_shift_norm=use_scale_shift_norm,
299
+ dtype=self.dtype,
300
+ device=device,
301
+ operations=operations
302
+ )]
303
+ if transformer_depth_middle >= 0:
304
+ mid_block += [SpatialTransformer( # always uses a self-attn
305
+ ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
306
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
307
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
308
+ ),
309
+ ResBlock(
310
+ ch,
311
+ time_embed_dim,
312
+ dropout,
313
+ dims=dims,
314
+ use_checkpoint=use_checkpoint,
315
+ use_scale_shift_norm=use_scale_shift_norm,
316
+ dtype=self.dtype,
317
+ device=device,
318
+ operations=operations
319
+ )]
320
+ self.middle_block = TimestepEmbedSequential(*mid_block)
321
+ self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
322
+ self._feature_size += ch
323
+
324
+ if union_controlnet_num_control_type is not None:
325
+ self.num_control_type = union_controlnet_num_control_type
326
+ num_trans_channel = 320
327
+ num_trans_head = 8
328
+ num_trans_layer = 1
329
+ num_proj_channel = 320
330
+ # task_scale_factor = num_trans_channel ** 0.5
331
+ self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))
332
+
333
+ self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
334
+ self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
335
+ #-----------------------------------------------------------------------------------------------------
336
+
337
+ control_add_embed_dim = 256
338
+ class ControlAddEmbedding(nn.Module):
339
+ def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
340
+ super().__init__()
341
+ self.num_control_type = num_control_type
342
+ self.in_dim = in_dim
343
+ self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
344
+ self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
345
+ def forward(self, control_type, dtype, device):
346
+ c_type = torch.zeros((self.num_control_type,), device=device)
347
+ c_type[control_type] = 1.0
348
+ c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
349
+ return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))
350
+
351
+ self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
352
+ else:
353
+ self.task_embedding = None
354
+ self.control_add_embedding = None
355
+
356
+ def union_controlnet_merge(self, hint, control_type, emb, context):
357
+ # Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
358
+ inputs = []
359
+ condition_list = []
360
+
361
+ for idx in range(min(1, len(control_type))):
362
+ controlnet_cond = self.input_hint_block(hint[idx], emb, context)
363
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
364
+ if idx < len(control_type):
365
+ feat_seq += self.task_embedding[control_type[idx]].to(dtype=feat_seq.dtype, device=feat_seq.device)
366
+
367
+ inputs.append(feat_seq.unsqueeze(1))
368
+ condition_list.append(controlnet_cond)
369
+
370
+ x = torch.cat(inputs, dim=1)
371
+ x = self.transformer_layes(x)
372
+ controlnet_cond_fuser = None
373
+ for idx in range(len(control_type)):
374
+ alpha = self.spatial_ch_projs(x[:, idx])
375
+ alpha = alpha.unsqueeze(-1).unsqueeze(-1)
376
+ o = condition_list[idx] + alpha
377
+ if controlnet_cond_fuser is None:
378
+ controlnet_cond_fuser = o
379
+ else:
380
+ controlnet_cond_fuser += o
381
+ return controlnet_cond_fuser
382
+
383
+ def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
384
+ return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
385
+
386
+ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
387
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
388
+ emb = self.time_embed(t_emb)
389
+
390
+ guided_hint = None
391
+ if self.control_add_embedding is not None: #Union Controlnet
392
+ control_type = kwargs.get("control_type", [])
393
+
394
+ if any([c >= self.num_control_type for c in control_type]):
395
+ max_type = max(control_type)
396
+ max_type_name = {
397
+ v: k for k, v in UNION_CONTROLNET_TYPES.items()
398
+ }[max_type]
399
+ raise ValueError(
400
+ f"Control type {max_type_name}({max_type}) is out of range for the number of control types" +
401
+ f"({self.num_control_type}) supported.\n" +
402
+ "Please consider using the ProMax ControlNet Union model.\n" +
403
+ "https://huggingface.co/xinsir/controlnet-union-sdxl-1.0/tree/main"
404
+ )
405
+
406
+ emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
407
+ if len(control_type) > 0:
408
+ if len(hint.shape) < 5:
409
+ hint = hint.unsqueeze(dim=0)
410
+ guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
411
+
412
+ if guided_hint is None:
413
+ guided_hint = self.input_hint_block(hint, emb, context)
414
+
415
+ out_output = []
416
+ out_middle = []
417
+
418
+ hs = []
419
+ if self.num_classes is not None:
420
+ assert y.shape[0] == x.shape[0]
421
+ emb = emb + self.label_emb(y)
422
+
423
+ h = x
424
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
425
+ if guided_hint is not None:
426
+ h = module(h, emb, context)
427
+ h += guided_hint
428
+ guided_hint = None
429
+ else:
430
+ h = module(h, emb, context)
431
+ out_output.append(zero_conv(h, emb, context))
432
+
433
+ h = self.middle_block(h, emb, context)
434
+ out_middle.append(self.middle_block_out(h, emb, context))
435
+
436
+ return {"middle": out_middle, "output": out_output}
437
+
Backend/comfy/cldm/control_types.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ UNION_CONTROLNET_TYPES = {
2
+ "openpose": 0,
3
+ "depth": 1,
4
+ "hed/pidi/scribble/ted": 2,
5
+ "canny/lineart/anime_lineart/mlsd": 3,
6
+ "normal": 4,
7
+ "segment": 5,
8
+ "tile": 6,
9
+ "repaint": 7,
10
+ }
Backend/comfy/cldm/mmdit.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict, Optional
3
+ import comfy.ldm.modules.diffusionmodules.mmdit
4
+
5
+ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
6
+ def __init__(
7
+ self,
8
+ num_blocks = None,
9
+ dtype = None,
10
+ device = None,
11
+ operations = None,
12
+ **kwargs,
13
+ ):
14
+ super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs)
15
+ # controlnet_blocks
16
+ self.controlnet_blocks = torch.nn.ModuleList([])
17
+ for _ in range(len(self.joint_blocks)):
18
+ self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
19
+
20
+ self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
21
+ None,
22
+ self.patch_size,
23
+ self.in_channels,
24
+ self.hidden_size,
25
+ bias=True,
26
+ strict_img_size=False,
27
+ dtype=dtype,
28
+ device=device,
29
+ operations=operations
30
+ )
31
+
32
+ def forward(
33
+ self,
34
+ x: torch.Tensor,
35
+ timesteps: torch.Tensor,
36
+ y: Optional[torch.Tensor] = None,
37
+ context: Optional[torch.Tensor] = None,
38
+ hint = None,
39
+ ) -> torch.Tensor:
40
+
41
+ #weird sd3 controlnet specific stuff
42
+ y = torch.zeros_like(y)
43
+
44
+ if self.context_processor is not None:
45
+ context = self.context_processor(context)
46
+
47
+ hw = x.shape[-2:]
48
+ x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
49
+ x += self.pos_embed_input(hint)
50
+
51
+ c = self.t_embedder(timesteps, dtype=x.dtype)
52
+ if y is not None and self.y_embedder is not None:
53
+ y = self.y_embedder(y)
54
+ c = c + y
55
+
56
+ if context is not None:
57
+ context = self.context_embedder(context)
58
+
59
+ output = []
60
+
61
+ blocks = len(self.joint_blocks)
62
+ for i in range(blocks):
63
+ context, x = self.joint_blocks[i](
64
+ context,
65
+ x,
66
+ c=c,
67
+ use_checkpoint=self.use_checkpoint,
68
+ )
69
+
70
+ out = self.controlnet_blocks[i](x)
71
+ count = self.depth // blocks
72
+ if i == blocks - 1:
73
+ count -= 1
74
+ for j in range(count):
75
+ output.append(out)
76
+
77
+ return {"output": output}
Backend/comfy/cli_args.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import enum
3
+ import os
4
+ from typing import Optional
5
+ import comfy.options
6
+
7
+
8
+ class EnumAction(argparse.Action):
9
+ """
10
+ Argparse action for handling Enums
11
+ """
12
+ def __init__(self, **kwargs):
13
+ # Pop off the type value
14
+ enum_type = kwargs.pop("type", None)
15
+
16
+ # Ensure an Enum subclass is provided
17
+ if enum_type is None:
18
+ raise ValueError("type must be assigned an Enum when using EnumAction")
19
+ if not issubclass(enum_type, enum.Enum):
20
+ raise TypeError("type must be an Enum when using EnumAction")
21
+
22
+ # Generate choices from the Enum
23
+ choices = tuple(e.value for e in enum_type)
24
+ kwargs.setdefault("choices", choices)
25
+ kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
26
+
27
+ super(EnumAction, self).__init__(**kwargs)
28
+
29
+ self._enum = enum_type
30
+
31
+ def __call__(self, parser, namespace, values, option_string=None):
32
+ # Convert value back into an Enum
33
+ value = self._enum(values)
34
+ setattr(namespace, self.dest, value)
35
+
36
+
37
+ parser = argparse.ArgumentParser()
38
+
39
+ parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
40
+ parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
41
+ parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
42
+ parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
43
+ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
44
+ parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
45
+
46
+ parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
47
+ parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
48
+ parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
49
+ parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
50
+ parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
51
+ parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
52
+ parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
53
+ cm_group = parser.add_mutually_exclusive_group()
54
+ cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
55
+ cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
56
+
57
+
58
+ fp_group = parser.add_mutually_exclusive_group()
59
+ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
60
+ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
61
+
62
+ fpunet_group = parser.add_mutually_exclusive_group()
63
+ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
64
+ fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
65
+ fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
66
+ fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
67
+
68
+ fpvae_group = parser.add_mutually_exclusive_group()
69
+ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
70
+ fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
71
+ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
72
+
73
+ parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
74
+
75
+ fpte_group = parser.add_mutually_exclusive_group()
76
+ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
77
+ fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
78
+ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
79
+ fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
80
+
81
+ parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
82
+
83
+ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
84
+
85
+ parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
86
+
87
+ class LatentPreviewMethod(enum.Enum):
88
+ NoPreviews = "none"
89
+ Auto = "auto"
90
+ Latent2RGB = "latent2rgb"
91
+ TAESD = "taesd"
92
+
93
+ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
94
+
95
+ attn_group = parser.add_mutually_exclusive_group()
96
+ attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
97
+ attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
98
+ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
99
+
100
+ parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
101
+
102
+ upcast = parser.add_mutually_exclusive_group()
103
+ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
104
+ upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
105
+
106
+
107
+ vram_group = parser.add_mutually_exclusive_group()
108
+ vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
109
+ vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
110
+ vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
111
+ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
112
+ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
113
+ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
114
+
115
+ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
116
+
117
+ parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
118
+ parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
119
+
120
+ parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
121
+ parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
122
+ parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
123
+
124
+ parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
125
+ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
126
+
127
+ parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
128
+
129
+ parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
130
+
131
+ # The default built-in provider hosted under web/
132
+ DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
133
+
134
+ parser.add_argument(
135
+ "--front-end-version",
136
+ type=str,
137
+ default=DEFAULT_VERSION_STRING,
138
+ help="""
139
+ Specifies the version of the frontend to be used. This command needs internet connectivity to query and
140
+ download available frontend implementations from GitHub releases.
141
+
142
+ The version string should be in the format of:
143
+ [repoOwner]/[repoName]@[version]
144
+ where version is one of: "latest" or a valid version number (e.g. "1.0.0")
145
+ """,
146
+ )
147
+
148
+ def is_valid_directory(path: Optional[str]) -> Optional[str]:
149
+ """Validate if the given path is a directory."""
150
+ if path is None:
151
+ return None
152
+
153
+ if not os.path.isdir(path):
154
+ raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
155
+ return path
156
+
157
+ parser.add_argument(
158
+ "--front-end-root",
159
+ type=is_valid_directory,
160
+ default=None,
161
+ help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
162
+ )
163
+
164
+ if comfy.options.args_parsing:
165
+ args = parser.parse_args()
166
+ else:
167
+ args = parser.parse_args([])
168
+
169
+ if args.windows_standalone_build:
170
+ args.auto_launch = True
171
+
172
+ if args.disable_auto_launch:
173
+ args.auto_launch = False
174
+
175
+ import logging
176
+ logging_level = logging.INFO
177
+ if args.verbose:
178
+ logging_level = logging.DEBUG
179
+
180
+ logging.basicConfig(format="%(message)s", level=logging_level)
Backend/comfy/clip_config_bigg.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPTextModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "dropout": 0.0,
8
+ "eos_token_id": 49407,
9
+ "hidden_act": "gelu",
10
+ "hidden_size": 1280,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 77,
16
+ "model_type": "clip_text_model",
17
+ "num_attention_heads": 20,
18
+ "num_hidden_layers": 32,
19
+ "pad_token_id": 1,
20
+ "projection_dim": 1280,
21
+ "torch_dtype": "float32",
22
+ "vocab_size": 49408
23
+ }
Backend/comfy/clip_model.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from comfy.ldm.modules.attention import optimized_attention_for_device
3
+ import comfy.ops
4
+
5
+ class CLIPAttention(torch.nn.Module):
6
+ def __init__(self, embed_dim, heads, dtype, device, operations):
7
+ super().__init__()
8
+
9
+ self.heads = heads
10
+ self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
11
+ self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
12
+ self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
13
+
14
+ self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
15
+
16
+ def forward(self, x, mask=None, optimized_attention=None):
17
+ q = self.q_proj(x)
18
+ k = self.k_proj(x)
19
+ v = self.v_proj(x)
20
+
21
+ out = optimized_attention(q, k, v, self.heads, mask)
22
+ return self.out_proj(out)
23
+
24
+ ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
25
+ "gelu": torch.nn.functional.gelu,
26
+ }
27
+
28
+ class CLIPMLP(torch.nn.Module):
29
+ def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
30
+ super().__init__()
31
+ self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
32
+ self.activation = ACTIVATIONS[activation]
33
+ self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.activation(x)
38
+ x = self.fc2(x)
39
+ return x
40
+
41
+ class CLIPLayer(torch.nn.Module):
42
+ def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
43
+ super().__init__()
44
+ self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
45
+ self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
46
+ self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
47
+ self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
48
+
49
+ def forward(self, x, mask=None, optimized_attention=None):
50
+ x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
51
+ x += self.mlp(self.layer_norm2(x))
52
+ return x
53
+
54
+
55
+ class CLIPEncoder(torch.nn.Module):
56
+ def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
57
+ super().__init__()
58
+ self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
59
+
60
+ def forward(self, x, mask=None, intermediate_output=None):
61
+ optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
62
+
63
+ if intermediate_output is not None:
64
+ if intermediate_output < 0:
65
+ intermediate_output = len(self.layers) + intermediate_output
66
+
67
+ intermediate = None
68
+ for i, l in enumerate(self.layers):
69
+ x = l(x, mask, optimized_attention)
70
+ if i == intermediate_output:
71
+ intermediate = x.clone()
72
+ return x, intermediate
73
+
74
+ class CLIPEmbeddings(torch.nn.Module):
75
+ def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, operations=None):
76
+ super().__init__()
77
+ self.token_embedding = operations.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
78
+ self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
79
+
80
+ def forward(self, input_tokens, dtype=torch.float32):
81
+ return self.token_embedding(input_tokens, out_dtype=dtype) + comfy.ops.cast_to(self.position_embedding.weight, dtype=dtype, device=input_tokens.device)
82
+
83
+
84
+ class CLIPTextModel_(torch.nn.Module):
85
+ def __init__(self, config_dict, dtype, device, operations):
86
+ num_layers = config_dict["num_hidden_layers"]
87
+ embed_dim = config_dict["hidden_size"]
88
+ heads = config_dict["num_attention_heads"]
89
+ intermediate_size = config_dict["intermediate_size"]
90
+ intermediate_activation = config_dict["hidden_act"]
91
+ self.eos_token_id = config_dict["eos_token_id"]
92
+
93
+ super().__init__()
94
+ self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations)
95
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
96
+ self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
97
+
98
+ def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
99
+ x = self.embeddings(input_tokens, dtype=dtype)
100
+ mask = None
101
+ if attention_mask is not None:
102
+ mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
103
+ mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
104
+
105
+ causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
106
+ if mask is not None:
107
+ mask += causal_mask
108
+ else:
109
+ mask = causal_mask
110
+
111
+ x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
112
+ x = self.final_layer_norm(x)
113
+ if i is not None and final_layer_norm_intermediate:
114
+ i = self.final_layer_norm(i)
115
+
116
+ pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
117
+ return x, i, pooled_output
118
+
119
+ class CLIPTextModel(torch.nn.Module):
120
+ def __init__(self, config_dict, dtype, device, operations):
121
+ super().__init__()
122
+ self.num_layers = config_dict["num_hidden_layers"]
123
+ self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
124
+ embed_dim = config_dict["hidden_size"]
125
+ self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
126
+ self.text_projection.weight.copy_(torch.eye(embed_dim))
127
+ self.dtype = dtype
128
+
129
+ def get_input_embeddings(self):
130
+ return self.text_model.embeddings.token_embedding
131
+
132
+ def set_input_embeddings(self, embeddings):
133
+ self.text_model.embeddings.token_embedding = embeddings
134
+
135
+ def forward(self, *args, **kwargs):
136
+ x = self.text_model(*args, **kwargs)
137
+ out = self.text_projection(x[2])
138
+ return (x[0], x[1], out, x[2])
139
+
140
+
141
+ class CLIPVisionEmbeddings(torch.nn.Module):
142
+ def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
143
+ super().__init__()
144
+ self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
145
+
146
+ self.patch_embedding = operations.Conv2d(
147
+ in_channels=num_channels,
148
+ out_channels=embed_dim,
149
+ kernel_size=patch_size,
150
+ stride=patch_size,
151
+ bias=False,
152
+ dtype=dtype,
153
+ device=device
154
+ )
155
+
156
+ num_patches = (image_size // patch_size) ** 2
157
+ num_positions = num_patches + 1
158
+ self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
159
+
160
+ def forward(self, pixel_values):
161
+ embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
162
+ return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
163
+
164
+
165
+ class CLIPVision(torch.nn.Module):
166
+ def __init__(self, config_dict, dtype, device, operations):
167
+ super().__init__()
168
+ num_layers = config_dict["num_hidden_layers"]
169
+ embed_dim = config_dict["hidden_size"]
170
+ heads = config_dict["num_attention_heads"]
171
+ intermediate_size = config_dict["intermediate_size"]
172
+ intermediate_activation = config_dict["hidden_act"]
173
+
174
+ self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
175
+ self.pre_layrnorm = operations.LayerNorm(embed_dim)
176
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
177
+ self.post_layernorm = operations.LayerNorm(embed_dim)
178
+
179
+ def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
180
+ x = self.embeddings(pixel_values)
181
+ x = self.pre_layrnorm(x)
182
+ #TODO: attention_mask?
183
+ x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
184
+ pooled_output = self.post_layernorm(x[:, 0, :])
185
+ return x, i, pooled_output
186
+
187
+ class CLIPVisionModelProjection(torch.nn.Module):
188
+ def __init__(self, config_dict, dtype, device, operations):
189
+ super().__init__()
190
+ self.vision_model = CLIPVision(config_dict, dtype, device, operations)
191
+ self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
192
+
193
+ def forward(self, *args, **kwargs):
194
+ x = self.vision_model(*args, **kwargs)
195
+ out = self.visual_projection(x[2])
196
+ return (x[0], x[1], out)
Backend/comfy/clip_vision.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
2
+ import os
3
+ import torch
4
+ import json
5
+ import logging
6
+
7
+ import comfy.ops
8
+ import comfy.model_patcher
9
+ import comfy.model_management
10
+ import comfy.utils
11
+ import comfy.clip_model
12
+
13
+ class Output:
14
+ def __getitem__(self, key):
15
+ return getattr(self, key)
16
+ def __setitem__(self, key, item):
17
+ setattr(self, key, item)
18
+
19
+ def clip_preprocess(image, size=224):
20
+ mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
21
+ std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
22
+ image = image.movedim(-1, 1)
23
+ if not (image.shape[2] == size and image.shape[3] == size):
24
+ scale = (size / min(image.shape[2], image.shape[3]))
25
+ image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
26
+ h = (image.shape[2] - size)//2
27
+ w = (image.shape[3] - size)//2
28
+ image = image[:,:,h:h+size,w:w+size]
29
+ image = torch.clip((255. * image), 0, 255).round() / 255.0
30
+ return (image - mean.view([3,1,1])) / std.view([3,1,1])
31
+
32
+ class ClipVisionModel():
33
+ def __init__(self, json_config):
34
+ with open(json_config) as f:
35
+ config = json.load(f)
36
+
37
+ self.image_size = config.get("image_size", 224)
38
+ self.load_device = comfy.model_management.text_encoder_device()
39
+ offload_device = comfy.model_management.text_encoder_offload_device()
40
+ self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
41
+ self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
42
+ self.model.eval()
43
+
44
+ self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
45
+
46
+ def load_sd(self, sd):
47
+ return self.model.load_state_dict(sd, strict=False)
48
+
49
+ def get_sd(self):
50
+ return self.model.state_dict()
51
+
52
+ def encode_image(self, image):
53
+ comfy.model_management.load_model_gpu(self.patcher)
54
+ pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
55
+ out = self.model(pixel_values=pixel_values, intermediate_output=-2)
56
+
57
+ outputs = Output()
58
+ outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
59
+ outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
60
+ outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
61
+ return outputs
62
+
63
+ def convert_to_transformers(sd, prefix):
64
+ sd_k = sd.keys()
65
+ if "{}transformer.resblocks.0.attn.in_proj_weight".format(prefix) in sd_k:
66
+ keys_to_replace = {
67
+ "{}class_embedding".format(prefix): "vision_model.embeddings.class_embedding",
68
+ "{}conv1.weight".format(prefix): "vision_model.embeddings.patch_embedding.weight",
69
+ "{}positional_embedding".format(prefix): "vision_model.embeddings.position_embedding.weight",
70
+ "{}ln_post.bias".format(prefix): "vision_model.post_layernorm.bias",
71
+ "{}ln_post.weight".format(prefix): "vision_model.post_layernorm.weight",
72
+ "{}ln_pre.bias".format(prefix): "vision_model.pre_layrnorm.bias",
73
+ "{}ln_pre.weight".format(prefix): "vision_model.pre_layrnorm.weight",
74
+ }
75
+
76
+ for x in keys_to_replace:
77
+ if x in sd_k:
78
+ sd[keys_to_replace[x]] = sd.pop(x)
79
+
80
+ if "{}proj".format(prefix) in sd_k:
81
+ sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
82
+
83
+ sd = transformers_convert(sd, prefix, "vision_model.", 48)
84
+ else:
85
+ replace_prefix = {prefix: ""}
86
+ sd = state_dict_prefix_replace(sd, replace_prefix)
87
+ return sd
88
+
89
+ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
90
+ if convert_keys:
91
+ sd = convert_to_transformers(sd, prefix)
92
+ if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
93
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
94
+ elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
95
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
96
+ elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
97
+ if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
98
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
99
+ else:
100
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
101
+ else:
102
+ return None
103
+
104
+ clip = ClipVisionModel(json_config)
105
+ m, u = clip.load_sd(sd)
106
+ if len(m) > 0:
107
+ logging.warning("missing clip vision: {}".format(m))
108
+ u = set(u)
109
+ keys = list(sd.keys())
110
+ for k in keys:
111
+ if k not in u:
112
+ t = sd.pop(k)
113
+ del t
114
+ return clip
115
+
116
+ def load(ckpt_path):
117
+ sd = load_torch_file(ckpt_path)
118
+ if "visual.transformer.resblocks.0.attn.in_proj_weight" in sd:
119
+ return load_clipvision_from_sd(sd, prefix="visual.", convert_keys=True)
120
+ else:
121
+ return load_clipvision_from_sd(sd)
Backend/comfy/clip_vision_config_g.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "gelu",
5
+ "hidden_size": 1664,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 8192,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 48,
15
+ "patch_size": 14,
16
+ "projection_dim": 1280,
17
+ "torch_dtype": "float32"
18
+ }
Backend/comfy/clip_vision_config_h.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "gelu",
5
+ "hidden_size": 1280,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 5120,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 32,
15
+ "patch_size": 14,
16
+ "projection_dim": 1024,
17
+ "torch_dtype": "float32"
18
+ }
Backend/comfy/clip_vision_config_vitl.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "quick_gelu",
5
+ "hidden_size": 1024,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 4096,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 24,
15
+ "patch_size": 14,
16
+ "projection_dim": 768,
17
+ "torch_dtype": "float32"
18
+ }
Backend/comfy/clip_vision_config_vitl_336.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "quick_gelu",
5
+ "hidden_size": 1024,
6
+ "image_size": 336,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 4096,
10
+ "layer_norm_eps": 1e-5,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 24,
15
+ "patch_size": 14,
16
+ "projection_dim": 768,
17
+ "torch_dtype": "float32"
18
+ }
Backend/comfy/conds.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import comfy.utils
4
+
5
+
6
+ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
7
+ return abs(a*b) // math.gcd(a, b)
8
+
9
+ class CONDRegular:
10
+ def __init__(self, cond):
11
+ self.cond = cond
12
+
13
+ def _copy_with(self, cond):
14
+ return self.__class__(cond)
15
+
16
+ def process_cond(self, batch_size, device, **kwargs):
17
+ return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
18
+
19
+ def can_concat(self, other):
20
+ if self.cond.shape != other.cond.shape:
21
+ return False
22
+ return True
23
+
24
+ def concat(self, others):
25
+ conds = [self.cond]
26
+ for x in others:
27
+ conds.append(x.cond)
28
+ return torch.cat(conds)
29
+
30
+ class CONDNoiseShape(CONDRegular):
31
+ def process_cond(self, batch_size, device, area, **kwargs):
32
+ data = self.cond
33
+ if area is not None:
34
+ dims = len(area) // 2
35
+ for i in range(dims):
36
+ data = data.narrow(i + 2, area[i + dims], area[i])
37
+
38
+ return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
39
+
40
+
41
+ class CONDCrossAttn(CONDRegular):
42
+ def can_concat(self, other):
43
+ s1 = self.cond.shape
44
+ s2 = other.cond.shape
45
+ if s1 != s2:
46
+ if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
47
+ return False
48
+
49
+ mult_min = lcm(s1[1], s2[1])
50
+ diff = mult_min // min(s1[1], s2[1])
51
+ if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
52
+ return False
53
+ return True
54
+
55
+ def concat(self, others):
56
+ conds = [self.cond]
57
+ crossattn_max_len = self.cond.shape[1]
58
+ for x in others:
59
+ c = x.cond
60
+ crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
61
+ conds.append(c)
62
+
63
+ out = []
64
+ for c in conds:
65
+ if c.shape[1] < crossattn_max_len:
66
+ c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
67
+ out.append(c)
68
+ return torch.cat(out)
69
+
70
+ class CONDConstant(CONDRegular):
71
+ def __init__(self, cond):
72
+ self.cond = cond
73
+
74
+ def process_cond(self, batch_size, device, **kwargs):
75
+ return self._copy_with(self.cond)
76
+
77
+ def can_concat(self, other):
78
+ if self.cond != other.cond:
79
+ return False
80
+ return True
81
+
82
+ def concat(self, others):
83
+ return self.cond
Backend/comfy/controlnet.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import os
4
+ import logging
5
+ import comfy.utils
6
+ import comfy.model_management
7
+ import comfy.model_detection
8
+ import comfy.model_patcher
9
+ import comfy.ops
10
+ import comfy.latent_formats
11
+
12
+ import comfy.cldm.cldm
13
+ import comfy.t2i_adapter.adapter
14
+ import comfy.ldm.cascade.controlnet
15
+ import comfy.cldm.mmdit
16
+
17
+
18
+ def broadcast_image_to(tensor, target_batch_size, batched_number):
19
+ current_batch_size = tensor.shape[0]
20
+ #print(current_batch_size, target_batch_size)
21
+ if current_batch_size == 1:
22
+ return tensor
23
+
24
+ per_batch = target_batch_size // batched_number
25
+ tensor = tensor[:per_batch]
26
+
27
+ if per_batch > tensor.shape[0]:
28
+ tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)
29
+
30
+ current_batch_size = tensor.shape[0]
31
+ if current_batch_size == target_batch_size:
32
+ return tensor
33
+ else:
34
+ return torch.cat([tensor] * batched_number, dim=0)
35
+
36
+ class ControlBase:
37
+ def __init__(self, device=None):
38
+ self.cond_hint_original = None
39
+ self.cond_hint = None
40
+ self.strength = 1.0
41
+ self.timestep_percent_range = (0.0, 1.0)
42
+ self.latent_format = None
43
+ self.vae = None
44
+ self.global_average_pooling = False
45
+ self.timestep_range = None
46
+ self.compression_ratio = 8
47
+ self.upscale_algorithm = 'nearest-exact'
48
+ self.extra_args = {}
49
+
50
+ if device is None:
51
+ device = comfy.model_management.get_torch_device()
52
+ self.device = device
53
+ self.previous_controlnet = None
54
+
55
+ def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
56
+ self.cond_hint_original = cond_hint
57
+ self.strength = strength
58
+ self.timestep_percent_range = timestep_percent_range
59
+ if self.latent_format is not None:
60
+ self.vae = vae
61
+ return self
62
+
63
+ def pre_run(self, model, percent_to_timestep_function):
64
+ self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
65
+ if self.previous_controlnet is not None:
66
+ self.previous_controlnet.pre_run(model, percent_to_timestep_function)
67
+
68
+ def set_previous_controlnet(self, controlnet):
69
+ self.previous_controlnet = controlnet
70
+ return self
71
+
72
+ def cleanup(self):
73
+ if self.previous_controlnet is not None:
74
+ self.previous_controlnet.cleanup()
75
+ if self.cond_hint is not None:
76
+ del self.cond_hint
77
+ self.cond_hint = None
78
+ self.timestep_range = None
79
+
80
+ def get_models(self):
81
+ out = []
82
+ if self.previous_controlnet is not None:
83
+ out += self.previous_controlnet.get_models()
84
+ return out
85
+
86
+ def copy_to(self, c):
87
+ c.cond_hint_original = self.cond_hint_original
88
+ c.strength = self.strength
89
+ c.timestep_percent_range = self.timestep_percent_range
90
+ c.global_average_pooling = self.global_average_pooling
91
+ c.compression_ratio = self.compression_ratio
92
+ c.upscale_algorithm = self.upscale_algorithm
93
+ c.latent_format = self.latent_format
94
+ c.extra_args = self.extra_args.copy()
95
+ c.vae = self.vae
96
+
97
+ def inference_memory_requirements(self, dtype):
98
+ if self.previous_controlnet is not None:
99
+ return self.previous_controlnet.inference_memory_requirements(dtype)
100
+ return 0
101
+
102
+ def control_merge(self, control, control_prev, output_dtype):
103
+ out = {'input':[], 'middle':[], 'output': []}
104
+
105
+ for key in control:
106
+ control_output = control[key]
107
+ applied_to = set()
108
+ for i in range(len(control_output)):
109
+ x = control_output[i]
110
+ if x is not None:
111
+ if self.global_average_pooling:
112
+ x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
113
+
114
+ if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
115
+ applied_to.add(x)
116
+ x *= self.strength
117
+
118
+ if x.dtype != output_dtype:
119
+ x = x.to(output_dtype)
120
+
121
+ out[key].append(x)
122
+
123
+ if control_prev is not None:
124
+ for x in ['input', 'middle', 'output']:
125
+ o = out[x]
126
+ for i in range(len(control_prev[x])):
127
+ prev_val = control_prev[x][i]
128
+ if i >= len(o):
129
+ o.append(prev_val)
130
+ elif prev_val is not None:
131
+ if o[i] is None:
132
+ o[i] = prev_val
133
+ else:
134
+ if o[i].shape[0] < prev_val.shape[0]:
135
+ o[i] = prev_val + o[i]
136
+ else:
137
+ o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
138
+ return out
139
+
140
+ def set_extra_arg(self, argument, value=None):
141
+ self.extra_args[argument] = value
142
+
143
+
144
+ class ControlNet(ControlBase):
145
+ def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
146
+ super().__init__(device)
147
+ self.control_model = control_model
148
+ self.load_device = load_device
149
+ if control_model is not None:
150
+ self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
151
+
152
+ self.compression_ratio = compression_ratio
153
+ self.global_average_pooling = global_average_pooling
154
+ self.model_sampling_current = None
155
+ self.manual_cast_dtype = manual_cast_dtype
156
+ self.latent_format = latent_format
157
+
158
+ def get_control(self, x_noisy, t, cond, batched_number):
159
+ control_prev = None
160
+ if self.previous_controlnet is not None:
161
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
162
+
163
+ if self.timestep_range is not None:
164
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
165
+ if control_prev is not None:
166
+ return control_prev
167
+ else:
168
+ return None
169
+
170
+ dtype = self.control_model.dtype
171
+ if self.manual_cast_dtype is not None:
172
+ dtype = self.manual_cast_dtype
173
+
174
+ output_dtype = x_noisy.dtype
175
+ if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
176
+ if self.cond_hint is not None:
177
+ del self.cond_hint
178
+ self.cond_hint = None
179
+ compression_ratio = self.compression_ratio
180
+ if self.vae is not None:
181
+ compression_ratio *= self.vae.downscale_ratio
182
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
183
+ if self.vae is not None:
184
+ loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
185
+ self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
186
+ comfy.model_management.load_models_gpu(loaded_models)
187
+ if self.latent_format is not None:
188
+ self.cond_hint = self.latent_format.process_in(self.cond_hint)
189
+ self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
190
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
191
+ self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
192
+
193
+ context = cond.get('crossattn_controlnet', cond['c_crossattn'])
194
+ extra = self.extra_args.copy()
195
+ for c in ["y", "guidance"]: #TODO
196
+ temp = cond.get(c, None)
197
+ if temp is not None:
198
+ extra[c] = temp.to(dtype)
199
+
200
+ timestep = self.model_sampling_current.timestep(t)
201
+ x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
202
+
203
+ control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
204
+ return self.control_merge(control, control_prev, output_dtype)
205
+
206
+ def copy(self):
207
+ c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
208
+ c.control_model = self.control_model
209
+ c.control_model_wrapped = self.control_model_wrapped
210
+ self.copy_to(c)
211
+ return c
212
+
213
+ def get_models(self):
214
+ out = super().get_models()
215
+ out.append(self.control_model_wrapped)
216
+ return out
217
+
218
+ def pre_run(self, model, percent_to_timestep_function):
219
+ super().pre_run(model, percent_to_timestep_function)
220
+ self.model_sampling_current = model.model_sampling
221
+
222
+ def cleanup(self):
223
+ self.model_sampling_current = None
224
+ super().cleanup()
225
+
226
+ class ControlLoraOps:
227
+ class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
228
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
229
+ device=None, dtype=None) -> None:
230
+ factory_kwargs = {'device': device, 'dtype': dtype}
231
+ super().__init__()
232
+ self.in_features = in_features
233
+ self.out_features = out_features
234
+ self.weight = None
235
+ self.up = None
236
+ self.down = None
237
+ self.bias = None
238
+
239
+ def forward(self, input):
240
+ weight, bias = comfy.ops.cast_bias_weight(self, input)
241
+ if self.up is not None:
242
+ return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
243
+ else:
244
+ return torch.nn.functional.linear(input, weight, bias)
245
+
246
+ class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
247
+ def __init__(
248
+ self,
249
+ in_channels,
250
+ out_channels,
251
+ kernel_size,
252
+ stride=1,
253
+ padding=0,
254
+ dilation=1,
255
+ groups=1,
256
+ bias=True,
257
+ padding_mode='zeros',
258
+ device=None,
259
+ dtype=None
260
+ ):
261
+ super().__init__()
262
+ self.in_channels = in_channels
263
+ self.out_channels = out_channels
264
+ self.kernel_size = kernel_size
265
+ self.stride = stride
266
+ self.padding = padding
267
+ self.dilation = dilation
268
+ self.transposed = False
269
+ self.output_padding = 0
270
+ self.groups = groups
271
+ self.padding_mode = padding_mode
272
+
273
+ self.weight = None
274
+ self.bias = None
275
+ self.up = None
276
+ self.down = None
277
+
278
+
279
+ def forward(self, input):
280
+ weight, bias = comfy.ops.cast_bias_weight(self, input)
281
+ if self.up is not None:
282
+ return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
283
+ else:
284
+ return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
285
+
286
+
287
+ class ControlLora(ControlNet):
288
+ def __init__(self, control_weights, global_average_pooling=False, device=None):
289
+ ControlBase.__init__(self, device)
290
+ self.control_weights = control_weights
291
+ self.global_average_pooling = global_average_pooling
292
+
293
+ def pre_run(self, model, percent_to_timestep_function):
294
+ super().pre_run(model, percent_to_timestep_function)
295
+ controlnet_config = model.model_config.unet_config.copy()
296
+ controlnet_config.pop("out_channels")
297
+ controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
298
+ self.manual_cast_dtype = model.manual_cast_dtype
299
+ dtype = model.get_dtype()
300
+ if self.manual_cast_dtype is None:
301
+ class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
302
+ pass
303
+ else:
304
+ class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast):
305
+ pass
306
+ dtype = self.manual_cast_dtype
307
+
308
+ controlnet_config["operations"] = control_lora_ops
309
+ controlnet_config["dtype"] = dtype
310
+ self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
311
+ self.control_model.to(comfy.model_management.get_torch_device())
312
+ diffusion_model = model.diffusion_model
313
+ sd = diffusion_model.state_dict()
314
+ cm = self.control_model.state_dict()
315
+
316
+ for k in sd:
317
+ weight = sd[k]
318
+ try:
319
+ comfy.utils.set_attr_param(self.control_model, k, weight)
320
+ except:
321
+ pass
322
+
323
+ for k in self.control_weights:
324
+ if k not in {"lora_controlnet"}:
325
+ comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
326
+
327
+ def copy(self):
328
+ c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
329
+ self.copy_to(c)
330
+ return c
331
+
332
+ def cleanup(self):
333
+ del self.control_model
334
+ self.control_model = None
335
+ super().cleanup()
336
+
337
+ def get_models(self):
338
+ out = ControlBase.get_models(self)
339
+ return out
340
+
341
+ def inference_memory_requirements(self, dtype):
342
+ return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
343
+
344
+ def controlnet_config(sd):
345
+ model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
346
+
347
+ supported_inference_dtypes = model_config.supported_inference_dtypes
348
+
349
+ controlnet_config = model_config.unet_config
350
+ unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
351
+ load_device = comfy.model_management.get_torch_device()
352
+ manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
353
+ if manual_cast_dtype is not None:
354
+ operations = comfy.ops.manual_cast
355
+ else:
356
+ operations = comfy.ops.disable_weight_init
357
+
358
+ return model_config, operations, load_device, unet_dtype, manual_cast_dtype
359
+
360
+ def controlnet_load_state_dict(control_model, sd):
361
+ missing, unexpected = control_model.load_state_dict(sd, strict=False)
362
+
363
+ if len(missing) > 0:
364
+ logging.warning("missing controlnet keys: {}".format(missing))
365
+
366
+ if len(unexpected) > 0:
367
+ logging.debug("unexpected controlnet keys: {}".format(unexpected))
368
+ return control_model
369
+
370
+ def load_controlnet_mmdit(sd):
371
+ new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
372
+ model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
373
+ num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
374
+ for k in sd:
375
+ new_sd[k] = sd[k]
376
+
377
+ control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
378
+ control_model = controlnet_load_state_dict(control_model, new_sd)
379
+
380
+ latent_format = comfy.latent_formats.SD3()
381
+ latent_format.shift_factor = 0 #SD3 controlnet weirdness
382
+ control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
383
+ return control
384
+
385
+
386
+ def load_controlnet(ckpt_path, model=None):
387
+ controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
388
+ if "lora_controlnet" in controlnet_data:
389
+ return ControlLora(controlnet_data)
390
+
391
+ controlnet_config = None
392
+ supported_inference_dtypes = None
393
+
394
+ if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
395
+ controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
396
+ diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
397
+ diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
398
+ diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
399
+
400
+ count = 0
401
+ loop = True
402
+ while loop:
403
+ suffix = [".weight", ".bias"]
404
+ for s in suffix:
405
+ k_in = "controlnet_down_blocks.{}{}".format(count, s)
406
+ k_out = "zero_convs.{}.0{}".format(count, s)
407
+ if k_in not in controlnet_data:
408
+ loop = False
409
+ break
410
+ diffusers_keys[k_in] = k_out
411
+ count += 1
412
+
413
+ count = 0
414
+ loop = True
415
+ while loop:
416
+ suffix = [".weight", ".bias"]
417
+ for s in suffix:
418
+ if count == 0:
419
+ k_in = "controlnet_cond_embedding.conv_in{}".format(s)
420
+ else:
421
+ k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
422
+ k_out = "input_hint_block.{}{}".format(count * 2, s)
423
+ if k_in not in controlnet_data:
424
+ k_in = "controlnet_cond_embedding.conv_out{}".format(s)
425
+ loop = False
426
+ diffusers_keys[k_in] = k_out
427
+ count += 1
428
+
429
+ new_sd = {}
430
+ for k in diffusers_keys:
431
+ if k in controlnet_data:
432
+ new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
433
+
434
+ if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
435
+ controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0]
436
+ for k in list(controlnet_data.keys()):
437
+ new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
438
+ new_sd[new_k] = controlnet_data.pop(k)
439
+
440
+ leftover_keys = controlnet_data.keys()
441
+ if len(leftover_keys) > 0:
442
+ logging.warning("leftover keys: {}".format(leftover_keys))
443
+ controlnet_data = new_sd
444
+ elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
445
+ return load_controlnet_mmdit(controlnet_data)
446
+
447
+ pth_key = 'control_model.zero_convs.0.0.weight'
448
+ pth = False
449
+ key = 'zero_convs.0.0.weight'
450
+ if pth_key in controlnet_data:
451
+ pth = True
452
+ key = pth_key
453
+ prefix = "control_model."
454
+ elif key in controlnet_data:
455
+ prefix = ""
456
+ else:
457
+ net = load_t2i_adapter(controlnet_data)
458
+ if net is None:
459
+ logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
460
+ return net
461
+
462
+ if controlnet_config is None:
463
+ model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
464
+ supported_inference_dtypes = model_config.supported_inference_dtypes
465
+ controlnet_config = model_config.unet_config
466
+
467
+ load_device = comfy.model_management.get_torch_device()
468
+ if supported_inference_dtypes is None:
469
+ unet_dtype = comfy.model_management.unet_dtype()
470
+ else:
471
+ unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
472
+
473
+ manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
474
+ if manual_cast_dtype is not None:
475
+ controlnet_config["operations"] = comfy.ops.manual_cast
476
+ controlnet_config["dtype"] = unet_dtype
477
+ controlnet_config.pop("out_channels")
478
+ controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
479
+ control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
480
+
481
+ if pth:
482
+ if 'difference' in controlnet_data:
483
+ if model is not None:
484
+ comfy.model_management.load_models_gpu([model])
485
+ model_sd = model.model_state_dict()
486
+ for x in controlnet_data:
487
+ c_m = "control_model."
488
+ if x.startswith(c_m):
489
+ sd_key = "diffusion_model.{}".format(x[len(c_m):])
490
+ if sd_key in model_sd:
491
+ cd = controlnet_data[x]
492
+ cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
493
+ else:
494
+ logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
495
+
496
+ class WeightsLoader(torch.nn.Module):
497
+ pass
498
+ w = WeightsLoader()
499
+ w.control_model = control_model
500
+ missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
501
+ else:
502
+ missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
503
+
504
+ if len(missing) > 0:
505
+ logging.warning("missing controlnet keys: {}".format(missing))
506
+
507
+ if len(unexpected) > 0:
508
+ logging.debug("unexpected controlnet keys: {}".format(unexpected))
509
+
510
+ global_average_pooling = False
511
+ filename = os.path.splitext(ckpt_path)[0]
512
+ if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
513
+ global_average_pooling = True
514
+
515
+ control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
516
+ return control
517
+
518
+ class T2IAdapter(ControlBase):
519
+ def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
520
+ super().__init__(device)
521
+ self.t2i_model = t2i_model
522
+ self.channels_in = channels_in
523
+ self.control_input = None
524
+ self.compression_ratio = compression_ratio
525
+ self.upscale_algorithm = upscale_algorithm
526
+
527
+ def scale_image_to(self, width, height):
528
+ unshuffle_amount = self.t2i_model.unshuffle_amount
529
+ width = math.ceil(width / unshuffle_amount) * unshuffle_amount
530
+ height = math.ceil(height / unshuffle_amount) * unshuffle_amount
531
+ return width, height
532
+
533
+ def get_control(self, x_noisy, t, cond, batched_number):
534
+ control_prev = None
535
+ if self.previous_controlnet is not None:
536
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
537
+
538
+ if self.timestep_range is not None:
539
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
540
+ if control_prev is not None:
541
+ return control_prev
542
+ else:
543
+ return None
544
+
545
+ if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
546
+ if self.cond_hint is not None:
547
+ del self.cond_hint
548
+ self.control_input = None
549
+ self.cond_hint = None
550
+ width, height = self.scale_image_to(x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio)
551
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, self.upscale_algorithm, "center").float().to(self.device)
552
+ if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
553
+ self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
554
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
555
+ self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
556
+ if self.control_input is None:
557
+ self.t2i_model.to(x_noisy.dtype)
558
+ self.t2i_model.to(self.device)
559
+ self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
560
+ self.t2i_model.cpu()
561
+
562
+ control_input = {}
563
+ for k in self.control_input:
564
+ control_input[k] = list(map(lambda a: None if a is None else a.clone(), self.control_input[k]))
565
+
566
+ return self.control_merge(control_input, control_prev, x_noisy.dtype)
567
+
568
+ def copy(self):
569
+ c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
570
+ self.copy_to(c)
571
+ return c
572
+
573
+ def load_t2i_adapter(t2i_data):
574
+ compression_ratio = 8
575
+ upscale_algorithm = 'nearest-exact'
576
+
577
+ if 'adapter' in t2i_data:
578
+ t2i_data = t2i_data['adapter']
579
+ if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
580
+ prefix_replace = {}
581
+ for i in range(4):
582
+ for j in range(2):
583
+ prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
584
+ prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
585
+ prefix_replace["adapter."] = ""
586
+ t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
587
+ keys = t2i_data.keys()
588
+
589
+ if "body.0.in_conv.weight" in keys:
590
+ cin = t2i_data['body.0.in_conv.weight'].shape[1]
591
+ model_ad = comfy.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
592
+ elif 'conv_in.weight' in keys:
593
+ cin = t2i_data['conv_in.weight'].shape[1]
594
+ channel = t2i_data['conv_in.weight'].shape[0]
595
+ ksize = t2i_data['body.0.block2.weight'].shape[2]
596
+ use_conv = False
597
+ down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
598
+ if len(down_opts) > 0:
599
+ use_conv = True
600
+ xl = False
601
+ if cin == 256 or cin == 768:
602
+ xl = True
603
+ model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
604
+ elif "backbone.0.0.weight" in keys:
605
+ model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
606
+ compression_ratio = 32
607
+ upscale_algorithm = 'bilinear'
608
+ elif "backbone.10.blocks.0.weight" in keys:
609
+ model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
610
+ compression_ratio = 1
611
+ upscale_algorithm = 'nearest-exact'
612
+ else:
613
+ return None
614
+
615
+ missing, unexpected = model_ad.load_state_dict(t2i_data)
616
+ if len(missing) > 0:
617
+ logging.warning("t2i missing {}".format(missing))
618
+
619
+ if len(unexpected) > 0:
620
+ logging.debug("t2i unexpected {}".format(unexpected))
621
+
622
+ return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)
Backend/comfy/diffusers_convert.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import logging
4
+
5
+ # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
6
+
7
+ # =================#
8
+ # UNet Conversion #
9
+ # =================#
10
+
11
+ unet_conversion_map = [
12
+ # (stable-diffusion, HF Diffusers)
13
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
14
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
15
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
16
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
17
+ ("input_blocks.0.0.weight", "conv_in.weight"),
18
+ ("input_blocks.0.0.bias", "conv_in.bias"),
19
+ ("out.0.weight", "conv_norm_out.weight"),
20
+ ("out.0.bias", "conv_norm_out.bias"),
21
+ ("out.2.weight", "conv_out.weight"),
22
+ ("out.2.bias", "conv_out.bias"),
23
+ ]
24
+
25
+ unet_conversion_map_resnet = [
26
+ # (stable-diffusion, HF Diffusers)
27
+ ("in_layers.0", "norm1"),
28
+ ("in_layers.2", "conv1"),
29
+ ("out_layers.0", "norm2"),
30
+ ("out_layers.3", "conv2"),
31
+ ("emb_layers.1", "time_emb_proj"),
32
+ ("skip_connection", "conv_shortcut"),
33
+ ]
34
+
35
+ unet_conversion_map_layer = []
36
+ # hardcoded number of downblocks and resnets/attentions...
37
+ # would need smarter logic for other networks.
38
+ for i in range(4):
39
+ # loop over downblocks/upblocks
40
+
41
+ for j in range(2):
42
+ # loop over resnets/attentions for downblocks
43
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
44
+ sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
45
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
46
+
47
+ if i < 3:
48
+ # no attention layers in down_blocks.3
49
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
50
+ sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
51
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
52
+
53
+ for j in range(3):
54
+ # loop over resnets/attentions for upblocks
55
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
56
+ sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
57
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
58
+
59
+ if i > 0:
60
+ # no attention layers in up_blocks.0
61
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
62
+ sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
63
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
64
+
65
+ if i < 3:
66
+ # no downsample in down_blocks.3
67
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
68
+ sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
69
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
70
+
71
+ # no upsample in up_blocks.3
72
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
73
+ sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
74
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
75
+
76
+ hf_mid_atn_prefix = "mid_block.attentions.0."
77
+ sd_mid_atn_prefix = "middle_block.1."
78
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
79
+
80
+ for j in range(2):
81
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
82
+ sd_mid_res_prefix = f"middle_block.{2 * j}."
83
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
84
+
85
+
86
+ def convert_unet_state_dict(unet_state_dict):
87
+ # buyer beware: this is a *brittle* function,
88
+ # and correct output requires that all of these pieces interact in
89
+ # the exact order in which I have arranged them.
90
+ mapping = {k: k for k in unet_state_dict.keys()}
91
+ for sd_name, hf_name in unet_conversion_map:
92
+ mapping[hf_name] = sd_name
93
+ for k, v in mapping.items():
94
+ if "resnets" in k:
95
+ for sd_part, hf_part in unet_conversion_map_resnet:
96
+ v = v.replace(hf_part, sd_part)
97
+ mapping[k] = v
98
+ for k, v in mapping.items():
99
+ for sd_part, hf_part in unet_conversion_map_layer:
100
+ v = v.replace(hf_part, sd_part)
101
+ mapping[k] = v
102
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
103
+ return new_state_dict
104
+
105
+
106
+ # ================#
107
+ # VAE Conversion #
108
+ # ================#
109
+
110
+ vae_conversion_map = [
111
+ # (stable-diffusion, HF Diffusers)
112
+ ("nin_shortcut", "conv_shortcut"),
113
+ ("norm_out", "conv_norm_out"),
114
+ ("mid.attn_1.", "mid_block.attentions.0."),
115
+ ]
116
+
117
+ for i in range(4):
118
+ # down_blocks have two resnets
119
+ for j in range(2):
120
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
121
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
122
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
123
+
124
+ if i < 3:
125
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
126
+ sd_downsample_prefix = f"down.{i}.downsample."
127
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
128
+
129
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
130
+ sd_upsample_prefix = f"up.{3 - i}.upsample."
131
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
132
+
133
+ # up_blocks have three resnets
134
+ # also, up blocks in hf are numbered in reverse from sd
135
+ for j in range(3):
136
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
137
+ sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
138
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
139
+
140
+ # this part accounts for mid blocks in both the encoder and the decoder
141
+ for i in range(2):
142
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
143
+ sd_mid_res_prefix = f"mid.block_{i + 1}."
144
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
145
+
146
+ vae_conversion_map_attn = [
147
+ # (stable-diffusion, HF Diffusers)
148
+ ("norm.", "group_norm."),
149
+ ("q.", "query."),
150
+ ("k.", "key."),
151
+ ("v.", "value."),
152
+ ("q.", "to_q."),
153
+ ("k.", "to_k."),
154
+ ("v.", "to_v."),
155
+ ("proj_out.", "to_out.0."),
156
+ ("proj_out.", "proj_attn."),
157
+ ]
158
+
159
+
160
+ def reshape_weight_for_sd(w):
161
+ # convert HF linear weights to SD conv2d weights
162
+ return w.reshape(*w.shape, 1, 1)
163
+
164
+
165
+ def convert_vae_state_dict(vae_state_dict):
166
+ mapping = {k: k for k in vae_state_dict.keys()}
167
+ for k, v in mapping.items():
168
+ for sd_part, hf_part in vae_conversion_map:
169
+ v = v.replace(hf_part, sd_part)
170
+ mapping[k] = v
171
+ for k, v in mapping.items():
172
+ if "attentions" in k:
173
+ for sd_part, hf_part in vae_conversion_map_attn:
174
+ v = v.replace(hf_part, sd_part)
175
+ mapping[k] = v
176
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
177
+ weights_to_convert = ["q", "k", "v", "proj_out"]
178
+ for k, v in new_state_dict.items():
179
+ for weight_name in weights_to_convert:
180
+ if f"mid.attn_1.{weight_name}.weight" in k:
181
+ logging.debug(f"Reshaping {k} for SD format")
182
+ new_state_dict[k] = reshape_weight_for_sd(v)
183
+ return new_state_dict
184
+
185
+
186
+ # =========================#
187
+ # Text Encoder Conversion #
188
+ # =========================#
189
+
190
+
191
+ textenc_conversion_lst = [
192
+ # (stable-diffusion, HF Diffusers)
193
+ ("resblocks.", "text_model.encoder.layers."),
194
+ ("ln_1", "layer_norm1"),
195
+ ("ln_2", "layer_norm2"),
196
+ (".c_fc.", ".fc1."),
197
+ (".c_proj.", ".fc2."),
198
+ (".attn", ".self_attn"),
199
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
200
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
201
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
202
+ ]
203
+ protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
204
+ textenc_pattern = re.compile("|".join(protected.keys()))
205
+
206
+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
207
+ code2idx = {"q": 0, "k": 1, "v": 2}
208
+
209
+ # This function exists because at the time of writing torch.cat can't do fp8 with cuda
210
+ def cat_tensors(tensors):
211
+ x = 0
212
+ for t in tensors:
213
+ x += t.shape[0]
214
+
215
+ shape = [x] + list(tensors[0].shape)[1:]
216
+ out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
217
+
218
+ x = 0
219
+ for t in tensors:
220
+ out[x:x + t.shape[0]] = t
221
+ x += t.shape[0]
222
+
223
+ return out
224
+
225
+ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
226
+ new_state_dict = {}
227
+ capture_qkv_weight = {}
228
+ capture_qkv_bias = {}
229
+ for k, v in text_enc_dict.items():
230
+ if not k.startswith(prefix):
231
+ continue
232
+ if (
233
+ k.endswith(".self_attn.q_proj.weight")
234
+ or k.endswith(".self_attn.k_proj.weight")
235
+ or k.endswith(".self_attn.v_proj.weight")
236
+ ):
237
+ k_pre = k[: -len(".q_proj.weight")]
238
+ k_code = k[-len("q_proj.weight")]
239
+ if k_pre not in capture_qkv_weight:
240
+ capture_qkv_weight[k_pre] = [None, None, None]
241
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
242
+ continue
243
+
244
+ if (
245
+ k.endswith(".self_attn.q_proj.bias")
246
+ or k.endswith(".self_attn.k_proj.bias")
247
+ or k.endswith(".self_attn.v_proj.bias")
248
+ ):
249
+ k_pre = k[: -len(".q_proj.bias")]
250
+ k_code = k[-len("q_proj.bias")]
251
+ if k_pre not in capture_qkv_bias:
252
+ capture_qkv_bias[k_pre] = [None, None, None]
253
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
254
+ continue
255
+
256
+ text_proj = "transformer.text_projection.weight"
257
+ if k.endswith(text_proj):
258
+ new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous()
259
+ else:
260
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
261
+ new_state_dict[relabelled_key] = v
262
+
263
+ for k_pre, tensors in capture_qkv_weight.items():
264
+ if None in tensors:
265
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
266
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
267
+ new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors)
268
+
269
+ for k_pre, tensors in capture_qkv_bias.items():
270
+ if None in tensors:
271
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
272
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
273
+ new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors)
274
+
275
+ return new_state_dict
276
+
277
+
278
+ def convert_text_enc_state_dict(text_enc_dict):
279
+ return text_enc_dict
280
+
281
+
Backend/comfy/diffusers_load.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import comfy.sd
4
+
5
+ def first_file(path, filenames):
6
+ for f in filenames:
7
+ p = os.path.join(path, f)
8
+ if os.path.exists(p):
9
+ return p
10
+ return None
11
+
12
+ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None):
13
+ diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
14
+ unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names)
15
+ vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names)
16
+
17
+ text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"]
18
+ text_encoder1_path = first_file(os.path.join(model_path, "text_encoder"), text_encoder_model_names)
19
+ text_encoder2_path = first_file(os.path.join(model_path, "text_encoder_2"), text_encoder_model_names)
20
+
21
+ text_encoder_paths = [text_encoder1_path]
22
+ if text_encoder2_path is not None:
23
+ text_encoder_paths.append(text_encoder2_path)
24
+
25
+ unet = comfy.sd.load_unet(unet_path)
26
+
27
+ clip = None
28
+ if output_clip:
29
+ clip = comfy.sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory)
30
+
31
+ vae = None
32
+ if output_vae:
33
+ sd = comfy.utils.load_torch_file(vae_path)
34
+ vae = comfy.sd.VAE(sd=sd)
35
+
36
+ return (unet, clip, vae)
Backend/comfy/extra_samplers/uni_pc.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #code taken from: https://github.com/wl-zhao/UniPC and modified
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import math
6
+
7
+ from tqdm.auto import trange, tqdm
8
+
9
+
10
+ class NoiseScheduleVP:
11
+ def __init__(
12
+ self,
13
+ schedule='discrete',
14
+ betas=None,
15
+ alphas_cumprod=None,
16
+ continuous_beta_0=0.1,
17
+ continuous_beta_1=20.,
18
+ ):
19
+ """Create a wrapper class for the forward SDE (VP type).
20
+
21
+ ***
22
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
23
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
24
+ ***
25
+
26
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
27
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
28
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
29
+
30
+ log_alpha_t = self.marginal_log_mean_coeff(t)
31
+ sigma_t = self.marginal_std(t)
32
+ lambda_t = self.marginal_lambda(t)
33
+
34
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
35
+
36
+ t = self.inverse_lambda(lambda_t)
37
+
38
+ ===============================================================
39
+
40
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
41
+
42
+ 1. For discrete-time DPMs:
43
+
44
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
45
+ t_i = (i + 1) / N
46
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
47
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
48
+
49
+ Args:
50
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
51
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
52
+
53
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
54
+
55
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
56
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
57
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
58
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
59
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
60
+ and
61
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
62
+
63
+
64
+ 2. For continuous-time DPMs:
65
+
66
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
67
+ schedule are the default settings in DDPM and improved-DDPM:
68
+
69
+ Args:
70
+ beta_min: A `float` number. The smallest beta for the linear schedule.
71
+ beta_max: A `float` number. The largest beta for the linear schedule.
72
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
73
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
74
+ T: A `float` number. The ending time of the forward process.
75
+
76
+ ===============================================================
77
+
78
+ Args:
79
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
80
+ 'linear' or 'cosine' for continuous-time DPMs.
81
+ Returns:
82
+ A wrapper object of the forward SDE (VP type).
83
+
84
+ ===============================================================
85
+
86
+ Example:
87
+
88
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
89
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
90
+
91
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
92
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
93
+
94
+ # For continuous-time DPMs (VPSDE), linear schedule:
95
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
96
+
97
+ """
98
+
99
+ if schedule not in ['discrete', 'linear', 'cosine']:
100
+ raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
101
+
102
+ self.schedule = schedule
103
+ if schedule == 'discrete':
104
+ if betas is not None:
105
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
106
+ else:
107
+ assert alphas_cumprod is not None
108
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
109
+ self.total_N = len(log_alphas)
110
+ self.T = 1.
111
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
112
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
113
+ else:
114
+ self.total_N = 1000
115
+ self.beta_0 = continuous_beta_0
116
+ self.beta_1 = continuous_beta_1
117
+ self.cosine_s = 0.008
118
+ self.cosine_beta_max = 999.
119
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
120
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
121
+ self.schedule = schedule
122
+ if schedule == 'cosine':
123
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
124
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
125
+ self.T = 0.9946
126
+ else:
127
+ self.T = 1.
128
+
129
+ def marginal_log_mean_coeff(self, t):
130
+ """
131
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
132
+ """
133
+ if self.schedule == 'discrete':
134
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
135
+ elif self.schedule == 'linear':
136
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
137
+ elif self.schedule == 'cosine':
138
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
139
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
140
+ return log_alpha_t
141
+
142
+ def marginal_alpha(self, t):
143
+ """
144
+ Compute alpha_t of a given continuous-time label t in [0, T].
145
+ """
146
+ return torch.exp(self.marginal_log_mean_coeff(t))
147
+
148
+ def marginal_std(self, t):
149
+ """
150
+ Compute sigma_t of a given continuous-time label t in [0, T].
151
+ """
152
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
153
+
154
+ def marginal_lambda(self, t):
155
+ """
156
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
157
+ """
158
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
159
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
160
+ return log_mean_coeff - log_std
161
+
162
+ def inverse_lambda(self, lamb):
163
+ """
164
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
165
+ """
166
+ if self.schedule == 'linear':
167
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
168
+ Delta = self.beta_0**2 + tmp
169
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
170
+ elif self.schedule == 'discrete':
171
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
172
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
173
+ return t.reshape((-1,))
174
+ else:
175
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
176
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
177
+ t = t_fn(log_alpha)
178
+ return t
179
+
180
+
181
+ def model_wrapper(
182
+ model,
183
+ noise_schedule,
184
+ model_type="noise",
185
+ model_kwargs={},
186
+ guidance_type="uncond",
187
+ condition=None,
188
+ unconditional_condition=None,
189
+ guidance_scale=1.,
190
+ classifier_fn=None,
191
+ classifier_kwargs={},
192
+ ):
193
+ """Create a wrapper function for the noise prediction model.
194
+
195
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
196
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
197
+
198
+ We support four types of the diffusion model by setting `model_type`:
199
+
200
+ 1. "noise": noise prediction model. (Trained by predicting noise).
201
+
202
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
203
+
204
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
205
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
206
+
207
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
208
+ arXiv preprint arXiv:2202.00512 (2022).
209
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
210
+ arXiv preprint arXiv:2210.02303 (2022).
211
+
212
+ 4. "score": marginal score function. (Trained by denoising score matching).
213
+ Note that the score function and the noise prediction model follows a simple relationship:
214
+ ```
215
+ noise(x_t, t) = -sigma_t * score(x_t, t)
216
+ ```
217
+
218
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
219
+ 1. "uncond": unconditional sampling by DPMs.
220
+ The input `model` has the following format:
221
+ ``
222
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
223
+ ``
224
+
225
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
226
+ The input `model` has the following format:
227
+ ``
228
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
229
+ ``
230
+
231
+ The input `classifier_fn` has the following format:
232
+ ``
233
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
234
+ ``
235
+
236
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
237
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
238
+
239
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
240
+ The input `model` has the following format:
241
+ ``
242
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
243
+ ``
244
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
245
+
246
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
247
+ arXiv preprint arXiv:2207.12598 (2022).
248
+
249
+
250
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
251
+ or continuous-time labels (i.e. epsilon to T).
252
+
253
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
254
+ ``
255
+ def model_fn(x, t_continuous) -> noise:
256
+ t_input = get_model_input_time(t_continuous)
257
+ return noise_pred(model, x, t_input, **model_kwargs)
258
+ ``
259
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
260
+
261
+ ===============================================================
262
+
263
+ Args:
264
+ model: A diffusion model with the corresponding format described above.
265
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
266
+ model_type: A `str`. The parameterization type of the diffusion model.
267
+ "noise" or "x_start" or "v" or "score".
268
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
269
+ guidance_type: A `str`. The type of the guidance for sampling.
270
+ "uncond" or "classifier" or "classifier-free".
271
+ condition: A pytorch tensor. The condition for the guided sampling.
272
+ Only used for "classifier" or "classifier-free" guidance type.
273
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
274
+ Only used for "classifier-free" guidance type.
275
+ guidance_scale: A `float`. The scale for the guided sampling.
276
+ classifier_fn: A classifier function. Only used for the classifier guidance.
277
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
278
+ Returns:
279
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
280
+ """
281
+
282
+ def get_model_input_time(t_continuous):
283
+ """
284
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
285
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
286
+ For continuous-time DPMs, we just use `t_continuous`.
287
+ """
288
+ if noise_schedule.schedule == 'discrete':
289
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
290
+ else:
291
+ return t_continuous
292
+
293
+ def noise_pred_fn(x, t_continuous, cond=None):
294
+ if t_continuous.reshape((-1,)).shape[0] == 1:
295
+ t_continuous = t_continuous.expand((x.shape[0]))
296
+ t_input = get_model_input_time(t_continuous)
297
+ output = model(x, t_input, **model_kwargs)
298
+ if model_type == "noise":
299
+ return output
300
+ elif model_type == "x_start":
301
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
302
+ dims = x.dim()
303
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
304
+ elif model_type == "v":
305
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
306
+ dims = x.dim()
307
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
308
+ elif model_type == "score":
309
+ sigma_t = noise_schedule.marginal_std(t_continuous)
310
+ dims = x.dim()
311
+ return -expand_dims(sigma_t, dims) * output
312
+
313
+ def cond_grad_fn(x, t_input):
314
+ """
315
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
316
+ """
317
+ with torch.enable_grad():
318
+ x_in = x.detach().requires_grad_(True)
319
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
320
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
321
+
322
+ def model_fn(x, t_continuous):
323
+ """
324
+ The noise predicition model function that is used for DPM-Solver.
325
+ """
326
+ if t_continuous.reshape((-1,)).shape[0] == 1:
327
+ t_continuous = t_continuous.expand((x.shape[0]))
328
+ if guidance_type == "uncond":
329
+ return noise_pred_fn(x, t_continuous)
330
+ elif guidance_type == "classifier":
331
+ assert classifier_fn is not None
332
+ t_input = get_model_input_time(t_continuous)
333
+ cond_grad = cond_grad_fn(x, t_input)
334
+ sigma_t = noise_schedule.marginal_std(t_continuous)
335
+ noise = noise_pred_fn(x, t_continuous)
336
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
337
+ elif guidance_type == "classifier-free":
338
+ if guidance_scale == 1. or unconditional_condition is None:
339
+ return noise_pred_fn(x, t_continuous, cond=condition)
340
+ else:
341
+ x_in = torch.cat([x] * 2)
342
+ t_in = torch.cat([t_continuous] * 2)
343
+ c_in = torch.cat([unconditional_condition, condition])
344
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
345
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
346
+
347
+ assert model_type in ["noise", "x_start", "v"]
348
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
349
+ return model_fn
350
+
351
+
352
+ class UniPC:
353
+ def __init__(
354
+ self,
355
+ model_fn,
356
+ noise_schedule,
357
+ predict_x0=True,
358
+ thresholding=False,
359
+ max_val=1.,
360
+ variant='bh1',
361
+ ):
362
+ """Construct a UniPC.
363
+
364
+ We support both data_prediction and noise_prediction.
365
+ """
366
+ self.model = model_fn
367
+ self.noise_schedule = noise_schedule
368
+ self.variant = variant
369
+ self.predict_x0 = predict_x0
370
+ self.thresholding = thresholding
371
+ self.max_val = max_val
372
+
373
+ def dynamic_thresholding_fn(self, x0, t=None):
374
+ """
375
+ The dynamic thresholding method.
376
+ """
377
+ dims = x0.dim()
378
+ p = self.dynamic_thresholding_ratio
379
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
380
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
381
+ x0 = torch.clamp(x0, -s, s) / s
382
+ return x0
383
+
384
+ def noise_prediction_fn(self, x, t):
385
+ """
386
+ Return the noise prediction model.
387
+ """
388
+ return self.model(x, t)
389
+
390
+ def data_prediction_fn(self, x, t):
391
+ """
392
+ Return the data prediction model (with thresholding).
393
+ """
394
+ noise = self.noise_prediction_fn(x, t)
395
+ dims = x.dim()
396
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
397
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
398
+ if self.thresholding:
399
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
400
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
401
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
402
+ x0 = torch.clamp(x0, -s, s) / s
403
+ return x0
404
+
405
+ def model_fn(self, x, t):
406
+ """
407
+ Convert the model to the noise prediction model or the data prediction model.
408
+ """
409
+ if self.predict_x0:
410
+ return self.data_prediction_fn(x, t)
411
+ else:
412
+ return self.noise_prediction_fn(x, t)
413
+
414
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
415
+ """Compute the intermediate time steps for sampling.
416
+ """
417
+ if skip_type == 'logSNR':
418
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
419
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
420
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
421
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
422
+ elif skip_type == 'time_uniform':
423
+ return torch.linspace(t_T, t_0, N + 1).to(device)
424
+ elif skip_type == 'time_quadratic':
425
+ t_order = 2
426
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
427
+ return t
428
+ else:
429
+ raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
430
+
431
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
432
+ """
433
+ Get the order of each step for sampling by the singlestep DPM-Solver.
434
+ """
435
+ if order == 3:
436
+ K = steps // 3 + 1
437
+ if steps % 3 == 0:
438
+ orders = [3,] * (K - 2) + [2, 1]
439
+ elif steps % 3 == 1:
440
+ orders = [3,] * (K - 1) + [1]
441
+ else:
442
+ orders = [3,] * (K - 1) + [2]
443
+ elif order == 2:
444
+ if steps % 2 == 0:
445
+ K = steps // 2
446
+ orders = [2,] * K
447
+ else:
448
+ K = steps // 2 + 1
449
+ orders = [2,] * (K - 1) + [1]
450
+ elif order == 1:
451
+ K = steps
452
+ orders = [1,] * steps
453
+ else:
454
+ raise ValueError("'order' must be '1' or '2' or '3'.")
455
+ if skip_type == 'logSNR':
456
+ # To reproduce the results in DPM-Solver paper
457
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
458
+ else:
459
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
460
+ return timesteps_outer, orders
461
+
462
+ def denoise_to_zero_fn(self, x, s):
463
+ """
464
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
465
+ """
466
+ return self.data_prediction_fn(x, s)
467
+
468
+ def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
469
+ if len(t.shape) == 0:
470
+ t = t.view(-1)
471
+ if 'bh' in self.variant:
472
+ return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
473
+ else:
474
+ assert self.variant == 'vary_coeff'
475
+ return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
476
+
477
+ def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
478
+ print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
479
+ ns = self.noise_schedule
480
+ assert order <= len(model_prev_list)
481
+
482
+ # first compute rks
483
+ t_prev_0 = t_prev_list[-1]
484
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
485
+ lambda_t = ns.marginal_lambda(t)
486
+ model_prev_0 = model_prev_list[-1]
487
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
488
+ log_alpha_t = ns.marginal_log_mean_coeff(t)
489
+ alpha_t = torch.exp(log_alpha_t)
490
+
491
+ h = lambda_t - lambda_prev_0
492
+
493
+ rks = []
494
+ D1s = []
495
+ for i in range(1, order):
496
+ t_prev_i = t_prev_list[-(i + 1)]
497
+ model_prev_i = model_prev_list[-(i + 1)]
498
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
499
+ rk = (lambda_prev_i - lambda_prev_0) / h
500
+ rks.append(rk)
501
+ D1s.append((model_prev_i - model_prev_0) / rk)
502
+
503
+ rks.append(1.)
504
+ rks = torch.tensor(rks, device=x.device)
505
+
506
+ K = len(rks)
507
+ # build C matrix
508
+ C = []
509
+
510
+ col = torch.ones_like(rks)
511
+ for k in range(1, K + 1):
512
+ C.append(col)
513
+ col = col * rks / (k + 1)
514
+ C = torch.stack(C, dim=1)
515
+
516
+ if len(D1s) > 0:
517
+ D1s = torch.stack(D1s, dim=1) # (B, K)
518
+ C_inv_p = torch.linalg.inv(C[:-1, :-1])
519
+ A_p = C_inv_p
520
+
521
+ if use_corrector:
522
+ print('using corrector')
523
+ C_inv = torch.linalg.inv(C)
524
+ A_c = C_inv
525
+
526
+ hh = -h if self.predict_x0 else h
527
+ h_phi_1 = torch.expm1(hh)
528
+ h_phi_ks = []
529
+ factorial_k = 1
530
+ h_phi_k = h_phi_1
531
+ for k in range(1, K + 2):
532
+ h_phi_ks.append(h_phi_k)
533
+ h_phi_k = h_phi_k / hh - 1 / factorial_k
534
+ factorial_k *= (k + 1)
535
+
536
+ model_t = None
537
+ if self.predict_x0:
538
+ x_t_ = (
539
+ sigma_t / sigma_prev_0 * x
540
+ - alpha_t * h_phi_1 * model_prev_0
541
+ )
542
+ # now predictor
543
+ x_t = x_t_
544
+ if len(D1s) > 0:
545
+ # compute the residuals for predictor
546
+ for k in range(K - 1):
547
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
548
+ # now corrector
549
+ if use_corrector:
550
+ model_t = self.model_fn(x_t, t)
551
+ D1_t = (model_t - model_prev_0)
552
+ x_t = x_t_
553
+ k = 0
554
+ for k in range(K - 1):
555
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
556
+ x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
557
+ else:
558
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
559
+ x_t_ = (
560
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
561
+ - (sigma_t * h_phi_1) * model_prev_0
562
+ )
563
+ # now predictor
564
+ x_t = x_t_
565
+ if len(D1s) > 0:
566
+ # compute the residuals for predictor
567
+ for k in range(K - 1):
568
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
569
+ # now corrector
570
+ if use_corrector:
571
+ model_t = self.model_fn(x_t, t)
572
+ D1_t = (model_t - model_prev_0)
573
+ x_t = x_t_
574
+ k = 0
575
+ for k in range(K - 1):
576
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
577
+ x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
578
+ return x_t, model_t
579
+
580
+ def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
581
+ # print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
582
+ ns = self.noise_schedule
583
+ assert order <= len(model_prev_list)
584
+ dims = x.dim()
585
+
586
+ # first compute rks
587
+ t_prev_0 = t_prev_list[-1]
588
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
589
+ lambda_t = ns.marginal_lambda(t)
590
+ model_prev_0 = model_prev_list[-1]
591
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
592
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
593
+ alpha_t = torch.exp(log_alpha_t)
594
+
595
+ h = lambda_t - lambda_prev_0
596
+
597
+ rks = []
598
+ D1s = []
599
+ for i in range(1, order):
600
+ t_prev_i = t_prev_list[-(i + 1)]
601
+ model_prev_i = model_prev_list[-(i + 1)]
602
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
603
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
604
+ rks.append(rk)
605
+ D1s.append((model_prev_i - model_prev_0) / rk)
606
+
607
+ rks.append(1.)
608
+ rks = torch.tensor(rks, device=x.device)
609
+
610
+ R = []
611
+ b = []
612
+
613
+ hh = -h[0] if self.predict_x0 else h[0]
614
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
615
+ h_phi_k = h_phi_1 / hh - 1
616
+
617
+ factorial_i = 1
618
+
619
+ if self.variant == 'bh1':
620
+ B_h = hh
621
+ elif self.variant == 'bh2':
622
+ B_h = torch.expm1(hh)
623
+ else:
624
+ raise NotImplementedError()
625
+
626
+ for i in range(1, order + 1):
627
+ R.append(torch.pow(rks, i - 1))
628
+ b.append(h_phi_k * factorial_i / B_h)
629
+ factorial_i *= (i + 1)
630
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
631
+
632
+ R = torch.stack(R)
633
+ b = torch.tensor(b, device=x.device)
634
+
635
+ # now predictor
636
+ use_predictor = len(D1s) > 0 and x_t is None
637
+ if len(D1s) > 0:
638
+ D1s = torch.stack(D1s, dim=1) # (B, K)
639
+ if x_t is None:
640
+ # for order 2, we use a simplified version
641
+ if order == 2:
642
+ rhos_p = torch.tensor([0.5], device=b.device)
643
+ else:
644
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
645
+ else:
646
+ D1s = None
647
+
648
+ if use_corrector:
649
+ # print('using corrector')
650
+ # for order 1, we use a simplified version
651
+ if order == 1:
652
+ rhos_c = torch.tensor([0.5], device=b.device)
653
+ else:
654
+ rhos_c = torch.linalg.solve(R, b)
655
+
656
+ model_t = None
657
+ if self.predict_x0:
658
+ x_t_ = (
659
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
660
+ - expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
661
+ )
662
+
663
+ if x_t is None:
664
+ if use_predictor:
665
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
666
+ else:
667
+ pred_res = 0
668
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
669
+
670
+ if use_corrector:
671
+ model_t = self.model_fn(x_t, t)
672
+ if D1s is not None:
673
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
674
+ else:
675
+ corr_res = 0
676
+ D1_t = (model_t - model_prev_0)
677
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
678
+ else:
679
+ x_t_ = (
680
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
681
+ - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
682
+ )
683
+ if x_t is None:
684
+ if use_predictor:
685
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
686
+ else:
687
+ pred_res = 0
688
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
689
+
690
+ if use_corrector:
691
+ model_t = self.model_fn(x_t, t)
692
+ if D1s is not None:
693
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
694
+ else:
695
+ corr_res = 0
696
+ D1_t = (model_t - model_prev_0)
697
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
698
+ return x_t, model_t
699
+
700
+
701
+ def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform',
702
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
703
+ atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
704
+ ):
705
+ # t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
706
+ # t_T = self.noise_schedule.T if t_start is None else t_start
707
+ device = x.device
708
+ steps = len(timesteps) - 1
709
+ if method == 'multistep':
710
+ assert steps >= order
711
+ # timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
712
+ assert timesteps.shape[0] - 1 == steps
713
+ # with torch.no_grad():
714
+ for step_index in trange(steps, disable=disable_pbar):
715
+ if step_index == 0:
716
+ vec_t = timesteps[0].expand((x.shape[0]))
717
+ model_prev_list = [self.model_fn(x, vec_t)]
718
+ t_prev_list = [vec_t]
719
+ elif step_index < order:
720
+ init_order = step_index
721
+ # Init the first `order` values by lower order multistep DPM-Solver.
722
+ # for init_order in range(1, order):
723
+ vec_t = timesteps[init_order].expand(x.shape[0])
724
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
725
+ if model_x is None:
726
+ model_x = self.model_fn(x, vec_t)
727
+ model_prev_list.append(model_x)
728
+ t_prev_list.append(vec_t)
729
+ else:
730
+ extra_final_step = 0
731
+ if step_index == (steps - 1):
732
+ extra_final_step = 1
733
+ for step in range(step_index, step_index + 1 + extra_final_step):
734
+ vec_t = timesteps[step].expand(x.shape[0])
735
+ if lower_order_final:
736
+ step_order = min(order, steps + 1 - step)
737
+ else:
738
+ step_order = order
739
+ # print('this step order:', step_order)
740
+ if step == steps:
741
+ # print('do not run corrector at the last step')
742
+ use_corrector = False
743
+ else:
744
+ use_corrector = True
745
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
746
+ for i in range(order - 1):
747
+ t_prev_list[i] = t_prev_list[i + 1]
748
+ model_prev_list[i] = model_prev_list[i + 1]
749
+ t_prev_list[-1] = vec_t
750
+ # We do not need to evaluate the final model value.
751
+ if step < steps:
752
+ if model_x is None:
753
+ model_x = self.model_fn(x, vec_t)
754
+ model_prev_list[-1] = model_x
755
+ if callback is not None:
756
+ callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]})
757
+ else:
758
+ raise NotImplementedError()
759
+ # if denoise_to_zero:
760
+ # x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
761
+ return x
762
+
763
+
764
+ #############################################################
765
+ # other utility functions
766
+ #############################################################
767
+
768
+ def interpolate_fn(x, xp, yp):
769
+ """
770
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
771
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
772
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
773
+
774
+ Args:
775
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
776
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
777
+ yp: PyTorch tensor with shape [C, K].
778
+ Returns:
779
+ The function values f(x), with shape [N, C].
780
+ """
781
+ N, K = x.shape[0], xp.shape[1]
782
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
783
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
784
+ x_idx = torch.argmin(x_indices, dim=2)
785
+ cand_start_idx = x_idx - 1
786
+ start_idx = torch.where(
787
+ torch.eq(x_idx, 0),
788
+ torch.tensor(1, device=x.device),
789
+ torch.where(
790
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
791
+ ),
792
+ )
793
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
794
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
795
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
796
+ start_idx2 = torch.where(
797
+ torch.eq(x_idx, 0),
798
+ torch.tensor(0, device=x.device),
799
+ torch.where(
800
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
801
+ ),
802
+ )
803
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
804
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
805
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
806
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
807
+ return cand
808
+
809
+
810
+ def expand_dims(v, dims):
811
+ """
812
+ Expand the tensor `v` to the dim `dims`.
813
+
814
+ Args:
815
+ `v`: a PyTorch tensor with shape [N].
816
+ `dim`: a `int`.
817
+ Returns:
818
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
819
+ """
820
+ return v[(...,) + (None,)*(dims - 1)]
821
+
822
+
823
+ class SigmaConvert:
824
+ schedule = ""
825
+ def marginal_log_mean_coeff(self, sigma):
826
+ return 0.5 * torch.log(1 / ((sigma * sigma) + 1))
827
+
828
+ def marginal_alpha(self, t):
829
+ return torch.exp(self.marginal_log_mean_coeff(t))
830
+
831
+ def marginal_std(self, t):
832
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
833
+
834
+ def marginal_lambda(self, t):
835
+ """
836
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
837
+ """
838
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
839
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
840
+ return log_mean_coeff - log_std
841
+
842
+ def predict_eps_sigma(model, input, sigma_in, **kwargs):
843
+ sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1))
844
+ input = input * ((sigma ** 2 + 1.0) ** 0.5)
845
+ return (input - model(input, sigma_in, **kwargs)) / sigma
846
+
847
+
848
+ def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
849
+ timesteps = sigmas.clone()
850
+ if sigmas[-1] == 0:
851
+ timesteps = sigmas[:]
852
+ timesteps[-1] = 0.001
853
+ else:
854
+ timesteps = sigmas.clone()
855
+ ns = SigmaConvert()
856
+
857
+ noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0)
858
+ model_type = "noise"
859
+
860
+ model_fn = model_wrapper(
861
+ lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
862
+ ns,
863
+ model_type=model_type,
864
+ guidance_type="uncond",
865
+ model_kwargs=extra_args,
866
+ )
867
+
868
+ order = min(3, len(timesteps) - 2)
869
+ uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant)
870
+ x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
871
+ x /= ns.marginal_alpha(timesteps[-1])
872
+ return x
873
+
874
+ def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
875
+ return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
Backend/comfy/gligen.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .ldm.modules.attention import CrossAttention
4
+ from inspect import isfunction
5
+ import comfy.ops
6
+ ops = comfy.ops.manual_cast
7
+
8
+ def exists(val):
9
+ return val is not None
10
+
11
+
12
+ def uniq(arr):
13
+ return{el: True for el in arr}.keys()
14
+
15
+
16
+ def default(val, d):
17
+ if exists(val):
18
+ return val
19
+ return d() if isfunction(d) else d
20
+
21
+
22
+ # feedforward
23
+ class GEGLU(nn.Module):
24
+ def __init__(self, dim_in, dim_out):
25
+ super().__init__()
26
+ self.proj = ops.Linear(dim_in, dim_out * 2)
27
+
28
+ def forward(self, x):
29
+ x, gate = self.proj(x).chunk(2, dim=-1)
30
+ return x * torch.nn.functional.gelu(gate)
31
+
32
+
33
+ class FeedForward(nn.Module):
34
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
35
+ super().__init__()
36
+ inner_dim = int(dim * mult)
37
+ dim_out = default(dim_out, dim)
38
+ project_in = nn.Sequential(
39
+ ops.Linear(dim, inner_dim),
40
+ nn.GELU()
41
+ ) if not glu else GEGLU(dim, inner_dim)
42
+
43
+ self.net = nn.Sequential(
44
+ project_in,
45
+ nn.Dropout(dropout),
46
+ ops.Linear(inner_dim, dim_out)
47
+ )
48
+
49
+ def forward(self, x):
50
+ return self.net(x)
51
+
52
+
53
+ class GatedCrossAttentionDense(nn.Module):
54
+ def __init__(self, query_dim, context_dim, n_heads, d_head):
55
+ super().__init__()
56
+
57
+ self.attn = CrossAttention(
58
+ query_dim=query_dim,
59
+ context_dim=context_dim,
60
+ heads=n_heads,
61
+ dim_head=d_head,
62
+ operations=ops)
63
+ self.ff = FeedForward(query_dim, glu=True)
64
+
65
+ self.norm1 = ops.LayerNorm(query_dim)
66
+ self.norm2 = ops.LayerNorm(query_dim)
67
+
68
+ self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
69
+ self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
70
+
71
+ # this can be useful: we can externally change magnitude of tanh(alpha)
72
+ # for example, when it is set to 0, then the entire model is same as
73
+ # original one
74
+ self.scale = 1
75
+
76
+ def forward(self, x, objs):
77
+
78
+ x = x + self.scale * \
79
+ torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
80
+ x = x + self.scale * \
81
+ torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
82
+
83
+ return x
84
+
85
+
86
+ class GatedSelfAttentionDense(nn.Module):
87
+ def __init__(self, query_dim, context_dim, n_heads, d_head):
88
+ super().__init__()
89
+
90
+ # we need a linear projection since we need cat visual feature and obj
91
+ # feature
92
+ self.linear = ops.Linear(context_dim, query_dim)
93
+
94
+ self.attn = CrossAttention(
95
+ query_dim=query_dim,
96
+ context_dim=query_dim,
97
+ heads=n_heads,
98
+ dim_head=d_head,
99
+ operations=ops)
100
+ self.ff = FeedForward(query_dim, glu=True)
101
+
102
+ self.norm1 = ops.LayerNorm(query_dim)
103
+ self.norm2 = ops.LayerNorm(query_dim)
104
+
105
+ self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
106
+ self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
107
+
108
+ # this can be useful: we can externally change magnitude of tanh(alpha)
109
+ # for example, when it is set to 0, then the entire model is same as
110
+ # original one
111
+ self.scale = 1
112
+
113
+ def forward(self, x, objs):
114
+
115
+ N_visual = x.shape[1]
116
+ objs = self.linear(objs)
117
+
118
+ x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
119
+ self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
120
+ x = x + self.scale * \
121
+ torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
122
+
123
+ return x
124
+
125
+
126
+ class GatedSelfAttentionDense2(nn.Module):
127
+ def __init__(self, query_dim, context_dim, n_heads, d_head):
128
+ super().__init__()
129
+
130
+ # we need a linear projection since we need cat visual feature and obj
131
+ # feature
132
+ self.linear = ops.Linear(context_dim, query_dim)
133
+
134
+ self.attn = CrossAttention(
135
+ query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops)
136
+ self.ff = FeedForward(query_dim, glu=True)
137
+
138
+ self.norm1 = ops.LayerNorm(query_dim)
139
+ self.norm2 = ops.LayerNorm(query_dim)
140
+
141
+ self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
142
+ self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
143
+
144
+ # this can be useful: we can externally change magnitude of tanh(alpha)
145
+ # for example, when it is set to 0, then the entire model is same as
146
+ # original one
147
+ self.scale = 1
148
+
149
+ def forward(self, x, objs):
150
+
151
+ B, N_visual, _ = x.shape
152
+ B, N_ground, _ = objs.shape
153
+
154
+ objs = self.linear(objs)
155
+
156
+ # sanity check
157
+ size_v = math.sqrt(N_visual)
158
+ size_g = math.sqrt(N_ground)
159
+ assert int(size_v) == size_v, "Visual tokens must be square rootable"
160
+ assert int(size_g) == size_g, "Grounding tokens must be square rootable"
161
+ size_v = int(size_v)
162
+ size_g = int(size_g)
163
+
164
+ # select grounding token and resize it to visual token size as residual
165
+ out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[
166
+ :, N_visual:, :]
167
+ out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g)
168
+ out = torch.nn.functional.interpolate(
169
+ out, (size_v, size_v), mode='bicubic')
170
+ residual = out.reshape(B, -1, N_visual).permute(0, 2, 1)
171
+
172
+ # add residual to visual feature
173
+ x = x + self.scale * torch.tanh(self.alpha_attn) * residual
174
+ x = x + self.scale * \
175
+ torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
176
+
177
+ return x
178
+
179
+
180
+ class FourierEmbedder():
181
+ def __init__(self, num_freqs=64, temperature=100):
182
+
183
+ self.num_freqs = num_freqs
184
+ self.temperature = temperature
185
+ self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
186
+
187
+ @torch.no_grad()
188
+ def __call__(self, x, cat_dim=-1):
189
+ "x: arbitrary shape of tensor. dim: cat dim"
190
+ out = []
191
+ for freq in self.freq_bands:
192
+ out.append(torch.sin(freq * x))
193
+ out.append(torch.cos(freq * x))
194
+ return torch.cat(out, cat_dim)
195
+
196
+
197
+ class PositionNet(nn.Module):
198
+ def __init__(self, in_dim, out_dim, fourier_freqs=8):
199
+ super().__init__()
200
+ self.in_dim = in_dim
201
+ self.out_dim = out_dim
202
+
203
+ self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
204
+ self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
205
+
206
+ self.linears = nn.Sequential(
207
+ ops.Linear(self.in_dim + self.position_dim, 512),
208
+ nn.SiLU(),
209
+ ops.Linear(512, 512),
210
+ nn.SiLU(),
211
+ ops.Linear(512, out_dim),
212
+ )
213
+
214
+ self.null_positive_feature = torch.nn.Parameter(
215
+ torch.zeros([self.in_dim]))
216
+ self.null_position_feature = torch.nn.Parameter(
217
+ torch.zeros([self.position_dim]))
218
+
219
+ def forward(self, boxes, masks, positive_embeddings):
220
+ B, N, _ = boxes.shape
221
+ masks = masks.unsqueeze(-1)
222
+ positive_embeddings = positive_embeddings
223
+
224
+ # embedding position (it may includes padding as placeholder)
225
+ xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
226
+
227
+ # learnable null embedding
228
+ positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
229
+ xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
230
+
231
+ # replace padding with learnable null embedding
232
+ positive_embeddings = positive_embeddings * \
233
+ masks + (1 - masks) * positive_null
234
+ xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
235
+
236
+ objs = self.linears(
237
+ torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
238
+ assert objs.shape == torch.Size([B, N, self.out_dim])
239
+ return objs
240
+
241
+
242
+ class Gligen(nn.Module):
243
+ def __init__(self, modules, position_net, key_dim):
244
+ super().__init__()
245
+ self.module_list = nn.ModuleList(modules)
246
+ self.position_net = position_net
247
+ self.key_dim = key_dim
248
+ self.max_objs = 30
249
+ self.current_device = torch.device("cpu")
250
+
251
+ def _set_position(self, boxes, masks, positive_embeddings):
252
+ objs = self.position_net(boxes, masks, positive_embeddings)
253
+ def func(x, extra_options):
254
+ key = extra_options["transformer_index"]
255
+ module = self.module_list[key]
256
+ return module(x, objs.to(device=x.device, dtype=x.dtype))
257
+ return func
258
+
259
+ def set_position(self, latent_image_shape, position_params, device):
260
+ batch, c, h, w = latent_image_shape
261
+ masks = torch.zeros([self.max_objs], device="cpu")
262
+ boxes = []
263
+ positive_embeddings = []
264
+ for p in position_params:
265
+ x1 = (p[4]) / w
266
+ y1 = (p[3]) / h
267
+ x2 = (p[4] + p[2]) / w
268
+ y2 = (p[3] + p[1]) / h
269
+ masks[len(boxes)] = 1.0
270
+ boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)]
271
+ positive_embeddings += [p[0]]
272
+ append_boxes = []
273
+ append_conds = []
274
+ if len(boxes) < self.max_objs:
275
+ append_boxes = [torch.zeros(
276
+ [self.max_objs - len(boxes), 4], device="cpu")]
277
+ append_conds = [torch.zeros(
278
+ [self.max_objs - len(boxes), self.key_dim], device="cpu")]
279
+
280
+ box_out = torch.cat(
281
+ boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1)
282
+ masks = masks.unsqueeze(0).repeat(batch, 1)
283
+ conds = torch.cat(positive_embeddings +
284
+ append_conds).unsqueeze(0).repeat(batch, 1, 1)
285
+ return self._set_position(
286
+ box_out.to(device),
287
+ masks.to(device),
288
+ conds.to(device))
289
+
290
+ def set_empty(self, latent_image_shape, device):
291
+ batch, c, h, w = latent_image_shape
292
+ masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1)
293
+ box_out = torch.zeros([self.max_objs, 4],
294
+ device="cpu").repeat(batch, 1, 1)
295
+ conds = torch.zeros([self.max_objs, self.key_dim],
296
+ device="cpu").repeat(batch, 1, 1)
297
+ return self._set_position(
298
+ box_out.to(device),
299
+ masks.to(device),
300
+ conds.to(device))
301
+
302
+
303
+ def load_gligen(sd):
304
+ sd_k = sd.keys()
305
+ output_list = []
306
+ key_dim = 768
307
+ for a in ["input_blocks", "middle_block", "output_blocks"]:
308
+ for b in range(20):
309
+ k_temp = filter(lambda k: "{}.{}.".format(a, b)
310
+ in k and ".fuser." in k, sd_k)
311
+ k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp)
312
+
313
+ n_sd = {}
314
+ for k in k_temp:
315
+ n_sd[k[1]] = sd[k[0]]
316
+ if len(n_sd) > 0:
317
+ query_dim = n_sd["linear.weight"].shape[0]
318
+ key_dim = n_sd["linear.weight"].shape[1]
319
+
320
+ if key_dim == 768: # SD1.x
321
+ n_heads = 8
322
+ d_head = query_dim // n_heads
323
+ else:
324
+ d_head = 64
325
+ n_heads = query_dim // d_head
326
+
327
+ gated = GatedSelfAttentionDense(
328
+ query_dim, key_dim, n_heads, d_head)
329
+ gated.load_state_dict(n_sd, strict=False)
330
+ output_list.append(gated)
331
+
332
+ if "position_net.null_positive_feature" in sd_k:
333
+ in_dim = sd["position_net.null_positive_feature"].shape[0]
334
+ out_dim = sd["position_net.linears.4.weight"].shape[0]
335
+
336
+ class WeightsLoader(torch.nn.Module):
337
+ pass
338
+ w = WeightsLoader()
339
+ w.position_net = PositionNet(in_dim, out_dim)
340
+ w.load_state_dict(sd, strict=False)
341
+
342
+ gligen = Gligen(output_list, w.position_net, key_dim)
343
+ return gligen
Backend/comfy/k_diffusion/deis.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Taken from: https://github.com/zju-pi/diff-sampler/blob/main/gits-main/solver_utils.py
2
+ #under Apache 2 license
3
+ import torch
4
+ import numpy as np
5
+
6
+ # A pytorch reimplementation of DEIS (https://github.com/qsh-zh/deis).
7
+ #############################
8
+ ### Utils for DEIS solver ###
9
+ #############################
10
+ #----------------------------------------------------------------------------
11
+ # Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
12
+
13
+ def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
14
+ vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
15
+ vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
16
+ vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
17
+ vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d
18
+ t_steps = vp_sigma_inv(vp_beta_d.clone().detach().cpu(), vp_beta_min.clone().detach().cpu())(edm_steps.clone().detach().cpu())
19
+ return t_steps, vp_beta_min, vp_beta_d + vp_beta_min
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ def cal_poly(prev_t, j, taus):
24
+ poly = 1
25
+ for k in range(prev_t.shape[0]):
26
+ if k == j:
27
+ continue
28
+ poly *= (taus - prev_t[k]) / (prev_t[j] - prev_t[k])
29
+ return poly
30
+
31
+ #----------------------------------------------------------------------------
32
+ # Transfer from t to alpha_t.
33
+
34
+ def t2alpha_fn(beta_0, beta_1, t):
35
+ return torch.exp(-0.5 * t ** 2 * (beta_1 - beta_0) - t * beta_0)
36
+
37
+ #----------------------------------------------------------------------------
38
+
39
+ def cal_intergrand(beta_0, beta_1, taus):
40
+ with torch.inference_mode(mode=False):
41
+ taus = taus.clone()
42
+ beta_0 = beta_0.clone()
43
+ beta_1 = beta_1.clone()
44
+ with torch.enable_grad():
45
+ taus.requires_grad_(True)
46
+ alpha = t2alpha_fn(beta_0, beta_1, taus)
47
+ log_alpha = alpha.log()
48
+ log_alpha.sum().backward()
49
+ d_log_alpha_dtau = taus.grad
50
+ integrand = -0.5 * d_log_alpha_dtau / torch.sqrt(alpha * (1 - alpha))
51
+ return integrand
52
+
53
+ #----------------------------------------------------------------------------
54
+
55
+ def get_deis_coeff_list(t_steps, max_order, N=10000, deis_mode='tab'):
56
+ """
57
+ Get the coefficient list for DEIS sampling.
58
+
59
+ Args:
60
+ t_steps: A pytorch tensor. The time steps for sampling.
61
+ max_order: A `int`. Maximum order of the solver. 1 <= max_order <= 4
62
+ N: A `int`. Use how many points to perform the numerical integration when deis_mode=='tab'.
63
+ deis_mode: A `str`. Select between 'tab' and 'rhoab'. Type of DEIS.
64
+ Returns:
65
+ A pytorch tensor. A batch of generated samples or sampling trajectories if return_inters=True.
66
+ """
67
+ if deis_mode == 'tab':
68
+ t_steps, beta_0, beta_1 = edm2t(t_steps)
69
+ C = []
70
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
71
+ order = min(i+1, max_order)
72
+ if order == 1:
73
+ C.append([])
74
+ else:
75
+ taus = torch.linspace(t_cur, t_next, N) # split the interval for integral appximation
76
+ dtau = (t_next - t_cur) / N
77
+ prev_t = t_steps[[i - k for k in range(order)]]
78
+ coeff_temp = []
79
+ integrand = cal_intergrand(beta_0, beta_1, taus)
80
+ for j in range(order):
81
+ poly = cal_poly(prev_t, j, taus)
82
+ coeff_temp.append(torch.sum(integrand * poly) * dtau)
83
+ C.append(coeff_temp)
84
+
85
+ elif deis_mode == 'rhoab':
86
+ # Analytical solution, second order
87
+ def get_def_intergral_2(a, b, start, end, c):
88
+ coeff = (end**3 - start**3) / 3 - (end**2 - start**2) * (a + b) / 2 + (end - start) * a * b
89
+ return coeff / ((c - a) * (c - b))
90
+
91
+ # Analytical solution, third order
92
+ def get_def_intergral_3(a, b, c, start, end, d):
93
+ coeff = (end**4 - start**4) / 4 - (end**3 - start**3) * (a + b + c) / 3 \
94
+ + (end**2 - start**2) * (a*b + a*c + b*c) / 2 - (end - start) * a * b * c
95
+ return coeff / ((d - a) * (d - b) * (d - c))
96
+
97
+ C = []
98
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
99
+ order = min(i, max_order)
100
+ if order == 0:
101
+ C.append([])
102
+ else:
103
+ prev_t = t_steps[[i - k for k in range(order+1)]]
104
+ if order == 1:
105
+ coeff_cur = ((t_next - prev_t[1])**2 - (t_cur - prev_t[1])**2) / (2 * (t_cur - prev_t[1]))
106
+ coeff_prev1 = (t_next - t_cur)**2 / (2 * (prev_t[1] - t_cur))
107
+ coeff_temp = [coeff_cur, coeff_prev1]
108
+ elif order == 2:
109
+ coeff_cur = get_def_intergral_2(prev_t[1], prev_t[2], t_cur, t_next, t_cur)
110
+ coeff_prev1 = get_def_intergral_2(t_cur, prev_t[2], t_cur, t_next, prev_t[1])
111
+ coeff_prev2 = get_def_intergral_2(t_cur, prev_t[1], t_cur, t_next, prev_t[2])
112
+ coeff_temp = [coeff_cur, coeff_prev1, coeff_prev2]
113
+ elif order == 3:
114
+ coeff_cur = get_def_intergral_3(prev_t[1], prev_t[2], prev_t[3], t_cur, t_next, t_cur)
115
+ coeff_prev1 = get_def_intergral_3(t_cur, prev_t[2], prev_t[3], t_cur, t_next, prev_t[1])
116
+ coeff_prev2 = get_def_intergral_3(t_cur, prev_t[1], prev_t[3], t_cur, t_next, prev_t[2])
117
+ coeff_prev3 = get_def_intergral_3(t_cur, prev_t[1], prev_t[2], t_cur, t_next, prev_t[3])
118
+ coeff_temp = [coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3]
119
+ C.append(coeff_temp)
120
+ return C
121
+
Backend/comfy/k_diffusion/sampling.py ADDED
@@ -0,0 +1,1050 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from scipy import integrate
4
+ import torch
5
+ from torch import nn
6
+ import torchsde
7
+ from tqdm.auto import trange, tqdm
8
+
9
+ from . import utils
10
+ from . import deis
11
+ import comfy.model_patcher
12
+
13
+ def append_zero(x):
14
+ return torch.cat([x, x.new_zeros([1])])
15
+
16
+
17
+ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
18
+ """Constructs the noise schedule of Karras et al. (2022)."""
19
+ ramp = torch.linspace(0, 1, n, device=device)
20
+ min_inv_rho = sigma_min ** (1 / rho)
21
+ max_inv_rho = sigma_max ** (1 / rho)
22
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
23
+ return append_zero(sigmas).to(device)
24
+
25
+
26
+ def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
27
+ """Constructs an exponential noise schedule."""
28
+ sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
29
+ return append_zero(sigmas)
30
+
31
+
32
+ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
33
+ """Constructs an polynomial in log sigma noise schedule."""
34
+ ramp = torch.linspace(1, 0, n, device=device) ** rho
35
+ sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
36
+ return append_zero(sigmas)
37
+
38
+
39
+ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
40
+ """Constructs a continuous VP noise schedule."""
41
+ t = torch.linspace(1, eps_s, n, device=device)
42
+ sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
43
+ return append_zero(sigmas)
44
+
45
+
46
+ def to_d(x, sigma, denoised):
47
+ """Converts a denoiser output to a Karras ODE derivative."""
48
+ return (x - denoised) / utils.append_dims(sigma, x.ndim)
49
+
50
+
51
+ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
52
+ """Calculates the noise level (sigma_down) to step down to and the amount
53
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
54
+ if not eta:
55
+ return sigma_to, 0.
56
+ sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
57
+ sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
58
+ return sigma_down, sigma_up
59
+
60
+
61
+ def default_noise_sampler(x):
62
+ return lambda sigma, sigma_next: torch.randn_like(x)
63
+
64
+
65
+ class BatchedBrownianTree:
66
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
67
+
68
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
69
+ self.cpu_tree = True
70
+ if "cpu" in kwargs:
71
+ self.cpu_tree = kwargs.pop("cpu")
72
+ t0, t1, self.sign = self.sort(t0, t1)
73
+ w0 = kwargs.get('w0', torch.zeros_like(x))
74
+ if seed is None:
75
+ seed = torch.randint(0, 2 ** 63 - 1, []).item()
76
+ self.batched = True
77
+ try:
78
+ assert len(seed) == x.shape[0]
79
+ w0 = w0[0]
80
+ except TypeError:
81
+ seed = [seed]
82
+ self.batched = False
83
+ if self.cpu_tree:
84
+ self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
85
+ else:
86
+ self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
87
+
88
+ @staticmethod
89
+ def sort(a, b):
90
+ return (a, b, 1) if a < b else (b, a, -1)
91
+
92
+ def __call__(self, t0, t1):
93
+ t0, t1, sign = self.sort(t0, t1)
94
+ if self.cpu_tree:
95
+ w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
96
+ else:
97
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
98
+
99
+ return w if self.batched else w[0]
100
+
101
+
102
+ class BrownianTreeNoiseSampler:
103
+ """A noise sampler backed by a torchsde.BrownianTree.
104
+
105
+ Args:
106
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
107
+ random samples.
108
+ sigma_min (float): The low end of the valid interval.
109
+ sigma_max (float): The high end of the valid interval.
110
+ seed (int or List[int]): The random seed. If a list of seeds is
111
+ supplied instead of a single integer, then the noise sampler will
112
+ use one BrownianTree per batch item, each with its own seed.
113
+ transform (callable): A function that maps sigma to the sampler's
114
+ internal timestep.
115
+ """
116
+
117
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
118
+ self.transform = transform
119
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
120
+ self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
121
+
122
+ def __call__(self, sigma, sigma_next):
123
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
124
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
125
+
126
+
127
+ @torch.no_grad()
128
+ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
129
+ """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
130
+ extra_args = {} if extra_args is None else extra_args
131
+ s_in = x.new_ones([x.shape[0]])
132
+ for i in trange(len(sigmas) - 1, disable=disable):
133
+ if s_churn > 0:
134
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
135
+ sigma_hat = sigmas[i] * (gamma + 1)
136
+ else:
137
+ gamma = 0
138
+ sigma_hat = sigmas[i]
139
+
140
+ if gamma > 0:
141
+ eps = torch.randn_like(x) * s_noise
142
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
143
+ denoised = model(x, sigma_hat * s_in, **extra_args)
144
+ d = to_d(x, sigma_hat, denoised)
145
+ if callback is not None:
146
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
147
+ dt = sigmas[i + 1] - sigma_hat
148
+ # Euler method
149
+ x = x + d * dt
150
+ return x
151
+
152
+
153
+ @torch.no_grad()
154
+ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
155
+ """Ancestral sampling with Euler method steps."""
156
+ extra_args = {} if extra_args is None else extra_args
157
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
158
+ s_in = x.new_ones([x.shape[0]])
159
+ for i in trange(len(sigmas) - 1, disable=disable):
160
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
161
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
162
+ if callback is not None:
163
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
164
+ d = to_d(x, sigmas[i], denoised)
165
+ # Euler method
166
+ dt = sigma_down - sigmas[i]
167
+ x = x + d * dt
168
+ if sigmas[i + 1] > 0:
169
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
170
+ return x
171
+
172
+
173
+ @torch.no_grad()
174
+ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
175
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
176
+ extra_args = {} if extra_args is None else extra_args
177
+ s_in = x.new_ones([x.shape[0]])
178
+ for i in trange(len(sigmas) - 1, disable=disable):
179
+ if s_churn > 0:
180
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
181
+ sigma_hat = sigmas[i] * (gamma + 1)
182
+ else:
183
+ gamma = 0
184
+ sigma_hat = sigmas[i]
185
+
186
+ sigma_hat = sigmas[i] * (gamma + 1)
187
+ if gamma > 0:
188
+ eps = torch.randn_like(x) * s_noise
189
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
190
+ denoised = model(x, sigma_hat * s_in, **extra_args)
191
+ d = to_d(x, sigma_hat, denoised)
192
+ if callback is not None:
193
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
194
+ dt = sigmas[i + 1] - sigma_hat
195
+ if sigmas[i + 1] == 0:
196
+ # Euler method
197
+ x = x + d * dt
198
+ else:
199
+ # Heun's method
200
+ x_2 = x + d * dt
201
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
202
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
203
+ d_prime = (d + d_2) / 2
204
+ x = x + d_prime * dt
205
+ return x
206
+
207
+
208
+ @torch.no_grad()
209
+ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
210
+ """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
211
+ extra_args = {} if extra_args is None else extra_args
212
+ s_in = x.new_ones([x.shape[0]])
213
+ for i in trange(len(sigmas) - 1, disable=disable):
214
+ if s_churn > 0:
215
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
216
+ sigma_hat = sigmas[i] * (gamma + 1)
217
+ else:
218
+ gamma = 0
219
+ sigma_hat = sigmas[i]
220
+
221
+ if gamma > 0:
222
+ eps = torch.randn_like(x) * s_noise
223
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
224
+ denoised = model(x, sigma_hat * s_in, **extra_args)
225
+ d = to_d(x, sigma_hat, denoised)
226
+ if callback is not None:
227
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
228
+ if sigmas[i + 1] == 0:
229
+ # Euler method
230
+ dt = sigmas[i + 1] - sigma_hat
231
+ x = x + d * dt
232
+ else:
233
+ # DPM-Solver-2
234
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
235
+ dt_1 = sigma_mid - sigma_hat
236
+ dt_2 = sigmas[i + 1] - sigma_hat
237
+ x_2 = x + d * dt_1
238
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
239
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
240
+ x = x + d_2 * dt_2
241
+ return x
242
+
243
+
244
+ @torch.no_grad()
245
+ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
246
+ """Ancestral sampling with DPM-Solver second-order steps."""
247
+ extra_args = {} if extra_args is None else extra_args
248
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
249
+ s_in = x.new_ones([x.shape[0]])
250
+ for i in trange(len(sigmas) - 1, disable=disable):
251
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
252
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
253
+ if callback is not None:
254
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
255
+ d = to_d(x, sigmas[i], denoised)
256
+ if sigma_down == 0:
257
+ # Euler method
258
+ dt = sigma_down - sigmas[i]
259
+ x = x + d * dt
260
+ else:
261
+ # DPM-Solver-2
262
+ sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
263
+ dt_1 = sigma_mid - sigmas[i]
264
+ dt_2 = sigma_down - sigmas[i]
265
+ x_2 = x + d * dt_1
266
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
267
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
268
+ x = x + d_2 * dt_2
269
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
270
+ return x
271
+
272
+
273
+ def linear_multistep_coeff(order, t, i, j):
274
+ if order - 1 > i:
275
+ raise ValueError(f'Order {order} too high for step {i}')
276
+ def fn(tau):
277
+ prod = 1.
278
+ for k in range(order):
279
+ if j == k:
280
+ continue
281
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
282
+ return prod
283
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
284
+
285
+
286
+ @torch.no_grad()
287
+ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
288
+ extra_args = {} if extra_args is None else extra_args
289
+ s_in = x.new_ones([x.shape[0]])
290
+ sigmas_cpu = sigmas.detach().cpu().numpy()
291
+ ds = []
292
+ for i in trange(len(sigmas) - 1, disable=disable):
293
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
294
+ d = to_d(x, sigmas[i], denoised)
295
+ ds.append(d)
296
+ if len(ds) > order:
297
+ ds.pop(0)
298
+ if callback is not None:
299
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
300
+ cur_order = min(i + 1, order)
301
+ coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
302
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
303
+ return x
304
+
305
+
306
+ class PIDStepSizeController:
307
+ """A PID controller for ODE adaptive step size control."""
308
+ def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
309
+ self.h = h
310
+ self.b1 = (pcoeff + icoeff + dcoeff) / order
311
+ self.b2 = -(pcoeff + 2 * dcoeff) / order
312
+ self.b3 = dcoeff / order
313
+ self.accept_safety = accept_safety
314
+ self.eps = eps
315
+ self.errs = []
316
+
317
+ def limiter(self, x):
318
+ return 1 + math.atan(x - 1)
319
+
320
+ def propose_step(self, error):
321
+ inv_error = 1 / (float(error) + self.eps)
322
+ if not self.errs:
323
+ self.errs = [inv_error, inv_error, inv_error]
324
+ self.errs[0] = inv_error
325
+ factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
326
+ factor = self.limiter(factor)
327
+ accept = factor >= self.accept_safety
328
+ if accept:
329
+ self.errs[2] = self.errs[1]
330
+ self.errs[1] = self.errs[0]
331
+ self.h *= factor
332
+ return accept
333
+
334
+
335
+ class DPMSolver(nn.Module):
336
+ """DPM-Solver. See https://arxiv.org/abs/2206.00927."""
337
+
338
+ def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
339
+ super().__init__()
340
+ self.model = model
341
+ self.extra_args = {} if extra_args is None else extra_args
342
+ self.eps_callback = eps_callback
343
+ self.info_callback = info_callback
344
+
345
+ def t(self, sigma):
346
+ return -sigma.log()
347
+
348
+ def sigma(self, t):
349
+ return t.neg().exp()
350
+
351
+ def eps(self, eps_cache, key, x, t, *args, **kwargs):
352
+ if key in eps_cache:
353
+ return eps_cache[key], eps_cache
354
+ sigma = self.sigma(t) * x.new_ones([x.shape[0]])
355
+ eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
356
+ if self.eps_callback is not None:
357
+ self.eps_callback()
358
+ return eps, {key: eps, **eps_cache}
359
+
360
+ def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
361
+ eps_cache = {} if eps_cache is None else eps_cache
362
+ h = t_next - t
363
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
364
+ x_1 = x - self.sigma(t_next) * h.expm1() * eps
365
+ return x_1, eps_cache
366
+
367
+ def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
368
+ eps_cache = {} if eps_cache is None else eps_cache
369
+ h = t_next - t
370
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
371
+ s1 = t + r1 * h
372
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
373
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
374
+ x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
375
+ return x_2, eps_cache
376
+
377
+ def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
378
+ eps_cache = {} if eps_cache is None else eps_cache
379
+ h = t_next - t
380
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
381
+ s1 = t + r1 * h
382
+ s2 = t + r2 * h
383
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
384
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
385
+ u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
386
+ eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
387
+ x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
388
+ return x_3, eps_cache
389
+
390
+ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
391
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
392
+ if not t_end > t_start and eta:
393
+ raise ValueError('eta must be 0 for reverse sampling')
394
+
395
+ m = math.floor(nfe / 3) + 1
396
+ ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
397
+
398
+ if nfe % 3 == 0:
399
+ orders = [3] * (m - 2) + [2, 1]
400
+ else:
401
+ orders = [3] * (m - 1) + [nfe % 3]
402
+
403
+ for i in range(len(orders)):
404
+ eps_cache = {}
405
+ t, t_next = ts[i], ts[i + 1]
406
+ if eta:
407
+ sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
408
+ t_next_ = torch.minimum(t_end, self.t(sd))
409
+ su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
410
+ else:
411
+ t_next_, su = t_next, 0.
412
+
413
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
414
+ denoised = x - self.sigma(t) * eps
415
+ if self.info_callback is not None:
416
+ self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
417
+
418
+ if orders[i] == 1:
419
+ x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
420
+ elif orders[i] == 2:
421
+ x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
422
+ else:
423
+ x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
424
+
425
+ x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
426
+
427
+ return x
428
+
429
+ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
430
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
431
+ if order not in {2, 3}:
432
+ raise ValueError('order should be 2 or 3')
433
+ forward = t_end > t_start
434
+ if not forward and eta:
435
+ raise ValueError('eta must be 0 for reverse sampling')
436
+ h_init = abs(h_init) * (1 if forward else -1)
437
+ atol = torch.tensor(atol)
438
+ rtol = torch.tensor(rtol)
439
+ s = t_start
440
+ x_prev = x
441
+ accept = True
442
+ pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
443
+ info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
444
+
445
+ while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
446
+ eps_cache = {}
447
+ t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
448
+ if eta:
449
+ sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
450
+ t_ = torch.minimum(t_end, self.t(sd))
451
+ su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
452
+ else:
453
+ t_, su = t, 0.
454
+
455
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
456
+ denoised = x - self.sigma(s) * eps
457
+
458
+ if order == 2:
459
+ x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
460
+ x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
461
+ else:
462
+ x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
463
+ x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
464
+ delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
465
+ error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
466
+ accept = pid.propose_step(error)
467
+ if accept:
468
+ x_prev = x_low
469
+ x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
470
+ s = t
471
+ info['n_accept'] += 1
472
+ else:
473
+ info['n_reject'] += 1
474
+ info['nfe'] += order
475
+ info['steps'] += 1
476
+
477
+ if self.info_callback is not None:
478
+ self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
479
+
480
+ return x, info
481
+
482
+
483
+ @torch.no_grad()
484
+ def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
485
+ """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
486
+ if sigma_min <= 0 or sigma_max <= 0:
487
+ raise ValueError('sigma_min and sigma_max must not be 0')
488
+ with tqdm(total=n, disable=disable) as pbar:
489
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
490
+ if callback is not None:
491
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
492
+ return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
493
+
494
+
495
+ @torch.no_grad()
496
+ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
497
+ """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
498
+ if sigma_min <= 0 or sigma_max <= 0:
499
+ raise ValueError('sigma_min and sigma_max must not be 0')
500
+ with tqdm(disable=disable) as pbar:
501
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
502
+ if callback is not None:
503
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
504
+ x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
505
+ if return_info:
506
+ return x, info
507
+ return x
508
+
509
+
510
+ @torch.no_grad()
511
+ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
512
+ """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
513
+ extra_args = {} if extra_args is None else extra_args
514
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
515
+ s_in = x.new_ones([x.shape[0]])
516
+ sigma_fn = lambda t: t.neg().exp()
517
+ t_fn = lambda sigma: sigma.log().neg()
518
+
519
+ for i in trange(len(sigmas) - 1, disable=disable):
520
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
521
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
522
+ if callback is not None:
523
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
524
+ if sigma_down == 0:
525
+ # Euler method
526
+ d = to_d(x, sigmas[i], denoised)
527
+ dt = sigma_down - sigmas[i]
528
+ x = x + d * dt
529
+ else:
530
+ # DPM-Solver++(2S)
531
+ t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
532
+ r = 1 / 2
533
+ h = t_next - t
534
+ s = t + r * h
535
+ x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
536
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
537
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
538
+ # Noise addition
539
+ if sigmas[i + 1] > 0:
540
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
541
+ return x
542
+
543
+
544
+ @torch.no_grad()
545
+ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
546
+ """DPM-Solver++ (stochastic)."""
547
+ if len(sigmas) <= 1:
548
+ return x
549
+
550
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
551
+ seed = extra_args.get("seed", None)
552
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
553
+ extra_args = {} if extra_args is None else extra_args
554
+ s_in = x.new_ones([x.shape[0]])
555
+ sigma_fn = lambda t: t.neg().exp()
556
+ t_fn = lambda sigma: sigma.log().neg()
557
+
558
+ for i in trange(len(sigmas) - 1, disable=disable):
559
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
560
+ if callback is not None:
561
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
562
+ if sigmas[i + 1] == 0:
563
+ # Euler method
564
+ d = to_d(x, sigmas[i], denoised)
565
+ dt = sigmas[i + 1] - sigmas[i]
566
+ x = x + d * dt
567
+ else:
568
+ # DPM-Solver++
569
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
570
+ h = t_next - t
571
+ s = t + h * r
572
+ fac = 1 / (2 * r)
573
+
574
+ # Step 1
575
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
576
+ s_ = t_fn(sd)
577
+ x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
578
+ x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
579
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
580
+
581
+ # Step 2
582
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
583
+ t_next_ = t_fn(sd)
584
+ denoised_d = (1 - fac) * denoised + fac * denoised_2
585
+ x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
586
+ x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
587
+ return x
588
+
589
+
590
+ @torch.no_grad()
591
+ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
592
+ """DPM-Solver++(2M)."""
593
+ extra_args = {} if extra_args is None else extra_args
594
+ s_in = x.new_ones([x.shape[0]])
595
+ sigma_fn = lambda t: t.neg().exp()
596
+ t_fn = lambda sigma: sigma.log().neg()
597
+ old_denoised = None
598
+
599
+ for i in trange(len(sigmas) - 1, disable=disable):
600
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
601
+ if callback is not None:
602
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
603
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
604
+ h = t_next - t
605
+ if old_denoised is None or sigmas[i + 1] == 0:
606
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
607
+ else:
608
+ h_last = t - t_fn(sigmas[i - 1])
609
+ r = h_last / h
610
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
611
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
612
+ old_denoised = denoised
613
+ return x
614
+
615
+ @torch.no_grad()
616
+ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
617
+ """DPM-Solver++(2M) SDE."""
618
+ if len(sigmas) <= 1:
619
+ return x
620
+
621
+ if solver_type not in {'heun', 'midpoint'}:
622
+ raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
623
+
624
+ seed = extra_args.get("seed", None)
625
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
626
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
627
+ extra_args = {} if extra_args is None else extra_args
628
+ s_in = x.new_ones([x.shape[0]])
629
+
630
+ old_denoised = None
631
+ h_last = None
632
+ h = None
633
+
634
+ for i in trange(len(sigmas) - 1, disable=disable):
635
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
636
+ if callback is not None:
637
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
638
+ if sigmas[i + 1] == 0:
639
+ # Denoising step
640
+ x = denoised
641
+ else:
642
+ # DPM-Solver++(2M) SDE
643
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
644
+ h = s - t
645
+ eta_h = eta * h
646
+
647
+ x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
648
+
649
+ if old_denoised is not None:
650
+ r = h_last / h
651
+ if solver_type == 'heun':
652
+ x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
653
+ elif solver_type == 'midpoint':
654
+ x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
655
+
656
+ if eta:
657
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
658
+
659
+ old_denoised = denoised
660
+ h_last = h
661
+ return x
662
+
663
+ @torch.no_grad()
664
+ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
665
+ """DPM-Solver++(3M) SDE."""
666
+
667
+ if len(sigmas) <= 1:
668
+ return x
669
+
670
+ seed = extra_args.get("seed", None)
671
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
672
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
673
+ extra_args = {} if extra_args is None else extra_args
674
+ s_in = x.new_ones([x.shape[0]])
675
+
676
+ denoised_1, denoised_2 = None, None
677
+ h, h_1, h_2 = None, None, None
678
+
679
+ for i in trange(len(sigmas) - 1, disable=disable):
680
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
681
+ if callback is not None:
682
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
683
+ if sigmas[i + 1] == 0:
684
+ # Denoising step
685
+ x = denoised
686
+ else:
687
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
688
+ h = s - t
689
+ h_eta = h * (eta + 1)
690
+
691
+ x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
692
+
693
+ if h_2 is not None:
694
+ r0 = h_1 / h
695
+ r1 = h_2 / h
696
+ d1_0 = (denoised - denoised_1) / r0
697
+ d1_1 = (denoised_1 - denoised_2) / r1
698
+ d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
699
+ d2 = (d1_0 - d1_1) / (r0 + r1)
700
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
701
+ phi_3 = phi_2 / h_eta - 0.5
702
+ x = x + phi_2 * d1 - phi_3 * d2
703
+ elif h_1 is not None:
704
+ r = h_1 / h
705
+ d = (denoised - denoised_1) / r
706
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
707
+ x = x + phi_2 * d
708
+
709
+ if eta:
710
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
711
+
712
+ denoised_1, denoised_2 = denoised, denoised_1
713
+ h_1, h_2 = h, h_1
714
+ return x
715
+
716
+ @torch.no_grad()
717
+ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
718
+ if len(sigmas) <= 1:
719
+ return x
720
+
721
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
722
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
723
+ return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
724
+
725
+ @torch.no_grad()
726
+ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
727
+ if len(sigmas) <= 1:
728
+ return x
729
+
730
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
731
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
732
+ return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
733
+
734
+ @torch.no_grad()
735
+ def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
736
+ if len(sigmas) <= 1:
737
+ return x
738
+
739
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
740
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
741
+ return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
742
+
743
+
744
+ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
745
+ alpha_cumprod = 1 / ((sigma * sigma) + 1)
746
+ alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1)
747
+ alpha = (alpha_cumprod / alpha_cumprod_prev)
748
+
749
+ mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt())
750
+ if sigma_prev > 0:
751
+ mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
752
+ return mu
753
+
754
+ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
755
+ extra_args = {} if extra_args is None else extra_args
756
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
757
+ s_in = x.new_ones([x.shape[0]])
758
+
759
+ for i in trange(len(sigmas) - 1, disable=disable):
760
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
761
+ if callback is not None:
762
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
763
+ x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler)
764
+ if sigmas[i + 1] != 0:
765
+ x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0)
766
+ return x
767
+
768
+
769
+ @torch.no_grad()
770
+ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
771
+ return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
772
+
773
+ @torch.no_grad()
774
+ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
775
+ extra_args = {} if extra_args is None else extra_args
776
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
777
+ s_in = x.new_ones([x.shape[0]])
778
+ for i in trange(len(sigmas) - 1, disable=disable):
779
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
780
+ if callback is not None:
781
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
782
+
783
+ x = denoised
784
+ if sigmas[i + 1] > 0:
785
+ x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
786
+ return x
787
+
788
+
789
+
790
+ @torch.no_grad()
791
+ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
792
+ # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
793
+ extra_args = {} if extra_args is None else extra_args
794
+ s_in = x.new_ones([x.shape[0]])
795
+ s_end = sigmas[-1]
796
+ for i in trange(len(sigmas) - 1, disable=disable):
797
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
798
+ eps = torch.randn_like(x) * s_noise
799
+ sigma_hat = sigmas[i] * (gamma + 1)
800
+ if gamma > 0:
801
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
802
+ denoised = model(x, sigma_hat * s_in, **extra_args)
803
+ d = to_d(x, sigma_hat, denoised)
804
+ if callback is not None:
805
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
806
+ dt = sigmas[i + 1] - sigma_hat
807
+ if sigmas[i + 1] == s_end:
808
+ # Euler method
809
+ x = x + d * dt
810
+ elif sigmas[i + 2] == s_end:
811
+
812
+ # Heun's method
813
+ x_2 = x + d * dt
814
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
815
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
816
+
817
+ w = 2 * sigmas[0]
818
+ w2 = sigmas[i+1]/w
819
+ w1 = 1 - w2
820
+
821
+ d_prime = d * w1 + d_2 * w2
822
+
823
+
824
+ x = x + d_prime * dt
825
+
826
+ else:
827
+ # Heun++
828
+ x_2 = x + d * dt
829
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
830
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
831
+ dt_2 = sigmas[i + 2] - sigmas[i + 1]
832
+
833
+ x_3 = x_2 + d_2 * dt_2
834
+ denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
835
+ d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
836
+
837
+ w = 3 * sigmas[0]
838
+ w2 = sigmas[i + 1] / w
839
+ w3 = sigmas[i + 2] / w
840
+ w1 = 1 - w2 - w3
841
+
842
+ d_prime = w1 * d + w2 * d_2 + w3 * d_3
843
+ x = x + d_prime * dt
844
+ return x
845
+
846
+
847
+ #From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
848
+ #under Apache 2 license
849
+ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
850
+ extra_args = {} if extra_args is None else extra_args
851
+ s_in = x.new_ones([x.shape[0]])
852
+
853
+ x_next = x
854
+
855
+ buffer_model = []
856
+ for i in trange(len(sigmas) - 1, disable=disable):
857
+ t_cur = sigmas[i]
858
+ t_next = sigmas[i + 1]
859
+
860
+ x_cur = x_next
861
+
862
+ denoised = model(x_cur, t_cur * s_in, **extra_args)
863
+ if callback is not None:
864
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
865
+
866
+ d_cur = (x_cur - denoised) / t_cur
867
+
868
+ order = min(max_order, i+1)
869
+ if order == 1: # First Euler step.
870
+ x_next = x_cur + (t_next - t_cur) * d_cur
871
+ elif order == 2: # Use one history point.
872
+ x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
873
+ elif order == 3: # Use two history points.
874
+ x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12
875
+ elif order == 4: # Use three history points.
876
+ x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24
877
+
878
+ if len(buffer_model) == max_order - 1:
879
+ for k in range(max_order - 2):
880
+ buffer_model[k] = buffer_model[k+1]
881
+ buffer_model[-1] = d_cur
882
+ else:
883
+ buffer_model.append(d_cur)
884
+
885
+ return x_next
886
+
887
+ #From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
888
+ #under Apache 2 license
889
+ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
890
+ extra_args = {} if extra_args is None else extra_args
891
+ s_in = x.new_ones([x.shape[0]])
892
+
893
+ x_next = x
894
+ t_steps = sigmas
895
+
896
+ buffer_model = []
897
+ for i in trange(len(sigmas) - 1, disable=disable):
898
+ t_cur = sigmas[i]
899
+ t_next = sigmas[i + 1]
900
+
901
+ x_cur = x_next
902
+
903
+ denoised = model(x_cur, t_cur * s_in, **extra_args)
904
+ if callback is not None:
905
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
906
+
907
+ d_cur = (x_cur - denoised) / t_cur
908
+
909
+ order = min(max_order, i+1)
910
+ if order == 1: # First Euler step.
911
+ x_next = x_cur + (t_next - t_cur) * d_cur
912
+ elif order == 2: # Use one history point.
913
+ h_n = (t_next - t_cur)
914
+ h_n_1 = (t_cur - t_steps[i-1])
915
+ coeff1 = (2 + (h_n / h_n_1)) / 2
916
+ coeff2 = -(h_n / h_n_1) / 2
917
+ x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1])
918
+ elif order == 3: # Use two history points.
919
+ h_n = (t_next - t_cur)
920
+ h_n_1 = (t_cur - t_steps[i-1])
921
+ h_n_2 = (t_steps[i-1] - t_steps[i-2])
922
+ temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
923
+ coeff1 = (2 + (h_n / h_n_1)) / 2 + temp
924
+ coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp
925
+ coeff3 = temp * h_n_1 / h_n_2
926
+ x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2])
927
+ elif order == 4: # Use three history points.
928
+ h_n = (t_next - t_cur)
929
+ h_n_1 = (t_cur - t_steps[i-1])
930
+ h_n_2 = (t_steps[i-1] - t_steps[i-2])
931
+ h_n_3 = (t_steps[i-2] - t_steps[i-3])
932
+ temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
933
+ temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \
934
+ * (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3))
935
+ coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2
936
+ coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2
937
+ coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2
938
+ coeff4 = -temp2 * (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * h_n_1 / h_n_2
939
+ x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2] + coeff4 * buffer_model[-3])
940
+
941
+ if len(buffer_model) == max_order - 1:
942
+ for k in range(max_order - 2):
943
+ buffer_model[k] = buffer_model[k+1]
944
+ buffer_model[-1] = d_cur.detach()
945
+ else:
946
+ buffer_model.append(d_cur.detach())
947
+
948
+ return x_next
949
+
950
+ #From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
951
+ #under Apache 2 license
952
+ @torch.no_grad()
953
+ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'):
954
+ extra_args = {} if extra_args is None else extra_args
955
+ s_in = x.new_ones([x.shape[0]])
956
+
957
+ x_next = x
958
+ t_steps = sigmas
959
+
960
+ coeff_list = deis.get_deis_coeff_list(t_steps, max_order, deis_mode=deis_mode)
961
+
962
+ buffer_model = []
963
+ for i in trange(len(sigmas) - 1, disable=disable):
964
+ t_cur = sigmas[i]
965
+ t_next = sigmas[i + 1]
966
+
967
+ x_cur = x_next
968
+
969
+ denoised = model(x_cur, t_cur * s_in, **extra_args)
970
+ if callback is not None:
971
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
972
+
973
+ d_cur = (x_cur - denoised) / t_cur
974
+
975
+ order = min(max_order, i+1)
976
+ if t_next <= 0:
977
+ order = 1
978
+
979
+ if order == 1: # First Euler step.
980
+ x_next = x_cur + (t_next - t_cur) * d_cur
981
+ elif order == 2: # Use one history point.
982
+ coeff_cur, coeff_prev1 = coeff_list[i]
983
+ x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1]
984
+ elif order == 3: # Use two history points.
985
+ coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i]
986
+ x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2]
987
+ elif order == 4: # Use three history points.
988
+ coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i]
989
+ x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3]
990
+
991
+ if len(buffer_model) == max_order - 1:
992
+ for k in range(max_order - 2):
993
+ buffer_model[k] = buffer_model[k+1]
994
+ buffer_model[-1] = d_cur.detach()
995
+ else:
996
+ buffer_model.append(d_cur.detach())
997
+
998
+ return x_next
999
+
1000
+ @torch.no_grad()
1001
+ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
1002
+ extra_args = {} if extra_args is None else extra_args
1003
+
1004
+ temp = [0]
1005
+ def post_cfg_function(args):
1006
+ temp[0] = args["uncond_denoised"]
1007
+ return args["denoised"]
1008
+
1009
+ model_options = extra_args.get("model_options", {}).copy()
1010
+ extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
1011
+
1012
+ s_in = x.new_ones([x.shape[0]])
1013
+ for i in trange(len(sigmas) - 1, disable=disable):
1014
+ sigma_hat = sigmas[i]
1015
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1016
+ d = to_d(x, sigma_hat, temp[0])
1017
+ if callback is not None:
1018
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1019
+ dt = sigmas[i + 1] - sigma_hat
1020
+ # Euler method
1021
+ x = denoised + d * sigmas[i + 1]
1022
+ return x
1023
+
1024
+ @torch.no_grad()
1025
+ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
1026
+ """Ancestral sampling with Euler method steps."""
1027
+ extra_args = {} if extra_args is None else extra_args
1028
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
1029
+
1030
+ temp = [0]
1031
+ def post_cfg_function(args):
1032
+ temp[0] = args["uncond_denoised"]
1033
+ return args["denoised"]
1034
+
1035
+ model_options = extra_args.get("model_options", {}).copy()
1036
+ extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
1037
+
1038
+ s_in = x.new_ones([x.shape[0]])
1039
+ for i in trange(len(sigmas) - 1, disable=disable):
1040
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1041
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
1042
+ if callback is not None:
1043
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1044
+ d = to_d(x, sigmas[i], temp[0])
1045
+ # Euler method
1046
+ dt = sigma_down - sigmas[i]
1047
+ x = denoised + d * sigma_down
1048
+ if sigmas[i + 1] > 0:
1049
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
1050
+ return x
Backend/comfy/k_diffusion/utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ import hashlib
3
+ import math
4
+ from pathlib import Path
5
+ import shutil
6
+ import urllib
7
+ import warnings
8
+
9
+ from PIL import Image
10
+ import torch
11
+ from torch import nn, optim
12
+ from torch.utils import data
13
+
14
+
15
+ def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
16
+ """Apply passed in transforms for HuggingFace Datasets."""
17
+ images = [transform(image.convert(mode)) for image in examples[image_key]]
18
+ return {image_key: images}
19
+
20
+
21
+ def append_dims(x, target_dims):
22
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
23
+ dims_to_append = target_dims - x.ndim
24
+ if dims_to_append < 0:
25
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
26
+ expanded = x[(...,) + (None,) * dims_to_append]
27
+ # MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
28
+ # https://github.com/pytorch/pytorch/issues/84364
29
+ return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
30
+
31
+
32
+ def n_params(module):
33
+ """Returns the number of trainable parameters in a module."""
34
+ return sum(p.numel() for p in module.parameters())
35
+
36
+
37
+ def download_file(path, url, digest=None):
38
+ """Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
39
+ path = Path(path)
40
+ path.parent.mkdir(parents=True, exist_ok=True)
41
+ if not path.exists():
42
+ with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
43
+ shutil.copyfileobj(response, f)
44
+ if digest is not None:
45
+ file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
46
+ if digest != file_digest:
47
+ raise OSError(f'hash of {path} (url: {url}) failed to validate')
48
+ return path
49
+
50
+
51
+ @contextmanager
52
+ def train_mode(model, mode=True):
53
+ """A context manager that places a model into training mode and restores
54
+ the previous mode on exit."""
55
+ modes = [module.training for module in model.modules()]
56
+ try:
57
+ yield model.train(mode)
58
+ finally:
59
+ for i, module in enumerate(model.modules()):
60
+ module.training = modes[i]
61
+
62
+
63
+ def eval_mode(model):
64
+ """A context manager that places a model into evaluation mode and restores
65
+ the previous mode on exit."""
66
+ return train_mode(model, False)
67
+
68
+
69
+ @torch.no_grad()
70
+ def ema_update(model, averaged_model, decay):
71
+ """Incorporates updated model parameters into an exponential moving averaged
72
+ version of a model. It should be called after each optimizer step."""
73
+ model_params = dict(model.named_parameters())
74
+ averaged_params = dict(averaged_model.named_parameters())
75
+ assert model_params.keys() == averaged_params.keys()
76
+
77
+ for name, param in model_params.items():
78
+ averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
79
+
80
+ model_buffers = dict(model.named_buffers())
81
+ averaged_buffers = dict(averaged_model.named_buffers())
82
+ assert model_buffers.keys() == averaged_buffers.keys()
83
+
84
+ for name, buf in model_buffers.items():
85
+ averaged_buffers[name].copy_(buf)
86
+
87
+
88
+ class EMAWarmup:
89
+ """Implements an EMA warmup using an inverse decay schedule.
90
+ If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
91
+ good values for models you plan to train for a million or more steps (reaches decay
92
+ factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
93
+ you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
94
+ 215.4k steps).
95
+ Args:
96
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
97
+ power (float): Exponential factor of EMA warmup. Default: 1.
98
+ min_value (float): The minimum EMA decay rate. Default: 0.
99
+ max_value (float): The maximum EMA decay rate. Default: 1.
100
+ start_at (int): The epoch to start averaging at. Default: 0.
101
+ last_epoch (int): The index of last epoch. Default: 0.
102
+ """
103
+
104
+ def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
105
+ last_epoch=0):
106
+ self.inv_gamma = inv_gamma
107
+ self.power = power
108
+ self.min_value = min_value
109
+ self.max_value = max_value
110
+ self.start_at = start_at
111
+ self.last_epoch = last_epoch
112
+
113
+ def state_dict(self):
114
+ """Returns the state of the class as a :class:`dict`."""
115
+ return dict(self.__dict__.items())
116
+
117
+ def load_state_dict(self, state_dict):
118
+ """Loads the class's state.
119
+ Args:
120
+ state_dict (dict): scaler state. Should be an object returned
121
+ from a call to :meth:`state_dict`.
122
+ """
123
+ self.__dict__.update(state_dict)
124
+
125
+ def get_value(self):
126
+ """Gets the current EMA decay rate."""
127
+ epoch = max(0, self.last_epoch - self.start_at)
128
+ value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
129
+ return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
130
+
131
+ def step(self):
132
+ """Updates the step count."""
133
+ self.last_epoch += 1
134
+
135
+
136
+ class InverseLR(optim.lr_scheduler._LRScheduler):
137
+ """Implements an inverse decay learning rate schedule with an optional exponential
138
+ warmup. When last_epoch=-1, sets initial lr as lr.
139
+ inv_gamma is the number of steps/epochs required for the learning rate to decay to
140
+ (1 / 2)**power of its original value.
141
+ Args:
142
+ optimizer (Optimizer): Wrapped optimizer.
143
+ inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
144
+ power (float): Exponential factor of learning rate decay. Default: 1.
145
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
146
+ Default: 0.
147
+ min_lr (float): The minimum learning rate. Default: 0.
148
+ last_epoch (int): The index of last epoch. Default: -1.
149
+ verbose (bool): If ``True``, prints a message to stdout for
150
+ each update. Default: ``False``.
151
+ """
152
+
153
+ def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
154
+ last_epoch=-1, verbose=False):
155
+ self.inv_gamma = inv_gamma
156
+ self.power = power
157
+ if not 0. <= warmup < 1:
158
+ raise ValueError('Invalid value for warmup')
159
+ self.warmup = warmup
160
+ self.min_lr = min_lr
161
+ super().__init__(optimizer, last_epoch, verbose)
162
+
163
+ def get_lr(self):
164
+ if not self._get_lr_called_within_step:
165
+ warnings.warn("To get the last learning rate computed by the scheduler, "
166
+ "please use `get_last_lr()`.")
167
+
168
+ return self._get_closed_form_lr()
169
+
170
+ def _get_closed_form_lr(self):
171
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
172
+ lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
173
+ return [warmup * max(self.min_lr, base_lr * lr_mult)
174
+ for base_lr in self.base_lrs]
175
+
176
+
177
+ class ExponentialLR(optim.lr_scheduler._LRScheduler):
178
+ """Implements an exponential learning rate schedule with an optional exponential
179
+ warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
180
+ continuously by decay (default 0.5) every num_steps steps.
181
+ Args:
182
+ optimizer (Optimizer): Wrapped optimizer.
183
+ num_steps (float): The number of steps to decay the learning rate by decay in.
184
+ decay (float): The factor by which to decay the learning rate every num_steps
185
+ steps. Default: 0.5.
186
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
187
+ Default: 0.
188
+ min_lr (float): The minimum learning rate. Default: 0.
189
+ last_epoch (int): The index of last epoch. Default: -1.
190
+ verbose (bool): If ``True``, prints a message to stdout for
191
+ each update. Default: ``False``.
192
+ """
193
+
194
+ def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
195
+ last_epoch=-1, verbose=False):
196
+ self.num_steps = num_steps
197
+ self.decay = decay
198
+ if not 0. <= warmup < 1:
199
+ raise ValueError('Invalid value for warmup')
200
+ self.warmup = warmup
201
+ self.min_lr = min_lr
202
+ super().__init__(optimizer, last_epoch, verbose)
203
+
204
+ def get_lr(self):
205
+ if not self._get_lr_called_within_step:
206
+ warnings.warn("To get the last learning rate computed by the scheduler, "
207
+ "please use `get_last_lr()`.")
208
+
209
+ return self._get_closed_form_lr()
210
+
211
+ def _get_closed_form_lr(self):
212
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
213
+ lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
214
+ return [warmup * max(self.min_lr, base_lr * lr_mult)
215
+ for base_lr in self.base_lrs]
216
+
217
+
218
+ def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
219
+ """Draws samples from an lognormal distribution."""
220
+ return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
221
+
222
+
223
+ def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
224
+ """Draws samples from an optionally truncated log-logistic distribution."""
225
+ min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
226
+ max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
227
+ min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
228
+ max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
229
+ u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
230
+ return u.logit().mul(scale).add(loc).exp().to(dtype)
231
+
232
+
233
+ def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
234
+ """Draws samples from an log-uniform distribution."""
235
+ min_value = math.log(min_value)
236
+ max_value = math.log(max_value)
237
+ return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
238
+
239
+
240
+ def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
241
+ """Draws samples from a truncated v-diffusion training timestep distribution."""
242
+ min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
243
+ max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
244
+ u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
245
+ return torch.tan(u * math.pi / 2) * sigma_data
246
+
247
+
248
+ def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
249
+ """Draws samples from a split lognormal distribution."""
250
+ n = torch.randn(shape, device=device, dtype=dtype).abs()
251
+ u = torch.rand(shape, device=device, dtype=dtype)
252
+ n_left = n * -scale_1 + loc
253
+ n_right = n * scale_2 + loc
254
+ ratio = scale_1 / (scale_1 + scale_2)
255
+ return torch.where(u < ratio, n_left, n_right).exp()
256
+
257
+
258
+ class FolderOfImages(data.Dataset):
259
+ """Recursively finds all images in a directory. It does not support
260
+ classes/targets."""
261
+
262
+ IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
263
+
264
+ def __init__(self, root, transform=None):
265
+ super().__init__()
266
+ self.root = Path(root)
267
+ self.transform = nn.Identity() if transform is None else transform
268
+ self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
269
+
270
+ def __repr__(self):
271
+ return f'FolderOfImages(root="{self.root}", len: {len(self)})'
272
+
273
+ def __len__(self):
274
+ return len(self.paths)
275
+
276
+ def __getitem__(self, key):
277
+ path = self.paths[key]
278
+ with open(path, 'rb') as f:
279
+ image = Image.open(f).convert('RGB')
280
+ image = self.transform(image)
281
+ return image,
282
+
283
+
284
+ class CSVLogger:
285
+ def __init__(self, filename, columns):
286
+ self.filename = Path(filename)
287
+ self.columns = columns
288
+ if self.filename.exists():
289
+ self.file = open(self.filename, 'a')
290
+ else:
291
+ self.file = open(self.filename, 'w')
292
+ self.write(*self.columns)
293
+
294
+ def write(self, *args):
295
+ print(*args, sep=',', file=self.file, flush=True)
296
+
297
+
298
+ @contextmanager
299
+ def tf32_mode(cudnn=None, matmul=None):
300
+ """A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
301
+ cudnn_old = torch.backends.cudnn.allow_tf32
302
+ matmul_old = torch.backends.cuda.matmul.allow_tf32
303
+ try:
304
+ if cudnn is not None:
305
+ torch.backends.cudnn.allow_tf32 = cudnn
306
+ if matmul is not None:
307
+ torch.backends.cuda.matmul.allow_tf32 = matmul
308
+ yield
309
+ finally:
310
+ if cudnn is not None:
311
+ torch.backends.cudnn.allow_tf32 = cudnn_old
312
+ if matmul is not None:
313
+ torch.backends.cuda.matmul.allow_tf32 = matmul_old
Backend/comfy/latent_formats.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class LatentFormat:
4
+ scale_factor = 1.0
5
+ latent_channels = 4
6
+ latent_rgb_factors = None
7
+ taesd_decoder_name = None
8
+
9
+ def process_in(self, latent):
10
+ return latent * self.scale_factor
11
+
12
+ def process_out(self, latent):
13
+ return latent / self.scale_factor
14
+
15
+ class SD15(LatentFormat):
16
+ def __init__(self, scale_factor=0.18215):
17
+ self.scale_factor = scale_factor
18
+ self.latent_rgb_factors = [
19
+ # R G B
20
+ [ 0.3512, 0.2297, 0.3227],
21
+ [ 0.3250, 0.4974, 0.2350],
22
+ [-0.2829, 0.1762, 0.2721],
23
+ [-0.2120, -0.2616, -0.7177]
24
+ ]
25
+ self.taesd_decoder_name = "taesd_decoder"
26
+
27
+ class SDXL(LatentFormat):
28
+ scale_factor = 0.13025
29
+
30
+ def __init__(self):
31
+ self.latent_rgb_factors = [
32
+ # R G B
33
+ [ 0.3920, 0.4054, 0.4549],
34
+ [-0.2634, -0.0196, 0.0653],
35
+ [ 0.0568, 0.1687, -0.0755],
36
+ [-0.3112, -0.2359, -0.2076]
37
+ ]
38
+ self.taesd_decoder_name = "taesdxl_decoder"
39
+
40
+ class SDXL_Playground_2_5(LatentFormat):
41
+ def __init__(self):
42
+ self.scale_factor = 0.5
43
+ self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1)
44
+ self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1)
45
+
46
+ self.latent_rgb_factors = [
47
+ # R G B
48
+ [ 0.3920, 0.4054, 0.4549],
49
+ [-0.2634, -0.0196, 0.0653],
50
+ [ 0.0568, 0.1687, -0.0755],
51
+ [-0.3112, -0.2359, -0.2076]
52
+ ]
53
+ self.taesd_decoder_name = "taesdxl_decoder"
54
+
55
+ def process_in(self, latent):
56
+ latents_mean = self.latents_mean.to(latent.device, latent.dtype)
57
+ latents_std = self.latents_std.to(latent.device, latent.dtype)
58
+ return (latent - latents_mean) * self.scale_factor / latents_std
59
+
60
+ def process_out(self, latent):
61
+ latents_mean = self.latents_mean.to(latent.device, latent.dtype)
62
+ latents_std = self.latents_std.to(latent.device, latent.dtype)
63
+ return latent * latents_std / self.scale_factor + latents_mean
64
+
65
+
66
+ class SD_X4(LatentFormat):
67
+ def __init__(self):
68
+ self.scale_factor = 0.08333
69
+ self.latent_rgb_factors = [
70
+ [-0.2340, -0.3863, -0.3257],
71
+ [ 0.0994, 0.0885, -0.0908],
72
+ [-0.2833, -0.2349, -0.3741],
73
+ [ 0.2523, -0.0055, -0.1651]
74
+ ]
75
+
76
+ class SC_Prior(LatentFormat):
77
+ latent_channels = 16
78
+ def __init__(self):
79
+ self.scale_factor = 1.0
80
+ self.latent_rgb_factors = [
81
+ [-0.0326, -0.0204, -0.0127],
82
+ [-0.1592, -0.0427, 0.0216],
83
+ [ 0.0873, 0.0638, -0.0020],
84
+ [-0.0602, 0.0442, 0.1304],
85
+ [ 0.0800, -0.0313, -0.1796],
86
+ [-0.0810, -0.0638, -0.1581],
87
+ [ 0.1791, 0.1180, 0.0967],
88
+ [ 0.0740, 0.1416, 0.0432],
89
+ [-0.1745, -0.1888, -0.1373],
90
+ [ 0.2412, 0.1577, 0.0928],
91
+ [ 0.1908, 0.0998, 0.0682],
92
+ [ 0.0209, 0.0365, -0.0092],
93
+ [ 0.0448, -0.0650, -0.1728],
94
+ [-0.1658, -0.1045, -0.1308],
95
+ [ 0.0542, 0.1545, 0.1325],
96
+ [-0.0352, -0.1672, -0.2541]
97
+ ]
98
+
99
+ class SC_B(LatentFormat):
100
+ def __init__(self):
101
+ self.scale_factor = 1.0 / 0.43
102
+ self.latent_rgb_factors = [
103
+ [ 0.1121, 0.2006, 0.1023],
104
+ [-0.2093, -0.0222, -0.0195],
105
+ [-0.3087, -0.1535, 0.0366],
106
+ [ 0.0290, -0.1574, -0.4078]
107
+ ]
108
+
109
+ class SD3(LatentFormat):
110
+ latent_channels = 16
111
+ def __init__(self):
112
+ self.scale_factor = 1.5305
113
+ self.shift_factor = 0.0609
114
+ self.latent_rgb_factors = [
115
+ [-0.0645, 0.0177, 0.1052],
116
+ [ 0.0028, 0.0312, 0.0650],
117
+ [ 0.1848, 0.0762, 0.0360],
118
+ [ 0.0944, 0.0360, 0.0889],
119
+ [ 0.0897, 0.0506, -0.0364],
120
+ [-0.0020, 0.1203, 0.0284],
121
+ [ 0.0855, 0.0118, 0.0283],
122
+ [-0.0539, 0.0658, 0.1047],
123
+ [-0.0057, 0.0116, 0.0700],
124
+ [-0.0412, 0.0281, -0.0039],
125
+ [ 0.1106, 0.1171, 0.1220],
126
+ [-0.0248, 0.0682, -0.0481],
127
+ [ 0.0815, 0.0846, 0.1207],
128
+ [-0.0120, -0.0055, -0.0867],
129
+ [-0.0749, -0.0634, -0.0456],
130
+ [-0.1418, -0.1457, -0.1259]
131
+ ]
132
+ self.taesd_decoder_name = "taesd3_decoder"
133
+
134
+ def process_in(self, latent):
135
+ return (latent - self.shift_factor) * self.scale_factor
136
+
137
+ def process_out(self, latent):
138
+ return (latent / self.scale_factor) + self.shift_factor
139
+
140
+ class StableAudio1(LatentFormat):
141
+ latent_channels = 64
142
+
143
+ class Flux(SD3):
144
+ def __init__(self):
145
+ self.scale_factor = 0.3611
146
+ self.shift_factor = 0.1159
147
+ self.latent_rgb_factors =[
148
+ [-0.0404, 0.0159, 0.0609],
149
+ [ 0.0043, 0.0298, 0.0850],
150
+ [ 0.0328, -0.0749, -0.0503],
151
+ [-0.0245, 0.0085, 0.0549],
152
+ [ 0.0966, 0.0894, 0.0530],
153
+ [ 0.0035, 0.0399, 0.0123],
154
+ [ 0.0583, 0.1184, 0.1262],
155
+ [-0.0191, -0.0206, -0.0306],
156
+ [-0.0324, 0.0055, 0.1001],
157
+ [ 0.0955, 0.0659, -0.0545],
158
+ [-0.0504, 0.0231, -0.0013],
159
+ [ 0.0500, -0.0008, -0.0088],
160
+ [ 0.0982, 0.0941, 0.0976],
161
+ [-0.1233, -0.0280, -0.0897],
162
+ [-0.0005, -0.0530, -0.0020],
163
+ [-0.1273, -0.0932, -0.0680]
164
+ ]
165
+
166
+ def process_in(self, latent):
167
+ return (latent - self.shift_factor) * self.scale_factor
168
+
169
+ def process_out(self, latent):
170
+ return (latent / self.scale_factor) + self.shift_factor
Backend/comfy/ldm/audio/autoencoder.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code adapted from: https://github.com/Stability-AI/stable-audio-tools
2
+
3
+ import torch
4
+ from torch import nn
5
+ from typing import Literal, Dict, Any
6
+ import math
7
+ import comfy.ops
8
+ ops = comfy.ops.disable_weight_init
9
+
10
+ def vae_sample(mean, scale):
11
+ stdev = nn.functional.softplus(scale) + 1e-4
12
+ var = stdev * stdev
13
+ logvar = torch.log(var)
14
+ latents = torch.randn_like(mean) * stdev + mean
15
+
16
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
17
+
18
+ return latents, kl
19
+
20
+ class VAEBottleneck(nn.Module):
21
+ def __init__(self):
22
+ super().__init__()
23
+ self.is_discrete = False
24
+
25
+ def encode(self, x, return_info=False, **kwargs):
26
+ info = {}
27
+
28
+ mean, scale = x.chunk(2, dim=1)
29
+
30
+ x, kl = vae_sample(mean, scale)
31
+
32
+ info["kl"] = kl
33
+
34
+ if return_info:
35
+ return x, info
36
+ else:
37
+ return x
38
+
39
+ def decode(self, x):
40
+ return x
41
+
42
+
43
+ def snake_beta(x, alpha, beta):
44
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
45
+
46
+ # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
47
+ class SnakeBeta(nn.Module):
48
+
49
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
50
+ super(SnakeBeta, self).__init__()
51
+ self.in_features = in_features
52
+
53
+ # initialize alpha
54
+ self.alpha_logscale = alpha_logscale
55
+ if self.alpha_logscale: # log scale alphas initialized to zeros
56
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
57
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
58
+ else: # linear scale alphas initialized to ones
59
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
60
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
61
+
62
+ # self.alpha.requires_grad = alpha_trainable
63
+ # self.beta.requires_grad = alpha_trainable
64
+
65
+ self.no_div_by_zero = 0.000000001
66
+
67
+ def forward(self, x):
68
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T]
69
+ beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device)
70
+ if self.alpha_logscale:
71
+ alpha = torch.exp(alpha)
72
+ beta = torch.exp(beta)
73
+ x = snake_beta(x, alpha, beta)
74
+
75
+ return x
76
+
77
+ def WNConv1d(*args, **kwargs):
78
+ try:
79
+ return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
80
+ except:
81
+ return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
82
+
83
+ def WNConvTranspose1d(*args, **kwargs):
84
+ try:
85
+ return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
86
+ except:
87
+ return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
88
+
89
+ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
90
+ if activation == "elu":
91
+ act = torch.nn.ELU()
92
+ elif activation == "snake":
93
+ act = SnakeBeta(channels)
94
+ elif activation == "none":
95
+ act = torch.nn.Identity()
96
+ else:
97
+ raise ValueError(f"Unknown activation {activation}")
98
+
99
+ if antialias:
100
+ act = Activation1d(act)
101
+
102
+ return act
103
+
104
+
105
+ class ResidualUnit(nn.Module):
106
+ def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
107
+ super().__init__()
108
+
109
+ self.dilation = dilation
110
+
111
+ padding = (dilation * (7-1)) // 2
112
+
113
+ self.layers = nn.Sequential(
114
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
115
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
116
+ kernel_size=7, dilation=dilation, padding=padding),
117
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
118
+ WNConv1d(in_channels=out_channels, out_channels=out_channels,
119
+ kernel_size=1)
120
+ )
121
+
122
+ def forward(self, x):
123
+ res = x
124
+
125
+ #x = checkpoint(self.layers, x)
126
+ x = self.layers(x)
127
+
128
+ return x + res
129
+
130
+ class EncoderBlock(nn.Module):
131
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
132
+ super().__init__()
133
+
134
+ self.layers = nn.Sequential(
135
+ ResidualUnit(in_channels=in_channels,
136
+ out_channels=in_channels, dilation=1, use_snake=use_snake),
137
+ ResidualUnit(in_channels=in_channels,
138
+ out_channels=in_channels, dilation=3, use_snake=use_snake),
139
+ ResidualUnit(in_channels=in_channels,
140
+ out_channels=in_channels, dilation=9, use_snake=use_snake),
141
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
142
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
143
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
144
+ )
145
+
146
+ def forward(self, x):
147
+ return self.layers(x)
148
+
149
+ class DecoderBlock(nn.Module):
150
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
151
+ super().__init__()
152
+
153
+ if use_nearest_upsample:
154
+ upsample_layer = nn.Sequential(
155
+ nn.Upsample(scale_factor=stride, mode="nearest"),
156
+ WNConv1d(in_channels=in_channels,
157
+ out_channels=out_channels,
158
+ kernel_size=2*stride,
159
+ stride=1,
160
+ bias=False,
161
+ padding='same')
162
+ )
163
+ else:
164
+ upsample_layer = WNConvTranspose1d(in_channels=in_channels,
165
+ out_channels=out_channels,
166
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
167
+
168
+ self.layers = nn.Sequential(
169
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
170
+ upsample_layer,
171
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
172
+ dilation=1, use_snake=use_snake),
173
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
174
+ dilation=3, use_snake=use_snake),
175
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
176
+ dilation=9, use_snake=use_snake),
177
+ )
178
+
179
+ def forward(self, x):
180
+ return self.layers(x)
181
+
182
+ class OobleckEncoder(nn.Module):
183
+ def __init__(self,
184
+ in_channels=2,
185
+ channels=128,
186
+ latent_dim=32,
187
+ c_mults = [1, 2, 4, 8],
188
+ strides = [2, 4, 8, 8],
189
+ use_snake=False,
190
+ antialias_activation=False
191
+ ):
192
+ super().__init__()
193
+
194
+ c_mults = [1] + c_mults
195
+
196
+ self.depth = len(c_mults)
197
+
198
+ layers = [
199
+ WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
200
+ ]
201
+
202
+ for i in range(self.depth-1):
203
+ layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
204
+
205
+ layers += [
206
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
207
+ WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
208
+ ]
209
+
210
+ self.layers = nn.Sequential(*layers)
211
+
212
+ def forward(self, x):
213
+ return self.layers(x)
214
+
215
+
216
+ class OobleckDecoder(nn.Module):
217
+ def __init__(self,
218
+ out_channels=2,
219
+ channels=128,
220
+ latent_dim=32,
221
+ c_mults = [1, 2, 4, 8],
222
+ strides = [2, 4, 8, 8],
223
+ use_snake=False,
224
+ antialias_activation=False,
225
+ use_nearest_upsample=False,
226
+ final_tanh=True):
227
+ super().__init__()
228
+
229
+ c_mults = [1] + c_mults
230
+
231
+ self.depth = len(c_mults)
232
+
233
+ layers = [
234
+ WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
235
+ ]
236
+
237
+ for i in range(self.depth-1, 0, -1):
238
+ layers += [DecoderBlock(
239
+ in_channels=c_mults[i]*channels,
240
+ out_channels=c_mults[i-1]*channels,
241
+ stride=strides[i-1],
242
+ use_snake=use_snake,
243
+ antialias_activation=antialias_activation,
244
+ use_nearest_upsample=use_nearest_upsample
245
+ )
246
+ ]
247
+
248
+ layers += [
249
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
250
+ WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
251
+ nn.Tanh() if final_tanh else nn.Identity()
252
+ ]
253
+
254
+ self.layers = nn.Sequential(*layers)
255
+
256
+ def forward(self, x):
257
+ return self.layers(x)
258
+
259
+
260
+ class AudioOobleckVAE(nn.Module):
261
+ def __init__(self,
262
+ in_channels=2,
263
+ channels=128,
264
+ latent_dim=64,
265
+ c_mults = [1, 2, 4, 8, 16],
266
+ strides = [2, 4, 4, 8, 8],
267
+ use_snake=True,
268
+ antialias_activation=False,
269
+ use_nearest_upsample=False,
270
+ final_tanh=False):
271
+ super().__init__()
272
+ self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation)
273
+ self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation,
274
+ use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh)
275
+ self.bottleneck = VAEBottleneck()
276
+
277
+ def encode(self, x):
278
+ return self.bottleneck.encode(self.encoder(x))
279
+
280
+ def decode(self, x):
281
+ return self.decoder(self.bottleneck.decode(x))
282
+
Backend/comfy/ldm/audio/dit.py ADDED
@@ -0,0 +1,891 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code adapted from: https://github.com/Stability-AI/stable-audio-tools
2
+
3
+ from comfy.ldm.modules.attention import optimized_attention
4
+ import typing as tp
5
+
6
+ import torch
7
+
8
+ from einops import rearrange
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ import math
12
+ import comfy.ops
13
+
14
+ class FourierFeatures(nn.Module):
15
+ def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
16
+ super().__init__()
17
+ assert out_features % 2 == 0
18
+ self.weight = nn.Parameter(torch.empty(
19
+ [out_features // 2, in_features], dtype=dtype, device=device))
20
+
21
+ def forward(self, input):
22
+ f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input)
23
+ return torch.cat([f.cos(), f.sin()], dim=-1)
24
+
25
+ # norms
26
+ class LayerNorm(nn.Module):
27
+ def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None):
28
+ """
29
+ bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
30
+ """
31
+ super().__init__()
32
+
33
+ self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
34
+
35
+ if bias:
36
+ self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
37
+ else:
38
+ self.beta = None
39
+
40
+ def forward(self, x):
41
+ beta = self.beta
42
+ if beta is not None:
43
+ beta = comfy.ops.cast_to_input(beta, x)
44
+ return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta)
45
+
46
+ class GLU(nn.Module):
47
+ def __init__(
48
+ self,
49
+ dim_in,
50
+ dim_out,
51
+ activation,
52
+ use_conv = False,
53
+ conv_kernel_size = 3,
54
+ dtype=None,
55
+ device=None,
56
+ operations=None,
57
+ ):
58
+ super().__init__()
59
+ self.act = activation
60
+ self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2), dtype=dtype, device=device)
61
+ self.use_conv = use_conv
62
+
63
+ def forward(self, x):
64
+ if self.use_conv:
65
+ x = rearrange(x, 'b n d -> b d n')
66
+ x = self.proj(x)
67
+ x = rearrange(x, 'b d n -> b n d')
68
+ else:
69
+ x = self.proj(x)
70
+
71
+ x, gate = x.chunk(2, dim = -1)
72
+ return x * self.act(gate)
73
+
74
+ class AbsolutePositionalEmbedding(nn.Module):
75
+ def __init__(self, dim, max_seq_len):
76
+ super().__init__()
77
+ self.scale = dim ** -0.5
78
+ self.max_seq_len = max_seq_len
79
+ self.emb = nn.Embedding(max_seq_len, dim)
80
+
81
+ def forward(self, x, pos = None, seq_start_pos = None):
82
+ seq_len, device = x.shape[1], x.device
83
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
84
+
85
+ if pos is None:
86
+ pos = torch.arange(seq_len, device = device)
87
+
88
+ if seq_start_pos is not None:
89
+ pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
90
+
91
+ pos_emb = self.emb(pos)
92
+ pos_emb = pos_emb * self.scale
93
+ return pos_emb
94
+
95
+ class ScaledSinusoidalEmbedding(nn.Module):
96
+ def __init__(self, dim, theta = 10000):
97
+ super().__init__()
98
+ assert (dim % 2) == 0, 'dimension must be divisible by 2'
99
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
100
+
101
+ half_dim = dim // 2
102
+ freq_seq = torch.arange(half_dim).float() / half_dim
103
+ inv_freq = theta ** -freq_seq
104
+ self.register_buffer('inv_freq', inv_freq, persistent = False)
105
+
106
+ def forward(self, x, pos = None, seq_start_pos = None):
107
+ seq_len, device = x.shape[1], x.device
108
+
109
+ if pos is None:
110
+ pos = torch.arange(seq_len, device = device)
111
+
112
+ if seq_start_pos is not None:
113
+ pos = pos - seq_start_pos[..., None]
114
+
115
+ emb = torch.einsum('i, j -> i j', pos, self.inv_freq)
116
+ emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
117
+ return emb * self.scale
118
+
119
+ class RotaryEmbedding(nn.Module):
120
+ def __init__(
121
+ self,
122
+ dim,
123
+ use_xpos = False,
124
+ scale_base = 512,
125
+ interpolation_factor = 1.,
126
+ base = 10000,
127
+ base_rescale_factor = 1.,
128
+ dtype=None,
129
+ device=None,
130
+ ):
131
+ super().__init__()
132
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
133
+ # has some connection to NTK literature
134
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
135
+ base *= base_rescale_factor ** (dim / (dim - 2))
136
+
137
+ # inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
138
+ self.register_buffer('inv_freq', torch.empty((dim // 2,), device=device, dtype=dtype))
139
+
140
+ assert interpolation_factor >= 1.
141
+ self.interpolation_factor = interpolation_factor
142
+
143
+ if not use_xpos:
144
+ self.register_buffer('scale', None)
145
+ return
146
+
147
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
148
+
149
+ self.scale_base = scale_base
150
+ self.register_buffer('scale', scale)
151
+
152
+ def forward_from_seq_len(self, seq_len, device, dtype):
153
+ # device = self.inv_freq.device
154
+
155
+ t = torch.arange(seq_len, device=device, dtype=dtype)
156
+ return self.forward(t)
157
+
158
+ def forward(self, t):
159
+ # device = self.inv_freq.device
160
+ device = t.device
161
+ dtype = t.dtype
162
+
163
+ # t = t.to(torch.float32)
164
+
165
+ t = t / self.interpolation_factor
166
+
167
+ freqs = torch.einsum('i , j -> i j', t, comfy.ops.cast_to_input(self.inv_freq, t))
168
+ freqs = torch.cat((freqs, freqs), dim = -1)
169
+
170
+ if self.scale is None:
171
+ return freqs, 1.
172
+
173
+ power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
174
+ scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
175
+ scale = torch.cat((scale, scale), dim = -1)
176
+
177
+ return freqs, scale
178
+
179
+ def rotate_half(x):
180
+ x = rearrange(x, '... (j d) -> ... j d', j = 2)
181
+ x1, x2 = x.unbind(dim = -2)
182
+ return torch.cat((-x2, x1), dim = -1)
183
+
184
+ def apply_rotary_pos_emb(t, freqs, scale = 1):
185
+ out_dtype = t.dtype
186
+
187
+ # cast to float32 if necessary for numerical stability
188
+ dtype = t.dtype #reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
189
+ rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
190
+ freqs, t = freqs.to(dtype), t.to(dtype)
191
+ freqs = freqs[-seq_len:, :]
192
+
193
+ if t.ndim == 4 and freqs.ndim == 3:
194
+ freqs = rearrange(freqs, 'b n d -> b 1 n d')
195
+
196
+ # partial rotary embeddings, Wang et al. GPT-J
197
+ t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
198
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
199
+
200
+ t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
201
+
202
+ return torch.cat((t, t_unrotated), dim = -1)
203
+
204
+ class FeedForward(nn.Module):
205
+ def __init__(
206
+ self,
207
+ dim,
208
+ dim_out = None,
209
+ mult = 4,
210
+ no_bias = False,
211
+ glu = True,
212
+ use_conv = False,
213
+ conv_kernel_size = 3,
214
+ zero_init_output = True,
215
+ dtype=None,
216
+ device=None,
217
+ operations=None,
218
+ ):
219
+ super().__init__()
220
+ inner_dim = int(dim * mult)
221
+
222
+ # Default to SwiGLU
223
+
224
+ activation = nn.SiLU()
225
+
226
+ dim_out = dim if dim_out is None else dim_out
227
+
228
+ if glu:
229
+ linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
230
+ else:
231
+ linear_in = nn.Sequential(
232
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
233
+ operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
234
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
235
+ activation
236
+ )
237
+
238
+ linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device)
239
+
240
+ # # init last linear layer to 0
241
+ # if zero_init_output:
242
+ # nn.init.zeros_(linear_out.weight)
243
+ # if not no_bias:
244
+ # nn.init.zeros_(linear_out.bias)
245
+
246
+
247
+ self.ff = nn.Sequential(
248
+ linear_in,
249
+ Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
250
+ linear_out,
251
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
252
+ )
253
+
254
+ def forward(self, x):
255
+ return self.ff(x)
256
+
257
+ class Attention(nn.Module):
258
+ def __init__(
259
+ self,
260
+ dim,
261
+ dim_heads = 64,
262
+ dim_context = None,
263
+ causal = False,
264
+ zero_init_output=True,
265
+ qk_norm = False,
266
+ natten_kernel_size = None,
267
+ dtype=None,
268
+ device=None,
269
+ operations=None,
270
+ ):
271
+ super().__init__()
272
+ self.dim = dim
273
+ self.dim_heads = dim_heads
274
+ self.causal = causal
275
+
276
+ dim_kv = dim_context if dim_context is not None else dim
277
+
278
+ self.num_heads = dim // dim_heads
279
+ self.kv_heads = dim_kv // dim_heads
280
+
281
+ if dim_context is not None:
282
+ self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
283
+ self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device)
284
+ else:
285
+ self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device)
286
+
287
+ self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
288
+
289
+ # if zero_init_output:
290
+ # nn.init.zeros_(self.to_out.weight)
291
+
292
+ self.qk_norm = qk_norm
293
+
294
+
295
+ def forward(
296
+ self,
297
+ x,
298
+ context = None,
299
+ mask = None,
300
+ context_mask = None,
301
+ rotary_pos_emb = None,
302
+ causal = None
303
+ ):
304
+ h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
305
+
306
+ kv_input = context if has_context else x
307
+
308
+ if hasattr(self, 'to_q'):
309
+ # Use separate linear projections for q and k/v
310
+ q = self.to_q(x)
311
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
312
+
313
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
314
+
315
+ k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
316
+ else:
317
+ # Use fused linear projection
318
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
319
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
320
+
321
+ # Normalize q and k for cosine sim attention
322
+ if self.qk_norm:
323
+ q = F.normalize(q, dim=-1)
324
+ k = F.normalize(k, dim=-1)
325
+
326
+ if rotary_pos_emb is not None and not has_context:
327
+ freqs, _ = rotary_pos_emb
328
+
329
+ q_dtype = q.dtype
330
+ k_dtype = k.dtype
331
+
332
+ q = q.to(torch.float32)
333
+ k = k.to(torch.float32)
334
+ freqs = freqs.to(torch.float32)
335
+
336
+ q = apply_rotary_pos_emb(q, freqs)
337
+ k = apply_rotary_pos_emb(k, freqs)
338
+
339
+ q = q.to(q_dtype)
340
+ k = k.to(k_dtype)
341
+
342
+ input_mask = context_mask
343
+
344
+ if input_mask is None and not has_context:
345
+ input_mask = mask
346
+
347
+ # determine masking
348
+ masks = []
349
+ final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
350
+
351
+ if input_mask is not None:
352
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
353
+ masks.append(~input_mask)
354
+
355
+ # Other masks will be added here later
356
+
357
+ if len(masks) > 0:
358
+ final_attn_mask = ~or_reduce(masks)
359
+
360
+ n, device = q.shape[-2], q.device
361
+
362
+ causal = self.causal if causal is None else causal
363
+
364
+ if n == 1 and causal:
365
+ causal = False
366
+
367
+ if h != kv_h:
368
+ # Repeat interleave kv_heads to match q_heads
369
+ heads_per_kv_head = h // kv_h
370
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
371
+
372
+ out = optimized_attention(q, k, v, h, skip_reshape=True)
373
+ out = self.to_out(out)
374
+
375
+ if mask is not None:
376
+ mask = rearrange(mask, 'b n -> b n 1')
377
+ out = out.masked_fill(~mask, 0.)
378
+
379
+ return out
380
+
381
+ class ConformerModule(nn.Module):
382
+ def __init__(
383
+ self,
384
+ dim,
385
+ norm_kwargs = {},
386
+ ):
387
+
388
+ super().__init__()
389
+
390
+ self.dim = dim
391
+
392
+ self.in_norm = LayerNorm(dim, **norm_kwargs)
393
+ self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
394
+ self.glu = GLU(dim, dim, nn.SiLU())
395
+ self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
396
+ self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
397
+ self.swish = nn.SiLU()
398
+ self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
399
+
400
+ def forward(self, x):
401
+ x = self.in_norm(x)
402
+ x = rearrange(x, 'b n d -> b d n')
403
+ x = self.pointwise_conv(x)
404
+ x = rearrange(x, 'b d n -> b n d')
405
+ x = self.glu(x)
406
+ x = rearrange(x, 'b n d -> b d n')
407
+ x = self.depthwise_conv(x)
408
+ x = rearrange(x, 'b d n -> b n d')
409
+ x = self.mid_norm(x)
410
+ x = self.swish(x)
411
+ x = rearrange(x, 'b n d -> b d n')
412
+ x = self.pointwise_conv_2(x)
413
+ x = rearrange(x, 'b d n -> b n d')
414
+
415
+ return x
416
+
417
+ class TransformerBlock(nn.Module):
418
+ def __init__(
419
+ self,
420
+ dim,
421
+ dim_heads = 64,
422
+ cross_attend = False,
423
+ dim_context = None,
424
+ global_cond_dim = None,
425
+ causal = False,
426
+ zero_init_branch_outputs = True,
427
+ conformer = False,
428
+ layer_ix = -1,
429
+ remove_norms = False,
430
+ attn_kwargs = {},
431
+ ff_kwargs = {},
432
+ norm_kwargs = {},
433
+ dtype=None,
434
+ device=None,
435
+ operations=None,
436
+ ):
437
+
438
+ super().__init__()
439
+ self.dim = dim
440
+ self.dim_heads = dim_heads
441
+ self.cross_attend = cross_attend
442
+ self.dim_context = dim_context
443
+ self.causal = causal
444
+
445
+ self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
446
+
447
+ self.self_attn = Attention(
448
+ dim,
449
+ dim_heads = dim_heads,
450
+ causal = causal,
451
+ zero_init_output=zero_init_branch_outputs,
452
+ dtype=dtype,
453
+ device=device,
454
+ operations=operations,
455
+ **attn_kwargs
456
+ )
457
+
458
+ if cross_attend:
459
+ self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
460
+ self.cross_attn = Attention(
461
+ dim,
462
+ dim_heads = dim_heads,
463
+ dim_context=dim_context,
464
+ causal = causal,
465
+ zero_init_output=zero_init_branch_outputs,
466
+ dtype=dtype,
467
+ device=device,
468
+ operations=operations,
469
+ **attn_kwargs
470
+ )
471
+
472
+ self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
473
+ self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs)
474
+
475
+ self.layer_ix = layer_ix
476
+
477
+ self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
478
+
479
+ self.global_cond_dim = global_cond_dim
480
+
481
+ if global_cond_dim is not None:
482
+ self.to_scale_shift_gate = nn.Sequential(
483
+ nn.SiLU(),
484
+ nn.Linear(global_cond_dim, dim * 6, bias=False)
485
+ )
486
+
487
+ nn.init.zeros_(self.to_scale_shift_gate[1].weight)
488
+ #nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
489
+
490
+ def forward(
491
+ self,
492
+ x,
493
+ context = None,
494
+ global_cond=None,
495
+ mask = None,
496
+ context_mask = None,
497
+ rotary_pos_emb = None
498
+ ):
499
+ if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
500
+
501
+ scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
502
+
503
+ # self-attention with adaLN
504
+ residual = x
505
+ x = self.pre_norm(x)
506
+ x = x * (1 + scale_self) + shift_self
507
+ x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
508
+ x = x * torch.sigmoid(1 - gate_self)
509
+ x = x + residual
510
+
511
+ if context is not None:
512
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
513
+
514
+ if self.conformer is not None:
515
+ x = x + self.conformer(x)
516
+
517
+ # feedforward with adaLN
518
+ residual = x
519
+ x = self.ff_norm(x)
520
+ x = x * (1 + scale_ff) + shift_ff
521
+ x = self.ff(x)
522
+ x = x * torch.sigmoid(1 - gate_ff)
523
+ x = x + residual
524
+
525
+ else:
526
+ x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
527
+
528
+ if context is not None:
529
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
530
+
531
+ if self.conformer is not None:
532
+ x = x + self.conformer(x)
533
+
534
+ x = x + self.ff(self.ff_norm(x))
535
+
536
+ return x
537
+
538
+ class ContinuousTransformer(nn.Module):
539
+ def __init__(
540
+ self,
541
+ dim,
542
+ depth,
543
+ *,
544
+ dim_in = None,
545
+ dim_out = None,
546
+ dim_heads = 64,
547
+ cross_attend=False,
548
+ cond_token_dim=None,
549
+ global_cond_dim=None,
550
+ causal=False,
551
+ rotary_pos_emb=True,
552
+ zero_init_branch_outputs=True,
553
+ conformer=False,
554
+ use_sinusoidal_emb=False,
555
+ use_abs_pos_emb=False,
556
+ abs_pos_emb_max_length=10000,
557
+ dtype=None,
558
+ device=None,
559
+ operations=None,
560
+ **kwargs
561
+ ):
562
+
563
+ super().__init__()
564
+
565
+ self.dim = dim
566
+ self.depth = depth
567
+ self.causal = causal
568
+ self.layers = nn.ModuleList([])
569
+
570
+ self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity()
571
+ self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
572
+
573
+ if rotary_pos_emb:
574
+ self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), device=device, dtype=dtype)
575
+ else:
576
+ self.rotary_pos_emb = None
577
+
578
+ self.use_sinusoidal_emb = use_sinusoidal_emb
579
+ if use_sinusoidal_emb:
580
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
581
+
582
+ self.use_abs_pos_emb = use_abs_pos_emb
583
+ if use_abs_pos_emb:
584
+ self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
585
+
586
+ for i in range(depth):
587
+ self.layers.append(
588
+ TransformerBlock(
589
+ dim,
590
+ dim_heads = dim_heads,
591
+ cross_attend = cross_attend,
592
+ dim_context = cond_token_dim,
593
+ global_cond_dim = global_cond_dim,
594
+ causal = causal,
595
+ zero_init_branch_outputs = zero_init_branch_outputs,
596
+ conformer=conformer,
597
+ layer_ix=i,
598
+ dtype=dtype,
599
+ device=device,
600
+ operations=operations,
601
+ **kwargs
602
+ )
603
+ )
604
+
605
+ def forward(
606
+ self,
607
+ x,
608
+ mask = None,
609
+ prepend_embeds = None,
610
+ prepend_mask = None,
611
+ global_cond = None,
612
+ return_info = False,
613
+ **kwargs
614
+ ):
615
+ batch, seq, device = *x.shape[:2], x.device
616
+
617
+ info = {
618
+ "hidden_states": [],
619
+ }
620
+
621
+ x = self.project_in(x)
622
+
623
+ if prepend_embeds is not None:
624
+ prepend_length, prepend_dim = prepend_embeds.shape[1:]
625
+
626
+ assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
627
+
628
+ x = torch.cat((prepend_embeds, x), dim = -2)
629
+
630
+ if prepend_mask is not None or mask is not None:
631
+ mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
632
+ prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
633
+
634
+ mask = torch.cat((prepend_mask, mask), dim = -1)
635
+
636
+ # Attention layers
637
+
638
+ if self.rotary_pos_emb is not None:
639
+ rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
640
+ else:
641
+ rotary_pos_emb = None
642
+
643
+ if self.use_sinusoidal_emb or self.use_abs_pos_emb:
644
+ x = x + self.pos_emb(x)
645
+
646
+ # Iterate over the transformer layers
647
+ for layer in self.layers:
648
+ x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
649
+ # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
650
+
651
+ if return_info:
652
+ info["hidden_states"].append(x)
653
+
654
+ x = self.project_out(x)
655
+
656
+ if return_info:
657
+ return x, info
658
+
659
+ return x
660
+
661
+ class AudioDiffusionTransformer(nn.Module):
662
+ def __init__(self,
663
+ io_channels=64,
664
+ patch_size=1,
665
+ embed_dim=1536,
666
+ cond_token_dim=768,
667
+ project_cond_tokens=False,
668
+ global_cond_dim=1536,
669
+ project_global_cond=True,
670
+ input_concat_dim=0,
671
+ prepend_cond_dim=0,
672
+ depth=24,
673
+ num_heads=24,
674
+ transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
675
+ global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
676
+ audio_model="",
677
+ dtype=None,
678
+ device=None,
679
+ operations=None,
680
+ **kwargs):
681
+
682
+ super().__init__()
683
+
684
+ self.dtype = dtype
685
+ self.cond_token_dim = cond_token_dim
686
+
687
+ # Timestep embeddings
688
+ timestep_features_dim = 256
689
+
690
+ self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device)
691
+
692
+ self.to_timestep_embed = nn.Sequential(
693
+ operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device),
694
+ nn.SiLU(),
695
+ operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device),
696
+ )
697
+
698
+ if cond_token_dim > 0:
699
+ # Conditioning tokens
700
+
701
+ cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
702
+ self.to_cond_embed = nn.Sequential(
703
+ operations.Linear(cond_token_dim, cond_embed_dim, bias=False, dtype=dtype, device=device),
704
+ nn.SiLU(),
705
+ operations.Linear(cond_embed_dim, cond_embed_dim, bias=False, dtype=dtype, device=device)
706
+ )
707
+ else:
708
+ cond_embed_dim = 0
709
+
710
+ if global_cond_dim > 0:
711
+ # Global conditioning
712
+ global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
713
+ self.to_global_embed = nn.Sequential(
714
+ operations.Linear(global_cond_dim, global_embed_dim, bias=False, dtype=dtype, device=device),
715
+ nn.SiLU(),
716
+ operations.Linear(global_embed_dim, global_embed_dim, bias=False, dtype=dtype, device=device)
717
+ )
718
+
719
+ if prepend_cond_dim > 0:
720
+ # Prepend conditioning
721
+ self.to_prepend_embed = nn.Sequential(
722
+ operations.Linear(prepend_cond_dim, embed_dim, bias=False, dtype=dtype, device=device),
723
+ nn.SiLU(),
724
+ operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
725
+ )
726
+
727
+ self.input_concat_dim = input_concat_dim
728
+
729
+ dim_in = io_channels + self.input_concat_dim
730
+
731
+ self.patch_size = patch_size
732
+
733
+ # Transformer
734
+
735
+ self.transformer_type = transformer_type
736
+
737
+ self.global_cond_type = global_cond_type
738
+
739
+ if self.transformer_type == "continuous_transformer":
740
+
741
+ global_dim = None
742
+
743
+ if self.global_cond_type == "adaLN":
744
+ # The global conditioning is projected to the embed_dim already at this point
745
+ global_dim = embed_dim
746
+
747
+ self.transformer = ContinuousTransformer(
748
+ dim=embed_dim,
749
+ depth=depth,
750
+ dim_heads=embed_dim // num_heads,
751
+ dim_in=dim_in * patch_size,
752
+ dim_out=io_channels * patch_size,
753
+ cross_attend = cond_token_dim > 0,
754
+ cond_token_dim = cond_embed_dim,
755
+ global_cond_dim=global_dim,
756
+ dtype=dtype,
757
+ device=device,
758
+ operations=operations,
759
+ **kwargs
760
+ )
761
+ else:
762
+ raise ValueError(f"Unknown transformer type: {self.transformer_type}")
763
+
764
+ self.preprocess_conv = operations.Conv1d(dim_in, dim_in, 1, bias=False, dtype=dtype, device=device)
765
+ self.postprocess_conv = operations.Conv1d(io_channels, io_channels, 1, bias=False, dtype=dtype, device=device)
766
+
767
+ def _forward(
768
+ self,
769
+ x,
770
+ t,
771
+ mask=None,
772
+ cross_attn_cond=None,
773
+ cross_attn_cond_mask=None,
774
+ input_concat_cond=None,
775
+ global_embed=None,
776
+ prepend_cond=None,
777
+ prepend_cond_mask=None,
778
+ return_info=False,
779
+ **kwargs):
780
+
781
+ if cross_attn_cond is not None:
782
+ cross_attn_cond = self.to_cond_embed(cross_attn_cond)
783
+
784
+ if global_embed is not None:
785
+ # Project the global conditioning to the embedding dimension
786
+ global_embed = self.to_global_embed(global_embed)
787
+
788
+ prepend_inputs = None
789
+ prepend_mask = None
790
+ prepend_length = 0
791
+ if prepend_cond is not None:
792
+ # Project the prepend conditioning to the embedding dimension
793
+ prepend_cond = self.to_prepend_embed(prepend_cond)
794
+
795
+ prepend_inputs = prepend_cond
796
+ if prepend_cond_mask is not None:
797
+ prepend_mask = prepend_cond_mask
798
+
799
+ if input_concat_cond is not None:
800
+
801
+ # Interpolate input_concat_cond to the same length as x
802
+ if input_concat_cond.shape[2] != x.shape[2]:
803
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
804
+
805
+ x = torch.cat([x, input_concat_cond], dim=1)
806
+
807
+ # Get the batch of timestep embeddings
808
+ timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None]).to(x.dtype)) # (b, embed_dim)
809
+
810
+ # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
811
+ if global_embed is not None:
812
+ global_embed = global_embed + timestep_embed
813
+ else:
814
+ global_embed = timestep_embed
815
+
816
+ # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
817
+ if self.global_cond_type == "prepend":
818
+ if prepend_inputs is None:
819
+ # Prepend inputs are just the global embed, and the mask is all ones
820
+ prepend_inputs = global_embed.unsqueeze(1)
821
+ prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
822
+ else:
823
+ # Prepend inputs are the prepend conditioning + the global embed
824
+ prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
825
+ prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
826
+
827
+ prepend_length = prepend_inputs.shape[1]
828
+
829
+ x = self.preprocess_conv(x) + x
830
+
831
+ x = rearrange(x, "b c t -> b t c")
832
+
833
+ extra_args = {}
834
+
835
+ if self.global_cond_type == "adaLN":
836
+ extra_args["global_cond"] = global_embed
837
+
838
+ if self.patch_size > 1:
839
+ x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
840
+
841
+ if self.transformer_type == "x-transformers":
842
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
843
+ elif self.transformer_type == "continuous_transformer":
844
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
845
+
846
+ if return_info:
847
+ output, info = output
848
+ elif self.transformer_type == "mm_transformer":
849
+ output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs)
850
+
851
+ output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
852
+
853
+ if self.patch_size > 1:
854
+ output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
855
+
856
+ output = self.postprocess_conv(output) + output
857
+
858
+ if return_info:
859
+ return output, info
860
+
861
+ return output
862
+
863
+ def forward(
864
+ self,
865
+ x,
866
+ timestep,
867
+ context=None,
868
+ context_mask=None,
869
+ input_concat_cond=None,
870
+ global_embed=None,
871
+ negative_global_embed=None,
872
+ prepend_cond=None,
873
+ prepend_cond_mask=None,
874
+ mask=None,
875
+ return_info=False,
876
+ control=None,
877
+ transformer_options={},
878
+ **kwargs):
879
+ return self._forward(
880
+ x,
881
+ timestep,
882
+ cross_attn_cond=context,
883
+ cross_attn_cond_mask=context_mask,
884
+ input_concat_cond=input_concat_cond,
885
+ global_embed=global_embed,
886
+ prepend_cond=prepend_cond,
887
+ prepend_cond_mask=prepend_cond_mask,
888
+ mask=mask,
889
+ return_info=return_info,
890
+ **kwargs
891
+ )
Backend/comfy/ldm/audio/embedders.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code adapted from: https://github.com/Stability-AI/stable-audio-tools
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import Tensor, einsum
6
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
7
+ from einops import rearrange
8
+ import math
9
+ import comfy.ops
10
+
11
+ class LearnedPositionalEmbedding(nn.Module):
12
+ """Used for continuous time"""
13
+
14
+ def __init__(self, dim: int):
15
+ super().__init__()
16
+ assert (dim % 2) == 0
17
+ half_dim = dim // 2
18
+ self.weights = nn.Parameter(torch.empty(half_dim))
19
+
20
+ def forward(self, x: Tensor) -> Tensor:
21
+ x = rearrange(x, "b -> b 1")
22
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
23
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
24
+ fouriered = torch.cat((x, fouriered), dim=-1)
25
+ return fouriered
26
+
27
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
28
+ return nn.Sequential(
29
+ LearnedPositionalEmbedding(dim),
30
+ comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
31
+ )
32
+
33
+
34
+ class NumberEmbedder(nn.Module):
35
+ def __init__(
36
+ self,
37
+ features: int,
38
+ dim: int = 256,
39
+ ):
40
+ super().__init__()
41
+ self.features = features
42
+ self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
43
+
44
+ def forward(self, x: Union[List[float], Tensor]) -> Tensor:
45
+ if not torch.is_tensor(x):
46
+ device = next(self.embedding.parameters()).device
47
+ x = torch.tensor(x, device=device)
48
+ assert isinstance(x, Tensor)
49
+ shape = x.shape
50
+ x = rearrange(x, "... -> (...)")
51
+ embedding = self.embedding(x)
52
+ x = embedding.view(*shape, self.features)
53
+ return x # type: ignore
54
+
55
+
56
+ class Conditioner(nn.Module):
57
+ def __init__(
58
+ self,
59
+ dim: int,
60
+ output_dim: int,
61
+ project_out: bool = False
62
+ ):
63
+
64
+ super().__init__()
65
+
66
+ self.dim = dim
67
+ self.output_dim = output_dim
68
+ self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
69
+
70
+ def forward(self, x):
71
+ raise NotImplementedError()
72
+
73
+ class NumberConditioner(Conditioner):
74
+ '''
75
+ Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
76
+ '''
77
+ def __init__(self,
78
+ output_dim: int,
79
+ min_val: float=0,
80
+ max_val: float=1
81
+ ):
82
+ super().__init__(output_dim, output_dim)
83
+
84
+ self.min_val = min_val
85
+ self.max_val = max_val
86
+
87
+ self.embedder = NumberEmbedder(features=output_dim)
88
+
89
+ def forward(self, floats, device=None):
90
+ # Cast the inputs to floats
91
+ floats = [float(x) for x in floats]
92
+
93
+ if device is None:
94
+ device = next(self.embedder.parameters()).device
95
+
96
+ floats = torch.tensor(floats).to(device)
97
+
98
+ floats = floats.clamp(self.min_val, self.max_val)
99
+
100
+ normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
101
+
102
+ # Cast floats to same type as embedder
103
+ embedder_dtype = next(self.embedder.parameters()).dtype
104
+ normalized_floats = normalized_floats.to(embedder_dtype)
105
+
106
+ float_embeds = self.embedder(normalized_floats).unsqueeze(1)
107
+
108
+ return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
Backend/comfy/ldm/aura/mmdit.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #AuraFlow MMDiT
2
+ #Originally written by the AuraFlow Authors
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from comfy.ldm.modules.attention import optimized_attention
11
+ import comfy.ops
12
+ import comfy.ldm.common_dit
13
+
14
+ def modulate(x, shift, scale):
15
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
16
+
17
+
18
+ def find_multiple(n: int, k: int) -> int:
19
+ if n % k == 0:
20
+ return n
21
+ return n + k - (n % k)
22
+
23
+
24
+ class MLP(nn.Module):
25
+ def __init__(self, dim, hidden_dim=None, dtype=None, device=None, operations=None) -> None:
26
+ super().__init__()
27
+ if hidden_dim is None:
28
+ hidden_dim = 4 * dim
29
+
30
+ n_hidden = int(2 * hidden_dim / 3)
31
+ n_hidden = find_multiple(n_hidden, 256)
32
+
33
+ self.c_fc1 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
34
+ self.c_fc2 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
35
+ self.c_proj = operations.Linear(n_hidden, dim, bias=False, dtype=dtype, device=device)
36
+
37
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
38
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
39
+ x = self.c_proj(x)
40
+ return x
41
+
42
+
43
+ class MultiHeadLayerNorm(nn.Module):
44
+ def __init__(self, hidden_size=None, eps=1e-5, dtype=None, device=None):
45
+ # Copy pasta from https://github.com/huggingface/transformers/blob/e5f71ecaae50ea476d1e12351003790273c4b2ed/src/transformers/models/cohere/modeling_cohere.py#L78
46
+
47
+ super().__init__()
48
+ self.weight = nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device))
49
+ self.variance_epsilon = eps
50
+
51
+ def forward(self, hidden_states):
52
+ input_dtype = hidden_states.dtype
53
+ hidden_states = hidden_states.to(torch.float32)
54
+ mean = hidden_states.mean(-1, keepdim=True)
55
+ variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
56
+ hidden_states = (hidden_states - mean) * torch.rsqrt(
57
+ variance + self.variance_epsilon
58
+ )
59
+ hidden_states = self.weight.to(torch.float32) * hidden_states
60
+ return hidden_states.to(input_dtype)
61
+
62
+ class SingleAttention(nn.Module):
63
+ def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
64
+ super().__init__()
65
+
66
+ self.n_heads = n_heads
67
+ self.head_dim = dim // n_heads
68
+
69
+ # this is for cond
70
+ self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
71
+ self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
72
+ self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
73
+ self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
74
+
75
+ self.q_norm1 = (
76
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
77
+ if mh_qknorm
78
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
79
+ )
80
+ self.k_norm1 = (
81
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
82
+ if mh_qknorm
83
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
84
+ )
85
+
86
+ #@torch.compile()
87
+ def forward(self, c):
88
+
89
+ bsz, seqlen1, _ = c.shape
90
+
91
+ q, k, v = self.w1q(c), self.w1k(c), self.w1v(c)
92
+ q = q.view(bsz, seqlen1, self.n_heads, self.head_dim)
93
+ k = k.view(bsz, seqlen1, self.n_heads, self.head_dim)
94
+ v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
95
+ q, k = self.q_norm1(q), self.k_norm1(k)
96
+
97
+ output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
98
+ c = self.w1o(output)
99
+ return c
100
+
101
+
102
+
103
+ class DoubleAttention(nn.Module):
104
+ def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
105
+ super().__init__()
106
+
107
+ self.n_heads = n_heads
108
+ self.head_dim = dim // n_heads
109
+
110
+ # this is for cond
111
+ self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
112
+ self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
113
+ self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
114
+ self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
115
+
116
+ # this is for x
117
+ self.w2q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
118
+ self.w2k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
119
+ self.w2v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
120
+ self.w2o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
121
+
122
+ self.q_norm1 = (
123
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
124
+ if mh_qknorm
125
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
126
+ )
127
+ self.k_norm1 = (
128
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
129
+ if mh_qknorm
130
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
131
+ )
132
+
133
+ self.q_norm2 = (
134
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
135
+ if mh_qknorm
136
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
137
+ )
138
+ self.k_norm2 = (
139
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
140
+ if mh_qknorm
141
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
142
+ )
143
+
144
+
145
+ #@torch.compile()
146
+ def forward(self, c, x):
147
+
148
+ bsz, seqlen1, _ = c.shape
149
+ bsz, seqlen2, _ = x.shape
150
+ seqlen = seqlen1 + seqlen2
151
+
152
+ cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
153
+ cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
154
+ ck = ck.view(bsz, seqlen1, self.n_heads, self.head_dim)
155
+ cv = cv.view(bsz, seqlen1, self.n_heads, self.head_dim)
156
+ cq, ck = self.q_norm1(cq), self.k_norm1(ck)
157
+
158
+ xq, xk, xv = self.w2q(x), self.w2k(x), self.w2v(x)
159
+ xq = xq.view(bsz, seqlen2, self.n_heads, self.head_dim)
160
+ xk = xk.view(bsz, seqlen2, self.n_heads, self.head_dim)
161
+ xv = xv.view(bsz, seqlen2, self.n_heads, self.head_dim)
162
+ xq, xk = self.q_norm2(xq), self.k_norm2(xk)
163
+
164
+ # concat all
165
+ q, k, v = (
166
+ torch.cat([cq, xq], dim=1),
167
+ torch.cat([ck, xk], dim=1),
168
+ torch.cat([cv, xv], dim=1),
169
+ )
170
+
171
+ output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
172
+
173
+ c, x = output.split([seqlen1, seqlen2], dim=1)
174
+ c = self.w1o(c)
175
+ x = self.w2o(x)
176
+
177
+ return c, x
178
+
179
+
180
+ class MMDiTBlock(nn.Module):
181
+ def __init__(self, dim, heads=8, global_conddim=1024, is_last=False, dtype=None, device=None, operations=None):
182
+ super().__init__()
183
+
184
+ self.normC1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
185
+ self.normC2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
186
+ if not is_last:
187
+ self.mlpC = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
188
+ self.modC = nn.Sequential(
189
+ nn.SiLU(),
190
+ operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
191
+ )
192
+ else:
193
+ self.modC = nn.Sequential(
194
+ nn.SiLU(),
195
+ operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
196
+ )
197
+
198
+ self.normX1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
199
+ self.normX2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
200
+ self.mlpX = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
201
+ self.modX = nn.Sequential(
202
+ nn.SiLU(),
203
+ operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
204
+ )
205
+
206
+ self.attn = DoubleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
207
+ self.is_last = is_last
208
+
209
+ #@torch.compile()
210
+ def forward(self, c, x, global_cond, **kwargs):
211
+
212
+ cres, xres = c, x
213
+
214
+ cshift_msa, cscale_msa, cgate_msa, cshift_mlp, cscale_mlp, cgate_mlp = (
215
+ self.modC(global_cond).chunk(6, dim=1)
216
+ )
217
+
218
+ c = modulate(self.normC1(c), cshift_msa, cscale_msa)
219
+
220
+ # xpath
221
+ xshift_msa, xscale_msa, xgate_msa, xshift_mlp, xscale_mlp, xgate_mlp = (
222
+ self.modX(global_cond).chunk(6, dim=1)
223
+ )
224
+
225
+ x = modulate(self.normX1(x), xshift_msa, xscale_msa)
226
+
227
+ # attention
228
+ c, x = self.attn(c, x)
229
+
230
+
231
+ c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
232
+ c = cgate_mlp.unsqueeze(1) * self.mlpC(modulate(c, cshift_mlp, cscale_mlp))
233
+ c = cres + c
234
+
235
+ x = self.normX2(xres + xgate_msa.unsqueeze(1) * x)
236
+ x = xgate_mlp.unsqueeze(1) * self.mlpX(modulate(x, xshift_mlp, xscale_mlp))
237
+ x = xres + x
238
+
239
+ return c, x
240
+
241
+ class DiTBlock(nn.Module):
242
+ # like MMDiTBlock, but it only has X
243
+ def __init__(self, dim, heads=8, global_conddim=1024, dtype=None, device=None, operations=None):
244
+ super().__init__()
245
+
246
+ self.norm1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
247
+ self.norm2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
248
+
249
+ self.modCX = nn.Sequential(
250
+ nn.SiLU(),
251
+ operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
252
+ )
253
+
254
+ self.attn = SingleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
255
+ self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
256
+
257
+ #@torch.compile()
258
+ def forward(self, cx, global_cond, **kwargs):
259
+ cxres = cx
260
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
261
+ global_cond
262
+ ).chunk(6, dim=1)
263
+ cx = modulate(self.norm1(cx), shift_msa, scale_msa)
264
+ cx = self.attn(cx)
265
+ cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
266
+ mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
267
+ cx = gate_mlp.unsqueeze(1) * mlpout
268
+
269
+ cx = cxres + cx
270
+
271
+ return cx
272
+
273
+
274
+
275
+ class TimestepEmbedder(nn.Module):
276
+ def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
277
+ super().__init__()
278
+ self.mlp = nn.Sequential(
279
+ operations.Linear(frequency_embedding_size, hidden_size, dtype=dtype, device=device),
280
+ nn.SiLU(),
281
+ operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device),
282
+ )
283
+ self.frequency_embedding_size = frequency_embedding_size
284
+
285
+ @staticmethod
286
+ def timestep_embedding(t, dim, max_period=10000):
287
+ half = dim // 2
288
+ freqs = 1000 * torch.exp(
289
+ -math.log(max_period) * torch.arange(start=0, end=half) / half
290
+ ).to(t.device)
291
+ args = t[:, None] * freqs[None]
292
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
293
+ if dim % 2:
294
+ embedding = torch.cat(
295
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
296
+ )
297
+ return embedding
298
+
299
+ #@torch.compile()
300
+ def forward(self, t, dtype):
301
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
302
+ t_emb = self.mlp(t_freq)
303
+ return t_emb
304
+
305
+
306
+ class MMDiT(nn.Module):
307
+ def __init__(
308
+ self,
309
+ in_channels=4,
310
+ out_channels=4,
311
+ patch_size=2,
312
+ dim=3072,
313
+ n_layers=36,
314
+ n_double_layers=4,
315
+ n_heads=12,
316
+ global_conddim=3072,
317
+ cond_seq_dim=2048,
318
+ max_seq=32 * 32,
319
+ device=None,
320
+ dtype=None,
321
+ operations=None,
322
+ ):
323
+ super().__init__()
324
+ self.dtype = dtype
325
+
326
+ self.t_embedder = TimestepEmbedder(global_conddim, dtype=dtype, device=device, operations=operations)
327
+
328
+ self.cond_seq_linear = operations.Linear(
329
+ cond_seq_dim, dim, bias=False, dtype=dtype, device=device
330
+ ) # linear for something like text sequence.
331
+ self.init_x_linear = operations.Linear(
332
+ patch_size * patch_size * in_channels, dim, dtype=dtype, device=device
333
+ ) # init linear for patchified image.
334
+
335
+ self.positional_encoding = nn.Parameter(torch.empty(1, max_seq, dim, dtype=dtype, device=device))
336
+ self.register_tokens = nn.Parameter(torch.empty(1, 8, dim, dtype=dtype, device=device))
337
+
338
+ self.double_layers = nn.ModuleList([])
339
+ self.single_layers = nn.ModuleList([])
340
+
341
+
342
+ for idx in range(n_double_layers):
343
+ self.double_layers.append(
344
+ MMDiTBlock(dim, n_heads, global_conddim, is_last=(idx == n_layers - 1), dtype=dtype, device=device, operations=operations)
345
+ )
346
+
347
+ for idx in range(n_double_layers, n_layers):
348
+ self.single_layers.append(
349
+ DiTBlock(dim, n_heads, global_conddim, dtype=dtype, device=device, operations=operations)
350
+ )
351
+
352
+
353
+ self.final_linear = operations.Linear(
354
+ dim, patch_size * patch_size * out_channels, bias=False, dtype=dtype, device=device
355
+ )
356
+
357
+ self.modF = nn.Sequential(
358
+ nn.SiLU(),
359
+ operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
360
+ )
361
+
362
+ self.out_channels = out_channels
363
+ self.patch_size = patch_size
364
+ self.n_double_layers = n_double_layers
365
+ self.n_layers = n_layers
366
+
367
+ self.h_max = round(max_seq**0.5)
368
+ self.w_max = round(max_seq**0.5)
369
+
370
+ @torch.no_grad()
371
+ def extend_pe(self, init_dim=(16, 16), target_dim=(64, 64)):
372
+ # extend pe
373
+ pe_data = self.positional_encoding.data.squeeze(0)[: init_dim[0] * init_dim[1]]
374
+
375
+ pe_as_2d = pe_data.view(init_dim[0], init_dim[1], -1).permute(2, 0, 1)
376
+
377
+ # now we need to extend this to target_dim. for this we will use interpolation.
378
+ # we will use torch.nn.functional.interpolate
379
+ pe_as_2d = F.interpolate(
380
+ pe_as_2d.unsqueeze(0), size=target_dim, mode="bilinear"
381
+ )
382
+ pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
383
+ self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
384
+ self.h_max, self.w_max = target_dim
385
+ print("PE extended to", target_dim)
386
+
387
+ def pe_selection_index_based_on_dim(self, h, w):
388
+ h_p, w_p = h // self.patch_size, w // self.patch_size
389
+ original_pe_indexes = torch.arange(self.positional_encoding.shape[1])
390
+ original_pe_indexes = original_pe_indexes.view(self.h_max, self.w_max)
391
+ starth = self.h_max // 2 - h_p // 2
392
+ endh =starth + h_p
393
+ startw = self.w_max // 2 - w_p // 2
394
+ endw = startw + w_p
395
+ original_pe_indexes = original_pe_indexes[
396
+ starth:endh, startw:endw
397
+ ]
398
+ return original_pe_indexes.flatten()
399
+
400
+ def unpatchify(self, x, h, w):
401
+ c = self.out_channels
402
+ p = self.patch_size
403
+
404
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
405
+ x = torch.einsum("nhwpqc->nchpwq", x)
406
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
407
+ return imgs
408
+
409
+ def patchify(self, x):
410
+ B, C, H, W = x.size()
411
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
412
+ x = x.view(
413
+ B,
414
+ C,
415
+ (H + 1) // self.patch_size,
416
+ self.patch_size,
417
+ (W + 1) // self.patch_size,
418
+ self.patch_size,
419
+ )
420
+ x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
421
+ return x
422
+
423
+ def apply_pos_embeds(self, x, h, w):
424
+ h = (h + 1) // self.patch_size
425
+ w = (w + 1) // self.patch_size
426
+ max_dim = max(h, w)
427
+
428
+ cur_dim = self.h_max
429
+ pos_encoding = comfy.ops.cast_to_input(self.positional_encoding.reshape(1, cur_dim, cur_dim, -1), x)
430
+
431
+ if max_dim > cur_dim:
432
+ pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
433
+ cur_dim = max_dim
434
+
435
+ from_h = (cur_dim - h) // 2
436
+ from_w = (cur_dim - w) // 2
437
+ pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
438
+ return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
439
+
440
+ def forward(self, x, timestep, context, **kwargs):
441
+ # patchify x, add PE
442
+ b, c, h, w = x.shape
443
+
444
+ # pe_indexes = self.pe_selection_index_based_on_dim(h, w)
445
+ # print(pe_indexes, pe_indexes.shape)
446
+
447
+ x = self.init_x_linear(self.patchify(x)) # B, T_x, D
448
+ x = self.apply_pos_embeds(x, h, w)
449
+ # x = x + self.positional_encoding[:, : x.size(1)].to(device=x.device, dtype=x.dtype)
450
+ # x = x + self.positional_encoding[:, pe_indexes].to(device=x.device, dtype=x.dtype)
451
+
452
+ # process conditions for MMDiT Blocks
453
+ c_seq = context # B, T_c, D_c
454
+ t = timestep
455
+
456
+ c = self.cond_seq_linear(c_seq) # B, T_c, D
457
+ c = torch.cat([comfy.ops.cast_to_input(self.register_tokens, c).repeat(c.size(0), 1, 1), c], dim=1)
458
+
459
+ global_cond = self.t_embedder(t, x.dtype) # B, D
460
+
461
+ if len(self.double_layers) > 0:
462
+ for layer in self.double_layers:
463
+ c, x = layer(c, x, global_cond, **kwargs)
464
+
465
+ if len(self.single_layers) > 0:
466
+ c_len = c.size(1)
467
+ cx = torch.cat([c, x], dim=1)
468
+ for layer in self.single_layers:
469
+ cx = layer(cx, global_cond, **kwargs)
470
+
471
+ x = cx[:, c_len:]
472
+
473
+ fshift, fscale = self.modF(global_cond).chunk(2, dim=1)
474
+
475
+ x = modulate(x, fshift, fscale)
476
+ x = self.final_linear(x)
477
+ x = self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w]
478
+ return x
Backend/comfy/ldm/cascade/common.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from comfy.ldm.modules.attention import optimized_attention
22
+ import comfy.ops
23
+
24
+ class OptimizedAttention(nn.Module):
25
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
26
+ super().__init__()
27
+ self.heads = nhead
28
+
29
+ self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
30
+ self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
31
+ self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
32
+
33
+ self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
34
+
35
+ def forward(self, q, k, v):
36
+ q = self.to_q(q)
37
+ k = self.to_k(k)
38
+ v = self.to_v(v)
39
+
40
+ out = optimized_attention(q, k, v, self.heads)
41
+
42
+ return self.out_proj(out)
43
+
44
+ class Attention2D(nn.Module):
45
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
46
+ super().__init__()
47
+ self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
48
+ # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
49
+
50
+ def forward(self, x, kv, self_attn=False):
51
+ orig_shape = x.shape
52
+ x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
53
+ if self_attn:
54
+ kv = torch.cat([x, kv], dim=1)
55
+ # x = self.attn(x, kv, kv, need_weights=False)[0]
56
+ x = self.attn(x, kv, kv)
57
+ x = x.permute(0, 2, 1).view(*orig_shape)
58
+ return x
59
+
60
+
61
+ def LayerNorm2d_op(operations):
62
+ class LayerNorm2d(operations.LayerNorm):
63
+ def __init__(self, *args, **kwargs):
64
+ super().__init__(*args, **kwargs)
65
+
66
+ def forward(self, x):
67
+ return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
68
+ return LayerNorm2d
69
+
70
+ class GlobalResponseNorm(nn.Module):
71
+ "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
72
+ def __init__(self, dim, dtype=None, device=None):
73
+ super().__init__()
74
+ self.gamma = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
75
+ self.beta = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
76
+
77
+ def forward(self, x):
78
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
79
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
80
+ return comfy.ops.cast_to_input(self.gamma, x) * (x * Nx) + comfy.ops.cast_to_input(self.beta, x) + x
81
+
82
+
83
+ class ResBlock(nn.Module):
84
+ def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
85
+ super().__init__()
86
+ self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
87
+ # self.depthwise = SAMBlock(c, num_heads, expansion)
88
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
89
+ self.channelwise = nn.Sequential(
90
+ operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
91
+ nn.GELU(),
92
+ GlobalResponseNorm(c * 4, dtype=dtype, device=device),
93
+ nn.Dropout(dropout),
94
+ operations.Linear(c * 4, c, dtype=dtype, device=device)
95
+ )
96
+
97
+ def forward(self, x, x_skip=None):
98
+ x_res = x
99
+ x = self.norm(self.depthwise(x))
100
+ if x_skip is not None:
101
+ x = torch.cat([x, x_skip], dim=1)
102
+ x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
103
+ return x + x_res
104
+
105
+
106
+ class AttnBlock(nn.Module):
107
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
108
+ super().__init__()
109
+ self.self_attn = self_attn
110
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
111
+ self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
112
+ self.kv_mapper = nn.Sequential(
113
+ nn.SiLU(),
114
+ operations.Linear(c_cond, c, dtype=dtype, device=device)
115
+ )
116
+
117
+ def forward(self, x, kv):
118
+ kv = self.kv_mapper(kv)
119
+ x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
120
+ return x
121
+
122
+
123
+ class FeedForwardBlock(nn.Module):
124
+ def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
125
+ super().__init__()
126
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
127
+ self.channelwise = nn.Sequential(
128
+ operations.Linear(c, c * 4, dtype=dtype, device=device),
129
+ nn.GELU(),
130
+ GlobalResponseNorm(c * 4, dtype=dtype, device=device),
131
+ nn.Dropout(dropout),
132
+ operations.Linear(c * 4, c, dtype=dtype, device=device)
133
+ )
134
+
135
+ def forward(self, x):
136
+ x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
137
+ return x
138
+
139
+
140
+ class TimestepBlock(nn.Module):
141
+ def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
142
+ super().__init__()
143
+ self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
144
+ self.conds = conds
145
+ for cname in conds:
146
+ setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
147
+
148
+ def forward(self, x, t):
149
+ t = t.chunk(len(self.conds) + 1, dim=1)
150
+ a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
151
+ for i, c in enumerate(self.conds):
152
+ ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
153
+ a, b = a + ac, b + bc
154
+ return x * (1 + a) + b
Backend/comfy/ldm/cascade/controlnet.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torch
20
+ import torchvision
21
+ from torch import nn
22
+ from .common import LayerNorm2d_op
23
+
24
+
25
+ class CNetResBlock(nn.Module):
26
+ def __init__(self, c, dtype=None, device=None, operations=None):
27
+ super().__init__()
28
+ self.blocks = nn.Sequential(
29
+ LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
30
+ nn.GELU(),
31
+ operations.Conv2d(c, c, kernel_size=3, padding=1),
32
+ LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
33
+ nn.GELU(),
34
+ operations.Conv2d(c, c, kernel_size=3, padding=1),
35
+ )
36
+
37
+ def forward(self, x):
38
+ return x + self.blocks(x)
39
+
40
+
41
+ class ControlNet(nn.Module):
42
+ def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None, dtype=None, device=None, operations=nn):
43
+ super().__init__()
44
+ if bottleneck_mode is None:
45
+ bottleneck_mode = 'effnet'
46
+ self.proj_blocks = proj_blocks
47
+ if bottleneck_mode == 'effnet':
48
+ embd_channels = 1280
49
+ self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
50
+ if c_in != 3:
51
+ in_weights = self.backbone[0][0].weight.data
52
+ self.backbone[0][0] = operations.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False, dtype=dtype, device=device)
53
+ if c_in > 3:
54
+ # nn.init.constant_(self.backbone[0][0].weight, 0)
55
+ self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone()
56
+ else:
57
+ self.backbone[0][0].weight.data = in_weights[:, :c_in].clone()
58
+ elif bottleneck_mode == 'simple':
59
+ embd_channels = c_in
60
+ self.backbone = nn.Sequential(
61
+ operations.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1, dtype=dtype, device=device),
62
+ nn.LeakyReLU(0.2, inplace=True),
63
+ operations.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1, dtype=dtype, device=device),
64
+ )
65
+ elif bottleneck_mode == 'large':
66
+ self.backbone = nn.Sequential(
67
+ operations.Conv2d(c_in, 4096 * 4, kernel_size=1, dtype=dtype, device=device),
68
+ nn.LeakyReLU(0.2, inplace=True),
69
+ operations.Conv2d(4096 * 4, 1024, kernel_size=1, dtype=dtype, device=device),
70
+ *[CNetResBlock(1024, dtype=dtype, device=device, operations=operations) for _ in range(8)],
71
+ operations.Conv2d(1024, 1280, kernel_size=1, dtype=dtype, device=device),
72
+ )
73
+ embd_channels = 1280
74
+ else:
75
+ raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}')
76
+ self.projections = nn.ModuleList()
77
+ for _ in range(len(proj_blocks)):
78
+ self.projections.append(nn.Sequential(
79
+ operations.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False, dtype=dtype, device=device),
80
+ nn.LeakyReLU(0.2, inplace=True),
81
+ operations.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False, dtype=dtype, device=device),
82
+ ))
83
+ # nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection
84
+ self.xl = False
85
+ self.input_channels = c_in
86
+ self.unshuffle_amount = 8
87
+
88
+ def forward(self, x):
89
+ x = self.backbone(x)
90
+ proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
91
+ for i, idx in enumerate(self.proj_blocks):
92
+ proj_outputs[idx] = self.projections[i](x)
93
+ return {"input": proj_outputs[::-1]}
Backend/comfy/ldm/cascade/stage_a.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torch
20
+ from torch import nn
21
+ from torch.autograd import Function
22
+
23
+ class vector_quantize(Function):
24
+ @staticmethod
25
+ def forward(ctx, x, codebook):
26
+ with torch.no_grad():
27
+ codebook_sqr = torch.sum(codebook ** 2, dim=1)
28
+ x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
29
+
30
+ dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
31
+ _, indices = dist.min(dim=1)
32
+
33
+ ctx.save_for_backward(indices, codebook)
34
+ ctx.mark_non_differentiable(indices)
35
+
36
+ nn = torch.index_select(codebook, 0, indices)
37
+ return nn, indices
38
+
39
+ @staticmethod
40
+ def backward(ctx, grad_output, grad_indices):
41
+ grad_inputs, grad_codebook = None, None
42
+
43
+ if ctx.needs_input_grad[0]:
44
+ grad_inputs = grad_output.clone()
45
+ if ctx.needs_input_grad[1]:
46
+ # Gradient wrt. the codebook
47
+ indices, codebook = ctx.saved_tensors
48
+
49
+ grad_codebook = torch.zeros_like(codebook)
50
+ grad_codebook.index_add_(0, indices, grad_output)
51
+
52
+ return (grad_inputs, grad_codebook)
53
+
54
+
55
+ class VectorQuantize(nn.Module):
56
+ def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
57
+ """
58
+ Takes an input of variable size (as long as the last dimension matches the embedding size).
59
+ Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
60
+ with the same size as the input, vq and commitment components for the loss as a touple
61
+ in the second output and the indices of the quantized vectors in the third:
62
+ quantized, (vq_loss, commit_loss), indices
63
+ """
64
+ super(VectorQuantize, self).__init__()
65
+
66
+ self.codebook = nn.Embedding(k, embedding_size)
67
+ self.codebook.weight.data.uniform_(-1./k, 1./k)
68
+ self.vq = vector_quantize.apply
69
+
70
+ self.ema_decay = ema_decay
71
+ self.ema_loss = ema_loss
72
+ if ema_loss:
73
+ self.register_buffer('ema_element_count', torch.ones(k))
74
+ self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
75
+
76
+ def _laplace_smoothing(self, x, epsilon):
77
+ n = torch.sum(x)
78
+ return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
79
+
80
+ def _updateEMA(self, z_e_x, indices):
81
+ mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
82
+ elem_count = mask.sum(dim=0)
83
+ weight_sum = torch.mm(mask.t(), z_e_x)
84
+
85
+ self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
86
+ self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
87
+ self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
88
+
89
+ self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
90
+
91
+ def idx2vq(self, idx, dim=-1):
92
+ q_idx = self.codebook(idx)
93
+ if dim != -1:
94
+ q_idx = q_idx.movedim(-1, dim)
95
+ return q_idx
96
+
97
+ def forward(self, x, get_losses=True, dim=-1):
98
+ if dim != -1:
99
+ x = x.movedim(dim, -1)
100
+ z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
101
+ z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
102
+ vq_loss, commit_loss = None, None
103
+ if self.ema_loss and self.training:
104
+ self._updateEMA(z_e_x.detach(), indices.detach())
105
+ # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
106
+ z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
107
+ if get_losses:
108
+ vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
109
+ commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
110
+
111
+ z_q_x = z_q_x.view(x.shape)
112
+ if dim != -1:
113
+ z_q_x = z_q_x.movedim(-1, dim)
114
+ return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
115
+
116
+
117
+ class ResBlock(nn.Module):
118
+ def __init__(self, c, c_hidden):
119
+ super().__init__()
120
+ # depthwise/attention
121
+ self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
122
+ self.depthwise = nn.Sequential(
123
+ nn.ReplicationPad2d(1),
124
+ nn.Conv2d(c, c, kernel_size=3, groups=c)
125
+ )
126
+
127
+ # channelwise
128
+ self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
129
+ self.channelwise = nn.Sequential(
130
+ nn.Linear(c, c_hidden),
131
+ nn.GELU(),
132
+ nn.Linear(c_hidden, c),
133
+ )
134
+
135
+ self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
136
+
137
+ # Init weights
138
+ def _basic_init(module):
139
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
140
+ torch.nn.init.xavier_uniform_(module.weight)
141
+ if module.bias is not None:
142
+ nn.init.constant_(module.bias, 0)
143
+
144
+ self.apply(_basic_init)
145
+
146
+ def _norm(self, x, norm):
147
+ return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
148
+
149
+ def forward(self, x):
150
+ mods = self.gammas
151
+
152
+ x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
153
+ try:
154
+ x = x + self.depthwise(x_temp) * mods[2]
155
+ except: #operation not implemented for bf16
156
+ x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
157
+ x = x + self.depthwise[1](x_temp) * mods[2]
158
+
159
+ x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
160
+ x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
161
+
162
+ return x
163
+
164
+
165
+ class StageA(nn.Module):
166
+ def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
167
+ super().__init__()
168
+ self.c_latent = c_latent
169
+ c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
170
+
171
+ # Encoder blocks
172
+ self.in_block = nn.Sequential(
173
+ nn.PixelUnshuffle(2),
174
+ nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
175
+ )
176
+ down_blocks = []
177
+ for i in range(levels):
178
+ if i > 0:
179
+ down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
180
+ block = ResBlock(c_levels[i], c_levels[i] * 4)
181
+ down_blocks.append(block)
182
+ down_blocks.append(nn.Sequential(
183
+ nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
184
+ nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
185
+ ))
186
+ self.down_blocks = nn.Sequential(*down_blocks)
187
+ self.down_blocks[0]
188
+
189
+ self.codebook_size = codebook_size
190
+ self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
191
+
192
+ # Decoder blocks
193
+ up_blocks = [nn.Sequential(
194
+ nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
195
+ )]
196
+ for i in range(levels):
197
+ for j in range(bottleneck_blocks if i == 0 else 1):
198
+ block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
199
+ up_blocks.append(block)
200
+ if i < levels - 1:
201
+ up_blocks.append(
202
+ nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
203
+ padding=1))
204
+ self.up_blocks = nn.Sequential(*up_blocks)
205
+ self.out_block = nn.Sequential(
206
+ nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
207
+ nn.PixelShuffle(2),
208
+ )
209
+
210
+ def encode(self, x, quantize=False):
211
+ x = self.in_block(x)
212
+ x = self.down_blocks(x)
213
+ if quantize:
214
+ qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
215
+ return qe, x, indices, vq_loss + commit_loss * 0.25
216
+ else:
217
+ return x
218
+
219
+ def decode(self, x):
220
+ x = self.up_blocks(x)
221
+ x = self.out_block(x)
222
+ return x
223
+
224
+ def forward(self, x, quantize=False):
225
+ qe, x, _, vq_loss = self.encode(x, quantize)
226
+ x = self.decode(qe)
227
+ return x, vq_loss
228
+
229
+
230
+ class Discriminator(nn.Module):
231
+ def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
232
+ super().__init__()
233
+ d = max(depth - 3, 3)
234
+ layers = [
235
+ nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
236
+ nn.LeakyReLU(0.2),
237
+ ]
238
+ for i in range(depth - 1):
239
+ c_in = c_hidden // (2 ** max((d - i), 0))
240
+ c_out = c_hidden // (2 ** max((d - 1 - i), 0))
241
+ layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
242
+ layers.append(nn.InstanceNorm2d(c_out))
243
+ layers.append(nn.LeakyReLU(0.2))
244
+ self.encoder = nn.Sequential(*layers)
245
+ self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
246
+ self.logits = nn.Sigmoid()
247
+
248
+ def forward(self, x, cond=None):
249
+ x = self.encoder(x)
250
+ if cond is not None:
251
+ cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
252
+ x = torch.cat([x, cond], dim=1)
253
+ x = self.shuffle(x)
254
+ x = self.logits(x)
255
+ return x
Backend/comfy/ldm/cascade/stage_b.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import math
20
+ import torch
21
+ from torch import nn
22
+ from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
23
+
24
+ class StageB(nn.Module):
25
+ def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280],
26
+ nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
27
+ block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280,
28
+ c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True,
29
+ t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None):
30
+ super().__init__()
31
+ self.dtype = dtype
32
+ self.c_r = c_r
33
+ self.t_conds = t_conds
34
+ self.c_clip_seq = c_clip_seq
35
+ if not isinstance(dropout, list):
36
+ dropout = [dropout] * len(c_hidden)
37
+ if not isinstance(self_attn, list):
38
+ self_attn = [self_attn] * len(c_hidden)
39
+
40
+ # CONDITIONING
41
+ self.effnet_mapper = nn.Sequential(
42
+ operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
43
+ nn.GELU(),
44
+ operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
45
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
46
+ )
47
+ self.pixels_mapper = nn.Sequential(
48
+ operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
49
+ nn.GELU(),
50
+ operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
51
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
52
+ )
53
+ self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device)
54
+ self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
55
+
56
+ self.embedding = nn.Sequential(
57
+ nn.PixelUnshuffle(patch_size),
58
+ operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
59
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
60
+ )
61
+
62
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
63
+ if block_type == 'C':
64
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
65
+ elif block_type == 'A':
66
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
67
+ elif block_type == 'F':
68
+ return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
69
+ elif block_type == 'T':
70
+ return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
71
+ else:
72
+ raise Exception(f'Block type {block_type} not supported')
73
+
74
+ # BLOCKS
75
+ # -- down blocks
76
+ self.down_blocks = nn.ModuleList()
77
+ self.down_downscalers = nn.ModuleList()
78
+ self.down_repeat_mappers = nn.ModuleList()
79
+ for i in range(len(c_hidden)):
80
+ if i > 0:
81
+ self.down_downscalers.append(nn.Sequential(
82
+ LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
83
+ operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device),
84
+ ))
85
+ else:
86
+ self.down_downscalers.append(nn.Identity())
87
+ down_block = nn.ModuleList()
88
+ for _ in range(blocks[0][i]):
89
+ for block_type in level_config[i]:
90
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
91
+ down_block.append(block)
92
+ self.down_blocks.append(down_block)
93
+ if block_repeat is not None:
94
+ block_repeat_mappers = nn.ModuleList()
95
+ for _ in range(block_repeat[0][i] - 1):
96
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
97
+ self.down_repeat_mappers.append(block_repeat_mappers)
98
+
99
+ # -- up blocks
100
+ self.up_blocks = nn.ModuleList()
101
+ self.up_upscalers = nn.ModuleList()
102
+ self.up_repeat_mappers = nn.ModuleList()
103
+ for i in reversed(range(len(c_hidden))):
104
+ if i > 0:
105
+ self.up_upscalers.append(nn.Sequential(
106
+ LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
107
+ operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device),
108
+ ))
109
+ else:
110
+ self.up_upscalers.append(nn.Identity())
111
+ up_block = nn.ModuleList()
112
+ for j in range(blocks[1][::-1][i]):
113
+ for k, block_type in enumerate(level_config[i]):
114
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
115
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
116
+ self_attn=self_attn[i])
117
+ up_block.append(block)
118
+ self.up_blocks.append(up_block)
119
+ if block_repeat is not None:
120
+ block_repeat_mappers = nn.ModuleList()
121
+ for _ in range(block_repeat[1][::-1][i] - 1):
122
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
123
+ self.up_repeat_mappers.append(block_repeat_mappers)
124
+
125
+ # OUTPUT
126
+ self.clf = nn.Sequential(
127
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
128
+ operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
129
+ nn.PixelShuffle(patch_size),
130
+ )
131
+
132
+ # --- WEIGHT INIT ---
133
+ # self.apply(self._init_weights) # General init
134
+ # nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
135
+ # nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
136
+ # nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
137
+ # nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
138
+ # nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
139
+ # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
140
+ # nn.init.constant_(self.clf[1].weight, 0) # outputs
141
+ #
142
+ # # blocks
143
+ # for level_block in self.down_blocks + self.up_blocks:
144
+ # for block in level_block:
145
+ # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
146
+ # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
147
+ # elif isinstance(block, TimestepBlock):
148
+ # for layer in block.modules():
149
+ # if isinstance(layer, nn.Linear):
150
+ # nn.init.constant_(layer.weight, 0)
151
+ #
152
+ # def _init_weights(self, m):
153
+ # if isinstance(m, (nn.Conv2d, nn.Linear)):
154
+ # torch.nn.init.xavier_uniform_(m.weight)
155
+ # if m.bias is not None:
156
+ # nn.init.constant_(m.bias, 0)
157
+
158
+ def gen_r_embedding(self, r, max_positions=10000):
159
+ r = r * max_positions
160
+ half_dim = self.c_r // 2
161
+ emb = math.log(max_positions) / (half_dim - 1)
162
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
163
+ emb = r[:, None] * emb[None, :]
164
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
165
+ if self.c_r % 2 == 1: # zero pad
166
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
167
+ return emb
168
+
169
+ def gen_c_embeddings(self, clip):
170
+ if len(clip.shape) == 2:
171
+ clip = clip.unsqueeze(1)
172
+ clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
173
+ clip = self.clip_norm(clip)
174
+ return clip
175
+
176
+ def _down_encode(self, x, r_embed, clip):
177
+ level_outputs = []
178
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
179
+ for down_block, downscaler, repmap in block_group:
180
+ x = downscaler(x)
181
+ for i in range(len(repmap) + 1):
182
+ for block in down_block:
183
+ if isinstance(block, ResBlock) or (
184
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
185
+ ResBlock)):
186
+ x = block(x)
187
+ elif isinstance(block, AttnBlock) or (
188
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
189
+ AttnBlock)):
190
+ x = block(x, clip)
191
+ elif isinstance(block, TimestepBlock) or (
192
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
193
+ TimestepBlock)):
194
+ x = block(x, r_embed)
195
+ else:
196
+ x = block(x)
197
+ if i < len(repmap):
198
+ x = repmap[i](x)
199
+ level_outputs.insert(0, x)
200
+ return level_outputs
201
+
202
+ def _up_decode(self, level_outputs, r_embed, clip):
203
+ x = level_outputs[0]
204
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
205
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
206
+ for j in range(len(repmap) + 1):
207
+ for k, block in enumerate(up_block):
208
+ if isinstance(block, ResBlock) or (
209
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
210
+ ResBlock)):
211
+ skip = level_outputs[i] if k == 0 and i > 0 else None
212
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
213
+ x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
214
+ align_corners=True)
215
+ x = block(x, skip)
216
+ elif isinstance(block, AttnBlock) or (
217
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
218
+ AttnBlock)):
219
+ x = block(x, clip)
220
+ elif isinstance(block, TimestepBlock) or (
221
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
222
+ TimestepBlock)):
223
+ x = block(x, r_embed)
224
+ else:
225
+ x = block(x)
226
+ if j < len(repmap):
227
+ x = repmap[j](x)
228
+ x = upscaler(x)
229
+ return x
230
+
231
+ def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
232
+ if pixels is None:
233
+ pixels = x.new_zeros(x.size(0), 3, 8, 8)
234
+
235
+ # Process the conditioning embeddings
236
+ r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
237
+ for c in self.t_conds:
238
+ t_cond = kwargs.get(c, torch.zeros_like(r))
239
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
240
+ clip = self.gen_c_embeddings(clip)
241
+
242
+ # Model Blocks
243
+ x = self.embedding(x)
244
+ x = x + self.effnet_mapper(
245
+ nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
246
+ x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
247
+ align_corners=True)
248
+ level_outputs = self._down_encode(x, r_embed, clip)
249
+ x = self._up_decode(level_outputs, r_embed, clip)
250
+ return self.clf(x)
251
+
252
+ def update_weights_ema(self, src_model, beta=0.999):
253
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
254
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
255
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
256
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
Backend/comfy/ldm/cascade/stage_c.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torch
20
+ from torch import nn
21
+ import math
22
+ from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
23
+ # from .controlnet import ControlNetDeliverer
24
+
25
+ class UpDownBlock2d(nn.Module):
26
+ def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None):
27
+ super().__init__()
28
+ assert mode in ['up', 'down']
29
+ interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
30
+ align_corners=True) if enabled else nn.Identity()
31
+ mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device)
32
+ self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
33
+
34
+ def forward(self, x):
35
+ for block in self.blocks:
36
+ x = block(x)
37
+ return x
38
+
39
+
40
+ class StageC(nn.Module):
41
+ def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
42
+ blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
43
+ c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
44
+ dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None,
45
+ dtype=None, device=None, operations=None):
46
+ super().__init__()
47
+ self.dtype = dtype
48
+ self.c_r = c_r
49
+ self.t_conds = t_conds
50
+ self.c_clip_seq = c_clip_seq
51
+ if not isinstance(dropout, list):
52
+ dropout = [dropout] * len(c_hidden)
53
+ if not isinstance(self_attn, list):
54
+ self_attn = [self_attn] * len(c_hidden)
55
+
56
+ # CONDITIONING
57
+ self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device)
58
+ self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device)
59
+ self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device)
60
+ self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
61
+
62
+ self.embedding = nn.Sequential(
63
+ nn.PixelUnshuffle(patch_size),
64
+ operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
65
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6)
66
+ )
67
+
68
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
69
+ if block_type == 'C':
70
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
71
+ elif block_type == 'A':
72
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
73
+ elif block_type == 'F':
74
+ return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
75
+ elif block_type == 'T':
76
+ return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
77
+ else:
78
+ raise Exception(f'Block type {block_type} not supported')
79
+
80
+ # BLOCKS
81
+ # -- down blocks
82
+ self.down_blocks = nn.ModuleList()
83
+ self.down_downscalers = nn.ModuleList()
84
+ self.down_repeat_mappers = nn.ModuleList()
85
+ for i in range(len(c_hidden)):
86
+ if i > 0:
87
+ self.down_downscalers.append(nn.Sequential(
88
+ LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
89
+ UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
90
+ ))
91
+ else:
92
+ self.down_downscalers.append(nn.Identity())
93
+ down_block = nn.ModuleList()
94
+ for _ in range(blocks[0][i]):
95
+ for block_type in level_config[i]:
96
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
97
+ down_block.append(block)
98
+ self.down_blocks.append(down_block)
99
+ if block_repeat is not None:
100
+ block_repeat_mappers = nn.ModuleList()
101
+ for _ in range(block_repeat[0][i] - 1):
102
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
103
+ self.down_repeat_mappers.append(block_repeat_mappers)
104
+
105
+ # -- up blocks
106
+ self.up_blocks = nn.ModuleList()
107
+ self.up_upscalers = nn.ModuleList()
108
+ self.up_repeat_mappers = nn.ModuleList()
109
+ for i in reversed(range(len(c_hidden))):
110
+ if i > 0:
111
+ self.up_upscalers.append(nn.Sequential(
112
+ LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6),
113
+ UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
114
+ ))
115
+ else:
116
+ self.up_upscalers.append(nn.Identity())
117
+ up_block = nn.ModuleList()
118
+ for j in range(blocks[1][::-1][i]):
119
+ for k, block_type in enumerate(level_config[i]):
120
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
121
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
122
+ self_attn=self_attn[i])
123
+ up_block.append(block)
124
+ self.up_blocks.append(up_block)
125
+ if block_repeat is not None:
126
+ block_repeat_mappers = nn.ModuleList()
127
+ for _ in range(block_repeat[1][::-1][i] - 1):
128
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
129
+ self.up_repeat_mappers.append(block_repeat_mappers)
130
+
131
+ # OUTPUT
132
+ self.clf = nn.Sequential(
133
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
134
+ operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
135
+ nn.PixelShuffle(patch_size),
136
+ )
137
+
138
+ # --- WEIGHT INIT ---
139
+ # self.apply(self._init_weights) # General init
140
+ # nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
141
+ # nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
142
+ # nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
143
+ # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
144
+ # nn.init.constant_(self.clf[1].weight, 0) # outputs
145
+ #
146
+ # # blocks
147
+ # for level_block in self.down_blocks + self.up_blocks:
148
+ # for block in level_block:
149
+ # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
150
+ # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
151
+ # elif isinstance(block, TimestepBlock):
152
+ # for layer in block.modules():
153
+ # if isinstance(layer, nn.Linear):
154
+ # nn.init.constant_(layer.weight, 0)
155
+ #
156
+ # def _init_weights(self, m):
157
+ # if isinstance(m, (nn.Conv2d, nn.Linear)):
158
+ # torch.nn.init.xavier_uniform_(m.weight)
159
+ # if m.bias is not None:
160
+ # nn.init.constant_(m.bias, 0)
161
+
162
+ def gen_r_embedding(self, r, max_positions=10000):
163
+ r = r * max_positions
164
+ half_dim = self.c_r // 2
165
+ emb = math.log(max_positions) / (half_dim - 1)
166
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
167
+ emb = r[:, None] * emb[None, :]
168
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
169
+ if self.c_r % 2 == 1: # zero pad
170
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
171
+ return emb
172
+
173
+ def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
174
+ clip_txt = self.clip_txt_mapper(clip_txt)
175
+ if len(clip_txt_pooled.shape) == 2:
176
+ clip_txt_pooled = clip_txt_pooled.unsqueeze(1)
177
+ if len(clip_img.shape) == 2:
178
+ clip_img = clip_img.unsqueeze(1)
179
+ clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
180
+ clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
181
+ clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
182
+ clip = self.clip_norm(clip)
183
+ return clip
184
+
185
+ def _down_encode(self, x, r_embed, clip, cnet=None):
186
+ level_outputs = []
187
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
188
+ for down_block, downscaler, repmap in block_group:
189
+ x = downscaler(x)
190
+ for i in range(len(repmap) + 1):
191
+ for block in down_block:
192
+ if isinstance(block, ResBlock) or (
193
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
194
+ ResBlock)):
195
+ if cnet is not None:
196
+ next_cnet = cnet.pop()
197
+ if next_cnet is not None:
198
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
199
+ align_corners=True).to(x.dtype)
200
+ x = block(x)
201
+ elif isinstance(block, AttnBlock) or (
202
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
203
+ AttnBlock)):
204
+ x = block(x, clip)
205
+ elif isinstance(block, TimestepBlock) or (
206
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
207
+ TimestepBlock)):
208
+ x = block(x, r_embed)
209
+ else:
210
+ x = block(x)
211
+ if i < len(repmap):
212
+ x = repmap[i](x)
213
+ level_outputs.insert(0, x)
214
+ return level_outputs
215
+
216
+ def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
217
+ x = level_outputs[0]
218
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
219
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
220
+ for j in range(len(repmap) + 1):
221
+ for k, block in enumerate(up_block):
222
+ if isinstance(block, ResBlock) or (
223
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
224
+ ResBlock)):
225
+ skip = level_outputs[i] if k == 0 and i > 0 else None
226
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
227
+ x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
228
+ align_corners=True)
229
+ if cnet is not None:
230
+ next_cnet = cnet.pop()
231
+ if next_cnet is not None:
232
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
233
+ align_corners=True).to(x.dtype)
234
+ x = block(x, skip)
235
+ elif isinstance(block, AttnBlock) or (
236
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
237
+ AttnBlock)):
238
+ x = block(x, clip)
239
+ elif isinstance(block, TimestepBlock) or (
240
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
241
+ TimestepBlock)):
242
+ x = block(x, r_embed)
243
+ else:
244
+ x = block(x)
245
+ if j < len(repmap):
246
+ x = repmap[j](x)
247
+ x = upscaler(x)
248
+ return x
249
+
250
+ def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
251
+ # Process the conditioning embeddings
252
+ r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
253
+ for c in self.t_conds:
254
+ t_cond = kwargs.get(c, torch.zeros_like(r))
255
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
256
+ clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
257
+
258
+ if control is not None:
259
+ cnet = control.get("input")
260
+ else:
261
+ cnet = None
262
+
263
+ # Model Blocks
264
+ x = self.embedding(x)
265
+ level_outputs = self._down_encode(x, r_embed, clip, cnet)
266
+ x = self._up_decode(level_outputs, r_embed, clip, cnet)
267
+ return self.clf(x)
268
+
269
+ def update_weights_ema(self, src_model, beta=0.999):
270
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
271
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
272
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
273
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
Backend/comfy/ldm/cascade/stage_c_coder.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+ import torch
19
+ import torchvision
20
+ from torch import nn
21
+
22
+
23
+ # EfficientNet
24
+ class EfficientNetEncoder(nn.Module):
25
+ def __init__(self, c_latent=16):
26
+ super().__init__()
27
+ self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
28
+ self.mapper = nn.Sequential(
29
+ nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
30
+ nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
31
+ )
32
+ self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
33
+ self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]))
34
+
35
+ def forward(self, x):
36
+ x = x * 0.5 + 0.5
37
+ x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
38
+ o = self.mapper(self.backbone(x))
39
+ return o
40
+
41
+
42
+ # Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
43
+ class Previewer(nn.Module):
44
+ def __init__(self, c_in=16, c_hidden=512, c_out=3):
45
+ super().__init__()
46
+ self.blocks = nn.Sequential(
47
+ nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
48
+ nn.GELU(),
49
+ nn.BatchNorm2d(c_hidden),
50
+
51
+ nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
52
+ nn.GELU(),
53
+ nn.BatchNorm2d(c_hidden),
54
+
55
+ nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
56
+ nn.GELU(),
57
+ nn.BatchNorm2d(c_hidden // 2),
58
+
59
+ nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
60
+ nn.GELU(),
61
+ nn.BatchNorm2d(c_hidden // 2),
62
+
63
+ nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
64
+ nn.GELU(),
65
+ nn.BatchNorm2d(c_hidden // 4),
66
+
67
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
68
+ nn.GELU(),
69
+ nn.BatchNorm2d(c_hidden // 4),
70
+
71
+ nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
72
+ nn.GELU(),
73
+ nn.BatchNorm2d(c_hidden // 4),
74
+
75
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
76
+ nn.GELU(),
77
+ nn.BatchNorm2d(c_hidden // 4),
78
+
79
+ nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
80
+ )
81
+
82
+ def forward(self, x):
83
+ return (self.blocks(x) - 0.5) * 2.0
84
+
85
+ class StageC_coder(nn.Module):
86
+ def __init__(self):
87
+ super().__init__()
88
+ self.previewer = Previewer()
89
+ self.encoder = EfficientNetEncoder()
90
+
91
+ def encode(self, x):
92
+ return self.encoder(x)
93
+
94
+ def decode(self, x):
95
+ return self.previewer(x)
Backend/comfy/ldm/common_dit.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
4
+ if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
5
+ padding_mode = "reflect"
6
+ pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
7
+ pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
8
+ return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
Backend/comfy/ldm/flux/layers.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import Tensor, nn
7
+
8
+ from .math import attention, rope
9
+ import comfy.ops
10
+
11
+ class EmbedND(nn.Module):
12
+ def __init__(self, dim: int, theta: int, axes_dim: list):
13
+ super().__init__()
14
+ self.dim = dim
15
+ self.theta = theta
16
+ self.axes_dim = axes_dim
17
+
18
+ def forward(self, ids: Tensor) -> Tensor:
19
+ n_axes = ids.shape[-1]
20
+ emb = torch.cat(
21
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
22
+ dim=-3,
23
+ )
24
+
25
+ return emb.unsqueeze(1)
26
+
27
+
28
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
29
+ """
30
+ Create sinusoidal timestep embeddings.
31
+ :param t: a 1-D Tensor of N indices, one per batch element.
32
+ These may be fractional.
33
+ :param dim: the dimension of the output.
34
+ :param max_period: controls the minimum frequency of the embeddings.
35
+ :return: an (N, D) Tensor of positional embeddings.
36
+ """
37
+ t = time_factor * t
38
+ half = dim // 2
39
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
40
+ t.device
41
+ )
42
+
43
+ args = t[:, None].float() * freqs[None]
44
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45
+ if dim % 2:
46
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
47
+ if torch.is_floating_point(t):
48
+ embedding = embedding.to(t)
49
+ return embedding
50
+
51
+
52
+ class MLPEmbedder(nn.Module):
53
+ def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
54
+ super().__init__()
55
+ self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
56
+ self.silu = nn.SiLU()
57
+ self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
58
+
59
+ def forward(self, x: Tensor) -> Tensor:
60
+ return self.out_layer(self.silu(self.in_layer(x)))
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ def __init__(self, dim: int, dtype=None, device=None, operations=None):
65
+ super().__init__()
66
+ self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
67
+
68
+ def forward(self, x: Tensor):
69
+ x_dtype = x.dtype
70
+ x = x.float()
71
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
72
+ return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
73
+
74
+
75
+ class QKNorm(torch.nn.Module):
76
+ def __init__(self, dim: int, dtype=None, device=None, operations=None):
77
+ super().__init__()
78
+ self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
79
+ self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
80
+
81
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
82
+ q = self.query_norm(q)
83
+ k = self.key_norm(k)
84
+ return q.to(v), k.to(v)
85
+
86
+
87
+ class SelfAttention(nn.Module):
88
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
89
+ super().__init__()
90
+ self.num_heads = num_heads
91
+ head_dim = dim // num_heads
92
+
93
+ self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
94
+ self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
95
+ self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
96
+
97
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
98
+ qkv = self.qkv(x)
99
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
100
+ q, k = self.norm(q, k, v)
101
+ x = attention(q, k, v, pe=pe)
102
+ x = self.proj(x)
103
+ return x
104
+
105
+
106
+ @dataclass
107
+ class ModulationOut:
108
+ shift: Tensor
109
+ scale: Tensor
110
+ gate: Tensor
111
+
112
+
113
+ class Modulation(nn.Module):
114
+ def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
115
+ super().__init__()
116
+ self.is_double = double
117
+ self.multiplier = 6 if double else 3
118
+ self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
119
+
120
+ def forward(self, vec: Tensor) -> tuple:
121
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
122
+
123
+ return (
124
+ ModulationOut(*out[:3]),
125
+ ModulationOut(*out[3:]) if self.is_double else None,
126
+ )
127
+
128
+
129
+ class DoubleStreamBlock(nn.Module):
130
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
131
+ super().__init__()
132
+
133
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
134
+ self.num_heads = num_heads
135
+ self.hidden_size = hidden_size
136
+ self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
137
+ self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
138
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
139
+
140
+ self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
141
+ self.img_mlp = nn.Sequential(
142
+ operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
143
+ nn.GELU(approximate="tanh"),
144
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
145
+ )
146
+
147
+ self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
148
+ self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
149
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
150
+
151
+ self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
152
+ self.txt_mlp = nn.Sequential(
153
+ operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
154
+ nn.GELU(approximate="tanh"),
155
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
156
+ )
157
+
158
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
159
+ img_mod1, img_mod2 = self.img_mod(vec)
160
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
161
+
162
+ # prepare image for attention
163
+ img_modulated = self.img_norm1(img)
164
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
165
+ img_qkv = self.img_attn.qkv(img_modulated)
166
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
167
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
168
+
169
+ # prepare txt for attention
170
+ txt_modulated = self.txt_norm1(txt)
171
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
172
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
173
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
174
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
175
+
176
+ # run actual attention
177
+ q = torch.cat((txt_q, img_q), dim=2)
178
+ k = torch.cat((txt_k, img_k), dim=2)
179
+ v = torch.cat((txt_v, img_v), dim=2)
180
+
181
+ attn = attention(q, k, v, pe=pe)
182
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
183
+
184
+ # calculate the img bloks
185
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
186
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
187
+
188
+ # calculate the txt bloks
189
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
190
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
191
+
192
+ if txt.dtype == torch.float16:
193
+ txt = txt.clip(-65504, 65504)
194
+
195
+ return img, txt
196
+
197
+
198
+ class SingleStreamBlock(nn.Module):
199
+ """
200
+ A DiT block with parallel linear layers as described in
201
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
202
+ """
203
+
204
+ def __init__(
205
+ self,
206
+ hidden_size: int,
207
+ num_heads: int,
208
+ mlp_ratio: float = 4.0,
209
+ qk_scale: float = None,
210
+ dtype=None,
211
+ device=None,
212
+ operations=None
213
+ ):
214
+ super().__init__()
215
+ self.hidden_dim = hidden_size
216
+ self.num_heads = num_heads
217
+ head_dim = hidden_size // num_heads
218
+ self.scale = qk_scale or head_dim**-0.5
219
+
220
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
221
+ # qkv and mlp_in
222
+ self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
223
+ # proj and mlp_out
224
+ self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
225
+
226
+ self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
227
+
228
+ self.hidden_size = hidden_size
229
+ self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
230
+
231
+ self.mlp_act = nn.GELU(approximate="tanh")
232
+ self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
233
+
234
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
235
+ mod, _ = self.modulation(vec)
236
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
237
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
238
+
239
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
240
+ q, k = self.norm(q, k, v)
241
+
242
+ # compute attention
243
+ attn = attention(q, k, v, pe=pe)
244
+ # compute activation in mlp stream, cat again and run second linear layer
245
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
246
+ x = x + mod.gate * output
247
+ if x.dtype == torch.float16:
248
+ x = x.clip(-65504, 65504)
249
+ return x
250
+
251
+
252
+ class LastLayer(nn.Module):
253
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
254
+ super().__init__()
255
+ self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
256
+ self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
257
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
258
+
259
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
260
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
261
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
262
+ x = self.linear(x)
263
+ return x
Backend/comfy/ldm/flux/math.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import Tensor
4
+ from comfy.ldm.modules.attention import optimized_attention
5
+ import comfy.model_management
6
+
7
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
8
+ q, k = apply_rope(q, k, pe)
9
+
10
+ heads = q.shape[1]
11
+ x = optimized_attention(q, k, v, heads, skip_reshape=True)
12
+ return x
13
+
14
+
15
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
16
+ assert dim % 2 == 0
17
+ if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu():
18
+ device = torch.device("cpu")
19
+ else:
20
+ device = pos.device
21
+
22
+ scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
23
+ omega = 1.0 / (theta**scale)
24
+ out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
25
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
26
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
27
+ return out.to(dtype=torch.float32, device=pos.device)
28
+
29
+
30
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
31
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
32
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
33
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
34
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
35
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
Backend/comfy/ldm/flux/model.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Original code can be found on: https://github.com/black-forest-labs/flux
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+ from .layers import (
9
+ DoubleStreamBlock,
10
+ EmbedND,
11
+ LastLayer,
12
+ MLPEmbedder,
13
+ SingleStreamBlock,
14
+ timestep_embedding,
15
+ )
16
+
17
+ from einops import rearrange, repeat
18
+ import comfy.ldm.common_dit
19
+
20
+ @dataclass
21
+ class FluxParams:
22
+ in_channels: int
23
+ vec_in_dim: int
24
+ context_in_dim: int
25
+ hidden_size: int
26
+ mlp_ratio: float
27
+ num_heads: int
28
+ depth: int
29
+ depth_single_blocks: int
30
+ axes_dim: list
31
+ theta: int
32
+ qkv_bias: bool
33
+ guidance_embed: bool
34
+
35
+
36
+ class Flux(nn.Module):
37
+ """
38
+ Transformer model for flow matching on sequences.
39
+ """
40
+
41
+ def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
42
+ super().__init__()
43
+ self.dtype = dtype
44
+ params = FluxParams(**kwargs)
45
+ self.params = params
46
+ self.in_channels = params.in_channels * 2 * 2
47
+ self.out_channels = self.in_channels
48
+ if params.hidden_size % params.num_heads != 0:
49
+ raise ValueError(
50
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
51
+ )
52
+ pe_dim = params.hidden_size // params.num_heads
53
+ if sum(params.axes_dim) != pe_dim:
54
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
55
+ self.hidden_size = params.hidden_size
56
+ self.num_heads = params.num_heads
57
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
58
+ self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
59
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
60
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
61
+ self.guidance_in = (
62
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
63
+ )
64
+ self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
65
+
66
+ self.double_blocks = nn.ModuleList(
67
+ [
68
+ DoubleStreamBlock(
69
+ self.hidden_size,
70
+ self.num_heads,
71
+ mlp_ratio=params.mlp_ratio,
72
+ qkv_bias=params.qkv_bias,
73
+ dtype=dtype, device=device, operations=operations
74
+ )
75
+ for _ in range(params.depth)
76
+ ]
77
+ )
78
+
79
+ self.single_blocks = nn.ModuleList(
80
+ [
81
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
82
+ for _ in range(params.depth_single_blocks)
83
+ ]
84
+ )
85
+
86
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
87
+
88
+ def forward_orig(
89
+ self,
90
+ img: Tensor,
91
+ img_ids: Tensor,
92
+ txt: Tensor,
93
+ txt_ids: Tensor,
94
+ timesteps: Tensor,
95
+ y: Tensor,
96
+ guidance: Tensor = None,
97
+ ) -> Tensor:
98
+ if img.ndim != 3 or txt.ndim != 3:
99
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
100
+
101
+ # running on sequences img
102
+ img = self.img_in(img)
103
+ vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
104
+ if self.params.guidance_embed:
105
+ if guidance is None:
106
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
107
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
108
+
109
+ vec = vec + self.vector_in(y)
110
+ txt = self.txt_in(txt)
111
+
112
+ ids = torch.cat((txt_ids, img_ids), dim=1)
113
+ pe = self.pe_embedder(ids)
114
+
115
+ for block in self.double_blocks:
116
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
117
+
118
+ img = torch.cat((txt, img), 1)
119
+ for block in self.single_blocks:
120
+ img = block(img, vec=vec, pe=pe)
121
+ img = img[:, txt.shape[1] :, ...]
122
+
123
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
124
+ return img
125
+
126
+ def forward(self, x, timestep, context, y, guidance, **kwargs):
127
+ bs, c, h, w = x.shape
128
+ patch_size = 2
129
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
130
+
131
+ img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
132
+
133
+ h_len = ((h + (patch_size // 2)) // patch_size)
134
+ w_len = ((w + (patch_size // 2)) // patch_size)
135
+ img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
136
+ img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
137
+ img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
138
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
139
+
140
+ txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
141
+ out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
142
+ return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
Backend/comfy/ldm/hydit/attn_layers.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Tuple, Union, Optional
4
+ from comfy.ldm.modules.attention import optimized_attention
5
+
6
+
7
+ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
8
+ """
9
+ Reshape frequency tensor for broadcasting it with another tensor.
10
+
11
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
12
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
13
+
14
+ Args:
15
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
16
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
17
+ head_first (bool): head dimension first (except batch dim) or not.
18
+
19
+ Returns:
20
+ torch.Tensor: Reshaped frequency tensor.
21
+
22
+ Raises:
23
+ AssertionError: If the frequency tensor doesn't match the expected shape.
24
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
25
+ """
26
+ ndim = x.ndim
27
+ assert 0 <= 1 < ndim
28
+
29
+ if isinstance(freqs_cis, tuple):
30
+ # freqs_cis: (cos, sin) in real space
31
+ if head_first:
32
+ assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
33
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
34
+ else:
35
+ assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
36
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
37
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
38
+ else:
39
+ # freqs_cis: values in complex space
40
+ if head_first:
41
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
42
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
43
+ else:
44
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
45
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
46
+ return freqs_cis.view(*shape)
47
+
48
+
49
+ def rotate_half(x):
50
+ x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
51
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
52
+
53
+
54
+ def apply_rotary_emb(
55
+ xq: torch.Tensor,
56
+ xk: Optional[torch.Tensor],
57
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
58
+ head_first: bool = False,
59
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
60
+ """
61
+ Apply rotary embeddings to input tensors using the given frequency tensor.
62
+
63
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
64
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
65
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
66
+ returned as real tensors.
67
+
68
+ Args:
69
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
70
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
71
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
72
+ head_first (bool): head dimension first (except batch dim) or not.
73
+
74
+ Returns:
75
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
76
+
77
+ """
78
+ xk_out = None
79
+ if isinstance(freqs_cis, tuple):
80
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
81
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
82
+ xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
83
+ if xk is not None:
84
+ xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
85
+ else:
86
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
87
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
88
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
89
+ if xk is not None:
90
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
91
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
92
+
93
+ return xq_out, xk_out
94
+
95
+
96
+
97
+ class CrossAttention(nn.Module):
98
+ """
99
+ Use QK Normalization.
100
+ """
101
+ def __init__(self,
102
+ qdim,
103
+ kdim,
104
+ num_heads,
105
+ qkv_bias=True,
106
+ qk_norm=False,
107
+ attn_drop=0.0,
108
+ proj_drop=0.0,
109
+ attn_precision=None,
110
+ device=None,
111
+ dtype=None,
112
+ operations=None,
113
+ ):
114
+ factory_kwargs = {'device': device, 'dtype': dtype}
115
+ super().__init__()
116
+ self.attn_precision = attn_precision
117
+ self.qdim = qdim
118
+ self.kdim = kdim
119
+ self.num_heads = num_heads
120
+ assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
121
+ self.head_dim = self.qdim // num_heads
122
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
123
+ self.scale = self.head_dim ** -0.5
124
+
125
+ self.q_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
126
+ self.kv_proj = operations.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
127
+
128
+ # TODO: eps should be 1 / 65530 if using fp16
129
+ self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
130
+ self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
131
+ self.attn_drop = nn.Dropout(attn_drop)
132
+ self.out_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
133
+ self.proj_drop = nn.Dropout(proj_drop)
134
+
135
+ def forward(self, x, y, freqs_cis_img=None):
136
+ """
137
+ Parameters
138
+ ----------
139
+ x: torch.Tensor
140
+ (batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
141
+ y: torch.Tensor
142
+ (batch, seqlen2, hidden_dim2)
143
+ freqs_cis_img: torch.Tensor
144
+ (batch, hidden_dim // 2), RoPE for image
145
+ """
146
+ b, s1, c = x.shape # [b, s1, D]
147
+ _, s2, c = y.shape # [b, s2, 1024]
148
+
149
+ q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
150
+ kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
151
+ k, v = kv.unbind(dim=2) # [b, s, h, d]
152
+ q = self.q_norm(q)
153
+ k = self.k_norm(k)
154
+
155
+ # Apply RoPE if needed
156
+ if freqs_cis_img is not None:
157
+ qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
158
+ assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
159
+ q = qq
160
+
161
+ q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
162
+ k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
163
+ v = v.transpose(-2, -3).contiguous()
164
+
165
+ context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
166
+
167
+ out = self.out_proj(context) # context.reshape - B, L1, -1
168
+ out = self.proj_drop(out)
169
+
170
+ out_tuple = (out,)
171
+
172
+ return out_tuple
173
+
174
+
175
+ class Attention(nn.Module):
176
+ """
177
+ We rename some layer names to align with flash attention
178
+ """
179
+ def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., attn_precision=None, dtype=None, device=None, operations=None):
180
+ super().__init__()
181
+ self.attn_precision = attn_precision
182
+ self.dim = dim
183
+ self.num_heads = num_heads
184
+ assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
185
+ self.head_dim = self.dim // num_heads
186
+ # This assertion is aligned with flash attention
187
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
188
+ self.scale = self.head_dim ** -0.5
189
+
190
+ # qkv --> Wqkv
191
+ self.Wqkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
192
+ # TODO: eps should be 1 / 65530 if using fp16
193
+ self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
194
+ self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
195
+ self.attn_drop = nn.Dropout(attn_drop)
196
+ self.out_proj = operations.Linear(dim, dim, dtype=dtype, device=device)
197
+ self.proj_drop = nn.Dropout(proj_drop)
198
+
199
+ def forward(self, x, freqs_cis_img=None):
200
+ B, N, C = x.shape
201
+ qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d]
202
+ q, k, v = qkv.unbind(0) # [b, h, s, d]
203
+ q = self.q_norm(q) # [b, h, s, d]
204
+ k = self.k_norm(k) # [b, h, s, d]
205
+
206
+ # Apply RoPE if needed
207
+ if freqs_cis_img is not None:
208
+ qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True)
209
+ assert qq.shape == q.shape and kk.shape == k.shape, \
210
+ f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
211
+ q, k = qq, kk
212
+
213
+ x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
214
+ x = self.out_proj(x)
215
+ x = self.proj_drop(x)
216
+
217
+ out_tuple = (x,)
218
+
219
+ return out_tuple
Backend/comfy/ldm/hydit/models.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ import comfy.ops
8
+ from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
9
+ from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
10
+ from torch.utils import checkpoint
11
+
12
+ from .attn_layers import Attention, CrossAttention
13
+ from .poolers import AttentionPool
14
+ from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
15
+
16
+ def calc_rope(x, patch_size, head_size):
17
+ th = (x.shape[2] + (patch_size // 2)) // patch_size
18
+ tw = (x.shape[3] + (patch_size // 2)) // patch_size
19
+ base_size = 512 // 8 // patch_size
20
+ start, stop = get_fill_resize_and_crop((th, tw), base_size)
21
+ sub_args = [start, stop, (th, tw)]
22
+ # head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
23
+ rope = get_2d_rotary_pos_embed(head_size, *sub_args)
24
+ return rope
25
+
26
+
27
+ def modulate(x, shift, scale):
28
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
29
+
30
+
31
+ class HunYuanDiTBlock(nn.Module):
32
+ """
33
+ A HunYuanDiT block with `add` conditioning.
34
+ """
35
+ def __init__(self,
36
+ hidden_size,
37
+ c_emb_size,
38
+ num_heads,
39
+ mlp_ratio=4.0,
40
+ text_states_dim=1024,
41
+ qk_norm=False,
42
+ norm_type="layer",
43
+ skip=False,
44
+ attn_precision=None,
45
+ dtype=None,
46
+ device=None,
47
+ operations=None,
48
+ ):
49
+ super().__init__()
50
+ use_ele_affine = True
51
+
52
+ if norm_type == "layer":
53
+ norm_layer = operations.LayerNorm
54
+ elif norm_type == "rms":
55
+ norm_layer = RMSNorm
56
+ else:
57
+ raise ValueError(f"Unknown norm_type: {norm_type}")
58
+
59
+ # ========================= Self-Attention =========================
60
+ self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
61
+ self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
62
+
63
+ # ========================= FFN =========================
64
+ self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
65
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
66
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
67
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0, dtype=dtype, device=device, operations=operations)
68
+
69
+ # ========================= Add =========================
70
+ # Simply use add like SDXL.
71
+ self.default_modulation = nn.Sequential(
72
+ nn.SiLU(),
73
+ operations.Linear(c_emb_size, hidden_size, bias=True, dtype=dtype, device=device)
74
+ )
75
+
76
+ # ========================= Cross-Attention =========================
77
+ self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
78
+ qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
79
+ self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
80
+
81
+ # ========================= Skip Connection =========================
82
+ if skip:
83
+ self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
84
+ self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, dtype=dtype, device=device)
85
+ else:
86
+ self.skip_linear = None
87
+
88
+ self.gradient_checkpointing = False
89
+
90
+ def _forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
91
+ # Long Skip Connection
92
+ if self.skip_linear is not None:
93
+ cat = torch.cat([x, skip], dim=-1)
94
+ cat = self.skip_norm(cat)
95
+ x = self.skip_linear(cat)
96
+
97
+ # Self-Attention
98
+ shift_msa = self.default_modulation(c).unsqueeze(dim=1)
99
+ attn_inputs = (
100
+ self.norm1(x) + shift_msa, freq_cis_img,
101
+ )
102
+ x = x + self.attn1(*attn_inputs)[0]
103
+
104
+ # Cross-Attention
105
+ cross_inputs = (
106
+ self.norm3(x), text_states, freq_cis_img
107
+ )
108
+ x = x + self.attn2(*cross_inputs)[0]
109
+
110
+ # FFN Layer
111
+ mlp_inputs = self.norm2(x)
112
+ x = x + self.mlp(mlp_inputs)
113
+
114
+ return x
115
+
116
+ def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
117
+ if self.gradient_checkpointing and self.training:
118
+ return checkpoint.checkpoint(self._forward, x, c, text_states, freq_cis_img, skip)
119
+ return self._forward(x, c, text_states, freq_cis_img, skip)
120
+
121
+
122
+ class FinalLayer(nn.Module):
123
+ """
124
+ The final layer of HunYuanDiT.
125
+ """
126
+ def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
127
+ super().__init__()
128
+ self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
129
+ self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
130
+ self.adaLN_modulation = nn.Sequential(
131
+ nn.SiLU(),
132
+ operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
133
+ )
134
+
135
+ def forward(self, x, c):
136
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
137
+ x = modulate(self.norm_final(x), shift, scale)
138
+ x = self.linear(x)
139
+ return x
140
+
141
+
142
+ class HunYuanDiT(nn.Module):
143
+ """
144
+ HunYuanDiT: Diffusion model with a Transformer backbone.
145
+
146
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
147
+
148
+ Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
149
+
150
+ Parameters
151
+ ----------
152
+ args: argparse.Namespace
153
+ The arguments parsed by argparse.
154
+ input_size: tuple
155
+ The size of the input image.
156
+ patch_size: int
157
+ The size of the patch.
158
+ in_channels: int
159
+ The number of input channels.
160
+ hidden_size: int
161
+ The hidden size of the transformer backbone.
162
+ depth: int
163
+ The number of transformer blocks.
164
+ num_heads: int
165
+ The number of attention heads.
166
+ mlp_ratio: float
167
+ The ratio of the hidden size of the MLP in the transformer block.
168
+ log_fn: callable
169
+ The logging function.
170
+ """
171
+ #@register_to_config
172
+ def __init__(self,
173
+ input_size: tuple = 32,
174
+ patch_size: int = 2,
175
+ in_channels: int = 4,
176
+ hidden_size: int = 1152,
177
+ depth: int = 28,
178
+ num_heads: int = 16,
179
+ mlp_ratio: float = 4.0,
180
+ text_states_dim = 1024,
181
+ text_states_dim_t5 = 2048,
182
+ text_len = 77,
183
+ text_len_t5 = 256,
184
+ qk_norm = True,# See http://arxiv.org/abs/2302.05442 for details.
185
+ size_cond = False,
186
+ use_style_cond = False,
187
+ learn_sigma = True,
188
+ norm = "layer",
189
+ log_fn: callable = print,
190
+ attn_precision=None,
191
+ dtype=None,
192
+ device=None,
193
+ operations=None,
194
+ **kwargs,
195
+ ):
196
+ super().__init__()
197
+ self.log_fn = log_fn
198
+ self.depth = depth
199
+ self.learn_sigma = learn_sigma
200
+ self.in_channels = in_channels
201
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
202
+ self.patch_size = patch_size
203
+ self.num_heads = num_heads
204
+ self.hidden_size = hidden_size
205
+ self.text_states_dim = text_states_dim
206
+ self.text_states_dim_t5 = text_states_dim_t5
207
+ self.text_len = text_len
208
+ self.text_len_t5 = text_len_t5
209
+ self.size_cond = size_cond
210
+ self.use_style_cond = use_style_cond
211
+ self.norm = norm
212
+ self.dtype = dtype
213
+ #import pdb
214
+ #pdb.set_trace()
215
+
216
+ self.mlp_t5 = nn.Sequential(
217
+ operations.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True, dtype=dtype, device=device),
218
+ nn.SiLU(),
219
+ operations.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True, dtype=dtype, device=device),
220
+ )
221
+ # learnable replace
222
+ self.text_embedding_padding = nn.Parameter(
223
+ torch.empty(self.text_len + self.text_len_t5, self.text_states_dim, dtype=dtype, device=device))
224
+
225
+ # Attention pooling
226
+ pooler_out_dim = 1024
227
+ self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=pooler_out_dim, dtype=dtype, device=device, operations=operations)
228
+
229
+ # Dimension of the extra input vectors
230
+ self.extra_in_dim = pooler_out_dim
231
+
232
+ if self.size_cond:
233
+ # Image size and crop size conditions
234
+ self.extra_in_dim += 6 * 256
235
+
236
+ if self.use_style_cond:
237
+ # Here we use a default learned embedder layer for future extension.
238
+ self.style_embedder = operations.Embedding(1, hidden_size, dtype=dtype, device=device)
239
+ self.extra_in_dim += hidden_size
240
+
241
+ # Text embedding for `add`
242
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, dtype=dtype, device=device, operations=operations)
243
+ self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device, operations=operations)
244
+ self.extra_embedder = nn.Sequential(
245
+ operations.Linear(self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device),
246
+ nn.SiLU(),
247
+ operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
248
+ )
249
+
250
+ # Image embedding
251
+ num_patches = self.x_embedder.num_patches
252
+
253
+ # HUnYuanDiT Blocks
254
+ self.blocks = nn.ModuleList([
255
+ HunYuanDiTBlock(hidden_size=hidden_size,
256
+ c_emb_size=hidden_size,
257
+ num_heads=num_heads,
258
+ mlp_ratio=mlp_ratio,
259
+ text_states_dim=self.text_states_dim,
260
+ qk_norm=qk_norm,
261
+ norm_type=self.norm,
262
+ skip=layer > depth // 2,
263
+ attn_precision=attn_precision,
264
+ dtype=dtype,
265
+ device=device,
266
+ operations=operations,
267
+ )
268
+ for layer in range(depth)
269
+ ])
270
+
271
+ self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
272
+ self.unpatchify_channels = self.out_channels
273
+
274
+
275
+
276
+ def forward(self,
277
+ x,
278
+ t,
279
+ context,#encoder_hidden_states=None,
280
+ text_embedding_mask=None,
281
+ encoder_hidden_states_t5=None,
282
+ text_embedding_mask_t5=None,
283
+ image_meta_size=None,
284
+ style=None,
285
+ return_dict=False,
286
+ control=None,
287
+ transformer_options=None,
288
+ ):
289
+ """
290
+ Forward pass of the encoder.
291
+
292
+ Parameters
293
+ ----------
294
+ x: torch.Tensor
295
+ (B, D, H, W)
296
+ t: torch.Tensor
297
+ (B)
298
+ encoder_hidden_states: torch.Tensor
299
+ CLIP text embedding, (B, L_clip, D)
300
+ text_embedding_mask: torch.Tensor
301
+ CLIP text embedding mask, (B, L_clip)
302
+ encoder_hidden_states_t5: torch.Tensor
303
+ T5 text embedding, (B, L_t5, D)
304
+ text_embedding_mask_t5: torch.Tensor
305
+ T5 text embedding mask, (B, L_t5)
306
+ image_meta_size: torch.Tensor
307
+ (B, 6)
308
+ style: torch.Tensor
309
+ (B)
310
+ cos_cis_img: torch.Tensor
311
+ sin_cis_img: torch.Tensor
312
+ return_dict: bool
313
+ Whether to return a dictionary.
314
+ """
315
+ #import pdb
316
+ #pdb.set_trace()
317
+ encoder_hidden_states = context
318
+ text_states = encoder_hidden_states # 2,77,1024
319
+ text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
320
+ text_states_mask = text_embedding_mask.bool() # 2,77
321
+ text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
322
+ b_t5, l_t5, c_t5 = text_states_t5.shape
323
+ text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
324
+
325
+ padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
326
+
327
+ text_states[:,-self.text_len:] = torch.where(text_states_mask[:,-self.text_len:].unsqueeze(2), text_states[:,-self.text_len:], padding[:self.text_len])
328
+ text_states_t5[:,-self.text_len_t5:] = torch.where(text_states_t5_mask[:,-self.text_len_t5:].unsqueeze(2), text_states_t5[:,-self.text_len_t5:], padding[self.text_len:])
329
+
330
+ text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,205,1024
331
+ # clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
332
+
333
+ _, _, oh, ow = x.shape
334
+ th, tw = (oh + (self.patch_size // 2)) // self.patch_size, (ow + (self.patch_size // 2)) // self.patch_size
335
+
336
+
337
+ # Get image RoPE embedding according to `reso`lution.
338
+ freqs_cis_img = calc_rope(x, self.patch_size, self.hidden_size // self.num_heads) #(cos_cis_img, sin_cis_img)
339
+
340
+ # ========================= Build time and image embedding =========================
341
+ t = self.t_embedder(t, dtype=x.dtype)
342
+ x = self.x_embedder(x)
343
+
344
+ # ========================= Concatenate all extra vectors =========================
345
+ # Build text tokens with pooling
346
+ extra_vec = self.pooler(encoder_hidden_states_t5)
347
+
348
+ # Build image meta size tokens if applicable
349
+ if self.size_cond:
350
+ image_meta_size = timestep_embedding(image_meta_size.view(-1), 256).to(x.dtype) # [B * 6, 256]
351
+ image_meta_size = image_meta_size.view(-1, 6 * 256)
352
+ extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
353
+
354
+ # Build style tokens
355
+ if self.use_style_cond:
356
+ if style is None:
357
+ style = torch.zeros((extra_vec.shape[0],), device=x.device, dtype=torch.int)
358
+ style_embedding = self.style_embedder(style, out_dtype=x.dtype)
359
+ extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
360
+
361
+ # Concatenate all extra vectors
362
+ c = t + self.extra_embedder(extra_vec) # [B, D]
363
+
364
+ controls = None
365
+ # ========================= Forward pass through HunYuanDiT blocks =========================
366
+ skips = []
367
+ for layer, block in enumerate(self.blocks):
368
+ if layer > self.depth // 2:
369
+ if controls is not None:
370
+ skip = skips.pop() + controls.pop()
371
+ else:
372
+ skip = skips.pop()
373
+ x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
374
+ else:
375
+ x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
376
+
377
+ if layer < (self.depth // 2 - 1):
378
+ skips.append(x)
379
+ if controls is not None and len(controls) != 0:
380
+ raise ValueError("The number of controls is not equal to the number of skip connections.")
381
+
382
+ # ========================= Final layer =========================
383
+ x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels)
384
+ x = self.unpatchify(x, th, tw) # (N, out_channels, H, W)
385
+
386
+ if return_dict:
387
+ return {'x': x}
388
+ if self.learn_sigma:
389
+ return x[:,:self.out_channels // 2,:oh,:ow]
390
+ return x[:,:,:oh,:ow]
391
+
392
+ def unpatchify(self, x, h, w):
393
+ """
394
+ x: (N, T, patch_size**2 * C)
395
+ imgs: (N, H, W, C)
396
+ """
397
+ c = self.unpatchify_channels
398
+ p = self.x_embedder.patch_size[0]
399
+ # h = w = int(x.shape[1] ** 0.5)
400
+ assert h * w == x.shape[1]
401
+
402
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
403
+ x = torch.einsum('nhwpqc->nchpwq', x)
404
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
405
+ return imgs
Backend/comfy/ldm/hydit/poolers.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from comfy.ldm.modules.attention import optimized_attention
5
+ import comfy.ops
6
+
7
+ class AttentionPool(nn.Module):
8
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None):
9
+ super().__init__()
10
+ self.positional_embedding = nn.Parameter(torch.empty(spacial_dim + 1, embed_dim, dtype=dtype, device=device))
11
+ self.k_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
12
+ self.q_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
13
+ self.v_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
14
+ self.c_proj = operations.Linear(embed_dim, output_dim or embed_dim, dtype=dtype, device=device)
15
+ self.num_heads = num_heads
16
+ self.embed_dim = embed_dim
17
+
18
+ def forward(self, x):
19
+ x = x[:,:self.positional_embedding.shape[0] - 1]
20
+ x = x.permute(1, 0, 2) # NLC -> LNC
21
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
22
+ x = x + comfy.ops.cast_to_input(self.positional_embedding[:, None, :], x) # (L+1)NC
23
+
24
+ q = self.q_proj(x[:1])
25
+ k = self.k_proj(x)
26
+ v = self.v_proj(x)
27
+
28
+ batch_size = q.shape[1]
29
+ head_dim = self.embed_dim // self.num_heads
30
+ q = q.view(1, batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
31
+ k = k.view(k.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
32
+ v = v.view(v.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
33
+
34
+ attn_output = optimized_attention(q, k, v, self.num_heads, skip_reshape=True).transpose(0, 1)
35
+
36
+ attn_output = self.c_proj(attn_output)
37
+ return attn_output.squeeze(0)
Backend/comfy/ldm/hydit/posemb_layers.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import Union
4
+
5
+
6
+ def _to_tuple(x):
7
+ if isinstance(x, int):
8
+ return x, x
9
+ else:
10
+ return x
11
+
12
+
13
+ def get_fill_resize_and_crop(src, tgt):
14
+ th, tw = _to_tuple(tgt)
15
+ h, w = _to_tuple(src)
16
+
17
+ tr = th / tw # base resolution
18
+ r = h / w # target resolution
19
+
20
+ # resize
21
+ if r > tr:
22
+ resize_height = th
23
+ resize_width = int(round(th / h * w))
24
+ else:
25
+ resize_width = tw
26
+ resize_height = int(round(tw / w * h)) # resize the target resolution down based on the base resolution
27
+
28
+ crop_top = int(round((th - resize_height) / 2.0))
29
+ crop_left = int(round((tw - resize_width) / 2.0))
30
+
31
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
32
+
33
+
34
+ def get_meshgrid(start, *args):
35
+ if len(args) == 0:
36
+ # start is grid_size
37
+ num = _to_tuple(start)
38
+ start = (0, 0)
39
+ stop = num
40
+ elif len(args) == 1:
41
+ # start is start, args[0] is stop, step is 1
42
+ start = _to_tuple(start)
43
+ stop = _to_tuple(args[0])
44
+ num = (stop[0] - start[0], stop[1] - start[1])
45
+ elif len(args) == 2:
46
+ # start is start, args[0] is stop, args[1] is num
47
+ start = _to_tuple(start)
48
+ stop = _to_tuple(args[0])
49
+ num = _to_tuple(args[1])
50
+ else:
51
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
52
+
53
+ grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32)
54
+ grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
55
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
56
+ grid = np.stack(grid, axis=0) # [2, W, H]
57
+ return grid
58
+
59
+ #################################################################################
60
+ # Sine/Cosine Positional Embedding Functions #
61
+ #################################################################################
62
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
63
+
64
+ def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0):
65
+ """
66
+ grid_size: int of the grid height and width
67
+ return:
68
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
69
+ """
70
+ grid = get_meshgrid(start, *args) # [2, H, w]
71
+ # grid_h = np.arange(grid_size, dtype=np.float32)
72
+ # grid_w = np.arange(grid_size, dtype=np.float32)
73
+ # grid = np.meshgrid(grid_w, grid_h) # here w goes first
74
+ # grid = np.stack(grid, axis=0) # [2, W, H]
75
+
76
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
77
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
78
+ if cls_token and extra_tokens > 0:
79
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
80
+ return pos_embed
81
+
82
+
83
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
84
+ assert embed_dim % 2 == 0
85
+
86
+ # use half of dimensions to encode grid_h
87
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
88
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
89
+
90
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
91
+ return emb
92
+
93
+
94
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
95
+ """
96
+ embed_dim: output dimension for each position
97
+ pos: a list of positions to be encoded: size (W,H)
98
+ out: (M, D)
99
+ """
100
+ assert embed_dim % 2 == 0
101
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
102
+ omega /= embed_dim / 2.
103
+ omega = 1. / 10000**omega # (D/2,)
104
+
105
+ pos = pos.reshape(-1) # (M,)
106
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
107
+
108
+ emb_sin = np.sin(out) # (M, D/2)
109
+ emb_cos = np.cos(out) # (M, D/2)
110
+
111
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
112
+ return emb
113
+
114
+
115
+ #################################################################################
116
+ # Rotary Positional Embedding Functions #
117
+ #################################################################################
118
+ # https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443
119
+
120
+ def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
121
+ """
122
+ This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure.
123
+
124
+ Parameters
125
+ ----------
126
+ embed_dim: int
127
+ embedding dimension size
128
+ start: int or tuple of int
129
+ If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1;
130
+ If len(args) == 2, start is start, args[0] is stop, args[1] is num.
131
+ use_real: bool
132
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
133
+
134
+ Returns
135
+ -------
136
+ pos_embed: torch.Tensor
137
+ [HW, D/2]
138
+ """
139
+ grid = get_meshgrid(start, *args) # [2, H, w]
140
+ grid = grid.reshape([2, 1, *grid.shape[1:]]) # Returns a sampling matrix with the same resolution as the target resolution
141
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
142
+ return pos_embed
143
+
144
+
145
+ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
146
+ assert embed_dim % 4 == 0
147
+
148
+ # use half of dimensions to encode grid_h
149
+ emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
150
+ emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
151
+
152
+ if use_real:
153
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
154
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
155
+ return cos, sin
156
+ else:
157
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
158
+ return emb
159
+
160
+
161
+ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
162
+ """
163
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
164
+
165
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
166
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
167
+ The returned tensor contains complex values in complex64 data type.
168
+
169
+ Args:
170
+ dim (int): Dimension of the frequency tensor.
171
+ pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar
172
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
173
+ use_real (bool, optional): If True, return real part and imaginary part separately.
174
+ Otherwise, return complex numbers.
175
+
176
+ Returns:
177
+ torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]
178
+
179
+ """
180
+ if isinstance(pos, int):
181
+ pos = np.arange(pos)
182
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
183
+ t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
184
+ freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
185
+ if use_real:
186
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
187
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
188
+ return freqs_cos, freqs_sin
189
+ else:
190
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
191
+ return freqs_cis
192
+
193
+
194
+
195
+ def calc_sizes(rope_img, patch_size, th, tw):
196
+ if rope_img == 'extend':
197
+ # Expansion mode
198
+ sub_args = [(th, tw)]
199
+ elif rope_img.startswith('base'):
200
+ # Based on the specified dimensions, other dimensions are obtained through interpolation.
201
+ base_size = int(rope_img[4:]) // 8 // patch_size
202
+ start, stop = get_fill_resize_and_crop((th, tw), base_size)
203
+ sub_args = [start, stop, (th, tw)]
204
+ else:
205
+ raise ValueError(f"Unknown rope_img: {rope_img}")
206
+ return sub_args
207
+
208
+
209
+ def init_image_posemb(rope_img,
210
+ resolutions,
211
+ patch_size,
212
+ hidden_size,
213
+ num_heads,
214
+ log_fn,
215
+ rope_real=True,
216
+ ):
217
+ freqs_cis_img = {}
218
+ for reso in resolutions:
219
+ th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size
220
+ sub_args = calc_sizes(rope_img, patch_size, th, tw)
221
+ freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real)
222
+ log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) "
223
+ f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}")
224
+ return freqs_cis_img
Backend/comfy/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from contextlib import contextmanager
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
6
+
7
+ from comfy.ldm.util import instantiate_from_config
8
+ from comfy.ldm.modules.ema import LitEma
9
+ import comfy.ops
10
+
11
+ class DiagonalGaussianRegularizer(torch.nn.Module):
12
+ def __init__(self, sample: bool = True):
13
+ super().__init__()
14
+ self.sample = sample
15
+
16
+ def get_trainable_parameters(self) -> Any:
17
+ yield from ()
18
+
19
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
20
+ log = dict()
21
+ posterior = DiagonalGaussianDistribution(z)
22
+ if self.sample:
23
+ z = posterior.sample()
24
+ else:
25
+ z = posterior.mode()
26
+ kl_loss = posterior.kl()
27
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
28
+ log["kl_loss"] = kl_loss
29
+ return z, log
30
+
31
+
32
+ class AbstractAutoencoder(torch.nn.Module):
33
+ """
34
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
35
+ unCLIP models, etc. Hence, it is fairly general, and specific features
36
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ ema_decay: Union[None, float] = None,
42
+ monitor: Union[None, str] = None,
43
+ input_key: str = "jpg",
44
+ **kwargs,
45
+ ):
46
+ super().__init__()
47
+
48
+ self.input_key = input_key
49
+ self.use_ema = ema_decay is not None
50
+ if monitor is not None:
51
+ self.monitor = monitor
52
+
53
+ if self.use_ema:
54
+ self.model_ema = LitEma(self, decay=ema_decay)
55
+ logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
56
+
57
+ def get_input(self, batch) -> Any:
58
+ raise NotImplementedError()
59
+
60
+ def on_train_batch_end(self, *args, **kwargs):
61
+ # for EMA computation
62
+ if self.use_ema:
63
+ self.model_ema(self)
64
+
65
+ @contextmanager
66
+ def ema_scope(self, context=None):
67
+ if self.use_ema:
68
+ self.model_ema.store(self.parameters())
69
+ self.model_ema.copy_to(self)
70
+ if context is not None:
71
+ logpy.info(f"{context}: Switched to EMA weights")
72
+ try:
73
+ yield None
74
+ finally:
75
+ if self.use_ema:
76
+ self.model_ema.restore(self.parameters())
77
+ if context is not None:
78
+ logpy.info(f"{context}: Restored training weights")
79
+
80
+ def encode(self, *args, **kwargs) -> torch.Tensor:
81
+ raise NotImplementedError("encode()-method of abstract base class called")
82
+
83
+ def decode(self, *args, **kwargs) -> torch.Tensor:
84
+ raise NotImplementedError("decode()-method of abstract base class called")
85
+
86
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
87
+ logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
88
+ return get_obj_from_str(cfg["target"])(
89
+ params, lr=lr, **cfg.get("params", dict())
90
+ )
91
+
92
+ def configure_optimizers(self) -> Any:
93
+ raise NotImplementedError()
94
+
95
+
96
+ class AutoencodingEngine(AbstractAutoencoder):
97
+ """
98
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
99
+ (we also restore them explicitly as special cases for legacy reasons).
100
+ Regularizations such as KL or VQ are moved to the regularizer class.
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ *args,
106
+ encoder_config: Dict,
107
+ decoder_config: Dict,
108
+ regularizer_config: Dict,
109
+ **kwargs,
110
+ ):
111
+ super().__init__(*args, **kwargs)
112
+
113
+ self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
114
+ self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
115
+ self.regularization: AbstractRegularizer = instantiate_from_config(
116
+ regularizer_config
117
+ )
118
+
119
+ def get_last_layer(self):
120
+ return self.decoder.get_last_layer()
121
+
122
+ def encode(
123
+ self,
124
+ x: torch.Tensor,
125
+ return_reg_log: bool = False,
126
+ unregularized: bool = False,
127
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
128
+ z = self.encoder(x)
129
+ if unregularized:
130
+ return z, dict()
131
+ z, reg_log = self.regularization(z)
132
+ if return_reg_log:
133
+ return z, reg_log
134
+ return z
135
+
136
+ def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
137
+ x = self.decoder(z, **kwargs)
138
+ return x
139
+
140
+ def forward(
141
+ self, x: torch.Tensor, **additional_decode_kwargs
142
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
143
+ z, reg_log = self.encode(x, return_reg_log=True)
144
+ dec = self.decode(z, **additional_decode_kwargs)
145
+ return z, dec, reg_log
146
+
147
+
148
+ class AutoencodingEngineLegacy(AutoencodingEngine):
149
+ def __init__(self, embed_dim: int, **kwargs):
150
+ self.max_batch_size = kwargs.pop("max_batch_size", None)
151
+ ddconfig = kwargs.pop("ddconfig")
152
+ super().__init__(
153
+ encoder_config={
154
+ "target": "comfy.ldm.modules.diffusionmodules.model.Encoder",
155
+ "params": ddconfig,
156
+ },
157
+ decoder_config={
158
+ "target": "comfy.ldm.modules.diffusionmodules.model.Decoder",
159
+ "params": ddconfig,
160
+ },
161
+ **kwargs,
162
+ )
163
+ self.quant_conv = comfy.ops.disable_weight_init.Conv2d(
164
+ (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
165
+ (1 + ddconfig["double_z"]) * embed_dim,
166
+ 1,
167
+ )
168
+ self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
169
+ self.embed_dim = embed_dim
170
+
171
+ def get_autoencoder_params(self) -> list:
172
+ params = super().get_autoencoder_params()
173
+ return params
174
+
175
+ def encode(
176
+ self, x: torch.Tensor, return_reg_log: bool = False
177
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
178
+ if self.max_batch_size is None:
179
+ z = self.encoder(x)
180
+ z = self.quant_conv(z)
181
+ else:
182
+ N = x.shape[0]
183
+ bs = self.max_batch_size
184
+ n_batches = int(math.ceil(N / bs))
185
+ z = list()
186
+ for i_batch in range(n_batches):
187
+ z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
188
+ z_batch = self.quant_conv(z_batch)
189
+ z.append(z_batch)
190
+ z = torch.cat(z, 0)
191
+
192
+ z, reg_log = self.regularization(z)
193
+ if return_reg_log:
194
+ return z, reg_log
195
+ return z
196
+
197
+ def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
198
+ if self.max_batch_size is None:
199
+ dec = self.post_quant_conv(z)
200
+ dec = self.decoder(dec, **decoder_kwargs)
201
+ else:
202
+ N = z.shape[0]
203
+ bs = self.max_batch_size
204
+ n_batches = int(math.ceil(N / bs))
205
+ dec = list()
206
+ for i_batch in range(n_batches):
207
+ dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
208
+ dec_batch = self.decoder(dec_batch, **decoder_kwargs)
209
+ dec.append(dec_batch)
210
+ dec = torch.cat(dec, 0)
211
+
212
+ return dec
213
+
214
+
215
+ class AutoencoderKL(AutoencodingEngineLegacy):
216
+ def __init__(self, **kwargs):
217
+ if "lossconfig" in kwargs:
218
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
219
+ super().__init__(
220
+ regularizer_config={
221
+ "target": (
222
+ "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"
223
+ )
224
+ },
225
+ **kwargs,
226
+ )
Backend/comfy/ldm/modules/attention.py ADDED
@@ -0,0 +1,865 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn, einsum
5
+ from einops import rearrange, repeat
6
+ from typing import Optional
7
+ import logging
8
+
9
+ from .diffusionmodules.util import AlphaBlender, timestep_embedding
10
+ from .sub_quadratic_attention import efficient_dot_product_attention
11
+
12
+ from comfy import model_management
13
+
14
+ if model_management.xformers_enabled():
15
+ import xformers
16
+ import xformers.ops
17
+
18
+ from comfy.cli_args import args
19
+ import comfy.ops
20
+ ops = comfy.ops.disable_weight_init
21
+
22
+ FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
23
+
24
+ def get_attn_precision(attn_precision):
25
+ if args.dont_upcast_attention:
26
+ return None
27
+ if FORCE_UPCAST_ATTENTION_DTYPE is not None:
28
+ return FORCE_UPCAST_ATTENTION_DTYPE
29
+ return attn_precision
30
+
31
+ def exists(val):
32
+ return val is not None
33
+
34
+
35
+ def uniq(arr):
36
+ return{el: True for el in arr}.keys()
37
+
38
+
39
+ def default(val, d):
40
+ if exists(val):
41
+ return val
42
+ return d
43
+
44
+
45
+ def max_neg_value(t):
46
+ return -torch.finfo(t.dtype).max
47
+
48
+
49
+ def init_(tensor):
50
+ dim = tensor.shape[-1]
51
+ std = 1 / math.sqrt(dim)
52
+ tensor.uniform_(-std, std)
53
+ return tensor
54
+
55
+
56
+ # feedforward
57
+ class GEGLU(nn.Module):
58
+ def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
59
+ super().__init__()
60
+ self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
61
+
62
+ def forward(self, x):
63
+ x, gate = self.proj(x).chunk(2, dim=-1)
64
+ return x * F.gelu(gate)
65
+
66
+
67
+ class FeedForward(nn.Module):
68
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=ops):
69
+ super().__init__()
70
+ inner_dim = int(dim * mult)
71
+ dim_out = default(dim_out, dim)
72
+ project_in = nn.Sequential(
73
+ operations.Linear(dim, inner_dim, dtype=dtype, device=device),
74
+ nn.GELU()
75
+ ) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device, operations=operations)
76
+
77
+ self.net = nn.Sequential(
78
+ project_in,
79
+ nn.Dropout(dropout),
80
+ operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
81
+ )
82
+
83
+ def forward(self, x):
84
+ return self.net(x)
85
+
86
+ def Normalize(in_channels, dtype=None, device=None):
87
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
88
+
89
+ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
90
+ attn_precision = get_attn_precision(attn_precision)
91
+
92
+ if skip_reshape:
93
+ b, _, _, dim_head = q.shape
94
+ else:
95
+ b, _, dim_head = q.shape
96
+ dim_head //= heads
97
+
98
+ scale = dim_head ** -0.5
99
+
100
+ h = heads
101
+ if skip_reshape:
102
+ q, k, v = map(
103
+ lambda t: t.reshape(b * heads, -1, dim_head),
104
+ (q, k, v),
105
+ )
106
+ else:
107
+ q, k, v = map(
108
+ lambda t: t.unsqueeze(3)
109
+ .reshape(b, -1, heads, dim_head)
110
+ .permute(0, 2, 1, 3)
111
+ .reshape(b * heads, -1, dim_head)
112
+ .contiguous(),
113
+ (q, k, v),
114
+ )
115
+
116
+ # force cast to fp32 to avoid overflowing
117
+ if attn_precision == torch.float32:
118
+ sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
119
+ else:
120
+ sim = einsum('b i d, b j d -> b i j', q, k) * scale
121
+
122
+ del q, k
123
+
124
+ if exists(mask):
125
+ if mask.dtype == torch.bool:
126
+ mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
127
+ max_neg_value = -torch.finfo(sim.dtype).max
128
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
129
+ sim.masked_fill_(~mask, max_neg_value)
130
+ else:
131
+ if len(mask.shape) == 2:
132
+ bs = 1
133
+ else:
134
+ bs = mask.shape[0]
135
+ mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
136
+ sim.add_(mask)
137
+
138
+ # attention, what we cannot get enough of
139
+ sim = sim.softmax(dim=-1)
140
+
141
+ out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
142
+ out = (
143
+ out.unsqueeze(0)
144
+ .reshape(b, heads, -1, dim_head)
145
+ .permute(0, 2, 1, 3)
146
+ .reshape(b, -1, heads * dim_head)
147
+ )
148
+ return out
149
+
150
+
151
+ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
152
+ attn_precision = get_attn_precision(attn_precision)
153
+
154
+ if skip_reshape:
155
+ b, _, _, dim_head = query.shape
156
+ else:
157
+ b, _, dim_head = query.shape
158
+ dim_head //= heads
159
+
160
+ scale = dim_head ** -0.5
161
+
162
+ if skip_reshape:
163
+ query = query.reshape(b * heads, -1, dim_head)
164
+ value = value.reshape(b * heads, -1, dim_head)
165
+ key = key.reshape(b * heads, -1, dim_head).movedim(1, 2)
166
+ else:
167
+ query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
168
+ value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
169
+ key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
170
+
171
+
172
+ dtype = query.dtype
173
+ upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
174
+ if upcast_attention:
175
+ bytes_per_token = torch.finfo(torch.float32).bits//8
176
+ else:
177
+ bytes_per_token = torch.finfo(query.dtype).bits//8
178
+ batch_x_heads, q_tokens, _ = query.shape
179
+ _, _, k_tokens = key.shape
180
+ qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
181
+
182
+ mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
183
+
184
+ kv_chunk_size_min = None
185
+ kv_chunk_size = None
186
+ query_chunk_size = None
187
+
188
+ for x in [4096, 2048, 1024, 512, 256]:
189
+ count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0)
190
+ if count >= k_tokens:
191
+ kv_chunk_size = k_tokens
192
+ query_chunk_size = x
193
+ break
194
+
195
+ if query_chunk_size is None:
196
+ query_chunk_size = 512
197
+
198
+ if mask is not None:
199
+ if len(mask.shape) == 2:
200
+ bs = 1
201
+ else:
202
+ bs = mask.shape[0]
203
+ mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
204
+
205
+ hidden_states = efficient_dot_product_attention(
206
+ query,
207
+ key,
208
+ value,
209
+ query_chunk_size=query_chunk_size,
210
+ kv_chunk_size=kv_chunk_size,
211
+ kv_chunk_size_min=kv_chunk_size_min,
212
+ use_checkpoint=False,
213
+ upcast_attention=upcast_attention,
214
+ mask=mask,
215
+ )
216
+
217
+ hidden_states = hidden_states.to(dtype)
218
+
219
+ hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
220
+ return hidden_states
221
+
222
+ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
223
+ attn_precision = get_attn_precision(attn_precision)
224
+
225
+ if skip_reshape:
226
+ b, _, _, dim_head = q.shape
227
+ else:
228
+ b, _, dim_head = q.shape
229
+ dim_head //= heads
230
+
231
+ scale = dim_head ** -0.5
232
+
233
+ h = heads
234
+ if skip_reshape:
235
+ q, k, v = map(
236
+ lambda t: t.reshape(b * heads, -1, dim_head),
237
+ (q, k, v),
238
+ )
239
+ else:
240
+ q, k, v = map(
241
+ lambda t: t.unsqueeze(3)
242
+ .reshape(b, -1, heads, dim_head)
243
+ .permute(0, 2, 1, 3)
244
+ .reshape(b * heads, -1, dim_head)
245
+ .contiguous(),
246
+ (q, k, v),
247
+ )
248
+
249
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
250
+
251
+ mem_free_total = model_management.get_free_memory(q.device)
252
+
253
+ if attn_precision == torch.float32:
254
+ element_size = 4
255
+ upcast = True
256
+ else:
257
+ element_size = q.element_size()
258
+ upcast = False
259
+
260
+ gb = 1024 ** 3
261
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
262
+ modifier = 3
263
+ mem_required = tensor_size * modifier
264
+ steps = 1
265
+
266
+
267
+ if mem_required > mem_free_total:
268
+ steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
269
+ # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
270
+ # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
271
+
272
+ if steps > 64:
273
+ max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
274
+ raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
275
+ f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
276
+
277
+ if mask is not None:
278
+ if len(mask.shape) == 2:
279
+ bs = 1
280
+ else:
281
+ bs = mask.shape[0]
282
+ mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
283
+
284
+ # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
285
+ first_op_done = False
286
+ cleared_cache = False
287
+ while True:
288
+ try:
289
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
290
+ for i in range(0, q.shape[1], slice_size):
291
+ end = i + slice_size
292
+ if upcast:
293
+ with torch.autocast(enabled=False, device_type = 'cuda'):
294
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
295
+ else:
296
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
297
+
298
+ if mask is not None:
299
+ if len(mask.shape) == 2:
300
+ s1 += mask[i:end]
301
+ else:
302
+ s1 += mask[:, i:end]
303
+
304
+ s2 = s1.softmax(dim=-1).to(v.dtype)
305
+ del s1
306
+ first_op_done = True
307
+
308
+ r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
309
+ del s2
310
+ break
311
+ except model_management.OOM_EXCEPTION as e:
312
+ if first_op_done == False:
313
+ model_management.soft_empty_cache(True)
314
+ if cleared_cache == False:
315
+ cleared_cache = True
316
+ logging.warning("out of memory error, emptying cache and trying again")
317
+ continue
318
+ steps *= 2
319
+ if steps > 64:
320
+ raise e
321
+ logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
322
+ else:
323
+ raise e
324
+
325
+ del q, k, v
326
+
327
+ r1 = (
328
+ r1.unsqueeze(0)
329
+ .reshape(b, heads, -1, dim_head)
330
+ .permute(0, 2, 1, 3)
331
+ .reshape(b, -1, heads * dim_head)
332
+ )
333
+ return r1
334
+
335
+ BROKEN_XFORMERS = False
336
+ try:
337
+ x_vers = xformers.__version__
338
+ # XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
339
+ BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
340
+ except:
341
+ pass
342
+
343
+ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
344
+ if skip_reshape:
345
+ b, _, _, dim_head = q.shape
346
+ else:
347
+ b, _, dim_head = q.shape
348
+ dim_head //= heads
349
+
350
+ disabled_xformers = False
351
+
352
+ if BROKEN_XFORMERS:
353
+ if b * heads > 65535:
354
+ disabled_xformers = True
355
+
356
+ if not disabled_xformers:
357
+ if torch.jit.is_tracing() or torch.jit.is_scripting():
358
+ disabled_xformers = True
359
+
360
+ if disabled_xformers:
361
+ return attention_pytorch(q, k, v, heads, mask)
362
+
363
+ if skip_reshape:
364
+ q, k, v = map(
365
+ lambda t: t.reshape(b * heads, -1, dim_head),
366
+ (q, k, v),
367
+ )
368
+ else:
369
+ q, k, v = map(
370
+ lambda t: t.reshape(b, -1, heads, dim_head),
371
+ (q, k, v),
372
+ )
373
+
374
+ if mask is not None:
375
+ pad = 8 - q.shape[1] % 8
376
+ mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
377
+ mask_out[:, :, :mask.shape[-1]] = mask
378
+ mask = mask_out[:, :, :mask.shape[-1]]
379
+
380
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
381
+
382
+ if skip_reshape:
383
+ out = (
384
+ out.unsqueeze(0)
385
+ .reshape(b, heads, -1, dim_head)
386
+ .permute(0, 2, 1, 3)
387
+ .reshape(b, -1, heads * dim_head)
388
+ )
389
+ else:
390
+ out = (
391
+ out.reshape(b, -1, heads * dim_head)
392
+ )
393
+
394
+ return out
395
+
396
+ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
397
+ if skip_reshape:
398
+ b, _, _, dim_head = q.shape
399
+ else:
400
+ b, _, dim_head = q.shape
401
+ dim_head //= heads
402
+ q, k, v = map(
403
+ lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
404
+ (q, k, v),
405
+ )
406
+
407
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
408
+ out = (
409
+ out.transpose(1, 2).reshape(b, -1, heads * dim_head)
410
+ )
411
+ return out
412
+
413
+
414
+ optimized_attention = attention_basic
415
+
416
+ if model_management.xformers_enabled():
417
+ logging.info("Using xformers cross attention")
418
+ optimized_attention = attention_xformers
419
+ elif model_management.pytorch_attention_enabled():
420
+ logging.info("Using pytorch cross attention")
421
+ optimized_attention = attention_pytorch
422
+ else:
423
+ if args.use_split_cross_attention:
424
+ logging.info("Using split optimization for cross attention")
425
+ optimized_attention = attention_split
426
+ else:
427
+ logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
428
+ optimized_attention = attention_sub_quad
429
+
430
+ optimized_attention_masked = optimized_attention
431
+
432
+ def optimized_attention_for_device(device, mask=False, small_input=False):
433
+ if small_input:
434
+ if model_management.pytorch_attention_enabled():
435
+ return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
436
+ else:
437
+ return attention_basic
438
+
439
+ if device == torch.device("cpu"):
440
+ return attention_sub_quad
441
+
442
+ if mask:
443
+ return optimized_attention_masked
444
+
445
+ return optimized_attention
446
+
447
+
448
+ class CrossAttention(nn.Module):
449
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops):
450
+ super().__init__()
451
+ inner_dim = dim_head * heads
452
+ context_dim = default(context_dim, query_dim)
453
+ self.attn_precision = attn_precision
454
+
455
+ self.heads = heads
456
+ self.dim_head = dim_head
457
+
458
+ self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
459
+ self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
460
+ self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
461
+
462
+ self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
463
+
464
+ def forward(self, x, context=None, value=None, mask=None):
465
+ q = self.to_q(x)
466
+ context = default(context, x)
467
+ k = self.to_k(context)
468
+ if value is not None:
469
+ v = self.to_v(value)
470
+ del value
471
+ else:
472
+ v = self.to_v(context)
473
+
474
+ if mask is None:
475
+ out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
476
+ else:
477
+ out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
478
+ return self.to_out(out)
479
+
480
+
481
+ class BasicTransformerBlock(nn.Module):
482
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
483
+ disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, attn_precision=None, dtype=None, device=None, operations=ops):
484
+ super().__init__()
485
+
486
+ self.ff_in = ff_in or inner_dim is not None
487
+ if inner_dim is None:
488
+ inner_dim = dim
489
+
490
+ self.is_res = inner_dim == dim
491
+ self.attn_precision = attn_precision
492
+
493
+ if self.ff_in:
494
+ self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
495
+ self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
496
+
497
+ self.disable_self_attn = disable_self_attn
498
+ self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
499
+ context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
500
+ self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
501
+
502
+ if disable_temporal_crossattention:
503
+ if switch_temporal_ca_to_sa:
504
+ raise ValueError
505
+ else:
506
+ self.attn2 = None
507
+ else:
508
+ context_dim_attn2 = None
509
+ if not switch_temporal_ca_to_sa:
510
+ context_dim_attn2 = context_dim
511
+
512
+ self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
513
+ heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
514
+ self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
515
+
516
+ self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
517
+ self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
518
+ self.n_heads = n_heads
519
+ self.d_head = d_head
520
+ self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
521
+
522
+ def forward(self, x, context=None, transformer_options={}):
523
+ extra_options = {}
524
+ block = transformer_options.get("block", None)
525
+ block_index = transformer_options.get("block_index", 0)
526
+ transformer_patches = {}
527
+ transformer_patches_replace = {}
528
+
529
+ for k in transformer_options:
530
+ if k == "patches":
531
+ transformer_patches = transformer_options[k]
532
+ elif k == "patches_replace":
533
+ transformer_patches_replace = transformer_options[k]
534
+ else:
535
+ extra_options[k] = transformer_options[k]
536
+
537
+ extra_options["n_heads"] = self.n_heads
538
+ extra_options["dim_head"] = self.d_head
539
+ extra_options["attn_precision"] = self.attn_precision
540
+
541
+ if self.ff_in:
542
+ x_skip = x
543
+ x = self.ff_in(self.norm_in(x))
544
+ if self.is_res:
545
+ x += x_skip
546
+
547
+ n = self.norm1(x)
548
+ if self.disable_self_attn:
549
+ context_attn1 = context
550
+ else:
551
+ context_attn1 = None
552
+ value_attn1 = None
553
+
554
+ if "attn1_patch" in transformer_patches:
555
+ patch = transformer_patches["attn1_patch"]
556
+ if context_attn1 is None:
557
+ context_attn1 = n
558
+ value_attn1 = context_attn1
559
+ for p in patch:
560
+ n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
561
+
562
+ if block is not None:
563
+ transformer_block = (block[0], block[1], block_index)
564
+ else:
565
+ transformer_block = None
566
+ attn1_replace_patch = transformer_patches_replace.get("attn1", {})
567
+ block_attn1 = transformer_block
568
+ if block_attn1 not in attn1_replace_patch:
569
+ block_attn1 = block
570
+
571
+ if block_attn1 in attn1_replace_patch:
572
+ if context_attn1 is None:
573
+ context_attn1 = n
574
+ value_attn1 = n
575
+ n = self.attn1.to_q(n)
576
+ context_attn1 = self.attn1.to_k(context_attn1)
577
+ value_attn1 = self.attn1.to_v(value_attn1)
578
+ n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
579
+ n = self.attn1.to_out(n)
580
+ else:
581
+ n = self.attn1(n, context=context_attn1, value=value_attn1)
582
+
583
+ if "attn1_output_patch" in transformer_patches:
584
+ patch = transformer_patches["attn1_output_patch"]
585
+ for p in patch:
586
+ n = p(n, extra_options)
587
+
588
+ x += n
589
+ if "middle_patch" in transformer_patches:
590
+ patch = transformer_patches["middle_patch"]
591
+ for p in patch:
592
+ x = p(x, extra_options)
593
+
594
+ if self.attn2 is not None:
595
+ n = self.norm2(x)
596
+ if self.switch_temporal_ca_to_sa:
597
+ context_attn2 = n
598
+ else:
599
+ context_attn2 = context
600
+ value_attn2 = None
601
+ if "attn2_patch" in transformer_patches:
602
+ patch = transformer_patches["attn2_patch"]
603
+ value_attn2 = context_attn2
604
+ for p in patch:
605
+ n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
606
+
607
+ attn2_replace_patch = transformer_patches_replace.get("attn2", {})
608
+ block_attn2 = transformer_block
609
+ if block_attn2 not in attn2_replace_patch:
610
+ block_attn2 = block
611
+
612
+ if block_attn2 in attn2_replace_patch:
613
+ if value_attn2 is None:
614
+ value_attn2 = context_attn2
615
+ n = self.attn2.to_q(n)
616
+ context_attn2 = self.attn2.to_k(context_attn2)
617
+ value_attn2 = self.attn2.to_v(value_attn2)
618
+ n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
619
+ n = self.attn2.to_out(n)
620
+ else:
621
+ n = self.attn2(n, context=context_attn2, value=value_attn2)
622
+
623
+ if "attn2_output_patch" in transformer_patches:
624
+ patch = transformer_patches["attn2_output_patch"]
625
+ for p in patch:
626
+ n = p(n, extra_options)
627
+
628
+ x += n
629
+ if self.is_res:
630
+ x_skip = x
631
+ x = self.ff(self.norm3(x))
632
+ if self.is_res:
633
+ x += x_skip
634
+
635
+ return x
636
+
637
+
638
+ class SpatialTransformer(nn.Module):
639
+ """
640
+ Transformer block for image-like data.
641
+ First, project the input (aka embedding)
642
+ and reshape to b, t, d.
643
+ Then apply standard transformer action.
644
+ Finally, reshape to image
645
+ NEW: use_linear for more efficiency instead of the 1x1 convs
646
+ """
647
+ def __init__(self, in_channels, n_heads, d_head,
648
+ depth=1, dropout=0., context_dim=None,
649
+ disable_self_attn=False, use_linear=False,
650
+ use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops):
651
+ super().__init__()
652
+ if exists(context_dim) and not isinstance(context_dim, list):
653
+ context_dim = [context_dim] * depth
654
+ self.in_channels = in_channels
655
+ inner_dim = n_heads * d_head
656
+ self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
657
+ if not use_linear:
658
+ self.proj_in = operations.Conv2d(in_channels,
659
+ inner_dim,
660
+ kernel_size=1,
661
+ stride=1,
662
+ padding=0, dtype=dtype, device=device)
663
+ else:
664
+ self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
665
+
666
+ self.transformer_blocks = nn.ModuleList(
667
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
668
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
669
+ for d in range(depth)]
670
+ )
671
+ if not use_linear:
672
+ self.proj_out = operations.Conv2d(inner_dim,in_channels,
673
+ kernel_size=1,
674
+ stride=1,
675
+ padding=0, dtype=dtype, device=device)
676
+ else:
677
+ self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
678
+ self.use_linear = use_linear
679
+
680
+ def forward(self, x, context=None, transformer_options={}):
681
+ # note: if no context is given, cross-attention defaults to self-attention
682
+ if not isinstance(context, list):
683
+ context = [context] * len(self.transformer_blocks)
684
+ b, c, h, w = x.shape
685
+ x_in = x
686
+ x = self.norm(x)
687
+ if not self.use_linear:
688
+ x = self.proj_in(x)
689
+ x = x.movedim(1, 3).flatten(1, 2).contiguous()
690
+ if self.use_linear:
691
+ x = self.proj_in(x)
692
+ for i, block in enumerate(self.transformer_blocks):
693
+ transformer_options["block_index"] = i
694
+ x = block(x, context=context[i], transformer_options=transformer_options)
695
+ if self.use_linear:
696
+ x = self.proj_out(x)
697
+ x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(3, 1).contiguous()
698
+ if not self.use_linear:
699
+ x = self.proj_out(x)
700
+ return x + x_in
701
+
702
+
703
+ class SpatialVideoTransformer(SpatialTransformer):
704
+ def __init__(
705
+ self,
706
+ in_channels,
707
+ n_heads,
708
+ d_head,
709
+ depth=1,
710
+ dropout=0.0,
711
+ use_linear=False,
712
+ context_dim=None,
713
+ use_spatial_context=False,
714
+ timesteps=None,
715
+ merge_strategy: str = "fixed",
716
+ merge_factor: float = 0.5,
717
+ time_context_dim=None,
718
+ ff_in=False,
719
+ checkpoint=False,
720
+ time_depth=1,
721
+ disable_self_attn=False,
722
+ disable_temporal_crossattention=False,
723
+ max_time_embed_period: int = 10000,
724
+ attn_precision=None,
725
+ dtype=None, device=None, operations=ops
726
+ ):
727
+ super().__init__(
728
+ in_channels,
729
+ n_heads,
730
+ d_head,
731
+ depth=depth,
732
+ dropout=dropout,
733
+ use_checkpoint=checkpoint,
734
+ context_dim=context_dim,
735
+ use_linear=use_linear,
736
+ disable_self_attn=disable_self_attn,
737
+ attn_precision=attn_precision,
738
+ dtype=dtype, device=device, operations=operations
739
+ )
740
+ self.time_depth = time_depth
741
+ self.depth = depth
742
+ self.max_time_embed_period = max_time_embed_period
743
+
744
+ time_mix_d_head = d_head
745
+ n_time_mix_heads = n_heads
746
+
747
+ time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
748
+
749
+ inner_dim = n_heads * d_head
750
+ if use_spatial_context:
751
+ time_context_dim = context_dim
752
+
753
+ self.time_stack = nn.ModuleList(
754
+ [
755
+ BasicTransformerBlock(
756
+ inner_dim,
757
+ n_time_mix_heads,
758
+ time_mix_d_head,
759
+ dropout=dropout,
760
+ context_dim=time_context_dim,
761
+ # timesteps=timesteps,
762
+ checkpoint=checkpoint,
763
+ ff_in=ff_in,
764
+ inner_dim=time_mix_inner_dim,
765
+ disable_self_attn=disable_self_attn,
766
+ disable_temporal_crossattention=disable_temporal_crossattention,
767
+ attn_precision=attn_precision,
768
+ dtype=dtype, device=device, operations=operations
769
+ )
770
+ for _ in range(self.depth)
771
+ ]
772
+ )
773
+
774
+ assert len(self.time_stack) == len(self.transformer_blocks)
775
+
776
+ self.use_spatial_context = use_spatial_context
777
+ self.in_channels = in_channels
778
+
779
+ time_embed_dim = self.in_channels * 4
780
+ self.time_pos_embed = nn.Sequential(
781
+ operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device),
782
+ nn.SiLU(),
783
+ operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device),
784
+ )
785
+
786
+ self.time_mixer = AlphaBlender(
787
+ alpha=merge_factor, merge_strategy=merge_strategy
788
+ )
789
+
790
+ def forward(
791
+ self,
792
+ x: torch.Tensor,
793
+ context: Optional[torch.Tensor] = None,
794
+ time_context: Optional[torch.Tensor] = None,
795
+ timesteps: Optional[int] = None,
796
+ image_only_indicator: Optional[torch.Tensor] = None,
797
+ transformer_options={}
798
+ ) -> torch.Tensor:
799
+ _, _, h, w = x.shape
800
+ x_in = x
801
+ spatial_context = None
802
+ if exists(context):
803
+ spatial_context = context
804
+
805
+ if self.use_spatial_context:
806
+ assert (
807
+ context.ndim == 3
808
+ ), f"n dims of spatial context should be 3 but are {context.ndim}"
809
+
810
+ if time_context is None:
811
+ time_context = context
812
+ time_context_first_timestep = time_context[::timesteps]
813
+ time_context = repeat(
814
+ time_context_first_timestep, "b ... -> (b n) ...", n=h * w
815
+ )
816
+ elif time_context is not None and not self.use_spatial_context:
817
+ time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
818
+ if time_context.ndim == 2:
819
+ time_context = rearrange(time_context, "b c -> b 1 c")
820
+
821
+ x = self.norm(x)
822
+ if not self.use_linear:
823
+ x = self.proj_in(x)
824
+ x = rearrange(x, "b c h w -> b (h w) c")
825
+ if self.use_linear:
826
+ x = self.proj_in(x)
827
+
828
+ num_frames = torch.arange(timesteps, device=x.device)
829
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
830
+ num_frames = rearrange(num_frames, "b t -> (b t)")
831
+ t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype)
832
+ emb = self.time_pos_embed(t_emb)
833
+ emb = emb[:, None, :]
834
+
835
+ for it_, (block, mix_block) in enumerate(
836
+ zip(self.transformer_blocks, self.time_stack)
837
+ ):
838
+ transformer_options["block_index"] = it_
839
+ x = block(
840
+ x,
841
+ context=spatial_context,
842
+ transformer_options=transformer_options,
843
+ )
844
+
845
+ x_mix = x
846
+ x_mix = x_mix + emb
847
+
848
+ B, S, C = x_mix.shape
849
+ x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
850
+ x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
851
+ x_mix = rearrange(
852
+ x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
853
+ )
854
+
855
+ x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator)
856
+
857
+ if self.use_linear:
858
+ x = self.proj_out(x)
859
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
860
+ if not self.use_linear:
861
+ x = self.proj_out(x)
862
+ out = x + x_in
863
+ return out
864
+
865
+
Backend/comfy/ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
Backend/comfy/ldm/modules/diffusionmodules/mmdit.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from typing import Dict, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from .. import attention
9
+ from einops import rearrange, repeat
10
+ from .util import timestep_embedding
11
+ import comfy.ops
12
+ import comfy.ldm.common_dit
13
+
14
+ def default(x, y):
15
+ if x is not None:
16
+ return x
17
+ return y
18
+
19
+ class Mlp(nn.Module):
20
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
21
+ """
22
+ def __init__(
23
+ self,
24
+ in_features,
25
+ hidden_features=None,
26
+ out_features=None,
27
+ act_layer=nn.GELU,
28
+ norm_layer=None,
29
+ bias=True,
30
+ drop=0.,
31
+ use_conv=False,
32
+ dtype=None,
33
+ device=None,
34
+ operations=None,
35
+ ):
36
+ super().__init__()
37
+ out_features = out_features or in_features
38
+ hidden_features = hidden_features or in_features
39
+ drop_probs = drop
40
+ linear_layer = partial(operations.Conv2d, kernel_size=1) if use_conv else operations.Linear
41
+
42
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
43
+ self.act = act_layer()
44
+ self.drop1 = nn.Dropout(drop_probs)
45
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
46
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
47
+ self.drop2 = nn.Dropout(drop_probs)
48
+
49
+ def forward(self, x):
50
+ x = self.fc1(x)
51
+ x = self.act(x)
52
+ x = self.drop1(x)
53
+ x = self.norm(x)
54
+ x = self.fc2(x)
55
+ x = self.drop2(x)
56
+ return x
57
+
58
+ class PatchEmbed(nn.Module):
59
+ """ 2D Image to Patch Embedding
60
+ """
61
+ dynamic_img_pad: torch.jit.Final[bool]
62
+
63
+ def __init__(
64
+ self,
65
+ img_size: Optional[int] = 224,
66
+ patch_size: int = 16,
67
+ in_chans: int = 3,
68
+ embed_dim: int = 768,
69
+ norm_layer = None,
70
+ flatten: bool = True,
71
+ bias: bool = True,
72
+ strict_img_size: bool = True,
73
+ dynamic_img_pad: bool = True,
74
+ padding_mode='circular',
75
+ dtype=None,
76
+ device=None,
77
+ operations=None,
78
+ ):
79
+ super().__init__()
80
+ self.patch_size = (patch_size, patch_size)
81
+ self.padding_mode = padding_mode
82
+ if img_size is not None:
83
+ self.img_size = (img_size, img_size)
84
+ self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
85
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
86
+ else:
87
+ self.img_size = None
88
+ self.grid_size = None
89
+ self.num_patches = None
90
+
91
+ # flatten spatial dim and transpose to channels last, kept for bwd compat
92
+ self.flatten = flatten
93
+ self.strict_img_size = strict_img_size
94
+ self.dynamic_img_pad = dynamic_img_pad
95
+
96
+ self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
97
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
98
+
99
+ def forward(self, x):
100
+ B, C, H, W = x.shape
101
+ # if self.img_size is not None:
102
+ # if self.strict_img_size:
103
+ # _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
104
+ # _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
105
+ # elif not self.dynamic_img_pad:
106
+ # _assert(
107
+ # H % self.patch_size[0] == 0,
108
+ # f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
109
+ # )
110
+ # _assert(
111
+ # W % self.patch_size[1] == 0,
112
+ # f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
113
+ # )
114
+ if self.dynamic_img_pad:
115
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
116
+ x = self.proj(x)
117
+ if self.flatten:
118
+ x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
119
+ x = self.norm(x)
120
+ return x
121
+
122
+ def modulate(x, shift, scale):
123
+ if shift is None:
124
+ shift = torch.zeros_like(scale)
125
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
126
+
127
+
128
+ #################################################################################
129
+ # Sine/Cosine Positional Embedding Functions #
130
+ #################################################################################
131
+
132
+
133
+ def get_2d_sincos_pos_embed(
134
+ embed_dim,
135
+ grid_size,
136
+ cls_token=False,
137
+ extra_tokens=0,
138
+ scaling_factor=None,
139
+ offset=None,
140
+ ):
141
+ """
142
+ grid_size: int of the grid height and width
143
+ return:
144
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
145
+ """
146
+ grid_h = np.arange(grid_size, dtype=np.float32)
147
+ grid_w = np.arange(grid_size, dtype=np.float32)
148
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
149
+ grid = np.stack(grid, axis=0)
150
+ if scaling_factor is not None:
151
+ grid = grid / scaling_factor
152
+ if offset is not None:
153
+ grid = grid - offset
154
+
155
+ grid = grid.reshape([2, 1, grid_size, grid_size])
156
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
157
+ if cls_token and extra_tokens > 0:
158
+ pos_embed = np.concatenate(
159
+ [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
160
+ )
161
+ return pos_embed
162
+
163
+
164
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
165
+ assert embed_dim % 2 == 0
166
+
167
+ # use half of dimensions to encode grid_h
168
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
169
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
170
+
171
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
172
+ return emb
173
+
174
+
175
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
176
+ """
177
+ embed_dim: output dimension for each position
178
+ pos: a list of positions to be encoded: size (M,)
179
+ out: (M, D)
180
+ """
181
+ assert embed_dim % 2 == 0
182
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
183
+ omega /= embed_dim / 2.0
184
+ omega = 1.0 / 10000**omega # (D/2,)
185
+
186
+ pos = pos.reshape(-1) # (M,)
187
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
188
+
189
+ emb_sin = np.sin(out) # (M, D/2)
190
+ emb_cos = np.cos(out) # (M, D/2)
191
+
192
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
193
+ return emb
194
+
195
+ def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos, device=None, dtype=torch.float32):
196
+ omega = torch.arange(embed_dim // 2, device=device, dtype=dtype)
197
+ omega /= embed_dim / 2.0
198
+ omega = 1.0 / 10000**omega # (D/2,)
199
+ pos = pos.reshape(-1) # (M,)
200
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
201
+ emb_sin = torch.sin(out) # (M, D/2)
202
+ emb_cos = torch.cos(out) # (M, D/2)
203
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
204
+ return emb
205
+
206
+ def get_2d_sincos_pos_embed_torch(embed_dim, w, h, val_center=7.5, val_magnitude=7.5, device=None, dtype=torch.float32):
207
+ small = min(h, w)
208
+ val_h = (h / small) * val_magnitude
209
+ val_w = (w / small) * val_magnitude
210
+ grid_h, grid_w = torch.meshgrid(torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), indexing='ij')
211
+ emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
212
+ emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
213
+ emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
214
+ return emb
215
+
216
+
217
+ #################################################################################
218
+ # Embedding Layers for Timesteps and Class Labels #
219
+ #################################################################################
220
+
221
+
222
+ class TimestepEmbedder(nn.Module):
223
+ """
224
+ Embeds scalar timesteps into vector representations.
225
+ """
226
+
227
+ def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
228
+ super().__init__()
229
+ self.mlp = nn.Sequential(
230
+ operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
231
+ nn.SiLU(),
232
+ operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
233
+ )
234
+ self.frequency_embedding_size = frequency_embedding_size
235
+
236
+ def forward(self, t, dtype, **kwargs):
237
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
238
+ t_emb = self.mlp(t_freq)
239
+ return t_emb
240
+
241
+
242
+ class VectorEmbedder(nn.Module):
243
+ """
244
+ Embeds a flat vector of dimension input_dim
245
+ """
246
+
247
+ def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None, operations=None):
248
+ super().__init__()
249
+ self.mlp = nn.Sequential(
250
+ operations.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
251
+ nn.SiLU(),
252
+ operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
253
+ )
254
+
255
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
256
+ emb = self.mlp(x)
257
+ return emb
258
+
259
+
260
+ #################################################################################
261
+ # Core DiT Model #
262
+ #################################################################################
263
+
264
+
265
+ def split_qkv(qkv, head_dim):
266
+ qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
267
+ return qkv[0], qkv[1], qkv[2]
268
+
269
+ def optimized_attention(qkv, num_heads):
270
+ return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
271
+
272
+ class SelfAttention(nn.Module):
273
+ ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
274
+
275
+ def __init__(
276
+ self,
277
+ dim: int,
278
+ num_heads: int = 8,
279
+ qkv_bias: bool = False,
280
+ qk_scale: Optional[float] = None,
281
+ proj_drop: float = 0.0,
282
+ attn_mode: str = "xformers",
283
+ pre_only: bool = False,
284
+ qk_norm: Optional[str] = None,
285
+ rmsnorm: bool = False,
286
+ dtype=None,
287
+ device=None,
288
+ operations=None,
289
+ ):
290
+ super().__init__()
291
+ self.num_heads = num_heads
292
+ self.head_dim = dim // num_heads
293
+
294
+ self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
295
+ if not pre_only:
296
+ self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
297
+ self.proj_drop = nn.Dropout(proj_drop)
298
+ assert attn_mode in self.ATTENTION_MODES
299
+ self.attn_mode = attn_mode
300
+ self.pre_only = pre_only
301
+
302
+ if qk_norm == "rms":
303
+ self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
304
+ self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
305
+ elif qk_norm == "ln":
306
+ self.ln_q = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
307
+ self.ln_k = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
308
+ elif qk_norm is None:
309
+ self.ln_q = nn.Identity()
310
+ self.ln_k = nn.Identity()
311
+ else:
312
+ raise ValueError(qk_norm)
313
+
314
+ def pre_attention(self, x: torch.Tensor) -> torch.Tensor:
315
+ B, L, C = x.shape
316
+ qkv = self.qkv(x)
317
+ q, k, v = split_qkv(qkv, self.head_dim)
318
+ q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
319
+ k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
320
+ return (q, k, v)
321
+
322
+ def post_attention(self, x: torch.Tensor) -> torch.Tensor:
323
+ assert not self.pre_only
324
+ x = self.proj(x)
325
+ x = self.proj_drop(x)
326
+ return x
327
+
328
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
329
+ qkv = self.pre_attention(x)
330
+ x = optimized_attention(
331
+ qkv, num_heads=self.num_heads
332
+ )
333
+ x = self.post_attention(x)
334
+ return x
335
+
336
+
337
+ class RMSNorm(torch.nn.Module):
338
+ def __init__(
339
+ self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
340
+ ):
341
+ """
342
+ Initialize the RMSNorm normalization layer.
343
+ Args:
344
+ dim (int): The dimension of the input tensor.
345
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
346
+ Attributes:
347
+ eps (float): A small value added to the denominator for numerical stability.
348
+ weight (nn.Parameter): Learnable scaling parameter.
349
+ """
350
+ super().__init__()
351
+ self.eps = eps
352
+ self.learnable_scale = elementwise_affine
353
+ if self.learnable_scale:
354
+ self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
355
+ else:
356
+ self.register_parameter("weight", None)
357
+
358
+ def _norm(self, x):
359
+ """
360
+ Apply the RMSNorm normalization to the input tensor.
361
+ Args:
362
+ x (torch.Tensor): The input tensor.
363
+ Returns:
364
+ torch.Tensor: The normalized tensor.
365
+ """
366
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
367
+
368
+ def forward(self, x):
369
+ """
370
+ Forward pass through the RMSNorm layer.
371
+ Args:
372
+ x (torch.Tensor): The input tensor.
373
+ Returns:
374
+ torch.Tensor: The output tensor after applying RMSNorm.
375
+ """
376
+ x = self._norm(x)
377
+ if self.learnable_scale:
378
+ return x * self.weight.to(device=x.device, dtype=x.dtype)
379
+ else:
380
+ return x
381
+
382
+
383
+ class SwiGLUFeedForward(nn.Module):
384
+ def __init__(
385
+ self,
386
+ dim: int,
387
+ hidden_dim: int,
388
+ multiple_of: int,
389
+ ffn_dim_multiplier: Optional[float] = None,
390
+ ):
391
+ """
392
+ Initialize the FeedForward module.
393
+
394
+ Args:
395
+ dim (int): Input dimension.
396
+ hidden_dim (int): Hidden dimension of the feedforward layer.
397
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
398
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
399
+
400
+ Attributes:
401
+ w1 (ColumnParallelLinear): Linear transformation for the first layer.
402
+ w2 (RowParallelLinear): Linear transformation for the second layer.
403
+ w3 (ColumnParallelLinear): Linear transformation for the third layer.
404
+
405
+ """
406
+ super().__init__()
407
+ hidden_dim = int(2 * hidden_dim / 3)
408
+ # custom dim factor multiplier
409
+ if ffn_dim_multiplier is not None:
410
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
411
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
412
+
413
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
414
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
415
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
416
+
417
+ def forward(self, x):
418
+ return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
419
+
420
+
421
+ class DismantledBlock(nn.Module):
422
+ """
423
+ A DiT block with gated adaptive layer norm (adaLN) conditioning.
424
+ """
425
+
426
+ ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
427
+
428
+ def __init__(
429
+ self,
430
+ hidden_size: int,
431
+ num_heads: int,
432
+ mlp_ratio: float = 4.0,
433
+ attn_mode: str = "xformers",
434
+ qkv_bias: bool = False,
435
+ pre_only: bool = False,
436
+ rmsnorm: bool = False,
437
+ scale_mod_only: bool = False,
438
+ swiglu: bool = False,
439
+ qk_norm: Optional[str] = None,
440
+ dtype=None,
441
+ device=None,
442
+ operations=None,
443
+ **block_kwargs,
444
+ ):
445
+ super().__init__()
446
+ assert attn_mode in self.ATTENTION_MODES
447
+ if not rmsnorm:
448
+ self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
449
+ else:
450
+ self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
451
+ self.attn = SelfAttention(
452
+ dim=hidden_size,
453
+ num_heads=num_heads,
454
+ qkv_bias=qkv_bias,
455
+ attn_mode=attn_mode,
456
+ pre_only=pre_only,
457
+ qk_norm=qk_norm,
458
+ rmsnorm=rmsnorm,
459
+ dtype=dtype,
460
+ device=device,
461
+ operations=operations
462
+ )
463
+ if not pre_only:
464
+ if not rmsnorm:
465
+ self.norm2 = operations.LayerNorm(
466
+ hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
467
+ )
468
+ else:
469
+ self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
470
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
471
+ if not pre_only:
472
+ if not swiglu:
473
+ self.mlp = Mlp(
474
+ in_features=hidden_size,
475
+ hidden_features=mlp_hidden_dim,
476
+ act_layer=lambda: nn.GELU(approximate="tanh"),
477
+ drop=0,
478
+ dtype=dtype,
479
+ device=device,
480
+ operations=operations
481
+ )
482
+ else:
483
+ self.mlp = SwiGLUFeedForward(
484
+ dim=hidden_size,
485
+ hidden_dim=mlp_hidden_dim,
486
+ multiple_of=256,
487
+ )
488
+ self.scale_mod_only = scale_mod_only
489
+ if not scale_mod_only:
490
+ n_mods = 6 if not pre_only else 2
491
+ else:
492
+ n_mods = 4 if not pre_only else 1
493
+ self.adaLN_modulation = nn.Sequential(
494
+ nn.SiLU(), operations.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device)
495
+ )
496
+ self.pre_only = pre_only
497
+
498
+ def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
499
+ if not self.pre_only:
500
+ if not self.scale_mod_only:
501
+ (
502
+ shift_msa,
503
+ scale_msa,
504
+ gate_msa,
505
+ shift_mlp,
506
+ scale_mlp,
507
+ gate_mlp,
508
+ ) = self.adaLN_modulation(c).chunk(6, dim=1)
509
+ else:
510
+ shift_msa = None
511
+ shift_mlp = None
512
+ (
513
+ scale_msa,
514
+ gate_msa,
515
+ scale_mlp,
516
+ gate_mlp,
517
+ ) = self.adaLN_modulation(
518
+ c
519
+ ).chunk(4, dim=1)
520
+ qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
521
+ return qkv, (
522
+ x,
523
+ gate_msa,
524
+ shift_mlp,
525
+ scale_mlp,
526
+ gate_mlp,
527
+ )
528
+ else:
529
+ if not self.scale_mod_only:
530
+ (
531
+ shift_msa,
532
+ scale_msa,
533
+ ) = self.adaLN_modulation(
534
+ c
535
+ ).chunk(2, dim=1)
536
+ else:
537
+ shift_msa = None
538
+ scale_msa = self.adaLN_modulation(c)
539
+ qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
540
+ return qkv, None
541
+
542
+ def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
543
+ assert not self.pre_only
544
+ x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
545
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(
546
+ modulate(self.norm2(x), shift_mlp, scale_mlp)
547
+ )
548
+ return x
549
+
550
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
551
+ assert not self.pre_only
552
+ qkv, intermediates = self.pre_attention(x, c)
553
+ attn = optimized_attention(
554
+ qkv,
555
+ num_heads=self.attn.num_heads,
556
+ )
557
+ return self.post_attention(attn, *intermediates)
558
+
559
+
560
+ def block_mixing(*args, use_checkpoint=True, **kwargs):
561
+ if use_checkpoint:
562
+ return torch.utils.checkpoint.checkpoint(
563
+ _block_mixing, *args, use_reentrant=False, **kwargs
564
+ )
565
+ else:
566
+ return _block_mixing(*args, **kwargs)
567
+
568
+
569
+ def _block_mixing(context, x, context_block, x_block, c):
570
+ context_qkv, context_intermediates = context_block.pre_attention(context, c)
571
+
572
+ x_qkv, x_intermediates = x_block.pre_attention(x, c)
573
+
574
+ o = []
575
+ for t in range(3):
576
+ o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))
577
+ qkv = tuple(o)
578
+
579
+ attn = optimized_attention(
580
+ qkv,
581
+ num_heads=x_block.attn.num_heads,
582
+ )
583
+ context_attn, x_attn = (
584
+ attn[:, : context_qkv[0].shape[1]],
585
+ attn[:, context_qkv[0].shape[1] :],
586
+ )
587
+
588
+ if not context_block.pre_only:
589
+ context = context_block.post_attention(context_attn, *context_intermediates)
590
+
591
+ else:
592
+ context = None
593
+ x = x_block.post_attention(x_attn, *x_intermediates)
594
+ return context, x
595
+
596
+
597
+ class JointBlock(nn.Module):
598
+ """just a small wrapper to serve as a fsdp unit"""
599
+
600
+ def __init__(
601
+ self,
602
+ *args,
603
+ **kwargs,
604
+ ):
605
+ super().__init__()
606
+ pre_only = kwargs.pop("pre_only")
607
+ qk_norm = kwargs.pop("qk_norm", None)
608
+ self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
609
+ self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
610
+
611
+ def forward(self, *args, **kwargs):
612
+ return block_mixing(
613
+ *args, context_block=self.context_block, x_block=self.x_block, **kwargs
614
+ )
615
+
616
+
617
+ class FinalLayer(nn.Module):
618
+ """
619
+ The final layer of DiT.
620
+ """
621
+
622
+ def __init__(
623
+ self,
624
+ hidden_size: int,
625
+ patch_size: int,
626
+ out_channels: int,
627
+ total_out_channels: Optional[int] = None,
628
+ dtype=None,
629
+ device=None,
630
+ operations=None,
631
+ ):
632
+ super().__init__()
633
+ self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
634
+ self.linear = (
635
+ operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
636
+ if (total_out_channels is None)
637
+ else operations.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device)
638
+ )
639
+ self.adaLN_modulation = nn.Sequential(
640
+ nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
641
+ )
642
+
643
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
644
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
645
+ x = modulate(self.norm_final(x), shift, scale)
646
+ x = self.linear(x)
647
+ return x
648
+
649
+ class SelfAttentionContext(nn.Module):
650
+ def __init__(self, dim, heads=8, dim_head=64, dtype=None, device=None, operations=None):
651
+ super().__init__()
652
+ dim_head = dim // heads
653
+ inner_dim = dim
654
+
655
+ self.heads = heads
656
+ self.dim_head = dim_head
657
+
658
+ self.qkv = operations.Linear(dim, dim * 3, bias=True, dtype=dtype, device=device)
659
+
660
+ self.proj = operations.Linear(inner_dim, dim, dtype=dtype, device=device)
661
+
662
+ def forward(self, x):
663
+ qkv = self.qkv(x)
664
+ q, k, v = split_qkv(qkv, self.dim_head)
665
+ x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
666
+ return self.proj(x)
667
+
668
+ class ContextProcessorBlock(nn.Module):
669
+ def __init__(self, context_size, dtype=None, device=None, operations=None):
670
+ super().__init__()
671
+ self.norm1 = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
672
+ self.attn = SelfAttentionContext(context_size, dtype=dtype, device=device, operations=operations)
673
+ self.norm2 = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
674
+ self.mlp = Mlp(in_features=context_size, hidden_features=(context_size * 4), act_layer=lambda: nn.GELU(approximate="tanh"), drop=0, dtype=dtype, device=device, operations=operations)
675
+
676
+ def forward(self, x):
677
+ x += self.attn(self.norm1(x))
678
+ x += self.mlp(self.norm2(x))
679
+ return x
680
+
681
+ class ContextProcessor(nn.Module):
682
+ def __init__(self, context_size, num_layers, dtype=None, device=None, operations=None):
683
+ super().__init__()
684
+ self.layers = torch.nn.ModuleList([ContextProcessorBlock(context_size, dtype=dtype, device=device, operations=operations) for i in range(num_layers)])
685
+ self.norm = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
686
+
687
+ def forward(self, x):
688
+ for i, l in enumerate(self.layers):
689
+ x = l(x)
690
+ return self.norm(x)
691
+
692
+ class MMDiT(nn.Module):
693
+ """
694
+ Diffusion model with a Transformer backbone.
695
+ """
696
+
697
+ def __init__(
698
+ self,
699
+ input_size: int = 32,
700
+ patch_size: int = 2,
701
+ in_channels: int = 4,
702
+ depth: int = 28,
703
+ # hidden_size: Optional[int] = None,
704
+ # num_heads: Optional[int] = None,
705
+ mlp_ratio: float = 4.0,
706
+ learn_sigma: bool = False,
707
+ adm_in_channels: Optional[int] = None,
708
+ context_embedder_config: Optional[Dict] = None,
709
+ compile_core: bool = False,
710
+ use_checkpoint: bool = False,
711
+ register_length: int = 0,
712
+ attn_mode: str = "torch",
713
+ rmsnorm: bool = False,
714
+ scale_mod_only: bool = False,
715
+ swiglu: bool = False,
716
+ out_channels: Optional[int] = None,
717
+ pos_embed_scaling_factor: Optional[float] = None,
718
+ pos_embed_offset: Optional[float] = None,
719
+ pos_embed_max_size: Optional[int] = None,
720
+ num_patches = None,
721
+ qk_norm: Optional[str] = None,
722
+ qkv_bias: bool = True,
723
+ context_processor_layers = None,
724
+ context_size = 4096,
725
+ num_blocks = None,
726
+ final_layer = True,
727
+ dtype = None, #TODO
728
+ device = None,
729
+ operations = None,
730
+ ):
731
+ super().__init__()
732
+ self.dtype = dtype
733
+ self.learn_sigma = learn_sigma
734
+ self.in_channels = in_channels
735
+ default_out_channels = in_channels * 2 if learn_sigma else in_channels
736
+ self.out_channels = default(out_channels, default_out_channels)
737
+ self.patch_size = patch_size
738
+ self.pos_embed_scaling_factor = pos_embed_scaling_factor
739
+ self.pos_embed_offset = pos_embed_offset
740
+ self.pos_embed_max_size = pos_embed_max_size
741
+
742
+ # hidden_size = default(hidden_size, 64 * depth)
743
+ # num_heads = default(num_heads, hidden_size // 64)
744
+
745
+ # apply magic --> this defines a head_size of 64
746
+ self.hidden_size = 64 * depth
747
+ num_heads = depth
748
+ if num_blocks is None:
749
+ num_blocks = depth
750
+
751
+ self.depth = depth
752
+ self.num_heads = num_heads
753
+
754
+ self.x_embedder = PatchEmbed(
755
+ input_size,
756
+ patch_size,
757
+ in_channels,
758
+ self.hidden_size,
759
+ bias=True,
760
+ strict_img_size=self.pos_embed_max_size is None,
761
+ dtype=dtype,
762
+ device=device,
763
+ operations=operations
764
+ )
765
+ self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
766
+
767
+ self.y_embedder = None
768
+ if adm_in_channels is not None:
769
+ assert isinstance(adm_in_channels, int)
770
+ self.y_embedder = VectorEmbedder(adm_in_channels, self.hidden_size, dtype=dtype, device=device, operations=operations)
771
+
772
+ if context_processor_layers is not None:
773
+ self.context_processor = ContextProcessor(context_size, context_processor_layers, dtype=dtype, device=device, operations=operations)
774
+ else:
775
+ self.context_processor = None
776
+
777
+ self.context_embedder = nn.Identity()
778
+ if context_embedder_config is not None:
779
+ if context_embedder_config["target"] == "torch.nn.Linear":
780
+ self.context_embedder = operations.Linear(**context_embedder_config["params"], dtype=dtype, device=device)
781
+
782
+ self.register_length = register_length
783
+ if self.register_length > 0:
784
+ self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size, dtype=dtype, device=device))
785
+
786
+ # num_patches = self.x_embedder.num_patches
787
+ # Will use fixed sin-cos embedding:
788
+ # just use a buffer already
789
+ if num_patches is not None:
790
+ self.register_buffer(
791
+ "pos_embed",
792
+ torch.empty(1, num_patches, self.hidden_size, dtype=dtype, device=device),
793
+ )
794
+ else:
795
+ self.pos_embed = None
796
+
797
+ self.use_checkpoint = use_checkpoint
798
+ self.joint_blocks = nn.ModuleList(
799
+ [
800
+ JointBlock(
801
+ self.hidden_size,
802
+ num_heads,
803
+ mlp_ratio=mlp_ratio,
804
+ qkv_bias=qkv_bias,
805
+ attn_mode=attn_mode,
806
+ pre_only=(i == num_blocks - 1) and final_layer,
807
+ rmsnorm=rmsnorm,
808
+ scale_mod_only=scale_mod_only,
809
+ swiglu=swiglu,
810
+ qk_norm=qk_norm,
811
+ dtype=dtype,
812
+ device=device,
813
+ operations=operations
814
+ )
815
+ for i in range(num_blocks)
816
+ ]
817
+ )
818
+
819
+ if final_layer:
820
+ self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
821
+
822
+ if compile_core:
823
+ assert False
824
+ self.forward_core_with_concat = torch.compile(self.forward_core_with_concat)
825
+
826
+ def cropped_pos_embed(self, hw, device=None):
827
+ p = self.x_embedder.patch_size[0]
828
+ h, w = hw
829
+ # patched size
830
+ h = (h + 1) // p
831
+ w = (w + 1) // p
832
+ if self.pos_embed is None:
833
+ return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device)
834
+ assert self.pos_embed_max_size is not None
835
+ assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
836
+ assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
837
+ top = (self.pos_embed_max_size - h) // 2
838
+ left = (self.pos_embed_max_size - w) // 2
839
+ spatial_pos_embed = rearrange(
840
+ self.pos_embed,
841
+ "1 (h w) c -> 1 h w c",
842
+ h=self.pos_embed_max_size,
843
+ w=self.pos_embed_max_size,
844
+ )
845
+ spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
846
+ spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
847
+ # print(spatial_pos_embed, top, left, h, w)
848
+ # # t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.875, 7.875, device=device) #matches exactly for 1024 res
849
+ # t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.5, 7.5, device=device) #scales better
850
+ # # print(t)
851
+ # return t
852
+ return spatial_pos_embed
853
+
854
+ def unpatchify(self, x, hw=None):
855
+ """
856
+ x: (N, T, patch_size**2 * C)
857
+ imgs: (N, H, W, C)
858
+ """
859
+ c = self.out_channels
860
+ p = self.x_embedder.patch_size[0]
861
+ if hw is None:
862
+ h = w = int(x.shape[1] ** 0.5)
863
+ else:
864
+ h, w = hw
865
+ h = (h + 1) // p
866
+ w = (w + 1) // p
867
+ assert h * w == x.shape[1]
868
+
869
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
870
+ x = torch.einsum("nhwpqc->nchpwq", x)
871
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
872
+ return imgs
873
+
874
+ def forward_core_with_concat(
875
+ self,
876
+ x: torch.Tensor,
877
+ c_mod: torch.Tensor,
878
+ context: Optional[torch.Tensor] = None,
879
+ control = None,
880
+ ) -> torch.Tensor:
881
+ if self.register_length > 0:
882
+ context = torch.cat(
883
+ (
884
+ repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
885
+ default(context, torch.Tensor([]).type_as(x)),
886
+ ),
887
+ 1,
888
+ )
889
+
890
+ # context is B, L', D
891
+ # x is B, L, D
892
+ blocks = len(self.joint_blocks)
893
+ for i in range(blocks):
894
+ context, x = self.joint_blocks[i](
895
+ context,
896
+ x,
897
+ c=c_mod,
898
+ use_checkpoint=self.use_checkpoint,
899
+ )
900
+ if control is not None:
901
+ control_o = control.get("output")
902
+ if i < len(control_o):
903
+ add = control_o[i]
904
+ if add is not None:
905
+ x += add
906
+
907
+ x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
908
+ return x
909
+
910
+ def forward(
911
+ self,
912
+ x: torch.Tensor,
913
+ t: torch.Tensor,
914
+ y: Optional[torch.Tensor] = None,
915
+ context: Optional[torch.Tensor] = None,
916
+ control = None,
917
+ ) -> torch.Tensor:
918
+ """
919
+ Forward pass of DiT.
920
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
921
+ t: (N,) tensor of diffusion timesteps
922
+ y: (N,) tensor of class labels
923
+ """
924
+
925
+ if self.context_processor is not None:
926
+ context = self.context_processor(context)
927
+
928
+ hw = x.shape[-2:]
929
+ x = self.x_embedder(x) + comfy.ops.cast_to_input(self.cropped_pos_embed(hw, device=x.device), x)
930
+ c = self.t_embedder(t, dtype=x.dtype) # (N, D)
931
+ if y is not None and self.y_embedder is not None:
932
+ y = self.y_embedder(y) # (N, D)
933
+ c = c + y # (N, D)
934
+
935
+ if context is not None:
936
+ context = self.context_embedder(context)
937
+
938
+ x = self.forward_core_with_concat(x, c, context, control)
939
+
940
+ x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
941
+ return x[:,:,:hw[-2],:hw[-1]]
942
+
943
+
944
+ class OpenAISignatureMMDITWrapper(MMDiT):
945
+ def forward(
946
+ self,
947
+ x: torch.Tensor,
948
+ timesteps: torch.Tensor,
949
+ context: Optional[torch.Tensor] = None,
950
+ y: Optional[torch.Tensor] = None,
951
+ control = None,
952
+ **kwargs,
953
+ ) -> torch.Tensor:
954
+ return super().forward(x, timesteps, context=context, y=y, control=control)
955
+