|
import gc |
|
import os |
|
import random |
|
import numpy as np |
|
import json |
|
import torch |
|
import uuid |
|
from PIL import Image, PngImagePlugin |
|
from datetime import datetime |
|
from dataclasses import dataclass |
|
from typing import Callable, Dict, Optional, Tuple |
|
from diffusers import ( |
|
DDIMScheduler, |
|
DPMSolverMultistepScheduler, |
|
DPMSolverSinglestepScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
) |
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
@dataclass |
|
class StyleConfig: |
|
prompt: str |
|
negative_prompt: str |
|
|
|
|
|
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: |
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
return seed |
|
|
|
|
|
def seed_everything(seed: int) -> torch.Generator: |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
np.random.seed(seed) |
|
generator = torch.Generator() |
|
generator.manual_seed(seed) |
|
return generator |
|
|
|
|
|
def parse_aspect_ratio(aspect_ratio: str) -> Optional[Tuple[int, int]]: |
|
if aspect_ratio == "Custom": |
|
return None |
|
width, height = aspect_ratio.split(" x ") |
|
return int(width), int(height) |
|
|
|
|
|
def aspect_ratio_handler( |
|
aspect_ratio: str, custom_width: int, custom_height: int |
|
) -> Tuple[int, int]: |
|
if aspect_ratio == "Custom": |
|
return custom_width, custom_height |
|
else: |
|
width, height = parse_aspect_ratio(aspect_ratio) |
|
return width, height |
|
|
|
|
|
def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]: |
|
scheduler_factory_map = { |
|
"DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config( |
|
scheduler_config, use_karras_sigmas=True |
|
), |
|
"DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config( |
|
scheduler_config, use_karras_sigmas=True |
|
), |
|
"DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config( |
|
scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++" |
|
), |
|
"Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config), |
|
"Euler a": lambda: EulerAncestralDiscreteScheduler.from_config( |
|
scheduler_config |
|
), |
|
"DDIM": lambda: DDIMScheduler.from_config(scheduler_config), |
|
} |
|
return scheduler_factory_map.get(name, lambda: None)() |
|
|
|
|
|
def free_memory() -> None: |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
def preprocess_prompt( |
|
style_dict, |
|
style_name: str, |
|
positive: str, |
|
negative: str = "", |
|
add_style: bool = True, |
|
) -> Tuple[str, str]: |
|
p, n = style_dict.get(style_name, style_dict["(None)"]) |
|
|
|
if add_style and positive.strip(): |
|
formatted_positive = p.format(prompt=positive) |
|
else: |
|
formatted_positive = positive |
|
|
|
combined_negative = n |
|
if negative.strip(): |
|
if combined_negative: |
|
combined_negative += ", " + negative |
|
else: |
|
combined_negative = negative |
|
|
|
return formatted_positive, combined_negative |
|
|
|
|
|
def common_upscale( |
|
samples: torch.Tensor, |
|
width: int, |
|
height: int, |
|
upscale_method: str, |
|
) -> torch.Tensor: |
|
return torch.nn.functional.interpolate( |
|
samples, size=(height, width), mode=upscale_method |
|
) |
|
|
|
|
|
def upscale( |
|
samples: torch.Tensor, upscale_method: str, scale_by: float |
|
) -> torch.Tensor: |
|
width = round(samples.shape[3] * scale_by) |
|
height = round(samples.shape[2] * scale_by) |
|
return common_upscale(samples, width, height, upscale_method) |
|
|
|
|
|
def load_wildcard_files(wildcard_dir: str) -> Dict[str, str]: |
|
wildcard_files = {} |
|
for file in os.listdir(wildcard_dir): |
|
if file.endswith(".txt"): |
|
key = f"__{file.split('.')[0]}__" |
|
wildcard_files[key] = os.path.join(wildcard_dir, file) |
|
return wildcard_files |
|
|
|
|
|
def get_random_line_from_file(file_path: str) -> str: |
|
with open(file_path, "r") as file: |
|
lines = file.readlines() |
|
if not lines: |
|
return "" |
|
return random.choice(lines).strip() |
|
|
|
|
|
def add_wildcard(prompt: str, wildcard_files: Dict[str, str]) -> str: |
|
for key, file_path in wildcard_files.items(): |
|
if key in prompt: |
|
wildcard_line = get_random_line_from_file(file_path) |
|
prompt = prompt.replace(key, wildcard_line) |
|
return prompt |
|
|
|
|
|
def preprocess_image_dimensions(width, height): |
|
if width % 8 != 0: |
|
width = width - (width % 8) |
|
if height % 8 != 0: |
|
height = height - (height % 8) |
|
return width, height |
|
|
|
|
|
def save_image(image, metadata, output_dir, is_colab): |
|
if is_colab: |
|
current_time = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
filename = f"image_{current_time}.png" |
|
else: |
|
filename = str(uuid.uuid4()) + ".png" |
|
os.makedirs(output_dir, exist_ok=True) |
|
filepath = os.path.join(output_dir, filename) |
|
metadata_str = json.dumps(metadata) |
|
info = PngImagePlugin.PngInfo() |
|
info.add_text("metadata", metadata_str) |
|
image.save(filepath, "PNG", pnginfo=info) |
|
return filepath |
|
|
|
|
|
def is_google_colab(): |
|
try: |
|
import google.colab |
|
return True |
|
except: |
|
return False |
|
|