File size: 6,981 Bytes
6e73cd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import contextlib
import logging
import math
import warnings
from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Union
from composer.utils import dist
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
from llmfoundry.models.utils import init_empty_weights
log = logging.getLogger(__name__)
def pop_config(cfg: DictConfig,
key: str,
must_exist: bool = True,
default_value: Any = None,
convert: bool = False) -> Any:
"""Pop a value from the main config file and return it.
If the key does not exist, return the default_value or raise a RuntimeError
depending on the must_exist flag. If the convert flag is set to True, then
we will convert the value to a python object using OmegaConf.to_container.
"""
value = cfg.pop(key, None)
if value is not None and convert:
if not isinstance(value, DictConfig) and not isinstance(
value, ListConfig):
raise ValueError(
f'The key {key} has a value of type {type(value)} that cannot be \
converted to a dict or list. Please check your yaml.'
)
return om.to_container(value)
elif value is not None:
return value
elif must_exist:
raise NameError(
f'The {key} parameter is missing and must exist for execution. Please check your yaml.'
)
else:
return default_value
def calculate_batch_size_info(
global_batch_size: int, device_microbatch_size: Union[int, Literal['auto']]
) -> Tuple[int, Union[int, Literal['auto']], Union[int, Literal['auto']]]:
if global_batch_size % dist.get_world_size() != 0:
raise ValueError(
f'Global batch size {global_batch_size} is not divisible by {dist.get_world_size()} '
+
'as a result, the batch size would be truncated, please adjust `global_batch_size` '
+ f'to be divisible by world size, {dist.get_world_size()}.')
device_batch_size = global_batch_size // dist.get_world_size()
if device_microbatch_size == 'auto':
device_grad_accum = 'auto'
elif isinstance(device_microbatch_size, int):
if device_microbatch_size > device_batch_size:
log.warn(
f'device_microbatch_size > device_batch_size, ' +
f'will be reduced from {device_microbatch_size} -> {device_batch_size}.'
)
device_microbatch_size = device_batch_size
device_grad_accum = math.ceil(device_batch_size /
device_microbatch_size)
else:
raise ValueError(f'Not sure how to parse {device_microbatch_size=}')
return device_batch_size, device_microbatch_size, device_grad_accum
# Coming soon: this conversion math will be done inside Composer Trainer
def update_batch_size_info(cfg: DictConfig) -> DictConfig:
device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info(
cfg.global_train_batch_size, cfg.device_train_microbatch_size)
cfg.n_gpus = dist.get_world_size()
cfg.device_train_batch_size = device_train_batch_size
cfg.device_train_microbatch_size = device_train_microbatch_size
cfg.device_train_grad_accum = device_train_grad_accum
# Safely set `device_eval_batch_size` if not provided by user
if 'device_eval_batch_size' not in cfg:
if cfg.device_train_microbatch_size == 'auto':
cfg.device_eval_batch_size = 1 # TODO debug auto eval microbatching
else:
cfg.device_eval_batch_size = cfg.device_train_microbatch_size
return cfg
def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]):
# Restrict model init_device to 'meta' and 'cpu',
# using 'cuda' vs. 'cuda:id' is tricky and can lead to common user errors
# when multiple GPUs are available.
# Also 'meta' is only valid when using FSDP
init_context = contextlib.nullcontext()
if 'init_device' in model_cfg:
assert model_cfg.init_device in ['meta', 'cpu', 'mixed']
if fsdp_config is None and model_cfg.init_device == 'meta':
warnings.warn(
"Using `cfg.model.init_device='meta'` is only valid when using FSDP! " +\
"Reverting to `cfg.model.init_device='cpu'`.")
model_cfg.init_device = 'cpu'
if model_cfg.init_device == 'meta':
init_context = init_empty_weights()
if model_cfg.init_device == 'mixed':
if fsdp_config is None:
raise NotImplementedError(
'Using init_device `mixed` is only supported with FSDP. ' +
'Please add a FSDP config.')
# Always set `sync_module_states` to True for mixed initialization
if not fsdp_config.get('sync_module_states', False):
warnings.warn((
'Setting `sync_module_states = True` for FSDP. This is required '
'when using mixed initialization.'))
fsdp_config['sync_module_states'] = True
# Set defaults for mixed initialization
fsdp_config.setdefault('use_orig_params', False)
fsdp_config.setdefault('load_monolith_rank0_only', True)
# no mixed precision needed for weights when they're already 16 bits
master_dtype = model_cfg.get('master_weights_dtype')
small_dtypes = ('bf16', 'f16', 'float16', 'bfloat16', 'amp_fp16',
'amp_bf16')
if fsdp_config and master_dtype in small_dtypes:
reduce_dtype = None
buffer_dtype = None
mixed_precision = fsdp_config.get('mixed_precision')
if isinstance(mixed_precision, Mapping):
reduce_dtype = mixed_precision.get('reduce_dtype')
buffer_dtype = mixed_precision.get('buffer_dtype')
fsdp_config['mixed_precision'] = {
'param_dtype': None,
'reduce_dtype': reduce_dtype,
'buffer_dtype': buffer_dtype,
'keep_low_precision_grads': True,
}
return init_context
def log_config(cfg: DictConfig) -> None:
"""Logs the current config and updates the wandb and mlflow configs.
This function can be called multiple times to update the wandb and MLflow
config with different variables.
"""
print(om.to_yaml(cfg))
if 'wandb' in cfg.get('loggers', {}):
try:
import wandb
except ImportError as e:
raise e
if wandb.run:
wandb.config.update(om.to_container(cfg, resolve=True))
if 'mlflow' in cfg.get('loggers', {}):
try:
import mlflow
except ImportError as e:
raise e
if mlflow.active_run():
mlflow.log_params(params=om.to_container(cfg, resolve=True))
|