NeTI / config.py
neural-ti's picture
Upload 17 files
3eb1ce9
raw
history blame contribute delete
No virus
6.63 kB
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Dict
from constants import VALIDATION_PROMPTS
from utils.types import PESigmas
@dataclass
class LogConfig:
""" Parameters for logging and saving """
# Name of experiment. This will be the name of the output folder
exp_name: str
# The output directory where the model predictions and checkpoints will be written
exp_dir: Path = Path("./outputs")
# Save interval
save_steps: int = 250
# [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to
# `output_dir/runs/**CURRENT_DATETIME_HOSTNAME`
logging_dir: Path = Path("logs")
# The integration to report the results to. Supported platforms are "tensorboard" '
# (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
report_to: str = "tensorboard"
# Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator`
checkpoints_total_limit: Optional[int] = None
@dataclass
class DataConfig:
""" Parameters for data """
# A folder containing the training data
train_data_dir: Path
# A token to use as a placeholder for the concept
placeholder_token: str
# Super category token to use for normalizing the mapper output
super_category_token: Optional[str] = "object"
# Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process
dataloader_num_workers: int = 8
# Choose between 'object' and 'style' - used for selecting the prompts for training
learnable_property: str = "object"
# How many times to repeat the training data
repeats: int = 100
# The resolution for input images, all the images in the train/validation dataset will be resized to this resolution
resolution: int = 512
# Whether to center crop images before resizing to resolution
center_crop: bool = False
@dataclass
class ModelConfig:
""" Parameters for defining all models """
# Path to pretrained model or model identifier from huggingface.co/models
pretrained_model_name_or_path: str = "CompVis/stable-diffusion-v1-4"
# Whether to use our Nested Dropout technique
use_nested_dropout: bool = True
# Probability to apply nested dropout during training
nested_dropout_prob: float = 0.5
# Whether to normalize the norm of the mapper's output vector
normalize_mapper_output: bool = True
# Target norm for the mapper's output vector
target_norm: Optional[float] = None
# Whether to use positional encoding over the input to the mapper
use_positional_encoding: bool = True
# Sigmas used for computing positional encoding
pe_sigmas: Dict[str, float] = field(default_factory=lambda: {'sigma_t': 0.03, 'sigma_l': 2.0})
# Number of time anchors for computing our positional encodings
num_pe_time_anchors: int = 10
# Whether to output the textual bypass vector
output_bypass: bool = True
# Revision of pretrained model identifier from huggingface.co/models
revision: Optional[str] = None
# Whether training should be resumed from a previous checkpoint.
mapper_checkpoint_path: Optional[Path] = None
def __post_init__(self):
if self.pe_sigmas is not None:
assert len(self.pe_sigmas) == 2, "Should provide exactly two sigma values: one for two and one for layers!"
self.pe_sigmas = PESigmas(sigma_t=self.pe_sigmas['sigma_t'], sigma_l=self.pe_sigmas['sigma_l'])
@dataclass
class EvalConfig:
""" Parameters for validation """
# A list of prompts that will be used during validation to verify that the model is learning
validation_prompts: List[str] = field(default_factory=lambda: VALIDATION_PROMPTS)
# Number of images that should be generated during validation with `validation_prompt`
num_validation_images: int = 4
# Seeds to use for generating the validation images
validation_seeds: Optional[List[int]] = field(default_factory=lambda: [42, 420, 501, 5456])
# Run validation every X steps.
validation_steps: int = 100
# Number of denoising steps
num_denoising_steps: int = 50
def __post_init__(self):
if self.validation_seeds is None:
self.validation_seeds = list(range(self.num_validation_images))
assert len(self.validation_seeds) == self.num_validation_images, \
"Length of validation_seeds should equal num_validation_images"
@dataclass
class OptimConfig:
""" Parameters for the optimization process """
# Total number of training steps to perform.
max_train_steps: Optional[int] = 1_000
# Learning rate
learning_rate: float = 1e-3
# Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size
scale_lr: bool = True
# Batch size (per device) for the training dataloader
train_batch_size: int = 2
# Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass
gradient_checkpointing: bool = False
# Number of updates steps to accumulate before performing a backward/update pass
gradient_accumulation_steps: int = 4
# A seed for reproducible training
seed: Optional[int] = None
# The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",
# "constant", "constant_with_warmup"]
lr_scheduler: str = "constant"
# Number of steps for the warmup in the lr scheduler
lr_warmup_steps: int = 0
# The beta1 parameter for the Adam optimizer
adam_beta1: float = 0.9
# The beta2 parameter for the Adam optimizer
adam_beta2: float = 0.999
# Weight decay to use
adam_weight_decay: float = 1e-2
# Epsilon value for the Adam optimizer
adam_epsilon: float = 1e-08
# Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.
# and an Nvidia Ampere GPU.
mixed_precision: str = "no"
# Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see
# https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
allow_tf32: bool = False
@dataclass
class RunConfig:
""" The main configuration for the coach trainer """
log: LogConfig = field(default_factory=LogConfig)
data: DataConfig = field(default_factory=DataConfig)
model: ModelConfig = field(default_factory=ModelConfig)
eval: EvalConfig = field(default_factory=EvalConfig)
optim: OptimConfig = field(default_factory=OptimConfig)