metadata
base_model:
- black-forest-labs/FLUX.1-dev
- black-forest-labs/FLUX.1-schnell
language:
- en
license: other
license_name: flux-1-dev-non-commercial-license
license_link: LICENSE.md
tags:
- merge
- flux
Aryanne/flux_swap
This model is a merge of black-forest-labs/FLUX.1-dev and black-forest-labs/FLUX.1-schnell.
But different than others methods here the values in the tensors are not changed but substitute in a checkboard pattern with the values of FLUX.1-schnell, so ~50% of each is present here.(if my code is right)
from diffusers import FluxTransformer2DModel
from huggingface_hub import snapshot_download
from accelerate import init_empty_weights
from diffusers.models.model_loading_utils import load_model_dict_into_meta
import safetensors.torch
import glob
import torch
import gc
with init_empty_weights():
config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer")
model = FluxTransformer2DModel.from_config(config)
dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", allow_patterns="transformer/*")
schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*")
dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors"))
schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors"))
def swapping_method(base, x, parameters):
def swap_values(shape, n, base, x):
if x.dim() == 2:
rows, cols = shape
rows_range = torch.arange(rows).view(-1, 1)
cols_range = torch.arange(cols).view(1, -1)
mask = ((rows_range + cols_range) % n == 0).to(base.device.type).bool()
x = torch.where(mask, x, base)
else:
rows_range = torch.arange(shape[0])
mask = ((rows_range) % n == 0).to(base.device.type).bool()
x = torch.where(mask, x, base)
return x
def rand_mask(base, x, percent, seed=None):
oldseed = torch.seed()
if seed is not None:
torch.manual_seed(seed)
random = torch.rand(base.shape)
mask = (random <= percent).to(base.device.type).bool()
del random
torch.manual_seed(oldseed)
x = torch.where(mask, x, base)
return x
if x.device.type == "cpu":
x = x.to(torch.bfloat16)
base = base.to(torch.bfloat16)
diagonal_offset = None
diagonal_offset = parameters.get('diagonal_offset')
random_mask = parameters.get('random_mask')
random_mask_seed = parameters.get('random_mask_seed')
random_mask_seed = int(random_mask_seed) if random_mask_seed is not None else random_mask_seed
assert (diagonal_offset is not None) and (diagonal_offset % 1 == 0) and (diagonal_offset >= 2), "The diagonal_offset must be an integer greater than or equal to 2."
if random_mask != 0.0:
assert (random_mask is not None) and (random_mask < 1.0) and (random_mask > 0.0) , "The random_mask parameter can't be empty, 0, 1, or None, it must be a number between 0 and 1."
assert random_mask_seed is None or (isinstance(random_mask_seed, int) and random_mask_seed % 1 == 0), "The random_mask_seed parameter must be None or an integer, None is a random seed."
x = rand_mask(base, x, random_mask, random_mask_seed)
else:
if parameters.get('invert_offset') == False:
x = swap_values(x.shape, diagonal_offset, base, x)
else:
x = swap_values(x.shape, diagonal_offset, x, base)
del base
return x
parameters = {
'diagonal_offset': 2,
'random_mask': False,
'invert_offset': False,
# 'random_mask_seed': "899557"
}
merged_state_dict = {}
guidance_state_dict = {}
for i in range(len((dev_shards))):
state_dict_dev_temp = safetensors.torch.load_file(dev_shards[i])
state_dict_schnell_temp = safetensors.torch.load_file(schnell_shards[i])
keys = list(state_dict_dev_temp.keys())
for k in keys:
if "guidance" not in k:
merged_state_dict[k] = swapping_method(state_dict_dev_temp.pop(k),state_dict_schnell_temp.pop(k), parameters)
else:
guidance_state_dict[k] = state_dict_dev_temp.pop(k)
if len(state_dict_dev_temp) > 0:
raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.")
if len(state_dict_schnell_temp) > 0:
raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.")
merged_state_dict.update(guidance_state_dict)
load_model_dict_into_meta(model, merged_state_dict)
model.to(torch.bfloat16).save_pretrained("merged-flux")
Used a piece of this code from mergekit
Thanks SayakPaul for your code which helped me do this merge.