Aryanne commited on
Commit
c918843
1 Parent(s): 61e8f5d

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +134 -0
README.md ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: black-forest-labs/FLUX.1-dev
3
+ language:
4
+ - en
5
+ tags:
6
+ - merge
7
+ - flux
8
+ ---
9
+
10
+ # Aryanne/flux_swap
11
+ This model is a merge of [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) and [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell).
12
+
13
+ But different than others methods here the values in the tensors are not changed but substitute in a checkboard pattern with the values of FLUX.1-schnell, so ~50% of each is present here.(if my code is right)
14
+
15
+ ```python
16
+ from diffusers import FluxTransformer2DModel
17
+ from huggingface_hub import snapshot_download
18
+ from accelerate import init_empty_weights
19
+ from diffusers.models.model_loading_utils import load_model_dict_into_meta
20
+ import safetensors.torch
21
+ import glob
22
+ import torch
23
+ import gc
24
+
25
+
26
+
27
+
28
+ with init_empty_weights():
29
+ config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer")
30
+ model = FluxTransformer2DModel.from_config(config)
31
+
32
+ dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", allow_patterns="transformer/*")
33
+ schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*")
34
+
35
+ dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors"))
36
+ schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors"))
37
+
38
+ def swapping_method(base, x, parameters):
39
+ def swap_values(shape, n, base, x):
40
+ if x.dim() == 2:
41
+ rows, cols = shape
42
+ rows_range = torch.arange(rows).view(-1, 1)
43
+ cols_range = torch.arange(cols).view(1, -1)
44
+ mask = ((rows_range + cols_range) % n == 0).to(base.device.type).bool()
45
+ x = torch.where(mask, x, base)
46
+ else:
47
+ rows_range = torch.arange(shape[0])
48
+ mask = ((rows_range) % n == 0).to(base.device.type).bool()
49
+ x = torch.where(mask, x, base)
50
+ return x
51
+
52
+ def rand_mask(base, x, percent, seed=None):
53
+ oldseed = torch.seed()
54
+ if seed is not None:
55
+ torch.manual_seed(seed)
56
+ random = torch.rand(base.shape)
57
+ mask = (random <= percent).to(base.device.type).bool()
58
+ del random
59
+ torch.manual_seed(oldseed)
60
+ x = torch.where(mask, x, base)
61
+ return x
62
+
63
+
64
+ if x.device.type == "cpu":
65
+ x = x.to(torch.bfloat16)
66
+ base = base.to(torch.bfloat16)
67
+
68
+ diagonal_offset = None
69
+ diagonal_offset = parameters.get('diagonal_offset')
70
+ random_mask = parameters.get('random_mask')
71
+ random_mask_seed = parameters.get('random_mask_seed')
72
+ random_mask_seed = int(random_mask_seed) if random_mask_seed is not None else random_mask_seed
73
+
74
+ assert (diagonal_offset is not None) and (diagonal_offset % 1 == 0) and (diagonal_offset >= 2), "The diagonal_offset must be an integer greater than or equal to 2."
75
+
76
+ if random_mask != 0.0:
77
+ assert (random_mask is not None) and (random_mask < 1.0) and (random_mask > 0.0) , "The random_mask parameter can't be empty, 0, 1, or None, it must be a number between 0 and 1."
78
+ assert random_mask_seed is None or (isinstance(random_mask_seed, int) and random_mask_seed % 1 == 0), "The random_mask_seed parameter must be None or an integer, None is a random seed."
79
+ x = rand_mask(base, x, random_mask, random_mask_seed)
80
+
81
+ else:
82
+ if parameters.get('invert_offset') == False:
83
+ x = swap_values(x.shape, diagonal_offset, base, x)
84
+ else:
85
+ x = swap_values(x.shape, diagonal_offset, x, base)
86
+
87
+ del base
88
+ return x
89
+
90
+ parameters = {
91
+ 'diagonal_offset': 2,
92
+ 'random_mask': False,
93
+ 'invert_offset': False,
94
+ # 'random_mask_seed': "899557"
95
+ }
96
+
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+ merged_state_dict = {}
105
+ guidance_state_dict = {}
106
+
107
+ for i in range(len((dev_shards))):
108
+ state_dict_dev_temp = safetensors.torch.load_file(dev_shards[i])
109
+ state_dict_schnell_temp = safetensors.torch.load_file(schnell_shards[i])
110
+
111
+ keys = list(state_dict_dev_temp.keys())
112
+ for k in keys:
113
+ if "guidance" not in k:
114
+ merged_state_dict[k] = swapping_method(state_dict_dev_temp.pop(k),state_dict_schnell_temp.pop(k), parameters)
115
+ else:
116
+ guidance_state_dict[k] = state_dict_dev_temp.pop(k)
117
+
118
+ if len(state_dict_dev_temp) > 0:
119
+ raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.")
120
+ if len(state_dict_schnell_temp) > 0:
121
+ raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.")
122
+
123
+
124
+
125
+
126
+ merged_state_dict.update(guidance_state_dict)
127
+ load_model_dict_into_meta(model, merged_state_dict)
128
+
129
+ model.to(torch.bfloat16).save_pretrained("merged-flux")
130
+ ```
131
+
132
+ Used a piece of this code from [mergekit](https://github.com/Ar57m/mergekit/tree/swapping)
133
+
134
+ Thanks SayakPaul for your code which helped me do this merge.