multimodalart HF staff commited on
Commit
5decbb5
1 Parent(s): 9b08e5f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +359 -0
app.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ is_spaces = True if os.environ.get('SPACE_ID') else False
3
+
4
+ if(is_spaces):
5
+ import spaces
6
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
7
+ import sys
8
+
9
+ from dotenv import load_dotenv
10
+ load_dotenv()
11
+
12
+ # Add the current working directory to the Python path
13
+ sys.path.insert(0, os.getcwd())
14
+
15
+ import gradio as gr
16
+ from PIL import Image
17
+ import torch
18
+ import uuid
19
+ import os
20
+ import shutil
21
+ import json
22
+ import yaml
23
+ from slugify import slugify
24
+ from transformers import AutoProcessor, AutoModelForCausalLM
25
+ if(not is_spaces):
26
+ from toolkit.job import get_job
27
+
28
+ MAX_IMAGES = 150
29
+
30
+ def load_captioning(uploaded_images, concept_sentence):
31
+ updates = []
32
+ if len(uploaded_images) <= 1:
33
+ raise gr.Error(
34
+ "Please upload at least 2 images to train your model (the ideal number with default settings is between 4-30)"
35
+ )
36
+ elif len(uploaded_images) > MAX_IMAGES:
37
+ raise gr.Error(
38
+ f"For now, only {MAX_IMAGES} or less images are allowed for training"
39
+ )
40
+ # Update for the captioning_area
41
+ #for _ in range(3):
42
+ updates.append(gr.update(visible=True))
43
+ # Update visibility and image for each captioning row and image
44
+ for i in range(1, MAX_IMAGES + 1):
45
+ # Determine if the current row and image should be visible
46
+ visible = i <= len(uploaded_images)
47
+
48
+ # Update visibility of the captioning row
49
+ updates.append(gr.update(visible=visible))
50
+
51
+ # Update for image component - display image if available, otherwise hide
52
+ image_value = uploaded_images[i - 1] if visible else None
53
+
54
+ updates.append(gr.update(value=image_value, visible=visible))
55
+
56
+ #Update value of captioning area
57
+ text_value = "[trigger]" if visible and concept_sentence else None
58
+ updates.append(gr.update(value=text_value, visible=visible))
59
+
60
+ #Update for the sample caption area
61
+ updates.append(gr.update(visible=True))
62
+ updates.append(gr.update(placeholder=f'A photo of {concept_sentence} holding a sign that reads "Hello friend"'))
63
+ updates.append(gr.update(placeholder=f'A mountainous landscape in the style of {concept_sentence}'))
64
+ updates.append(gr.update(placeholder=f'A {concept_sentence} in a mall'))
65
+ return updates
66
+
67
+ if(is_spaces):
68
+ load_captioning = spaces.GPU()(load_captioning)
69
+
70
+ def create_dataset(*inputs):
71
+ print("Creating dataset")
72
+ images = inputs[0]
73
+ destination_folder = str(uuid.uuid4())
74
+ if not os.path.exists(destination_folder):
75
+ os.makedirs(destination_folder)
76
+
77
+ jsonl_file_path = os.path.join(destination_folder, 'metadata.jsonl')
78
+ with open(jsonl_file_path, 'a') as jsonl_file:
79
+ for index, image in enumerate(images):
80
+ new_image_path = shutil.copy(image, destination_folder)
81
+
82
+ original_caption = inputs[index + 1]
83
+ file_name = os.path.basename(new_image_path)
84
+
85
+ data = {"file_name": file_name, "prompt": original_caption}
86
+
87
+ jsonl_file.write(json.dumps(data) + "\n")
88
+
89
+ return destination_folder
90
+
91
+ def run_captioning(images, concept_sentence, *captions):
92
+ device = "cuda" if torch.cuda.is_available() else "cpu"
93
+ torch_dtype = torch.float16
94
+ model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
95
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
96
+
97
+ captions = list(captions)
98
+ for i, image_path in enumerate(images):
99
+ print(captions[i])
100
+ if isinstance(image_path, str): # If image is a file path
101
+ image = Image.open(image_path).convert('RGB')
102
+
103
+ prompt = "<DETAILED_CAPTION>"
104
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
105
+
106
+ generated_ids = model.generate(
107
+ input_ids=inputs["input_ids"],
108
+ pixel_values=inputs["pixel_values"],
109
+ max_new_tokens=1024,
110
+ num_beams=3
111
+ )
112
+
113
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
114
+ parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
115
+ caption_text = parsed_answer['<DETAILED_CAPTION>'].replace("The image shows ", "")
116
+ if(concept_sentence):
117
+ caption_text = f"{caption_text} [trigger]"
118
+ captions[i] = caption_text
119
+
120
+
121
+ yield captions
122
+ model.to("cpu")
123
+ del model
124
+ del processor
125
+
126
+ def start_training(
127
+ lora_name,
128
+ concept_sentence,
129
+ steps,
130
+ lr,
131
+ rank,
132
+ dataset_folder,
133
+ sample_1,
134
+ sample_2,
135
+ sample_3,
136
+ ):
137
+ if not lora_name:
138
+ raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.")
139
+ print("Started training")
140
+ slugged_lora_name = slugify(lora_name)
141
+
142
+ # Load the default config
143
+ with open("config/examples/train_lora_flux_24gb.yaml", "r") as f:
144
+ config = yaml.safe_load(f)
145
+
146
+ # Update the config with user inputs
147
+ config['config']['name'] = slugged_lora_name
148
+ config['config']['process'][0]['model']['low_vram'] = True
149
+ config['config']['process'][0]['train']['skip_first_sample'] = True
150
+ config['config']['process'][0]['train']['steps'] = int(steps)
151
+ config['config']['process'][0]['train']['lr'] = float(lr)
152
+ config['config']['process'][0]['network']['linear'] = int(rank)
153
+ config['config']['process'][0]['network']['linear_alpha'] = int(rank)
154
+ config['config']['process'][0]['datasets'][0]['folder_path'] = dataset_folder
155
+ if(concept_sentence):
156
+ config['config']['process'][0]['trigger_word'] = concept_sentence
157
+ if(sample_1 or sample_2 or sample_2):
158
+ config['config']['process'][0]['train']['disable_sampling'] = False
159
+ config['config']['process'][0]['sample']["sample_every"] = steps
160
+ config['config']['process'][0]['sample']['prompts'] = []
161
+ if(sample_1):
162
+ config['config']['process'][0]['sample']['prompts'].append(sample_1)
163
+ if(sample_2):
164
+ config['config']['process'][0]['sample']['prompts'].append(sample_2)
165
+ if(sample_3):
166
+ config['config']['process'][0]['sample']['prompts'].append(sample_3)
167
+ else:
168
+ config['config']['process'][0]['train']['disable_sampling'] = True
169
+ # Save the updated config
170
+ config_path = f"config/{slugged_lora_name}.yaml"
171
+ with open(config_path, "w") as f:
172
+ yaml.dump(config, f)
173
+
174
+ job = get_job(config_path)
175
+
176
+ # Run the job
177
+ job.run()
178
+ job.cleanup()
179
+
180
+ return f"Training completed successfully. Model saved as {slugged_lora_name}"
181
+
182
+ def start_training_spaces(
183
+ lora_name,
184
+ concept_sentence,
185
+ steps,
186
+ lr,
187
+ rank,
188
+ dataset_folder,
189
+ sample_1,
190
+ sample_2,
191
+ sample_3,
192
+ ):
193
+ #Feel free to include the spacerunner stuff here @abhishek
194
+ pass
195
+
196
+ theme = gr.themes.Monochrome(
197
+ text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"),
198
+ font=[gr.themes.GoogleFont('Source Sans Pro'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
199
+ )
200
+ css = '''
201
+ #component-1{text-align:center}
202
+ .main_ui_logged_out{opacity: 0.5; poiner-events: none}
203
+ .tabitem{border: 0px}
204
+ '''
205
+
206
+ def swap_visibilty(profile: gr.OAuthProfile | None):
207
+ print(profile)
208
+ if(is_spaces):
209
+ if profile is None:
210
+ return gr.update(elem_classes=["main_ui_logged_out"])
211
+ else:
212
+ print(profile.name)
213
+ return gr.update(elem_classes=["main_ui_logged_in"])
214
+ else:
215
+ gr.update(elem_classes=["main_ui_logged_in"])
216
+
217
+ with gr.Blocks(theme=theme, css=css) as demo:
218
+ gr.Markdown('''# LoRA Ease for FLUX 🧞‍♂️
219
+ ### Train a high quality FLUX LoRA in a breeze ༄ using [Ostris' AI Toolkit](https://github.com/ostris/ai-toolkit) and [AutoTrain Advanced](https://github.com/huggingface/autotrain-advanced)''')
220
+ gr.LoginButton(visible=is_spaces)
221
+ with gr.Tab("Train on Spaces" if is_spaces else "Train locally"):
222
+ with gr.Column(elem_classes="main_ui_logged_out") as main_ui:
223
+ with gr.Row():
224
+ lora_name = gr.Textbox(label="The name of your LoRA", info="This has to be a unique name", placeholder="e.g.: Persian Miniature Painting style, Cat Toy")
225
+ #training_option = gr.Radio(
226
+ # label="What are you training?", choices=["object", "style", "character", "face", "custom"]
227
+ #)
228
+ concept_sentence = gr.Textbox(
229
+ label="Trigger word/sentence",
230
+ info="Trigger word or sentence to be used",
231
+ placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
232
+ interactive=True,
233
+ )
234
+ with gr.Group(visible=True) as image_upload:
235
+ with gr.Row():
236
+ images = gr.File(
237
+ file_types=["image"],
238
+ label="Upload your images",
239
+ file_count="multiple",
240
+ interactive=True,
241
+ visible=True,
242
+ scale=1,
243
+ )
244
+ with gr.Column(scale=3, visible=False) as captioning_area:
245
+ with gr.Column():
246
+ gr.Markdown("""# Custom captioning
247
+ You can optionally add a custom caption for each image (or use an AI model for this). [trigger] will represent your concept sentence/trigger word.
248
+ """)
249
+ do_captioning = gr.Button("Add AI captions with Florence-2")
250
+ output_components = [captioning_area]
251
+ caption_list = []
252
+ for i in range(1, MAX_IMAGES + 1):
253
+ locals()[f"captioning_row_{i}"] = gr.Row(visible=False)
254
+ with locals()[f"captioning_row_{i}"]:
255
+ locals()[f"image_{i}"] = gr.Image(
256
+ type="filepath",
257
+ width=111,
258
+ height=111,
259
+ min_width=111,
260
+ interactive=False,
261
+ scale=2,
262
+ show_label=False,
263
+ show_share_button=False,
264
+ show_download_button=False
265
+ )
266
+ locals()[f"caption_{i}"] = gr.Textbox(
267
+ label=f"Caption {i}", scale=15, interactive=True
268
+ )
269
+
270
+ output_components.append(locals()[f"captioning_row_{i}"])
271
+ output_components.append(locals()[f"image_{i}"])
272
+ output_components.append(locals()[f"caption_{i}"])
273
+ caption_list.append(locals()[f"caption_{i}"])
274
+
275
+ with gr.Accordion("Advanced options", open=False):
276
+ steps = gr.Number(label="Steps", value=1000, minimum=1, maximum=10000, step=1)
277
+ lr = gr.Number(label="Learning Rate", value=4e-4, minimum=1e-6, maximum=1e-3, step=1e-6)
278
+ rank = gr.Number(label="LoRA Rank", value=16, minimum=4, maximum=128, step=4)
279
+
280
+ with gr.Accordion("Sample prompts", visible=False) as sample:
281
+ gr.Markdown("Include sample prompts to test out your trained model. Don't forget to include your trigger word/sentence (optional)")
282
+ sample_1 = gr.Textbox(label="Test prompt 1")
283
+ sample_2 = gr.Textbox(label="Test prompt 2")
284
+ sample_3 = gr.Textbox(label="Test prompt 3")
285
+
286
+ output_components.append(sample)
287
+ output_components.append(sample_1)
288
+ output_components.append(sample_2)
289
+ output_components.append(sample_3)
290
+ start = gr.Button("Start training")
291
+ progress_area = gr.Markdown("")
292
+
293
+ with gr.Tab("Train locally" if is_spaces else "Instructions"):
294
+ gr.Markdown(f'''To use FLUX LoRA Ease locally with this UI, you can clone this repository (yes, HF Spaces are git repos!)
295
+ ```bash
296
+ git clone https://huggingface.co/spaces/flux-train/flux-lora-trainer
297
+ cd flux-lora-trainer
298
+ ```
299
+
300
+ Then you can install ai-toolkit
301
+ ```bash
302
+ git clone https://github.com/ostris/ai-toolkit.git
303
+ cd ai-toolkit
304
+ git submodule update --init --recursive
305
+ python3 -m venv venv
306
+ source venv/bin/activate
307
+ # .\venv\Scripts\activate on windows
308
+ # install torch first
309
+ pip3 install torch
310
+ pip3 install -r requirements.txt
311
+ cd ..
312
+ ```
313
+
314
+ Now you can run FLUX LoRA Ease locally by doing a simple
315
+ ```py
316
+ python app.py
317
+ ```
318
+ If you prefer command line, you can run Ostris' [AI Toolkit](https://github.com/ostris/ai-toolkit) yourself.
319
+ ''')
320
+
321
+ dataset_folder = gr.State()
322
+
323
+ images.upload(
324
+ load_captioning,
325
+ inputs=[images, concept_sentence],
326
+ outputs=output_components,
327
+ queue=False
328
+ )
329
+
330
+ start.click(
331
+ fn=create_dataset,
332
+ inputs=[images] + caption_list,
333
+ outputs=dataset_folder,
334
+ queue=False
335
+ ).then(
336
+ fn=start_training_spaces if is_spaces else start_training,
337
+ inputs=[
338
+ lora_name,
339
+ concept_sentence,
340
+ steps,
341
+ lr,
342
+ rank,
343
+ dataset_folder,
344
+ sample_1,
345
+ sample_2,
346
+ sample_3,
347
+ ],
348
+ outputs=progress_area,
349
+ queue=False
350
+ )
351
+
352
+ do_captioning.click(
353
+ fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list
354
+ )
355
+ demo.load(fn=swap_visibilty, outputs=main_ui, queue=False)
356
+
357
+ if __name__ == "__main__":
358
+ demo.queue()
359
+ demo.launch(share=True)