Agon H
commited on
Commit
•
d1b836d
1
Parent(s):
a8df195
Upload param_init_fns.py
Browse files- param_init_fns.py +181 -0
param_init_fns.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from collections.abc import Sequence
|
4 |
+
from functools import partial
|
5 |
+
from typing import Optional, Tuple, Union
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from .norm import NORM_CLASS_REGISTRY
|
9 |
+
|
10 |
+
def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs):
|
11 |
+
del kwargs
|
12 |
+
if verbose > 1:
|
13 |
+
warnings.warn(f"Initializing network using module's reset_parameters attribute")
|
14 |
+
if hasattr(module, 'reset_parameters'):
|
15 |
+
module.reset_parameters()
|
16 |
+
|
17 |
+
def fused_init_helper_(module: nn.Module, init_fn_):
|
18 |
+
_fused = getattr(module, '_fused', None)
|
19 |
+
if _fused is None:
|
20 |
+
raise RuntimeError(f'Internal logic error')
|
21 |
+
(dim, splits) = _fused
|
22 |
+
splits = (0, *splits, module.weight.size(dim))
|
23 |
+
for (s, e) in zip(splits[:-1], splits[1:]):
|
24 |
+
slice_indices = [slice(None)] * module.weight.ndim
|
25 |
+
slice_indices[dim] = slice(s, e)
|
26 |
+
init_fn_(module.weight[slice_indices])
|
27 |
+
|
28 |
+
def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
|
29 |
+
del kwargs
|
30 |
+
if verbose > 1:
|
31 |
+
warnings.warn(f'If model has bias parameters they are initialized to 0.')
|
32 |
+
init_div_is_residual = init_div_is_residual
|
33 |
+
if init_div_is_residual is False:
|
34 |
+
div_is_residual = 1.0
|
35 |
+
elif init_div_is_residual is True:
|
36 |
+
div_is_residual = math.sqrt(2 * n_layers)
|
37 |
+
elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
|
38 |
+
div_is_residual = init_div_is_residual
|
39 |
+
elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
|
40 |
+
div_is_residual = float(init_div_is_residual)
|
41 |
+
else:
|
42 |
+
div_is_residual = 1.0
|
43 |
+
raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
|
44 |
+
if init_div_is_residual is not False:
|
45 |
+
if verbose > 1:
|
46 |
+
warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.')
|
47 |
+
if isinstance(module, nn.Linear):
|
48 |
+
if hasattr(module, '_fused'):
|
49 |
+
fused_init_helper_(module, init_fn_)
|
50 |
+
else:
|
51 |
+
init_fn_(module.weight)
|
52 |
+
if module.bias is not None:
|
53 |
+
torch.nn.init.zeros_(module.bias)
|
54 |
+
if init_div_is_residual is not False and getattr(module, '_is_residual', False):
|
55 |
+
with torch.no_grad():
|
56 |
+
module.weight.div_(div_is_residual)
|
57 |
+
elif isinstance(module, nn.Embedding):
|
58 |
+
if emb_init_std is not None:
|
59 |
+
std = emb_init_std
|
60 |
+
if std == 0:
|
61 |
+
warnings.warn(f'Embedding layer initialized to 0.')
|
62 |
+
emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
|
63 |
+
if verbose > 1:
|
64 |
+
warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.')
|
65 |
+
elif emb_init_uniform_lim is not None:
|
66 |
+
lim = emb_init_uniform_lim
|
67 |
+
if isinstance(lim, Sequence):
|
68 |
+
if len(lim) > 2:
|
69 |
+
raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.')
|
70 |
+
if lim[0] == lim[1]:
|
71 |
+
warnings.warn(f'Embedding layer initialized to {lim[0]}.')
|
72 |
+
else:
|
73 |
+
if lim == 0:
|
74 |
+
warnings.warn(f'Embedding layer initialized to 0.')
|
75 |
+
lim = [-lim, lim]
|
76 |
+
(a, b) = lim
|
77 |
+
emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
|
78 |
+
if verbose > 1:
|
79 |
+
warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.')
|
80 |
+
else:
|
81 |
+
emb_init_fn_ = init_fn_
|
82 |
+
emb_init_fn_(module.weight)
|
83 |
+
elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
|
84 |
+
if verbose > 1:
|
85 |
+
warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.')
|
86 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
87 |
+
torch.nn.init.ones_(module.weight)
|
88 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
89 |
+
torch.nn.init.zeros_(module.bias)
|
90 |
+
elif isinstance(module, nn.MultiheadAttention):
|
91 |
+
if module._qkv_same_embed_dim:
|
92 |
+
assert module.in_proj_weight is not None
|
93 |
+
assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
|
94 |
+
assert d_model is not None
|
95 |
+
_d = d_model
|
96 |
+
splits = (0, _d, 2 * _d, 3 * _d)
|
97 |
+
for (s, e) in zip(splits[:-1], splits[1:]):
|
98 |
+
init_fn_(module.in_proj_weight[s:e])
|
99 |
+
else:
|
100 |
+
assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
|
101 |
+
assert module.in_proj_weight is None
|
102 |
+
init_fn_(module.q_proj_weight)
|
103 |
+
init_fn_(module.k_proj_weight)
|
104 |
+
init_fn_(module.v_proj_weight)
|
105 |
+
if module.in_proj_bias is not None:
|
106 |
+
torch.nn.init.zeros_(module.in_proj_bias)
|
107 |
+
if module.bias_k is not None:
|
108 |
+
torch.nn.init.zeros_(module.bias_k)
|
109 |
+
if module.bias_v is not None:
|
110 |
+
torch.nn.init.zeros_(module.bias_v)
|
111 |
+
init_fn_(module.out_proj.weight)
|
112 |
+
if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False):
|
113 |
+
with torch.no_grad():
|
114 |
+
module.out_proj.weight.div_(div_is_residual)
|
115 |
+
if module.out_proj.bias is not None:
|
116 |
+
torch.nn.init.zeros_(module.out_proj.bias)
|
117 |
+
else:
|
118 |
+
for _ in module.parameters(recurse=False):
|
119 |
+
raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
|
120 |
+
|
121 |
+
def _normal_init_(std, mean=0.0):
|
122 |
+
return partial(torch.nn.init.normal_, mean=mean, std=std)
|
123 |
+
|
124 |
+
def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
|
125 |
+
del kwargs
|
126 |
+
init_fn_ = _normal_init_(std=std)
|
127 |
+
if verbose > 1:
|
128 |
+
warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
|
129 |
+
generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
130 |
+
|
131 |
+
def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
|
132 |
+
del kwargs
|
133 |
+
if init_std is None:
|
134 |
+
raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
|
135 |
+
_normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
136 |
+
|
137 |
+
def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
|
138 |
+
del kwargs
|
139 |
+
std = math.sqrt(2 / (5 * d_model))
|
140 |
+
_normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
141 |
+
|
142 |
+
def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
|
143 |
+
"""From section 2.3.1 of GPT-NeoX-20B:
|
144 |
+
|
145 |
+
An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
|
146 |
+
see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
|
147 |
+
and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
|
148 |
+
"""
|
149 |
+
del kwargs
|
150 |
+
residual_div = n_layers / math.sqrt(10)
|
151 |
+
if verbose > 1:
|
152 |
+
warnings.warn(f'setting init_div_is_residual to {residual_div}')
|
153 |
+
small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
154 |
+
|
155 |
+
def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
|
156 |
+
del kwargs
|
157 |
+
if verbose > 1:
|
158 |
+
warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
|
159 |
+
kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
160 |
+
generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
161 |
+
|
162 |
+
def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
|
163 |
+
del kwargs
|
164 |
+
if verbose > 1:
|
165 |
+
warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
|
166 |
+
kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
167 |
+
generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
168 |
+
|
169 |
+
def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
|
170 |
+
del kwargs
|
171 |
+
xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
|
172 |
+
if verbose > 1:
|
173 |
+
warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}')
|
174 |
+
generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
175 |
+
|
176 |
+
def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
|
177 |
+
xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
|
178 |
+
if verbose > 1:
|
179 |
+
warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}')
|
180 |
+
generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
|
181 |
+
MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
|