Flux Redux

In Redux, SigLIP "reads" the image features, which are then passed to the Flux model as T5 embeddings.

Image variations

The example disables the T5 model for wider use. Enable it if needed.

Otherwise it will generate the sameish image, neglecting the prompt.

from diffusers import FluxTransformer2DModel, FluxPipeline, StableDiffusion3Pipeline
from PIL import Image
from safetensors.torch import load_file
import torch
from torch import nn
from transformers import SiglipImageProcessor, SiglipVisionModel
from typing import List, Optional, Union

class Styler:
    def set_style(self, image: Image.Image, siglip_path):
        image = image.convert('RGB')
        if not hasattr(self, 'siglip'):
            self.siglip = SiglipVisionModel.from_pretrained(siglip_path)
            self.siglip_processor = SiglipImageProcessor.from_pretrained(siglip_path)
        with torch.no_grad():
            normalized_cond = self.siglip_processor.preprocess(images=[image],
                                                               do_convert_rgb=True,
                                                               do_resize=True,
                                                               return_tensors='pt')
            self.siglip.to('cuda')
            encoded_cond = self.siglip(**normalized_cond.to('cuda')).last_hidden_state
            self.siglip.to('cpu')
            # SD3 pipeline is in float16.
            self.image_cond = image_encoder(encoded_cond).to(torch.float16)

class ReduxImageEncoder(nn.Module):
    # Code from black-forest-labs/flux/src/flux/modules/image_embedders.py
    def __init__(self, redux_dim: int = 1152, txt_in_features: int = 4096):
        super().__init__()
        self.redux_up = nn.Linear(redux_dim, txt_in_features * 3)
        self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features)

    def forward(self, x: torch.Tensor):
        return self.redux_down(nn.functional.silu(self.redux_up(x)))

class ReduxFluxPipeline(FluxPipeline, Styler):
    def _get_t5_prompt_embeds(self,
                              prompt: Union[str, List[str]] = None,
                              num_images_per_prompt: int = 1,
                              max_sequence_length: int = 512,
                              device: Optional[torch.device] = None,
                              dtype: Optional[torch.dtype] = None):
        if self.text_encoder_2 is None:
            prompt_embeds = torch.zeros((1, max_sequence_length, 4096)).to(device)
        else:
            prompt_embeds = super()._get_t5_prompt_embeds(prompt,
                                                          num_images_per_prompt,
                                                          max_sequence_length,
                                                          device,
                                                          dtype)

        # Embed Redux.
        return torch.cat((prompt_embeds, self.image_cond), dim=-2)

if __name__ == '__main__':
    pipe = ReduxFluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev',
                                             text_encoder_2=None,  # Remove this line.
                                             torch_dtype=torch.bfloat16)
    pipe.to('cuda')
    image_encoder = ReduxImageEncoder()
    image_encoder.requires_grad_(False)
    image_encoder.load_state_dict(load_file('/path/to/flux1-redux-dev.safetensors'))
    image_encoder.to('cuda')
    if isinstance(pipe, Styler):
        pipe.set_style(Image.open('./reference_image.png'),
                       'google/siglip-so400m-patch14-384')
    image = pipe('portrait', num_inference_steps=20).images[0]
    image.save('preview.png')

Constrainted style

Apply the style to the first steps of the inference.

class ReduxTransformer2DModel(FluxTransformer2DModel):
    def forward(self,
                hidden_states: torch.Tensor,
                encoder_hidden_states: torch.Tensor = None,
                pooled_projections: torch.Tensor = None,
                timestep: torch.LongTensor = None,
                img_ids: torch.Tensor = None,
                txt_ids: torch.Tensor = None,
                guidance: torch.Tensor = None,
                joint_attention_kwargs: Optional[Dict[str, Any]] = None,
                controlnet_block_samples=None,
                controlnet_single_block_samples=None,
                return_dict: bool = True,
                controlnet_blocks_repeat: bool = False):
        if (hasattr(self, '_end_at_step') and
            pipe.scheduler._step_index is not None and
            pipe.scheduler._step_index >= self._end_at_step):
            encoder_hidden_states = encoder_hidden_states[:, :512, :]
            txt_ids = txt_ids[:512, :]
        noise_pred = super().forward(hidden_states,
                                     encoder_hidden_states,
                                     pooled_projections,
                                     timestep,
                                     img_ids,
                                     txt_ids,
                                     guidance,
                                     joint_attention_kwargs,
                                     controlnet_block_samples,
                                     controlnet_single_block_samples,
                                     return_dict,
                                     controlnet_blocks_repeat)

        return noise_pred

    def end_at_step(self, step):
        self._end_at_step = step

if __name__ == '__main__':
    transformer = ReduxTransformer2DModel.from_pretrained('black-forest-labs/FLUX.1-dev',
                                                          subfolder='transformer')
    pipe = ReduxFluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev',
                                             transformer=transformer,
                                             text_encoder_2=None,  # Remove this line.
                                             torch_dtype=torch.bfloat16)
    pipe.to('cuda')
    image_encoder = ReduxImageEncoder()
    image_encoder.requires_grad_(False)
    image_encoder.load_state_dict(load_file('/path/to/flux1-redux-dev.safetensors'))
    image_encoder.to('cuda')
    if isinstance(pipe, Styler):
        pipe.set_style(Image.open('./reference_image.png'),
                       'google/siglip-so400m-patch14-384')
    if isinstance(transformer, ReduxTransformer2DModel):
        transformer.end_at_step(4)
    image = pipe('portrait', num_inference_steps=20).images[0]
    image.save('preview.png')

SD3.5M Redux

This will not produce the expected result.

class ReduxDiffusion3Pipeline(StableDiffusion3Pipeline, Styler):
    def _get_t5_prompt_embeds(self,
                              prompt: Union[str, List[str]] = None,
                              num_images_per_prompt: int = 1,
                              max_sequence_length: int = 256,
                              device: Optional[torch.device] = None,
                              dtype: Optional[torch.dtype] = None):
        if self.text_encoder_2 is None:
            prompt_embeds = torch.zeros((1, max_sequence_length, 4096)).to(device)
        else:
            prompt_embeds = super()._get_t5_prompt_embeds(prompt,
                                                          num_images_per_prompt,
                                                          max_sequence_length,
                                                          device,
                                                          dtype)

        # Embed Redux.
        return torch.cat((prompt_embeds, self.image_cond), dim=-2)

if __name__ == '__main__':
    pipe = ReduxDiffusion3Pipeline.from_pretrained('stabilityai/stable-diffusion-3.5-medium',
                                                   text_encoder_3=None,
                                                   tokenizer_3=None,
                                                   torch_dtype=torch.float16)
    pipe.to('cuda')
    image_encoder = ReduxImageEncoder()
    image_encoder.requires_grad_(False)
    image_encoder.load_state_dict(load_file('/path/to/flux1-redux-dev.safetensors'))
    image_encoder.to('cuda')
    if isinstance(pipe, Styler):
        pipe.set_style(Image.open('./reference_image.png'),
                       'google/siglip-so400m-patch14-384')
    image = pipe('portrait', num_inference_steps=20).images[0]
    image.save('preview.png')

Disclaimer

Use of this code requires citation and attribution to the author via a link to their Hugging Face profile in all resulting work.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .