from __future__ import annotations import datetime import os import pathlib import shlex import shutil import subprocess import gradio as gr import PIL.Image import slugify import torch from constants import UploadTarget def pad_image(image: PIL.Image.Image) -> PIL.Image.Image: w, h = image.size if w == h: return image elif w > h: new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0)) new_image.paste(image, (0, (w - h) // 2)) return new_image else: new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0)) new_image.paste(image, ((h - w) // 2, 0)) return new_image class Trainer: def prepare_dataset(self, instance_images: list, resolution: int, instance_data_dir: pathlib.Path) -> None: shutil.rmtree(instance_data_dir, ignore_errors=True) instance_data_dir.mkdir(parents=True) for i, temp_path in enumerate(instance_images): image = PIL.Image.open(temp_path.name) image = pad_image(image) image = image.resize((resolution, resolution)) image = image.convert('RGB') out_path = instance_data_dir / f'{i:03d}.jpg' image.save(out_path, format='JPEG', quality=100) def run( self, instance_images: list | None, instance_prompt: str, output_model_name: str, overwrite_existing_model: bool, validation_prompt: str, base_model: str, resolution_s: str, n_steps: int, learning_rate: float, gradient_accumulation: int, seed: int, fp16: bool, use_8bit_adam: bool, checkpointing_steps: int, use_wandb: bool, validation_epochs: int, upload_to_hub: bool, use_private_repo: bool, delete_existing_repo: bool, upload_to: str, ) -> str: if not torch.cuda.is_available(): raise gr.Error('CUDA is not available.') if instance_images is None: raise gr.Error('You need to upload images.') if not instance_prompt: raise gr.Error('The instance prompt is missing.') if not validation_prompt: raise gr.Error('The validation prompt is missing.') resolution = int(resolution_s) if not output_model_name: output_model_name = datetime.datetime.now().strftime( '%Y-%m-%d-%H-%M-%S') output_model_name = slugify.slugify(output_model_name) repo_dir = pathlib.Path(__file__).parent output_dir = repo_dir / 'experiments' / output_model_name if overwrite_existing_model or upload_to_hub: shutil.rmtree(output_dir, ignore_errors=True) if not upload_to_hub: output_dir.mkdir(parents=True) instance_data_dir = repo_dir / 'training_data' / output_model_name self.prepare_dataset(instance_images, resolution, instance_data_dir) command = f''' accelerate launch train_dreambooth_lora.py \ --pretrained_model_name_or_path={base_model} \ --instance_data_dir={instance_data_dir} \ --output_dir={output_dir} \ --instance_prompt="{instance_prompt}" \ --resolution={resolution} \ --train_batch_size=1 \ --gradient_accumulation_steps={gradient_accumulation} \ --learning_rate={learning_rate} \ --lr_scheduler=constant \ --lr_warmup_steps=0 \ --max_train_steps={n_steps} \ --checkpointing_steps={checkpointing_steps} \ --validation_prompt="{validation_prompt}" \ --validation_epochs={validation_epochs} \ --seed={seed} ''' if fp16: command += ' --mixed_precision fp16' if use_8bit_adam: command += ' --use_8bit_adam' if use_wandb: command += ' --report_to wandb' if upload_to_hub: hf_token = os.getenv('HF_TOKEN') command += f' --push_to_hub --hub_token {hf_token}' if use_private_repo: command += ' --private_repo' if delete_existing_repo: command += ' --delete_existing_repo' if upload_to == UploadTarget.LORA_LIBRARY.value: command += ' --upload_to_lora_library' subprocess.run(shlex.split(command)) with open(output_dir / 'train.sh', 'w') as f: command_s = ' '.join(command.split()) f.write(command_s) return 'Training completed!'