asigalov61 commited on
Commit
6593aaf
1 Parent(s): 03a225c

Upload 3 files

Browse files
Files changed (3) hide show
  1. TMIDIX.py +0 -0
  2. midi_to_colab_audio.py +0 -0
  3. x_transformer_1_23_2.py +2458 -0
TMIDIX.py ADDED
The diff for this file is too large to render. See raw diff
 
midi_to_colab_audio.py ADDED
The diff for this file is too large to render. See raw diff
 
x_transformer_1_23_2.py ADDED
@@ -0,0 +1,2458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #===================================================================================================================
2
+ #
3
+ # X Trasformer Module
4
+ #
5
+ # Partial x-transformers code With useful modifications
6
+ #
7
+ # Version 1.0
8
+ #
9
+ # Original source code courtesy of lucidrains
10
+ # https://github.com/lucidrains/x-transformers
11
+ #
12
+ # Original source code retrieved on 10/10/2023
13
+ #
14
+ # Project Los Angeles
15
+ # Tegridy Code 2023
16
+
17
+ #===================================================================================================================
18
+
19
+ # Critical dependencies
20
+ #
21
+ # !pip install torch
22
+ # !pip install einops
23
+
24
+ #===================================================================================================================
25
+
26
+ from functools import partial
27
+ from typing import Optional, Tuple
28
+
29
+ import torch
30
+ from torch import nn, einsum, Tensor
31
+ import torch.nn.functional as F
32
+
33
+ from collections import namedtuple
34
+ from functools import wraps
35
+ from packaging import version
36
+ from dataclasses import dataclass
37
+
38
+ from einops import rearrange, repeat
39
+
40
+ # constants
41
+
42
+ EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
43
+
44
+ @dataclass
45
+ class Intermediates:
46
+ qk_similarities: Optional[Tensor] = None
47
+ pre_softmax_attn: Optional[Tensor] = None
48
+ post_softmax_attn: Optional[Tensor] = None
49
+ cached_kv: Optional[Tuple[Tensor, Tensor]] = None
50
+
51
+ def to_tuple(self):
52
+ return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
53
+
54
+ # helpers
55
+
56
+ def exists(val):
57
+ return val is not None
58
+
59
+ def default(val, d):
60
+ return val if exists(val) else d
61
+
62
+ def compact(arr):
63
+ return [*filter(exists, arr)]
64
+
65
+ def once(fn):
66
+ called = False
67
+ @wraps(fn)
68
+ def inner(x):
69
+ nonlocal called
70
+ if called:
71
+ return
72
+ called = True
73
+ return fn(x)
74
+ return inner
75
+
76
+ print_once = once(print)
77
+
78
+ # functions for creating causal mask
79
+ # need a special one for onnx cpu (no support for .triu)
80
+
81
+ def create_causal_mask(i, j, device):
82
+ return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
83
+
84
+ def onnx_create_causal_mask(i, j, device):
85
+ r = torch.arange(i, device = device)
86
+ causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
87
+ causal_mask = F.pad(causal_mask, (j - i, 0), value = False)
88
+ return causal_mask
89
+
90
+ # main class
91
+
92
+ class Attend(nn.Module):
93
+ def __init__(
94
+ self,
95
+ *,
96
+ dropout = 0.,
97
+ causal = False,
98
+ heads = None,
99
+ talking_heads = False,
100
+ sparse_topk = None,
101
+ scale = None,
102
+ qk_norm = False,
103
+ flash = False,
104
+ add_zero_kv = False,
105
+ onnxable = False
106
+ ):
107
+ super().__init__()
108
+ self.scale = scale
109
+ self.qk_norm = qk_norm
110
+
111
+ self.causal = causal
112
+ self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
113
+
114
+ self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax
115
+
116
+ self.dropout = dropout
117
+ self.attn_dropout = nn.Dropout(dropout)
118
+
119
+ # talking heads
120
+
121
+ assert not (flash and talking_heads), 'talking heads not compatible with flash attention'
122
+
123
+ self.talking_heads = talking_heads
124
+ if talking_heads:
125
+ self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
126
+ self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
127
+
128
+ # sparse topk
129
+
130
+ assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
131
+ self.sparse_topk = sparse_topk
132
+
133
+ # add a key / value token composed of zeros
134
+ # in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html
135
+
136
+ self.add_zero_kv = add_zero_kv
137
+
138
+ # flash attention
139
+
140
+ self.flash = flash
141
+ assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
142
+
143
+ # determine efficient attention configs for cuda and cpu
144
+
145
+ self.cpu_config = EfficientAttentionConfig(True, True, True)
146
+ self.cuda_config = None
147
+
148
+ if not torch.cuda.is_available() or not flash:
149
+ return
150
+
151
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
152
+
153
+ major, minor = device_properties.major, device_properties.minor
154
+
155
+ if (major, minor) == (8, 0):
156
+ print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
157
+ self.cuda_config = EfficientAttentionConfig(True, False, False)
158
+ elif (major, minor) == (9, 0):
159
+ print_once('H100 GPU detected, using flash attention')
160
+ self.cuda_config = EfficientAttentionConfig(True, False, False)
161
+ else:
162
+ print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
163
+ self.cuda_config = EfficientAttentionConfig(False, True, True)
164
+
165
+ def flash_attn(
166
+ self,
167
+ q, k, v,
168
+ mask = None,
169
+ attn_bias = None
170
+ ):
171
+ batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
172
+
173
+ # Recommended for multi-query single-key-value attention by Tri Dao
174
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
175
+
176
+ if k.ndim == 3:
177
+ k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
178
+
179
+ if v.ndim == 3:
180
+ v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
181
+
182
+ # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
183
+
184
+ if self.qk_norm:
185
+ default_scale = q.shape[-1] ** -0.5
186
+ q = q * (self.scale / default_scale)
187
+
188
+ # Check if mask exists and expand to compatible shape
189
+ # The mask is B L, so it would have to be expanded to B H N L
190
+
191
+ causal = self.causal
192
+
193
+ # in the case of kv caching with one token (q_len == 1), just turn off causal masking
194
+ # in speculative decoding, this may go up to 5-6, so right aligned causal mask will be needed there
195
+
196
+ if q_len == 1 and causal:
197
+ causal = False
198
+
199
+ # expand key padding mask
200
+
201
+ if exists(mask):
202
+ assert mask.ndim == 4
203
+ mask = mask.expand(batch, heads, q_len, k_len)
204
+
205
+ # handle kv cache - this should be bypassable in updated flash attention 2
206
+
207
+ if k_len > q_len and causal:
208
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
209
+ if not exists(mask):
210
+ mask = ~causal_mask
211
+ else:
212
+ mask = mask & ~causal_mask
213
+ causal = False
214
+
215
+ # manually handle causal mask, if another mask was given
216
+
217
+ row_is_entirely_masked = None
218
+
219
+ if exists(mask) and causal:
220
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
221
+ mask = mask & ~causal_mask
222
+
223
+ # protect against an entire row being masked out
224
+
225
+ row_is_entirely_masked = ~mask.any(dim = -1)
226
+ mask[..., 0] = mask[..., 0] | row_is_entirely_masked
227
+
228
+ causal = False
229
+
230
+ # handle alibi positional bias
231
+ # convert from bool to float
232
+
233
+ if exists(attn_bias):
234
+ attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)
235
+
236
+ # if mask given, the mask would already contain the causal mask from above logic
237
+ # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
238
+
239
+ mask_value = -torch.finfo(q.dtype).max
240
+
241
+ if exists(mask):
242
+ attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
243
+ elif causal:
244
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
245
+ attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
246
+ causal = False
247
+
248
+ # scaled_dot_product_attention handles attn_mask either as bool or additive bias
249
+ # make it an additive bias here
250
+
251
+ mask = attn_bias
252
+
253
+ # Check if there is a compatible device for flash attention
254
+
255
+ config = self.cuda_config if is_cuda else self.cpu_config
256
+
257
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
258
+
259
+ with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=True, enable_flash=True):
260
+ out = F.scaled_dot_product_attention(
261
+ q, k, v,
262
+ attn_mask = mask,
263
+ dropout_p = self.dropout if self.training else 0.,
264
+ is_causal = causal
265
+ )
266
+
267
+ # for a row that is entirely masked out, should zero out the output of that row token
268
+
269
+ if exists(row_is_entirely_masked):
270
+ out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
271
+
272
+ return out, Intermediates()
273
+
274
+ def forward(
275
+ self,
276
+ q, k, v,
277
+ mask = None,
278
+ attn_bias = None,
279
+ prev_attn = None
280
+ ):
281
+ """
282
+ einstein notation
283
+ b - batch
284
+ h - heads
285
+ n, i, j - sequence length (base sequence length, source, target)
286
+ d - feature dimension
287
+ """
288
+
289
+ n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device
290
+
291
+ scale = default(self.scale, q.shape[-1] ** -0.5)
292
+
293
+ causal = self.causal
294
+
295
+ # handle kv cached decoding
296
+
297
+ if n == 1 and causal:
298
+ causal = False
299
+
300
+ # handle grouped multi-query attention
301
+
302
+ if kv_heads == 1:
303
+ k, v = map(lambda t: rearrange(t, 'b 1 n d -> b n d'), (k, v))
304
+ elif kv_heads < heads:
305
+ k, v = map(lambda t: repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads), (k, v))
306
+
307
+ # handle zero kv, as means for allowing network to attend to nothing
308
+
309
+ if self.add_zero_kv:
310
+ k, v = map(lambda t: F.pad(t, (0, 0, 1, 0), value = 0.), (k, v))
311
+
312
+ if exists(mask):
313
+ mask = F.pad(mask, (1, 0), value = True)
314
+
315
+ if exists(attn_bias):
316
+ attn_bias = F.pad(attn_bias, (1, 0), value = 0.)
317
+
318
+ if self.flash:
319
+ assert not exists(prev_attn), 'residual attention not compatible with flash attention'
320
+ return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
321
+
322
+ kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
323
+
324
+ dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
325
+
326
+ if exists(prev_attn):
327
+ dots = dots + prev_attn
328
+
329
+ qk_similarities = dots.clone()
330
+
331
+ if self.talking_heads:
332
+ dots = self.pre_softmax_talking_heads(dots)
333
+
334
+ if exists(attn_bias):
335
+ dots = dots + attn_bias
336
+
337
+ i, j, dtype = *dots.shape[-2:], dots.dtype
338
+
339
+ mask_value = -torch.finfo(dots.dtype).max
340
+
341
+ if exists(self.sparse_topk) and self.sparse_topk < j:
342
+ top_values, _ = dots.topk(self.sparse_topk, dim = -1)
343
+ sparse_topk_mask = dots < top_values[..., -1:]
344
+ mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask
345
+
346
+ if exists(mask):
347
+ dots = dots.masked_fill(~mask, mask_value)
348
+
349
+ if causal:
350
+ causal_mask = self.create_causal_mask(i, j, device = device)
351
+ dots = dots.masked_fill(causal_mask, mask_value)
352
+
353
+ pre_softmax_attn = dots.clone()
354
+
355
+ attn = self.attn_fn(dots, dim = -1)
356
+ attn = attn.type(dtype)
357
+
358
+ post_softmax_attn = attn.clone()
359
+
360
+ attn = self.attn_dropout(attn)
361
+
362
+ if self.talking_heads:
363
+ attn = self.post_softmax_talking_heads(attn)
364
+
365
+ out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
366
+
367
+ intermediates = Intermediates(
368
+ qk_similarities = qk_similarities,
369
+ pre_softmax_attn = pre_softmax_attn,
370
+ post_softmax_attn = post_softmax_attn
371
+ )
372
+
373
+ return out, intermediates
374
+
375
+ #===================================================================================================================
376
+
377
+ from math import ceil, log
378
+ from typing import Optional, Union, Tuple, Callable
379
+
380
+ import torch
381
+ from torch import nn, Tensor
382
+ from torch.nn import Module
383
+ import torch.nn.functional as F
384
+
385
+ from einops import rearrange, pack, unpack
386
+
387
+ def exists(val):
388
+ return val is not None
389
+
390
+ def default(val, d):
391
+ return val if exists(val) else d
392
+
393
+ def identity(t, *args, **kwargs):
394
+ return t
395
+
396
+ def cast_tuple(t, length = 1):
397
+ return t if isinstance(t, tuple) else (t,) * length
398
+
399
+ def eval_decorator(fn):
400
+ def inner(self, *args, **kwargs):
401
+ was_training = self.training
402
+ self.eval()
403
+ out = fn(self, *args, **kwargs)
404
+ self.train(was_training)
405
+ return out
406
+ return inner
407
+
408
+ # for variable lengthed prefixes
409
+
410
+ def align_right(t, lens, pad_id = 0):
411
+ batch, seq_len, device, dtype = *t.shape, t.device, t.dtype
412
+
413
+ assert lens.ndim == 1 and lens.shape[0] == batch
414
+ assert lens.amax() <= seq_len
415
+
416
+ pad_lens = seq_len - lens
417
+ max_pad_len = pad_lens.amax()
418
+
419
+ batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
420
+ prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
421
+
422
+ t = F.pad(t, (max_pad_len, 0), value = 0)
423
+ offset = max_pad_len - pad_lens
424
+
425
+ aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
426
+ return aligned
427
+
428
+ # nucleus
429
+
430
+ def top_p(logits, thres = 0.9):
431
+ sorted_logits, sorted_indices = torch.sort(logits, descending = True)
432
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)
433
+
434
+ sorted_indices_to_remove = cum_probs > thres
435
+ sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)
436
+
437
+ sorted_logits[sorted_indices_to_remove] = float('-inf')
438
+ return sorted_logits.scatter(1, sorted_indices, sorted_logits)
439
+
440
+ # topk
441
+
442
+ def top_k(logits, frac_num_tokens = 0.1, k = None):
443
+ num_tokens = logits.shape[-1]
444
+
445
+ k = default(k, ceil(frac_num_tokens * num_tokens))
446
+ k = min(k, num_tokens)
447
+
448
+ val, ind = torch.topk(logits, k)
449
+ probs = torch.full_like(logits, float('-inf'))
450
+ probs.scatter_(1, ind, val)
451
+ return probs
452
+
453
+ # top_a
454
+
455
+ def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02):
456
+ probs = F.softmax(logits, dim = -1)
457
+ max_probs = torch.amax(probs, dim = -1, keepdim = True)
458
+ limit = torch.pow(max_probs, min_p_pow) * min_p_ratio
459
+ return torch.where(probs < limit, float('-inf'), logits)
460
+
461
+ # contrastive decoding function
462
+
463
+ def contrastive_decode_fn(
464
+ expert_logits,
465
+ amateur_logits,
466
+ alpha = 0.1,
467
+ beta = 0.5
468
+ ):
469
+ """
470
+ Appendix A Algorithm 2
471
+ https://arxiv.org/abs/2309.09117
472
+ """
473
+
474
+ cutoff = log(alpha) + expert_logits.amax(dim = -1, keepdim = True)
475
+ diffs = (1 + beta) * expert_logits - beta * amateur_logits
476
+ contrastive_decode_logits = diffs.masked_fill(expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max)
477
+ return contrastive_decode_logits
478
+
479
+ # autoregressive wrapper class
480
+
481
+ class AutoregressiveWrapper(Module):
482
+ def __init__(
483
+ self,
484
+ net,
485
+ ignore_index = -100,
486
+ pad_value = 0,
487
+ mask_prob = 0.,
488
+ add_attn_z_loss = False
489
+ ):
490
+ super().__init__()
491
+ self.pad_value = pad_value
492
+ self.ignore_index = ignore_index
493
+
494
+ self.net = net
495
+ self.max_seq_len = net.max_seq_len
496
+
497
+ # paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432
498
+ assert mask_prob < 1.
499
+ self.mask_prob = mask_prob
500
+
501
+ # whether to add router z-loss
502
+ self.add_attn_z_loss = add_attn_z_loss
503
+
504
+ @torch.no_grad()
505
+ @eval_decorator
506
+ def generate(
507
+ self,
508
+ prompts,
509
+ seq_len,
510
+ eos_token = None,
511
+ temperature = 1.,
512
+ prompt_lens: Optional[Tensor] = None,
513
+ filter_logits_fn: Callable = top_k,
514
+ restrict_to_max_seq_len = True,
515
+ amateur_model: Optional[Union[Module, Tuple[Module]]] = None,
516
+ filter_kwargs: dict = dict(),
517
+ contrastive_decode_kwargs: Union[dict, Tuple[dict]] = dict(
518
+ beta = 0.5,
519
+ alpha = 0.1
520
+ ),
521
+ cache_kv = True,
522
+ verbose=True,
523
+ return_prime=False,
524
+ **kwargs
525
+ ):
526
+ max_seq_len, device = self.max_seq_len, prompts.device
527
+
528
+ prompts, ps = pack([prompts], '* n')
529
+
530
+ b, t = prompts.shape
531
+
532
+ # handle variable lengthed prompts (prefixes)
533
+
534
+ seq_start_pos = None
535
+ if exists(prompt_lens):
536
+ prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
537
+ seq_start_pos = t - prompt_lens
538
+
539
+ # output from which sampled tokens appended to
540
+
541
+ out = prompts
542
+
543
+ if verbose:
544
+ print("Generating sequence of max length:", seq_len)
545
+
546
+ # kv caches
547
+
548
+ cache = None
549
+
550
+ # if doing contrastive decoding, turn off filter automatically
551
+
552
+ if exists(amateur_model):
553
+ amateur_model = cast_tuple(amateur_model)
554
+ contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
555
+
556
+ assert len(amateur_model) == len(contrastive_decode_kwargs)
557
+
558
+ amateur_caches = [None] * len(amateur_model)
559
+ filter_logits_fn = identity
560
+
561
+ for i, module in enumerate(amateur_model):
562
+ if isinstance(module, AutoregressiveWrapper):
563
+ amateur_model[i] = module.net
564
+
565
+ module.eval()
566
+
567
+ # sampling up to seq_len
568
+
569
+ for sl in range(seq_len):
570
+
571
+ if restrict_to_max_seq_len:
572
+ x = out[:, -max_seq_len:]
573
+
574
+ if exists(cache):
575
+ for inter in cache.attn_intermediates:
576
+ inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
577
+
578
+ logits, new_cache = self.net(
579
+ x,
580
+ return_intermediates = True,
581
+ cache = cache,
582
+ seq_start_pos = seq_start_pos,
583
+ **kwargs
584
+ )
585
+
586
+ if cache_kv and self.net.can_cache_kv:
587
+ cache = new_cache
588
+
589
+ logits = logits[:, -1]
590
+
591
+ # handle contrastive decoding, Li et al.
592
+ # https://arxiv.org/abs/2210.15097
593
+
594
+ if exists(amateur_model):
595
+ for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
596
+ amateur_logits, next_amateur_cache = amateur(
597
+ x,
598
+ return_intermediates = True,
599
+ cache = amateur_cache,
600
+ seq_start_pos = seq_start_pos,
601
+ **kwargs
602
+ )
603
+
604
+ amateur_logits = amateur_logits[:, -1]
605
+
606
+ assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
607
+ logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
608
+
609
+ if cache_kv and amateur.can_cache_kv:
610
+ amateur_caches[i] = next_amateur_cache
611
+
612
+ # filter by top_k, top_p (nucleus), top_a, or custom
613
+
614
+ filtered_logits = filter_logits_fn(logits, **filter_kwargs)
615
+
616
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
617
+
618
+ sample = torch.multinomial(probs, 1)
619
+
620
+ out = torch.cat((out, sample), dim=-1)
621
+
622
+ if verbose:
623
+ if sl % 32 == 0:
624
+ print(sl, '/', seq_len)
625
+
626
+ if exists(eos_token):
627
+ is_eos_tokens = (out == eos_token)
628
+
629
+ if is_eos_tokens.any(dim = -1).all():
630
+ # mask out everything after the eos tokens
631
+ shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
632
+ mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
633
+ out = out.masked_fill(mask, self.pad_value)
634
+
635
+ if verbose:
636
+ print('Model called the end of sequence at:', sl, '/', seq_len)
637
+
638
+ break
639
+
640
+ if return_prime:
641
+ return out[:, :]
642
+
643
+ else:
644
+ return out[:, t:]
645
+
646
+ # out, = unpack(out, ps, '* n')
647
+
648
+ # return out
649
+
650
+ def compute_accuracy(self, logits, labels):
651
+ out = torch.argmax(logits, dim=-1)
652
+ out = out.flatten()
653
+ labels = labels.flatten()
654
+
655
+ mask = (labels != self.ignore_index) # can also be self.pad_value (your choice)
656
+ out = out[mask]
657
+ labels = labels[mask]
658
+
659
+ num_right = (out == labels)
660
+ num_right = torch.sum(num_right).type(torch.float32)
661
+
662
+ acc = num_right / len(labels)
663
+ return acc
664
+
665
+ def forward(self, x, **kwargs):
666
+ seq, ignore_index, add_attn_z_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss
667
+
668
+ inp, target = x[:, :-1], x[:, 1:]
669
+ inp = torch.where(inp == ignore_index, self.pad_value, inp)
670
+
671
+ if self.mask_prob > 0.:
672
+ rand = torch.randn(inp.shape, device = x.device)
673
+ rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out
674
+ num_mask = min(int(seq * self.mask_prob), seq - 1)
675
+ indices = rand.topk(num_mask, dim = -1).indices
676
+ mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
677
+ kwargs.update(self_attn_kv_mask = mask)
678
+
679
+ logits, cache = self.net(
680
+ inp,
681
+ return_intermediates = True,
682
+ return_attn_z_loss = add_attn_z_loss,
683
+ **kwargs
684
+ )
685
+
686
+ acc = self.compute_accuracy(logits, target)
687
+
688
+ loss = F.cross_entropy(
689
+ rearrange(logits, 'b n c -> b c n'),
690
+ target,
691
+ ignore_index = ignore_index
692
+ )
693
+
694
+ if add_attn_z_loss:
695
+ loss = loss + cache.attn_z_loss
696
+
697
+ return loss, acc
698
+
699
+ #===============================================================================
700
+
701
+ import math
702
+ from random import random
703
+
704
+ import torch
705
+ from torch import nn, einsum, Tensor
706
+ import torch.nn.functional as F
707
+
708
+ from functools import partial, wraps
709
+ from inspect import isfunction
710
+ from collections import namedtuple
711
+ from dataclasses import dataclass
712
+ from typing import List, Callable, Optional
713
+
714
+ from einops import rearrange, repeat, reduce, pack, unpack
715
+ from einops.layers.torch import Rearrange
716
+
717
+ # constants
718
+
719
+ DEFAULT_DIM_HEAD = 64
720
+
721
+ @dataclass
722
+ class LayerIntermediates:
723
+ hiddens: Optional[List[Tensor]] = None
724
+ attn_intermediates: Optional[List[Intermediates]] = None
725
+ layer_hiddens: Optional[List[Tensor]] = None
726
+ attn_z_loss: Optional[Tensor] = None
727
+ mems: Optional[Tensor] = None
728
+
729
+ # helpers
730
+
731
+ def exists(val):
732
+ return val is not None
733
+
734
+ def default(val, d):
735
+ if exists(val):
736
+ return val
737
+ return d() if isfunction(d) else d
738
+
739
+ def cast_tuple(val, depth):
740
+ return val if isinstance(val, tuple) else (val,) * depth
741
+
742
+ def divisible_by(num, den):
743
+ return (num % den) == 0
744
+
745
+ def maybe(fn):
746
+ @wraps(fn)
747
+ def inner(x, *args, **kwargs):
748
+ if not exists(x):
749
+ return x
750
+ return fn(x, *args, **kwargs)
751
+ return inner
752
+
753
+ class always():
754
+ def __init__(self, val):
755
+ self.val = val
756
+ def __call__(self, *args, **kwargs):
757
+ return self.val
758
+
759
+ class not_equals():
760
+ def __init__(self, val):
761
+ self.val = val
762
+ def __call__(self, x, *args, **kwargs):
763
+ return x != self.val
764
+
765
+ class equals():
766
+ def __init__(self, val):
767
+ self.val = val
768
+ def __call__(self, x, *args, **kwargs):
769
+ return x == self.val
770
+
771
+ def Sequential(*modules):
772
+ return nn.Sequential(*filter(exists, modules))
773
+
774
+ # tensor helpers
775
+
776
+ def max_neg_value(tensor):
777
+ return -torch.finfo(tensor.dtype).max
778
+
779
+ def l2norm(t, groups = 1):
780
+ t = rearrange(t, '... (g d) -> ... g d', g = groups)
781
+ t = F.normalize(t, p = 2, dim = -1)
782
+ return rearrange(t, '... g d -> ... (g d)')
783
+
784
+ def pad_at_dim(t, pad, dim = -1, value = 0.):
785
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
786
+ zeros = ((0, 0) * dims_from_right)
787
+ return F.pad(t, (*zeros, *pad), value = value)
788
+
789
+ def or_reduce(masks):
790
+ head, *body = masks
791
+ for rest in body:
792
+ head = head | rest
793
+ return head
794
+
795
+ # auxiliary loss helpers
796
+
797
+ def calc_z_loss(
798
+ pre_softmax_attns: List[Tensor],
799
+ mask = None,
800
+ weight = 1.
801
+ ):
802
+ # the same loss applied to the mixture of experts router logits in https://arxiv.org/abs/2202.08906
803
+ # in the paper, in a tiny footnote, they mention using it on attention logits with stabilizing effects
804
+ # also used in PaLM as one of the measures
805
+
806
+ lse = 0.
807
+
808
+ for attn in pre_softmax_attns:
809
+ lse = lse + attn.logsumexp(dim = -1)
810
+
811
+ loss = torch.square(lse)
812
+ loss = reduce(loss, 'b h n -> b n', 'sum')
813
+
814
+ if not exists(mask):
815
+ return loss.mean() * weight
816
+
817
+ loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5)
818
+ return loss * weight
819
+
820
+ # init helpers
821
+
822
+ def init_zero_(layer):
823
+ nn.init.constant_(layer.weight, 0.)
824
+ if exists(layer.bias):
825
+ nn.init.constant_(layer.bias, 0.)
826
+
827
+ # keyword argument helpers
828
+
829
+ def pick_and_pop(keys, d):
830
+ values = list(map(lambda key: d.pop(key), keys))
831
+ return dict(zip(keys, values))
832
+
833
+ def group_dict_by_key(cond, d):
834
+ return_val = [dict(),dict()]
835
+ for key in d.keys():
836
+ match = bool(cond(key))
837
+ ind = int(not match)
838
+ return_val[ind][key] = d[key]
839
+ return (*return_val,)
840
+
841
+ def string_begins_with(prefix, str):
842
+ return str.startswith(prefix)
843
+
844
+ def group_by_key_prefix(prefix, d):
845
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
846
+
847
+ def groupby_prefix_and_trim(prefix, d):
848
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
849
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
850
+ return kwargs_without_prefix, kwargs
851
+
852
+ # structured dropout, more effective than traditional attention dropouts
853
+
854
+ def dropout_seq(seq, mask, dropout):
855
+ b, n, *_, device = *seq.shape, seq.device
856
+ logits = torch.randn(b, n, device = device)
857
+
858
+ if exists(mask):
859
+ mask_value = max_neg_value(logits)
860
+ logits = logits.masked_fill(~mask, mask_value)
861
+
862
+ keep_prob = 1. - dropout
863
+ num_keep = max(1, int(keep_prob * n))
864
+ keep_indices = logits.topk(num_keep, dim = 1).indices
865
+
866
+ batch_indices = torch.arange(b, device = device)
867
+ batch_indices = rearrange(batch_indices, 'b -> b 1')
868
+
869
+ seq = seq[batch_indices, keep_indices]
870
+
871
+ if exists(mask):
872
+ seq_counts = mask.sum(dim = -1)
873
+ seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
874
+ keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
875
+
876
+ mask = mask[batch_indices, keep_indices] & keep_mask
877
+
878
+ return seq, mask
879
+
880
+ # activations
881
+
882
+ class ReluSquared(nn.Module):
883
+ def forward(self, x):
884
+ return F.relu(x) ** 2
885
+
886
+ # embedding
887
+
888
+ class TokenEmbedding(nn.Module):
889
+ def __init__(self, dim, num_tokens, l2norm_embed = False):
890
+ super().__init__()
891
+ self.l2norm_embed = l2norm_embed
892
+ self.emb = nn.Embedding(num_tokens, dim)
893
+
894
+ def forward(self, x):
895
+ token_emb = self.emb(x)
896
+ return l2norm(token_emb) if self.l2norm_embed else token_emb
897
+
898
+ # positional embeddings
899
+
900
+ class AbsolutePositionalEmbedding(nn.Module):
901
+ def __init__(self, dim, max_seq_len, l2norm_embed = False):
902
+ super().__init__()
903
+ self.scale = dim ** -0.5 if not l2norm_embed else 1.
904
+ self.max_seq_len = max_seq_len
905
+ self.l2norm_embed = l2norm_embed
906
+ self.emb = nn.Embedding(max_seq_len, dim)
907
+
908
+ def forward(self, x, pos = None, seq_start_pos = None):
909
+ seq_len, device = x.shape[1], x.device
910
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
911
+
912
+ if not exists(pos):
913
+ pos = torch.arange(seq_len, device = device)
914
+
915
+ if exists(seq_start_pos):
916
+ pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
917
+
918
+ pos_emb = self.emb(pos)
919
+ pos_emb = pos_emb * self.scale
920
+ return l2norm(pos_emb) if self.l2norm_embed else pos_emb
921
+
922
+ class ScaledSinusoidalEmbedding(nn.Module):
923
+ def __init__(self, dim, theta = 10000):
924
+ super().__init__()
925
+ assert divisible_by(dim, 2)
926
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
927
+
928
+ half_dim = dim // 2
929
+ freq_seq = torch.arange(half_dim).float() / half_dim
930
+ inv_freq = theta ** -freq_seq
931
+ self.register_buffer('inv_freq', inv_freq, persistent = False)
932
+
933
+ def forward(self, x, pos = None, seq_start_pos = None):
934
+ seq_len, device = x.shape[1], x.device
935
+
936
+ if not exists(pos):
937
+ pos = torch.arange(seq_len, device = device)
938
+
939
+ if exists(seq_start_pos):
940
+ pos = pos - seq_start_pos[..., None]
941
+
942
+ emb = einsum('i, j -> i j', pos, self.inv_freq)
943
+ emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
944
+ return emb * self.scale
945
+
946
+ class RelativePositionBias(nn.Module):
947
+ def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
948
+ super().__init__()
949
+ self.scale = scale
950
+ self.causal = causal
951
+ self.num_buckets = num_buckets
952
+ self.max_distance = max_distance
953
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
954
+
955
+ @staticmethod
956
+ def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
957
+ ret = 0
958
+ n = -relative_position
959
+ if not causal:
960
+ num_buckets //= 2
961
+ ret += (n < 0).long() * num_buckets
962
+ n = torch.abs(n)
963
+ else:
964
+ n = torch.max(n, torch.zeros_like(n))
965
+
966
+ max_exact = num_buckets // 2
967
+ is_small = n < max_exact
968
+
969
+ val_if_large = max_exact + (
970
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
971
+ ).long()
972
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
973
+
974
+ ret += torch.where(is_small, n, val_if_large)
975
+ return ret
976
+
977
+ @property
978
+ def device(self):
979
+ return next(self.parameters()).device
980
+
981
+ def forward(self, i, j):
982
+ device = self.device
983
+ q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
984
+ k_pos = torch.arange(j, dtype = torch.long, device = device)
985
+ rel_pos = k_pos[None, :] - q_pos[:, None]
986
+ rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
987
+ values = self.relative_attention_bias(rp_bucket)
988
+ bias = rearrange(values, 'i j h -> h i j')
989
+ return bias * self.scale
990
+
991
+ class DynamicPositionBias(nn.Module):
992
+ def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
993
+ super().__init__()
994
+ assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
995
+ self.log_distance = log_distance
996
+
997
+ self.mlp = nn.ModuleList([])
998
+
999
+ self.mlp.append(Sequential(
1000
+ nn.Linear(1, dim),
1001
+ nn.LayerNorm(dim) if norm else None,
1002
+ nn.SiLU()
1003
+ ))
1004
+
1005
+ for _ in range(depth - 1):
1006
+ self.mlp.append(Sequential(
1007
+ nn.Linear(dim, dim),
1008
+ nn.LayerNorm(dim) if norm else None,
1009
+ nn.SiLU()
1010
+ ))
1011
+
1012
+ self.mlp.append(nn.Linear(dim, heads))
1013
+
1014
+ @property
1015
+ def device(self):
1016
+ return next(self.parameters()).device
1017
+
1018
+ def forward(self, i, j):
1019
+ assert i == j
1020
+ n, device = j, self.device
1021
+
1022
+ # get the (n x n) matrix of distances
1023
+ seq_arange = torch.arange(n, device = device)
1024
+ context_arange = torch.arange(n, device = device)
1025
+ indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
1026
+ indices += (n - 1)
1027
+
1028
+ # input to continuous positions MLP
1029
+ pos = torch.arange(-n + 1, n, device = device).float()
1030
+ pos = rearrange(pos, '... -> ... 1')
1031
+
1032
+ if self.log_distance:
1033
+ pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)
1034
+
1035
+ for layer in self.mlp:
1036
+ pos = layer(pos)
1037
+
1038
+ # get position biases
1039
+ bias = pos[indices]
1040
+ bias = rearrange(bias, 'i j h -> h i j')
1041
+ return bias
1042
+
1043
+ class AlibiPositionalBias(nn.Module):
1044
+ def __init__(self, heads, total_heads, **kwargs):
1045
+ super().__init__()
1046
+ self.heads = heads
1047
+ self.total_heads = total_heads
1048
+
1049
+ slopes = Tensor(self._get_slopes(heads))
1050
+ slopes = rearrange(slopes, 'h -> h 1 1')
1051
+ self.register_buffer('slopes', slopes, persistent = False)
1052
+ self.register_buffer('bias', None, persistent = False)
1053
+
1054
+ def get_bias(self, i, j, device):
1055
+ i_arange = torch.arange(j - i, j, device = device)
1056
+ j_arange = torch.arange(j, device = device)
1057
+ bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
1058
+ return bias
1059
+
1060
+ @staticmethod
1061
+ def _get_slopes(heads):
1062
+ def get_slopes_power_of_2(n):
1063
+ start = (2**(-2**-(math.log2(n)-3)))
1064
+ ratio = start
1065
+ return [start*ratio**i for i in range(n)]
1066
+
1067
+ if math.log2(heads).is_integer():
1068
+ return get_slopes_power_of_2(heads)
1069
+
1070
+ closest_power_of_2 = 2 ** math.floor(math.log2(heads))
1071
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
1072
+
1073
+ @property
1074
+ def device(self):
1075
+ return next(self.buffers()).device
1076
+
1077
+ def forward(self, i, j):
1078
+ h, device = self.total_heads, self.device
1079
+
1080
+ if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
1081
+ return self.bias[..., -i:, -j:]
1082
+
1083
+ bias = self.get_bias(i, j, device)
1084
+ bias = bias * self.slopes
1085
+
1086
+ num_heads_unalibied = h - bias.shape[0]
1087
+ bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
1088
+ self.register_buffer('bias', bias, persistent = False)
1089
+
1090
+ return self.bias
1091
+
1092
+ class RotaryEmbedding(nn.Module):
1093
+ def __init__(
1094
+ self,
1095
+ dim,
1096
+ use_xpos = False,
1097
+ scale_base = 512,
1098
+ interpolation_factor = 1.,
1099
+ base = 10000,
1100
+ base_rescale_factor = 1.
1101
+ ):
1102
+ super().__init__()
1103
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
1104
+ # has some connection to NTK literature
1105
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
1106
+ base *= base_rescale_factor ** (dim / (dim - 2))
1107
+
1108
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
1109
+ self.register_buffer('inv_freq', inv_freq)
1110
+
1111
+ assert interpolation_factor >= 1.
1112
+ self.interpolation_factor = interpolation_factor
1113
+
1114
+ if not use_xpos:
1115
+ self.register_buffer('scale', None)
1116
+ return
1117
+
1118
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
1119
+
1120
+ self.scale_base = scale_base
1121
+ self.register_buffer('scale', scale)
1122
+
1123
+ def forward(self, seq_len):
1124
+ device = self.inv_freq.device
1125
+ t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
1126
+
1127
+ t = t / self.interpolation_factor
1128
+
1129
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
1130
+ freqs = torch.cat((freqs, freqs), dim = -1)
1131
+
1132
+ if not exists(self.scale):
1133
+ return freqs, 1.
1134
+
1135
+ power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
1136
+ scale = self.scale ** rearrange(power, 'n -> n 1')
1137
+ scale = torch.cat((scale, scale), dim = -1)
1138
+
1139
+ return freqs, scale
1140
+
1141
+
1142
+ def rotate_half(x):
1143
+ x = rearrange(x, '... (j d) -> ... j d', j = 2)
1144
+ x1, x2 = x.unbind(dim = -2)
1145
+ return torch.cat((-x2, x1), dim = -1)
1146
+
1147
+ def apply_rotary_pos_emb(t, freqs, scale = 1):
1148
+ rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
1149
+ freqs = freqs[-seq_len:, :]
1150
+
1151
+ if t.ndim == 4 and freqs.ndim == 3:
1152
+ freqs = rearrange(freqs, 'b n d -> b 1 n d')
1153
+
1154
+ # partial rotary embeddings, Wang et al. GPT-J
1155
+ t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
1156
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
1157
+ return torch.cat((t, t_unrotated), dim = -1)
1158
+
1159
+ # norms
1160
+
1161
+ class Scale(nn.Module):
1162
+ def __init__(self, value, fn):
1163
+ super().__init__()
1164
+ self.value = value
1165
+ self.fn = fn
1166
+
1167
+ def forward(self, x, **kwargs):
1168
+ out = self.fn(x, **kwargs)
1169
+ scale_fn = lambda t: t * self.value
1170
+
1171
+ if not isinstance(out, tuple):
1172
+ return scale_fn(out)
1173
+
1174
+ return (scale_fn(out[0]), *out[1:])
1175
+
1176
+ class ScaleNorm(nn.Module):
1177
+ def __init__(self, dim, eps = 1e-5):
1178
+ super().__init__()
1179
+ self.eps = eps
1180
+ self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
1181
+
1182
+ def forward(self, x):
1183
+ norm = torch.norm(x, dim = -1, keepdim = True)
1184
+ return x / norm.clamp(min = self.eps) * self.g
1185
+
1186
+ class RMSNorm(nn.Module):
1187
+ def __init__(self, dim):
1188
+ super().__init__()
1189
+ self.scale = dim ** 0.5
1190
+ self.g = nn.Parameter(torch.ones(dim))
1191
+
1192
+ def forward(self, x):
1193
+ return F.normalize(x, dim = -1) * self.scale * self.g
1194
+
1195
+ class SimpleRMSNorm(nn.Module):
1196
+ def __init__(self, dim):
1197
+ super().__init__()
1198
+ self.scale = dim ** 0.5
1199
+
1200
+ def forward(self, x):
1201
+ return F.normalize(x, dim = -1) * self.scale
1202
+
1203
+ # residual and residual gates
1204
+
1205
+ class Residual(nn.Module):
1206
+ def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
1207
+ super().__init__()
1208
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
1209
+ self.scale_residual_constant = scale_residual_constant
1210
+
1211
+ def forward(self, x, residual):
1212
+ if exists(self.residual_scale):
1213
+ residual = residual * self.residual_scale
1214
+
1215
+ if self.scale_residual_constant != 1:
1216
+ residual = residual * self.scale_residual_constant
1217
+
1218
+ return x + residual
1219
+
1220
+ class GRUGating(nn.Module):
1221
+ def __init__(self, dim, scale_residual = False, **kwargs):
1222
+ super().__init__()
1223
+ self.gru = nn.GRUCell(dim, dim)
1224
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
1225
+
1226
+ def forward(self, x, residual):
1227
+ if exists(self.residual_scale):
1228
+ residual = residual * self.residual_scale
1229
+
1230
+ gated_output = self.gru(
1231
+ rearrange(x, 'b n d -> (b n) d'),
1232
+ rearrange(residual, 'b n d -> (b n) d')
1233
+ )
1234
+
1235
+ return gated_output.reshape_as(x)
1236
+
1237
+ # token shifting
1238
+
1239
+ def shift(t, amount, mask = None):
1240
+ if amount == 0:
1241
+ return t
1242
+ else:
1243
+ amount = min(amount, t.shape[1])
1244
+
1245
+ if exists(mask):
1246
+ t = t.masked_fill(~mask[..., None], 0.)
1247
+
1248
+ return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
1249
+
1250
+ class ShiftTokens(nn.Module):
1251
+ def __init__(self, shifts, fn):
1252
+ super().__init__()
1253
+ self.fn = fn
1254
+ self.shifts = tuple(shifts)
1255
+
1256
+ def forward(self, x, **kwargs):
1257
+ mask = kwargs.get('mask', None)
1258
+ shifts = self.shifts
1259
+ segments = len(shifts)
1260
+ feats_per_shift = x.shape[-1] // segments
1261
+ splitted = x.split(feats_per_shift, dim = -1)
1262
+ segments_to_shift, rest = splitted[:segments], splitted[segments:]
1263
+ segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
1264
+ x = torch.cat((*segments_to_shift, *rest), dim = -1)
1265
+ return self.fn(x, **kwargs)
1266
+
1267
+ # feedforward
1268
+
1269
+ class GLU(nn.Module):
1270
+ def __init__(
1271
+ self,
1272
+ dim_in,
1273
+ dim_out,
1274
+ activation: Callable,
1275
+ mult_bias = False
1276
+ ):
1277
+ super().__init__()
1278
+ self.act = activation
1279
+ self.proj = nn.Linear(dim_in, dim_out * 2)
1280
+ self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1.
1281
+
1282
+ def forward(self, x):
1283
+ x, gate = self.proj(x).chunk(2, dim = -1)
1284
+ return x * self.act(gate) * self.mult_bias
1285
+
1286
+ class FeedForward(nn.Module):
1287
+ def __init__(
1288
+ self,
1289
+ dim,
1290
+ dim_out = None,
1291
+ mult = 4,
1292
+ glu = False,
1293
+ glu_mult_bias = False,
1294
+ swish = False,
1295
+ relu_squared = False,
1296
+ post_act_ln = False,
1297
+ dropout = 0.,
1298
+ no_bias = False,
1299
+ zero_init_output = False
1300
+ ):
1301
+ super().__init__()
1302
+ inner_dim = int(dim * mult)
1303
+ dim_out = default(dim_out, dim)
1304
+
1305
+ if relu_squared:
1306
+ activation = ReluSquared()
1307
+ elif swish:
1308
+ activation = nn.SiLU()
1309
+ else:
1310
+ activation = nn.GELU()
1311
+
1312
+ if glu:
1313
+ project_in = GLU(dim, inner_dim, activation, mult_bias = glu_mult_bias)
1314
+ else:
1315
+ project_in = nn.Sequential(
1316
+ nn.Linear(dim, inner_dim, bias = not no_bias),
1317
+ activation
1318
+ )
1319
+
1320
+ self.ff = Sequential(
1321
+ project_in,
1322
+ nn.LayerNorm(inner_dim) if post_act_ln else None,
1323
+ nn.Dropout(dropout),
1324
+ nn.Linear(inner_dim, dim_out, bias = not no_bias)
1325
+ )
1326
+
1327
+ # init last linear layer to 0
1328
+ if zero_init_output:
1329
+ init_zero_(self.ff[-1])
1330
+
1331
+ def forward(self, x):
1332
+ return self.ff(x)
1333
+
1334
+ # attention. it is all we need
1335
+
1336
+ class Attention(nn.Module):
1337
+ def __init__(
1338
+ self,
1339
+ dim,
1340
+ dim_head = DEFAULT_DIM_HEAD,
1341
+ heads = 8,
1342
+ causal = False,
1343
+ flash = False,
1344
+ talking_heads = False,
1345
+ head_scale = False,
1346
+ sparse_topk = None,
1347
+ num_mem_kv = 0,
1348
+ dropout = 0.,
1349
+ on_attn = False,
1350
+ gate_value_heads = False,
1351
+ gate_values = False,
1352
+ zero_init_output = False,
1353
+ max_attend_past = None,
1354
+ qk_norm = False,
1355
+ qk_norm_groups = 1,
1356
+ qk_norm_scale = 10,
1357
+ qk_norm_dim_scale = False,
1358
+ one_kv_head = False,
1359
+ kv_heads = None,
1360
+ shared_kv = False,
1361
+ value_dim_head = None,
1362
+ tensor_product = False, # https://arxiv.org/abs/2208.06061
1363
+ add_zero_kv = False, # same as add_zero_attn in pytorch
1364
+ rotary_embed_values = False,
1365
+ onnxable = False
1366
+ ):
1367
+ super().__init__()
1368
+ self.scale = dim_head ** -0.5
1369
+
1370
+ self.heads = heads
1371
+ self.causal = causal
1372
+ self.max_attend_past = max_attend_past
1373
+
1374
+ assert not (exists(kv_heads) and one_kv_head), 'either attn_one_kv_head is set to True (in which case kv_heads is set to 1), or attn_kv_heads is set, but not both'
1375
+
1376
+ value_dim_head = default(value_dim_head, dim_head)
1377
+ kv_heads = default(kv_heads, heads)
1378
+
1379
+ kv_heads = 1 if one_kv_head else kv_heads
1380
+ assert divisible_by(heads, kv_heads)
1381
+
1382
+ self.kv_heads = kv_heads
1383
+
1384
+ q_dim = dim_head * heads
1385
+ k_dim = dim_head * kv_heads
1386
+ v_dim = value_dim_head * kv_heads
1387
+ out_dim = value_dim_head * heads
1388
+
1389
+ self.to_q = nn.Linear(dim, q_dim, bias = False)
1390
+ self.to_k = nn.Linear(dim, k_dim, bias = False)
1391
+
1392
+ # shared key / values, for further memory savings during inference
1393
+ assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
1394
+ self.to_v = nn.Linear(dim, v_dim, bias = False) if not shared_kv else None
1395
+
1396
+ # relations projection from tp-attention
1397
+ self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
1398
+
1399
+ # add GLU gating for aggregated values, from alphafold2
1400
+ self.to_v_gate = None
1401
+ if gate_values:
1402
+ self.to_v_gate = nn.Linear(dim, out_dim)
1403
+ nn.init.constant_(self.to_v_gate.weight, 0)
1404
+ nn.init.constant_(self.to_v_gate.bias, 10)
1405
+
1406
+ # add per head gating of the output values, from 'Attend to nothing' paper
1407
+ self.to_v_head_gate = None
1408
+ if gate_value_heads:
1409
+ self.to_v_head_gate = nn.Linear(dim, heads)
1410
+ nn.init.constant_(self.to_v_head_gate.weight, 0)
1411
+ nn.init.constant_(self.to_v_head_gate.bias, 10)
1412
+
1413
+ # cosine sim attention
1414
+ self.qk_norm = qk_norm
1415
+ self.qk_norm_groups = qk_norm_groups
1416
+ self.qk_norm_scale = qk_norm_scale
1417
+
1418
+ # whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442
1419
+ self.qk_norm_dim_scale = qk_norm_dim_scale
1420
+
1421
+ self.qk_norm_q_scale = self.qk_norm_k_scale = 1
1422
+ if qk_norm and qk_norm_dim_scale:
1423
+ self.qk_norm_q_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1424
+ self.qk_norm_k_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1425
+
1426
+ assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
1427
+ assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
1428
+
1429
+ # attend class - includes core attention algorithm + talking heads
1430
+
1431
+ self.attend = Attend(
1432
+ heads = heads,
1433
+ causal = causal,
1434
+ talking_heads = talking_heads,
1435
+ dropout = dropout,
1436
+ sparse_topk = sparse_topk,
1437
+ qk_norm = qk_norm,
1438
+ scale = qk_norm_scale if qk_norm else self.scale,
1439
+ add_zero_kv = add_zero_kv,
1440
+ flash = flash,
1441
+ onnxable = onnxable
1442
+ )
1443
+
1444
+ # head scaling
1445
+ self.head_scale = head_scale
1446
+ if head_scale:
1447
+ self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
1448
+
1449
+ # explicit topk sparse attention
1450
+ self.sparse_topk = sparse_topk
1451
+
1452
+ # add memory key / values
1453
+ self.num_mem_kv = num_mem_kv
1454
+ if num_mem_kv > 0:
1455
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
1456
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
1457
+
1458
+ # attention on attention
1459
+ self.attn_on_attn = on_attn
1460
+ self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False)
1461
+
1462
+ # whether to rotate positions into values, for absolute positions in addition to relative
1463
+ self.rotary_embed_values = rotary_embed_values
1464
+
1465
+ # init output projection 0
1466
+ if zero_init_output:
1467
+ init_zero_(self.to_out)
1468
+
1469
+ def forward(
1470
+ self,
1471
+ x,
1472
+ context = None,
1473
+ mask = None,
1474
+ context_mask = None,
1475
+ attn_mask = None,
1476
+ rel_pos = None,
1477
+ rotary_pos_emb = None,
1478
+ prev_attn = None,
1479
+ mem = None,
1480
+ return_intermediates = False,
1481
+ cache: Optional[Intermediates] = None,
1482
+ ):
1483
+ b, n, _, h, kv_h, head_scale, device, has_context = *x.shape, self.heads, self.kv_heads, self.head_scale, x.device, exists(context)
1484
+ kv_input = default(context, x)
1485
+
1486
+ q_input = x
1487
+ k_input = kv_input
1488
+ v_input = kv_input
1489
+ r_input = x
1490
+
1491
+ if exists(mem):
1492
+ k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
1493
+ v_input, _ = pack([mem, v_input], 'b * d')
1494
+
1495
+ q = self.to_q(q_input)
1496
+ k = self.to_k(k_input)
1497
+ v = self.to_v(v_input) if exists(self.to_v) else k
1498
+ r = self.to_r(r_input) if exists(self.to_r) else None
1499
+
1500
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
1501
+
1502
+ k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h), (k, v, r))
1503
+
1504
+ if exists(cache) and not has_context:
1505
+ ck, cv = cache.cached_kv
1506
+
1507
+ if exists(mem):
1508
+ mk, k = unpack(k, mem_packed_shape, 'b h * d')
1509
+ mv, v = unpack(v, mem_packed_shape, 'b h * d')
1510
+
1511
+ k = torch.cat((ck, k), dim = -2)
1512
+ v = torch.cat((cv, v), dim = -2)
1513
+
1514
+ if exists(mem):
1515
+ k = torch.cat((mk, k), dim = -2)
1516
+ v = torch.cat((mv, v), dim = -2)
1517
+
1518
+ if return_intermediates:
1519
+ mem_len = mem.shape[-2] if exists(mem) else 0
1520
+ cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1521
+
1522
+ if self.qk_norm:
1523
+ qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1524
+ q, k = map(qk_l2norm, (q, k))
1525
+ scale = self.qk_norm_scale
1526
+
1527
+ q = q * self.qk_norm_q_scale
1528
+ k = k * self.qk_norm_k_scale
1529
+
1530
+ if exists(rotary_pos_emb) and not has_context:
1531
+ freqs, xpos_scale = rotary_pos_emb
1532
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1533
+
1534
+ q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
1535
+ k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
1536
+
1537
+ if self.rotary_embed_values:
1538
+ v = apply_rotary_pos_emb(v, freqs, k_xpos_scale)
1539
+
1540
+ input_mask = context_mask
1541
+
1542
+ if not exists(input_mask) and not has_context:
1543
+ input_mask = mask
1544
+
1545
+ if self.num_mem_kv > 0:
1546
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
1547
+
1548
+ if self.qk_norm:
1549
+ mem_k = l2norm(mem_k)
1550
+ mem_k = mem_k * self.qk_norm_k_scale
1551
+
1552
+ k = torch.cat((mem_k, k), dim = -2)
1553
+ v = torch.cat((mem_v, v), dim = -2)
1554
+
1555
+ if exists(input_mask):
1556
+ input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
1557
+
1558
+ i, j = map(lambda t: t.shape[-2], (q, k))
1559
+
1560
+ # determine masking
1561
+
1562
+ mask_value = max_neg_value(q)
1563
+ masks = []
1564
+ final_attn_mask = None
1565
+
1566
+ if exists(input_mask):
1567
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
1568
+ masks.append(~input_mask)
1569
+
1570
+ if exists(attn_mask):
1571
+ assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
1572
+ if attn_mask.ndim == 2:
1573
+ attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j')
1574
+ elif attn_mask.ndim == 3:
1575
+ attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j')
1576
+ masks.append(~attn_mask)
1577
+
1578
+ if exists(self.max_attend_past):
1579
+ range_q = torch.arange(j - i, j, device = device)
1580
+ range_k = torch.arange(j, device = device)
1581
+ dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j')
1582
+ max_attend_past_mask = dist > self.max_attend_past
1583
+ masks.append(max_attend_past_mask)
1584
+
1585
+ if len(masks) > 0:
1586
+ final_attn_mask = ~or_reduce(masks)
1587
+
1588
+ # prepare relative positional bias, if needed
1589
+
1590
+ attn_bias = None
1591
+ if exists(rel_pos):
1592
+ attn_bias = rel_pos(i, j)
1593
+
1594
+ # attention is all we need
1595
+
1596
+ out, intermediates = self.attend(
1597
+ q, k, v,
1598
+ mask = final_attn_mask,
1599
+ attn_bias = attn_bias,
1600
+ prev_attn = prev_attn
1601
+ )
1602
+
1603
+ # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
1604
+
1605
+ if exists(r):
1606
+ out = out * r + out
1607
+
1608
+ # normformer scaling of heads
1609
+
1610
+ if head_scale:
1611
+ out = out * self.head_scale_params
1612
+
1613
+ # per head gating, from https://arxiv.org/abs/2306.12929
1614
+
1615
+ if exists(self.to_v_head_gate):
1616
+ head_gate = self.to_v_head_gate(x)
1617
+ out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid()
1618
+
1619
+ # merge heads
1620
+
1621
+ out = rearrange(out, 'b h n d -> b n (h d)')
1622
+
1623
+ # alphafold2 styled gating of the values
1624
+
1625
+ if exists(self.to_v_gate):
1626
+ gates = self.to_v_gate(x)
1627
+ out = out * gates.sigmoid()
1628
+
1629
+ # combine the heads
1630
+
1631
+ out = self.to_out(out)
1632
+
1633
+ if exists(mask):
1634
+ mask = rearrange(mask, 'b n -> b n 1')
1635
+ out = out.masked_fill(~mask, 0.)
1636
+
1637
+ if not return_intermediates:
1638
+ return out
1639
+
1640
+ intermediates.cached_kv = cached_kv
1641
+
1642
+ return out, intermediates
1643
+
1644
+ class AttentionLayers(nn.Module):
1645
+ def __init__(
1646
+ self,
1647
+ dim,
1648
+ depth,
1649
+ heads = 8,
1650
+ causal = False,
1651
+ cross_attend = False,
1652
+ only_cross = False,
1653
+ use_scalenorm = False,
1654
+ use_rmsnorm = False,
1655
+ use_simple_rmsnorm = False,
1656
+ alibi_pos_bias = False,
1657
+ alibi_num_heads = None,
1658
+ rel_pos_bias = False,
1659
+ rel_pos_num_buckets = 32,
1660
+ rel_pos_max_distance = 128,
1661
+ dynamic_pos_bias = False,
1662
+ dynamic_pos_bias_log_distance = False,
1663
+ dynamic_pos_bias_mlp_depth = 2,
1664
+ dynamic_pos_bias_norm = False,
1665
+ rotary_pos_emb = False,
1666
+ rotary_emb_dim = None,
1667
+ rotary_xpos = False,
1668
+ rotary_interpolation_factor = 1.,
1669
+ rotary_xpos_scale_base = 512,
1670
+ rotary_base_rescale_factor = 1.,
1671
+ custom_layers = None,
1672
+ sandwich_coef = None,
1673
+ par_ratio = None,
1674
+ weight_tie_layers = False, # Albert - https://arxiv.org/abs/1909.11942
1675
+ layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
1676
+ residual_attn = False,
1677
+ cross_residual_attn = False,
1678
+ macaron = False,
1679
+ pre_norm = True,
1680
+ pre_norm_has_final_norm = True,
1681
+ gate_residual = False,
1682
+ scale_residual = False,
1683
+ scale_residual_constant = 1.,
1684
+ shift_tokens = 0,
1685
+ sandwich_norm = False,
1686
+ resi_dual = False,
1687
+ resi_dual_scale = 1.,
1688
+ zero_init_branch_output = False,
1689
+ layer_dropout = 0.,
1690
+ cross_attn_tokens_dropout = 0.,
1691
+ **kwargs
1692
+ ):
1693
+ super().__init__()
1694
+ rotary_pos_emb = rotary_pos_emb or rotary_xpos
1695
+
1696
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
1697
+ attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
1698
+
1699
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1700
+
1701
+ self.dim = dim
1702
+ self.depth = depth
1703
+ self.causal = causal
1704
+ self.layers = nn.ModuleList([])
1705
+
1706
+ self.has_pos_emb = rel_pos_bias or rotary_pos_emb
1707
+
1708
+ rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
1709
+
1710
+ assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
1711
+ self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
1712
+
1713
+ assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
1714
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
1715
+
1716
+ # relative positional bias
1717
+
1718
+ flash_attn = attn_kwargs.get('flash', False)
1719
+ assert (int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias)) <= 1, 'you can only choose up to one of t5, alibi, or dynamic positional bias'
1720
+
1721
+ self.rel_pos = None
1722
+ if rel_pos_bias:
1723
+ assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
1724
+ self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
1725
+ elif dynamic_pos_bias:
1726
+ assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
1727
+ self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
1728
+ elif alibi_pos_bias:
1729
+ alibi_num_heads = default(alibi_num_heads, heads)
1730
+ assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
1731
+ self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
1732
+
1733
+ assert (int(sandwich_norm) + int(resi_dual)) <= 1, 'either sandwich norm or resiDual is selected, but not both'
1734
+ assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
1735
+
1736
+ if resi_dual:
1737
+ pre_norm = False
1738
+
1739
+ self.pre_norm = pre_norm
1740
+ self.sandwich_norm = sandwich_norm
1741
+
1742
+ self.resi_dual = resi_dual
1743
+ assert 0 < resi_dual_scale <= 1., 'resiDual prenorm residual must be scaled by a factor greater than 0 and less than or equal to 1.'
1744
+ self.resi_dual_scale = resi_dual_scale
1745
+
1746
+ self.residual_attn = residual_attn
1747
+ self.cross_residual_attn = cross_residual_attn
1748
+ assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention'
1749
+
1750
+ self.cross_attend = cross_attend
1751
+
1752
+ assert (int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm)) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'
1753
+
1754
+ if use_scalenorm:
1755
+ norm_class = ScaleNorm
1756
+ elif use_rmsnorm:
1757
+ norm_class = RMSNorm
1758
+ elif use_simple_rmsnorm:
1759
+ norm_class = SimpleRMSNorm
1760
+ else:
1761
+ norm_class = nn.LayerNorm
1762
+
1763
+ norm_fn = partial(norm_class, dim)
1764
+
1765
+ if cross_attend and not only_cross:
1766
+ default_block = ('a', 'c', 'f')
1767
+ elif cross_attend and only_cross:
1768
+ default_block = ('c', 'f')
1769
+ else:
1770
+ default_block = ('a', 'f')
1771
+
1772
+ if macaron:
1773
+ default_block = ('f',) + default_block
1774
+
1775
+ # zero init
1776
+
1777
+ if zero_init_branch_output:
1778
+ attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
1779
+ ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
1780
+
1781
+ # setup weight tying, which is a special case of `layer_execute_order`
1782
+
1783
+ assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
1784
+
1785
+ if weight_tie_layers:
1786
+ assert not exists(layers_execute_order)
1787
+ layers_execute_order = tuple(range(len(default_block))) * depth
1788
+ depth = 1
1789
+
1790
+ # calculate layer block order
1791
+
1792
+ if exists(custom_layers):
1793
+ layer_types = custom_layers
1794
+ elif exists(par_ratio):
1795
+ par_depth = depth * len(default_block)
1796
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
1797
+ default_block = tuple(filter(not_equals('f'), default_block))
1798
+ par_attn = par_depth // par_ratio
1799
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
1800
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
1801
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
1802
+ par_block = default_block + ('f',) * (par_width - len(default_block))
1803
+ par_head = par_block * par_attn
1804
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
1805
+ elif exists(sandwich_coef):
1806
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
1807
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
1808
+ else:
1809
+ layer_types = default_block * depth
1810
+
1811
+ self.layer_types = layer_types
1812
+ self.layers_execute_order = default(layers_execute_order, tuple(range(len(layer_types))))
1813
+
1814
+ assert all([i < len(self.layer_types) for i in self.layers_execute_order])
1815
+
1816
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
1817
+
1818
+ # stochastic depth
1819
+
1820
+ self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
1821
+
1822
+ # structured dropout for cross attending
1823
+
1824
+ self.cross_attn_tokens_dropout = cross_attn_tokens_dropout
1825
+
1826
+ # calculate token shifting
1827
+
1828
+ shift_tokens = cast_tuple(shift_tokens, len(layer_types))
1829
+
1830
+ # whether it has post norm
1831
+
1832
+ self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity()
1833
+
1834
+ # iterate and construct layers
1835
+
1836
+ for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
1837
+ is_last_layer = ind == (len(self.layer_types) - 1)
1838
+
1839
+ if layer_type == 'a':
1840
+ layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
1841
+ elif layer_type == 'c':
1842
+ layer = Attention(dim, heads = heads, **attn_kwargs)
1843
+ elif layer_type == 'f':
1844
+ layer = FeedForward(dim, **ff_kwargs)
1845
+ layer = layer if not macaron else Scale(0.5, layer)
1846
+ else:
1847
+ raise Exception(f'invalid layer type {layer_type}')
1848
+
1849
+ if layer_shift_tokens > 0:
1850
+ shift_range_upper = layer_shift_tokens + 1
1851
+ shift_range_lower = -layer_shift_tokens if not causal else 0
1852
+ layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
1853
+
1854
+ residual_fn = GRUGating if gate_residual else Residual
1855
+ residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1856
+
1857
+ pre_branch_norm = norm_fn() if pre_norm else None
1858
+ post_branch_norm = norm_fn() if sandwich_norm else None
1859
+ post_main_norm = norm_fn() if not pre_norm else None
1860
+
1861
+ norms = nn.ModuleList([
1862
+ pre_branch_norm,
1863
+ post_branch_norm,
1864
+ post_main_norm
1865
+ ])
1866
+
1867
+ self.layers.append(nn.ModuleList([
1868
+ norms,
1869
+ layer,
1870
+ residual
1871
+ ]))
1872
+
1873
+ def forward(
1874
+ self,
1875
+ x,
1876
+ context = None,
1877
+ mask = None,
1878
+ context_mask = None,
1879
+ attn_mask = None,
1880
+ self_attn_kv_mask = None,
1881
+ mems = None,
1882
+ seq_start_pos: Optional[Tensor] = None,
1883
+ cache: Optional[LayerIntermediates] = None,
1884
+ cache_age = 1,
1885
+ return_hiddens = False
1886
+ ):
1887
+ assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
1888
+
1889
+ # initialize accums
1890
+
1891
+ hiddens = []
1892
+ layer_hiddens = []
1893
+ intermediates = []
1894
+
1895
+ prev_attn = None
1896
+ prev_cross_attn = None
1897
+
1898
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
1899
+
1900
+ # handle left padded sequences
1901
+
1902
+ if exists(seq_start_pos):
1903
+ seq_arange = torch.arange(x.shape[-2], device = x.device, dtype = torch.long)
1904
+ left_pad_mask = seq_arange >= seq_start_pos[..., None]
1905
+
1906
+ if exists(self_attn_kv_mask):
1907
+ self_attn_kv_mask = self_attn_kv_mask & left_pad_mask
1908
+ else:
1909
+ self_attn_kv_mask = left_pad_mask
1910
+
1911
+ # rotary positions
1912
+
1913
+ rotary_pos_emb = None
1914
+
1915
+ if exists(self.rotary_pos_emb):
1916
+ max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
1917
+ rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length)
1918
+
1919
+ # assume cached key / values
1920
+
1921
+ attn_cache = []
1922
+
1923
+ if exists(cache):
1924
+ assert not self.training and self.causal and not any([*map(exists, (mask, attn_mask))])
1925
+
1926
+ if cache_age > 0:
1927
+ x = x[:, -cache_age:] # for spec decoding, may be greater than 1
1928
+
1929
+ attn_cache = cache.attn_intermediates
1930
+
1931
+ iter_attn_cache = iter(attn_cache)
1932
+
1933
+ # outer residual - for resiDual paper
1934
+
1935
+ outer_residual = x * self.resi_dual_scale
1936
+
1937
+ # get layers to be executed
1938
+
1939
+ layer_variables = (
1940
+ self.layer_types,
1941
+ self.layers,
1942
+ self.layer_dropouts
1943
+ )
1944
+
1945
+ layer_variables = tuple(tuple(layer_variable[i] for i in self.layers_execute_order) for layer_variable in layer_variables)
1946
+
1947
+ # go through the attention and feedforward layers
1948
+
1949
+ for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
1950
+ is_last = ind == (len(self.layers) - 1)
1951
+
1952
+ if self.training and layer_dropout > 0. and random() < layer_dropout:
1953
+ continue
1954
+
1955
+ if layer_type == 'a':
1956
+ if return_hiddens:
1957
+ hiddens.append(x)
1958
+ layer_mem = mems.pop(0) if mems else None
1959
+
1960
+ if layer_type == 'c':
1961
+ if self.training and self.cross_attn_tokens_dropout > 0.:
1962
+ context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
1963
+
1964
+ inner_residual = x
1965
+
1966
+ if return_hiddens:
1967
+ layer_hiddens.append(x)
1968
+
1969
+ pre_norm, post_branch_norm, post_main_norm = norm
1970
+
1971
+ if exists(pre_norm):
1972
+ x = pre_norm(x)
1973
+
1974
+ if layer_type == 'a':
1975
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, return_intermediates = True)
1976
+ elif layer_type == 'c':
1977
+ out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
1978
+ elif layer_type == 'f':
1979
+ out = block(x)
1980
+
1981
+ if self.resi_dual:
1982
+ outer_residual = outer_residual + out * self.resi_dual_scale
1983
+
1984
+ if exists(post_branch_norm):
1985
+ out = post_branch_norm(out)
1986
+
1987
+ x = residual_fn(out, inner_residual)
1988
+
1989
+ if layer_type in ('a', 'c') and return_hiddens:
1990
+ intermediates.append(inter)
1991
+
1992
+ if layer_type == 'a' and self.residual_attn:
1993
+ prev_attn = inter.pre_softmax_attn
1994
+ elif layer_type == 'c' and self.cross_residual_attn:
1995
+ prev_cross_attn = inter.pre_softmax_attn
1996
+
1997
+ if exists(post_main_norm):
1998
+ x = post_main_norm(x)
1999
+
2000
+ if return_hiddens:
2001
+ layer_hiddens.append(x)
2002
+
2003
+ if self.resi_dual:
2004
+ x = x + self.final_norm(outer_residual)
2005
+ else:
2006
+ x = self.final_norm(x)
2007
+
2008
+ if not return_hiddens:
2009
+ return x
2010
+
2011
+ intermediates = LayerIntermediates(
2012
+ hiddens = hiddens,
2013
+ attn_intermediates = intermediates,
2014
+ layer_hiddens = layer_hiddens
2015
+ )
2016
+
2017
+ return x, intermediates
2018
+
2019
+ class Encoder(AttentionLayers):
2020
+ def __init__(self, **kwargs):
2021
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
2022
+ super().__init__(causal = False, **kwargs)
2023
+
2024
+ class Decoder(AttentionLayers):
2025
+ def __init__(self, **kwargs):
2026
+ assert 'causal' not in kwargs, 'cannot set causality on decoder'
2027
+ super().__init__(causal = True, **kwargs)
2028
+
2029
+ class CrossAttender(AttentionLayers):
2030
+ def __init__(self, **kwargs):
2031
+ super().__init__(cross_attend = True, only_cross = True, **kwargs)
2032
+
2033
+ class ViTransformerWrapper(nn.Module):
2034
+ def __init__(
2035
+ self,
2036
+ *,
2037
+ image_size,
2038
+ patch_size,
2039
+ attn_layers,
2040
+ channels = 3,
2041
+ num_classes = None,
2042
+ post_emb_norm = False,
2043
+ num_register_tokens = 0,
2044
+ emb_dropout = 0.
2045
+ ):
2046
+ super().__init__()
2047
+ assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
2048
+ assert divisible_by(image_size, patch_size), 'image dimensions must be divisible by the patch size'
2049
+ dim = attn_layers.dim
2050
+ num_patches = (image_size // patch_size) ** 2
2051
+ patch_dim = channels * patch_size ** 2
2052
+
2053
+ self.patch_size = patch_size
2054
+
2055
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
2056
+
2057
+ has_register_tokens = num_register_tokens > 0
2058
+ self.has_register_tokens = has_register_tokens
2059
+
2060
+ if has_register_tokens:
2061
+ self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
2062
+
2063
+ self.patch_to_embedding = nn.Sequential(
2064
+ nn.LayerNorm(patch_dim),
2065
+ nn.Linear(patch_dim, dim),
2066
+ nn.LayerNorm(dim)
2067
+ )
2068
+
2069
+ self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
2070
+ self.dropout = nn.Dropout(emb_dropout)
2071
+
2072
+ self.attn_layers = attn_layers
2073
+
2074
+ self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
2075
+
2076
+ def forward(
2077
+ self,
2078
+ img,
2079
+ return_embeddings = False
2080
+ ):
2081
+ b, p = img.shape[0], self.patch_size
2082
+
2083
+ x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
2084
+ x = self.patch_to_embedding(x)
2085
+ n = x.shape[1]
2086
+
2087
+ x = x + self.pos_embedding[:, :n]
2088
+
2089
+ x = self.post_emb_norm(x)
2090
+ x = self.dropout(x)
2091
+
2092
+ if self.has_register_tokens:
2093
+ r = repeat(self.register_tokens, 'n d -> b n d', b = b)
2094
+ x, ps = pack((x, r), 'b * d')
2095
+
2096
+ x = self.attn_layers(x)
2097
+
2098
+ if self.has_register_tokens:
2099
+ x, _ = unpack(x, ps, 'b * d')
2100
+
2101
+ if not exists(self.mlp_head) or return_embeddings:
2102
+ return x
2103
+
2104
+ x = x.mean(dim = -2)
2105
+ return self.mlp_head(x)
2106
+
2107
+ class TransformerWrapper(nn.Module):
2108
+ def __init__(
2109
+ self,
2110
+ *,
2111
+ num_tokens,
2112
+ max_seq_len,
2113
+ attn_layers,
2114
+ emb_dim = None,
2115
+ max_mem_len = 0,
2116
+ shift_mem_down = 0,
2117
+ emb_dropout = 0.,
2118
+ post_emb_norm = False,
2119
+ num_memory_tokens = None,
2120
+ memory_tokens_interspersed_every = None,
2121
+ tie_embedding = False,
2122
+ logits_dim = None,
2123
+ use_abs_pos_emb = True,
2124
+ scaled_sinu_pos_emb = False,
2125
+ l2norm_embed = False,
2126
+ emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
2127
+ attn_z_loss_weight = 1e-4,
2128
+ ):
2129
+ super().__init__()
2130
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
2131
+
2132
+ dim = attn_layers.dim
2133
+ emb_dim = default(emb_dim, dim)
2134
+ self.emb_dim = emb_dim
2135
+ self.num_tokens = num_tokens
2136
+
2137
+ self.max_seq_len = max_seq_len
2138
+ self.max_mem_len = max_mem_len
2139
+ self.shift_mem_down = shift_mem_down
2140
+
2141
+ self.l2norm_embed = l2norm_embed
2142
+ self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
2143
+
2144
+ if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
2145
+ self.pos_emb = always(0)
2146
+ elif scaled_sinu_pos_emb:
2147
+ self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
2148
+ else:
2149
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
2150
+
2151
+ self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
2152
+
2153
+ self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
2154
+ self.emb_dropout = nn.Dropout(emb_dropout)
2155
+
2156
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
2157
+ self.attn_layers = attn_layers
2158
+
2159
+ self.init_()
2160
+
2161
+ logits_dim = default(logits_dim, num_tokens)
2162
+ self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
2163
+
2164
+ # memory tokens (like [cls]) from Memory Transformers paper
2165
+
2166
+ num_memory_tokens = default(num_memory_tokens, 0)
2167
+ self.num_memory_tokens = num_memory_tokens
2168
+ if num_memory_tokens > 0:
2169
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
2170
+
2171
+ self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
2172
+
2173
+ # whether can do cached kv decoding
2174
+
2175
+ self.can_cache_kv = self.num_memory_tokens == 0
2176
+
2177
+ def init_(self):
2178
+ if self.l2norm_embed:
2179
+ nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
2180
+ if not isinstance(self.pos_emb, always):
2181
+ nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
2182
+ return
2183
+
2184
+ nn.init.kaiming_normal_(self.token_emb.emb.weight)
2185
+
2186
+ def forward(
2187
+ self,
2188
+ x,
2189
+ return_embeddings = False,
2190
+ return_logits_and_embeddings = False,
2191
+ return_intermediates = False,
2192
+ mask = None,
2193
+ return_mems = False,
2194
+ return_attn = False,
2195
+ mems = None,
2196
+ pos = None,
2197
+ prepend_embeds = None,
2198
+ sum_embeds = None,
2199
+ return_attn_z_loss = False,
2200
+ attn_z_loss_weight = 1e-4,
2201
+ seq_start_pos = None,
2202
+ cache: Optional[LayerIntermediates] = None,
2203
+ **kwargs
2204
+ ):
2205
+ b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
2206
+ return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
2207
+
2208
+ # absolute positional embedding
2209
+
2210
+ external_pos_emb = exists(pos) and pos.dtype != torch.long
2211
+ pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
2212
+ x = self.token_emb(x) + pos_emb
2213
+
2214
+ # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
2215
+
2216
+ if exists(sum_embeds):
2217
+ x = x + sum_embeds
2218
+
2219
+ # post embedding norm, purportedly leads to greater stabilization
2220
+
2221
+ x = self.post_emb_norm(x)
2222
+
2223
+ # whether to append embeds, as in PaLI, for image embeddings
2224
+
2225
+ if exists(prepend_embeds):
2226
+ prepend_seq, prepend_dim = prepend_embeds.shape[1:]
2227
+ assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
2228
+
2229
+ x = torch.cat((prepend_embeds, x), dim = -2)
2230
+
2231
+ # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
2232
+
2233
+ if emb_frac_gradient < 1:
2234
+ assert emb_frac_gradient > 0
2235
+ x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
2236
+
2237
+ # embedding dropout
2238
+
2239
+ x = self.emb_dropout(x)
2240
+
2241
+ x = self.project_emb(x)
2242
+
2243
+ if has_memory_tokens:
2244
+ mem_every = self.memory_tokens_interspersed_every
2245
+
2246
+ if exists(mem_every):
2247
+ assert mem_every > 0
2248
+ assert isinstance(self.attn_layers, Decoder), 'only for decoder'
2249
+ next_seq_len = math.ceil(n / mem_every) * mem_every
2250
+
2251
+ x = pad_at_dim(x, (0, next_seq_len - n), dim = -2, value = 0.)
2252
+ x = rearrange(x, 'b (n m) d -> (b n) m d', m = mem_every)
2253
+
2254
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b = x.shape[0])
2255
+ x, mem_packed_shape = pack((mem, x), 'b * d')
2256
+
2257
+ # auto-handle masking after appending memory tokens
2258
+ if not exists(mem_every) and exists(mask):
2259
+ mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
2260
+
2261
+ if exists(mem_every):
2262
+ x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2263
+
2264
+ if self.shift_mem_down and exists(mems):
2265
+ mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
2266
+ mems = [*mems_r, *mems_l]
2267
+
2268
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
2269
+
2270
+ if has_memory_tokens:
2271
+ if exists(mem_every):
2272
+ x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
2273
+
2274
+ mem, x = unpack(x, mem_packed_shape, 'b * d')
2275
+
2276
+ if exists(mem_every):
2277
+ x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2278
+
2279
+ x = x[:, :n]
2280
+
2281
+ if return_logits_and_embeddings:
2282
+ out = (self.to_logits(x), x)
2283
+ elif return_embeddings:
2284
+ out = x
2285
+ else:
2286
+ out = self.to_logits(x)
2287
+
2288
+ if return_attn_z_loss:
2289
+ pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
2290
+ intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
2291
+ return_intermediates = True
2292
+
2293
+ if return_mems:
2294
+ hiddens = intermediates.hiddens
2295
+ new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
2296
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
2297
+
2298
+ if not return_intermediates:
2299
+ return out, new_mems
2300
+
2301
+ intermediates.mems = new_mems
2302
+
2303
+ if return_intermediates:
2304
+ return out, intermediates
2305
+
2306
+ if return_attn:
2307
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
2308
+ return out, attn_maps
2309
+
2310
+ return out
2311
+
2312
+ class ContinuousTransformerWrapper(nn.Module):
2313
+ def __init__(
2314
+ self,
2315
+ *,
2316
+ max_seq_len,
2317
+ attn_layers,
2318
+ dim_in = None,
2319
+ dim_out = None,
2320
+ emb_dim = None,
2321
+ max_mem_len = 0,
2322
+ post_emb_norm = False,
2323
+ emb_dropout = 0.,
2324
+ use_abs_pos_emb = True,
2325
+ scaled_sinu_pos_emb = False
2326
+ ):
2327
+ super().__init__()
2328
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
2329
+
2330
+ dim = attn_layers.dim
2331
+
2332
+ self.max_seq_len = max_seq_len
2333
+
2334
+ self.max_mem_len = max_mem_len
2335
+
2336
+ if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
2337
+ self.pos_emb = always(0)
2338
+ elif scaled_sinu_pos_emb:
2339
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
2340
+ else:
2341
+ self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
2342
+
2343
+ self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
2344
+ self.emb_dropout = nn.Dropout(emb_dropout)
2345
+
2346
+ self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
2347
+
2348
+ self.attn_layers = attn_layers
2349
+
2350
+ self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
2351
+
2352
+ def forward(
2353
+ self,
2354
+ x,
2355
+ return_embeddings = False,
2356
+ return_intermediates = False,
2357
+ return_mems = False,
2358
+ mask = None,
2359
+ return_attn = False,
2360
+ mems = None,
2361
+ pos = None,
2362
+ prepend_embeds = None,
2363
+ **kwargs
2364
+ ):
2365
+ x = self.project_in(x)
2366
+ x = x + self.pos_emb(x, pos = pos)
2367
+
2368
+ x = self.post_emb_norm(x)
2369
+
2370
+ # whether to append embeds, as in PaLI, for image embeddings
2371
+
2372
+ if exists(prepend_embeds):
2373
+ _, prepend_dim = prepend_embeds.shape[1:]
2374
+ assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
2375
+
2376
+ x = torch.cat((prepend_embeds, x), dim = -2)
2377
+
2378
+ x = self.emb_dropout(x)
2379
+
2380
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
2381
+
2382
+ out = self.project_out(x) if not return_embeddings else x
2383
+
2384
+ if return_intermediates:
2385
+ return out, intermediates
2386
+
2387
+ if return_mems:
2388
+ hiddens = intermediates.hiddens
2389
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
2390
+ return out, new_mems
2391
+
2392
+ if return_attn:
2393
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
2394
+ return out, attn_maps
2395
+
2396
+ return out
2397
+
2398
+ class XTransformer(nn.Module):
2399
+ def __init__(
2400
+ self,
2401
+ *,
2402
+ dim,
2403
+ tie_token_emb = False,
2404
+ ignore_index = -100,
2405
+ pad_value = 0,
2406
+ cross_attn_tokens_dropout = 0.,
2407
+ **kwargs
2408
+ ):
2409
+ super().__init__()
2410
+ enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
2411
+ dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
2412
+
2413
+ assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
2414
+ enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
2415
+ enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
2416
+ enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
2417
+ enc_transformer_kwargs['scaled_sinu_pos_emb'] = enc_kwargs.pop('scaled_sinu_pos_emb', False)
2418
+ enc_transformer_kwargs['use_abs_pos_emb'] = enc_kwargs.pop('use_abs_pos_emb', True)
2419
+
2420
+ dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
2421
+ dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
2422
+ dec_transformer_kwargs['scaled_sinu_pos_emb'] = dec_kwargs.pop('scaled_sinu_pos_emb', False)
2423
+ dec_transformer_kwargs['use_abs_pos_emb'] = dec_kwargs.pop('use_abs_pos_emb', True)
2424
+
2425
+ self.cross_attn_tokens_dropout = cross_attn_tokens_dropout # how many tokens from the encoder to dropout when cross attending from decoder - seen in a couple papers, including Perceiver AR - this will also be very effective regularization when cross attending to very long memories
2426
+
2427
+ self.encoder = TransformerWrapper(
2428
+ **enc_transformer_kwargs,
2429
+ attn_layers = Encoder(dim = dim, **enc_kwargs)
2430
+ )
2431
+
2432
+ self.decoder = TransformerWrapper(
2433
+ **dec_transformer_kwargs,
2434
+ attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
2435
+ )
2436
+
2437
+ if tie_token_emb:
2438
+ self.decoder.token_emb = self.encoder.token_emb
2439
+
2440
+ self.decoder = AutoregressiveWrapper(self.decoder, ignore_index=ignore_index, pad_value=pad_value)
2441
+
2442
+ @torch.no_grad()
2443
+ def generate(self, seq_in, seq_out_start, seq_len, mask = None, attn_mask = None, **kwargs):
2444
+ encodings = self.encoder(seq_in, mask = mask, attn_mask = attn_mask, return_embeddings = True)
2445
+ return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = mask, **kwargs)
2446
+
2447
+ def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):
2448
+
2449
+ if exists(src_prepend_embeds) and exists(mask):
2450
+ mask = pad_at_dim(mask, (src_prepend_embeds.shape[-2], 0), dim = -1, value = True)
2451
+
2452
+ enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)
2453
+
2454
+ if self.training and self.cross_attn_tokens_dropout > 0:
2455
+ enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
2456
+
2457
+ out = self.decoder(tgt, context = enc, context_mask = mask)
2458
+ return out