Leffa / leffa /conf /train.yaml
franciszzj's picture
init code
b213d84
raw
history blame
2.44 kB
### Main entry for the training config in hydra.
### Only top level configurations can run, we decompose the full
### config to multiple subfolders for better reusability.
seed: 42
defaults:
- constants: base
- model: simple_vton_sd15
- trainer: base
- eval: base
- datasets:
- viton_hd_train
- viton_hd_test
- viton_hd_test_local
- dress_code_train
- dress_code_test
- deepfashion_train
- deepfashion_test
- _self_
train_dataset: ${datasets.viton_hd_train}
# train_dataset: ${datasets.dress_code_train}
# train_dataset: ${datasets.deepfashion_train}
eval_dataset: null
unit:
_target_: leffa.vton_unit.VtonUnit
_partial_: True
model: ${model}
# strategy: ddp
strategy:
_target_: leffa.utils.create_fsdp_strategy
sharding_strategy: SHARD_GRAD_OP
state_dict_type: SHARDED_STATE_DICT
mixed_precision:
param_dtype: ${constants.precision}
reduce_dtype: ${constants.precision}
cast_forward_inputs: True
class_paths:
# For VAE (first stage)
- diffusers.models.unets.unet_2d_blocks.DownEncoderBlock2D
- diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D
- diffusers.models.unets.unet_2d_blocks.UpDecoderBlock2D
# For UNet (unet stage) IdmVton
- leffa.models.diffusion_model.attentionhacked_tryon.BasicTransformerBlock
- leffa.models.diffusion_model.attentionhacked_garment.BasicTransformerBlock
# For UNet (unet stage) CatVton
- diffusers.models.attention.BasicTransformerBlock
# For CLIP (condition stage)
- transformers.CLIPTextModel
- transformers.CLIPTextModelWithProjection
- transformers.CLIPVisionModelWithProjection
optim_fn:
_target_: torch.optim.AdamW
_partial_: True
lr: 1.0e-5
betas: [0.9, 0.999]
eps: 1.0e-8
weight_decay: 1.0e-2
amsgrad: false
lr_scheduler_fn:
_target_: torch.optim.lr_scheduler.ConstantLR
_partial_: True
factor: 1.0
swa_params:
_target_: torchtnt.framework.auto_unit.SWAParams
warmup_steps_or_epochs: 0
step_or_epoch_update_freq: 1
averaging_method: ema
ema_decay: 0.9999
use_lit: True
precision: ${constants.precision}
clip_grad_norm: 1.0
umm_metadata:
model_type_name: ads_genads_ldm
model_series_name: ads_genads_ldm
oncall: ai_genads
checkpoint:
checkpoint_dir: null
checkpoint_path: null
checkpoint_every_n_steps: ${constants.checkpoint_every_n_steps}