vampnet-choir / vampnet /scheduler.py
hugo flores garcia
recovering from a gittastrophe
41b9d24
raw
history blame
1.21 kB
import copy
from typing import List
import torch
class NoamScheduler:
"""OG scheduler from transformer paper: https://arxiv.org/pdf/1706.03762.pdf
Implementation from Annotated Transformer: https://nlp.seas.harvard.edu/2018/04/03/attention.html
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
d_model: int = 512,
factor: float = 1.0,
warmup: int = 4000,
):
# Store hparams
self.warmup = warmup
self.factor = factor
self.d_model = d_model
# Initialize variables `lr` and `steps`
self.lr = None
self.steps = 0
# Store the optimizer
self.optimizer = optimizer
def state_dict(self):
return {
key: value for key, value in self.__dict__.items() if key != "optimizer"
}
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def step(self):
self.steps += 1
self.lr = self.factor * (
self.d_model ** (-0.5)
* min(self.steps ** (-0.5), self.steps * self.warmup ** (-1.5))
)
for p in self.optimizer.param_groups:
p["lr"] = self.lr