|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
import easygui |
|
import json |
|
import math |
|
import os |
|
import subprocess |
|
import pathlib |
|
import argparse |
|
from library.common_gui import ( |
|
get_folder_path, |
|
remove_doublequote, |
|
get_file_path, |
|
get_any_file_path, |
|
get_saveasfile_path, |
|
color_aug_changed, |
|
save_inference_file, |
|
gradio_advanced_training, |
|
run_cmd_advanced_training, |
|
gradio_training, |
|
gradio_config, |
|
gradio_source_model, |
|
run_cmd_training, |
|
|
|
update_my_data, |
|
check_if_model_exist, |
|
) |
|
from library.dreambooth_folder_creation_gui import ( |
|
gradio_dreambooth_folder_creation_tab, |
|
) |
|
from library.tensorboard_gui import ( |
|
gradio_tensorboard, |
|
start_tensorboard, |
|
stop_tensorboard, |
|
) |
|
from library.dataset_balancing_gui import gradio_dataset_balancing_tab |
|
from library.utilities import utilities_tab |
|
from library.merge_lora_gui import gradio_merge_lora_tab |
|
from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab |
|
from library.verify_lora_gui import gradio_verify_lora_tab |
|
from library.resize_lora_gui import gradio_resize_lora_tab |
|
from library.sampler_gui import sample_gradio_config, run_cmd_sample |
|
from easygui import msgbox |
|
|
|
folder_symbol = '\U0001f4c2' |
|
refresh_symbol = '\U0001f504' |
|
save_style_symbol = '\U0001f4be' |
|
document_symbol = '\U0001F4C4' |
|
path_of_this_folder = os.getcwd() |
|
|
|
|
|
def save_configuration( |
|
save_as, |
|
file_path, |
|
pretrained_model_name_or_path, |
|
v2, |
|
v_parameterization, |
|
logging_dir, |
|
train_data_dir, |
|
reg_data_dir, |
|
output_dir, |
|
max_resolution, |
|
learning_rate, |
|
lr_scheduler, |
|
lr_warmup, |
|
train_batch_size, |
|
epoch, |
|
save_every_n_epochs, |
|
mixed_precision, |
|
save_precision, |
|
seed, |
|
num_cpu_threads_per_process, |
|
cache_latents, |
|
caption_extension, |
|
enable_bucket, |
|
gradient_checkpointing, |
|
full_fp16, |
|
no_token_padding, |
|
stop_text_encoder_training, |
|
|
|
xformers, |
|
save_model_as, |
|
shuffle_caption, |
|
save_state, |
|
resume, |
|
prior_loss_weight, |
|
text_encoder_lr, |
|
unet_lr, |
|
network_dim, |
|
lora_network_weights, |
|
color_aug, |
|
flip_aug, |
|
clip_skip, |
|
gradient_accumulation_steps, |
|
mem_eff_attn, |
|
output_name, |
|
model_list, |
|
max_token_length, |
|
max_train_epochs, |
|
max_data_loader_n_workers, |
|
network_alpha, |
|
training_comment, |
|
keep_tokens, |
|
lr_scheduler_num_cycles, |
|
lr_scheduler_power, |
|
persistent_data_loader_workers, |
|
bucket_no_upscale, |
|
random_crop, |
|
bucket_reso_steps, |
|
caption_dropout_every_n_epochs, |
|
caption_dropout_rate, |
|
optimizer, |
|
optimizer_args, |
|
noise_offset, |
|
LoRA_type, |
|
conv_dim, |
|
conv_alpha, |
|
sample_every_n_steps, |
|
sample_every_n_epochs, |
|
sample_sampler, |
|
sample_prompts, |
|
additional_parameters, |
|
vae_batch_size, |
|
min_snr_gamma, |
|
down_lr_weight,mid_lr_weight,up_lr_weight,block_lr_zero_threshold,block_dims,block_alphas,conv_dims,conv_alphas, |
|
weighted_captions,unit, |
|
): |
|
|
|
parameters = list(locals().items()) |
|
|
|
original_file_path = file_path |
|
|
|
save_as_bool = True if save_as.get('label') == 'True' else False |
|
|
|
if save_as_bool: |
|
print('Save as...') |
|
file_path = get_saveasfile_path(file_path) |
|
else: |
|
print('Save...') |
|
if file_path == None or file_path == '': |
|
file_path = get_saveasfile_path(file_path) |
|
|
|
|
|
|
|
if file_path == None or file_path == '': |
|
return original_file_path |
|
|
|
|
|
variables = { |
|
name: value |
|
for name, value in parameters |
|
if name |
|
not in [ |
|
'file_path', |
|
'save_as', |
|
] |
|
} |
|
|
|
|
|
destination_directory = os.path.dirname(file_path) |
|
|
|
|
|
if not os.path.exists(destination_directory): |
|
os.makedirs(destination_directory) |
|
|
|
|
|
with open(file_path, 'w') as file: |
|
json.dump(variables, file, indent=2) |
|
|
|
return file_path |
|
|
|
|
|
def open_configuration( |
|
ask_for_file, |
|
file_path, |
|
pretrained_model_name_or_path, |
|
v2, |
|
v_parameterization, |
|
logging_dir, |
|
train_data_dir, |
|
reg_data_dir, |
|
output_dir, |
|
max_resolution, |
|
learning_rate, |
|
lr_scheduler, |
|
lr_warmup, |
|
train_batch_size, |
|
epoch, |
|
save_every_n_epochs, |
|
mixed_precision, |
|
save_precision, |
|
seed, |
|
num_cpu_threads_per_process, |
|
cache_latents, |
|
caption_extension, |
|
enable_bucket, |
|
gradient_checkpointing, |
|
full_fp16, |
|
no_token_padding, |
|
stop_text_encoder_training, |
|
|
|
xformers, |
|
save_model_as, |
|
shuffle_caption, |
|
save_state, |
|
resume, |
|
prior_loss_weight, |
|
text_encoder_lr, |
|
unet_lr, |
|
network_dim, |
|
lora_network_weights, |
|
color_aug, |
|
flip_aug, |
|
clip_skip, |
|
gradient_accumulation_steps, |
|
mem_eff_attn, |
|
output_name, |
|
model_list, |
|
max_token_length, |
|
max_train_epochs, |
|
max_data_loader_n_workers, |
|
network_alpha, |
|
training_comment, |
|
keep_tokens, |
|
lr_scheduler_num_cycles, |
|
lr_scheduler_power, |
|
persistent_data_loader_workers, |
|
bucket_no_upscale, |
|
random_crop, |
|
bucket_reso_steps, |
|
caption_dropout_every_n_epochs, |
|
caption_dropout_rate, |
|
optimizer, |
|
optimizer_args, |
|
noise_offset, |
|
LoRA_type, |
|
conv_dim, |
|
conv_alpha, |
|
sample_every_n_steps, |
|
sample_every_n_epochs, |
|
sample_sampler, |
|
sample_prompts, |
|
additional_parameters, |
|
vae_batch_size, |
|
min_snr_gamma, |
|
down_lr_weight,mid_lr_weight,up_lr_weight,block_lr_zero_threshold,block_dims,block_alphas,conv_dims,conv_alphas, |
|
weighted_captions,unit, |
|
): |
|
|
|
parameters = list(locals().items()) |
|
|
|
ask_for_file = True if ask_for_file.get('label') == 'True' else False |
|
|
|
original_file_path = file_path |
|
|
|
if ask_for_file: |
|
file_path = get_file_path(file_path) |
|
|
|
if not file_path == '' and not file_path == None: |
|
|
|
with open(file_path, 'r') as f: |
|
my_data = json.load(f) |
|
print('Loading config...') |
|
|
|
|
|
my_data = update_my_data(my_data) |
|
else: |
|
file_path = original_file_path |
|
my_data = {} |
|
|
|
values = [file_path] |
|
for key, value in parameters: |
|
|
|
if not key in ['ask_for_file', 'file_path']: |
|
values.append(my_data.get(key, value)) |
|
|
|
|
|
if my_data.get('LoRA_type', 'Standard') == 'LoCon': |
|
values.append(gr.Row.update(visible=True)) |
|
else: |
|
values.append(gr.Row.update(visible=False)) |
|
|
|
return tuple(values) |
|
|
|
|
|
def train_model( |
|
print_only, |
|
pretrained_model_name_or_path, |
|
v2, |
|
v_parameterization, |
|
logging_dir, |
|
train_data_dir, |
|
reg_data_dir, |
|
output_dir, |
|
max_resolution, |
|
learning_rate, |
|
lr_scheduler, |
|
lr_warmup, |
|
train_batch_size, |
|
epoch, |
|
save_every_n_epochs, |
|
mixed_precision, |
|
save_precision, |
|
seed, |
|
num_cpu_threads_per_process, |
|
cache_latents, |
|
caption_extension, |
|
enable_bucket, |
|
gradient_checkpointing, |
|
full_fp16, |
|
no_token_padding, |
|
stop_text_encoder_training_pct, |
|
|
|
xformers, |
|
save_model_as, |
|
shuffle_caption, |
|
save_state, |
|
resume, |
|
prior_loss_weight, |
|
text_encoder_lr, |
|
unet_lr, |
|
network_dim, |
|
lora_network_weights, |
|
color_aug, |
|
flip_aug, |
|
clip_skip, |
|
gradient_accumulation_steps, |
|
mem_eff_attn, |
|
output_name, |
|
model_list, |
|
max_token_length, |
|
max_train_epochs, |
|
max_data_loader_n_workers, |
|
network_alpha, |
|
training_comment, |
|
keep_tokens, |
|
lr_scheduler_num_cycles, |
|
lr_scheduler_power, |
|
persistent_data_loader_workers, |
|
bucket_no_upscale, |
|
random_crop, |
|
bucket_reso_steps, |
|
caption_dropout_every_n_epochs, |
|
caption_dropout_rate, |
|
optimizer, |
|
optimizer_args, |
|
noise_offset, |
|
LoRA_type, |
|
conv_dim, |
|
conv_alpha, |
|
sample_every_n_steps, |
|
sample_every_n_epochs, |
|
sample_sampler, |
|
sample_prompts, |
|
additional_parameters, |
|
vae_batch_size, |
|
min_snr_gamma, |
|
down_lr_weight,mid_lr_weight,up_lr_weight,block_lr_zero_threshold,block_dims,block_alphas,conv_dims,conv_alphas, |
|
weighted_captions,unit, |
|
): |
|
print_only_bool = True if print_only.get('label') == 'True' else False |
|
|
|
if pretrained_model_name_or_path == '': |
|
msgbox('Source model information is missing') |
|
return |
|
|
|
if train_data_dir == '': |
|
msgbox('Image folder path is missing') |
|
return |
|
|
|
if not os.path.exists(train_data_dir): |
|
msgbox('Image folder does not exist') |
|
return |
|
|
|
if reg_data_dir != '': |
|
if not os.path.exists(reg_data_dir): |
|
msgbox('Regularisation folder does not exist') |
|
return |
|
|
|
if output_dir == '': |
|
msgbox('Output folder path is missing') |
|
return |
|
|
|
if int(bucket_reso_steps) < 1: |
|
msgbox('Bucket resolution steps need to be greater than 0') |
|
return |
|
|
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
|
|
if stop_text_encoder_training_pct > 0: |
|
msgbox( |
|
'Output "stop text encoder training" is not yet supported. Ignoring' |
|
) |
|
stop_text_encoder_training_pct = 0 |
|
|
|
if check_if_model_exist(output_name, output_dir, save_model_as): |
|
return |
|
|
|
if optimizer == 'Adafactor' and lr_warmup != '0': |
|
msgbox("Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", title="Warning") |
|
lr_warmup = '0' |
|
|
|
|
|
if text_encoder_lr == '': |
|
text_encoder_lr = 0 |
|
if unet_lr == '': |
|
unet_lr = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
subfolders = [ |
|
f |
|
for f in os.listdir(train_data_dir) |
|
if os.path.isdir(os.path.join(train_data_dir, f)) |
|
] |
|
|
|
total_steps = 0 |
|
|
|
|
|
for folder in subfolders: |
|
try: |
|
|
|
repeats = int(folder.split('_')[0]) |
|
|
|
|
|
num_images = len( |
|
[ |
|
f |
|
for f, lower_f in ( |
|
(file, file.lower()) |
|
for file in os.listdir( |
|
os.path.join(train_data_dir, folder) |
|
) |
|
) |
|
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp')) |
|
] |
|
) |
|
|
|
print(f'Folder {folder}: {num_images} images found') |
|
|
|
|
|
steps = repeats * num_images |
|
|
|
|
|
print(f'Folder {folder}: {steps} steps') |
|
|
|
total_steps += steps |
|
|
|
except ValueError: |
|
|
|
print(f"Error: '{folder}' does not contain an underscore, skipping...") |
|
|
|
|
|
max_train_steps = int( |
|
math.ceil( |
|
float(total_steps) |
|
/ int(train_batch_size) |
|
* int(epoch) |
|
|
|
) |
|
) |
|
print(f'max_train_steps = {max_train_steps}') |
|
|
|
|
|
if stop_text_encoder_training_pct == None: |
|
stop_text_encoder_training = 0 |
|
else: |
|
stop_text_encoder_training = math.ceil( |
|
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) |
|
) |
|
print(f'stop_text_encoder_training = {stop_text_encoder_training}') |
|
|
|
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) |
|
print(f'lr_warmup_steps = {lr_warmup_steps}') |
|
|
|
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"' |
|
|
|
if v2: |
|
run_cmd += ' --v2' |
|
if v_parameterization: |
|
run_cmd += ' --v_parameterization' |
|
if enable_bucket: |
|
run_cmd += ' --enable_bucket' |
|
if no_token_padding: |
|
run_cmd += ' --no_token_padding' |
|
if weighted_captions: |
|
run_cmd += ' --weighted_captions' |
|
run_cmd += ( |
|
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"' |
|
) |
|
run_cmd += f' --train_data_dir="{train_data_dir}"' |
|
if len(reg_data_dir): |
|
run_cmd += f' --reg_data_dir="{reg_data_dir}"' |
|
run_cmd += f' --resolution={max_resolution}' |
|
run_cmd += f' --output_dir="{output_dir}"' |
|
run_cmd += f' --logging_dir="{logging_dir}"' |
|
run_cmd += f' --network_alpha="{network_alpha}"' |
|
if not training_comment == '': |
|
run_cmd += f' --training_comment="{training_comment}"' |
|
if not stop_text_encoder_training == 0: |
|
run_cmd += ( |
|
f' --stop_text_encoder_training={stop_text_encoder_training}' |
|
) |
|
if not save_model_as == 'same as source model': |
|
run_cmd += f' --save_model_as={save_model_as}' |
|
if not float(prior_loss_weight) == 1.0: |
|
run_cmd += f' --prior_loss_weight={prior_loss_weight}' |
|
if LoRA_type == 'LoCon' or LoRA_type == 'LyCORIS/LoCon': |
|
try: |
|
import lycoris |
|
except ModuleNotFoundError: |
|
print( |
|
"\033[1;31mError:\033[0m The required module 'lycoris_lora' is not installed. Please install by running \033[33mupgrade.ps1\033[0m before running this program." |
|
) |
|
return |
|
run_cmd += f' --network_module=lycoris.kohya' |
|
run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=lora"' |
|
if LoRA_type == 'LyCORIS/LoHa': |
|
try: |
|
import lycoris |
|
except ModuleNotFoundError: |
|
print( |
|
"\033[1;31mError:\033[0m The required module 'lycoris_lora' is not installed. Please install by running \033[33mupgrade.ps1\033[0m before running this program." |
|
) |
|
return |
|
run_cmd += f' --network_module=lycoris.kohya' |
|
run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=loha"' |
|
|
|
|
|
if LoRA_type in ['Kohya LoCon', 'Standard']: |
|
kohya_lora_var_list = ['down_lr_weight', 'mid_lr_weight', 'up_lr_weight', 'block_lr_zero_threshold', 'block_dims', 'block_alphas', 'conv_dims', 'conv_alphas'] |
|
|
|
run_cmd += f' --network_module=networks.lora' |
|
kohya_lora_vars = {key: value for key, value in vars().items() if key in kohya_lora_var_list and value} |
|
|
|
network_args = '' |
|
if LoRA_type == 'Kohya LoCon': |
|
network_args += f' "conv_dim={conv_dim}" "conv_alpha={conv_alpha}"' |
|
|
|
for key, value in kohya_lora_vars.items(): |
|
if value: |
|
network_args += f' {key}="{value}"' |
|
|
|
if network_args: |
|
run_cmd += f' --network_args{network_args}' |
|
|
|
if LoRA_type in ['Kohya DyLoRA']: |
|
kohya_lora_var_list = ['conv_dim', 'conv_alpha', 'down_lr_weight', 'mid_lr_weight', 'up_lr_weight', 'block_lr_zero_threshold', 'block_dims', 'block_alphas', 'conv_dims', 'conv_alphas', 'unit'] |
|
|
|
run_cmd += f' --network_module=networks.dylora' |
|
kohya_lora_vars = {key: value for key, value in vars().items() if key in kohya_lora_var_list and value} |
|
|
|
network_args = '' |
|
|
|
for key, value in kohya_lora_vars.items(): |
|
if value: |
|
network_args += f' {key}="{value}"' |
|
|
|
if network_args: |
|
run_cmd += f' --network_args{network_args}' |
|
|
|
if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0): |
|
if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0): |
|
run_cmd += f' --text_encoder_lr={text_encoder_lr}' |
|
run_cmd += f' --unet_lr={unet_lr}' |
|
elif not (float(text_encoder_lr) == 0): |
|
run_cmd += f' --text_encoder_lr={text_encoder_lr}' |
|
run_cmd += f' --network_train_text_encoder_only' |
|
else: |
|
run_cmd += f' --unet_lr={unet_lr}' |
|
run_cmd += f' --network_train_unet_only' |
|
else: |
|
if float(text_encoder_lr) == 0: |
|
msgbox('Please input learning rate values.') |
|
return |
|
|
|
run_cmd += f' --network_dim={network_dim}' |
|
|
|
if not lora_network_weights == '': |
|
run_cmd += f' --network_weights="{lora_network_weights}"' |
|
if int(gradient_accumulation_steps) > 1: |
|
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}' |
|
if not output_name == '': |
|
run_cmd += f' --output_name="{output_name}"' |
|
if not lr_scheduler_num_cycles == '': |
|
run_cmd += f' --lr_scheduler_num_cycles="{lr_scheduler_num_cycles}"' |
|
else: |
|
run_cmd += f' --lr_scheduler_num_cycles="{epoch}"' |
|
if not lr_scheduler_power == '': |
|
run_cmd += f' --lr_scheduler_power="{lr_scheduler_power}"' |
|
|
|
run_cmd += run_cmd_training( |
|
learning_rate=learning_rate, |
|
lr_scheduler=lr_scheduler, |
|
lr_warmup_steps=lr_warmup_steps, |
|
train_batch_size=train_batch_size, |
|
max_train_steps=max_train_steps, |
|
save_every_n_epochs=save_every_n_epochs, |
|
mixed_precision=mixed_precision, |
|
save_precision=save_precision, |
|
seed=seed, |
|
caption_extension=caption_extension, |
|
cache_latents=cache_latents, |
|
optimizer=optimizer, |
|
optimizer_args=optimizer_args, |
|
) |
|
|
|
run_cmd += run_cmd_advanced_training( |
|
max_train_epochs=max_train_epochs, |
|
max_data_loader_n_workers=max_data_loader_n_workers, |
|
max_token_length=max_token_length, |
|
resume=resume, |
|
save_state=save_state, |
|
mem_eff_attn=mem_eff_attn, |
|
clip_skip=clip_skip, |
|
flip_aug=flip_aug, |
|
color_aug=color_aug, |
|
shuffle_caption=shuffle_caption, |
|
gradient_checkpointing=gradient_checkpointing, |
|
full_fp16=full_fp16, |
|
xformers=xformers, |
|
|
|
keep_tokens=keep_tokens, |
|
persistent_data_loader_workers=persistent_data_loader_workers, |
|
bucket_no_upscale=bucket_no_upscale, |
|
random_crop=random_crop, |
|
bucket_reso_steps=bucket_reso_steps, |
|
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, |
|
caption_dropout_rate=caption_dropout_rate, |
|
noise_offset=noise_offset, |
|
additional_parameters=additional_parameters, |
|
vae_batch_size=vae_batch_size, |
|
min_snr_gamma=min_snr_gamma, |
|
) |
|
|
|
run_cmd += run_cmd_sample( |
|
sample_every_n_steps, |
|
sample_every_n_epochs, |
|
sample_sampler, |
|
sample_prompts, |
|
output_dir, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if print_only_bool: |
|
print( |
|
'\033[93m\nHere is the trainer command as a reference. It will not be executed:\033[0m\n' |
|
) |
|
print('\033[96m' + run_cmd + '\033[0m\n') |
|
else: |
|
print(run_cmd) |
|
|
|
if os.name == 'posix': |
|
os.system(run_cmd) |
|
else: |
|
subprocess.run(run_cmd) |
|
|
|
|
|
last_dir = pathlib.Path(f'{output_dir}/{output_name}') |
|
|
|
if not last_dir.is_dir(): |
|
|
|
save_inference_file( |
|
output_dir, v2, v_parameterization, output_name |
|
) |
|
|
|
|
|
def lora_tab( |
|
train_data_dir_input=gr.Textbox(), |
|
reg_data_dir_input=gr.Textbox(), |
|
output_dir_input=gr.Textbox(), |
|
logging_dir_input=gr.Textbox(), |
|
): |
|
dummy_db_true = gr.Label(value=True, visible=False) |
|
dummy_db_false = gr.Label(value=False, visible=False) |
|
gr.Markdown( |
|
'Train a custom model using kohya train network LoRA python code...' |
|
) |
|
( |
|
button_open_config, |
|
button_save_config, |
|
button_save_as_config, |
|
config_file_name, |
|
button_load_config, |
|
) = gradio_config() |
|
|
|
( |
|
pretrained_model_name_or_path, |
|
v2, |
|
v_parameterization, |
|
save_model_as, |
|
model_list, |
|
) = gradio_source_model( |
|
save_model_as_choices=[ |
|
'ckpt', |
|
'safetensors', |
|
] |
|
) |
|
|
|
with gr.Tab('Folders'): |
|
with gr.Row(): |
|
train_data_dir = gr.Textbox( |
|
label='Image folder', |
|
placeholder='Folder where the training folders containing the images are located', |
|
) |
|
train_data_dir_folder = gr.Button('π', elem_id='open_folder_small') |
|
train_data_dir_folder.click( |
|
get_folder_path, |
|
outputs=train_data_dir, |
|
show_progress=False, |
|
) |
|
reg_data_dir = gr.Textbox( |
|
label='Regularisation folder', |
|
placeholder='(Optional) Folder where where the regularization folders containing the images are located', |
|
) |
|
reg_data_dir_folder = gr.Button('π', elem_id='open_folder_small') |
|
reg_data_dir_folder.click( |
|
get_folder_path, |
|
outputs=reg_data_dir, |
|
show_progress=False, |
|
) |
|
with gr.Row(): |
|
output_dir = gr.Textbox( |
|
label='Output folder', |
|
placeholder='Folder to output trained model', |
|
) |
|
output_dir_folder = gr.Button('π', elem_id='open_folder_small') |
|
output_dir_folder.click( |
|
get_folder_path, |
|
outputs=output_dir, |
|
show_progress=False, |
|
) |
|
logging_dir = gr.Textbox( |
|
label='Logging folder', |
|
placeholder='Optional: enable logging and output TensorBoard log to this folder', |
|
) |
|
logging_dir_folder = gr.Button('π', elem_id='open_folder_small') |
|
logging_dir_folder.click( |
|
get_folder_path, |
|
outputs=logging_dir, |
|
show_progress=False, |
|
) |
|
with gr.Row(): |
|
output_name = gr.Textbox( |
|
label='Model output name', |
|
placeholder='(Name of the model to output)', |
|
value='last', |
|
interactive=True, |
|
) |
|
training_comment = gr.Textbox( |
|
label='Training comment', |
|
placeholder='(Optional) Add training comment to be included in metadata', |
|
interactive=True, |
|
) |
|
train_data_dir.change( |
|
remove_doublequote, |
|
inputs=[train_data_dir], |
|
outputs=[train_data_dir], |
|
) |
|
reg_data_dir.change( |
|
remove_doublequote, |
|
inputs=[reg_data_dir], |
|
outputs=[reg_data_dir], |
|
) |
|
output_dir.change( |
|
remove_doublequote, |
|
inputs=[output_dir], |
|
outputs=[output_dir], |
|
) |
|
logging_dir.change( |
|
remove_doublequote, |
|
inputs=[logging_dir], |
|
outputs=[logging_dir], |
|
) |
|
with gr.Tab('Training parameters'): |
|
with gr.Row(): |
|
LoRA_type = gr.Dropdown( |
|
label='LoRA type', |
|
choices=[ |
|
'Kohya DyLoRA', |
|
'Kohya LoCon', |
|
|
|
'LyCORIS/LoCon', |
|
'LyCORIS/LoHa', |
|
'Standard', |
|
], |
|
value='Standard', |
|
) |
|
lora_network_weights = gr.Textbox( |
|
label='LoRA network weights', |
|
placeholder='{Optional) Path to existing LoRA network weights to resume training', |
|
) |
|
lora_network_weights_file = gr.Button( |
|
document_symbol, elem_id='open_folder_small' |
|
) |
|
lora_network_weights_file.click( |
|
get_any_file_path, |
|
inputs=[lora_network_weights], |
|
outputs=lora_network_weights, |
|
show_progress=False, |
|
) |
|
( |
|
learning_rate, |
|
lr_scheduler, |
|
lr_warmup, |
|
train_batch_size, |
|
epoch, |
|
save_every_n_epochs, |
|
mixed_precision, |
|
save_precision, |
|
num_cpu_threads_per_process, |
|
seed, |
|
caption_extension, |
|
cache_latents, |
|
optimizer, |
|
optimizer_args, |
|
) = gradio_training( |
|
learning_rate_value='0.0001', |
|
lr_scheduler_value='cosine', |
|
lr_warmup_value='10', |
|
) |
|
|
|
with gr.Row(): |
|
text_encoder_lr = gr.Textbox( |
|
label='Text Encoder learning rate', |
|
value='5e-5', |
|
placeholder='Optional', |
|
) |
|
unet_lr = gr.Textbox( |
|
label='Unet learning rate', |
|
value='0.0001', |
|
placeholder='Optional', |
|
) |
|
network_dim = gr.Slider( |
|
minimum=1, |
|
maximum=1024, |
|
label='Network Rank (Dimension)', |
|
value=8, |
|
step=1, |
|
interactive=True, |
|
) |
|
network_alpha = gr.Slider( |
|
minimum=0.1, |
|
maximum=1024, |
|
label='Network Alpha', |
|
value=1, |
|
step=0.1, |
|
interactive=True, |
|
info='alpha for LoRA weight scaling', |
|
) |
|
with gr.Row(visible=False) as LoCon_row: |
|
|
|
|
|
conv_dim = gr.Slider( |
|
minimum=1, |
|
maximum=512, |
|
value=1, |
|
step=1, |
|
label='Convolution Rank (Dimension)', |
|
) |
|
conv_alpha = gr.Slider( |
|
minimum=0.1, |
|
maximum=512, |
|
value=1, |
|
step=0.1, |
|
label='Convolution Alpha', |
|
) |
|
with gr.Row(visible=False) as kohya_dylora: |
|
unit = gr.Slider( |
|
minimum=1, |
|
maximum=64, |
|
label='DyLoRA Unit', |
|
value=1, |
|
step=1, |
|
interactive=True, |
|
) |
|
|
|
|
|
def update_LoRA_settings(LoRA_type): |
|
|
|
print('LoRA type changed...') |
|
|
|
|
|
LoCon_row = LoRA_type in {'LoCon', 'Kohya DyLoRA', 'Kohya LoCon', 'LyCORIS/LoHa', 'LyCORIS/LoCon'} |
|
|
|
|
|
LoRA_type_change = LoRA_type in {'Standard', 'Kohya DyLoRA', 'Kohya LoCon'} |
|
|
|
|
|
kohya_dylora_visible = LoRA_type == 'Kohya DyLoRA' |
|
|
|
|
|
return ( |
|
gr.Group.update(visible=LoCon_row), |
|
gr.Group.update(visible=LoRA_type_change), |
|
gr.Group.update(visible=kohya_dylora_visible), |
|
) |
|
|
|
|
|
with gr.Row(): |
|
max_resolution = gr.Textbox( |
|
label='Max resolution', |
|
value='512,512', |
|
placeholder='512,512', |
|
info='The maximum resolution of dataset images. W,H', |
|
) |
|
stop_text_encoder_training = gr.Slider( |
|
minimum=0, |
|
maximum=100, |
|
value=0, |
|
step=1, |
|
label='Stop text encoder training', |
|
info='After what % of steps should the text encoder stop being trained. 0 = train for all steps.', |
|
) |
|
enable_bucket = gr.Checkbox(label='Enable buckets', value=True, |
|
info='Allow non similar resolution dataset images to be trained on.',) |
|
|
|
with gr.Accordion('Advanced Configuration', open=False): |
|
with gr.Row(visible=True) as kohya_advanced_lora: |
|
with gr.Tab(label='Weights'): |
|
with gr.Row(visible=True): |
|
down_lr_weight = gr.Textbox( |
|
label='Down LR weights', |
|
placeholder='(Optional) eg: 0,0,0,0,0,0,1,1,1,1,1,1', |
|
info='Specify the learning rate weight of the down blocks of U-Net.' |
|
) |
|
mid_lr_weight = gr.Textbox( |
|
label='Mid LR weights', |
|
placeholder='(Optional) eg: 0.5', |
|
info='Specify the learning rate weight of the mid block of U-Net.' |
|
) |
|
up_lr_weight = gr.Textbox( |
|
label='Up LR weights', |
|
placeholder='(Optional) eg: 0,0,0,0,0,0,1,1,1,1,1,1', |
|
info='Specify the learning rate weight of the up blocks of U-Net. The same as down_lr_weight.' |
|
) |
|
block_lr_zero_threshold = gr.Textbox( |
|
label='Blocks LR zero threshold', |
|
placeholder='(Optional) eg: 0.1', |
|
info='If the weight is not more than this value, the LoRA module is not created. The default is 0.' |
|
) |
|
with gr.Tab(label='Blocks'): |
|
with gr.Row(visible=True): |
|
block_dims = gr.Textbox( |
|
label='Block dims', |
|
placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2', |
|
info='Specify the dim (rank) of each block. Specify 25 numbers.' |
|
) |
|
block_alphas = gr.Textbox( |
|
label='Block alphas', |
|
placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2', |
|
info='Specify the alpha of each block. Specify 25 numbers as with block_dims. If omitted, the value of network_alpha is used.' |
|
) |
|
with gr.Tab(label='Conv'): |
|
with gr.Row(visible=True): |
|
conv_dims = gr.Textbox( |
|
label='Conv dims', |
|
placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2', |
|
info='Expand LoRA to Conv2d 3x3 and specify the dim (rank) of each block. Specify 25 numbers.' |
|
) |
|
conv_alphas = gr.Textbox( |
|
label='Conv alphas', |
|
placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2', |
|
info='Specify the alpha of each block when expanding LoRA to Conv2d 3x3. Specify 25 numbers. If omitted, the value of conv_alpha is used.' |
|
) |
|
with gr.Row(): |
|
no_token_padding = gr.Checkbox( |
|
label='No token padding', value=False |
|
) |
|
gradient_accumulation_steps = gr.Number( |
|
label='Gradient accumulate steps', value='1' |
|
) |
|
weighted_captions = gr.Checkbox( |
|
label='Weighted captions', value=False, info='Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder.', |
|
) |
|
with gr.Row(): |
|
prior_loss_weight = gr.Number( |
|
label='Prior loss weight', value=1.0 |
|
) |
|
lr_scheduler_num_cycles = gr.Textbox( |
|
label='LR number of cycles', |
|
placeholder='(Optional) For Cosine with restart and polynomial only', |
|
) |
|
|
|
lr_scheduler_power = gr.Textbox( |
|
label='LR power', |
|
placeholder='(Optional) For Cosine with restart and polynomial only', |
|
) |
|
( |
|
|
|
xformers, |
|
full_fp16, |
|
gradient_checkpointing, |
|
shuffle_caption, |
|
color_aug, |
|
flip_aug, |
|
clip_skip, |
|
mem_eff_attn, |
|
save_state, |
|
resume, |
|
max_token_length, |
|
max_train_epochs, |
|
max_data_loader_n_workers, |
|
keep_tokens, |
|
persistent_data_loader_workers, |
|
bucket_no_upscale, |
|
random_crop, |
|
bucket_reso_steps, |
|
caption_dropout_every_n_epochs, |
|
caption_dropout_rate, |
|
noise_offset, |
|
additional_parameters, |
|
vae_batch_size, |
|
min_snr_gamma, |
|
) = gradio_advanced_training() |
|
color_aug.change( |
|
color_aug_changed, |
|
inputs=[color_aug], |
|
outputs=[cache_latents], |
|
) |
|
|
|
( |
|
sample_every_n_steps, |
|
sample_every_n_epochs, |
|
sample_sampler, |
|
sample_prompts, |
|
) = sample_gradio_config() |
|
|
|
LoRA_type.change( |
|
update_LoRA_settings, inputs=[LoRA_type], outputs=[LoCon_row, kohya_advanced_lora, kohya_dylora] |
|
) |
|
|
|
with gr.Tab('Tools'): |
|
gr.Markdown( |
|
'This section provide Dreambooth tools to help setup your dataset...' |
|
) |
|
gradio_dreambooth_folder_creation_tab( |
|
train_data_dir_input=train_data_dir, |
|
reg_data_dir_input=reg_data_dir, |
|
output_dir_input=output_dir, |
|
logging_dir_input=logging_dir, |
|
) |
|
gradio_dataset_balancing_tab() |
|
gradio_merge_lora_tab() |
|
gradio_svd_merge_lora_tab() |
|
gradio_resize_lora_tab() |
|
gradio_verify_lora_tab() |
|
|
|
button_run = gr.Button('Train model', variant='primary') |
|
|
|
button_print = gr.Button('Print training command') |
|
|
|
|
|
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard() |
|
|
|
button_start_tensorboard.click( |
|
start_tensorboard, |
|
inputs=logging_dir, |
|
show_progress=False, |
|
) |
|
|
|
button_stop_tensorboard.click( |
|
stop_tensorboard, |
|
show_progress=False, |
|
) |
|
|
|
settings_list = [ |
|
pretrained_model_name_or_path, |
|
v2, |
|
v_parameterization, |
|
logging_dir, |
|
train_data_dir, |
|
reg_data_dir, |
|
output_dir, |
|
max_resolution, |
|
learning_rate, |
|
lr_scheduler, |
|
lr_warmup, |
|
train_batch_size, |
|
epoch, |
|
save_every_n_epochs, |
|
mixed_precision, |
|
save_precision, |
|
seed, |
|
num_cpu_threads_per_process, |
|
cache_latents, |
|
caption_extension, |
|
enable_bucket, |
|
gradient_checkpointing, |
|
full_fp16, |
|
no_token_padding, |
|
stop_text_encoder_training, |
|
|
|
xformers, |
|
save_model_as, |
|
shuffle_caption, |
|
save_state, |
|
resume, |
|
prior_loss_weight, |
|
text_encoder_lr, |
|
unet_lr, |
|
network_dim, |
|
lora_network_weights, |
|
color_aug, |
|
flip_aug, |
|
clip_skip, |
|
gradient_accumulation_steps, |
|
mem_eff_attn, |
|
output_name, |
|
model_list, |
|
max_token_length, |
|
max_train_epochs, |
|
max_data_loader_n_workers, |
|
network_alpha, |
|
training_comment, |
|
keep_tokens, |
|
lr_scheduler_num_cycles, |
|
lr_scheduler_power, |
|
persistent_data_loader_workers, |
|
bucket_no_upscale, |
|
random_crop, |
|
bucket_reso_steps, |
|
caption_dropout_every_n_epochs, |
|
caption_dropout_rate, |
|
optimizer, |
|
optimizer_args, |
|
noise_offset, |
|
LoRA_type, |
|
conv_dim, |
|
conv_alpha, |
|
sample_every_n_steps, |
|
sample_every_n_epochs, |
|
sample_sampler, |
|
sample_prompts, |
|
additional_parameters, |
|
vae_batch_size, |
|
min_snr_gamma, |
|
down_lr_weight,mid_lr_weight,up_lr_weight,block_lr_zero_threshold,block_dims,block_alphas,conv_dims,conv_alphas, |
|
weighted_captions, unit, |
|
] |
|
|
|
button_open_config.click( |
|
open_configuration, |
|
inputs=[dummy_db_true, config_file_name] + settings_list, |
|
outputs=[config_file_name] + settings_list + [LoCon_row], |
|
show_progress=False, |
|
) |
|
|
|
button_load_config.click( |
|
open_configuration, |
|
inputs=[dummy_db_false, config_file_name] + settings_list, |
|
outputs=[config_file_name] + settings_list + [LoCon_row], |
|
show_progress=False, |
|
) |
|
|
|
button_save_config.click( |
|
save_configuration, |
|
inputs=[dummy_db_false, config_file_name] + settings_list, |
|
outputs=[config_file_name], |
|
show_progress=False, |
|
) |
|
|
|
button_save_as_config.click( |
|
save_configuration, |
|
inputs=[dummy_db_true, config_file_name] + settings_list, |
|
outputs=[config_file_name], |
|
show_progress=False, |
|
) |
|
|
|
button_run.click( |
|
train_model, |
|
inputs=[dummy_db_false] + settings_list, |
|
show_progress=False, |
|
) |
|
|
|
button_print.click( |
|
train_model, |
|
inputs=[dummy_db_true] + settings_list, |
|
show_progress=False, |
|
) |
|
|
|
return ( |
|
train_data_dir, |
|
reg_data_dir, |
|
output_dir, |
|
logging_dir, |
|
) |
|
|
|
|
|
def UI(**kwargs): |
|
css = '' |
|
|
|
if os.path.exists('./style.css'): |
|
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: |
|
print('Load CSS...') |
|
css += file.read() + '\n' |
|
|
|
interface = gr.Blocks(css=css) |
|
|
|
with interface: |
|
with gr.Tab('LoRA'): |
|
( |
|
train_data_dir_input, |
|
reg_data_dir_input, |
|
output_dir_input, |
|
logging_dir_input, |
|
) = lora_tab() |
|
with gr.Tab('Utilities'): |
|
utilities_tab( |
|
train_data_dir_input=train_data_dir_input, |
|
reg_data_dir_input=reg_data_dir_input, |
|
output_dir_input=output_dir_input, |
|
logging_dir_input=logging_dir_input, |
|
enable_copy_info_button=True, |
|
) |
|
|
|
|
|
launch_kwargs = {} |
|
if not kwargs.get('username', None) == '': |
|
launch_kwargs['auth'] = ( |
|
kwargs.get('username', None), |
|
kwargs.get('password', None), |
|
) |
|
if kwargs.get('server_port', 0) > 0: |
|
launch_kwargs['server_port'] = kwargs.get('server_port', 0) |
|
if kwargs.get('inbrowser', False): |
|
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False) |
|
if kwargs.get('listen', True): |
|
launch_kwargs['server_name'] = '0.0.0.0' |
|
print(launch_kwargs) |
|
interface.launch(**launch_kwargs) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
'--username', type=str, default='', help='Username for authentication' |
|
) |
|
parser.add_argument( |
|
'--password', type=str, default='', help='Password for authentication' |
|
) |
|
parser.add_argument( |
|
'--server_port', |
|
type=int, |
|
default=0, |
|
help='Port to run the server listener on', |
|
) |
|
parser.add_argument( |
|
'--inbrowser', action='store_true', help='Open in browser' |
|
) |
|
parser.add_argument( |
|
'--listen', |
|
action='store_true', |
|
help='Launch gradio with server name 0.0.0.0, allowing LAN access', |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
UI( |
|
username=args.username, |
|
password=args.password, |
|
inbrowser=args.inbrowser, |
|
server_port=args.server_port, |
|
) |
|
|