|
import math |
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange, repeat |
|
try: |
|
from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_func, apply_rotary_emb_kv_ |
|
except: |
|
apply_rotary_emb_qkv_, apply_rotary_emb_func, apply_rotary_emb_kv_ = None, None, None |
|
|
|
class RelativePositionalEncoding(nn.Module): |
|
|
|
def __init__(self, relative_attention_num_buckets, relative_attention_max_distance, n_heads, max_sequence_length, bidirectional=True, randomized_position=False): |
|
|
|
super().__init__() |
|
|
|
self.relative_attention_num_buckets = relative_attention_num_buckets |
|
self.relative_attention_max_distance = relative_attention_max_distance |
|
self.n_heads = n_heads |
|
self.max_sequence_length = max_sequence_length |
|
self.bidirectional = bidirectional |
|
self.randomized_position = randomized_position |
|
|
|
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) |
|
|
|
@staticmethod |
|
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): |
|
""" |
|
Adapted from Mesh Tensorflow: |
|
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 |
|
|
|
Translate relative position to a bucket number for relative attention. The relative position is defined as |
|
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to |
|
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for |
|
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative |
|
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. |
|
This should allow for more graceful generalization to longer sequences than the model has been trained on |
|
|
|
Args: |
|
relative_position: an int32 Tensor |
|
bidirectional: a boolean - whether the attention is bidirectional |
|
num_buckets: an integer |
|
max_distance: an integer |
|
|
|
Returns: |
|
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) |
|
""" |
|
relative_buckets = 0 |
|
if bidirectional: |
|
num_buckets //= 2 |
|
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets |
|
relative_position = torch.abs(relative_position) |
|
else: |
|
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) |
|
|
|
|
|
|
|
max_exact = num_buckets // 2 |
|
is_small = relative_position < max_exact |
|
|
|
|
|
relative_position_if_large = max_exact + ( |
|
torch.log(relative_position.float() / max_exact) |
|
/ math.log(max_distance / max_exact) |
|
* (num_buckets - max_exact) |
|
).to(torch.long) |
|
relative_position_if_large = torch.min( |
|
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) |
|
) |
|
|
|
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) |
|
return relative_buckets |
|
|
|
def compute_bias(self, query_length, key_length, device=None): |
|
"""Compute binned relative position bias""" |
|
if device is None: |
|
device = self.relative_attention_bias.weight.device |
|
|
|
if self.randomized_position: |
|
context_position = torch.arange(self.max_sequence_length, dtype=torch.long, device=device) |
|
context_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:query_length]) |
|
context_indices_rand[0] = 0 |
|
context_position = context_position[context_indices_rand][:, None] |
|
|
|
memory_position = torch.arange(self.max_sequence_length, dtype=torch.long, device=device) |
|
memory_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:key_length]) |
|
memory_indices_rand[0] = 0 |
|
memory_position = memory_position[memory_indices_rand][None, :] |
|
else: |
|
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] |
|
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] |
|
|
|
relative_position = memory_position - context_position |
|
|
|
relative_position_bucket = self._relative_position_bucket( |
|
relative_position, |
|
bidirectional=self.bidirectional, |
|
num_buckets=self.relative_attention_num_buckets, |
|
max_distance=self.relative_attention_max_distance, |
|
) |
|
values = self.relative_attention_bias(relative_position_bucket) |
|
values = values.permute([2, 0, 1]).unsqueeze(0) |
|
return values |
|
|
|
def forward(self, q, k=None, v=None): |
|
|
|
query_length = q.shape[1] |
|
key_length = k.shape[1] if k is not None else query_length |
|
bias = self.compute_bias(query_length, key_length, device=q.device).contiguous().to(q.dtype) |
|
|
|
return q, k, v, bias |
|
|
|
|
|
class ALiBiPositionalEncoding(nn.Module): |
|
|
|
def __init__(self, max_sequence_length, num_heads, mode='symetric', randomized_position=False): |
|
|
|
super().__init__() |
|
|
|
self.max_sequence_length = max_sequence_length |
|
self.num_heads = num_heads |
|
self.mode = mode |
|
self.randomized_position = randomized_position |
|
|
|
self.alibi_bias = self.build_alibi_bias_matrix(num_heads, max_sequence_length, mode) |
|
|
|
@staticmethod |
|
def fill_with_neg_inf(t): |
|
"""FP16-compatible function that fills a tensor with -inf.""" |
|
return t.float().fill_(float("-inf")).type_as(t) |
|
|
|
def get_slopes(self, n): |
|
|
|
def get_slopes_power_of_2(n): |
|
start = (2**(-2**-(math.log2(n)-3))) |
|
ratio = start |
|
return [start*ratio**i for i in range(n)] |
|
|
|
if math.log2(n).is_integer(): |
|
return get_slopes_power_of_2(n) |
|
else: |
|
closest_power_of_2 = 2**math.floor(math.log2(n)) |
|
return get_slopes_power_of_2(closest_power_of_2) + self.get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2] |
|
|
|
def build_symetric_alibi_bias_matrix(self, num_heads, maxpos): |
|
|
|
context_position = torch.arange(maxpos)[:, None] |
|
memory_position = torch.arange(maxpos)[None, :] |
|
|
|
relative_position = memory_position - context_position |
|
relative_position = torch.abs(relative_position).unsqueeze(0).expand(num_heads, -1,-1) |
|
|
|
slopes = torch.Tensor(self.get_slopes(num_heads)) * -1 |
|
alibi = slopes.unsqueeze(1).unsqueeze(1) * relative_position |
|
return alibi.view(1, num_heads, maxpos, maxpos) |
|
|
|
def build_asymetric_alibi_bias_matrix(self, num_heads, maxpos): |
|
_future_mask_right = torch.triu(self.fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1).unsqueeze(0).repeat(num_heads // 2, 1, 1) |
|
_future_mask_left = torch.tril(self.fill_with_neg_inf(torch.zeros([maxpos, maxpos])), -1).unsqueeze(0).repeat(num_heads // 2, 1, 1) |
|
|
|
nonsym_mask = torch.cat((_future_mask_right, _future_mask_left), dim = 0).unsqueeze(0) |
|
slopes = torch.Tensor(self.get_slopes(num_heads // 2)) * -1 |
|
|
|
context_position = torch.arange(maxpos)[:, None] |
|
memory_position = torch.arange(maxpos)[None, :] |
|
|
|
relative_position = memory_position - context_position |
|
relative_position = torch.abs(relative_position).unsqueeze(0).expand(num_heads // 2, -1,-1) |
|
|
|
alibi = slopes.unsqueeze(1).unsqueeze(1) * relative_position |
|
alibi = alibi.view(1, num_heads // 2, maxpos, maxpos) |
|
alibi = alibi.repeat(1, 2, 1, 1) |
|
|
|
return alibi.view(1, num_heads, maxpos, maxpos) + nonsym_mask.view(1, num_heads, maxpos, maxpos) |
|
|
|
|
|
def build_alibi_bias_matrix(self, num_heads, maxpos, mode='symetric'): |
|
if mode == 'symetric': |
|
return self.build_symetric_alibi_bias_matrix(num_heads, maxpos) |
|
elif mode == 'asymetric': |
|
return self.build_asymetric_alibi_bias_matrix(num_heads, maxpos) |
|
else: |
|
raise ValueError("ALiBi mode " + mode + " is not implemented.") |
|
|
|
def forward(self, q, k=None, v=None): |
|
|
|
query_length = q.shape[1] |
|
key_length = k.shape[1] if k is not None else query_length |
|
assert (self.alibi_bias.shape[1] < query_length) & (self.alibi_bias.shape[1] < key_length), "Sequence length larger than allowed alibi bound" |
|
|
|
if self.randomized_position: |
|
query_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:query_length]) |
|
key_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:key_length]) |
|
|
|
|
|
query_indices_rand[0] = 0 |
|
key_indices_rand[0] = 0 |
|
|
|
bias = self.alibi_bias[:, :, query_indices_rand, key_indices_rand].to(q.device) |
|
|
|
else: |
|
bias = self.alibi_bias[:, :, :query_length, :key_length].to(q.device) |
|
|
|
return q, k, v, bias.to(q.dtype).contiguous() |
|
|
|
class RotaryPositionalEncoding(nn.Module): |
|
|
|
def __init__(self, dim, |
|
max_sequence_length, |
|
base=10000.0, |
|
interleaved=False, |
|
scale_base=None, |
|
randomized_position=False): |
|
|
|
super().__init__() |
|
|
|
self.max_sequence_length = max_sequence_length |
|
self.randomized_position = randomized_position |
|
|
|
self.dim = dim |
|
self.base = base |
|
self.interleaved = interleaved |
|
self.scale_base = scale_base |
|
|
|
inv_freq = self._compute_inv_freq() |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
scale = ( |
|
(torch.arange(0, dim, 2, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) |
|
if scale_base is not None |
|
else None |
|
) |
|
self.register_buffer("scale", scale, persistent=False) |
|
|
|
self._cos_cached = None |
|
self._sin_cached = None |
|
self._cos_k_cached = None |
|
self._sin_k_cached = None |
|
|
|
def _compute_inv_freq(self, device=None): |
|
return 1.0 / ( |
|
self.base |
|
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) |
|
) |
|
|
|
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): |
|
|
|
|
|
|
|
if ( |
|
self._cos_cached is None |
|
or self._cos_cached.device != device |
|
or self._cos_cached.dtype != dtype |
|
or (self.training and self._cos_cached.is_inference()) |
|
): |
|
|
|
|
|
|
|
inv_freq = self._compute_inv_freq(device=device) |
|
|
|
|
|
|
|
t = torch.arange(seqlen, device=device, dtype=dtype) |
|
freqs = torch.outer(t, inv_freq) |
|
if self.scale is None: |
|
self._cos_cached = torch.cos(freqs).to(dtype) |
|
self._sin_cached = torch.sin(freqs).to(dtype) |
|
self._cos_k_cached = None |
|
self._sin_k_cached = None |
|
else: |
|
power = ( |
|
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) |
|
- seqlen // 2 |
|
) / self.scale_base |
|
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") |
|
|
|
self._cos_cached = (torch.cos(freqs) * scale).to(dtype) |
|
self._sin_cached = (torch.sin(freqs) * scale).to(dtype) |
|
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) |
|
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) |
|
|
|
def forward(self, q, k=None, v=None): |
|
|
|
if self._cos_cached is None: |
|
self._update_cos_sin_cache(self.max_sequence_length, device=q.device, dtype=q.dtype) |
|
|
|
if k is None and v is None: |
|
q = apply_rotary_emb_qkv_( |
|
q, |
|
self._cos_cached, |
|
self._sin_cached, |
|
self._cos_k_cached, |
|
self._sin_k_cached, |
|
interleaved=self.interleaved, |
|
seqlen_offsets=0 |
|
) |
|
elif v is None and k is not None: |
|
q = apply_rotary_emb_func( |
|
q, |
|
self._cos_cached, |
|
self._sin_cached, |
|
interleaved=self.interleaved, |
|
inplace=True, |
|
seqlen_offsets=0 |
|
) |
|
|
|
k = apply_rotary_emb_kv_( |
|
k, |
|
self._cos_cached if self._cos_k_cached is None else self._cos_k_cached, |
|
self._sin_cached if self._sin_k_cached is None else self._sin_k_cached, |
|
interleaved=self.interleaved, |
|
seqlen_offsets=0, |
|
) |
|
else: |
|
q = apply_rotary_emb_func( |
|
q, |
|
self._cos_cached, |
|
self._sin_cached, |
|
interleaved=self.interleaved, |
|
inplace=True, |
|
seqlen_offsets=0 |
|
) |
|
|
|
k = apply_rotary_emb_func( |
|
k, |
|
self._cos_cached if self._cos_k_cached is None else self._cos_k_cached, |
|
self._sin_cached if self._sin_k_cached is None else self._sin_k_cached, |
|
interleaved=self.interleaved, |
|
seqlen_offsets=0, |
|
) |
|
|
|
v = apply_rotary_emb_func( |
|
v, |
|
self._cos_cached if self._cos_k_cached is None else self._cos_k_cached, |
|
self._sin_cached if self._sin_k_cached is None else self._sin_k_cached, |
|
interleaved=self.interleaved, |
|
seqlen_offsets=0, |
|
) |
|
|
|
return q, k, v, None |
|
|
|
class FIRE(nn.Module): |
|
|
|
def __init__(self, num_heads=12, mlp_width=32, init_c=0.1, init_L=512., eps=1e-6): |
|
""" |
|
FIRE attention bias module. |
|
|
|
Args: |
|
num_heads: number of attention heads. |
|
mlp_width: Width of MLP. |
|
init_c: initial value of log transformation parameter |
|
init_L: initial value of thresholding parameter |
|
eps: small constant for numerical stability |
|
""" |
|
|
|
super(FIRE, self).__init__() |
|
|
|
|
|
self.mlp = nn.Sequential( |
|
nn.Linear(1, mlp_width), |
|
nn.ReLU(), |
|
nn.Linear(mlp_width, num_heads) |
|
) |
|
|
|
|
|
self.c = nn.Parameter(torch.tensor(init_c)) |
|
|
|
|
|
|
|
self.init_L = nn.Parameter(torch.tensor(init_L), |
|
requires_grad=False) |
|
|
|
self.L_multiplier = nn.Parameter(torch.tensor(1.0)) |
|
self.eps = eps |
|
|
|
def apply_fire(self, seq_length, device): |
|
""" |
|
Compute FIRE attention bias. |
|
|
|
Args: |
|
x: input sequence, |
|
shape [bsz, seq_len, num_heads, hidden_dim] |
|
|
|
Returns: |
|
attention bias, |
|
shape [1, num_heads, seq_len, seq_len] |
|
""" |
|
positions = torch.arange(seq_length, |
|
dtype=torch.float32, |
|
device=device) |
|
|
|
rel_distance = positions[:, None] - positions[None, :] |
|
|
|
|
|
threshold = torch.abs(self.L_multiplier * self.init_L) |
|
pos_normalizer = torch.max(positions, threshold) |
|
pos_normalizer = pos_normalizer[:, None] |
|
|
|
|
|
|
|
rel_distance = torch.sign(rel_distance) * torch.log( |
|
torch.abs(self.c * rel_distance) + 1 |
|
) |
|
pos_normalizer = torch.log( |
|
torch.abs(self.c * pos_normalizer) + 1 |
|
) + self.eps |
|
|
|
|
|
normalized_distance = rel_distance / pos_normalizer |
|
fire_bias = self.mlp(normalized_distance.unsqueeze(-1)) |
|
fire_bias = fire_bias.unsqueeze(0).permute(0, 3, 1, 2) |
|
return fire_bias |
|
|
|
def forward(self, q, k=None, v=None): |
|
|
|
bias = self.apply_fire(q.shape[1], device=q.device).contiguous().to(q.dtype) |
|
|
|
return q, k, v, bias |
|
|