vikhyatk commited on
Commit
0176a51
1 Parent(s): 8e12426
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: Moondream1
3
- emoji: 😻
4
- colorFrom: pink
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.15.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: moondream1
3
+ emoji: 🌔
4
+ colorFrom: indigo
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 4.15.0
8
  app_file: app.py
9
  pinned: false
10
+ preload_from_hub:
11
+ - vikhyatk/moondream1
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,1228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from PIL import Image
5
+ from einops import rearrange
6
+ from torchvision.transforms.v2 import (
7
+ Compose,
8
+ Resize,
9
+ InterpolationMode,
10
+ ToImage,
11
+ ToDtype,
12
+ Normalize,
13
+ )
14
+
15
+ from transformers import CodeGenTokenizerFast as Tokenizer
16
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
17
+ import re
18
+
19
+ import math
20
+ from typing import Optional
21
+
22
+ from transformers import PretrainedConfig
23
+
24
+
25
+ import math
26
+ from dataclasses import dataclass, field
27
+ from typing import Any, Dict, Optional, Tuple, Union
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ from einops import rearrange, repeat
32
+ from transformers import PretrainedConfig, PreTrainedModel
33
+ from transformers.activations import ACT2FN
34
+ from transformers.modeling_outputs import CausalLMOutputWithPast
35
+
36
+ pad_input, unpad_input = None, None
37
+ FlashRotaryEmbedding = None
38
+ FlashSelfAttention, FlashCrossAttention = None, None
39
+ FusedDense = None
40
+
41
+ if torch.cuda.is_available():
42
+ DEVICE = "cuda"
43
+ DTYPE = torch.float16
44
+ else:
45
+ DEVICE = "cpu"
46
+ DTYPE = torch.float32
47
+
48
+
49
+ class PhiConfig(PretrainedConfig):
50
+ """Phi configuration."""
51
+
52
+ model_type = "phi-msft"
53
+ attribute_map = {
54
+ "max_position_embeddings": "n_positions",
55
+ "hidden_size": "n_embd",
56
+ "num_attention_heads": "n_head",
57
+ "num_hidden_layers": "n_layer",
58
+ }
59
+
60
+ def __init__(
61
+ self,
62
+ vocab_size: int = 50304,
63
+ n_positions: int = 2048,
64
+ n_embd: int = 1024,
65
+ n_layer: int = 20,
66
+ n_inner: Optional[int] = None,
67
+ n_head: int = 16,
68
+ n_head_kv: Optional[int] = None,
69
+ rotary_dim: Optional[int] = 32,
70
+ activation_function: Optional[str] = "gelu_new",
71
+ flash_attn: bool = False,
72
+ flash_rotary: bool = False,
73
+ fused_dense: bool = False,
74
+ attn_pdrop: float = 0.0,
75
+ embd_pdrop: float = 0.0,
76
+ resid_pdrop: float = 0.0,
77
+ layer_norm_epsilon: float = 1e-5,
78
+ initializer_range: float = 0.02,
79
+ tie_word_embeddings: bool = False,
80
+ pad_vocab_size_multiple: int = 64,
81
+ gradient_checkpointing: bool = False,
82
+ **kwargs,
83
+ ) -> None:
84
+ self.vocab_size = int(
85
+ math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
86
+ )
87
+ self.n_positions = n_positions
88
+ self.n_embd = n_embd
89
+ self.n_layer = n_layer
90
+ self.n_inner = n_inner
91
+ self.n_head = n_head
92
+ self.n_head_kv = n_head_kv
93
+ self.rotary_dim = min(rotary_dim, n_embd // n_head)
94
+ self.activation_function = activation_function
95
+ self.flash_attn = flash_attn
96
+ self.flash_rotary = flash_rotary
97
+ self.fused_dense = fused_dense
98
+ self.attn_pdrop = attn_pdrop
99
+ self.embd_pdrop = embd_pdrop
100
+ self.resid_pdrop = resid_pdrop
101
+ self.layer_norm_epsilon = layer_norm_epsilon
102
+ self.initializer_range = initializer_range
103
+ self.gradient_checkpointing = gradient_checkpointing
104
+
105
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
106
+
107
+
108
+ @dataclass
109
+ class InferenceParams:
110
+ """Inference parameters passed to model to efficiently calculate
111
+ and store context during inference.
112
+
113
+ Reference:
114
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
115
+
116
+ Args:
117
+ max_seqlen: Maximum sequence length.
118
+ max_batch_size: Maximum batch size.
119
+ seqlen_offset: Sequence length offset.
120
+ batch_size_offset: Batch size offset.
121
+ key_value_memory_dict: Key value memory dictionary.
122
+ lengths_per_sample: Lengths per sample.
123
+
124
+ """
125
+
126
+ max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
127
+
128
+ max_batch_size: int = field(metadata={"help": "Maximum batch size."})
129
+
130
+ seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
131
+
132
+ batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
133
+
134
+ key_value_memory_dict: Dict[str, Any] = field(
135
+ default_factory=dict, metadata={"help": "Key value memory dictionary."}
136
+ )
137
+
138
+ lengths_per_sample: torch.Tensor = field(
139
+ default=None, metadata={"help": "Lengths per sample."}
140
+ )
141
+
142
+
143
+ class Embedding(nn.Module):
144
+ """Token embedding with dropout."""
145
+
146
+ def __init__(self, config: PretrainedConfig) -> None:
147
+ super().__init__()
148
+
149
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
150
+ self.drop = nn.Dropout(config.embd_pdrop)
151
+
152
+ def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
153
+ input_shape = input_ids.size()
154
+ input_ids = input_ids.view(-1, input_shape[-1])
155
+
156
+ hidden_states = self.wte(input_ids)
157
+ hidden_states = self.drop(hidden_states)
158
+
159
+ return hidden_states
160
+
161
+
162
+ # @torch.compile
163
+ def _apply_rotary_emb(
164
+ x: torch.FloatTensor,
165
+ cos: torch.FloatTensor,
166
+ sin: torch.FloatTensor,
167
+ ) -> torch.FloatTensor:
168
+ _, seqlen, _, _ = x.shape
169
+ _, rotary_dim = cos.shape
170
+ rotary_dim *= 2
171
+
172
+ x_rot = x[:, :, :, :rotary_dim]
173
+ x_pass = x[:, :, :, rotary_dim:]
174
+
175
+ x1, x2 = x_rot.chunk(2, dim=-1)
176
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
177
+ sin[:seqlen], "s d -> s 1 d"
178
+ )
179
+ x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
180
+
181
+ x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
182
+
183
+ return torch.cat([x_rot, x_pass], axis=-1)
184
+
185
+
186
+ # @torch.compile
187
+ def _apply_rotary_emb_kv(
188
+ kv: torch.FloatTensor,
189
+ cos: torch.FloatTensor,
190
+ sin: torch.FloatTensor,
191
+ cos_k: Optional[torch.FloatTensor] = None,
192
+ sin_k: Optional[torch.FloatTensor] = None,
193
+ ) -> torch.FloatTensor:
194
+ _, seqlen, _, _, _ = kv.shape
195
+ _, rotary_dim = cos.shape
196
+ rotary_dim *= 2
197
+
198
+ k_rot = kv[:, :, 0, :, :rotary_dim]
199
+ k_pass = kv[:, :, 0, :, rotary_dim:]
200
+
201
+ k1, k2 = k_rot.chunk(2, dim=-1)
202
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
203
+ sin[:seqlen], "s d -> s 1 d"
204
+ )
205
+ k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
206
+
207
+ k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
208
+
209
+ return torch.cat(
210
+ [
211
+ torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
212
+ kv[:, :, 1:2, :, :],
213
+ ],
214
+ axis=2,
215
+ )
216
+
217
+
218
+ # @torch.compile
219
+ def _apply_rotary_emb_qkv(
220
+ qkv: torch.FloatTensor,
221
+ cos: torch.FloatTensor,
222
+ sin: torch.FloatTensor,
223
+ cos_k: Optional[torch.FloatTensor] = None,
224
+ sin_k: Optional[torch.FloatTensor] = None,
225
+ ) -> torch.FloatTensor:
226
+ _, seqlen, _, _, _ = qkv.shape
227
+ _, rotary_dim = cos.shape
228
+ rotary_dim *= 2
229
+
230
+ q_rot = qkv[:, :, 0, :, :rotary_dim]
231
+ q_pass = qkv[:, :, 0, :, rotary_dim:]
232
+
233
+ k_rot = qkv[:, :, 1, :, :rotary_dim]
234
+ k_pass = qkv[:, :, 1, :, rotary_dim:]
235
+
236
+ q1, q2 = q_rot.chunk(2, dim=-1)
237
+ k1, k2 = k_rot.chunk(2, dim=-1)
238
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
239
+ sin[:seqlen], "s d -> s 1 d"
240
+ )
241
+ q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
242
+
243
+ q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
244
+ k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
245
+
246
+ return torch.cat(
247
+ [
248
+ torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
249
+ torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
250
+ qkv[:, :, 2:3, :, :],
251
+ ],
252
+ axis=2,
253
+ )
254
+
255
+
256
+ class RotaryEmbedding(nn.Module):
257
+ """Rotary positional embedding (RoPE).
258
+
259
+ Reference:
260
+ RoFormer: Enhanced Transformer with Rotary Position Embedding.
261
+ https://arxiv.org/pdf/2104.09864.pdf.
262
+
263
+ """
264
+
265
+ def __init__(
266
+ self,
267
+ dim: int,
268
+ base: int = 10000,
269
+ scale_base: Optional[float] = None,
270
+ pos_idx_in_fp32: bool = True,
271
+ max_position_embeddings: int = 2048,
272
+ device: Optional[str] = None,
273
+ **kwargs,
274
+ ) -> None:
275
+ super().__init__()
276
+
277
+ if scale_base is not None:
278
+ raise NotImplementedError
279
+
280
+ self.dim = dim
281
+ self.base = float(base)
282
+ self.scale_base = scale_base
283
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
284
+ self.max_position_embeddings = max_position_embeddings
285
+ self.device = device
286
+
287
+ # Generate and save the inverse frequency buffer (non-trainable)
288
+ inv_freq = self._compute_inv_freq(device)
289
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
290
+
291
+ # Generate and save the scale buffer (non-trainable)
292
+ scale = (
293
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
294
+ / (1.4 * dim)
295
+ if scale_base is not None
296
+ else None
297
+ )
298
+ self.register_buffer("scale", scale, persistent=False)
299
+
300
+ # Initialize cached attributes since ONNX can't rely on dynamic initialization
301
+ self._update_cos_sin_cache(
302
+ max_position_embeddings, device=device, dtype=torch.float32
303
+ )
304
+
305
+ def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
306
+ return 1.0 / (
307
+ self.base
308
+ ** (
309
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
310
+ / self.dim
311
+ )
312
+ )
313
+
314
+ def _update_cos_sin_cache(
315
+ self,
316
+ seqlen: int,
317
+ device: Optional[str] = None,
318
+ dtype: Optional[torch.dtype] = None,
319
+ ) -> None:
320
+ self._seq_len_cached = seqlen
321
+
322
+ # fp32 is preferred since the output of `torch.arange` can be quite large
323
+ # and bf16 would lose a lot of precision
324
+ if self.pos_idx_in_fp32:
325
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
326
+ if self.inv_freq.dtype != torch.float32:
327
+ inv_freq = self._compute_inv_freq(device=device)
328
+ else:
329
+ inv_freq = self.inv_freq
330
+ else:
331
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
332
+ inv_freq = self.inv_freq
333
+
334
+ # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
335
+ freqs = torch.outer(t, inv_freq)
336
+ if self.scale is None:
337
+ self._cos_cached = torch.cos(freqs).to(dtype)
338
+ self._sin_cached = torch.sin(freqs).to(dtype)
339
+ else:
340
+ power = (
341
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
342
+ - seqlen // 2
343
+ ) / self.scale_base
344
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
345
+
346
+ # Force the scale multiplication to happen in fp32
347
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
348
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
349
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
350
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
351
+
352
+ def forward(
353
+ self,
354
+ qkv: torch.Tensor,
355
+ kv: Optional[torch.Tensor] = None,
356
+ seqlen_offset: int = 0,
357
+ **kwargs,
358
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
359
+ if (
360
+ self._seq_len_cached < qkv.shape[1] + seqlen_offset
361
+ or self._cos_cached.device != qkv.device
362
+ or self._cos_cached.dtype != qkv.dtype
363
+ or (self.training and self._cos_cached.is_inference())
364
+ ):
365
+ self._update_cos_sin_cache(
366
+ qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype
367
+ )
368
+
369
+ if kv is None:
370
+ return _apply_rotary_emb_qkv(
371
+ qkv,
372
+ self._cos_cached[seqlen_offset:],
373
+ self._sin_cached[seqlen_offset:],
374
+ )
375
+ else:
376
+ q = _apply_rotary_emb(
377
+ qkv,
378
+ self._cos_cached[seqlen_offset:],
379
+ self._sin_cached[seqlen_offset:],
380
+ )
381
+ kv = _apply_rotary_emb_kv(
382
+ kv,
383
+ self._cos_cached[seqlen_offset:],
384
+ self._sin_cached[seqlen_offset:],
385
+ )
386
+
387
+ return q, kv
388
+
389
+
390
+ class MLP(nn.Module):
391
+ """Multi-Layer Perceptron.
392
+
393
+ Reference:
394
+ Attention Is All You Need.
395
+ https://arxiv.org/pdf/1706.03762.pdf.
396
+
397
+ """
398
+
399
+ def __init__(
400
+ self,
401
+ config: PretrainedConfig,
402
+ n_inner: Optional[int] = None,
403
+ act_fn: Optional[str] = None,
404
+ ) -> None:
405
+ super().__init__()
406
+
407
+ act_fn = config.activation_function if act_fn is None else act_fn
408
+
409
+ n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
410
+ n_inner = n_inner if n_inner is not None else 4 * config.n_embd
411
+
412
+ self.fc1 = nn.Linear(config.n_embd, n_inner)
413
+ self.fc2 = nn.Linear(n_inner, config.n_embd)
414
+ self.act = ACT2FN[act_fn]
415
+
416
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
417
+ hidden_states = self.fc1(hidden_states)
418
+ hidden_states = self.act(hidden_states)
419
+ hidden_states = self.fc2(hidden_states)
420
+
421
+ return hidden_states
422
+
423
+
424
+ class SelfAttention(nn.Module):
425
+ """Self-attention layer (compatible with PyTorch).
426
+
427
+ Reference:
428
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
429
+
430
+ """
431
+
432
+ def __init__(
433
+ self,
434
+ causal: bool = True,
435
+ softmax_scale: Optional[float] = None,
436
+ attention_dropout: float = 0.0,
437
+ ) -> None:
438
+ super().__init__()
439
+
440
+ self.causal = causal
441
+ self.softmax_scale = softmax_scale
442
+ self.drop = nn.Dropout(attention_dropout)
443
+
444
+ @torch.autocast("cpu", enabled=False)
445
+ @torch.autocast("cuda", enabled=False)
446
+ def forward(
447
+ self,
448
+ qkv: torch.FloatTensor,
449
+ causal: bool = None,
450
+ key_padding_mask: Optional[torch.BoolTensor] = None,
451
+ **kwargs,
452
+ ) -> torch.FloatTensor:
453
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
454
+ q, k, v = qkv.unbind(dim=2)
455
+
456
+ q = q.to(torch.float32)
457
+ k = k.to(torch.float32)
458
+
459
+ causal = self.causal if causal is None else causal
460
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
461
+
462
+ # Autocast is manually disabled to avoid `torch.einsum` performing the operation
463
+ # using float16, which might lead to overflow
464
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
465
+
466
+ if key_padding_mask is not None:
467
+ padding_mask = torch.full(
468
+ (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
469
+ )
470
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
471
+
472
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
473
+
474
+ if causal:
475
+ causal_mask = torch.triu(
476
+ torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
477
+ )
478
+ scores = scores + causal_mask.to(dtype=scores.dtype)
479
+
480
+ attention = torch.softmax(scores, dim=-1).to(v.dtype)
481
+ attention = self.drop(attention)
482
+
483
+ output = torch.einsum("bhts,bshd->bthd", attention, v)
484
+
485
+ return output
486
+
487
+
488
+ class CrossAttention(nn.Module):
489
+ """Cross-attention layer (compatible with PyTorch).
490
+
491
+ Reference:
492
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
493
+
494
+ """
495
+
496
+ def __init__(
497
+ self,
498
+ causal: bool = True,
499
+ softmax_scale: Optional[float] = None,
500
+ attention_dropout: float = 0.0,
501
+ ) -> None:
502
+ super().__init__()
503
+
504
+ self.causal = causal
505
+ self.softmax_scale = softmax_scale
506
+ self.drop = nn.Dropout(attention_dropout)
507
+
508
+ @torch.autocast("cpu", enabled=False)
509
+ @torch.autocast("cuda", enabled=False)
510
+ def forward(
511
+ self,
512
+ q: torch.FloatTensor,
513
+ kv: torch.FloatTensor,
514
+ causal: bool = None,
515
+ key_padding_mask: Optional[torch.BoolTensor] = None,
516
+ **kwargs,
517
+ ) -> torch.FloatTensor:
518
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
519
+ seqlen_k = kv.shape[1]
520
+
521
+ if kv.shape[3] != q.shape[2]:
522
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
523
+ k, v = kv.unbind(dim=2)
524
+
525
+ q = q.to(torch.float32)
526
+ k = k.to(torch.float32)
527
+
528
+ causal = self.causal if causal is None else causal
529
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
530
+
531
+ # Autocast is manually disabled to avoid `torch.einsum` performing the operation
532
+ # using float16, which might lead to overflow
533
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
534
+
535
+ if key_padding_mask is not None:
536
+ padding_mask = torch.full(
537
+ (batch_size, seqlen_k),
538
+ -10000.0,
539
+ dtype=scores.dtype,
540
+ device=scores.device,
541
+ )
542
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
543
+
544
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
545
+
546
+ if causal:
547
+ rows = rearrange(
548
+ torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
549
+ )
550
+ cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
551
+ causal_mask = cols > rows + seqlen_k - seqlen_q
552
+
553
+ scores = scores.masked_fill(causal_mask, -10000.0)
554
+
555
+ attention = torch.softmax(scores, dim=-1).to(v.dtype)
556
+ attention = self.drop(attention)
557
+
558
+ output = torch.einsum("bhts,bshd->bthd", attention, v)
559
+
560
+ return output
561
+
562
+
563
+ def _find_mha_dims(
564
+ config: PretrainedConfig,
565
+ n_head: Optional[int] = None,
566
+ n_head_kv: Optional[int] = None,
567
+ head_dim: Optional[int] = None,
568
+ ) -> Tuple[int, int]:
569
+ if n_head is None and head_dim is None:
570
+ head_dim = config.n_embd // config.n_head
571
+ n_head = config.n_head
572
+ elif n_head is None or head_dim is None:
573
+ raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
574
+
575
+ if n_head_kv is None:
576
+ n_head_kv = getattr(config, "n_head_kv", None) or n_head
577
+
578
+ return n_head, n_head_kv, head_dim
579
+
580
+
581
+ def _update_kv_cache(
582
+ kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int
583
+ ) -> torch.FloatTensor:
584
+ num_heads, head_dim = kv.shape[-2:]
585
+
586
+ if layer_idx not in inference_params.key_value_memory_dict:
587
+ inference_params.key_value_memory_dict[layer_idx] = torch.empty(
588
+ inference_params.max_batch_size,
589
+ inference_params.max_seqlen,
590
+ 2,
591
+ num_heads,
592
+ head_dim,
593
+ dtype=kv.dtype,
594
+ device=kv.device,
595
+ )
596
+
597
+ batch_start = inference_params.batch_size_offset
598
+ batch_end = batch_start + kv.shape[0]
599
+
600
+ sequence_start = inference_params.seqlen_offset
601
+ sequence_end = sequence_start + kv.shape[1]
602
+
603
+ # When the current sequence length is equal to or larger than the maximum sequence length,
604
+ # we need to concatenate the current `kv` with the cached `kv` to expand its length
605
+ if sequence_end >= inference_params.max_seqlen:
606
+ inference_params.key_value_memory_dict[layer_idx] = torch.concatenate(
607
+ (inference_params.key_value_memory_dict[layer_idx], kv), dim=1
608
+ )
609
+
610
+ inference_params.key_value_memory_dict[layer_idx][
611
+ batch_start:batch_end, sequence_start:sequence_end, ...
612
+ ] = kv
613
+ kv = inference_params.key_value_memory_dict[layer_idx][
614
+ batch_start:batch_end, :sequence_end, ...
615
+ ]
616
+
617
+ return kv
618
+
619
+
620
+ class MHA(nn.Module):
621
+ """Multi-head attention layer."""
622
+
623
+ def __init__(
624
+ self,
625
+ config: PretrainedConfig,
626
+ dtype: Optional[torch.dtype] = None,
627
+ device: Optional[str] = None,
628
+ rotary_dim: Optional[int] = None,
629
+ rotary_base: float = 10000.0,
630
+ rotary_scale_base: Optional[float] = None,
631
+ n_head: Optional[int] = None,
632
+ n_head_kv: Optional[int] = None,
633
+ head_dim: Optional[int] = None,
634
+ bias: bool = True,
635
+ causal: bool = True,
636
+ softmax_scale: Optional[float] = None,
637
+ layer_idx: Optional[int] = None,
638
+ return_residual: bool = False,
639
+ checkpointing: bool = False,
640
+ ) -> None:
641
+ super().__init__()
642
+
643
+ # Rotary embedding
644
+ self.rotary_dim = (
645
+ rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
646
+ )
647
+
648
+ if self.rotary_dim > 0:
649
+ self.rotary_emb = RotaryEmbedding(
650
+ self.rotary_dim,
651
+ base=rotary_base,
652
+ scale_base=rotary_scale_base,
653
+ device=device,
654
+ max_position_embeddings=config.n_positions,
655
+ )
656
+
657
+ # MLP
658
+ self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
659
+ config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim
660
+ )
661
+ op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
662
+ hidden_size = config.n_embd
663
+
664
+ linear_cls = FusedDense if config.fused_dense else nn.Linear
665
+ if linear_cls is None:
666
+ linear_cls = nn.Linear
667
+
668
+ self.Wqkv = linear_cls(
669
+ hidden_size, op_size, bias=bias, device=device, dtype=dtype
670
+ )
671
+ self.out_proj = linear_cls(
672
+ hidden_size, hidden_size, bias=bias, device=device, dtype=dtype
673
+ )
674
+
675
+ # Attention
676
+ self.inner_attn = SelfAttention(
677
+ causal=causal,
678
+ softmax_scale=softmax_scale,
679
+ attention_dropout=config.attn_pdrop,
680
+ )
681
+ self.inner_cross_attn = CrossAttention(
682
+ causal=causal,
683
+ softmax_scale=softmax_scale,
684
+ attention_dropout=config.attn_pdrop,
685
+ )
686
+
687
+ self.layer_idx = layer_idx
688
+ self.return_residual = return_residual
689
+ self.checkpointing = checkpointing
690
+
691
+ def _forward_self_attn(
692
+ self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
693
+ ) -> torch.FloatTensor:
694
+ qkv = self.Wqkv(x)
695
+ qkv = rearrange(
696
+ qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
697
+ )
698
+
699
+ if self.rotary_dim > 0:
700
+ qkv = self.rotary_emb(qkv)
701
+
702
+ if self.checkpointing:
703
+ return torch.utils.checkpoint.checkpoint(
704
+ self.inner_attn, qkv, key_padding_mask=key_padding_mask
705
+ )
706
+
707
+ return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
708
+
709
+ def _forward_cross_attn(
710
+ self,
711
+ x: torch.FloatTensor,
712
+ past_key_values: Optional[InferenceParams],
713
+ key_padding_mask: Optional[torch.BoolTensor],
714
+ ) -> torch.FloatTensor:
715
+ batch_size = x.shape[0]
716
+
717
+ qkv = self.Wqkv(x)
718
+
719
+ q = qkv[..., : self.n_head * self.head_dim]
720
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
721
+
722
+ kv = qkv[..., self.n_head * self.head_dim :]
723
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
724
+
725
+ seqlen_offset = (
726
+ past_key_values.seqlen_offset if past_key_values is not None else 0
727
+ )
728
+ causal = None if seqlen_offset == 0 else False
729
+ if self.rotary_dim > 0:
730
+ q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
731
+
732
+ if past_key_values is not None:
733
+ kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
734
+
735
+ if self.checkpointing:
736
+ return torch.utils.checkpoint.checkpoint(
737
+ self.inner_cross_attn,
738
+ q,
739
+ kv,
740
+ key_padding_mask=key_padding_mask,
741
+ causal=causal,
742
+ )
743
+
744
+ return self.inner_cross_attn(
745
+ q, kv, key_padding_mask=key_padding_mask, causal=causal
746
+ )
747
+
748
+ def forward(
749
+ self,
750
+ x: torch.FloatTensor,
751
+ past_key_values: Optional[InferenceParams] = None,
752
+ attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
753
+ **kwargs,
754
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
755
+ if attention_mask is not None:
756
+ attention_mask = attention_mask.bool()
757
+ else:
758
+ attention_mask = None
759
+
760
+ # MHA
761
+ if self.n_head == self.n_head_kv:
762
+ if past_key_values is None:
763
+ # If `past_key_values` are not supplied, we run self-attention
764
+ attn_output = self._forward_self_attn(x, attention_mask)
765
+ else:
766
+ # If `past_key_values` are supplied, it means that we might have cached values and
767
+ # could take advantage of cross-attention
768
+ attn_output = self._forward_cross_attn(
769
+ x, past_key_values, attention_mask
770
+ )
771
+ # MQA / GQA
772
+ else:
773
+ # Regardless of `past_key_values` being supplied or not, it always use cross-attention
774
+ # because `q` and `kv` lengths might be different
775
+ attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
776
+
777
+ output = rearrange(attn_output, "... h d -> ... (h d)")
778
+ output = self.out_proj(output)
779
+
780
+ return output if not self.return_residual else (output, x)
781
+
782
+
783
+ class ParallelBlock(nn.Module):
784
+ """Parallel block.
785
+
786
+ This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
787
+
788
+ """
789
+
790
+ def __init__(
791
+ self,
792
+ config: PretrainedConfig,
793
+ block_idx: Optional[int] = None,
794
+ ) -> None:
795
+ super().__init__()
796
+
797
+ self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
798
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
799
+ self.block_idx = block_idx
800
+
801
+ self.mixer = MHA(config, layer_idx=block_idx)
802
+ self.mlp = MLP(config)
803
+
804
+ def forward(
805
+ self,
806
+ hidden_states: torch.FloatTensor,
807
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
808
+ attention_mask: Optional[torch.BoolTensor] = None,
809
+ **kwargs,
810
+ ) -> torch.FloatTensor:
811
+ residual = hidden_states
812
+ hidden_states = self.ln(hidden_states)
813
+
814
+ attn_outputs = self.mixer(
815
+ hidden_states,
816
+ past_key_values=past_key_values,
817
+ attention_mask=attention_mask,
818
+ )
819
+ if isinstance(attn_outputs, tuple):
820
+ attn_outputs = attn_outputs[0]
821
+
822
+ attn_outputs = self.resid_dropout(attn_outputs)
823
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
824
+
825
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
826
+
827
+ return hidden_states
828
+
829
+
830
+ class CausalLMHead(nn.Module):
831
+ """Causal Language Modeling head.
832
+
833
+ Reference:
834
+ Improving Language Understanding by Generative Pre-Training.
835
+ https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
836
+
837
+ """
838
+
839
+ def __init__(self, config: PretrainedConfig) -> None:
840
+ super().__init__()
841
+
842
+ self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
843
+ self.linear = nn.Linear(config.n_embd, config.vocab_size)
844
+
845
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
846
+ hidden_states = self.ln(hidden_states)
847
+ logits = self.linear(hidden_states).to(torch.float32)
848
+
849
+ return logits
850
+
851
+
852
+ class CausalLMLoss(nn.Module):
853
+ """Causal Language Modeling loss.
854
+
855
+ Reference:
856
+ Improving Language Understanding by Generative Pre-Training.
857
+ https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
858
+
859
+ """
860
+
861
+ def __init__(self, shift_labels: bool = True) -> None:
862
+ super().__init__()
863
+
864
+ self.shift_labels = shift_labels
865
+ self.loss_fct = nn.CrossEntropyLoss()
866
+
867
+ def forward(
868
+ self, logits: torch.FloatTensor, labels: torch.LongTensor
869
+ ) -> torch.FloatTensor:
870
+ if self.shift_labels:
871
+ logits = logits[..., :-1, :].contiguous()
872
+ labels = labels[..., 1:].contiguous()
873
+
874
+ loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
875
+
876
+ return loss
877
+
878
+
879
+ class PhiPreTrainedModel(PreTrainedModel):
880
+ """Phi pre-trained model."""
881
+
882
+ config_class = PhiConfig
883
+ base_model_prefix = "transformer"
884
+ supports_gradient_checkpointing = False
885
+ _no_split_modules = ["ParallelBlock"]
886
+
887
+ def __init__(self, *inputs, **kwargs) -> None:
888
+ super().__init__(*inputs, **kwargs)
889
+
890
+ def prepare_inputs_for_generation(
891
+ self,
892
+ input_ids: torch.LongTensor = None,
893
+ inputs_embeds: torch.FloatTensor = None,
894
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
895
+ attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
896
+ **kwargs,
897
+ ) -> Dict[str, Any]:
898
+ if inputs_embeds is not None:
899
+ max_batch_size = inputs_embeds.shape[0]
900
+ seqlen_offset = inputs_embeds.shape[1] + input_ids.shape[1] - 2
901
+ elif input_ids is not None:
902
+ max_batch_size = input_ids.shape[0]
903
+ seqlen_offset = input_ids.shape[1] - 1
904
+ else:
905
+ raise ValueError(
906
+ "You have to specify either `input_ids` or `inputs_embeds`."
907
+ )
908
+
909
+ args = {}
910
+
911
+ if past_key_values is None or not (
912
+ isinstance(past_key_values, InferenceParams)
913
+ ):
914
+ past_key_values = InferenceParams(
915
+ max_seqlen=self.config.n_positions,
916
+ max_batch_size=max_batch_size,
917
+ seqlen_offset=0,
918
+ batch_size_offset=0,
919
+ key_value_memory_dict={},
920
+ lengths_per_sample=None,
921
+ )
922
+ if inputs_embeds is not None:
923
+ args = {"inputs_embeds": inputs_embeds}
924
+ elif input_ids is not None:
925
+ args = {"input_ids": input_ids}
926
+ else:
927
+ raise ValueError(
928
+ "You have to specify either `input_ids` or `inputs_embeds`."
929
+ )
930
+ else:
931
+ # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
932
+ past_key_values.seqlen_offset = seqlen_offset
933
+ input_ids = input_ids[:, -1].unsqueeze(-1)
934
+ args = {"input_ids": input_ids}
935
+
936
+ return {
937
+ **args,
938
+ "past_key_values": past_key_values,
939
+ "attention_mask": attention_mask,
940
+ }
941
+
942
+
943
+ class PhiModel(PhiPreTrainedModel):
944
+ """Phi model."""
945
+
946
+ _keys_to_ignore_on_load_missing = [""]
947
+ _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
948
+
949
+ def __init__(self, config: PhiConfig) -> None:
950
+ super().__init__(config)
951
+
952
+ self.embd = Embedding(config)
953
+ self.h = nn.ModuleList(
954
+ [ParallelBlock(config, block_idx=i) for i in range(config.n_layer)]
955
+ )
956
+ self.gradient_checkpointing = config.gradient_checkpointing
957
+ self.post_init()
958
+
959
+ def get_input_embeddings(self) -> nn.Embedding:
960
+ return self.embd.wte
961
+
962
+ def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
963
+ self.embd.wte = new_embeddings
964
+
965
+ def forward(
966
+ self,
967
+ input_ids: torch.LongTensor = None,
968
+ inputs_embeds: torch.FloatTensor = None,
969
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
970
+ attention_mask: Optional[torch.BoolTensor] = None,
971
+ ) -> torch.FloatTensor:
972
+ if input_ids is not None and inputs_embeds is not None:
973
+ raise ValueError(
974
+ "You cannot specify both `input_ids` and `inputs_embeds` at the same time."
975
+ )
976
+ elif input_ids is None and inputs_embeds is None:
977
+ raise ValueError(
978
+ "You have to specify either `input_ids` or `inputs_embeds`."
979
+ )
980
+ elif input_ids is not None:
981
+ hidden_states = self.embd(input_ids)
982
+ else:
983
+ hidden_states = inputs_embeds
984
+
985
+ for layer in self.h:
986
+ if self.gradient_checkpointing:
987
+ hidden_states = torch.utils.checkpoint.checkpoint(
988
+ layer.__call__,
989
+ hidden_states,
990
+ past_key_values,
991
+ attention_mask,
992
+ use_reentrant=True,
993
+ )
994
+ else:
995
+ hidden_states = layer(
996
+ hidden_states,
997
+ past_key_values=past_key_values,
998
+ attention_mask=attention_mask,
999
+ )
1000
+
1001
+ return hidden_states
1002
+
1003
+
1004
+ class PhiForCausalLM(PhiPreTrainedModel):
1005
+ """Phi for Causal Language Modeling."""
1006
+
1007
+ _keys_to_ignore_on_load_missing = [""]
1008
+ _keys_to_ignore_on_load_unexpected = [
1009
+ r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"
1010
+ ]
1011
+
1012
+ def __init__(self, config: PhiConfig) -> None:
1013
+ super().__init__(config)
1014
+
1015
+ self.transformer = PhiModel(config)
1016
+ self.lm_head = CausalLMHead(config)
1017
+ self.loss = CausalLMLoss()
1018
+
1019
+ self.post_init()
1020
+
1021
+ def get_output_embeddings(self) -> nn.Linear:
1022
+ return self.lm_head.linear
1023
+
1024
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
1025
+ self.lm_head.linear = new_embeddings
1026
+
1027
+ def forward(
1028
+ self,
1029
+ input_ids: torch.LongTensor = None,
1030
+ inputs_embeds: torch.FloatTensor = None,
1031
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
1032
+ attention_mask: Optional[torch.BoolTensor] = None,
1033
+ labels: Optional[torch.LongTensor] = None,
1034
+ **kwargs,
1035
+ ) -> CausalLMOutputWithPast:
1036
+ hidden_states = self.transformer(
1037
+ input_ids,
1038
+ inputs_embeds,
1039
+ past_key_values=past_key_values,
1040
+ attention_mask=attention_mask,
1041
+ )
1042
+ lm_logits = self.lm_head(hidden_states)
1043
+
1044
+ loss = None
1045
+ if labels is not None:
1046
+ loss = self.loss(lm_logits, labels)
1047
+
1048
+ return CausalLMOutputWithPast(
1049
+ loss=loss, logits=lm_logits, past_key_values=past_key_values
1050
+ )
1051
+
1052
+
1053
+ class VisionEncoder(nn.Module):
1054
+ def __init__(self, model_path: str = "model") -> None:
1055
+ super().__init__()
1056
+ self.model = torch.jit.load(f"{model_path}/vision.pt").to(DEVICE, dtype=DTYPE)
1057
+ self.preprocess = Compose(
1058
+ [
1059
+ Resize(size=(384, 384), interpolation=InterpolationMode.BICUBIC),
1060
+ ToImage(),
1061
+ ToDtype(torch.float32, scale=True),
1062
+ Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
1063
+ ]
1064
+ )
1065
+
1066
+ def __call__(self, image: Image) -> torch.Tensor:
1067
+ with torch.no_grad():
1068
+ image_vec = self.preprocess(image.convert("RGB")).unsqueeze(0)
1069
+ image_vec = image_vec[:, :, :-6, :-6]
1070
+ image_vec = rearrange(
1071
+ image_vec, "b c (h p1) (w p2) -> b (h w) (c p1 p2)", p1=14, p2=14
1072
+ )
1073
+
1074
+ image_vec = image_vec.to(DEVICE, dtype=DTYPE)
1075
+ return self.model(image_vec)
1076
+
1077
+
1078
+ class TextModel(nn.Module):
1079
+ def __init__(self, model_path: str = "model") -> None:
1080
+ super().__init__()
1081
+ self.tokenizer = Tokenizer.from_pretrained(f"{model_path}/tokenizer")
1082
+ phi_config = PhiConfig.from_pretrained(f"{model_path}/text_model_cfg.json")
1083
+
1084
+ with init_empty_weights():
1085
+ self.model = PhiForCausalLM(phi_config)
1086
+
1087
+ self.model = load_checkpoint_and_dispatch(
1088
+ self.model,
1089
+ f"{model_path}/text_model.pt",
1090
+ device_map={"": DEVICE},
1091
+ dtype=DTYPE,
1092
+ )
1093
+
1094
+ self.text_emb = self.model.get_input_embeddings()
1095
+
1096
+ def input_embeds(self, prompt, image_embeds):
1097
+ embeds = []
1098
+
1099
+ def _add_toks(toks):
1100
+ embeds.append(self.text_emb(toks))
1101
+
1102
+ def _tokenize(txt):
1103
+ return self.tokenizer(
1104
+ txt, return_tensors="pt", add_special_tokens=False
1105
+ ).input_ids.to(self.model.device)
1106
+
1107
+ # Add BOS token
1108
+ _add_toks(
1109
+ torch.tensor([[self.tokenizer.bos_token_id]], device=self.model.device)
1110
+ )
1111
+
1112
+ if "<image>" not in prompt:
1113
+ embeds.append(self.text_emb(_tokenize(prompt)))
1114
+ else:
1115
+ assert prompt.count("<image>") == 1
1116
+ before, after = prompt.split("<image>")
1117
+ embeds.append(self.text_emb(_tokenize(f"{before}<image>")))
1118
+ embeds.append(image_embeds.to(self.model.device))
1119
+ embeds.append(self.text_emb(_tokenize(f"</image>{after}")))
1120
+
1121
+ return torch.cat(embeds, dim=1)
1122
+
1123
+ def generate(
1124
+ self, image_embeds, prompt, eos_text="Human:", max_new_tokens=128, **kwargs
1125
+ ):
1126
+ eos_tokens = self.tokenizer(eos_text, add_special_tokens=False)[0].ids
1127
+
1128
+ generate_config = {
1129
+ "eos_token_id": eos_tokens,
1130
+ "bos_token_id": self.tokenizer.bos_token_id,
1131
+ "pad_token_id": self.tokenizer.eos_token_id,
1132
+ "max_new_tokens": max_new_tokens,
1133
+ **kwargs,
1134
+ }
1135
+
1136
+ with torch.no_grad():
1137
+ inputs_embeds = self.input_embeds(prompt, image_embeds)
1138
+ output_ids = self.model.generate(
1139
+ inputs_embeds=inputs_embeds, **generate_config
1140
+ )
1141
+
1142
+ return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
1143
+
1144
+ def answer_question(self, image_embeds, question, **kwargs):
1145
+ prompt = f"<image>\n\nQuestion: {question}\n\nAnswer:"
1146
+ answer = self.generate(
1147
+ image_embeds,
1148
+ prompt,
1149
+ eos_text="<END>",
1150
+ max_new_tokens=128,
1151
+ **kwargs,
1152
+ )[0]
1153
+
1154
+ return re.sub("<$", "", re.sub("END$", "", answer)).strip()
1155
+
1156
+
1157
+ ##### GRADIO INTERFACE #####
1158
+
1159
+ import gradio as gr
1160
+ from huggingface_hub import snapshot_download
1161
+ from threading import Thread
1162
+ from transformers import TextIteratorStreamer
1163
+ import hashlib
1164
+ import os
1165
+
1166
+ model_path = snapshot_download("vikhyatk/moondream1")
1167
+
1168
+ vision_encoder = VisionEncoder(model_path).to(DEVICE, dtype=DTYPE)
1169
+ text_model = TextModel(model_path).to(DEVICE, dtype=DTYPE)
1170
+
1171
+
1172
+ def cached_vision_encoder(image):
1173
+ # Calculate checksum of the image
1174
+ image_hash = hashlib.sha256(image.tobytes()).hexdigest()
1175
+
1176
+ # Check if `image_encoder_cache/{image_hash}.pt` exists, if so load and return it.
1177
+ # Otherwise, save the encoded image to `image_encoder_cache/{image_hash}.pt` and return it.
1178
+ cache_path = f"image_encoder_cache/{image_hash}.pt"
1179
+ if os.path.exists(cache_path):
1180
+ return torch.load(cache_path).to(DEVICE, dtype=DTYPE)
1181
+ else:
1182
+ image_vec = vision_encoder(image).to("cpu", dtype=torch.float16)
1183
+ os.makedirs("image_encoder_cache", exist_ok=True)
1184
+ torch.save(image_vec, cache_path)
1185
+ return image_vec.to(DEVICE, dtype=DTYPE)
1186
+
1187
+
1188
+ def answer_question(image, question):
1189
+ yield "Encoding image..."
1190
+
1191
+ streamer = TextIteratorStreamer(text_model.tokenizer, skip_special_tokens=True)
1192
+ generation_kwargs = dict(
1193
+ image_embeds=cached_vision_encoder(image), question=question, streamer=streamer
1194
+ )
1195
+ thread = Thread(target=text_model.answer_question, kwargs=generation_kwargs)
1196
+ thread.start()
1197
+
1198
+ buffer = ""
1199
+ for new_text in streamer:
1200
+ buffer += new_text
1201
+ if len(buffer) > 1:
1202
+ yield re.sub("<$", "", re.sub("END$", "", buffer))
1203
+
1204
+
1205
+ gr.Interface(
1206
+ title="🌔 moondream1",
1207
+ description="""
1208
+ moondream1 is a tiny (1.6B parameter) vision language model that performs
1209
+ competitively with models twice its size. It is trained on the LLaVa training
1210
+ dataset, and initialized with SigLIP as the vision tower and Phi-1.5 as the
1211
+ text encoder. Check out the <a href="https://huggingface.co/vikhyatk/moondream1">HuggingFace
1212
+ model card</a> for more details.
1213
+ """,
1214
+ fn=answer_question,
1215
+ inputs=[gr.Image(type="pil"), gr.Textbox(lines=2, label="Question")],
1216
+ examples=[
1217
+ [Image.open("assets/demo-1.jpg"), "Who is the author of this book?"],
1218
+ [Image.open("assets/demo-2.jpg"), "What type of food is the girl eating?"],
1219
+ [
1220
+ Image.open("assets/demo-3.jpg"),
1221
+ "What kind of public transportation is in the image?",
1222
+ ],
1223
+ [Image.open("assets/demo-4.jpg"), "What is the girl looking at?"],
1224
+ [Image.open("assets/demo-5.jpg"), "What kind of dog is in the picture?"],
1225
+ ],
1226
+ outputs=gr.TextArea(label="Answer"),
1227
+ allow_flagging=False,
1228
+ ).launch()
assets/demo-1.jpg ADDED
assets/demo-2.jpg ADDED
assets/demo-3.jpg ADDED
assets/demo-4.jpg ADDED
assets/demo-5.jpg ADDED
image_encoder_cache/33c2dc3e4183b82eb8e09ecceb8c3f30262237d6ee0f49904e1fbde913195ffb.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c3d02a8ff0a7ec5883e696bc513bee6df9068127fa263bd83066584db20f0f5
3
+ size 2987641
image_encoder_cache/7ffdc9dcb8e0304e8658c0011cc71cb993e4729af95ebbe890f91a4dfce46170.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b1610027c6a749a9ab9c41a0c5cdec466e29e1945a0ebcd73e4e7edd261bb80
3
+ size 2987641
image_encoder_cache/b3ba6d3612f786df76a0cc5603c9a3d4cc6bede86d4c0f3001092fe0cc67f132.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9468baee42cba3ce022fc1a59c3f69aa26bced6d888b0ddf90ca6836752ba48e
3
+ size 2987641
image_encoder_cache/f5c5b77a4b0025925f7313a019fd91e435caf4b1ce651794c0c7bf5c4ea92827.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd546e1eec69e7c57179cb578b292a6a4403985309b8e54db6cedeb1517bf2a2
3
+ size 2987641
image_encoder_cache/f6defd804af24f18ea5d966b6b248f3dccc514b06155648592b93af567157a4e.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:613ffac2302c6b8c780235ad62b175850b07c8b8c7dc39148b5c57a324ed9868
3
+ size 2987641
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ accelerate==0.25.0
2
+ huggingface-hub==0.20.1
3
+ Pillow==10.1.0
4
+ torch==2.1.2
5
+ torchvision==0.16.2
6
+ transformers==4.36.2
7
+ einops==0.7.0