Spaces:
Running
Running
from functools import partial | |
import torch | |
import torch.nn.functional as F | |
from einops import rearrange | |
from rotary_embedding_torch import RotaryEmbedding, broadcat | |
from torch import nn | |
# helpers | |
def exists(val): | |
return val is not None | |
def default(val, d): | |
return val if exists(val) else d | |
def cast_tuple(val, depth = 1): | |
if isinstance(val, list): | |
val = tuple(val) | |
return val if isinstance(val, tuple) else (val,) * depth | |
def max_neg_value(t): | |
return -torch.finfo(t.dtype).max | |
def stable_softmax(t, dim = -1, alpha = 32 ** 2): | |
t = t / alpha | |
t = t - torch.amax(t, dim = dim, keepdim = True).detach() | |
return (t * alpha).softmax(dim = dim) | |
def route_args(router, args, depth): | |
routed_args = [(dict(), dict()) for _ in range(depth)] | |
matched_keys = [key for key in args.keys() if key in router] | |
for key in matched_keys: | |
val = args[key] | |
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): | |
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) | |
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) | |
return routed_args | |
# classes | |
class SequentialSequence(nn.Module): | |
def __init__(self, layers, args_route = {}, layer_dropout = 0.): | |
super().__init__() | |
assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' | |
self.layers = layers | |
self.args_route = args_route | |
self.layer_dropout = layer_dropout | |
def forward(self, x, **kwargs): | |
args = route_args(self.args_route, kwargs, len(self.layers)) | |
layers_and_args = list(zip(self.layers, args)) | |
for (f, g), (f_args, g_args) in layers_and_args: | |
x = x + f(x, **f_args) | |
x = x + g(x, **g_args) | |
return x | |
class DivideMax(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
maxes = x.amax(dim = self.dim, keepdim = True).detach() | |
return x / maxes | |
# https://arxiv.org/abs/2103.17239 | |
class LayerScale(nn.Module): | |
def __init__(self, dim, depth, fn): | |
super().__init__() | |
if depth <= 18: | |
init_eps = 0.1 | |
elif depth > 18 and depth <= 24: | |
init_eps = 1e-5 | |
else: | |
init_eps = 1e-6 | |
scale = torch.zeros(1, 1, dim).fill_(init_eps) | |
self.scale = nn.Parameter(scale) | |
self.fn = fn | |
def forward(self, x, **kwargs): | |
return self.fn(x, **kwargs) * self.scale | |
# layer norm | |
class PreNorm(nn.Module): | |
def __init__(self, dim, fn, sandwich = False): | |
super().__init__() | |
self.norm = nn.LayerNorm(dim) | |
self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity() | |
self.fn = fn | |
def forward(self, x, **kwargs): | |
x = self.norm(x) | |
x = self.fn(x, **kwargs) | |
return self.norm_out(x) | |
# feed forward | |
class GEGLU(nn.Module): | |
def forward(self, x): | |
x, gates = x.chunk(2, dim = -1) | |
return x * F.gelu(gates) | |
class FeedForward(nn.Module): | |
def __init__(self, dim, dropout = 0., mult = 4.): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(dim, dim * mult * 2), | |
GEGLU(), | |
nn.Dropout(dropout), | |
nn.Linear(dim * mult, dim) | |
) | |
def forward(self, x): | |
return self.net(x) | |
# Attention | |
class Attention(nn.Module): | |
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.): | |
super().__init__() | |
inner_dim = dim_head * heads | |
self.heads = heads | |
self.seq_len = seq_len | |
self.scale = dim_head ** -0.5 | |
self.causal = causal | |
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) | |
self.to_out = nn.Sequential( | |
nn.Linear(inner_dim, dim), | |
nn.Dropout(dropout) | |
) | |
def forward(self, x, mask = None): | |
b, n, _, h, device = *x.shape, self.heads, x.device | |
softmax = torch.softmax | |
qkv = self.to_qkv(x).chunk(3, dim = -1) | |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) | |
q = q * self.scale | |
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) | |
mask_value = max_neg_value(dots) | |
if exists(mask): | |
mask = rearrange(mask, 'b j -> b () () j') | |
dots.masked_fill_(~mask, mask_value) | |
del mask | |
if self.causal: | |
i, j = dots.shape[-2:] | |
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() | |
dots.masked_fill_(mask, mask_value) | |
attn = softmax(dots, dim=-1) | |
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) | |
out = rearrange(out, 'b h n d -> b n (h d)') | |
out = self.to_out(out) | |
return out | |
# main transformer class | |
class Transformer(nn.Module): | |
def __init__( | |
self, | |
*, | |
dim, | |
depth, | |
seq_len, | |
causal = True, | |
heads = 8, | |
dim_head = 64, | |
ff_mult = 4, | |
attn_dropout = 0., | |
ff_dropout = 0., | |
sparse_attn = False, | |
sandwich_norm = False, | |
): | |
super().__init__() | |
layers = nn.ModuleList([]) | |
sparse_layer = cast_tuple(sparse_attn, depth) | |
for ind, sparse_attn in zip(range(depth), sparse_layer): | |
attn = Attention(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) | |
ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout) | |
layers.append(nn.ModuleList([ | |
LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)), | |
LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm)) | |
])) | |
execute_type = SequentialSequence | |
route_attn = ((True, False),) * depth | |
attn_route_map = {'mask': route_attn} | |
self.layers = execute_type(layers, args_route = attn_route_map) | |
def forward(self, x, **kwargs): | |
return self.layers(x, **kwargs) | |