sociolome / sftp /training /transformer_optimizer.py
Gosse Minnema
Initial commit
05922fb
import logging
import re
from typing import *
import torch
from allennlp.common.from_params import Params, T
from allennlp.training.optimizers import Optimizer
logger = logging.getLogger('optim')
@Optimizer.register('transformer')
class TransformerOptimizer:
"""
Wrapper for AllenNLP optimizer.
This is used to fine-tune the pretrained transformer with some layers fixed and different learning rate.
When some layers are fixed, the wrapper will set the `require_grad` flag as False, which could save
training time and optimize memory usage.
Plz contact Guanghui Qin for bugs.
Params:
base: base optimizer.
embeddings_lr: learning rate for embedding layer. Set as 0.0 to fix it.
encoder_lr: learning rate for encoder layer. Set as 0.0 to fix it.
pooler_lr: learning rate for pooler layer. Set as 0.0 to fix it.
layer_fix: the number of encoder layers that should be fixed.
Example json config:
1. No-op. Do nothing (why do you use me?)
optimizer: {
type: "transformer",
base: {
type: "adam",
lr: 0.001
}
}
2. Fix everything in the transformer.
optimizer: {
type: "transformer",
base: {
type: "adam",
lr: 0.001
},
embeddings_lr: 0.0,
encoder_lr: 0.0,
pooler_lr: 0.0
}
Or equivalently (suppose we have 24 layers)
optimizer: {
type: "transformer",
base: {
type: "adam",
lr: 0.001
},
embeddings_lr: 0.0,
layer_fix: 24,
pooler_lr: 0.0
}
3. Fix embeddings and the lower 12 encoder layers, set a small learning rate
for the other parts of the transformer
optimizer: {
type: "transformer",
base: {
type: "adam",
lr: 0.001
},
embeddings_lr: 0.0,
layer_fix: 12,
encoder_lr: 1e-5,
pooler_lr: 1e-5
}
"""
@classmethod
def from_params(
cls: Type[T],
params: Params,
model_parameters: List[Tuple[str, torch.nn.Parameter]],
**_
):
param_groups = list()
def remove_param(keyword_):
nonlocal model_parameters
logger.info(f'Fix param with name matching {keyword_}.')
for name, param in model_parameters:
if keyword_ in name:
logger.debug(f'Fix param {name}.')
param.requires_grad_(False)
model_parameters = list(filter(lambda x: keyword_ not in x[0], model_parameters))
for i_layer in range(params.pop('layer_fix')):
remove_param('transformer_model.encoder.layer.{}.'.format(i_layer))
for specific_lr, keyword in (
(params.pop('embeddings_lr', None), 'transformer_model.embeddings'),
(params.pop('encoder_lr', None), 'transformer_model.encoder.layer'),
(params.pop('pooler_lr', None), 'transformer_model.pooler'),
):
if specific_lr is not None:
if specific_lr > 0.:
pattern = '.*' + keyword.replace('.', r'\.') + '.*'
if len([name for name, _ in model_parameters if re.match(pattern, name)]) > 0:
param_groups.append([[pattern], {'lr': specific_lr}])
else:
logger.warning(f'{pattern} is set to use lr {specific_lr} but no param matches.')
else:
remove_param(keyword)
if 'parameter_groups' in params:
for pg in params.pop('parameter_groups'):
param_groups.append([pg[0], pg[1].as_dict()])
return Optimizer.by_name(params.get('base').pop('type'))(
model_parameters=model_parameters, parameter_groups=param_groups,
**params.pop('base').as_flat_dict()
)