Spaces:
Running
on
Zero
Running
on
Zero
def create_deepspeed_config(args): | |
ds_config = { | |
"steps_per_print": 1000, | |
"train_batch_size": args.global_batch_size, | |
"gradient_accumulation_steps": args.gradient_accumulation_steps, | |
# "train_micro_batch_size_per_gpu": args.batch_size, # determined by (train_batch_size, gradient_accumulation_steps) | |
"optimizer": { | |
"type": "Adam", | |
"adam_w_mode": True, | |
"params": { | |
"lr": args.lr, | |
"weight_decay": args.weight_decay, | |
"bias_correction": True, | |
"betas": [ | |
args.beta1, | |
args.beta2 | |
], | |
} | |
}, | |
"fp16": { | |
"enabled": args.mixed_precision == 'fp16', | |
"loss_scale": 0, | |
"initial_scale_power": 16, | |
"loss_scale_window": 1000, | |
"hysteresis": 2, | |
"min_loss_scale": 1 | |
}, | |
"bf16": { | |
"enabled": args.mixed_precision == 'bf16', | |
}, | |
# "flops_profiler": { | |
# "enabled": True, | |
# "profile_step": -1, | |
# "module_depth": -1, | |
# "top_modules": 1, | |
# "detailed": True, | |
# }, | |
"zero_allow_untested_optimizer": True | |
} | |
if args.clip_grad is not None: | |
ds_config.update({'gradient_clipping': args.clip_grad}) | |
if args.zero_stage == 0: | |
ds_config.update({"zero_optimization": | |
{ | |
"stage": args.zero_stage, | |
"contiguous_gradients": True, | |
"overlap_comm": True, | |
} | |
}) | |
elif args.zero_stage == 1: | |
ds_config.update({"zero_optimization": | |
{ | |
"stage": args.zero_stage, | |
"contiguous_gradients": True, | |
"overlap_comm": True, | |
"reduce_bucket_size": 5e8, | |
} | |
}) | |
elif args.zero_stage == 2: | |
ds_config.update({"zero_optimization": | |
{ | |
"stage": args.zero_stage, | |
"contiguous_gradients": True, | |
"overlap_comm": True, | |
"reduce_scatter": True, | |
"reduce_bucket_size": 5e8, | |
"allgather_bucket_size": 5e8, | |
} | |
}) | |
elif args.zero_stage == 3: | |
ds_config.update({"zero_optimization": | |
{ | |
"stage": args.zero_stage, | |
"contiguous_gradients": True, | |
"overlap_comm": True, | |
"reduce_bucket_size": 5e8, | |
"stage3_prefetch_bucket_size": 5e8, | |
"stage3_param_persistence_threshold": 1e6, | |
"stage3_max_live_parameters": 1e9, | |
"stage3_max_reuse_distance": 1e9, | |
"stage3_gather_16bit_weights_on_model_save": True | |
} | |
}) | |
return ds_config | |