hieungo1410's picture
'add'
8cb4f3b
raw
history blame
1.43 kB
import torch, math
class AdamOptim(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-9, **kwargs):
defaults = dict(lr=lr, betas=betas, eps=eps)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
# No backward
continue
grad = p.grad
state = self.state[p]
if len(state) == 0:
# Initial step
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
step_size = group['lr'] / bias_correction1
p.addcdiv_(exp_avg, denom, value=-step_size)