Flux Redux
In Redux, SigLIP "reads" the image features, which are then passed to the Flux model as T5 embeddings.
Image variations
The example disables the T5 model for wider use. Enable it if needed.
Otherwise it will generate the sameish image, neglecting the prompt.
from diffusers import FluxTransformer2DModel, FluxPipeline, StableDiffusion3Pipeline
from PIL import Image
from safetensors.torch import load_file
import torch
from torch import nn
from transformers import SiglipImageProcessor, SiglipVisionModel
from typing import List, Optional, Union
class Styler:
def set_style(self, image: Image.Image, siglip_path):
image = image.convert('RGB')
if not hasattr(self, 'siglip'):
self.siglip = SiglipVisionModel.from_pretrained(siglip_path)
self.siglip_processor = SiglipImageProcessor.from_pretrained(siglip_path)
with torch.no_grad():
normalized_cond = self.siglip_processor.preprocess(images=[image],
do_convert_rgb=True,
do_resize=True,
return_tensors='pt')
self.siglip.to('cuda')
encoded_cond = self.siglip(**normalized_cond.to('cuda')).last_hidden_state
self.siglip.to('cpu')
# SD3 pipeline is in float16.
self.image_cond = image_encoder(encoded_cond).to(torch.float16)
class ReduxImageEncoder(nn.Module):
# Code from black-forest-labs/flux/src/flux/modules/image_embedders.py
def __init__(self, redux_dim: int = 1152, txt_in_features: int = 4096):
super().__init__()
self.redux_up = nn.Linear(redux_dim, txt_in_features * 3)
self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features)
def forward(self, x: torch.Tensor):
return self.redux_down(nn.functional.silu(self.redux_up(x)))
class ReduxFluxPipeline(FluxPipeline, Styler):
def _get_t5_prompt_embeds(self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None):
if self.text_encoder_2 is None:
prompt_embeds = torch.zeros((1, max_sequence_length, 4096)).to(device)
else:
prompt_embeds = super()._get_t5_prompt_embeds(prompt,
num_images_per_prompt,
max_sequence_length,
device,
dtype)
# Embed Redux.
return torch.cat((prompt_embeds, self.image_cond), dim=-2)
if __name__ == '__main__':
pipe = ReduxFluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev',
text_encoder_2=None, # Remove this line.
torch_dtype=torch.bfloat16)
pipe.to('cuda')
image_encoder = ReduxImageEncoder()
image_encoder.requires_grad_(False)
image_encoder.load_state_dict(load_file('/path/to/flux1-redux-dev.safetensors'))
image_encoder.to('cuda')
if isinstance(pipe, Styler):
pipe.set_style(Image.open('./reference_image.png'),
'google/siglip-so400m-patch14-384')
image = pipe('portrait', num_inference_steps=20).images[0]
image.save('preview.png')
Constrainted style
Apply the style to the first steps of the inference.
class ReduxTransformer2DModel(FluxTransformer2DModel):
def forward(self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False):
if (hasattr(self, '_end_at_step') and
pipe.scheduler._step_index is not None and
pipe.scheduler._step_index >= self._end_at_step):
encoder_hidden_states = encoder_hidden_states[:, :512, :]
txt_ids = txt_ids[:512, :]
noise_pred = super().forward(hidden_states,
encoder_hidden_states,
pooled_projections,
timestep,
img_ids,
txt_ids,
guidance,
joint_attention_kwargs,
controlnet_block_samples,
controlnet_single_block_samples,
return_dict,
controlnet_blocks_repeat)
return noise_pred
def end_at_step(self, step):
self._end_at_step = step
if __name__ == '__main__':
transformer = ReduxTransformer2DModel.from_pretrained('black-forest-labs/FLUX.1-dev',
subfolder='transformer')
pipe = ReduxFluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev',
transformer=transformer,
text_encoder_2=None, # Remove this line.
torch_dtype=torch.bfloat16)
pipe.to('cuda')
image_encoder = ReduxImageEncoder()
image_encoder.requires_grad_(False)
image_encoder.load_state_dict(load_file('/path/to/flux1-redux-dev.safetensors'))
image_encoder.to('cuda')
if isinstance(pipe, Styler):
pipe.set_style(Image.open('./reference_image.png'),
'google/siglip-so400m-patch14-384')
if isinstance(transformer, ReduxTransformer2DModel):
transformer.end_at_step(4)
image = pipe('portrait', num_inference_steps=20).images[0]
image.save('preview.png')
SD3.5M Redux
This will not produce the expected result.
class ReduxDiffusion3Pipeline(StableDiffusion3Pipeline, Styler):
def _get_t5_prompt_embeds(self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 256,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None):
if self.text_encoder_2 is None:
prompt_embeds = torch.zeros((1, max_sequence_length, 4096)).to(device)
else:
prompt_embeds = super()._get_t5_prompt_embeds(prompt,
num_images_per_prompt,
max_sequence_length,
device,
dtype)
# Embed Redux.
return torch.cat((prompt_embeds, self.image_cond), dim=-2)
if __name__ == '__main__':
pipe = ReduxDiffusion3Pipeline.from_pretrained('stabilityai/stable-diffusion-3.5-medium',
text_encoder_3=None,
tokenizer_3=None,
torch_dtype=torch.float16)
pipe.to('cuda')
image_encoder = ReduxImageEncoder()
image_encoder.requires_grad_(False)
image_encoder.load_state_dict(load_file('/path/to/flux1-redux-dev.safetensors'))
image_encoder.to('cuda')
if isinstance(pipe, Styler):
pipe.set_style(Image.open('./reference_image.png'),
'google/siglip-so400m-patch14-384')
image = pipe('portrait', num_inference_steps=20).images[0]
image.save('preview.png')
Disclaimer
Use of this code requires citation and attribution to the author via a link to their Hugging Face profile in all resulting work.