Leffa / leffa /conf /train_mae.yaml
franciszzj's picture
init code
b213d84
raw
history blame
5.35 kB
constants:
img_size: 512
batch_size: 16
precision: fp32
max_epochs: 1000
max_steps: null
max_train_steps_per_epoch: null
evaluate_every_n_train_steps: null
evaluate_every_n_train_epochs: 10
max_eval_steps_per_eval_epoch: null
use_torchsnapshot: false
checkpoint_every_n_steps: 500
model:
_target_: leffa.models.idm_vton_model.Mae4BgGen
img_size: ${constants.img_size}
patch_size: 16
embed_dim: 1024
depth: 24
num_heads: 16
# pretrained_path: manifold://genads_models/tree/zijianzhou/model/mae/mae_pretrain_vit_large.pth
pretrained_path: null
bg_masking_type: min
trainer:
max_epochs: ${constants.max_epochs}
max_steps: ${constants.max_steps}
max_train_steps_per_epoch: ${constants.max_train_steps_per_epoch}
checkpoint_every_n_steps: ${constants.checkpoint_every_n_steps}
model_entity_id: null
resume_from_last_ckpt: true
model_store_checkpoint_version: null
garbage_collector_interval: 5001
pretrained_weights: null
log_dir: manifold://fblearner_flow_run_metrics/tree/torchmultimodal/idm_vton/logs/
use_pt2: false
memory_snapshot: false
eval:
warmup_iters: 0
evaluate_every_n_train_steps: ${constants.evaluate_every_n_train_steps}
evaluate_every_n_train_epochs: ${constants.evaluate_every_n_train_epochs}
max_eval_steps_per_eval_epoch: ${constants.max_eval_steps_per_eval_epoch}
datasets:
mae_train:
dataset:
_target_: media_dataloader.api.EnrichingDataset
datasource:
_target_: media_dataloader.api.LazyHiveDataSource
namespace: ad_metrics
table: hybrid_3_0_1st_shein_data
partition_filter_predicate_list:
- ds = '2024-07-20'
enrichments:
- _target_: media_dataloader.api.media_lookups.EverstoreLookups
lookup_handle_to_media_columns:
everstore_handle: "image"
- _target_: media_dataloader.api.media_lookups.ManifoldLookups
lookup_handle_to_media_columns:
binary_mask_manifold_path: bg_mask
collate_fn:
- _target_: media_dataloader.api.Collate
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: image
blob_field: image
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: bg_mask
blob_field: bg_mask
- _target_: leffa.datasets.transform.MaeTransform
input_size: ${constants.img_size}
is_train: true
dataloader:
_target_: media_dataloader.api.StatefulDataLoader
dataset: ${datasets.mae_train.dataset}
batch_size: ${constants.batch_size}
num_workers: 8
prefetch_factor: 2
pin_memory: true
persistent_workers: true
multiprocessing_context: forkserver
mae_test:
dataset:
_target_: media_dataloader.api.EnrichingDataset
datasource:
_target_: media_dataloader.api.LazyHiveDataSource
namespace: ad_metrics
table: hybrid_3_0_1st_shein_data
partition_filter_predicate_list:
- ds = '2024-07-20'
enrichments:
- _target_: media_dataloader.api.media_lookups.EverstoreLookups
lookup_handle_to_media_columns:
everstore_handle: "image"
- _target_: media_dataloader.api.media_lookups.ManifoldLookups
lookup_handle_to_media_columns:
binary_mask_manifold_path: bg_mask
collate_fn:
- _target_: media_dataloader.api.Collate
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: image
blob_field: image
- _target_: torchmultimodal.fb.genai.transforms.hive_transforms.EverstoreImageToPILTransform
image_field: bg_mask
blob_field: bg_mask
- _target_: leffa.datasets.transform.MaeTransform
input_size: ${constants.img_size}
is_train: false
dataloader:
_target_: media_dataloader.api.StatefulDataLoader
dataset: ${datasets.mae_train.dataset}
batch_size: ${constants.batch_size}
num_workers: 0
prefetch_factor: null
pin_memory: true
persistent_workers: false
multiprocessing_context: null
seed: 42
train_dataset: ${datasets.mae_train}
eval_dataset: null
# eval_dataset: ${datasets.mae_test}
unit:
_target_: leffa.vton_unit.VtonUnit
_partial_: true
model: ${model}
strategy: ddp
# strategy:
# _target_: leffa.utils.create_fsdp_strategy
# sharding_strategy: FULL_SHARD
# state_dict_type: SHARDED_STATE_DICT
# class_paths:
# - leffa.models.idm_vton_model.MaskedAutoencoderViT
optim_fn:
_target_: torch.optim.AdamW
_partial_: true
lr: 1.0e-05
betas:
- 0.9
- 0.999
eps: 1.0e-08
weight_decay: 0.01
amsgrad: false
lr_scheduler_fn:
_target_: torch.optim.lr_scheduler.ConstantLR
_partial_: true
factor: 1.0
swa_params:
_target_: torchtnt.framework.auto_unit.SWAParams
warmup_steps_or_epochs: 0
step_or_epoch_update_freq: 1
averaging_method: ema
ema_decay: 0.9999
use_lit: true
precision: ${constants.precision}
clip_grad_norm: 1.0
umm_metadata:
model_type_name: ads_genads_ldm
model_series_name: ads_genads_ldm
oncall: ai_genads
checkpoint:
checkpoint_dir: null
checkpoint_path: null
checkpoint_every_n_steps: ${constants.checkpoint_every_n_steps}