--- base_model: - black-forest-labs/FLUX.1-dev - black-forest-labs/FLUX.1-schnell language: - en tags: - merge - flux --- # Aryanne/flux_swap This model is a merge of [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) and [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell). But different than others methods here the values in the tensors are not changed but substitute in a checkboard pattern with the values of FLUX.1-schnell, so ~50% of each is present here.(if my code is right) ```python from diffusers import FluxTransformer2DModel from huggingface_hub import snapshot_download from accelerate import init_empty_weights from diffusers.models.model_loading_utils import load_model_dict_into_meta import safetensors.torch import glob import torch import gc with init_empty_weights(): config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer") model = FluxTransformer2DModel.from_config(config) dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", allow_patterns="transformer/*") schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*") dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors")) schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors")) def swapping_method(base, x, parameters): def swap_values(shape, n, base, x): if x.dim() == 2: rows, cols = shape rows_range = torch.arange(rows).view(-1, 1) cols_range = torch.arange(cols).view(1, -1) mask = ((rows_range + cols_range) % n == 0).to(base.device.type).bool() x = torch.where(mask, x, base) else: rows_range = torch.arange(shape[0]) mask = ((rows_range) % n == 0).to(base.device.type).bool() x = torch.where(mask, x, base) return x def rand_mask(base, x, percent, seed=None): oldseed = torch.seed() if seed is not None: torch.manual_seed(seed) random = torch.rand(base.shape) mask = (random <= percent).to(base.device.type).bool() del random torch.manual_seed(oldseed) x = torch.where(mask, x, base) return x if x.device.type == "cpu": x = x.to(torch.bfloat16) base = base.to(torch.bfloat16) diagonal_offset = None diagonal_offset = parameters.get('diagonal_offset') random_mask = parameters.get('random_mask') random_mask_seed = parameters.get('random_mask_seed') random_mask_seed = int(random_mask_seed) if random_mask_seed is not None else random_mask_seed assert (diagonal_offset is not None) and (diagonal_offset % 1 == 0) and (diagonal_offset >= 2), "The diagonal_offset must be an integer greater than or equal to 2." if random_mask != 0.0: assert (random_mask is not None) and (random_mask < 1.0) and (random_mask > 0.0) , "The random_mask parameter can't be empty, 0, 1, or None, it must be a number between 0 and 1." assert random_mask_seed is None or (isinstance(random_mask_seed, int) and random_mask_seed % 1 == 0), "The random_mask_seed parameter must be None or an integer, None is a random seed." x = rand_mask(base, x, random_mask, random_mask_seed) else: if parameters.get('invert_offset') == False: x = swap_values(x.shape, diagonal_offset, base, x) else: x = swap_values(x.shape, diagonal_offset, x, base) del base return x parameters = { 'diagonal_offset': 2, 'random_mask': False, 'invert_offset': False, # 'random_mask_seed': "899557" } merged_state_dict = {} guidance_state_dict = {} for i in range(len((dev_shards))): state_dict_dev_temp = safetensors.torch.load_file(dev_shards[i]) state_dict_schnell_temp = safetensors.torch.load_file(schnell_shards[i]) keys = list(state_dict_dev_temp.keys()) for k in keys: if "guidance" not in k: merged_state_dict[k] = swapping_method(state_dict_dev_temp.pop(k),state_dict_schnell_temp.pop(k), parameters) else: guidance_state_dict[k] = state_dict_dev_temp.pop(k) if len(state_dict_dev_temp) > 0: raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.") if len(state_dict_schnell_temp) > 0: raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.") merged_state_dict.update(guidance_state_dict) load_model_dict_into_meta(model, merged_state_dict) model.to(torch.bfloat16).save_pretrained("merged-flux") ``` Used a piece of this code from [mergekit](https://github.com/Ar57m/mergekit/tree/swapping) Thanks SayakPaul for your code which helped me do this merge.