Commit
•
69c26b8
1
Parent(s):
974def6
Upload folder using huggingface_hub
Browse files- controlnet_flux.py +418 -0
- images/0.jpg +0 -0
- images/1.jpg +0 -0
- images/2.jpg +0 -0
- images/3.jpg +0 -0
- images/alibaba.png +0 -0
- images/alibabaalimama.png +0 -0
- images/alimama.png +0 -0
- images/flux1.jpg +0 -0
- images/flux2.jpg +0 -0
- images/flux3.jpg +0 -0
- main.py +50 -0
- pipeline_flux_controlnet_inpaint.py +1049 -0
- readme.md +78 -0
- transformer_flux.py +525 -0
controlnet_flux.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
8 |
+
from diffusers.loaders import PeftAdapterMixin
|
9 |
+
from diffusers.models.modeling_utils import ModelMixin
|
10 |
+
from diffusers.models.attention_processor import AttentionProcessor
|
11 |
+
from diffusers.utils import (
|
12 |
+
USE_PEFT_BACKEND,
|
13 |
+
is_torch_version,
|
14 |
+
logging,
|
15 |
+
scale_lora_layers,
|
16 |
+
unscale_lora_layers,
|
17 |
+
)
|
18 |
+
from diffusers.models.controlnet import BaseOutput, zero_module
|
19 |
+
from diffusers.models.embeddings import (
|
20 |
+
CombinedTimestepGuidanceTextProjEmbeddings,
|
21 |
+
CombinedTimestepTextProjEmbeddings,
|
22 |
+
)
|
23 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
24 |
+
from transformer_flux import (
|
25 |
+
EmbedND,
|
26 |
+
FluxSingleTransformerBlock,
|
27 |
+
FluxTransformerBlock,
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class FluxControlNetOutput(BaseOutput):
|
36 |
+
controlnet_block_samples: Tuple[torch.Tensor]
|
37 |
+
controlnet_single_block_samples: Tuple[torch.Tensor]
|
38 |
+
|
39 |
+
|
40 |
+
class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
41 |
+
_supports_gradient_checkpointing = True
|
42 |
+
|
43 |
+
@register_to_config
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
patch_size: int = 1,
|
47 |
+
in_channels: int = 64,
|
48 |
+
num_layers: int = 19,
|
49 |
+
num_single_layers: int = 38,
|
50 |
+
attention_head_dim: int = 128,
|
51 |
+
num_attention_heads: int = 24,
|
52 |
+
joint_attention_dim: int = 4096,
|
53 |
+
pooled_projection_dim: int = 768,
|
54 |
+
guidance_embeds: bool = False,
|
55 |
+
axes_dims_rope: List[int] = [16, 56, 56],
|
56 |
+
extra_condition_channels: int = 1 * 4,
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
self.out_channels = in_channels
|
60 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
61 |
+
|
62 |
+
self.pos_embed = EmbedND(
|
63 |
+
dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope
|
64 |
+
)
|
65 |
+
text_time_guidance_cls = (
|
66 |
+
CombinedTimestepGuidanceTextProjEmbeddings
|
67 |
+
if guidance_embeds
|
68 |
+
else CombinedTimestepTextProjEmbeddings
|
69 |
+
)
|
70 |
+
self.time_text_embed = text_time_guidance_cls(
|
71 |
+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
72 |
+
)
|
73 |
+
|
74 |
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
75 |
+
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
|
76 |
+
|
77 |
+
self.transformer_blocks = nn.ModuleList(
|
78 |
+
[
|
79 |
+
FluxTransformerBlock(
|
80 |
+
dim=self.inner_dim,
|
81 |
+
num_attention_heads=num_attention_heads,
|
82 |
+
attention_head_dim=attention_head_dim,
|
83 |
+
)
|
84 |
+
for _ in range(num_layers)
|
85 |
+
]
|
86 |
+
)
|
87 |
+
|
88 |
+
self.single_transformer_blocks = nn.ModuleList(
|
89 |
+
[
|
90 |
+
FluxSingleTransformerBlock(
|
91 |
+
dim=self.inner_dim,
|
92 |
+
num_attention_heads=num_attention_heads,
|
93 |
+
attention_head_dim=attention_head_dim,
|
94 |
+
)
|
95 |
+
for _ in range(num_single_layers)
|
96 |
+
]
|
97 |
+
)
|
98 |
+
|
99 |
+
# controlnet_blocks
|
100 |
+
self.controlnet_blocks = nn.ModuleList([])
|
101 |
+
for _ in range(len(self.transformer_blocks)):
|
102 |
+
self.controlnet_blocks.append(
|
103 |
+
zero_module(nn.Linear(self.inner_dim, self.inner_dim))
|
104 |
+
)
|
105 |
+
|
106 |
+
self.controlnet_single_blocks = nn.ModuleList([])
|
107 |
+
for _ in range(len(self.single_transformer_blocks)):
|
108 |
+
self.controlnet_single_blocks.append(
|
109 |
+
zero_module(nn.Linear(self.inner_dim, self.inner_dim))
|
110 |
+
)
|
111 |
+
|
112 |
+
self.controlnet_x_embedder = zero_module(
|
113 |
+
torch.nn.Linear(in_channels + extra_condition_channels, self.inner_dim)
|
114 |
+
)
|
115 |
+
|
116 |
+
self.gradient_checkpointing = False
|
117 |
+
|
118 |
+
@property
|
119 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
120 |
+
def attn_processors(self):
|
121 |
+
r"""
|
122 |
+
Returns:
|
123 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
124 |
+
indexed by its weight name.
|
125 |
+
"""
|
126 |
+
# set recursively
|
127 |
+
processors = {}
|
128 |
+
|
129 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
130 |
+
if hasattr(module, "get_processor"):
|
131 |
+
processors[f"{name}.processor"] = module.get_processor()
|
132 |
+
|
133 |
+
for sub_name, child in module.named_children():
|
134 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
135 |
+
|
136 |
+
return processors
|
137 |
+
|
138 |
+
for name, module in self.named_children():
|
139 |
+
fn_recursive_add_processors(name, module, processors)
|
140 |
+
|
141 |
+
return processors
|
142 |
+
|
143 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
144 |
+
def set_attn_processor(self, processor):
|
145 |
+
r"""
|
146 |
+
Sets the attention processor to use to compute attention.
|
147 |
+
|
148 |
+
Parameters:
|
149 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
150 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
151 |
+
for **all** `Attention` layers.
|
152 |
+
|
153 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
154 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
155 |
+
|
156 |
+
"""
|
157 |
+
count = len(self.attn_processors.keys())
|
158 |
+
|
159 |
+
if isinstance(processor, dict) and len(processor) != count:
|
160 |
+
raise ValueError(
|
161 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
162 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
163 |
+
)
|
164 |
+
|
165 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
166 |
+
if hasattr(module, "set_processor"):
|
167 |
+
if not isinstance(processor, dict):
|
168 |
+
module.set_processor(processor)
|
169 |
+
else:
|
170 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
171 |
+
|
172 |
+
for sub_name, child in module.named_children():
|
173 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
174 |
+
|
175 |
+
for name, module in self.named_children():
|
176 |
+
fn_recursive_attn_processor(name, module, processor)
|
177 |
+
|
178 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
179 |
+
if hasattr(module, "gradient_checkpointing"):
|
180 |
+
module.gradient_checkpointing = value
|
181 |
+
|
182 |
+
@classmethod
|
183 |
+
def from_transformer(
|
184 |
+
cls,
|
185 |
+
transformer,
|
186 |
+
num_layers: int = 4,
|
187 |
+
num_single_layers: int = 10,
|
188 |
+
attention_head_dim: int = 128,
|
189 |
+
num_attention_heads: int = 24,
|
190 |
+
load_weights_from_transformer=True,
|
191 |
+
):
|
192 |
+
config = transformer.config
|
193 |
+
config["num_layers"] = num_layers
|
194 |
+
config["num_single_layers"] = num_single_layers
|
195 |
+
config["attention_head_dim"] = attention_head_dim
|
196 |
+
config["num_attention_heads"] = num_attention_heads
|
197 |
+
|
198 |
+
controlnet = cls(**config)
|
199 |
+
|
200 |
+
if load_weights_from_transformer:
|
201 |
+
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
|
202 |
+
controlnet.time_text_embed.load_state_dict(
|
203 |
+
transformer.time_text_embed.state_dict()
|
204 |
+
)
|
205 |
+
controlnet.context_embedder.load_state_dict(
|
206 |
+
transformer.context_embedder.state_dict()
|
207 |
+
)
|
208 |
+
controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
|
209 |
+
controlnet.transformer_blocks.load_state_dict(
|
210 |
+
transformer.transformer_blocks.state_dict(), strict=False
|
211 |
+
)
|
212 |
+
controlnet.single_transformer_blocks.load_state_dict(
|
213 |
+
transformer.single_transformer_blocks.state_dict(), strict=False
|
214 |
+
)
|
215 |
+
|
216 |
+
controlnet.controlnet_x_embedder = zero_module(
|
217 |
+
controlnet.controlnet_x_embedder
|
218 |
+
)
|
219 |
+
|
220 |
+
return controlnet
|
221 |
+
|
222 |
+
def forward(
|
223 |
+
self,
|
224 |
+
hidden_states: torch.Tensor,
|
225 |
+
controlnet_cond: torch.Tensor,
|
226 |
+
conditioning_scale: float = 1.0,
|
227 |
+
encoder_hidden_states: torch.Tensor = None,
|
228 |
+
pooled_projections: torch.Tensor = None,
|
229 |
+
timestep: torch.LongTensor = None,
|
230 |
+
img_ids: torch.Tensor = None,
|
231 |
+
txt_ids: torch.Tensor = None,
|
232 |
+
guidance: torch.Tensor = None,
|
233 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
234 |
+
return_dict: bool = True,
|
235 |
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
236 |
+
"""
|
237 |
+
The [`FluxTransformer2DModel`] forward method.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
241 |
+
Input `hidden_states`.
|
242 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
243 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
244 |
+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
245 |
+
from the embeddings of input conditions.
|
246 |
+
timestep ( `torch.LongTensor`):
|
247 |
+
Used to indicate denoising step.
|
248 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
249 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
250 |
+
joint_attention_kwargs (`dict`, *optional*):
|
251 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
252 |
+
`self.processor` in
|
253 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
254 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
255 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
256 |
+
tuple.
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
260 |
+
`tuple` where the first element is the sample tensor.
|
261 |
+
"""
|
262 |
+
if joint_attention_kwargs is not None:
|
263 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
264 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
265 |
+
else:
|
266 |
+
lora_scale = 1.0
|
267 |
+
|
268 |
+
if USE_PEFT_BACKEND:
|
269 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
270 |
+
scale_lora_layers(self, lora_scale)
|
271 |
+
else:
|
272 |
+
if (
|
273 |
+
joint_attention_kwargs is not None
|
274 |
+
and joint_attention_kwargs.get("scale", None) is not None
|
275 |
+
):
|
276 |
+
logger.warning(
|
277 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
278 |
+
)
|
279 |
+
hidden_states = self.x_embedder(hidden_states)
|
280 |
+
|
281 |
+
# add condition
|
282 |
+
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
|
283 |
+
|
284 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
285 |
+
if guidance is not None:
|
286 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
287 |
+
else:
|
288 |
+
guidance = None
|
289 |
+
temb = (
|
290 |
+
self.time_text_embed(timestep, pooled_projections)
|
291 |
+
if guidance is None
|
292 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
293 |
+
)
|
294 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
295 |
+
|
296 |
+
txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
|
297 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
298 |
+
image_rotary_emb = self.pos_embed(ids)
|
299 |
+
|
300 |
+
block_samples = ()
|
301 |
+
for _, block in enumerate(self.transformer_blocks):
|
302 |
+
if self.training and self.gradient_checkpointing:
|
303 |
+
|
304 |
+
def create_custom_forward(module, return_dict=None):
|
305 |
+
def custom_forward(*inputs):
|
306 |
+
if return_dict is not None:
|
307 |
+
return module(*inputs, return_dict=return_dict)
|
308 |
+
else:
|
309 |
+
return module(*inputs)
|
310 |
+
|
311 |
+
return custom_forward
|
312 |
+
|
313 |
+
ckpt_kwargs: Dict[str, Any] = (
|
314 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
315 |
+
)
|
316 |
+
(
|
317 |
+
encoder_hidden_states,
|
318 |
+
hidden_states,
|
319 |
+
) = torch.utils.checkpoint.checkpoint(
|
320 |
+
create_custom_forward(block),
|
321 |
+
hidden_states,
|
322 |
+
encoder_hidden_states,
|
323 |
+
temb,
|
324 |
+
image_rotary_emb,
|
325 |
+
**ckpt_kwargs,
|
326 |
+
)
|
327 |
+
|
328 |
+
else:
|
329 |
+
encoder_hidden_states, hidden_states = block(
|
330 |
+
hidden_states=hidden_states,
|
331 |
+
encoder_hidden_states=encoder_hidden_states,
|
332 |
+
temb=temb,
|
333 |
+
image_rotary_emb=image_rotary_emb,
|
334 |
+
)
|
335 |
+
block_samples = block_samples + (hidden_states,)
|
336 |
+
|
337 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
338 |
+
|
339 |
+
single_block_samples = ()
|
340 |
+
for _, block in enumerate(self.single_transformer_blocks):
|
341 |
+
if self.training and self.gradient_checkpointing:
|
342 |
+
|
343 |
+
def create_custom_forward(module, return_dict=None):
|
344 |
+
def custom_forward(*inputs):
|
345 |
+
if return_dict is not None:
|
346 |
+
return module(*inputs, return_dict=return_dict)
|
347 |
+
else:
|
348 |
+
return module(*inputs)
|
349 |
+
|
350 |
+
return custom_forward
|
351 |
+
|
352 |
+
ckpt_kwargs: Dict[str, Any] = (
|
353 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
354 |
+
)
|
355 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
356 |
+
create_custom_forward(block),
|
357 |
+
hidden_states,
|
358 |
+
temb,
|
359 |
+
image_rotary_emb,
|
360 |
+
**ckpt_kwargs,
|
361 |
+
)
|
362 |
+
|
363 |
+
else:
|
364 |
+
hidden_states = block(
|
365 |
+
hidden_states=hidden_states,
|
366 |
+
temb=temb,
|
367 |
+
image_rotary_emb=image_rotary_emb,
|
368 |
+
)
|
369 |
+
single_block_samples = single_block_samples + (
|
370 |
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
371 |
+
)
|
372 |
+
|
373 |
+
# controlnet block
|
374 |
+
controlnet_block_samples = ()
|
375 |
+
for block_sample, controlnet_block in zip(
|
376 |
+
block_samples, self.controlnet_blocks
|
377 |
+
):
|
378 |
+
block_sample = controlnet_block(block_sample)
|
379 |
+
controlnet_block_samples = controlnet_block_samples + (block_sample,)
|
380 |
+
|
381 |
+
controlnet_single_block_samples = ()
|
382 |
+
for single_block_sample, controlnet_block in zip(
|
383 |
+
single_block_samples, self.controlnet_single_blocks
|
384 |
+
):
|
385 |
+
single_block_sample = controlnet_block(single_block_sample)
|
386 |
+
controlnet_single_block_samples = controlnet_single_block_samples + (
|
387 |
+
single_block_sample,
|
388 |
+
)
|
389 |
+
|
390 |
+
# scaling
|
391 |
+
controlnet_block_samples = [
|
392 |
+
sample * conditioning_scale for sample in controlnet_block_samples
|
393 |
+
]
|
394 |
+
controlnet_single_block_samples = [
|
395 |
+
sample * conditioning_scale for sample in controlnet_single_block_samples
|
396 |
+
]
|
397 |
+
|
398 |
+
#
|
399 |
+
controlnet_block_samples = (
|
400 |
+
None if len(controlnet_block_samples) == 0 else controlnet_block_samples
|
401 |
+
)
|
402 |
+
controlnet_single_block_samples = (
|
403 |
+
None
|
404 |
+
if len(controlnet_single_block_samples) == 0
|
405 |
+
else controlnet_single_block_samples
|
406 |
+
)
|
407 |
+
|
408 |
+
if USE_PEFT_BACKEND:
|
409 |
+
# remove `lora_scale` from each PEFT layer
|
410 |
+
unscale_lora_layers(self, lora_scale)
|
411 |
+
|
412 |
+
if not return_dict:
|
413 |
+
return (controlnet_block_samples, controlnet_single_block_samples)
|
414 |
+
|
415 |
+
return FluxControlNetOutput(
|
416 |
+
controlnet_block_samples=controlnet_block_samples,
|
417 |
+
controlnet_single_block_samples=controlnet_single_block_samples,
|
418 |
+
)
|
images/0.jpg
ADDED
images/1.jpg
ADDED
images/2.jpg
ADDED
images/3.jpg
ADDED
images/alibaba.png
ADDED
images/alibabaalimama.png
ADDED
images/alimama.png
ADDED
images/flux1.jpg
ADDED
images/flux2.jpg
ADDED
images/flux3.jpg
ADDED
main.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers.utils import load_image, check_min_version
|
3 |
+
from controlnet_flux import FluxControlNetModel
|
4 |
+
from transformer_flux import FluxTransformer2DModel
|
5 |
+
from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
|
6 |
+
|
7 |
+
check_min_version("0.30.2")
|
8 |
+
|
9 |
+
# Set image path , mask path and prompt
|
10 |
+
image_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket.png',
|
11 |
+
mask_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket_mask.jpeg',
|
12 |
+
prompt='a person wearing a white shoe, carrying a white bucket with text "FLUX" on it'
|
13 |
+
|
14 |
+
# Build pipeline
|
15 |
+
controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16)
|
16 |
+
transformer = FluxTransformer2DModel.from_pretrained(
|
17 |
+
"black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dytpe=torch.bfloat16
|
18 |
+
)
|
19 |
+
pipe = FluxControlNetInpaintingPipeline.from_pretrained(
|
20 |
+
"black-forest-labs/FLUX.1-dev",
|
21 |
+
controlnet=controlnet,
|
22 |
+
transformer=transformer,
|
23 |
+
torch_dtype=torch.bfloat16
|
24 |
+
).to("cuda")
|
25 |
+
pipe.transformer.to(torch.bfloat16)
|
26 |
+
pipe.controlnet.to(torch.bfloat16)
|
27 |
+
|
28 |
+
# Load image and mask
|
29 |
+
size = (768, 768)
|
30 |
+
image = load_image(image_path).convert("RGB").resize(size)
|
31 |
+
mask = load_image(mask_path).convert("RGB").resize(size)
|
32 |
+
generator = torch.Generator(device="cuda").manual_seed(24)
|
33 |
+
|
34 |
+
# Inpaint
|
35 |
+
result = pipe(
|
36 |
+
prompt=prompt,
|
37 |
+
height=size[1],
|
38 |
+
width=size[0],
|
39 |
+
control_image=image,
|
40 |
+
control_mask=mask,
|
41 |
+
num_inference_steps=28,
|
42 |
+
generator=generator,
|
43 |
+
controlnet_conditioning_scale=0.9,
|
44 |
+
guidance_scale=3.5,
|
45 |
+
negative_prompt="",
|
46 |
+
true_guidance_scale=3.5
|
47 |
+
).images[0]
|
48 |
+
|
49 |
+
result.save('flux_inpaint.png')
|
50 |
+
print("Successfully inpaint image")
|
pipeline_flux_controlnet_inpaint.py
ADDED
@@ -0,0 +1,1049 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from transformers import (
|
7 |
+
CLIPTextModel,
|
8 |
+
CLIPTokenizer,
|
9 |
+
T5EncoderModel,
|
10 |
+
T5TokenizerFast,
|
11 |
+
)
|
12 |
+
|
13 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
14 |
+
from diffusers.loaders import FluxLoraLoaderMixin
|
15 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
16 |
+
|
17 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
18 |
+
from diffusers.utils import (
|
19 |
+
USE_PEFT_BACKEND,
|
20 |
+
is_torch_xla_available,
|
21 |
+
logging,
|
22 |
+
replace_example_docstring,
|
23 |
+
scale_lora_layers,
|
24 |
+
unscale_lora_layers,
|
25 |
+
)
|
26 |
+
from diffusers.utils.torch_utils import randn_tensor
|
27 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
28 |
+
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
29 |
+
|
30 |
+
from transformer_flux import FluxTransformer2DModel
|
31 |
+
from controlnet_flux import FluxControlNetModel
|
32 |
+
|
33 |
+
if is_torch_xla_available():
|
34 |
+
import torch_xla.core.xla_model as xm
|
35 |
+
|
36 |
+
XLA_AVAILABLE = True
|
37 |
+
else:
|
38 |
+
XLA_AVAILABLE = False
|
39 |
+
|
40 |
+
|
41 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
42 |
+
|
43 |
+
EXAMPLE_DOC_STRING = """
|
44 |
+
Examples:
|
45 |
+
```py
|
46 |
+
>>> import torch
|
47 |
+
>>> from diffusers.utils import load_image
|
48 |
+
>>> from diffusers import FluxControlNetPipeline
|
49 |
+
>>> from diffusers import FluxControlNetModel
|
50 |
+
|
51 |
+
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny-alpha"
|
52 |
+
>>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
|
53 |
+
>>> pipe = FluxControlNetPipeline.from_pretrained(
|
54 |
+
... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
|
55 |
+
... )
|
56 |
+
>>> pipe.to("cuda")
|
57 |
+
>>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
|
58 |
+
>>> control_mask = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
|
59 |
+
>>> prompt = "A girl in city, 25 years old, cool, futuristic"
|
60 |
+
>>> image = pipe(
|
61 |
+
... prompt,
|
62 |
+
... control_image=control_image,
|
63 |
+
... controlnet_conditioning_scale=0.6,
|
64 |
+
... num_inference_steps=28,
|
65 |
+
... guidance_scale=3.5,
|
66 |
+
... ).images[0]
|
67 |
+
>>> image.save("flux.png")
|
68 |
+
```
|
69 |
+
"""
|
70 |
+
|
71 |
+
|
72 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
73 |
+
def calculate_shift(
|
74 |
+
image_seq_len,
|
75 |
+
base_seq_len: int = 256,
|
76 |
+
max_seq_len: int = 4096,
|
77 |
+
base_shift: float = 0.5,
|
78 |
+
max_shift: float = 1.16,
|
79 |
+
):
|
80 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
81 |
+
b = base_shift - m * base_seq_len
|
82 |
+
mu = image_seq_len * m + b
|
83 |
+
return mu
|
84 |
+
|
85 |
+
|
86 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
87 |
+
def retrieve_timesteps(
|
88 |
+
scheduler,
|
89 |
+
num_inference_steps: Optional[int] = None,
|
90 |
+
device: Optional[Union[str, torch.device]] = None,
|
91 |
+
timesteps: Optional[List[int]] = None,
|
92 |
+
sigmas: Optional[List[float]] = None,
|
93 |
+
**kwargs,
|
94 |
+
):
|
95 |
+
"""
|
96 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
97 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
scheduler (`SchedulerMixin`):
|
101 |
+
The scheduler to get timesteps from.
|
102 |
+
num_inference_steps (`int`):
|
103 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
104 |
+
must be `None`.
|
105 |
+
device (`str` or `torch.device`, *optional*):
|
106 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
107 |
+
timesteps (`List[int]`, *optional*):
|
108 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
109 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
110 |
+
sigmas (`List[float]`, *optional*):
|
111 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
112 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
116 |
+
second element is the number of inference steps.
|
117 |
+
"""
|
118 |
+
if timesteps is not None and sigmas is not None:
|
119 |
+
raise ValueError(
|
120 |
+
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
121 |
+
)
|
122 |
+
if timesteps is not None:
|
123 |
+
accepts_timesteps = "timesteps" in set(
|
124 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
125 |
+
)
|
126 |
+
if not accepts_timesteps:
|
127 |
+
raise ValueError(
|
128 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
129 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
130 |
+
)
|
131 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
132 |
+
timesteps = scheduler.timesteps
|
133 |
+
num_inference_steps = len(timesteps)
|
134 |
+
elif sigmas is not None:
|
135 |
+
accept_sigmas = "sigmas" in set(
|
136 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
137 |
+
)
|
138 |
+
if not accept_sigmas:
|
139 |
+
raise ValueError(
|
140 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
141 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
142 |
+
)
|
143 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
144 |
+
timesteps = scheduler.timesteps
|
145 |
+
num_inference_steps = len(timesteps)
|
146 |
+
else:
|
147 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
148 |
+
timesteps = scheduler.timesteps
|
149 |
+
return timesteps, num_inference_steps
|
150 |
+
|
151 |
+
|
152 |
+
class FluxControlNetInpaintingPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
153 |
+
r"""
|
154 |
+
The Flux pipeline for text-to-image generation.
|
155 |
+
|
156 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
157 |
+
|
158 |
+
Args:
|
159 |
+
transformer ([`FluxTransformer2DModel`]):
|
160 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
161 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
162 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
163 |
+
vae ([`AutoencoderKL`]):
|
164 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
165 |
+
text_encoder ([`CLIPTextModel`]):
|
166 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
167 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
168 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
169 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
170 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
171 |
+
tokenizer (`CLIPTokenizer`):
|
172 |
+
Tokenizer of class
|
173 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
174 |
+
tokenizer_2 (`T5TokenizerFast`):
|
175 |
+
Second Tokenizer of class
|
176 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
177 |
+
"""
|
178 |
+
|
179 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
180 |
+
_optional_components = []
|
181 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
182 |
+
|
183 |
+
def __init__(
|
184 |
+
self,
|
185 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
186 |
+
vae: AutoencoderKL,
|
187 |
+
text_encoder: CLIPTextModel,
|
188 |
+
tokenizer: CLIPTokenizer,
|
189 |
+
text_encoder_2: T5EncoderModel,
|
190 |
+
tokenizer_2: T5TokenizerFast,
|
191 |
+
transformer: FluxTransformer2DModel,
|
192 |
+
controlnet: FluxControlNetModel,
|
193 |
+
):
|
194 |
+
super().__init__()
|
195 |
+
|
196 |
+
self.register_modules(
|
197 |
+
vae=vae,
|
198 |
+
text_encoder=text_encoder,
|
199 |
+
text_encoder_2=text_encoder_2,
|
200 |
+
tokenizer=tokenizer,
|
201 |
+
tokenizer_2=tokenizer_2,
|
202 |
+
transformer=transformer,
|
203 |
+
scheduler=scheduler,
|
204 |
+
controlnet=controlnet,
|
205 |
+
)
|
206 |
+
self.vae_scale_factor = (
|
207 |
+
2 ** (len(self.vae.config.block_out_channels))
|
208 |
+
if hasattr(self, "vae") and self.vae is not None
|
209 |
+
else 16
|
210 |
+
)
|
211 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=True, do_convert_rgb=True, do_normalize=True)
|
212 |
+
self.mask_processor = VaeImageProcessor(
|
213 |
+
vae_scale_factor=self.vae_scale_factor,
|
214 |
+
do_resize=True,
|
215 |
+
do_convert_grayscale=True,
|
216 |
+
do_normalize=False,
|
217 |
+
do_binarize=True,
|
218 |
+
)
|
219 |
+
self.tokenizer_max_length = (
|
220 |
+
self.tokenizer.model_max_length
|
221 |
+
if hasattr(self, "tokenizer") and self.tokenizer is not None
|
222 |
+
else 77
|
223 |
+
)
|
224 |
+
self.default_sample_size = 64
|
225 |
+
|
226 |
+
@property
|
227 |
+
def do_classifier_free_guidance(self):
|
228 |
+
return self._guidance_scale > 1
|
229 |
+
|
230 |
+
def _get_t5_prompt_embeds(
|
231 |
+
self,
|
232 |
+
prompt: Union[str, List[str]] = None,
|
233 |
+
num_images_per_prompt: int = 1,
|
234 |
+
max_sequence_length: int = 512,
|
235 |
+
device: Optional[torch.device] = None,
|
236 |
+
dtype: Optional[torch.dtype] = None,
|
237 |
+
):
|
238 |
+
device = device or self._execution_device
|
239 |
+
dtype = dtype or self.text_encoder.dtype
|
240 |
+
|
241 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
242 |
+
batch_size = len(prompt)
|
243 |
+
|
244 |
+
text_inputs = self.tokenizer_2(
|
245 |
+
prompt,
|
246 |
+
padding="max_length",
|
247 |
+
max_length=max_sequence_length,
|
248 |
+
truncation=True,
|
249 |
+
return_length=False,
|
250 |
+
return_overflowing_tokens=False,
|
251 |
+
return_tensors="pt",
|
252 |
+
)
|
253 |
+
text_input_ids = text_inputs.input_ids
|
254 |
+
untruncated_ids = self.tokenizer_2(
|
255 |
+
prompt, padding="longest", return_tensors="pt"
|
256 |
+
).input_ids
|
257 |
+
|
258 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
259 |
+
text_input_ids, untruncated_ids
|
260 |
+
):
|
261 |
+
removed_text = self.tokenizer_2.batch_decode(
|
262 |
+
untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
|
263 |
+
)
|
264 |
+
logger.warning(
|
265 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
266 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
267 |
+
)
|
268 |
+
|
269 |
+
prompt_embeds = self.text_encoder_2(
|
270 |
+
text_input_ids.to(device), output_hidden_states=False
|
271 |
+
)[0]
|
272 |
+
|
273 |
+
dtype = self.text_encoder_2.dtype
|
274 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
275 |
+
|
276 |
+
_, seq_len, _ = prompt_embeds.shape
|
277 |
+
|
278 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
279 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
280 |
+
prompt_embeds = prompt_embeds.view(
|
281 |
+
batch_size * num_images_per_prompt, seq_len, -1
|
282 |
+
)
|
283 |
+
|
284 |
+
return prompt_embeds
|
285 |
+
|
286 |
+
def _get_clip_prompt_embeds(
|
287 |
+
self,
|
288 |
+
prompt: Union[str, List[str]],
|
289 |
+
num_images_per_prompt: int = 1,
|
290 |
+
device: Optional[torch.device] = None,
|
291 |
+
):
|
292 |
+
device = device or self._execution_device
|
293 |
+
|
294 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
295 |
+
batch_size = len(prompt)
|
296 |
+
|
297 |
+
text_inputs = self.tokenizer(
|
298 |
+
prompt,
|
299 |
+
padding="max_length",
|
300 |
+
max_length=self.tokenizer_max_length,
|
301 |
+
truncation=True,
|
302 |
+
return_overflowing_tokens=False,
|
303 |
+
return_length=False,
|
304 |
+
return_tensors="pt",
|
305 |
+
)
|
306 |
+
|
307 |
+
text_input_ids = text_inputs.input_ids
|
308 |
+
untruncated_ids = self.tokenizer(
|
309 |
+
prompt, padding="longest", return_tensors="pt"
|
310 |
+
).input_ids
|
311 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
312 |
+
text_input_ids, untruncated_ids
|
313 |
+
):
|
314 |
+
removed_text = self.tokenizer.batch_decode(
|
315 |
+
untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
|
316 |
+
)
|
317 |
+
logger.warning(
|
318 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
319 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
320 |
+
)
|
321 |
+
prompt_embeds = self.text_encoder(
|
322 |
+
text_input_ids.to(device), output_hidden_states=False
|
323 |
+
)
|
324 |
+
|
325 |
+
# Use pooled output of CLIPTextModel
|
326 |
+
prompt_embeds = prompt_embeds.pooler_output
|
327 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
328 |
+
|
329 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
330 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
331 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
332 |
+
|
333 |
+
return prompt_embeds
|
334 |
+
|
335 |
+
def encode_prompt(
|
336 |
+
self,
|
337 |
+
prompt: Union[str, List[str]],
|
338 |
+
prompt_2: Union[str, List[str]],
|
339 |
+
device: Optional[torch.device] = None,
|
340 |
+
num_images_per_prompt: int = 1,
|
341 |
+
do_classifier_free_guidance: bool = True,
|
342 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
343 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
344 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
345 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
346 |
+
max_sequence_length: int = 512,
|
347 |
+
lora_scale: Optional[float] = None,
|
348 |
+
):
|
349 |
+
r"""
|
350 |
+
|
351 |
+
Args:
|
352 |
+
prompt (`str` or `List[str]`, *optional*):
|
353 |
+
prompt to be encoded
|
354 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
355 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
356 |
+
used in all text-encoders
|
357 |
+
device: (`torch.device`):
|
358 |
+
torch device
|
359 |
+
num_images_per_prompt (`int`):
|
360 |
+
number of images that should be generated per prompt
|
361 |
+
do_classifier_free_guidance (`bool`):
|
362 |
+
whether to use classifier-free guidance or not
|
363 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
364 |
+
negative prompt to be encoded
|
365 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
366 |
+
negative prompt to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is
|
367 |
+
used in all text-encoders
|
368 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
369 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
370 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
371 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
372 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
373 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
374 |
+
clip_skip (`int`, *optional*):
|
375 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
376 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
377 |
+
lora_scale (`float`, *optional*):
|
378 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
379 |
+
"""
|
380 |
+
device = device or self._execution_device
|
381 |
+
|
382 |
+
# set lora scale so that monkey patched LoRA
|
383 |
+
# function of text encoder can correctly access it
|
384 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
385 |
+
self._lora_scale = lora_scale
|
386 |
+
|
387 |
+
# dynamically adjust the LoRA scale
|
388 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
389 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
390 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
391 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
392 |
+
|
393 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
394 |
+
if prompt is not None:
|
395 |
+
batch_size = len(prompt)
|
396 |
+
else:
|
397 |
+
batch_size = prompt_embeds.shape[0]
|
398 |
+
|
399 |
+
if prompt_embeds is None:
|
400 |
+
prompt_2 = prompt_2 or prompt
|
401 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
402 |
+
|
403 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
404 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
405 |
+
prompt=prompt,
|
406 |
+
device=device,
|
407 |
+
num_images_per_prompt=num_images_per_prompt,
|
408 |
+
)
|
409 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
410 |
+
prompt=prompt_2,
|
411 |
+
num_images_per_prompt=num_images_per_prompt,
|
412 |
+
max_sequence_length=max_sequence_length,
|
413 |
+
device=device,
|
414 |
+
)
|
415 |
+
|
416 |
+
if do_classifier_free_guidance:
|
417 |
+
# 处理 negative prompt
|
418 |
+
negative_prompt = negative_prompt or ""
|
419 |
+
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
420 |
+
|
421 |
+
negative_pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
422 |
+
negative_prompt,
|
423 |
+
device=device,
|
424 |
+
num_images_per_prompt=num_images_per_prompt,
|
425 |
+
)
|
426 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
427 |
+
negative_prompt_2,
|
428 |
+
num_images_per_prompt=num_images_per_prompt,
|
429 |
+
max_sequence_length=max_sequence_length,
|
430 |
+
device=device,
|
431 |
+
)
|
432 |
+
else:
|
433 |
+
negative_pooled_prompt_embeds = None
|
434 |
+
negative_prompt_embeds = None
|
435 |
+
|
436 |
+
if self.text_encoder is not None:
|
437 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
438 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
439 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
440 |
+
|
441 |
+
if self.text_encoder_2 is not None:
|
442 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
443 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
444 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
445 |
+
|
446 |
+
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(
|
447 |
+
device=device, dtype=self.text_encoder.dtype
|
448 |
+
)
|
449 |
+
|
450 |
+
return prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds,text_ids
|
451 |
+
|
452 |
+
def check_inputs(
|
453 |
+
self,
|
454 |
+
prompt,
|
455 |
+
prompt_2,
|
456 |
+
height,
|
457 |
+
width,
|
458 |
+
prompt_embeds=None,
|
459 |
+
pooled_prompt_embeds=None,
|
460 |
+
callback_on_step_end_tensor_inputs=None,
|
461 |
+
max_sequence_length=None,
|
462 |
+
):
|
463 |
+
if height % 8 != 0 or width % 8 != 0:
|
464 |
+
raise ValueError(
|
465 |
+
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
466 |
+
)
|
467 |
+
|
468 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
469 |
+
k in self._callback_tensor_inputs
|
470 |
+
for k in callback_on_step_end_tensor_inputs
|
471 |
+
):
|
472 |
+
raise ValueError(
|
473 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
474 |
+
)
|
475 |
+
|
476 |
+
if prompt is not None and prompt_embeds is not None:
|
477 |
+
raise ValueError(
|
478 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
479 |
+
" only forward one of the two."
|
480 |
+
)
|
481 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
482 |
+
raise ValueError(
|
483 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
484 |
+
" only forward one of the two."
|
485 |
+
)
|
486 |
+
elif prompt is None and prompt_embeds is None:
|
487 |
+
raise ValueError(
|
488 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
489 |
+
)
|
490 |
+
elif prompt is not None and (
|
491 |
+
not isinstance(prompt, str) and not isinstance(prompt, list)
|
492 |
+
):
|
493 |
+
raise ValueError(
|
494 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
495 |
+
)
|
496 |
+
elif prompt_2 is not None and (
|
497 |
+
not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
|
498 |
+
):
|
499 |
+
raise ValueError(
|
500 |
+
f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
|
501 |
+
)
|
502 |
+
|
503 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
504 |
+
raise ValueError(
|
505 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
506 |
+
)
|
507 |
+
|
508 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
509 |
+
raise ValueError(
|
510 |
+
f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}"
|
511 |
+
)
|
512 |
+
|
513 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux._prepare_latent_image_ids
|
514 |
+
@staticmethod
|
515 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
516 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
517 |
+
latent_image_ids[..., 1] = (
|
518 |
+
latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
519 |
+
)
|
520 |
+
latent_image_ids[..., 2] = (
|
521 |
+
latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
522 |
+
)
|
523 |
+
|
524 |
+
(
|
525 |
+
latent_image_id_height,
|
526 |
+
latent_image_id_width,
|
527 |
+
latent_image_id_channels,
|
528 |
+
) = latent_image_ids.shape
|
529 |
+
|
530 |
+
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
531 |
+
latent_image_ids = latent_image_ids.reshape(
|
532 |
+
batch_size,
|
533 |
+
latent_image_id_height * latent_image_id_width,
|
534 |
+
latent_image_id_channels,
|
535 |
+
)
|
536 |
+
|
537 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
538 |
+
|
539 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux._pack_latents
|
540 |
+
@staticmethod
|
541 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
542 |
+
latents = latents.view(
|
543 |
+
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
|
544 |
+
)
|
545 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
546 |
+
latents = latents.reshape(
|
547 |
+
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
|
548 |
+
)
|
549 |
+
|
550 |
+
return latents
|
551 |
+
|
552 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux._unpack_latents
|
553 |
+
@staticmethod
|
554 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
555 |
+
batch_size, num_patches, channels = latents.shape
|
556 |
+
|
557 |
+
height = height // vae_scale_factor
|
558 |
+
width = width // vae_scale_factor
|
559 |
+
|
560 |
+
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
561 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
562 |
+
|
563 |
+
latents = latents.reshape(
|
564 |
+
batch_size, channels // (2 * 2), height * 2, width * 2
|
565 |
+
)
|
566 |
+
|
567 |
+
return latents
|
568 |
+
|
569 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.prepare_latents
|
570 |
+
def prepare_latents(
|
571 |
+
self,
|
572 |
+
batch_size,
|
573 |
+
num_channels_latents,
|
574 |
+
height,
|
575 |
+
width,
|
576 |
+
dtype,
|
577 |
+
device,
|
578 |
+
generator,
|
579 |
+
latents=None,
|
580 |
+
):
|
581 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
582 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
583 |
+
|
584 |
+
shape = (batch_size, num_channels_latents, height, width)
|
585 |
+
|
586 |
+
if latents is not None:
|
587 |
+
latent_image_ids = self._prepare_latent_image_ids(
|
588 |
+
batch_size, height, width, device, dtype
|
589 |
+
)
|
590 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
591 |
+
|
592 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
593 |
+
raise ValueError(
|
594 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
595 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
596 |
+
)
|
597 |
+
|
598 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
599 |
+
latents = self._pack_latents(
|
600 |
+
latents, batch_size, num_channels_latents, height, width
|
601 |
+
)
|
602 |
+
|
603 |
+
latent_image_ids = self._prepare_latent_image_ids(
|
604 |
+
batch_size, height, width, device, dtype
|
605 |
+
)
|
606 |
+
|
607 |
+
return latents, latent_image_ids
|
608 |
+
|
609 |
+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
|
610 |
+
def prepare_image(
|
611 |
+
self,
|
612 |
+
image,
|
613 |
+
width,
|
614 |
+
height,
|
615 |
+
batch_size,
|
616 |
+
num_images_per_prompt,
|
617 |
+
device,
|
618 |
+
dtype,
|
619 |
+
):
|
620 |
+
if isinstance(image, torch.Tensor):
|
621 |
+
pass
|
622 |
+
else:
|
623 |
+
image = self.image_processor.preprocess(image, height=height, width=width)
|
624 |
+
|
625 |
+
image_batch_size = image.shape[0]
|
626 |
+
|
627 |
+
if image_batch_size == 1:
|
628 |
+
repeat_by = batch_size
|
629 |
+
else:
|
630 |
+
# image batch size is the same as prompt batch size
|
631 |
+
repeat_by = num_images_per_prompt
|
632 |
+
|
633 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
634 |
+
|
635 |
+
image = image.to(device=device, dtype=dtype)
|
636 |
+
|
637 |
+
return image
|
638 |
+
|
639 |
+
def prepare_image_with_mask(
|
640 |
+
self,
|
641 |
+
image,
|
642 |
+
mask,
|
643 |
+
width,
|
644 |
+
height,
|
645 |
+
batch_size,
|
646 |
+
num_images_per_prompt,
|
647 |
+
device,
|
648 |
+
dtype,
|
649 |
+
do_classifier_free_guidance = False,
|
650 |
+
):
|
651 |
+
# Prepare image
|
652 |
+
if isinstance(image, torch.Tensor):
|
653 |
+
pass
|
654 |
+
else:
|
655 |
+
image = self.image_processor.preprocess(image, height=height, width=width)
|
656 |
+
|
657 |
+
image_batch_size = image.shape[0]
|
658 |
+
if image_batch_size == 1:
|
659 |
+
repeat_by = batch_size
|
660 |
+
else:
|
661 |
+
# image batch size is the same as prompt batch size
|
662 |
+
repeat_by = num_images_per_prompt
|
663 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
664 |
+
image = image.to(device=device, dtype=dtype)
|
665 |
+
|
666 |
+
# Prepare mask
|
667 |
+
if isinstance(mask, torch.Tensor):
|
668 |
+
pass
|
669 |
+
else:
|
670 |
+
mask = self.mask_processor.preprocess(mask, height=height, width=width)
|
671 |
+
mask = mask.repeat_interleave(repeat_by, dim=0)
|
672 |
+
mask = mask.to(device=device, dtype=dtype)
|
673 |
+
|
674 |
+
# Get masked image
|
675 |
+
masked_image = image.clone()
|
676 |
+
masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1
|
677 |
+
|
678 |
+
# Encode to latents
|
679 |
+
image_latents = self.vae.encode(masked_image.to(self.vae.dtype)).latent_dist.sample()
|
680 |
+
image_latents = (
|
681 |
+
image_latents - self.vae.config.shift_factor
|
682 |
+
) * self.vae.config.scaling_factor
|
683 |
+
image_latents = image_latents.to(dtype)
|
684 |
+
|
685 |
+
mask = torch.nn.functional.interpolate(
|
686 |
+
mask, size=(height // self.vae_scale_factor * 2, width // self.vae_scale_factor * 2)
|
687 |
+
)
|
688 |
+
mask = 1 - mask
|
689 |
+
|
690 |
+
control_image = torch.cat([image_latents, mask], dim=1)
|
691 |
+
|
692 |
+
# Pack cond latents
|
693 |
+
packed_control_image = self._pack_latents(
|
694 |
+
control_image,
|
695 |
+
batch_size * num_images_per_prompt,
|
696 |
+
control_image.shape[1],
|
697 |
+
control_image.shape[2],
|
698 |
+
control_image.shape[3],
|
699 |
+
)
|
700 |
+
|
701 |
+
if do_classifier_free_guidance:
|
702 |
+
packed_control_image = torch.cat([packed_control_image] * 2)
|
703 |
+
|
704 |
+
return packed_control_image, height, width
|
705 |
+
|
706 |
+
@property
|
707 |
+
def guidance_scale(self):
|
708 |
+
return self._guidance_scale
|
709 |
+
|
710 |
+
@property
|
711 |
+
def joint_attention_kwargs(self):
|
712 |
+
return self._joint_attention_kwargs
|
713 |
+
|
714 |
+
@property
|
715 |
+
def num_timesteps(self):
|
716 |
+
return self._num_timesteps
|
717 |
+
|
718 |
+
@property
|
719 |
+
def interrupt(self):
|
720 |
+
return self._interrupt
|
721 |
+
|
722 |
+
@torch.no_grad()
|
723 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
724 |
+
def __call__(
|
725 |
+
self,
|
726 |
+
prompt: Union[str, List[str]] = None,
|
727 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
728 |
+
height: Optional[int] = None,
|
729 |
+
width: Optional[int] = None,
|
730 |
+
num_inference_steps: int = 28,
|
731 |
+
timesteps: List[int] = None,
|
732 |
+
guidance_scale: float = 7.0,
|
733 |
+
true_guidance_scale: float = 3.5 ,
|
734 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
735 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
736 |
+
control_image: PipelineImageInput = None,
|
737 |
+
control_mask: PipelineImageInput = None,
|
738 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
739 |
+
num_images_per_prompt: Optional[int] = 1,
|
740 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
741 |
+
latents: Optional[torch.FloatTensor] = None,
|
742 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
743 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
744 |
+
output_type: Optional[str] = "pil",
|
745 |
+
return_dict: bool = True,
|
746 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
747 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
748 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
749 |
+
max_sequence_length: int = 512,
|
750 |
+
):
|
751 |
+
r"""
|
752 |
+
Function invoked when calling the pipeline for generation.
|
753 |
+
|
754 |
+
Args:
|
755 |
+
prompt (`str` or `List[str]`, *optional*):
|
756 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
757 |
+
instead.
|
758 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
759 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
760 |
+
will be used instead
|
761 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
762 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
763 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
764 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
765 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
766 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
767 |
+
expense of slower inference.
|
768 |
+
timesteps (`List[int]`, *optional*):
|
769 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
770 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
771 |
+
passed will be used. Must be in descending order.
|
772 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
773 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
774 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
775 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
776 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
777 |
+
usually at the expense of lower image quality.
|
778 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
779 |
+
The number of images to generate per prompt.
|
780 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
781 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
782 |
+
to make generation deterministic.
|
783 |
+
latents (`torch.FloatTensor`, *optional*):
|
784 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
785 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
786 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
787 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
788 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
789 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
790 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
791 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
792 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
793 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
794 |
+
The output format of the generate image. Choose between
|
795 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
796 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
797 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
798 |
+
joint_attention_kwargs (`dict`, *optional*):
|
799 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
800 |
+
`self.processor` in
|
801 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
802 |
+
callback_on_step_end (`Callable`, *optional*):
|
803 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
804 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
805 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
806 |
+
`callback_on_step_end_tensor_inputs`.
|
807 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
808 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
809 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
810 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
811 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
812 |
+
|
813 |
+
Examples:
|
814 |
+
|
815 |
+
Returns:
|
816 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
817 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
818 |
+
images.
|
819 |
+
"""
|
820 |
+
|
821 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
822 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
823 |
+
|
824 |
+
# 1. Check inputs. Raise error if not correct
|
825 |
+
self.check_inputs(
|
826 |
+
prompt,
|
827 |
+
prompt_2,
|
828 |
+
height,
|
829 |
+
width,
|
830 |
+
prompt_embeds=prompt_embeds,
|
831 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
832 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
833 |
+
max_sequence_length=max_sequence_length,
|
834 |
+
)
|
835 |
+
|
836 |
+
self._guidance_scale = true_guidance_scale
|
837 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
838 |
+
self._interrupt = False
|
839 |
+
|
840 |
+
# 2. Define call parameters
|
841 |
+
if prompt is not None and isinstance(prompt, str):
|
842 |
+
batch_size = 1
|
843 |
+
elif prompt is not None and isinstance(prompt, list):
|
844 |
+
batch_size = len(prompt)
|
845 |
+
else:
|
846 |
+
batch_size = prompt_embeds.shape[0]
|
847 |
+
|
848 |
+
device = self._execution_device
|
849 |
+
dtype = self.transformer.dtype
|
850 |
+
|
851 |
+
lora_scale = (
|
852 |
+
self.joint_attention_kwargs.get("scale", None)
|
853 |
+
if self.joint_attention_kwargs is not None
|
854 |
+
else None
|
855 |
+
)
|
856 |
+
(
|
857 |
+
prompt_embeds,
|
858 |
+
pooled_prompt_embeds,
|
859 |
+
negative_prompt_embeds,
|
860 |
+
negative_pooled_prompt_embeds,
|
861 |
+
text_ids
|
862 |
+
) = self.encode_prompt(
|
863 |
+
prompt=prompt,
|
864 |
+
prompt_2=prompt_2,
|
865 |
+
prompt_embeds=prompt_embeds,
|
866 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
867 |
+
do_classifier_free_guidance = self.do_classifier_free_guidance,
|
868 |
+
negative_prompt = negative_prompt,
|
869 |
+
negative_prompt_2 = negative_prompt_2,
|
870 |
+
device=device,
|
871 |
+
num_images_per_prompt=num_images_per_prompt,
|
872 |
+
max_sequence_length=max_sequence_length,
|
873 |
+
lora_scale=lora_scale,
|
874 |
+
)
|
875 |
+
|
876 |
+
# 在 encode_prompt 之后
|
877 |
+
if self.do_classifier_free_guidance:
|
878 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim = 0)
|
879 |
+
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim = 0)
|
880 |
+
text_ids = torch.cat([text_ids, text_ids], dim = 0)
|
881 |
+
|
882 |
+
# 3. Prepare control image
|
883 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
884 |
+
if isinstance(self.controlnet, FluxControlNetModel):
|
885 |
+
control_image, height, width = self.prepare_image_with_mask(
|
886 |
+
image=control_image,
|
887 |
+
mask=control_mask,
|
888 |
+
width=width,
|
889 |
+
height=height,
|
890 |
+
batch_size=batch_size * num_images_per_prompt,
|
891 |
+
num_images_per_prompt=num_images_per_prompt,
|
892 |
+
device=device,
|
893 |
+
dtype=dtype,
|
894 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
895 |
+
)
|
896 |
+
|
897 |
+
# 4. Prepare latent variables
|
898 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
899 |
+
latents, latent_image_ids = self.prepare_latents(
|
900 |
+
batch_size * num_images_per_prompt,
|
901 |
+
num_channels_latents,
|
902 |
+
height,
|
903 |
+
width,
|
904 |
+
prompt_embeds.dtype,
|
905 |
+
device,
|
906 |
+
generator,
|
907 |
+
latents,
|
908 |
+
)
|
909 |
+
|
910 |
+
if self.do_classifier_free_guidance:
|
911 |
+
latent_image_ids = torch.cat([latent_image_ids] * 2)
|
912 |
+
|
913 |
+
# 5. Prepare timesteps
|
914 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
915 |
+
image_seq_len = latents.shape[1]
|
916 |
+
mu = calculate_shift(
|
917 |
+
image_seq_len,
|
918 |
+
self.scheduler.config.base_image_seq_len,
|
919 |
+
self.scheduler.config.max_image_seq_len,
|
920 |
+
self.scheduler.config.base_shift,
|
921 |
+
self.scheduler.config.max_shift,
|
922 |
+
)
|
923 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
924 |
+
self.scheduler,
|
925 |
+
num_inference_steps,
|
926 |
+
device,
|
927 |
+
timesteps,
|
928 |
+
sigmas,
|
929 |
+
mu=mu,
|
930 |
+
)
|
931 |
+
|
932 |
+
num_warmup_steps = max(
|
933 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
934 |
+
)
|
935 |
+
self._num_timesteps = len(timesteps)
|
936 |
+
|
937 |
+
# 6. Denoising loop
|
938 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
939 |
+
for i, t in enumerate(timesteps):
|
940 |
+
if self.interrupt:
|
941 |
+
continue
|
942 |
+
|
943 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
944 |
+
|
945 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
946 |
+
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
947 |
+
|
948 |
+
# handle guidance
|
949 |
+
if self.transformer.config.guidance_embeds:
|
950 |
+
guidance = torch.tensor([guidance_scale], device=device)
|
951 |
+
guidance = guidance.expand(latent_model_input.shape[0])
|
952 |
+
else:
|
953 |
+
guidance = None
|
954 |
+
|
955 |
+
# controlnet
|
956 |
+
(
|
957 |
+
controlnet_block_samples,
|
958 |
+
controlnet_single_block_samples,
|
959 |
+
) = self.controlnet(
|
960 |
+
hidden_states=latent_model_input,
|
961 |
+
controlnet_cond=control_image,
|
962 |
+
conditioning_scale=controlnet_conditioning_scale,
|
963 |
+
timestep=timestep / 1000,
|
964 |
+
guidance=guidance,
|
965 |
+
pooled_projections=pooled_prompt_embeds,
|
966 |
+
encoder_hidden_states=prompt_embeds,
|
967 |
+
txt_ids=text_ids,
|
968 |
+
img_ids=latent_image_ids,
|
969 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
970 |
+
return_dict=False,
|
971 |
+
)
|
972 |
+
|
973 |
+
noise_pred = self.transformer(
|
974 |
+
hidden_states=latent_model_input,
|
975 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
976 |
+
timestep=timestep / 1000,
|
977 |
+
guidance=guidance,
|
978 |
+
pooled_projections=pooled_prompt_embeds,
|
979 |
+
encoder_hidden_states=prompt_embeds,
|
980 |
+
controlnet_block_samples=[
|
981 |
+
sample.to(dtype=self.transformer.dtype)
|
982 |
+
for sample in controlnet_block_samples
|
983 |
+
],
|
984 |
+
controlnet_single_block_samples=[
|
985 |
+
sample.to(dtype=self.transformer.dtype)
|
986 |
+
for sample in controlnet_single_block_samples
|
987 |
+
] if controlnet_single_block_samples is not None else controlnet_single_block_samples,
|
988 |
+
txt_ids=text_ids,
|
989 |
+
img_ids=latent_image_ids,
|
990 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
991 |
+
return_dict=False,
|
992 |
+
)[0]
|
993 |
+
|
994 |
+
# 在生成循环中
|
995 |
+
if self.do_classifier_free_guidance:
|
996 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
997 |
+
noise_pred = noise_pred_uncond + true_guidance_scale * (noise_pred_text - noise_pred_uncond)
|
998 |
+
|
999 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1000 |
+
latents_dtype = latents.dtype
|
1001 |
+
latents = self.scheduler.step(
|
1002 |
+
noise_pred, t, latents, return_dict=False
|
1003 |
+
)[0]
|
1004 |
+
|
1005 |
+
if latents.dtype != latents_dtype:
|
1006 |
+
if torch.backends.mps.is_available():
|
1007 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
1008 |
+
latents = latents.to(latents_dtype)
|
1009 |
+
|
1010 |
+
if callback_on_step_end is not None:
|
1011 |
+
callback_kwargs = {}
|
1012 |
+
for k in callback_on_step_end_tensor_inputs:
|
1013 |
+
callback_kwargs[k] = locals()[k]
|
1014 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1015 |
+
|
1016 |
+
latents = callback_outputs.pop("latents", latents)
|
1017 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1018 |
+
|
1019 |
+
# call the callback, if provided
|
1020 |
+
if i == len(timesteps) - 1 or (
|
1021 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
1022 |
+
):
|
1023 |
+
progress_bar.update()
|
1024 |
+
|
1025 |
+
if XLA_AVAILABLE:
|
1026 |
+
xm.mark_step()
|
1027 |
+
|
1028 |
+
if output_type == "latent":
|
1029 |
+
image = latents
|
1030 |
+
|
1031 |
+
else:
|
1032 |
+
latents = self._unpack_latents(
|
1033 |
+
latents, height, width, self.vae_scale_factor
|
1034 |
+
)
|
1035 |
+
latents = (
|
1036 |
+
latents / self.vae.config.scaling_factor
|
1037 |
+
) + self.vae.config.shift_factor
|
1038 |
+
latents = latents.to(self.vae.dtype)
|
1039 |
+
|
1040 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
1041 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
1042 |
+
|
1043 |
+
# Offload all models
|
1044 |
+
self.maybe_free_model_hooks()
|
1045 |
+
|
1046 |
+
if not return_dict:
|
1047 |
+
return (image,)
|
1048 |
+
|
1049 |
+
return FluxPipelineOutput(images=image)
|
readme.md
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div style="display: flex;align-items: center;">
|
2 |
+
<img src="images/alibabaalimama.png" alt="alibaba" style="width: 40%; height: auto; margin: 0 10px;">
|
3 |
+
</div>
|
4 |
+
|
5 |
+
This repository provides a Inpainting ControlNet checkpoint for [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) model released by researchers from AlimamaCreative Team.
|
6 |
+
|
7 |
+
## News
|
8 |
+
|
9 |
+
🎉 Thanks to @comfyanonymous,ComfyUI now supports inference for Alimama inpainting ControlNet. Workflow can be downloaded from [here](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/alimama-flux-controlnet-inpaint.json).
|
10 |
+
|
11 |
+
ComfyUI Usage Tips:
|
12 |
+
|
13 |
+
* Using the `t5xxl-FP16` and `flux1-dev-fp8` models for 28-step inference, the GPU memory usage is 27GB. The inference time with `cfg=3.5` is 27 seconds, while without `cfg=1` it is 15 seconds. `Hyper-FLUX-lora` can be used to accelerate inference.
|
14 |
+
* You can try adjusting(lower) the parameters `control-strength`, `control-end-percent`, and `cfg` to achieve better results.
|
15 |
+
* The following example uses `control-strength` = 0.9 & `control-end-percent` = 1.0 & `cfg` = 3.5
|
16 |
+
|
17 |
+
| Input | Output | Prompt |
|
18 |
+
|------------------------------|------------------------------|-------------|
|
19 |
+
| ![Image1](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_in_1.png) | ![Image2](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_out_1.png) | <small><i>The image depicts a scene from the anime series Dragon Ball Z, with the characters Goku, <span style="color:red; font-weight:bold;">Elon Musk</span>, and a child version of Gohan sharing a meal of ramen noodles. They are all sitting around a dining table, with Goku and Gohan on one side and Naruto on the other. They are all holding chopsticks and eating the noodles. The table is set with bowls of ramen, cups, and bowls of drinks. The arrangement of the characters and the food creates a sense of camaraderie and shared enjoyment of the meal. |
|
20 |
+
| ![Image3](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_in_2.png) | ![Image4](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_out_2.png) | <small><i>The image is an illustration of a man standing in a cafe. He is wearing a white turtleneck, a camel-colored trench coat, and brown shoes. He is holding a cell phone and appears to be looking at it. There is a small table with <span style="color:red; font-weight:bold;">a cat</span> on it to his right. In the background, there is another man sitting at a table with a laptop. The man is wearing a black turtleneck and a tie. </i></small>|
|
21 |
+
| ![Image5](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_in_3.png) | ![Image6](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_out_3.png) | <small><i>A woman with blonde hair is sitting on a table wearing a <span style="color:red; font-weight:bold;">red and white long dress</span>. She is holding a green phone in her hand and appears to be taking a photo. There is a bag next to her on the table and a handbag beside her on the chair. The woman is looking at the phone with a smile on her face. The background includes a TV on the left wall and a couch on the right. A chair is also present in the scene. </i></small>|
|
22 |
+
| ![Image7](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_in_4.png) | ![Image8](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_out_4.png) | <small><i>The image depicts a beautiful young woman sitting at a desk, reading a book. She has long, wavy brown hair and is wearing a grey shirt with a black cardigan. She is holding a <span style="color:red; font-weight:bold;">red pencil</span> in her left hand and appears to be deep in thought. Surrounding her are numerous books, some stacked on the desk and others placed on a shelf behind her. A potted plant is also visible in the background, adding a touch of greenery to the scene. The image conveys a sense of serenity and intellectual pursuits. </i></small>|
|
23 |
+
|
24 |
+
|
25 |
+
## Model Cards
|
26 |
+
|
27 |
+
<!-- 使用HTML来调整图标大小 -->
|
28 |
+
<a href="https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha" target="_blank">
|
29 |
+
<img src="https://huggingface.co/favicon.ico" alt="Hugging Face" width="25" height="25" /> The model weights have been uploaded to Hugging Face.
|
30 |
+
</a>
|
31 |
+
|
32 |
+
* The model was trained on 12M laion2B and internal source images at resolution 768x768. The inference performs best at this size, with other sizes yielding suboptimal results.
|
33 |
+
|
34 |
+
* The recommended controlnet_conditioning_scale is 0.9 - 0.95.
|
35 |
+
|
36 |
+
* **Please note: This is only the alpha version during the training process. We will release an updated version when we feel ready.**
|
37 |
+
|
38 |
+
## Showcase
|
39 |
+
|
40 |
+
![flux1](images/flux1.jpg)
|
41 |
+
![flux2](images/flux2.jpg)
|
42 |
+
![flux3](images/flux3.jpg)
|
43 |
+
|
44 |
+
## Comparison with SDXL-Inpainting
|
45 |
+
|
46 |
+
Compared with [SDXL-Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1)
|
47 |
+
|
48 |
+
From left to right: Input image | Masked image | SDXL inpainting | Ours
|
49 |
+
|
50 |
+
![0](images/0.jpg)
|
51 |
+
<small><i>*The image depicts a beautiful young woman sitting at a desk, reading a book. She has long, wavy brown hair and is wearing a grey shirt with a black cardigan. She is holding a pencil in her left hand and appears to be deep in thought. Surrounding her are numerous books, some stacked on the desk and others placed on a shelf behind her. A potted plant is also visible in the background, adding a touch of greenery to the scene. The image conveys a sense of serenity and intellectual pursuits.*</i></small>
|
52 |
+
|
53 |
+
![0](images/1.jpg)
|
54 |
+
<small><i>A woman with blonde hair is sitting on a table wearing a blue and white long dress. She is holding a green phone in her hand and appears to be taking a photo. There is a bag next to her on the table and a handbag beside her on the chair. The woman is looking at the phone with a smile on her face. The background includes a TV on the left wall and a couch on the right. A chair is also present in the scene.</i></small>
|
55 |
+
|
56 |
+
![0](images/2.jpg)
|
57 |
+
<small><i>The image is an illustration of a man standing in a cafe. He is wearing a white turtleneck, a camel-colored trench coat, and brown shoes. He is holding a cell phone and appears to be looking at it. There is a small table with a cup of coffee on it to his right. In the background, there is another man sitting at a table with a laptop. The man is wearing a black turtleneck and a tie. There are several cups and a cake on the table in the background. The man sitting at the table appears to be typing on the laptop.</i></small>
|
58 |
+
|
59 |
+
![0](images/3.jpg)
|
60 |
+
<small><i>The image depicts a scene from the anime series Dragon Ball Z, with the characters Goku, Naruto, and a child version of Gohan sharing a meal of ramen noodles. They are all sitting around a dining table, with Goku and Gohan on one side and Naruto on the other. They are all holding chopsticks and eating the noodles. The table is set with bowls of ramen, cups, and bowls of drinks. The arrangement of the characters and the food creates a sense of camaraderie and shared enjoyment of the meal.</i></small>
|
61 |
+
|
62 |
+
## Using with Diffusers
|
63 |
+
Step1: install diffusers
|
64 |
+
``` Shell
|
65 |
+
pip install diffusers==0.30.2
|
66 |
+
```
|
67 |
+
|
68 |
+
Step2: clone repo from github
|
69 |
+
``` Shell
|
70 |
+
git clone https://github.com/alimama-creative/FLUX-Controlnet-Inpainting.git
|
71 |
+
```
|
72 |
+
|
73 |
+
Step3: modify the image_path, mask_path, prompt and run
|
74 |
+
``` Shell
|
75 |
+
python main.py
|
76 |
+
```
|
77 |
+
## LICENSE
|
78 |
+
Our weights fall under the [FLUX.1 [dev]](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md) Non-Commercial License.
|
transformer_flux.py
ADDED
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
9 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
10 |
+
from diffusers.models.attention import FeedForward
|
11 |
+
from diffusers.models.attention_processor import (
|
12 |
+
Attention,
|
13 |
+
FluxAttnProcessor2_0,
|
14 |
+
FluxSingleAttnProcessor2_0,
|
15 |
+
)
|
16 |
+
from diffusers.models.modeling_utils import ModelMixin
|
17 |
+
from diffusers.models.normalization import (
|
18 |
+
AdaLayerNormContinuous,
|
19 |
+
AdaLayerNormZero,
|
20 |
+
AdaLayerNormZeroSingle,
|
21 |
+
)
|
22 |
+
from diffusers.utils import (
|
23 |
+
USE_PEFT_BACKEND,
|
24 |
+
is_torch_version,
|
25 |
+
logging,
|
26 |
+
scale_lora_layers,
|
27 |
+
unscale_lora_layers,
|
28 |
+
)
|
29 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
30 |
+
from diffusers.models.embeddings import (
|
31 |
+
CombinedTimestepGuidanceTextProjEmbeddings,
|
32 |
+
CombinedTimestepTextProjEmbeddings,
|
33 |
+
)
|
34 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
35 |
+
|
36 |
+
|
37 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
38 |
+
|
39 |
+
|
40 |
+
# YiYi to-do: refactor rope related functions/classes
|
41 |
+
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
42 |
+
assert dim % 2 == 0, "The dimension must be even."
|
43 |
+
|
44 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
45 |
+
omega = 1.0 / (theta**scale)
|
46 |
+
|
47 |
+
batch_size, seq_length = pos.shape
|
48 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
49 |
+
cos_out = torch.cos(out)
|
50 |
+
sin_out = torch.sin(out)
|
51 |
+
|
52 |
+
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
53 |
+
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
54 |
+
return out.float()
|
55 |
+
|
56 |
+
|
57 |
+
# YiYi to-do: refactor rope related functions/classes
|
58 |
+
class EmbedND(nn.Module):
|
59 |
+
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
60 |
+
super().__init__()
|
61 |
+
self.dim = dim
|
62 |
+
self.theta = theta
|
63 |
+
self.axes_dim = axes_dim
|
64 |
+
|
65 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
66 |
+
n_axes = ids.shape[-1]
|
67 |
+
emb = torch.cat(
|
68 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
69 |
+
dim=-3,
|
70 |
+
)
|
71 |
+
return emb.unsqueeze(1)
|
72 |
+
|
73 |
+
|
74 |
+
@maybe_allow_in_graph
|
75 |
+
class FluxSingleTransformerBlock(nn.Module):
|
76 |
+
r"""
|
77 |
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
78 |
+
|
79 |
+
Reference: https://arxiv.org/abs/2403.03206
|
80 |
+
|
81 |
+
Parameters:
|
82 |
+
dim (`int`): The number of channels in the input and output.
|
83 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
84 |
+
attention_head_dim (`int`): The number of channels in each head.
|
85 |
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
86 |
+
processing of `context` conditions.
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
|
90 |
+
super().__init__()
|
91 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
92 |
+
|
93 |
+
self.norm = AdaLayerNormZeroSingle(dim)
|
94 |
+
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
95 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
96 |
+
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
97 |
+
|
98 |
+
processor = FluxSingleAttnProcessor2_0()
|
99 |
+
self.attn = Attention(
|
100 |
+
query_dim=dim,
|
101 |
+
cross_attention_dim=None,
|
102 |
+
dim_head=attention_head_dim,
|
103 |
+
heads=num_attention_heads,
|
104 |
+
out_dim=dim,
|
105 |
+
bias=True,
|
106 |
+
processor=processor,
|
107 |
+
qk_norm="rms_norm",
|
108 |
+
eps=1e-6,
|
109 |
+
pre_only=True,
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(
|
113 |
+
self,
|
114 |
+
hidden_states: torch.FloatTensor,
|
115 |
+
temb: torch.FloatTensor,
|
116 |
+
image_rotary_emb=None,
|
117 |
+
):
|
118 |
+
residual = hidden_states
|
119 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
120 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
121 |
+
|
122 |
+
attn_output = self.attn(
|
123 |
+
hidden_states=norm_hidden_states,
|
124 |
+
image_rotary_emb=image_rotary_emb,
|
125 |
+
)
|
126 |
+
|
127 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
128 |
+
gate = gate.unsqueeze(1)
|
129 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
130 |
+
hidden_states = residual + hidden_states
|
131 |
+
if hidden_states.dtype == torch.float16:
|
132 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
133 |
+
|
134 |
+
return hidden_states
|
135 |
+
|
136 |
+
|
137 |
+
@maybe_allow_in_graph
|
138 |
+
class FluxTransformerBlock(nn.Module):
|
139 |
+
r"""
|
140 |
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
141 |
+
|
142 |
+
Reference: https://arxiv.org/abs/2403.03206
|
143 |
+
|
144 |
+
Parameters:
|
145 |
+
dim (`int`): The number of channels in the input and output.
|
146 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
147 |
+
attention_head_dim (`int`): The number of channels in each head.
|
148 |
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
149 |
+
processing of `context` conditions.
|
150 |
+
"""
|
151 |
+
|
152 |
+
def __init__(
|
153 |
+
self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6
|
154 |
+
):
|
155 |
+
super().__init__()
|
156 |
+
|
157 |
+
self.norm1 = AdaLayerNormZero(dim)
|
158 |
+
|
159 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
160 |
+
|
161 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
162 |
+
processor = FluxAttnProcessor2_0()
|
163 |
+
else:
|
164 |
+
raise ValueError(
|
165 |
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
166 |
+
)
|
167 |
+
self.attn = Attention(
|
168 |
+
query_dim=dim,
|
169 |
+
cross_attention_dim=None,
|
170 |
+
added_kv_proj_dim=dim,
|
171 |
+
dim_head=attention_head_dim,
|
172 |
+
heads=num_attention_heads,
|
173 |
+
out_dim=dim,
|
174 |
+
context_pre_only=False,
|
175 |
+
bias=True,
|
176 |
+
processor=processor,
|
177 |
+
qk_norm=qk_norm,
|
178 |
+
eps=eps,
|
179 |
+
)
|
180 |
+
|
181 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
182 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
183 |
+
|
184 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
185 |
+
self.ff_context = FeedForward(
|
186 |
+
dim=dim, dim_out=dim, activation_fn="gelu-approximate"
|
187 |
+
)
|
188 |
+
|
189 |
+
# let chunk size default to None
|
190 |
+
self._chunk_size = None
|
191 |
+
self._chunk_dim = 0
|
192 |
+
|
193 |
+
def forward(
|
194 |
+
self,
|
195 |
+
hidden_states: torch.FloatTensor,
|
196 |
+
encoder_hidden_states: torch.FloatTensor,
|
197 |
+
temb: torch.FloatTensor,
|
198 |
+
image_rotary_emb=None,
|
199 |
+
):
|
200 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
201 |
+
hidden_states, emb=temb
|
202 |
+
)
|
203 |
+
|
204 |
+
(
|
205 |
+
norm_encoder_hidden_states,
|
206 |
+
c_gate_msa,
|
207 |
+
c_shift_mlp,
|
208 |
+
c_scale_mlp,
|
209 |
+
c_gate_mlp,
|
210 |
+
) = self.norm1_context(encoder_hidden_states, emb=temb)
|
211 |
+
|
212 |
+
# Attention.
|
213 |
+
attn_output, context_attn_output = self.attn(
|
214 |
+
hidden_states=norm_hidden_states,
|
215 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
216 |
+
image_rotary_emb=image_rotary_emb,
|
217 |
+
)
|
218 |
+
|
219 |
+
# Process attention outputs for the `hidden_states`.
|
220 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
221 |
+
hidden_states = hidden_states + attn_output
|
222 |
+
|
223 |
+
norm_hidden_states = self.norm2(hidden_states)
|
224 |
+
norm_hidden_states = (
|
225 |
+
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
226 |
+
)
|
227 |
+
|
228 |
+
ff_output = self.ff(norm_hidden_states)
|
229 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
230 |
+
|
231 |
+
hidden_states = hidden_states + ff_output
|
232 |
+
|
233 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
234 |
+
|
235 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
236 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
237 |
+
|
238 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
239 |
+
norm_encoder_hidden_states = (
|
240 |
+
norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
|
241 |
+
+ c_shift_mlp[:, None]
|
242 |
+
)
|
243 |
+
|
244 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
245 |
+
encoder_hidden_states = (
|
246 |
+
encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
247 |
+
)
|
248 |
+
if encoder_hidden_states.dtype == torch.float16:
|
249 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
250 |
+
|
251 |
+
return encoder_hidden_states, hidden_states
|
252 |
+
|
253 |
+
|
254 |
+
class FluxTransformer2DModel(
|
255 |
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
|
256 |
+
):
|
257 |
+
"""
|
258 |
+
The Transformer model introduced in Flux.
|
259 |
+
|
260 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
261 |
+
|
262 |
+
Parameters:
|
263 |
+
patch_size (`int`): Patch size to turn the input data into small patches.
|
264 |
+
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
265 |
+
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
|
266 |
+
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
|
267 |
+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
268 |
+
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
269 |
+
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
270 |
+
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
271 |
+
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
|
272 |
+
"""
|
273 |
+
|
274 |
+
_supports_gradient_checkpointing = True
|
275 |
+
|
276 |
+
@register_to_config
|
277 |
+
def __init__(
|
278 |
+
self,
|
279 |
+
patch_size: int = 1,
|
280 |
+
in_channels: int = 64,
|
281 |
+
num_layers: int = 19,
|
282 |
+
num_single_layers: int = 38,
|
283 |
+
attention_head_dim: int = 128,
|
284 |
+
num_attention_heads: int = 24,
|
285 |
+
joint_attention_dim: int = 4096,
|
286 |
+
pooled_projection_dim: int = 768,
|
287 |
+
guidance_embeds: bool = False,
|
288 |
+
axes_dims_rope: List[int] = [16, 56, 56],
|
289 |
+
):
|
290 |
+
super().__init__()
|
291 |
+
self.out_channels = in_channels
|
292 |
+
self.inner_dim = (
|
293 |
+
self.config.num_attention_heads * self.config.attention_head_dim
|
294 |
+
)
|
295 |
+
|
296 |
+
self.pos_embed = EmbedND(
|
297 |
+
dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope
|
298 |
+
)
|
299 |
+
text_time_guidance_cls = (
|
300 |
+
CombinedTimestepGuidanceTextProjEmbeddings
|
301 |
+
if guidance_embeds
|
302 |
+
else CombinedTimestepTextProjEmbeddings
|
303 |
+
)
|
304 |
+
self.time_text_embed = text_time_guidance_cls(
|
305 |
+
embedding_dim=self.inner_dim,
|
306 |
+
pooled_projection_dim=self.config.pooled_projection_dim,
|
307 |
+
)
|
308 |
+
|
309 |
+
self.context_embedder = nn.Linear(
|
310 |
+
self.config.joint_attention_dim, self.inner_dim
|
311 |
+
)
|
312 |
+
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
|
313 |
+
|
314 |
+
self.transformer_blocks = nn.ModuleList(
|
315 |
+
[
|
316 |
+
FluxTransformerBlock(
|
317 |
+
dim=self.inner_dim,
|
318 |
+
num_attention_heads=self.config.num_attention_heads,
|
319 |
+
attention_head_dim=self.config.attention_head_dim,
|
320 |
+
)
|
321 |
+
for i in range(self.config.num_layers)
|
322 |
+
]
|
323 |
+
)
|
324 |
+
|
325 |
+
self.single_transformer_blocks = nn.ModuleList(
|
326 |
+
[
|
327 |
+
FluxSingleTransformerBlock(
|
328 |
+
dim=self.inner_dim,
|
329 |
+
num_attention_heads=self.config.num_attention_heads,
|
330 |
+
attention_head_dim=self.config.attention_head_dim,
|
331 |
+
)
|
332 |
+
for i in range(self.config.num_single_layers)
|
333 |
+
]
|
334 |
+
)
|
335 |
+
|
336 |
+
self.norm_out = AdaLayerNormContinuous(
|
337 |
+
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
|
338 |
+
)
|
339 |
+
self.proj_out = nn.Linear(
|
340 |
+
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
|
341 |
+
)
|
342 |
+
|
343 |
+
self.gradient_checkpointing = False
|
344 |
+
|
345 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
346 |
+
if hasattr(module, "gradient_checkpointing"):
|
347 |
+
module.gradient_checkpointing = value
|
348 |
+
|
349 |
+
def forward(
|
350 |
+
self,
|
351 |
+
hidden_states: torch.Tensor,
|
352 |
+
encoder_hidden_states: torch.Tensor = None,
|
353 |
+
pooled_projections: torch.Tensor = None,
|
354 |
+
timestep: torch.LongTensor = None,
|
355 |
+
img_ids: torch.Tensor = None,
|
356 |
+
txt_ids: torch.Tensor = None,
|
357 |
+
guidance: torch.Tensor = None,
|
358 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
359 |
+
controlnet_block_samples=None,
|
360 |
+
controlnet_single_block_samples=None,
|
361 |
+
return_dict: bool = True,
|
362 |
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
363 |
+
"""
|
364 |
+
The [`FluxTransformer2DModel`] forward method.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
368 |
+
Input `hidden_states`.
|
369 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
370 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
371 |
+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
372 |
+
from the embeddings of input conditions.
|
373 |
+
timestep ( `torch.LongTensor`):
|
374 |
+
Used to indicate denoising step.
|
375 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
376 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
377 |
+
joint_attention_kwargs (`dict`, *optional*):
|
378 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
379 |
+
`self.processor` in
|
380 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
381 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
382 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
383 |
+
tuple.
|
384 |
+
|
385 |
+
Returns:
|
386 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
387 |
+
`tuple` where the first element is the sample tensor.
|
388 |
+
"""
|
389 |
+
if joint_attention_kwargs is not None:
|
390 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
391 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
392 |
+
else:
|
393 |
+
lora_scale = 1.0
|
394 |
+
|
395 |
+
if USE_PEFT_BACKEND:
|
396 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
397 |
+
scale_lora_layers(self, lora_scale)
|
398 |
+
else:
|
399 |
+
if (
|
400 |
+
joint_attention_kwargs is not None
|
401 |
+
and joint_attention_kwargs.get("scale", None) is not None
|
402 |
+
):
|
403 |
+
logger.warning(
|
404 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
405 |
+
)
|
406 |
+
hidden_states = self.x_embedder(hidden_states)
|
407 |
+
|
408 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
409 |
+
if guidance is not None:
|
410 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
411 |
+
else:
|
412 |
+
guidance = None
|
413 |
+
temb = (
|
414 |
+
self.time_text_embed(timestep, pooled_projections)
|
415 |
+
if guidance is None
|
416 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
417 |
+
)
|
418 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
419 |
+
|
420 |
+
txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
|
421 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
422 |
+
image_rotary_emb = self.pos_embed(ids)
|
423 |
+
|
424 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
425 |
+
if self.training and self.gradient_checkpointing:
|
426 |
+
|
427 |
+
def create_custom_forward(module, return_dict=None):
|
428 |
+
def custom_forward(*inputs):
|
429 |
+
if return_dict is not None:
|
430 |
+
return module(*inputs, return_dict=return_dict)
|
431 |
+
else:
|
432 |
+
return module(*inputs)
|
433 |
+
|
434 |
+
return custom_forward
|
435 |
+
|
436 |
+
ckpt_kwargs: Dict[str, Any] = (
|
437 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
438 |
+
)
|
439 |
+
(
|
440 |
+
encoder_hidden_states,
|
441 |
+
hidden_states,
|
442 |
+
) = torch.utils.checkpoint.checkpoint(
|
443 |
+
create_custom_forward(block),
|
444 |
+
hidden_states,
|
445 |
+
encoder_hidden_states,
|
446 |
+
temb,
|
447 |
+
image_rotary_emb,
|
448 |
+
**ckpt_kwargs,
|
449 |
+
)
|
450 |
+
|
451 |
+
else:
|
452 |
+
encoder_hidden_states, hidden_states = block(
|
453 |
+
hidden_states=hidden_states,
|
454 |
+
encoder_hidden_states=encoder_hidden_states,
|
455 |
+
temb=temb,
|
456 |
+
image_rotary_emb=image_rotary_emb,
|
457 |
+
)
|
458 |
+
|
459 |
+
# controlnet residual
|
460 |
+
if controlnet_block_samples is not None:
|
461 |
+
interval_control = len(self.transformer_blocks) / len(
|
462 |
+
controlnet_block_samples
|
463 |
+
)
|
464 |
+
interval_control = int(np.ceil(interval_control))
|
465 |
+
hidden_states = (
|
466 |
+
hidden_states
|
467 |
+
+ controlnet_block_samples[index_block // interval_control]
|
468 |
+
)
|
469 |
+
|
470 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
471 |
+
|
472 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
473 |
+
if self.training and self.gradient_checkpointing:
|
474 |
+
|
475 |
+
def create_custom_forward(module, return_dict=None):
|
476 |
+
def custom_forward(*inputs):
|
477 |
+
if return_dict is not None:
|
478 |
+
return module(*inputs, return_dict=return_dict)
|
479 |
+
else:
|
480 |
+
return module(*inputs)
|
481 |
+
|
482 |
+
return custom_forward
|
483 |
+
|
484 |
+
ckpt_kwargs: Dict[str, Any] = (
|
485 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
486 |
+
)
|
487 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
488 |
+
create_custom_forward(block),
|
489 |
+
hidden_states,
|
490 |
+
temb,
|
491 |
+
image_rotary_emb,
|
492 |
+
**ckpt_kwargs,
|
493 |
+
)
|
494 |
+
|
495 |
+
else:
|
496 |
+
hidden_states = block(
|
497 |
+
hidden_states=hidden_states,
|
498 |
+
temb=temb,
|
499 |
+
image_rotary_emb=image_rotary_emb,
|
500 |
+
)
|
501 |
+
|
502 |
+
# controlnet residual
|
503 |
+
if controlnet_single_block_samples is not None:
|
504 |
+
interval_control = len(self.single_transformer_blocks) / len(
|
505 |
+
controlnet_single_block_samples
|
506 |
+
)
|
507 |
+
interval_control = int(np.ceil(interval_control))
|
508 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
509 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
510 |
+
+ controlnet_single_block_samples[index_block // interval_control]
|
511 |
+
)
|
512 |
+
|
513 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
514 |
+
|
515 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
516 |
+
output = self.proj_out(hidden_states)
|
517 |
+
|
518 |
+
if USE_PEFT_BACKEND:
|
519 |
+
# remove `lora_scale` from each PEFT layer
|
520 |
+
unscale_lora_layers(self, lora_scale)
|
521 |
+
|
522 |
+
if not return_dict:
|
523 |
+
return (output,)
|
524 |
+
|
525 |
+
return Transformer2DModelOutput(sample=output)
|