m2-bert-80M-8k-retrieval / hyena_utils.py
Dan Fu
Automodel support
03ef821
raw
history blame
8.35 kB
# Copyright (c) 2023, Dan Fu and Simran Arora.
# Adapted from https://github.com/HazyResearch/safari/blob/main/src/models/sequence/hyena.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import opt_einsum as oe
contract = oe.contract
""" Utils for the training loop. Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py """
class OptimModule(nn.Module):
""" Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """
def register(self, name, tensor, lr=None, wd=0.0):
"""Register a tensor with a configurable learning rate and 0 weight decay"""
if lr == 0.0:
self.register_buffer(name, tensor)
else:
self.register_parameter(name, nn.Parameter(tensor))
optim = {}
if lr is not None: optim["lr"] = lr
if wd is not None: optim["weight_decay"] = wd
setattr(getattr(self, name), "_optim", optim)
def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None):
# u.shape: B H L
seqlen = u.shape[-1]
fft_size = 2 * seqlen
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
if k_rev is not None:
k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
k_f = k_f + k_rev_f.conj()
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
if len(u.shape) > 3:
k_f = k_f.unsqueeze(1)
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
out = y + u * D
if gelu:
out = F.gelu(out)
if dropout_mask is not None:
return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype)
else:
return out.to(dtype=u.dtype)
@torch.jit.script
def mul_sum(q, y):
return (q * y).sum(dim=1)
class Sin(nn.Module):
def __init__(self, dim, w=10, w_mod=1, train_freq=True):
super().__init__()
init_tensor = torch.ones(1, dim)
self.freq = (
nn.Parameter(w * init_tensor)
if train_freq
else w * torch.ones(1, dim)
)
self.w_mod = w_mod
def forward(self, x):
return torch.sin(self.w_mod * self.freq * x)
class PositionalEmbedding(OptimModule):
def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float = 1e-5, **kwargs):
"""Complex exponential positional embeddings for Hyena filters."""
super().__init__()
self.seq_len = seq_len
# The time embedding fed to the filteres is normalized so that t_f = 1
t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
if emb_dim > 1:
bands = (emb_dim - 1) // 2
# To compute the right embeddings we use the "proper" linspace
t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1
f = torch.linspace(1e-4, bands - 1, bands)[None, None]
z = torch.exp(-1j * f * w)
z = torch.cat([t, z.real, z.imag], dim=-1)
self.register("z", z, lr=lr_pos_emb)
self.register("t", t, lr=0.0)
def forward(self, L):
return self.z[:, :L], self.t[:, :L]
class ExponentialModulation(OptimModule):
def __init__(
self,
d_model,
fast_decay_pct=0.3,
slow_decay_pct=1.5,
target=1e-2,
modulation_lr=0.0,
shift: float = 0.0,
**kwargs,
):
super().__init__()
self.shift = shift
max_decay = math.log(target) / fast_decay_pct
min_decay = math.log(target) / slow_decay_pct
deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]
self.register("deltas", deltas, lr=modulation_lr)
def forward(self, t, x):
decay = torch.exp(-t * self.deltas.abs())
x = x * (decay + self.shift)
return x
class HyenaFilter(OptimModule):
def __init__(
self,
d_model,
emb_dim=3, # dim of input to MLP, augments with positional encoding
order=16, # width of the implicit MLP
seq_len=1024,
lr=1e-3,
lr_pos_emb=1e-5,
dropout=0.0,
w=1, # frequency of periodic activations
w_mod=1, # non-learnable modification of w
wd=0, # weight decay of kernel parameters
bias=True,
num_inner_mlps=2,
linear_mixer=False,
modulate: bool = True,
normalized=False,
bidirectional=False,
**kwargs,
):
"""
Implicit long filter with modulation.
Args:
d_model: number of channels in the input
emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands
order: width of the FFN
num_inner_mlps: number of inner linear layers inside filter MLP
Note:
filter_dropout is not implemented
"""
super().__init__()
self.d_model=d_model
self.emb_dim=emb_dim
self.seq_len=seq_len
self.modulate=modulate
self.use_bias = bias
self.bidirectional = bidirectional
self.bias = nn.Parameter(torch.randn(self.d_model))
self.dropout = nn.Dropout(dropout)
act = Sin(dim=order, w=w, w_mod=w_mod)
assert (
emb_dim % 2 != 0 and emb_dim >= 3
), "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)"
self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb)
# uses a variable number of inner linear layers
if linear_mixer is False:
self.implicit_filter = nn.Sequential(
nn.Linear(emb_dim, order),
act,
)
for i in range(num_inner_mlps):
self.implicit_filter.append(nn.Linear(order, order))
self.implicit_filter.append(act)
self.implicit_filter.append(nn.Linear(order, d_model, bias=False))
else:
self.implicit_filter = nn.Sequential(
nn.Linear(emb_dim, d_model, bias=False),
)
if self.bidirectional:
self.implicit_filter_rev = nn.Sequential(
nn.Linear(emb_dim, order),
act,
)
for i in range(num_inner_mlps):
self.implicit_filter_rev.append(nn.Linear(order, order))
self.implicit_filter_rev.append(act)
self.implicit_filter_rev.append(nn.Linear(order, d_model, bias=False))
self.modulation = ExponentialModulation(d_model, **kwargs)
self.normalized = normalized
for c in self.implicit_filter.children():
for name, v in c.state_dict().items():
optim = {"weight_decay": wd, "lr": lr}
setattr(getattr(c, name), "_optim", optim)
def filter(self, L, *args, **kwargs):
z, t = self.pos_emb(L)
h = self.implicit_filter(z)
if self.modulate:
h = self.modulation(t, h)
if self.normalized:
h = h / torch.norm(h, dim=-1, p=1, keepdim=True)
return h
def filter_rev(self, L, *args, **kwargs):
z, t = self.pos_emb(L)
h = self.implicit_filter_rev(z)
if self.modulate:
h = self.modulation(t, h)
if self.normalized:
h = h / torch.norm(h, dim=-1, p=1, keepdim=True)
return h
def forward(self, x, L, k_fwd=None, k_rev=None, bias=None, *args, **kwargs):
if k_fwd is None:
k_fwd = self.filter(L)
if self.bidirectional and k_rev is None:
k_rev = self.filter_rev(L)
# Ensure compatibility with filters that return a tuple
k_fwd = k_fwd[0] if type(k_fwd) is tuple else k_fwd
if bias is None:
bias = self.bias
bias = bias if self.use_bias else 0 * bias
if self.bidirectional:
k_rev = k_rev[0] if type(k_rev) is tuple else k_rev
k = F.pad(k_fwd, (0, L)) \
+ F.pad(k_rev.flip(-1), (L, 0))
else:
k = k_fwd
y = fftconv_ref(
x,
k,
bias,
dropout_mask=None,
gelu=False,
)
return y.to(dtype=x.dtype)