junlong jia commited on
Commit
a0b5907
1 Parent(s): 9488eaa

commit from junlongjia

Browse files
config.json ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TinyLlavaForConditionalGeneration"
4
+ ],
5
+ "cache_dir": null,
6
+ "connector_type": "mlp2x_gelu",
7
+ "hidden_size": 1536,
8
+ "ignore_index": -100,
9
+ "image_aspect_ratio": "square",
10
+ "image_token_index": -200,
11
+ "llm_model_name_or_path": "apple/OpenELM-450M-Instruct",
12
+ "model_type": "tinyllava",
13
+ "num_queries": 128,
14
+ "num_resampler_layers": 3,
15
+ "pad_token": "<unk>",
16
+ "pad_token_id": 0,
17
+ "resampler_hidden_size": 768,
18
+ "text_config": {
19
+ "_name_or_path": "apple/OpenELM-450M-Instruct",
20
+ "activation_fn_name": "swish",
21
+ "architectures": [
22
+ "OpenELMForCausalLM"
23
+ ],
24
+ "auto_map": {
25
+ "AutoConfig": "apple/OpenELM-450M-Instruct--configuration_openelm.OpenELMConfig",
26
+ "AutoModelForCausalLM": "apple/OpenELM-450M-Instruct--modeling_openelm.OpenELMForCausalLM"
27
+ },
28
+ "ffn_dim_divisor": 256,
29
+ "ffn_multipliers": [
30
+ 0.5,
31
+ 0.68,
32
+ 0.87,
33
+ 1.05,
34
+ 1.24,
35
+ 1.42,
36
+ 1.61,
37
+ 1.79,
38
+ 1.97,
39
+ 2.16,
40
+ 2.34,
41
+ 2.53,
42
+ 2.71,
43
+ 2.89,
44
+ 3.08,
45
+ 3.26,
46
+ 3.45,
47
+ 3.63,
48
+ 3.82,
49
+ 4.0
50
+ ],
51
+ "ffn_with_glu": true,
52
+ "head_dim": 64,
53
+ "max_context_length": 2048,
54
+ "model_dim": 1536,
55
+ "model_type": "openelm",
56
+ "normalization_layer_name": "rms_norm",
57
+ "normalize_qk_projections": true,
58
+ "num_gqa_groups": 4,
59
+ "num_kv_heads": [
60
+ 3,
61
+ 3,
62
+ 3,
63
+ 4,
64
+ 4,
65
+ 4,
66
+ 4,
67
+ 4,
68
+ 4,
69
+ 4,
70
+ 5,
71
+ 5,
72
+ 5,
73
+ 5,
74
+ 5,
75
+ 5,
76
+ 6,
77
+ 6,
78
+ 6,
79
+ 6
80
+ ],
81
+ "num_query_heads": [
82
+ 12,
83
+ 12,
84
+ 12,
85
+ 16,
86
+ 16,
87
+ 16,
88
+ 16,
89
+ 16,
90
+ 16,
91
+ 16,
92
+ 20,
93
+ 20,
94
+ 20,
95
+ 20,
96
+ 20,
97
+ 20,
98
+ 24,
99
+ 24,
100
+ 24,
101
+ 24
102
+ ],
103
+ "num_transformer_layers": 20,
104
+ "qkv_multipliers": [
105
+ 0.5,
106
+ 1.0
107
+ ],
108
+ "rope_freq_constant": 10000,
109
+ "rope_max_length": 4096,
110
+ "share_input_output_layers": true,
111
+ "tie_word_embeddings": true,
112
+ "torch_dtype": "float16"
113
+ },
114
+ "tokenizer_model_max_length": 2048,
115
+ "tokenizer_name_or_path": "meta-llama/Llama-2-7b-hf",
116
+ "tokenizer_padding_side": "right",
117
+ "tokenizer_use_fast": false,
118
+ "torch_dtype": "float16",
119
+ "transformers_version": "4.39.3",
120
+ "tune_type_connector": "full",
121
+ "tune_type_llm": "full",
122
+ "tune_type_vision_tower": "frozen",
123
+ "tune_vision_tower_from_layer": 0,
124
+ "use_cache": true,
125
+ "vision_config": {
126
+ "hidden_act": "gelu_pytorch_tanh",
127
+ "hidden_size": 1152,
128
+ "image_size": 384,
129
+ "intermediate_size": 4304,
130
+ "layer_norm_eps": 1e-06,
131
+ "model_name_or_path": "google/siglip-so400m-patch14-384",
132
+ "model_name_or_path2": "",
133
+ "model_type": "siglip_vision_model",
134
+ "num_attention_heads": 16,
135
+ "num_hidden_layers": 27,
136
+ "patch_size": 14
137
+ },
138
+ "vision_feature_layer": -2,
139
+ "vision_feature_select_strategy": "patch",
140
+ "vision_hidden_size": 1152,
141
+ "vision_model_name_or_path": "google/siglip-so400m-patch14-384",
142
+ "vision_model_name_or_path2": "",
143
+ "vocab_size": 32000
144
+ }
configuration.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from transformers import CONFIG_MAPPING
3
+ from transformers import AutoConfig
4
+ from utils import *
5
+
6
+ class TinyLlavaConfig(PretrainedConfig):
7
+
8
+ model_type = "tinyllava"
9
+ def __init__(
10
+ self,
11
+ llm_model_name_or_path = '',
12
+ tokenizer_name_or_path = None,
13
+ vision_model_name_or_path = '',
14
+ vision_model_name_or_path2 = '',
15
+ connector_type = None,
16
+ text_config=None,
17
+ hidden_size=2048,
18
+ vocab_size=32000,
19
+ ignore_index=-100,
20
+ image_token_index=32000,
21
+ pad_token = None,
22
+ pad_token_id = None,
23
+ tokenizer_padding_side = 'right',
24
+ tokenizer_model_max_length = 2048,
25
+ vision_config = None,
26
+ vision_hidden_size = None,
27
+ vision_feature_layer = -2,
28
+ vision_feature_select_strategy = 'patch',
29
+ image_aspect_ratio = 'square',
30
+ resampler_hidden_size = None,
31
+ num_queries = None,
32
+ num_resampler_layers = None,
33
+ use_cache = False,
34
+ cache_dir = None,
35
+ tokenizer_use_fast = False,
36
+ tune_type_llm = 'frozen',
37
+ tune_type_connector = 'frozen',
38
+ tune_type_vision_tower = 'frozen',
39
+ tune_vision_tower_from_layer = -1,
40
+
41
+ **kwargs
42
+
43
+ ):
44
+ self.llm_model_name_or_path = llm_model_name_or_path
45
+ self.tokenizer_name_or_path = tokenizer_name_or_path or self.llm_model_name_or_path
46
+ self.vision_model_name_or_path = vision_model_name_or_path
47
+ self.vision_model_name_or_path2 = vision_model_name_or_path2
48
+ self.connector_type = connector_type
49
+ self.tune_type_llm = tune_type_llm
50
+ self.tune_type_connector = tune_type_connector
51
+ self.tune_type_vision_tower = tune_type_vision_tower
52
+ self.tune_vision_tower_from_layer = tune_vision_tower_from_layer
53
+
54
+ self.ignore_index = IGNORE_INDEX
55
+ self.image_token_index = IMAGE_TOKEN_INDEX
56
+ self.pad_token = pad_token
57
+ self.pad_token_id = pad_token_id
58
+ self.tokenizer_padding_side = tokenizer_padding_side
59
+ self.tokenizer_model_max_length = tokenizer_model_max_length
60
+ self.vision_feature_layer = vision_feature_layer
61
+ self.vision_feature_select_strategy = vision_feature_select_strategy
62
+ self.image_aspect_ratio = image_aspect_ratio
63
+ self.resampler_hidden_size = resampler_hidden_size
64
+ self.num_queries = num_queries
65
+ self.num_resampler_layers = num_resampler_layers
66
+ self.use_cache = use_cache
67
+ self.cache_dir = cache_dir
68
+ self.tokenizer_use_fast = tokenizer_use_fast
69
+ self._load_text_config(text_config)
70
+ self._load_vision_config(vision_config)
71
+
72
+ super().__init__(**kwargs)
73
+
74
+
75
+ def _load_text_config(self, text_config=None):
76
+ if self.llm_model_name_or_path is None or self.llm_model_name_or_path == '':
77
+ self.text_config = CONFIG_MAPPING['llama']()
78
+
79
+ else:
80
+ self.text_config = AutoConfig.from_pretrained(self.llm_model_name_or_path, trust_remote_code=True)
81
+ if text_config is not None:
82
+ self.text_config = self.text_config.from_dict(text_config)
83
+
84
+ self.hidden_size = getattr(self.text_config, 'hidden_size', getattr(self.text_config, 'model_dim', None))
85
+ self.vocab_size = getattr(self.text_config, 'vocab_size', None)
86
+
87
+
88
+
89
+ def _load_vision_config(self, vision_config=None):
90
+ if self.vision_model_name_or_path is None or self.vision_model_name_or_path == '':
91
+ self.vision_config = CONFIG_MAPPING['clip_vision_model'](
92
+ intermediate_size=4096,
93
+ hidden_size=1024,
94
+ patch_size=14,
95
+ image_size=336,
96
+ num_hidden_layers=24,
97
+ num_attention_heads=16,
98
+ vocab_size=32000,
99
+ projection_dim=768,
100
+ )
101
+
102
+ else:
103
+ self.vision_config = AutoConfig.from_pretrained(self.vision_model_name_or_path.split(':')[-1])
104
+ self.vision_config = getattr(self.vision_config, 'vision_config', self.vision_config)
105
+ if vision_config is not None:
106
+ self.vision_config = self.vision_config.from_dict(vision_config)
107
+
108
+ self.vision_config.model_name_or_path = self.vision_model_name_or_path.split(':')[-1]
109
+ self.vision_config.model_name_or_path2 = self.vision_model_name_or_path2.split(':')[-1]
110
+ self.vision_hidden_size = getattr(self.vision_config, 'hidden_size', None)
111
+
112
+
conversion.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ MPT = auto()
11
+ PLAIN = auto()
12
+ LLAMA_2 = auto()
13
+ TINY_LLAMA = auto()
14
+ QWEN_2 = auto()
15
+
16
+
17
+ @dataclasses.dataclass
18
+ class Conversation:
19
+ """A class that keeps all conversation history."""
20
+ system: str
21
+ roles: List[str]
22
+ messages: List[List[str]]
23
+ offset: int
24
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
25
+ sep: str = "###"
26
+ sep2: str = None
27
+ version: str = "Unknown"
28
+
29
+ skip_next: bool = False
30
+
31
+ def get_prompt(self):
32
+ messages = self.messages
33
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
34
+ messages = self.messages.copy()
35
+ init_role, init_msg = messages[0].copy()
36
+ init_msg = init_msg[0].replace("<image>", "").strip()
37
+ if 'mmtag' in self.version:
38
+ messages[0] = (init_role, init_msg)
39
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
40
+ messages.insert(1, (self.roles[1], "Received."))
41
+ else:
42
+ messages[0] = (init_role, "<image>\n" + init_msg)
43
+
44
+ if self.sep_style == SeparatorStyle.SINGLE:
45
+ ret = self.system + self.sep
46
+ for role, message in messages:
47
+ if message:
48
+ if type(message) is tuple:
49
+ message, _, _ = message
50
+ ret += role + ": " + message + self.sep
51
+ else:
52
+ ret += role + ":"
53
+ elif self.sep_style == SeparatorStyle.TWO:
54
+ seps = [self.sep, self.sep2]
55
+ ret = self.system + seps[0]
56
+ for i, (role, message) in enumerate(messages):
57
+ if message:
58
+ if type(message) is tuple:
59
+ message, _, _ = message
60
+ ret += role + ": " + message + seps[i % 2]
61
+ else:
62
+ ret += role + ":"
63
+ elif self.sep_style == SeparatorStyle.MPT:
64
+ ret = self.system + self.sep
65
+ for role, message in messages:
66
+ if message:
67
+ if type(message) is tuple:
68
+ message, _, _ = message
69
+ ret += role + message + self.sep
70
+ else:
71
+ ret += role
72
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
73
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
74
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
75
+ ret = ""
76
+
77
+ for i, (role, message) in enumerate(messages):
78
+ if i == 0:
79
+ assert message, "first message should not be none"
80
+ assert role == self.roles[0], "first message should come from user"
81
+ if message:
82
+ if type(message) is tuple:
83
+ message, _, _ = message
84
+ if i == 0: message = wrap_sys(self.system) + message
85
+ if i % 2 == 0:
86
+ message = wrap_inst(message)
87
+ ret += self.sep + message
88
+ else:
89
+ ret += " " + message + " " + self.sep2
90
+ else:
91
+ ret += ""
92
+ ret = ret.lstrip(self.sep)
93
+ elif self.sep_style == SeparatorStyle.TINY_LLAMA:
94
+ sep = "</s>"
95
+ wrap_sys = lambda msg: f"<|system|>\n{msg}\n"
96
+ wrap_user = lambda msg: f"<|user|>\n{msg}\n"
97
+ wrap_assistant = lambda msg: f"<|assistant|>\n{msg}"
98
+ ret = ""
99
+
100
+ for i, (role, message) in enumerate(messages):
101
+ if i == 0:
102
+ assert message, "first message should not be none"
103
+ assert role == self.roles[0], "first message should come from user"
104
+ if message:
105
+ if type(message) is tuple:
106
+ message, _, _ = message
107
+ if i % 2 == 0:
108
+ message = wrap_user(message)
109
+ if i == 0:
110
+ message = wrap_sys(self.system) + message
111
+ ret += self.sep + message
112
+ else:
113
+ message = wrap_assistant(message) + self.sep2
114
+ ret += message
115
+ else:
116
+ ret += "<|assistant|>\n"
117
+ ret = ret.lstrip(self.sep)
118
+ elif self.sep_style == SeparatorStyle.QWEN_2:
119
+ ret = self.system + self.sep
120
+ for role, message in messages:
121
+ if message:
122
+ if type(message) is tuple:
123
+ message, _, _ = message
124
+ ret += role + message + self.sep
125
+ else:
126
+ ret += role
127
+ elif self.sep_style == SeparatorStyle.PLAIN:
128
+ seps = [self.sep, self.sep2]
129
+ ret = self.system
130
+ for i, (role, message) in enumerate(messages):
131
+ if message:
132
+ if type(message) is tuple:
133
+ message, _, _ = message
134
+ ret += message + seps[i % 2]
135
+ else:
136
+ ret += ""
137
+ else:
138
+ raise ValueError(f"Invalid style: {self.sep_style}")
139
+
140
+ return ret
141
+
142
+ def append_message(self, role, message):
143
+ self.messages.append([role, message])
144
+
145
+ def get_images(self, return_pil=False):
146
+ images = []
147
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
148
+ if i % 2 == 0:
149
+ if type(msg) is tuple:
150
+ import base64
151
+ from io import BytesIO
152
+ from PIL import Image
153
+ msg, image, image_process_mode = msg
154
+ if image_process_mode == "Pad":
155
+ def expand2square(pil_img, background_color=(122, 116, 104)):
156
+ width, height = pil_img.size
157
+ if width == height:
158
+ return pil_img
159
+ elif width > height:
160
+ result = Image.new(pil_img.mode, (width, width), background_color)
161
+ result.paste(pil_img, (0, (width - height) // 2))
162
+ return result
163
+ else:
164
+ result = Image.new(pil_img.mode, (height, height), background_color)
165
+ result.paste(pil_img, ((height - width) // 2, 0))
166
+ return result
167
+ image = expand2square(image)
168
+ elif image_process_mode in ["Default", "Crop"]:
169
+ pass
170
+ elif image_process_mode == "Resize":
171
+ image = image.resize((336, 336))
172
+ else:
173
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
174
+ max_hw, min_hw = max(image.size), min(image.size)
175
+ aspect_ratio = max_hw / min_hw
176
+ max_len, min_len = 800, 400
177
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
178
+ longest_edge = int(shortest_edge * aspect_ratio)
179
+ W, H = image.size
180
+ if longest_edge != max(image.size):
181
+ if H > W:
182
+ H, W = longest_edge, shortest_edge
183
+ else:
184
+ H, W = shortest_edge, longest_edge
185
+ image = image.resize((W, H))
186
+ if return_pil:
187
+ images.append(image)
188
+ else:
189
+ buffered = BytesIO()
190
+ image.save(buffered, format="PNG")
191
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
192
+ images.append(img_b64_str)
193
+ return images
194
+
195
+ def to_gradio_chatbot(self):
196
+ ret = []
197
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
198
+ if i % 2 == 0:
199
+ if type(msg) is tuple:
200
+ import base64
201
+ from io import BytesIO
202
+ msg, image, image_process_mode = msg
203
+ max_hw, min_hw = max(image.size), min(image.size)
204
+ aspect_ratio = max_hw / min_hw
205
+ max_len, min_len = 800, 400
206
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
207
+ longest_edge = int(shortest_edge * aspect_ratio)
208
+ W, H = image.size
209
+ if H > W:
210
+ H, W = longest_edge, shortest_edge
211
+ else:
212
+ H, W = shortest_edge, longest_edge
213
+ image = image.resize((W, H))
214
+ buffered = BytesIO()
215
+ image.save(buffered, format="JPEG")
216
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
217
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
218
+ msg = img_str + msg.replace('<image>', '').strip()
219
+ ret.append([msg, None])
220
+ else:
221
+ ret.append([msg, None])
222
+ else:
223
+ ret[-1][-1] = msg
224
+ return ret
225
+
226
+ def copy(self):
227
+ return Conversation(
228
+ system=self.system,
229
+ roles=self.roles,
230
+ messages=[[x, y] for x, y in self.messages],
231
+ offset=self.offset,
232
+ sep_style=self.sep_style,
233
+ sep=self.sep,
234
+ sep2=self.sep2,
235
+ version=self.version)
236
+
237
+ def dict(self):
238
+ if len(self.get_images()) > 0:
239
+ return {
240
+ "system": self.system,
241
+ "roles": self.roles,
242
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
243
+ "offset": self.offset,
244
+ "sep": self.sep,
245
+ "sep2": self.sep2,
246
+ }
247
+ return {
248
+ "system": self.system,
249
+ "roles": self.roles,
250
+ "messages": self.messages,
251
+ "offset": self.offset,
252
+ "sep": self.sep,
253
+ "sep2": self.sep2,
254
+ }
255
+
256
+
257
+ conv_vicuna_v0 = Conversation(
258
+ system="A chat between a curious human and an artificial intelligence assistant. "
259
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
260
+ roles=("Human", "Assistant"),
261
+ messages=(
262
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
263
+ ("Assistant",
264
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
265
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
266
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
267
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
268
+ "renewable and non-renewable energy sources:\n"
269
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
270
+ "energy sources are finite and will eventually run out.\n"
271
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
272
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
273
+ "and other negative effects.\n"
274
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
275
+ "have lower operational costs than non-renewable sources.\n"
276
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
277
+ "locations than non-renewable sources.\n"
278
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
279
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
280
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
281
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
282
+ ),
283
+ offset=2,
284
+ sep_style=SeparatorStyle.SINGLE,
285
+ sep="###",
286
+ )
287
+
288
+ conv_vicuna_v1 = Conversation(
289
+ system="A chat between a curious user and an artificial intelligence assistant. "
290
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
291
+ roles=("USER", "ASSISTANT"),
292
+ version="v1",
293
+ messages=(),
294
+ offset=0,
295
+ sep_style=SeparatorStyle.TWO,
296
+ sep=" ",
297
+ sep2="</s>",
298
+ )
299
+
300
+ conv_llama_2 = Conversation(
301
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
302
+
303
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
304
+ roles=("USER", "ASSISTANT"),
305
+ version="llama_v2",
306
+ messages=(),
307
+ offset=0,
308
+ sep_style=SeparatorStyle.LLAMA_2,
309
+ sep="<s>",
310
+ sep2="</s>",
311
+ )
312
+
313
+ conv_llava_llama_2 = Conversation(
314
+ system="You are a helpful language and vision assistant. "
315
+ "You are able to understand the visual content that the user provides, "
316
+ "and assist the user with a variety of tasks using natural language.",
317
+ roles=("USER", "ASSISTANT"),
318
+ version="llama_v2",
319
+ messages=(),
320
+ offset=0,
321
+ sep_style=SeparatorStyle.LLAMA_2,
322
+ sep="<s>",
323
+ sep2="</s>",
324
+ )
325
+
326
+ conv_tiny_llava_tiny_llama = Conversation(
327
+ system="You are a helpful language and vision assistant. "
328
+ "You are able to understand the visual content that the user provides, "
329
+ "and assist the user with a variety of tasks using natural language.",
330
+ roles=("USER", "ASSISTANT"),
331
+ version="tiny_llama",
332
+ messages=(),
333
+ offset=0,
334
+ sep_style=SeparatorStyle.TINY_LLAMA,
335
+ sep="<s>",
336
+ sep2="</s>"
337
+ )
338
+
339
+
340
+ conv_mpt = Conversation(
341
+ system="""<|im_start|>system
342
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
343
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
344
+ version="mpt",
345
+ messages=(),
346
+ offset=0,
347
+ sep_style=SeparatorStyle.MPT,
348
+ sep="<|im_end|>",
349
+ )
350
+
351
+ conv_llava_plain = Conversation(
352
+ system="",
353
+ roles=("", ""),
354
+ messages=(
355
+ ),
356
+ version='plain',
357
+ offset=0,
358
+ sep_style=SeparatorStyle.PLAIN,
359
+ sep="\n",
360
+ )
361
+
362
+ conv_llava_v0 = Conversation(
363
+ system="A chat between a curious human and an artificial intelligence assistant. "
364
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
365
+ roles=("Human", "Assistant"),
366
+ messages=(
367
+ ),
368
+ offset=0,
369
+ sep_style=SeparatorStyle.SINGLE,
370
+ sep="###",
371
+ )
372
+
373
+ conv_llava_v0_mmtag = Conversation(
374
+ system="A chat between a curious user and an artificial intelligence assistant. "
375
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
376
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
377
+ roles=("Human", "Assistant"),
378
+ messages=(
379
+ ),
380
+ offset=0,
381
+ sep_style=SeparatorStyle.SINGLE,
382
+ sep="###",
383
+ version="v0_mmtag",
384
+ )
385
+
386
+ conv_llava_v1 = Conversation(
387
+ system="A chat between a curious human and an artificial intelligence assistant. "
388
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
389
+ roles=("USER", "ASSISTANT"),
390
+ version="v1",
391
+ messages=(),
392
+ offset=0,
393
+ sep_style=SeparatorStyle.TWO,
394
+ sep=" ",
395
+ sep2="</s>",
396
+ )
397
+
398
+ conv_llava_v1_mmtag = Conversation(
399
+ system="A chat between a curious user and an artificial intelligence assistant. "
400
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
401
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
402
+ roles=("USER", "ASSISTANT"),
403
+ messages=(),
404
+ offset=0,
405
+ sep_style=SeparatorStyle.TWO,
406
+ sep=" ",
407
+ sep2="</s>",
408
+ version="v1_mmtag",
409
+ )
410
+
411
+ conv_phi_v0 = Conversation(
412
+ system="A chat between a curious user and an artificial intelligence assistant. "
413
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
414
+ roles=("USER", "ASSISTANT"),
415
+ version="phi",
416
+ messages=(),
417
+ offset=0,
418
+ sep_style=SeparatorStyle.TWO,
419
+ sep=" ",
420
+ sep2="<|endoftext|>",
421
+ )
422
+
423
+ conv_stablelm = Conversation(
424
+ system="A chat between a curious user and an artificial intelligence assistant. "
425
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
426
+ roles=("USER", "ASSISTANT"),
427
+ version="stablelm",
428
+ messages=(),
429
+ offset=0,
430
+ sep_style=SeparatorStyle.TWO,
431
+ sep=" ",
432
+ sep2="<|endoftext|>",
433
+ )
434
+
435
+ conv_mistral_instruct = Conversation(
436
+ system="",
437
+ roles=("USER", "ASSISTANT"),
438
+ version="llama_v2",
439
+ messages=(),
440
+ offset=0,
441
+ sep_style=SeparatorStyle.LLAMA_2,
442
+ sep="",
443
+ sep2="</s>",
444
+ )
445
+
446
+ conv_chatml_direct = Conversation(
447
+ system="""<|im_start|>system
448
+ Answer the questions.""",
449
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
450
+ version="mpt",
451
+ messages=(),
452
+ offset=0,
453
+ sep_style=SeparatorStyle.MPT,
454
+ sep="<|im_end|>",
455
+ )
456
+
457
+ conv_qwen2 = Conversation(
458
+ system="<|im_start|>system\nYou are a helpful assistant",
459
+ roles=("<im_start>user\n", "<im_start>assistant\n"),
460
+ version="mpt",
461
+ messages=(),
462
+ offset=0,
463
+ sep_style=SeparatorStyle.MPT,
464
+ sep="<im_end>"
465
+ )
466
+
467
+ default_conversation = conv_phi_v0
468
+ conv_templates = {
469
+ "default": conv_vicuna_v0,
470
+ "v0": conv_vicuna_v0,
471
+ "v1": conv_vicuna_v1,
472
+ "vicuna_v1": conv_vicuna_v1,
473
+ "llama_2": conv_llama_2,
474
+
475
+ "plain": conv_llava_plain,
476
+ "v0_plain": conv_llava_plain,
477
+ "llava_v0": conv_llava_v0,
478
+ "v0_mmtag": conv_llava_v0_mmtag,
479
+ "llava_v1": conv_llava_v1,
480
+ "v1_mmtag": conv_llava_v1_mmtag,
481
+ "llava_llama_2": conv_llava_llama_2,
482
+
483
+ "mpt": conv_mpt,
484
+
485
+ "tiny_llama": conv_tiny_llava_tiny_llama,
486
+ "phi": conv_phi_v0,
487
+
488
+ # added by llava-1.6
489
+ "mistral_instruct": conv_mistral_instruct,
490
+ "chatml_direct": conv_chatml_direct,
491
+ "mistral_direct": conv_chatml_direct,
492
+ }
493
+
494
+
495
+ if __name__ == "__main__":
496
+ print(default_conversation.get_prompt())
generate_model.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ import logging
4
+ import requests
5
+ import os
6
+ from PIL import Image
7
+ from io import BytesIO
8
+
9
+ from PIL import Image
10
+ import torch
11
+ from transformers import AutoTokenizer
12
+
13
+ from modeling_tinyllava_elm import TinyLlavaForConditionalGeneration
14
+ from configuration import *
15
+ from conversion import *
16
+ from utils import *
17
+
18
+
19
+
20
+ def load_image(image_file):
21
+ if image_file.startswith("http") or image_file.startswith("https"):
22
+ response = requests.get(image_file)
23
+ image = Image.open(BytesIO(response.content)).convert("RGB")
24
+ else:
25
+ image = Image.open(image_file).convert("RGB")
26
+ return image
27
+
28
+
29
+ def generate(
30
+ prompt: str,
31
+ model: str,
32
+ tokenizer = None,
33
+ image: str = None,
34
+ device: str = None,
35
+ max_new_tokens: int = 1024,
36
+ num_beams = 1,
37
+ top_p=None,
38
+ temperature=0.2
39
+ ):
40
+ if not device:
41
+ if torch.cuda.is_available() and torch.cuda.device_count():
42
+ device = "cuda:0"
43
+ logging.warning(
44
+ 'inference device is not set, using cuda:0, %s',
45
+ torch.cuda.get_device_name(0)
46
+ )
47
+ else:
48
+ device = 'cpu'
49
+ logging.warning(
50
+ (
51
+ 'No CUDA device detected, using cpu, '
52
+ 'expect slower speeds.'
53
+ )
54
+ )
55
+
56
+ if 'cuda' in device and not torch.cuda.is_available():
57
+ raise ValueError('CUDA device requested but no CUDA device detected.')
58
+
59
+ if isinstance(model, str):
60
+ checkpoint_path = model
61
+ # print(f'loading model from {checkpoint_path}...')
62
+ model = TinyLlavaForConditionalGeneration.from_pretrained(
63
+ checkpoint_path,
64
+ torch_dtype=torch.float16,
65
+ )
66
+ # print('model load over')
67
+ config = model.config
68
+ if tokenizer is None:
69
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, model_max_length = config.tokenizer_model_max_length,
70
+ padding_side = config.tokenizer_padding_side)
71
+ image_processor = model.vision_tower._image_processor
72
+ context_len = getattr(config, 'max_sequence_length', 2048)
73
+ model.to(device).eval()
74
+
75
+
76
+ if image is not None:
77
+ prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
78
+ conv = conv_phi_v0.copy()
79
+ conv.append_message(conv.roles[0], prompt)
80
+ conv.append_message(conv.roles[1], None)
81
+ prompt = conv.get_prompt()
82
+ if image is not None:
83
+ # print('loading image...')
84
+ image = load_image(image)
85
+ # print('load image over')
86
+ image_tensor = process_images(image, image_processor, config).to(model.device, dtype=torch.float16)
87
+
88
+ input_ids = (
89
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
90
+ .unsqueeze(0)
91
+ .cuda()
92
+ )
93
+ # Generate
94
+ stime = time.time()
95
+ # stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
96
+ # keywords = [stop_str]
97
+ # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
98
+ # print('start inference...')
99
+ with torch.inference_mode():
100
+ output_ids = model.generate(
101
+ input_ids,
102
+ images=image_tensor,
103
+ do_sample=True if temperature > 0 else False,
104
+ temperature=temperature,
105
+ top_p=top_p,
106
+ num_beams=num_beams,
107
+ pad_token_id=tokenizer.pad_token_id,
108
+ max_new_tokens=max_new_tokens,
109
+ use_cache=True,
110
+ # stopping_criteria=[stopping_criteria],
111
+ )
112
+
113
+ # print('inference over')
114
+ generation_time = time.time() - stime
115
+ outputs = tokenizer.batch_decode(
116
+ output_ids, skip_special_tokens=True
117
+ )[0]
118
+ # outputs = outputs.strip()
119
+ # if outputs.endswith(stop_str):
120
+ # outputs = outputs[: -len(stop_str)]
121
+ outputs = outputs.strip()
122
+
123
+ return outputs, generation_time
124
+ def tinyllava_elm_generate_parser():
125
+ """Argument Parser"""
126
+
127
+ class KwargsParser(argparse.Action):
128
+ """Parser action class to parse kwargs of form key=value"""
129
+ def __call__(self, parser, namespace, values, option_string=None):
130
+ setattr(namespace, self.dest, dict())
131
+ for val in values:
132
+ if '=' not in val:
133
+ raise ValueError(
134
+ (
135
+ 'Argument parsing error, kwargs are expected in'
136
+ ' the form of key=value.'
137
+ )
138
+ )
139
+ kwarg_k, kwarg_v = val.split('=')
140
+ try:
141
+ converted_v = int(kwarg_v)
142
+ except ValueError:
143
+ try:
144
+ converted_v = float(kwarg_v)
145
+ except ValueError:
146
+ converted_v = kwarg_v
147
+ getattr(namespace, self.dest)[kwarg_k] = converted_v
148
+
149
+ parser = argparse.ArgumentParser('TinyLLaVA-OpenELM Generate Module')
150
+ parser.add_argument(
151
+ '--model',
152
+ dest='model',
153
+ help='Path to the hf converted model.',
154
+ required=True,
155
+ type=str,
156
+ )
157
+ parser.add_argument(
158
+ '--prompt',
159
+ dest='prompt',
160
+ help='Prompt for LLM call.',
161
+ default='',
162
+ type=str,
163
+ )
164
+ parser.add_argument(
165
+ '--device',
166
+ dest='device',
167
+ help='Device used for inference.',
168
+ type=str,
169
+ )
170
+ parser.add_argument("--image", type=str, default=None)
171
+ parser.add_argument("--temperature", type=float, default=0)
172
+ parser.add_argument("--top_p", type=float, default=None)
173
+ parser.add_argument("--num_beams", type=int, default=1)
174
+ parser.add_argument("--max_new_tokens", type=int, default=512)
175
+ return parser.parse_args()
176
+
177
+
178
+ if __name__ == '__main__':
179
+ args = tinyllava_elm_generate_parser()
180
+ prompt = args.prompt
181
+ model = TinyLlavaForConditionalGeneration.from_pretrained(args.model)
182
+
183
+ output_text, genertaion_time = generate(
184
+ prompt=prompt,
185
+ image=args.image,
186
+ model=args.model,
187
+ device=args.device,
188
+ max_new_tokens = args.max_new_tokens,
189
+ num_beams = args.num_beams,
190
+ top_p=args.top_p,
191
+ temperature=args.temperature
192
+ )
193
+
194
+ print_txt = (
195
+ f'\r\n{"=" * os.get_terminal_size().columns}\r\n'
196
+ '\033[1m Prompt + Generated Output\033[0m\r\n'
197
+ f'{"-" * os.get_terminal_size().columns}\r\n'
198
+ f'{output_text}\r\n'
199
+ f'{"-" * os.get_terminal_size().columns}\r\n'
200
+ '\r\nGeneration took'
201
+ f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m'
202
+ 'seconds.\r\n'
203
+ )
204
+ print(print_txt)
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.39.3",
6
+ "use_cache": false
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3537a54391e475003f38a1ae2a167a2873d5e97bfe048b846ea491f0a4d63d56
3
+ size 1779161568
modeling_elm.py ADDED
@@ -0,0 +1,1288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.utils.checkpoint
10
+ from torch import Tensor, nn
11
+ from torch.nn import CrossEntropyLoss
12
+ from torch.nn import functional as F
13
+ from transformers import PreTrainedModel
14
+ from transformers.activations import ACT2FN
15
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
16
+ from transformers.modeling_outputs import (
17
+ BaseModelOutputWithPast,
18
+ CausalLMOutputWithPast,
19
+ )
20
+ from transformers.utils import logging
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ # this import has to be relative, otherwise, when setting trust_remote_code=True
25
+ # huggingface transformers won't be able to load the module correctly
26
+ from numbers import Number
27
+ from typing import List, Optional, Union
28
+
29
+ import numpy as np
30
+ from transformers import PretrainedConfig, AutoTokenizer
31
+
32
+
33
+ def make_divisible(
34
+ v: Union[float, int],
35
+ divisor: Optional[int] = 8,
36
+ min_value: Optional[Union[float, int]] = None,
37
+ ) -> Union[float, int]:
38
+ """
39
+ This function is taken from the original tf repo.
40
+ It ensures that all layers have a channel number that is divisible by the divisor
41
+ It can be seen at:
42
+ https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62
43
+ Args:
44
+ v: input value
45
+ divisor: default to 8
46
+ min_value: minimum divisor value
47
+ Returns:
48
+ new_v: new divisible value
49
+ """
50
+ if min_value is None:
51
+ min_value = divisor
52
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
53
+ # Make sure that round down does not go down by more than 10%.
54
+ if new_v < 0.9 * v:
55
+ new_v += divisor
56
+ return new_v
57
+
58
+
59
+ def compute_heads(model_dim: int, head_dim: int) -> int:
60
+ """Compute the number of heads.
61
+ Args:
62
+ model_dim: Model dimension.
63
+ head_dim: Head dimension.
64
+ Returns:
65
+ An integer denoting number of heads in multi-head attention is returned.
66
+ Raises:
67
+ ValueError: if model dimension is not divisible by head dimension.
68
+ """
69
+ if model_dim % head_dim == 0:
70
+ return model_dim // head_dim
71
+ else:
72
+ raise ValueError(
73
+ f"Model dimension should be divisible by head dimension. Got: {model_dim} and {head_dim}."
74
+ )
75
+
76
+
77
+ OpenELM_CONFIGS = {
78
+ "OpenELM-270M": dict(
79
+ num_transformer_layers=16,
80
+ model_dim=1280,
81
+ head_dim=64,
82
+ num_gqa_groups=4,
83
+ normalize_qk_projections=True,
84
+ share_input_output_layers=True,
85
+ # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
86
+ ffn_multipliers=(0.5, 4.0),
87
+ qkv_multipliers=(0.5, 1.0),
88
+ ),
89
+ "OpenELM-450M": dict(
90
+ num_transformer_layers=20,
91
+ model_dim=1536,
92
+ head_dim=64,
93
+ num_gqa_groups=4,
94
+ normalize_qk_projections=True,
95
+ share_input_output_layers=True,
96
+ # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
97
+ ffn_multipliers=(0.5, 4.0),
98
+ qkv_multipliers=(0.5, 1.0),
99
+ ),
100
+ "OpenELM-1_1B": dict(
101
+ num_transformer_layers=28,
102
+ model_dim=2048,
103
+ head_dim=64,
104
+ num_gqa_groups=4,
105
+ normalize_qk_projections=True,
106
+ share_input_output_layers=True,
107
+ # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
108
+ ffn_multipliers=(0.5, 4.0),
109
+ qkv_multipliers=(0.5, 1.0),
110
+ ),
111
+ "OpenELM-3B": dict(
112
+ num_transformer_layers=36,
113
+ model_dim=3072,
114
+ head_dim=128,
115
+ num_gqa_groups=4,
116
+ normalize_qk_projections=True,
117
+ share_input_output_layers=True,
118
+ # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
119
+ ffn_multipliers=(0.5, 4.0),
120
+ qkv_multipliers=(0.5, 1.0),
121
+ ),
122
+ }
123
+
124
+
125
+ class OpenELMConfig(PretrainedConfig):
126
+ r"""
127
+ This is the configuration class to store the configuration of a [`OpenELMModel`]. It is used to instantiate an OpenELM model according to the specified arguments, defining the model architecture.
128
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
129
+ documentation from [`PretrainedConfig`] for more information.
130
+ Args:
131
+ vocab_size (`int`, *optional*, defaults to 32000):
132
+ Vocabulary size of the OpenELM model.
133
+ max_context_length (`int`, *optional*, defaults to 2048):
134
+ Maximum number of input tokens.
135
+ num_transformer_layers (`int`, *optional*, defaults to 12):
136
+ Number of hidden layers in the Transformer decoder.
137
+ model_dim (`int`, *optional*, defaults to 2048):
138
+ Dimension of the hidden representations.
139
+ head_dim (`int`, *optional*, defaults to 128):
140
+ The attention head dimension.
141
+ qkv_multipliers (`Union[Number, List[Number]]`, *optional*, defaults to 1.0):
142
+ If the qkv_multipliers is a Number, then all attention layers have the same latent dimensions,
143
+ resulting in uniform allocation of parameters.
144
+ If the qkv_multipliers is a List of Number, then each attention layer have different latent dimensions
145
+ assuming qkv_multipliers[0] != qkv_multipliers[1]. This results in variable allocation of parameters in attention layer.
146
+ This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
147
+ num_query_heads (`Union[int, None]`, *optional*, defaults to None):
148
+ The number of query heads, computed from `compute_heads(model_dim=model_dim, head_dim=head_dim)`.
149
+ num_gqa_groups (`int`, *optional*, defaults to 1):
150
+ This variable allows to switch between multi-head attention, group query attention, and multi-query attention.
151
+ When num_gqa_groups == 1, then it is multi-head attention.
152
+ When 1 < num_gqa_groups < num_heads and num_heads is divisible by num_gqa_groups, then it is group query attention
153
+ When num_gqa_groups == num_heads, then it is multi-query attention
154
+ ffn_multipliers (`Union[Number, List[Number]]`, *optional*, defaults to 4.0):
155
+ Feed-forward network (FFN) multipliers.
156
+ If the ffn_multipliers is a Number, then all FFN layers have the same latent dimensions,
157
+ resulting in uniform allocation of parameters.
158
+ If the ffn_multipliers is a List of Number, then each FFN layer have different latent dimensions
159
+ assuming ffn_multipliers[0] != ffn_multipliers[1]. This results in variable allocation of parameters in FFN layer.
160
+ This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
161
+ ffn_with_glu (`bool`, *optional*, defaults to True):
162
+ Whether to use FFN with Gated Linear Unit (GLU)
163
+ ffn_dim_divisor (`int`, *optional*, defaults to 256):
164
+ The ffn layer dimension divisor.
165
+ activation_fn_name (`str` or `function`, *optional*, defaults to `"swish"`):
166
+ The non-linear activation function (function or string) in the decoder.
167
+ normalization_layer_name (`str` or `function`, *optional*, defaults to `"rms_norm"`):
168
+ Type of normalization layer.
169
+ normalize_qk_projections (`bool`, *optional*, defaults to False):
170
+ Whether to normalize queries and keys after projections
171
+ share_input_output_layers (`bool`, *optional*, defaults to False):
172
+ Whether to share the embedding between input and output linear layer
173
+ rope_freq_constant (`int`, *optional*, defaults to 10000):
174
+ The base period of the RoPE embeddings.
175
+ rope_max_length (`int`, *optional*, defaults to 4096):
176
+ That rope_max_length is set to twice of max_context_length.
177
+ This allows flexibility in token lengths during training or fine-tuning.
178
+ initializer_range (`float`, *optional*, defaults to 0.02):
179
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
180
+ use_cache (`bool`, *optional*, defaults to `True`):
181
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
182
+ relevant if `config.is_decoder=True`.
183
+ bos_token_id (`int`, *optional*, defaults to 2):
184
+ Beginning of stream token id.
185
+ eos_token_id (`int`, *optional*, defaults to 1):
186
+ End of stream token id.
187
+ """
188
+
189
+ model_type = "openelm"
190
+
191
+ def __init__(
192
+ self,
193
+ vocab_size: int = 32000,
194
+ max_context_length: int = 2048,
195
+ num_transformer_layers: int = 12,
196
+ model_dim: int = 2048,
197
+ head_dim: int = 128,
198
+ qkv_multipliers: Union[Number, List[Number]] = 1.0,
199
+ num_query_heads: Union[int, None] = None,
200
+ num_gqa_groups: int = 1,
201
+ ffn_multipliers: Union[Number, List[Number]] = 4.0,
202
+ ffn_with_glu: bool = True,
203
+ ffn_dim_divisor: int = 256,
204
+ activation_fn_name: str = "swish",
205
+ normalization_layer_name: str = "rms_norm",
206
+ normalize_qk_projections: bool = False,
207
+ share_input_output_layers: bool = False,
208
+ rope_freq_constant: int = 10000,
209
+ rope_max_length: int = 4096,
210
+ initializer_range: float = 0.02,
211
+ use_cache: bool = True,
212
+ bos_token_id: int = 1,
213
+ eos_token_id: int = 2,
214
+ **kwargs,
215
+ ) -> None:
216
+ self.vocab_size = vocab_size
217
+ self.max_context_length = max_context_length
218
+ self.num_transformer_layers = num_transformer_layers
219
+ self.model_dim = model_dim
220
+ self.head_dim = head_dim
221
+ self.qkv_multipliers = qkv_multipliers
222
+ self.num_query_heads = num_query_heads
223
+ self.num_gqa_groups = num_gqa_groups
224
+ self.ffn_multipliers = ffn_multipliers
225
+ self.ffn_with_glu = ffn_with_glu
226
+ self.ffn_dim_divisor = ffn_dim_divisor
227
+ self.activation_fn_name = activation_fn_name
228
+ self.normalization_layer_name = normalization_layer_name
229
+ self.normalize_qk_projections = normalize_qk_projections
230
+ self.share_input_output_layers = share_input_output_layers
231
+ self.rope_freq_constant = rope_freq_constant
232
+ self.rope_max_length = rope_max_length
233
+ self.num_query_heads = (
234
+ compute_heads(model_dim=model_dim, head_dim=head_dim)
235
+ if num_query_heads is None
236
+ else num_query_heads
237
+ )
238
+ self.initializer_range = initializer_range
239
+
240
+ self.__post_init__()
241
+ super().__init__(
242
+ use_cache=use_cache,
243
+ bos_token_id=bos_token_id,
244
+ eos_token_id=eos_token_id,
245
+ **kwargs,
246
+ )
247
+
248
+ def __post_init__(self) -> None:
249
+ if self.num_gqa_groups is not None:
250
+ head_multiple_of = self.num_gqa_groups
251
+ else:
252
+ head_multiple_of = 2
253
+
254
+ if isinstance(self.qkv_multipliers, Number):
255
+ # All attention layers have the same latent dimensions, resulting in uniform allocation of parameters.
256
+ qkv_dim = make_divisible(
257
+ self.model_dim * self.qkv_multipliers,
258
+ divisor=self.head_dim * head_multiple_of,
259
+ )
260
+ query_dims = [int(qkv_dim)] * self.num_transformer_layers
261
+
262
+ elif (
263
+ isinstance(self.qkv_multipliers, (tuple, list))
264
+ and len(self.qkv_multipliers) == 2
265
+ ):
266
+ # Each attention layer have different latent dimensions assuming qkv_multipliers[0] != qkv_multipliers[1].
267
+ # This results in variable allocation of parameters in attention layer.
268
+ # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
269
+ qkv_multipliers = [
270
+ round(v, 2)
271
+ for v in np.linspace(
272
+ self.qkv_multipliers[0],
273
+ self.qkv_multipliers[1],
274
+ num=self.num_transformer_layers,
275
+ dtype=float,
276
+ )
277
+ ]
278
+ # Make sure that scaled model dimension is divisible by scaled head dimension.
279
+ query_dims = [
280
+ int(
281
+ make_divisible(
282
+ self.model_dim * m, divisor=self.head_dim * head_multiple_of
283
+ )
284
+ )
285
+ for m in qkv_multipliers
286
+ ]
287
+ else:
288
+ raise NotImplementedError(
289
+ f"QKV multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}."
290
+ )
291
+
292
+ # compute the number of query, key, and value heads
293
+ # For multi-head and multi-query attention, the number of heads for query, key, and value are the same.
294
+ # For group query attention, the number of key and value heads are the same.
295
+ self.num_query_heads = [
296
+ int(compute_heads(q_dim, self.head_dim)) for q_dim in query_dims
297
+ ]
298
+ self.num_kv_heads = [
299
+ q_heads // self.num_gqa_groups for q_heads in self.num_query_heads
300
+ ]
301
+
302
+ # Feed-forward network (FFN) multipliers
303
+ if isinstance(self.ffn_multipliers, Number):
304
+ # All FFN layers have the same latent dimensions, resulting in uniform allocation of parameters.
305
+ self.ffn_multipliers = [self.ffn_multipliers] * self.num_transformer_layers
306
+ elif isinstance(self.ffn_multipliers, (tuple, list)):
307
+ # Each FFN layer have different latent dimensions assuming ffn_multipliers[0] != ffn_multipliers[1].
308
+ # This results in variable allocation of parameters in FFN layer.
309
+ # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
310
+ if len(self.ffn_multipliers) == 2:
311
+ self.ffn_multipliers = [
312
+ round(v, 2)
313
+ for v in np.linspace(
314
+ self.ffn_multipliers[0],
315
+ self.ffn_multipliers[1],
316
+ num=self.num_transformer_layers,
317
+ dtype=float,
318
+ )
319
+ ]
320
+ else:
321
+ assert (
322
+ len(self.ffn_multipliers) == self.num_transformer_layers
323
+ ), f"{len(self.ffn_multipliers)=}!={self.num_transformer_layers=}"
324
+ else:
325
+ raise NotImplementedError(
326
+ f"FFN multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}."
327
+ )
328
+
329
+ # check num_query_heads divisible by num_kv_heads for every layer
330
+ for layer_idx in range(len(query_dims)):
331
+ assert self.num_query_heads[layer_idx] % self.num_kv_heads[layer_idx] == 0
332
+
333
+ class OpenELMRMSNorm(nn.Module):
334
+ def __init__(self, num_features: int, eps: float = 1e-6):
335
+ """
336
+ Initialize the OpenELMRMSNorm normalization layer.
337
+ Args:
338
+ dim (int): The dimension of the input tensor.
339
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
340
+ Attributes:
341
+ eps (float): A small value added to the denominator for numerical stability.
342
+ weight (nn.Parameter): Learnable scaling parameter.
343
+ """
344
+ super().__init__()
345
+ self.eps = eps
346
+ self.weight = nn.Parameter(torch.ones(num_features))
347
+ self.num_features = num_features
348
+
349
+ def _norm(self, x: Tensor) -> Tensor:
350
+ """
351
+ Apply the OpenELMRMSNorm normalization to the input tensor.
352
+ Args:
353
+ x (torch.Tensor): The input tensor.
354
+ Returns:
355
+ torch.Tensor: The normalized tensor.
356
+ """
357
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
358
+
359
+ def forward(self, x: Tensor) -> Tensor:
360
+ """
361
+ Forward pass through the OpenELMRMSNorm layer.
362
+ Args:
363
+ x (torch.Tensor): The input tensor.
364
+ Returns:
365
+ torch.Tensor: The output tensor after applying OpenELMRMSNorm.
366
+ """
367
+ output = self._norm(x.float()).type_as(x)
368
+ return output * self.weight
369
+
370
+ def extra_repr(self) -> str:
371
+ return (
372
+ super().extra_repr() + f"num_features={self.num_features}, eps={self.eps}"
373
+ )
374
+
375
+
376
+ class OpenELMPreTrainedModel(PreTrainedModel):
377
+ config_class = OpenELMConfig
378
+ base_model_prefix = "transformer"
379
+ supports_gradient_checkpointing = True
380
+ _no_split_modules = ["OpenELMDecoderLayer"]
381
+ _skip_keys_device_placement = "past_key_values"
382
+
383
+ def __init__(self, *inputs, **kwargs) -> None:
384
+ super().__init__(*inputs, **kwargs)
385
+
386
+ def _init_weights(self, module: nn.Module) -> None:
387
+ """Initialize the weights."""
388
+ if isinstance(module, nn.Linear):
389
+ # Slightly different from the TF version which uses truncated_normal for initialization
390
+ # cf https://github.com/pytorch/pytorch/pull/5617
391
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
392
+ if module.bias is not None:
393
+ module.bias.data.zero_()
394
+ elif isinstance(module, nn.Embedding):
395
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
396
+ if module.padding_idx is not None:
397
+ module.weight.data[module.padding_idx].zero_()
398
+ elif isinstance(module, OpenELMRMSNorm):
399
+ module.weight.data.fill_(1.0)
400
+
401
+
402
+ def _rotate_half(x: Tensor) -> Tensor:
403
+ x1, x2 = x.chunk(2, dim=-1)
404
+ return torch.cat((-x2, x1), dim=-1)
405
+
406
+
407
+ def _apply_rotary_pos_emb(x: Tensor, pos_sin: Tensor, pos_cos: Tensor) -> Tensor:
408
+ return (x * pos_cos) + (_rotate_half(x) * pos_sin)
409
+
410
+
411
+ class OpenELMRotaryEmbedding(torch.nn.Module):
412
+ """
413
+ The rotary position embeddings (aka RoPE) from `RoFormer <https://arxiv.org/abs/2104.09864>`_.
414
+ RoPE encodes the position information of tokens using a rotation matrix, and is able to capture
415
+ explicit relative positional dependencies.
416
+ Args:
417
+ model_dim: The dimensionality of the model's hidden state.
418
+ max_seq_length: Maximum sequence length.
419
+ freq_constant: A constant used for computing frequencies.
420
+ """
421
+
422
+ def __init__(
423
+ self, model_dim: int, max_seq_length: int, freq_constant: int = 10000
424
+ ) -> None:
425
+ inv_freq = 1.0 / (
426
+ freq_constant
427
+ ** (torch.arange(0, model_dim, 2, dtype=torch.float32) / model_dim)
428
+ )
429
+ super().__init__()
430
+
431
+ self.model_dim = model_dim
432
+ self.freq_constant = freq_constant
433
+ self.max_seq_length = max_seq_length
434
+
435
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
436
+ self._cached_cos = None
437
+ self._cached_sin = None
438
+ self._cached_seq_length = max_seq_length
439
+ self._compute_sin_cos_embeddings(max_seq_length)
440
+
441
+ def extra_repr(self) -> str:
442
+ return f"\tmodel_dim={self.model_dim}, max_seq_length={self.max_seq_length}, freq_constant={self.freq_constant}"
443
+
444
+ def _compute_sin_cos_embeddings(
445
+ self,
446
+ key_len: int,
447
+ key_device: torch.device = torch.device("cpu"),
448
+ key_dtype: torch.dtype = torch.float32,
449
+ ) -> None:
450
+ """
451
+ Compute sine and cos embeddings.
452
+ Args:
453
+ key_len: Number of tokens in the key embeddings in the transformer model.
454
+ device: Device where the key embeddings are stored.
455
+ key_dtype: Data type of the key embeddings.
456
+ Returns:
457
+ None
458
+ ...note:
459
+ We recalculate the sine and cosine embeddings if any of the following conditions are met:
460
+ 1. The number of tokens in key embeddings are greater than the cached sequence length.
461
+ 2. Sine and cosine caches are empty.
462
+ 3. The device and data type of sine and cosine embeddings does not match with the key embeddings.
463
+ """
464
+ if (
465
+ key_len > self._cached_seq_length
466
+ or self._cached_cos is None
467
+ or (self._cached_cos is not None and self._cached_cos.device != key_device)
468
+ or (self._cached_cos is not None and self._cached_cos.dtype != key_dtype)
469
+ or self._cached_sin is None
470
+ or (self._cached_sin is not None and self._cached_sin.device != key_device)
471
+ or (self._cached_sin is not None and self._cached_sin.dtype != key_dtype)
472
+ ):
473
+ self._cached_seq_length = max(key_len, self._cached_seq_length)
474
+
475
+ # The shape of 'pos_index' is [number of key tokens]
476
+ pos_index = torch.arange(
477
+ self._cached_seq_length,
478
+ dtype=torch.float32,
479
+ device=self.inv_freq.device,
480
+ )
481
+ # The shape of 'pos_index_theta' is [number of key tokens, model dimension]
482
+ pos_index_theta = torch.einsum("i,j->ij", pos_index, self.inv_freq)
483
+ # The shape of 'emb' is [number of key tokens, model dimension]
484
+ emb = torch.cat((pos_index_theta, pos_index_theta), dim=-1)
485
+
486
+ # the shape of cos and sin embeddings is [number of key tokens, model_dim]
487
+ cos_emb = emb.cos().to(dtype=key_dtype, device=key_device)
488
+ sin_emb = emb.sin().to(dtype=key_dtype, device=key_device)
489
+
490
+ # the shape of cached cos and sin embeddings is [1, 1, number of key tokens, model_dim]
491
+ self._cached_cos = cos_emb[None, None, :, :]
492
+ self._cached_sin = sin_emb[None, None, :, :]
493
+
494
+ def forward(
495
+ self,
496
+ query: torch.Tensor,
497
+ key: torch.Tensor,
498
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
499
+ """
500
+ The forward function of RoPE embeddings.
501
+ Args:
502
+ query: Query embeddings in the transformer model. The shape of query embeddings is
503
+ [Batch, number of query heads, number of query tokens, model dimension].
504
+ key: Key embeddings in the transformer model. The shape of key embeddings is
505
+ [Batch, number of key heads, number of key tokens, model dimension].
506
+ Returns:
507
+ A tuple containing the query and key embeddings with positional information. The shape of the returned query
508
+ and key embeddings is the same as the input query and key embeddings respectively.
509
+ ...note:
510
+ The RoPE embedding computation is done in full-precision. After the computation, input query and key tensors
511
+ are casted to original input datatype.
512
+ """
513
+ dim = key.shape[-1]
514
+ key_len = key.shape[2]
515
+ query_len = query.shape[2]
516
+
517
+ assert dim == self.model_dim
518
+ assert key.device == query.device
519
+ assert key.dtype == query.dtype
520
+
521
+ # In the context of self-attention, the lengths of keys and queries are equal.
522
+ # However, in generation tasks, such as predicting the next token in a sequence, the lengths of keys and queries
523
+ # can differ. For instance, when employing key-value (KV) caching for sequence prediction, the keys
524
+ # represent embeddings of previous tokens and the current token, while the query corresponds
525
+ # to the embedding of the current token only.
526
+ assert (
527
+ key_len >= query_len
528
+ ), "Number of keys has to be greater than or equal to number of queries."
529
+
530
+ query_float = query.float()
531
+ key_float = key.float()
532
+
533
+ self._compute_sin_cos_embeddings(
534
+ key_len, key_device=key_float.device, key_dtype=key_float.dtype
535
+ )
536
+ query_float = _apply_rotary_pos_emb(
537
+ x=query_float,
538
+ pos_sin=self._cached_sin[..., key_len - query_len : key_len, :],
539
+ pos_cos=self._cached_cos[..., key_len - query_len : key_len, :],
540
+ )
541
+ key_float = _apply_rotary_pos_emb(
542
+ x=key_float,
543
+ pos_sin=self._cached_sin[..., :key_len, :],
544
+ pos_cos=self._cached_cos[..., :key_len, :],
545
+ )
546
+
547
+ return query_float.type_as(query), key_float.type_as(key)
548
+
549
+
550
+ class OpenELMMultiHeadCausalAttention(nn.Module):
551
+ def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
552
+ super().__init__()
553
+ self.layer_idx = layer_idx
554
+ head_dim = config.head_dim
555
+ q_heads = config.num_query_heads[layer_idx]
556
+ k_heads = config.num_kv_heads[layer_idx]
557
+ v_heads = config.num_kv_heads[layer_idx]
558
+
559
+ self.qkv_proj = nn.Linear(
560
+ in_features=config.model_dim,
561
+ out_features=(q_heads + k_heads + v_heads) * head_dim,
562
+ bias=False,
563
+ )
564
+
565
+ self.pos_embedding = OpenELMRotaryEmbedding(
566
+ model_dim=config.head_dim,
567
+ max_seq_length=config.rope_max_length,
568
+ freq_constant=config.rope_freq_constant,
569
+ )
570
+
571
+ if config.normalize_qk_projections:
572
+ self.q_norm = OpenELMRMSNorm(
573
+ num_features=config.head_dim,
574
+ )
575
+ self.k_norm = OpenELMRMSNorm(
576
+ num_features=config.head_dim,
577
+ )
578
+ else:
579
+ self.q_norm = None
580
+ self.k_norm = None
581
+
582
+ self.out_proj = nn.Linear(
583
+ in_features=q_heads * head_dim,
584
+ out_features=config.model_dim,
585
+ bias=False,
586
+ )
587
+
588
+ self.head_dim = config.head_dim
589
+ self.num_q_heads = q_heads
590
+ self.num_k_heads = k_heads
591
+ self.num_v_heads = v_heads
592
+ self.transformer_dim = config.model_dim
593
+ self.num_groups = self.num_q_heads // self.num_k_heads
594
+
595
+ def extra_repr(self) -> str:
596
+ return (
597
+ super().extra_repr()
598
+ + f"query_heads={self.num_q_heads}, key_heads={self.num_k_heads}, value_heads={self.num_v_heads}"
599
+ )
600
+
601
+ def forward(
602
+ self,
603
+ hidden_states: torch.Tensor,
604
+ attention_mask: Optional[torch.Tensor] = None,
605
+ past_key_value: Optional[Cache] = None,
606
+ output_attentions: bool = False,
607
+ use_cache: bool = False,
608
+ cache_position: Optional[torch.LongTensor] = None,
609
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
610
+ """
611
+ Forward pass of multi-head self-attention.
612
+ Args:
613
+ hidden_states: Input tensor of the shape [batch size, sequence length, model dimension].
614
+ past_key_value: Tensor storing the cached keys and values.
615
+ output_attentions: output attention weights.
616
+ use_cache: Specifies whether to use kv-cache for generation.
617
+ cache_position: used for updating the kv-cache.
618
+ Returns:
619
+ The output of the same shape as the input, optionally with a tensor containing cached keys and values.
620
+ """
621
+
622
+ # scaled_dot_product_attention does not return attention weights, set output_attentions to False
623
+ output_attentions = False
624
+ batch_size, seq_length, d_model = hidden_states.size()
625
+
626
+ # [B, S, d] --> [B, S, (q_h + k_h + v_h) * h]
627
+ qkv = self.qkv_proj(hidden_states)
628
+ # [B, S, (q_h + k_h + v_h) * h] --> [B, S, (q_h + k_h + v_h), h]
629
+ qkv = qkv.reshape(
630
+ batch_size,
631
+ seq_length,
632
+ self.num_q_heads + self.num_k_heads + self.num_v_heads,
633
+ self.head_dim,
634
+ )
635
+ # [B, S, (q_h + k_h + v_h), h] --> [B, (q_h + k_h + v_h), S, h]
636
+ qkv = qkv.transpose(1, 2)
637
+ # [B, (q_h + k_h + v_h), S, h] --> [B, q_h, S h], [B, k_h, S, h], [B, v_h, S, h]
638
+ queries, keys, values = qkv.split(
639
+ [self.num_q_heads, self.num_k_heads, self.num_v_heads], dim=1
640
+ )
641
+
642
+ if self.q_norm is not None:
643
+ queries = self.q_norm(queries)
644
+
645
+ if self.k_norm is not None:
646
+ keys = self.k_norm(keys)
647
+
648
+ past_key_value = getattr(self, "past_key_value", past_key_value)
649
+
650
+ if past_key_value is not None:
651
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
652
+ # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
653
+ cache_kwargs = {"cache_position": cache_position}
654
+ keys, values = past_key_value.update(
655
+ keys, values, self.layer_idx, cache_kwargs
656
+ )
657
+
658
+ # Add positional embedding
659
+ queries, keys = self.pos_embedding(queries, keys)
660
+
661
+ if self.num_groups != 1:
662
+ # GQA
663
+ # [B, k_h, S, h] --> [B, q_h, S, h]
664
+ keys = keys.repeat_interleave(self.num_groups, dim=1)
665
+ # [B, v_h, S, h] --> [B, q_h, S, h]
666
+ values = values.repeat_interleave(self.num_groups, dim=1)
667
+
668
+ causal_mask = attention_mask
669
+ if attention_mask is not None and cache_position is not None:
670
+ causal_mask = causal_mask[:, :, cache_position, : keys.shape[-2]]
671
+
672
+ attn_output = F.scaled_dot_product_attention(
673
+ queries,
674
+ keys,
675
+ values,
676
+ attn_mask=causal_mask,
677
+ dropout_p=0,
678
+ )
679
+
680
+ attn_output = attn_output.transpose(1, 2).contiguous()
681
+ attn_output = attn_output.reshape(
682
+ batch_size, seq_length, self.num_q_heads * self.head_dim
683
+ )
684
+ attn_output = self.out_proj(attn_output)
685
+ if not output_attentions:
686
+ attn_weights = None
687
+ return attn_output, attn_weights, past_key_value
688
+
689
+
690
+ class OpenELMFeedForwardNetwork(nn.Module):
691
+ def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
692
+ super().__init__()
693
+ ffn_multiplier = config.ffn_multipliers[layer_idx]
694
+ intermediate_dim = int(
695
+ make_divisible(
696
+ ffn_multiplier * config.model_dim,
697
+ divisor=config.ffn_dim_divisor,
698
+ )
699
+ )
700
+ if config.ffn_with_glu:
701
+ # FFN with Gated linear unit, as described in https://arxiv.org/abs/2002.05202v1.
702
+ self.proj_1 = nn.Linear(
703
+ in_features=config.model_dim,
704
+ out_features=2 * intermediate_dim,
705
+ bias=False,
706
+ )
707
+ self.proj_2 = nn.Linear(
708
+ in_features=intermediate_dim,
709
+ out_features=config.model_dim,
710
+ bias=False,
711
+ )
712
+ self.ffn_with_glu = True
713
+ else:
714
+ # Standard FFN, as described in https://arxiv.org/abs/1706.03762
715
+ self.proj_1 = nn.Linear(
716
+ in_features=config.model_dim,
717
+ out_features=intermediate_dim,
718
+ bias=False,
719
+ )
720
+ self.proj_2 = nn.Linear(
721
+ in_features=intermediate_dim,
722
+ out_features=config.model_dim,
723
+ bias=False,
724
+ )
725
+ self.ffn_with_glu = False
726
+
727
+ self.act = ACT2FN[config.activation_fn_name]
728
+
729
+ def extra_repr(self) -> str:
730
+ return super().extra_repr() + f"(ffn_with_glu) : {self.ffn_with_glu}"
731
+
732
+ def forward(self, x: Tensor) -> Tensor:
733
+ """Forward function of FFN layer.
734
+ Args:
735
+ x: Input tensor of the shape [batch size, sequence length, model dimension].
736
+ Returns:
737
+ A tensor of the same shape as the input.
738
+ """
739
+ if self.ffn_with_glu:
740
+ y_12 = self.proj_1(x)
741
+ y_1, y_2 = y_12.chunk(2, dim=-1)
742
+ y = self.act(y_1) * y_2
743
+ return self.proj_2(y)
744
+ else:
745
+ return self.proj_2(self.act(self.proj_1(x)))
746
+
747
+
748
+ class OpenELMDecoderLayer(nn.Module):
749
+ def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
750
+ super().__init__()
751
+ self.attn = OpenELMMultiHeadCausalAttention(config=config, layer_idx=layer_idx)
752
+ self.ffn = OpenELMFeedForwardNetwork(config=config, layer_idx=layer_idx)
753
+ self.ffn_norm = OpenELMRMSNorm(
754
+ num_features=config.model_dim,
755
+ )
756
+ self.attn_norm = OpenELMRMSNorm(
757
+ num_features=config.model_dim,
758
+ )
759
+
760
+ def forward(
761
+ self,
762
+ hidden_states: torch.Tensor,
763
+ attention_mask: Optional[torch.Tensor] = None,
764
+ position_ids: Optional[torch.LongTensor] = None,
765
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
766
+ output_attentions: Optional[bool] = False,
767
+ use_cache: Optional[bool] = False,
768
+ cache_position: Optional[torch.LongTensor] = None,
769
+ **kwargs,
770
+ ) -> Tuple[
771
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
772
+ ]:
773
+ """
774
+ Args:
775
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
776
+ attention_mask (`torch.FloatTensor`, *optional*):
777
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
778
+ query_sequence_length, key_sequence_length)` if default attention is used.
779
+ output_attentions (`bool`, *optional*):
780
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
781
+ returned tensors for more detail.
782
+ use_cache (`bool`, *optional*):
783
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
784
+ (see `past_key_values`).
785
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
786
+ """
787
+ residual = hidden_states
788
+ hidden_states = self.attn_norm(hidden_states)
789
+
790
+ # Self Attention
791
+ hidden_states, self_attn_weights, present_key_value = self.attn(
792
+ hidden_states=hidden_states,
793
+ attention_mask=attention_mask,
794
+ past_key_value=past_key_value,
795
+ output_attentions=output_attentions,
796
+ use_cache=use_cache,
797
+ cache_position=cache_position,
798
+ **kwargs,
799
+ )
800
+ hidden_states = residual + hidden_states
801
+
802
+ # Fully Connected
803
+ residual = hidden_states
804
+ hidden_states = self.ffn_norm(hidden_states)
805
+ hidden_states = self.ffn(hidden_states)
806
+ hidden_states = residual + hidden_states
807
+
808
+ outputs = (hidden_states,)
809
+
810
+ if output_attentions:
811
+ outputs += (self_attn_weights,)
812
+
813
+ if use_cache:
814
+ outputs += (present_key_value,)
815
+
816
+ return outputs
817
+
818
+
819
+ class OpenELMModel(OpenELMPreTrainedModel):
820
+ config_class = OpenELMConfig
821
+
822
+ def __init__(self, config: OpenELMConfig):
823
+ super().__init__(config)
824
+ self.config = config
825
+
826
+ self.token_embeddings = nn.Embedding(
827
+ embedding_dim=config.model_dim,
828
+ num_embeddings=config.vocab_size,
829
+ )
830
+
831
+ self.layers = nn.ModuleList(
832
+ OpenELMDecoderLayer(config=config, layer_idx=layer_idx)
833
+ for layer_idx in range(config.num_transformer_layers)
834
+ )
835
+ self.norm = OpenELMRMSNorm(num_features=config.model_dim)
836
+ if config.share_input_output_layers:
837
+ self.classifier = None
838
+ else:
839
+ self.classifier = nn.Linear(
840
+ in_features=config.model_dim,
841
+ out_features=config.vocab_size,
842
+ bias=False,
843
+ )
844
+ self.num_transformer_layers = config.num_transformer_layers
845
+ self.gradient_checkpointing = False
846
+
847
+ # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
848
+ # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_context_length`.
849
+ causal_mask = torch.full(
850
+ (config.max_context_length, config.max_context_length),
851
+ fill_value=True,
852
+ dtype=torch.bool,
853
+ )
854
+ self.register_buffer(
855
+ "causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False
856
+ )
857
+
858
+ # Initialize weights and apply final processing
859
+ self.post_init()
860
+ self.reset_parameters(config=config)
861
+
862
+ def get_input_embeddings(self):
863
+ return self.token_embeddings
864
+
865
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
866
+ self.token_embeddings = new_embeddings
867
+
868
+ def reset_parameters(self, config: OpenELMConfig) -> None:
869
+ """Initialize the layers in Language Model
870
+ The initialization scheme is followed, following `OPT <https://arxiv.org/pdf/2205.01068.pdf>`_.
871
+ Args:
872
+ use_megatron_std: Use standard deviation as described in Megatron-LM.
873
+ Returns:
874
+ None
875
+ """
876
+ for module in self.modules():
877
+ if isinstance(module, nn.Linear):
878
+ std = module.in_features**-0.5
879
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
880
+ if module.bias is not None:
881
+ torch.nn.init.zeros_(module.bias)
882
+ elif isinstance(module, nn.Embedding):
883
+ std = module.embedding_dim**-0.5
884
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
885
+ elif isinstance(module, OpenELMRMSNorm):
886
+ if module.weight is not None:
887
+ torch.nn.init.ones_(module.weight)
888
+ if hasattr(module, "bias") and module.bias is not None:
889
+ torch.nn.init.zeros_(module.bias)
890
+
891
+ model_dim = config.model_dim
892
+ n_layers = config.num_transformer_layers
893
+ std = (model_dim**-0.5) * ((2 * n_layers) ** -0.5)
894
+ for param_name, param in self.named_parameters():
895
+ if param_name.endswith("out_proj.weight") or param_name.endswith(
896
+ "ffn.proj_2.weight"
897
+ ):
898
+ torch.nn.init.normal_(param, mean=0.0, std=std)
899
+
900
+ def forward(
901
+ self,
902
+ input_ids: torch.LongTensor = None,
903
+ attention_mask: Optional[torch.Tensor] = None,
904
+ position_ids: Optional[torch.LongTensor] = None,
905
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
906
+ inputs_embeds: Optional[torch.FloatTensor] = None,
907
+ use_cache: Optional[bool] = None,
908
+ output_attentions: Optional[bool] = None,
909
+ output_hidden_states: Optional[bool] = None,
910
+ return_dict: Optional[bool] = None,
911
+ cache_position: Optional[torch.LongTensor] = None,
912
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
913
+ output_attentions = (
914
+ output_attentions
915
+ if output_attentions is not None
916
+ else self.config.output_attentions
917
+ )
918
+ output_hidden_states = (
919
+ output_hidden_states
920
+ if output_hidden_states is not None
921
+ else self.config.output_hidden_states
922
+ )
923
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
924
+ return_dict = (
925
+ return_dict if return_dict is not None else self.config.use_return_dict
926
+ )
927
+
928
+ if (input_ids is None) ^ (inputs_embeds is not None):
929
+ raise ValueError(
930
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
931
+ )
932
+
933
+ if self.gradient_checkpointing and self.training and use_cache:
934
+ logger.warning_once(
935
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
936
+ )
937
+ use_cache = False
938
+
939
+ if inputs_embeds is None:
940
+ inputs_embeds = self.token_embeddings(input_ids)
941
+
942
+ past_seen_tokens = 0
943
+ if use_cache: # kept for BC (cache positions)
944
+ if not isinstance(past_key_values, StaticCache):
945
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
946
+ past_seen_tokens = past_key_values.get_seq_length()
947
+
948
+ if cache_position is None:
949
+ cache_position = torch.arange(
950
+ past_seen_tokens,
951
+ past_seen_tokens + inputs_embeds.shape[1],
952
+ device=inputs_embeds.device,
953
+ )
954
+
955
+ if position_ids is None:
956
+ position_ids = cache_position.unsqueeze(0)
957
+
958
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
959
+
960
+ # embed positions
961
+ hidden_states = inputs_embeds
962
+
963
+ # decoder layers
964
+ all_hidden_states = () if output_hidden_states else None
965
+ all_self_attns = () if output_attentions else None
966
+ next_decoder_cache = None
967
+
968
+ for decoder_layer in self.layers:
969
+ if output_hidden_states:
970
+ all_hidden_states += (hidden_states,)
971
+
972
+ if self.gradient_checkpointing and self.training:
973
+ layer_outputs = self._gradient_checkpointing_func(
974
+ decoder_layer.__call__,
975
+ hidden_states,
976
+ causal_mask,
977
+ position_ids,
978
+ past_key_values,
979
+ output_attentions,
980
+ use_cache,
981
+ cache_position,
982
+ )
983
+ else:
984
+ layer_outputs = decoder_layer(
985
+ hidden_states,
986
+ attention_mask=causal_mask,
987
+ position_ids=position_ids,
988
+ past_key_value=past_key_values,
989
+ output_attentions=output_attentions,
990
+ use_cache=use_cache,
991
+ cache_position=cache_position,
992
+ )
993
+
994
+ hidden_states = layer_outputs[0]
995
+
996
+ if use_cache:
997
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
998
+
999
+ if output_attentions:
1000
+ all_self_attns += (layer_outputs[1],)
1001
+
1002
+ hidden_states = self.norm(hidden_states)
1003
+
1004
+ # add hidden states from the last decoder layer
1005
+ if output_hidden_states:
1006
+ all_hidden_states += (hidden_states,)
1007
+
1008
+ next_cache = None
1009
+ if use_cache:
1010
+ next_cache = (
1011
+ next_decoder_cache.to_legacy_cache()
1012
+ if isinstance(next_decoder_cache, Cache)
1013
+ else next_decoder_cache
1014
+ )
1015
+ if not return_dict:
1016
+ return tuple(
1017
+ v
1018
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1019
+ if v is not None
1020
+ )
1021
+ return BaseModelOutputWithPast(
1022
+ last_hidden_state=hidden_states,
1023
+ past_key_values=next_cache,
1024
+ hidden_states=all_hidden_states,
1025
+ attentions=all_self_attns,
1026
+ )
1027
+
1028
+ def _update_causal_mask(self, attention_mask, input_tensor):
1029
+ if self.config._attn_implementation == "flash_attention_2":
1030
+ if attention_mask is not None and 0.0 in attention_mask:
1031
+ return attention_mask
1032
+ return None
1033
+
1034
+ batch_size, seq_length = input_tensor.shape[:2]
1035
+ dtype = input_tensor.dtype
1036
+ device = input_tensor.device
1037
+
1038
+ # support going beyond cached `max_position_embedding`
1039
+ if seq_length > self.causal_mask.shape[-1]:
1040
+ causal_mask = torch.full(
1041
+ (2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]),
1042
+ fill_value=1,
1043
+ )
1044
+ self.register_buffer(
1045
+ "causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False
1046
+ )
1047
+
1048
+ # We use the current dtype to avoid any overflows
1049
+ min_dtype = torch.finfo(dtype).min
1050
+ causal_mask = (
1051
+ self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype)
1052
+ * min_dtype
1053
+ )
1054
+
1055
+ causal_mask = causal_mask.to(dtype=dtype, device=device)
1056
+ if attention_mask is not None and attention_mask.dim() == 2:
1057
+ mask_length = attention_mask.shape[-1]
1058
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[
1059
+ :, None, None, :
1060
+ ].eq(0.0)
1061
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
1062
+ padding_mask, min_dtype
1063
+ )
1064
+
1065
+ if self.config._attn_implementation == "sdpa" and attention_mask is not None:
1066
+ # For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1067
+ is_tracing = (
1068
+ torch.jit.is_tracing()
1069
+ or isinstance(input_tensor, torch.fx.Proxy)
1070
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
1071
+ )
1072
+ if not is_tracing and torch.any(attention_mask != 1):
1073
+ # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
1074
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1075
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1076
+ causal_mask = causal_mask.mul(
1077
+ ~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)
1078
+ ).to(dtype)
1079
+
1080
+ return causal_mask
1081
+
1082
+
1083
+ class OpenELMForCausalLM(OpenELMPreTrainedModel):
1084
+ _tied_weights_keys = ["lm_head.weight"]
1085
+
1086
+ def __init__(self, config: OpenELMConfig):
1087
+ super().__init__(config)
1088
+ self.transformer = OpenELMModel(config)
1089
+ self.vocab_size = config.vocab_size
1090
+ if config.share_input_output_layers:
1091
+ self.lm_head = None
1092
+ else:
1093
+ self.lm_head = nn.Linear(config.model_dim, config.vocab_size, bias=False)
1094
+
1095
+ # Initialize weights and apply final processing
1096
+ self.post_init()
1097
+
1098
+ def get_input_embeddings(self):
1099
+ return self.transformer.token_embeddings
1100
+
1101
+ def set_input_embeddings(self, value):
1102
+ self.transformer.token_embeddings = value
1103
+
1104
+ def get_output_embeddings(self):
1105
+ return self.lm_head
1106
+
1107
+ def set_output_embeddings(self, new_embeddings):
1108
+ self.lm_head = new_embeddings
1109
+
1110
+ def set_decoder(self, decoder):
1111
+ self.transformer = decoder
1112
+
1113
+ def get_decoder(self):
1114
+ return self.transformer
1115
+
1116
+ def forward(
1117
+ self,
1118
+ input_ids: torch.LongTensor = None,
1119
+ attention_mask: Optional[torch.Tensor] = None,
1120
+ position_ids: Optional[torch.LongTensor] = None,
1121
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1122
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1123
+ labels: Optional[torch.LongTensor] = None,
1124
+ use_cache: Optional[bool] = None,
1125
+ output_attentions: Optional[bool] = None,
1126
+ output_hidden_states: Optional[bool] = None,
1127
+ return_dict: Optional[bool] = None,
1128
+ cache_position: Optional[torch.LongTensor] = None,
1129
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1130
+ output_attentions = (
1131
+ output_attentions
1132
+ if output_attentions is not None
1133
+ else self.config.output_attentions
1134
+ )
1135
+ output_hidden_states = (
1136
+ output_hidden_states
1137
+ if output_hidden_states is not None
1138
+ else self.config.output_hidden_states
1139
+ )
1140
+ return_dict = (
1141
+ return_dict if return_dict is not None else self.config.use_return_dict
1142
+ )
1143
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1144
+ outputs = self.transformer(
1145
+ input_ids=input_ids,
1146
+ attention_mask=attention_mask,
1147
+ position_ids=position_ids,
1148
+ past_key_values=past_key_values,
1149
+ inputs_embeds=inputs_embeds,
1150
+ use_cache=use_cache,
1151
+ output_attentions=output_attentions,
1152
+ output_hidden_states=output_hidden_states,
1153
+ return_dict=return_dict,
1154
+ cache_position=cache_position,
1155
+ )
1156
+
1157
+ hidden_states = outputs[0]
1158
+ if self.lm_head is None:
1159
+ # shared
1160
+ logits = F.linear(
1161
+ hidden_states, weight=self.transformer.token_embeddings.weight
1162
+ )
1163
+ else:
1164
+ logits = self.lm_head(hidden_states)
1165
+ logits = logits[:, : self.config.vocab_size]
1166
+ loss = None
1167
+ if labels is not None:
1168
+ # Shift so that tokens < n predict n
1169
+ shift_logits = logits[..., :-1, :].contiguous()
1170
+ shift_labels = labels[..., 1:].contiguous()
1171
+ # Flatten the tokens
1172
+ loss_fct = CrossEntropyLoss()
1173
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1174
+ shift_labels = shift_labels.view(-1)
1175
+ # Enable model parallelism
1176
+ shift_labels = shift_labels.to(shift_logits.device)
1177
+ loss = loss_fct(shift_logits, shift_labels)
1178
+
1179
+ if not return_dict:
1180
+ output = (logits,) + outputs[1:]
1181
+ return (loss,) + output if loss is not None else output
1182
+
1183
+ return CausalLMOutputWithPast(
1184
+ loss=loss,
1185
+ logits=logits,
1186
+ past_key_values=outputs.past_key_values,
1187
+ hidden_states=outputs.hidden_states,
1188
+ attentions=outputs.attentions,
1189
+ )
1190
+
1191
+ def prepare_inputs_for_generation(
1192
+ self,
1193
+ input_ids,
1194
+ past_key_values=None,
1195
+ attention_mask=None,
1196
+ inputs_embeds=None,
1197
+ **kwargs,
1198
+ ):
1199
+ past_length = 0
1200
+ if past_key_values is not None:
1201
+ if isinstance(past_key_values, Cache):
1202
+ cache_length = past_key_values.get_seq_length()
1203
+ past_length = past_key_values.seen_tokens
1204
+ max_cache_length = past_key_values.get_max_length()
1205
+ else:
1206
+ cache_length = past_length = past_key_values[0][0].shape[2]
1207
+ max_cache_length = None
1208
+
1209
+ # Keep only the unprocessed tokens:
1210
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1211
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1212
+ # input)
1213
+ if (
1214
+ attention_mask is not None
1215
+ and attention_mask.shape[1] > input_ids.shape[1]
1216
+ ):
1217
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1218
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1219
+ # input_ids based on the past_length.
1220
+ elif past_length < input_ids.shape[1]:
1221
+ input_ids = input_ids[:, past_length:]
1222
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1223
+
1224
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1225
+ if (
1226
+ max_cache_length is not None
1227
+ and attention_mask is not None
1228
+ and cache_length + input_ids.shape[1] > max_cache_length
1229
+ ):
1230
+ attention_mask = attention_mask[:, -max_cache_length:]
1231
+
1232
+ position_ids = kwargs.get("position_ids", None)
1233
+ if attention_mask is not None and position_ids is None:
1234
+ # create position_ids on the fly for batch generation
1235
+ position_ids = attention_mask.long().cumsum(-1) - 1
1236
+ position_ids.masked_fill_(attention_mask == 0, 1)
1237
+ if past_key_values:
1238
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1239
+
1240
+ if self.generation_config.cache_implementation == "static":
1241
+ # generation with static cache
1242
+ cache_position = kwargs.get("cache_position", None)
1243
+ if cache_position is None:
1244
+ past_length = 0
1245
+ else:
1246
+ past_length = cache_position[-1] + 1
1247
+ input_ids = input_ids[:, past_length:]
1248
+ position_ids = position_ids[:, past_length:]
1249
+
1250
+ # we should only keep a `cache_position` in generate, and do +=1.
1251
+ # same goes for position ids. Could also help with continued generation.
1252
+ cache_position = torch.arange(
1253
+ past_length,
1254
+ past_length + position_ids.shape[-1],
1255
+ device=position_ids.device,
1256
+ )
1257
+
1258
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1259
+ if inputs_embeds is not None and past_key_values is None:
1260
+ model_inputs = {"inputs_embeds": inputs_embeds}
1261
+ else:
1262
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1263
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1264
+ # We could use `next_tokens` directly instead.
1265
+ model_inputs = {"input_ids": input_ids.contiguous()}
1266
+
1267
+ model_inputs.update(
1268
+ {
1269
+ "position_ids": position_ids.contiguous(),
1270
+ "cache_position": cache_position,
1271
+ "past_key_values": past_key_values,
1272
+ "use_cache": kwargs.get("use_cache"),
1273
+ "attention_mask": attention_mask,
1274
+ }
1275
+ )
1276
+ return model_inputs
1277
+
1278
+ @staticmethod
1279
+ def _reorder_cache(past_key_values, beam_idx):
1280
+ reordered_past = ()
1281
+ for layer_past in past_key_values:
1282
+ reordered_past += (
1283
+ tuple(
1284
+ past_state.index_select(0, beam_idx.to(past_state.device))
1285
+ for past_state in layer_past
1286
+ ),
1287
+ )
1288
+ return reordered_past
modeling_tinyllava_elm.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union
3
+ import ast
4
+ import re
5
+
6
+ import torch
7
+ import torch.utils.checkpoint
8
+ from torch import nn
9
+
10
+ from transformers import PreTrainedModel
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from transformers.generation.utils import GenerateOutput
13
+ from transformers import CLIPVisionModel, CLIPImageProcessor,SiglipVisionModel, SiglipImageProcessor
14
+
15
+ from configuration import TinyLlavaConfig
16
+ from utils import *
17
+ from modeling_elm import OpenELMForCausalLM
18
+
19
+ # from tinyllava.utils.data_utils import get_value_from_kwargs
20
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
21
+ WORKER_HEART_BEAT_INTERVAL = 15
22
+
23
+ LOGDIR = "."
24
+ import os
25
+
26
+ ACT_TYPE = {
27
+ 'relu': nn.ReLU,
28
+ 'gelu': nn.GELU
29
+ }
30
+
31
+ class Connector(nn.Module):
32
+ def __init__(self, config=None):
33
+ super().__init__()
34
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', config.connector_type)
35
+ act_type = config.connector_type.split('_')[-1]
36
+ mlp_depth = int(mlp_gelu_match.group(1))
37
+ modules = [nn.Linear(config.vision_hidden_size, config.hidden_size)]
38
+ for _ in range(1, mlp_depth):
39
+ modules.append(ACT_TYPE[act_type]())
40
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
41
+
42
+ self._connector = nn.Sequential(*modules)
43
+
44
+ def forward(self, x):
45
+ return self._connector(x)
46
+
47
+ class VisionTower(nn.Module):
48
+ def __init__(self, cfg, model_name_or_path = 'clip'):
49
+ super().__init__()
50
+ if 'clip' in model_name_or_path:
51
+ self._vision_tower = CLIPVisionModel(cfg)
52
+ self._image_processor = CLIPImageProcessor.from_pretrained(cfg.model_name_or_path)
53
+ else:
54
+ self._vision_tower = SiglipVisionModel(cfg)
55
+ self._image_processor = SiglipImageProcessor.from_pretrained(cfg.model_name_or_path)
56
+
57
+ self.config = cfg
58
+
59
+
60
+
61
+ def forward(self, x, **kwargs):
62
+ image_features = self._vision_tower(x, output_hidden_states=True)
63
+ image_features = image_features.hidden_states[kwargs.get('vision_feature_layer', -2)]
64
+
65
+ if kwargs.get('vision_feature_select_strategy', 'patch') == 'patch':
66
+ image_features = image_features[:, 1:]
67
+ elif kwargs.get('vision_feature_select_strategy', 'patch') == 'cls_patch':
68
+ image_features = image_features
69
+ else:
70
+ raise ValueError(f"Unexpected select feature: {kwargs.get('vision_feature_select_strategy')}")
71
+
72
+ return image_features
73
+
74
+
75
+
76
+ @property
77
+ def vision_tower(self):
78
+ return self._vision_tower
79
+
80
+ @vision_tower.setter
81
+ def vision_tower(self, vision_tower):
82
+ self._vision_tower = vision_tower
83
+
84
+ def get_value_from_kwargs(kwargs, name):
85
+ if name in kwargs:
86
+ return kwargs.pop(name)
87
+ else:
88
+ return None
89
+
90
+
91
+
92
+ class TinyLlavaPreTrainedModel(PreTrainedModel):
93
+ config_class = TinyLlavaConfig
94
+ base_model_prefix = "model"
95
+ supports_gradient_checkpointing = True
96
+ _no_split_modules = ["LlavaVisionAttention"]
97
+ _skip_keys_device_placement = "past_key_values"
98
+ _supports_flash_attn_2 = True
99
+
100
+ def _init_weights(self, module):
101
+ std = (
102
+ self.config.initializer_range
103
+ if hasattr(self.config, "initializer_range")
104
+ else self.config.text_config.initializer_range
105
+ )
106
+
107
+ if hasattr(module, "class_embedding"):
108
+ module.class_embedding.data.normal_(mean=0.0, std=std)
109
+
110
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
111
+ module.weight.data.normal_(mean=0.0, std=std)
112
+ if module.bias is not None:
113
+ module.bias.data.zero_()
114
+ elif isinstance(module, nn.Embedding):
115
+ module.weight.data.normal_(mean=0.0, std=std)
116
+ if module.padding_idx is not None:
117
+ module.weight.data[module.padding_idx].zero_()
118
+
119
+ @property
120
+ def _supports_sdpa(self):
121
+ return self.language_model._supports_sdpa
122
+
123
+
124
+ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
125
+ def __init__(self, config: TinyLlavaConfig):
126
+
127
+ super().__init__(config)
128
+
129
+ self.language_model = OpenELMForCausalLM(config.text_config)
130
+ self.vision_tower = VisionTower(config.vision_config, config.vision_model_name_or_path)
131
+ self.connector = Connector(config)
132
+ self.post_init()
133
+
134
+
135
+ def get_input_embeddings(self):
136
+ return self.language_model.get_input_embeddings()
137
+
138
+ def set_input_embeddings(self, value):
139
+ self.language_model.set_input_embeddings(value)
140
+
141
+ def get_output_embeddings(self):
142
+ return self.language_model.get_output_embeddings()
143
+
144
+ def set_output_embeddings(self, new_embeddings):
145
+ self.language_model.set_output_embeddings(new_embeddings)
146
+
147
+ def set_decoder(self, decoder):
148
+ self.language_model.set_decoder(decoder)
149
+
150
+ def get_decoder(self):
151
+ return self.language_model.get_decoder()
152
+
153
+ def tie_weights(self):
154
+ return self.language_model.tie_weights()
155
+
156
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
157
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
158
+ # update vocab size
159
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
160
+ self.config.vocab_size = model_embeds.num_embeddings
161
+ self.vocab_size = model_embeds.num_embeddings
162
+ return model_embeds
163
+
164
+
165
+ def forward(
166
+ self,
167
+ input_ids: torch.LongTensor = None,
168
+ attention_mask: Optional[torch.Tensor] = None,
169
+ position_ids: Optional[torch.LongTensor] = None,
170
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
171
+ inputs_embeds: Optional[torch.FloatTensor] = None,
172
+ labels: Optional[torch.LongTensor] = None,
173
+ use_cache: Optional[bool] = None,
174
+ output_attentions: Optional[bool] = None,
175
+ output_hidden_states: Optional[bool] = None,
176
+ images: Optional[torch.FloatTensor] = None,
177
+ image_sizes: Optional[List[List[int]]] = None,
178
+ return_dict: Optional[bool] = None,
179
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
180
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
181
+ if inputs_embeds is None:
182
+ (
183
+ input_ids,
184
+ position_ids,
185
+ attention_mask,
186
+ past_key_values,
187
+ inputs_embeds,
188
+ labels
189
+ ) = self.prepare_inputs_labels_for_multimodal(
190
+ input_ids,
191
+ position_ids,
192
+ attention_mask,
193
+ past_key_values,
194
+ labels,
195
+ images,
196
+ image_sizes
197
+ )
198
+ return self.language_model.forward(
199
+ input_ids=input_ids,
200
+ attention_mask=attention_mask,
201
+ position_ids=position_ids,
202
+ past_key_values=past_key_values,
203
+ inputs_embeds=inputs_embeds,
204
+ labels=labels,
205
+ use_cache=use_cache,
206
+ output_attentions=output_attentions,
207
+ output_hidden_states=output_hidden_states,
208
+ return_dict=return_dict
209
+ )
210
+
211
+ @torch.no_grad()
212
+ def generate(
213
+ self,
214
+ inputs: Optional[torch.Tensor] = None,
215
+ images: Optional[torch.Tensor] = None,
216
+ image_sizes: Optional[torch.Tensor] = None,
217
+ **kwargs,
218
+ ) -> Union[GenerateOutput, torch.LongTensor]:
219
+ position_ids = kwargs.pop("position_ids", None)
220
+ attention_mask = kwargs.pop("attention_mask", None)
221
+ if "inputs_embeds" in kwargs:
222
+ raise NotImplementedError("`inputs_embeds` is not supported")
223
+
224
+ if images is not None:
225
+ (
226
+ inputs,
227
+ position_ids,
228
+ attention_mask,
229
+ _,
230
+ inputs_embeds,
231
+ _
232
+ ) = self.prepare_inputs_labels_for_multimodal(
233
+ inputs,
234
+ position_ids,
235
+ attention_mask,
236
+ None,
237
+ None,
238
+ images,
239
+ image_sizes=image_sizes
240
+ )
241
+ else:
242
+ inputs_embeds = self.language_model.get_input_embeddings()(inputs)
243
+
244
+ return self.language_model.generate(
245
+ position_ids=position_ids,
246
+ attention_mask=attention_mask,
247
+ inputs_embeds=inputs_embeds,
248
+ **kwargs
249
+ )
250
+
251
+ def encode_images(self, images):
252
+ kwargs = {}
253
+ kwargs['vision_feature_layer'] = self.config.vision_feature_layer
254
+ kwargs['vision_feature_select_strategy'] = self.config.vision_feature_select_strategy
255
+ images = images.to(device=self.device, dtype=self.dtype)
256
+ image_features = self.vision_tower(images, **kwargs)
257
+ image_features = self.connector(image_features)
258
+ return image_features
259
+
260
+
261
+
262
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
263
+ inputs_embeds=None, **kwargs):
264
+ images = kwargs.pop("images", None)
265
+ image_sizes = kwargs.pop("image_sizes", None)
266
+ inputs = self.language_model.prepare_inputs_for_generation(
267
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
268
+ )
269
+ if images is not None:
270
+ inputs['images'] = images
271
+ if image_sizes is not None:
272
+ inputs['image_sizes'] = image_sizes
273
+ return inputs
274
+
275
+ def prepare_inputs_labels_for_multimodal(
276
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
277
+ images, image_sizes=None
278
+ ):
279
+ vision_tower = self.vision_tower
280
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
281
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
282
+
283
+
284
+ image_features = self.encode_images(images)
285
+
286
+ # TODO: image start / end is not implemented here to support pretraining.
287
+ if getattr(self.config, 'tune_mm_mlp_adapter', False):
288
+ raise NotImplementedError
289
+
290
+ # Let's just add dummy tensors if they do not exist,
291
+ # it is a headache to deal with None all the time.
292
+ # But it is not ideal, and if you have a better idea,
293
+ # please open an issue / submit a PR, thanks.
294
+ _labels = labels
295
+ _position_ids = position_ids
296
+ _attention_mask = attention_mask
297
+ if attention_mask is None:
298
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
299
+ else:
300
+ attention_mask = attention_mask.bool()
301
+ if position_ids is None:
302
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
303
+ if labels is None:
304
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
305
+
306
+ # remove the padding using attention_mask -- FIXME
307
+ _input_ids = input_ids
308
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
309
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
310
+
311
+ new_input_embeds = []
312
+ new_labels = []
313
+ cur_image_idx = 0
314
+ for batch_idx, cur_input_ids in enumerate(input_ids):
315
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
316
+ if num_images == 0:
317
+ cur_image_features = image_features[cur_image_idx]
318
+ cur_input_embeds_1 = self.language_model.get_input_embeddings()(cur_input_ids)
319
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
320
+ new_input_embeds.append(cur_input_embeds)
321
+ new_labels.append(labels[batch_idx])
322
+ cur_image_idx += 1
323
+ continue
324
+
325
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
326
+ cur_input_ids_noim = []
327
+ cur_labels = labels[batch_idx]
328
+ cur_labels_noim = []
329
+ for i in range(len(image_token_indices) - 1):
330
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
331
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
332
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
333
+ cur_input_embeds = self.language_model.get_input_embeddings()(torch.cat(cur_input_ids_noim))
334
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
335
+ cur_new_input_embeds = []
336
+ cur_new_labels = []
337
+
338
+ for i in range(num_images + 1):
339
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
340
+ cur_new_labels.append(cur_labels_noim[i])
341
+ if i < num_images:
342
+ cur_image_features = image_features[cur_image_idx]
343
+ cur_image_idx += 1
344
+ cur_new_input_embeds.append(cur_image_features)
345
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
346
+
347
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
348
+
349
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
350
+ cur_new_labels = torch.cat(cur_new_labels)
351
+
352
+ new_input_embeds.append(cur_new_input_embeds)
353
+ new_labels.append(cur_new_labels)
354
+
355
+ # Truncate sequences to max length as image embeddings can make the sequence longer
356
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
357
+ if tokenizer_model_max_length is not None:
358
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
359
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
360
+
361
+ # Combine them
362
+ max_len = max(x.shape[0] for x in new_input_embeds)
363
+ batch_size = len(new_input_embeds)
364
+
365
+ new_input_embeds_padded = []
366
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
367
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
368
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
369
+
370
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
371
+ cur_len = cur_new_embed.shape[0]
372
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
373
+ new_input_embeds_padded.append(torch.cat((
374
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
375
+ cur_new_embed
376
+ ), dim=0))
377
+ if cur_len > 0:
378
+ new_labels_padded[i, -cur_len:] = cur_new_labels
379
+ attention_mask[i, -cur_len:] = True
380
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
381
+ else:
382
+ new_input_embeds_padded.append(torch.cat((
383
+ cur_new_embed,
384
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
385
+ ), dim=0))
386
+ if cur_len > 0:
387
+ new_labels_padded[i, :cur_len] = cur_new_labels
388
+ attention_mask[i, :cur_len] = True
389
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
390
+
391
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
392
+
393
+ if _labels is None:
394
+ new_labels = None
395
+ else:
396
+ new_labels = new_labels_padded
397
+
398
+ if _attention_mask is None:
399
+ attention_mask = None
400
+ else:
401
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
402
+
403
+ if _position_ids is None:
404
+ position_ids = None
405
+
406
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
407
+
408
+
409
+
410
+
411
+
412
+
413
+
special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<unk>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": true,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "bos_token": "<s>",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "legacy": false,
35
+ "model_max_length": 2048,
36
+ "pad_token": "<unk>",
37
+ "padding_side": "right",
38
+ "sp_model_kwargs": {},
39
+ "spaces_between_special_tokens": false,
40
+ "tokenizer_class": "LlamaTokenizer",
41
+ "unk_token": "<unk>",
42
+ "use_default_system_prompt": false
43
+ }
trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28c57737631f2f1f316a23a19856352df64c085809582ee4e34ef5271a9af2b0
3
+ size 7099
utils.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+
5
+ import torch
6
+ from transformers import StoppingCriteria
7
+
8
+ import math
9
+ import ast
10
+
11
+ # Model Constants
12
+ IGNORE_INDEX = -100
13
+ IMAGE_TOKEN_INDEX = -200
14
+ DEFAULT_IMAGE_TOKEN = "<image>"
15
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
16
+ DEFAULT_IM_START_TOKEN = "<im_start>"
17
+ DEFAULT_IM_END_TOKEN = "<im_end>"
18
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
19
+
20
+ def select_best_resolution(original_size, possible_resolutions):
21
+ """
22
+ Selects the best resolution from a list of possible resolutions based on the original size.
23
+
24
+ Args:
25
+ original_size (tuple): The original size of the image in the format (width, height).
26
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
27
+
28
+ Returns:
29
+ tuple: The best fit resolution in the format (width, height).
30
+ """
31
+ original_width, original_height = original_size
32
+ best_fit = None
33
+ max_effective_resolution = 0
34
+ min_wasted_resolution = float('inf')
35
+
36
+ for width, height in possible_resolutions:
37
+ scale = min(width / original_width, height / original_height)
38
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
39
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
40
+ wasted_resolution = (width * height) - effective_resolution
41
+
42
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
43
+ max_effective_resolution = effective_resolution
44
+ min_wasted_resolution = wasted_resolution
45
+ best_fit = (width, height)
46
+
47
+ return best_fit
48
+
49
+
50
+ ## added by llava-1.6
51
+ def resize_and_pad_image(image, target_resolution):
52
+ """
53
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
54
+
55
+ Args:
56
+ image (PIL.Image.Image): The input image.
57
+ target_resolution (tuple): The target resolution (width, height) of the image.
58
+
59
+ Returns:
60
+ PIL.Image.Image: The resized and padded image.
61
+ """
62
+ original_width, original_height = image.size
63
+ target_width, target_height = target_resolution
64
+
65
+ scale_w = target_width / original_width
66
+ scale_h = target_height / original_height
67
+
68
+ if scale_w < scale_h:
69
+ new_width = target_width
70
+ new_height = min(math.ceil(original_height * scale_w), target_height)
71
+ else:
72
+ new_height = target_height
73
+ new_width = min(math.ceil(original_width * scale_h), target_width)
74
+
75
+ # Resize the image
76
+ resized_image = image.resize((new_width, new_height))
77
+
78
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
79
+ paste_x = (target_width - new_width) // 2
80
+ paste_y = (target_height - new_height) // 2
81
+ new_image.paste(resized_image, (paste_x, paste_y))
82
+
83
+ return new_image
84
+
85
+
86
+ ## added by llava-1.6
87
+ def divide_to_patches(image, patch_size):
88
+ """
89
+ Divides an image into patches of a specified size.
90
+
91
+ Args:
92
+ image (PIL.Image.Image): The input image.
93
+ patch_size (int): The size of each patch.
94
+
95
+ Returns:
96
+ list: A list of PIL.Image.Image objects representing the patches.
97
+ """
98
+ patches = []
99
+ width, height = image.size
100
+ for i in range(0, height, patch_size):
101
+ for j in range(0, width, patch_size):
102
+ box = (j, i, j + patch_size, i + patch_size)
103
+ patch = image.crop(box)
104
+ patches.append(patch)
105
+
106
+ return patches
107
+
108
+
109
+ ## added by llava-1.6
110
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
111
+ """
112
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
113
+
114
+ Args:
115
+ image_size (tuple): The size of the input image in the format (width, height).
116
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
117
+ patch_size (int): The size of each image patch.
118
+
119
+ Returns:
120
+ tuple: The shape of the image patch grid in the format (width, height).
121
+ """
122
+ if type(grid_pinpoints) is list:
123
+ possible_resolutions = grid_pinpoints
124
+ else:
125
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
126
+ width, height = select_best_resolution(image_size, possible_resolutions)
127
+ return width // patch_size, height // patch_size
128
+
129
+
130
+ ## added by llava-1.6
131
+ def process_anyres_image(image, processor, grid_pinpoints):
132
+ """
133
+ Process an image with variable resolutions.
134
+
135
+ Args:
136
+ image (PIL.Image.Image): The input image to be processed.
137
+ processor: The image processor object.
138
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
139
+
140
+ Returns:
141
+ torch.Tensor: A tensor containing the processed image patches.
142
+ """
143
+ if type(grid_pinpoints) is list:
144
+ possible_resolutions = grid_pinpoints
145
+ else:
146
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
147
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
148
+ image_padded = resize_and_pad_image(image, best_resolution)
149
+
150
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
151
+
152
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
153
+
154
+ image_patches = [image_original_resize] + patches
155
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
156
+ for image_patch in image_patches]
157
+ return torch.stack(image_patches, dim=0)
158
+
159
+
160
+ def load_image_from_base64(image):
161
+ return Image.open(BytesIO(base64.b64decode(image)))
162
+
163
+
164
+ def expand2square(pil_img, background_color):
165
+ width, height = pil_img.size
166
+ if width == height:
167
+ return pil_img
168
+ elif width > height:
169
+ result = Image.new(pil_img.mode, (width, width), background_color)
170
+ result.paste(pil_img, (0, (width - height) // 2))
171
+ return result
172
+ else:
173
+ result = Image.new(pil_img.mode, (height, height), background_color)
174
+ result.paste(pil_img, ((height - width) // 2, 0))
175
+ return result
176
+
177
+
178
+ def process_images(images, image_processor, model_cfg):
179
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
180
+ new_images = []
181
+ if image_aspect_ratio == 'pad':
182
+ for image in images:
183
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
184
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
185
+ new_images.append(image)
186
+ elif image_aspect_ratio == "anyres":
187
+ for image in images:
188
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
189
+ new_images.append(image)
190
+ else:
191
+ return image_processor(images, return_tensors='pt')['pixel_values']
192
+ if all(x.shape == new_images[0].shape for x in new_images):
193
+ new_images = torch.stack(new_images, dim=0)
194
+ return new_images
195
+
196
+
197
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
198
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
199
+
200
+ def insert_separator(X, sep):
201
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
202
+
203
+ input_ids = []
204
+ offset = 0
205
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
206
+ offset = 1
207
+ input_ids.append(prompt_chunks[0][0])
208
+
209
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
210
+ input_ids.extend(x[offset:])
211
+
212
+ if return_tensors is not None:
213
+ if return_tensors == 'pt':
214
+ return torch.tensor(input_ids, dtype=torch.long)
215
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
216
+ return input_ids
217
+
218
+
219
+ def get_model_name_from_path(model_path):
220
+ model_path = model_path.strip("/")
221
+ model_paths = model_path.split("/")
222
+ if model_paths[-1].startswith('checkpoint-'):
223
+ return model_paths[-2] + "_" + model_paths[-1]
224
+ else:
225
+ return model_paths[-1]
226
+
227
+
228
+ class KeywordsStoppingCriteria(StoppingCriteria):
229
+ def __init__(self, keywords, tokenizer, input_ids):
230
+ self.keywords = keywords
231
+ self.keyword_ids = []
232
+ self.max_keyword_len = 0
233
+ for keyword in keywords:
234
+ cur_keyword_ids = tokenizer(keyword).input_ids
235
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
236
+ cur_keyword_ids = cur_keyword_ids[1:]
237
+ if len(cur_keyword_ids) > self.max_keyword_len:
238
+ self.max_keyword_len = len(cur_keyword_ids)
239
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
240
+ self.tokenizer = tokenizer
241
+ self.start_len = input_ids.shape[1]
242
+
243
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
244
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
245
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
246
+ for keyword_id in self.keyword_ids:
247
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
248
+ return True
249
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
250
+ for keyword in self.keywords:
251
+ if keyword in outputs:
252
+ return True
253
+ return False
254
+
255
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
256
+ outputs = []
257
+ for i in range(output_ids.shape[0]):
258
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
259
+ return all(outputs)