File size: 5,354 Bytes
b213d84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
constants:
  img_size: 512
  batch_size: 16
  precision: fp32
  max_epochs: 1000
  max_steps: null
  max_train_steps_per_epoch: null
  evaluate_every_n_train_steps: null
  evaluate_every_n_train_epochs: 10
  max_eval_steps_per_eval_epoch: null
  use_torchsnapshot: false
  checkpoint_every_n_steps: 500
model:
  _target_: leffa.models.idm_vton_model.Mae4BgGen
  img_size: ${constants.img_size}
  patch_size: 16
  embed_dim: 1024
  depth: 24
  num_heads: 16
  # pretrained_path: manifold://genads_models/tree/zijianzhou/model/mae/mae_pretrain_vit_large.pth
  pretrained_path: null
  bg_masking_type: min
trainer:
  max_epochs: ${constants.max_epochs}
  max_steps: ${constants.max_steps}
  max_train_steps_per_epoch: ${constants.max_train_steps_per_epoch}
  checkpoint_every_n_steps: ${constants.checkpoint_every_n_steps}
  model_entity_id: null
  resume_from_last_ckpt: true
  model_store_checkpoint_version: null
  garbage_collector_interval: 5001
  pretrained_weights: null
  log_dir: manifold://fblearner_flow_run_metrics/tree/torchmultimodal/idm_vton/logs/
  use_pt2: false
  memory_snapshot: false
eval:
  warmup_iters: 0
  evaluate_every_n_train_steps: ${constants.evaluate_every_n_train_steps}
  evaluate_every_n_train_epochs: ${constants.evaluate_every_n_train_epochs}
  max_eval_steps_per_eval_epoch: ${constants.max_eval_steps_per_eval_epoch}
datasets:
  mae_train:
    dataset:
      _target_: media_dataloader.api.EnrichingDataset
      datasource:
        _target_: media_dataloader.api.LazyHiveDataSource
        namespace: ad_metrics
        table: hybrid_3_0_1st_shein_data
        partition_filter_predicate_list:
        - ds = '2024-07-20'
      enrichments:
      - _target_: media_dataloader.api.media_lookups.EverstoreLookups
        lookup_handle_to_media_columns:
          everstore_handle: "image"
      - _target_: media_dataloader.api.media_lookups.ManifoldLookups
        lookup_handle_to_media_columns:
          binary_mask_manifold_path: bg_mask
      collate_fn:
      - _target_: media_dataloader.api.Collate
      - _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
        image_field: image
        blob_field: image
      - _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
        image_field: bg_mask
        blob_field: bg_mask
      - _target_: leffa.datasets.transform.MaeTransform
        input_size: ${constants.img_size}
        is_train: true
    dataloader:
      _target_: media_dataloader.api.StatefulDataLoader
      dataset: ${datasets.mae_train.dataset}
      batch_size: ${constants.batch_size}
      num_workers: 8
      prefetch_factor: 2
      pin_memory: true
      persistent_workers: true
      multiprocessing_context: forkserver
  mae_test:
    dataset:
      _target_: media_dataloader.api.EnrichingDataset
      datasource:
        _target_: media_dataloader.api.LazyHiveDataSource
        namespace: ad_metrics
        table: hybrid_3_0_1st_shein_data
        partition_filter_predicate_list:
        - ds = '2024-07-20'
      enrichments:
      - _target_: media_dataloader.api.media_lookups.EverstoreLookups
        lookup_handle_to_media_columns:
          everstore_handle: "image"
      - _target_: media_dataloader.api.media_lookups.ManifoldLookups
        lookup_handle_to_media_columns:
          binary_mask_manifold_path: bg_mask
      collate_fn:
      - _target_: media_dataloader.api.Collate
      - _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
        image_field: image
        blob_field: image
      - _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
        image_field: bg_mask
        blob_field: bg_mask
      - _target_: leffa.datasets.transform.MaeTransform
        input_size: ${constants.img_size}
        is_train: false
    dataloader:
      _target_: media_dataloader.api.StatefulDataLoader
      dataset: ${datasets.mae_train.dataset}
      batch_size: ${constants.batch_size}
      num_workers: 0
      prefetch_factor: null
      pin_memory: true
      persistent_workers: false
      multiprocessing_context: null
seed: 42
train_dataset: ${datasets.mae_train}
eval_dataset: null
# eval_dataset: ${datasets.mae_test}
unit:
  _target_: leffa.vton_unit.VtonUnit
  _partial_: true
  model: ${model}
  strategy: ddp
  # strategy:
  #   _target_: leffa.utils.create_fsdp_strategy
  #   sharding_strategy: FULL_SHARD
  #   state_dict_type: SHARDED_STATE_DICT
  #   class_paths:
  #   - leffa.models.idm_vton_model.MaskedAutoencoderViT
  optim_fn:
    _target_: torch.optim.AdamW
    _partial_: true
    lr: 1.0e-05
    betas:
    - 0.9
    - 0.999
    eps: 1.0e-08
    weight_decay: 0.01
    amsgrad: false
  lr_scheduler_fn:
    _target_: torch.optim.lr_scheduler.ConstantLR
    _partial_: true
    factor: 1.0
  swa_params:
    _target_: torchtnt.framework.auto_unit.SWAParams
    warmup_steps_or_epochs: 0
    step_or_epoch_update_freq: 1
    averaging_method: ema
    ema_decay: 0.9999
    use_lit: true
  precision: ${constants.precision}
  clip_grad_norm: 1.0
umm_metadata:
  model_type_name: ads_genads_ldm
  model_series_name: ads_genads_ldm
  oncall: ai_genads
checkpoint:
  checkpoint_dir: null
  checkpoint_path: null
  checkpoint_every_n_steps: ${constants.checkpoint_every_n_steps}