File size: 2,441 Bytes
b213d84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
### Main entry for the training config in hydra.
### Only top level configurations can run, we decompose the full
### config to multiple subfolders for better reusability.

seed: 42

defaults:
  - constants: base
  - model: simple_vton_sd15
  - trainer: base
  - eval: base
  - datasets:
    - viton_hd_train
    - viton_hd_test
    - viton_hd_test_local
    - dress_code_train
    - dress_code_test
    - deepfashion_train
    - deepfashion_test
  - _self_

train_dataset: ${datasets.viton_hd_train}
# train_dataset: ${datasets.dress_code_train}
# train_dataset: ${datasets.deepfashion_train}
eval_dataset: null

unit:
  _target_: leffa.vton_unit.VtonUnit
  _partial_: True
  model: ${model}
  # strategy: ddp
  strategy:
    _target_: leffa.utils.create_fsdp_strategy
    sharding_strategy: SHARD_GRAD_OP
    state_dict_type: SHARDED_STATE_DICT
    mixed_precision:
      param_dtype: ${constants.precision}
      reduce_dtype: ${constants.precision}
      cast_forward_inputs: True
    class_paths:
      # For VAE (first stage)
      - diffusers.models.unets.unet_2d_blocks.DownEncoderBlock2D
      - diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D
      - diffusers.models.unets.unet_2d_blocks.UpDecoderBlock2D
      # For UNet (unet stage) IdmVton
      - leffa.models.diffusion_model.attentionhacked_tryon.BasicTransformerBlock
      - leffa.models.diffusion_model.attentionhacked_garment.BasicTransformerBlock
      # For UNet (unet stage) CatVton
      - diffusers.models.attention.BasicTransformerBlock
      # For CLIP (condition stage)
      - transformers.CLIPTextModel
      - transformers.CLIPTextModelWithProjection
      - transformers.CLIPVisionModelWithProjection
  optim_fn:
    _target_: torch.optim.AdamW
    _partial_: True
    lr: 1.0e-5
    betas: [0.9, 0.999]
    eps: 1.0e-8
    weight_decay: 1.0e-2
    amsgrad: false
  lr_scheduler_fn:
    _target_: torch.optim.lr_scheduler.ConstantLR
    _partial_: True
    factor: 1.0
  swa_params:
    _target_: torchtnt.framework.auto_unit.SWAParams
    warmup_steps_or_epochs: 0
    step_or_epoch_update_freq: 1
    averaging_method: ema
    ema_decay: 0.9999
    use_lit: True
  precision: ${constants.precision}
  clip_grad_norm: 1.0

umm_metadata:
  model_type_name: ads_genads_ldm
  model_series_name: ads_genads_ldm
  oncall: ai_genads

checkpoint:
  checkpoint_dir: null
  checkpoint_path: null
  checkpoint_every_n_steps: ${constants.checkpoint_every_n_steps}