root commited on
Commit
b4942cf
1 Parent(s): 23ce00f

Ajout du module Ovis

Browse files
ovis/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import os
2
+
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
ovis/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .visual_tokenizer.clip_visual_tokenizer import ClipVisualTokenizerConfig, ClipVisualTokenizer
2
+ from .visual_tokenizer.siglip_visual_tokenizer import SiglipVisualTokenizerConfig, SiglipVisualTokenizer
ovis/model/configuration_ovis.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Optional
2
+
3
+ from transformers import PretrainedConfig, AutoConfig
4
+
5
+
6
+ class OvisConfig(PretrainedConfig):
7
+ model_type = "ovis"
8
+
9
+ def __init__(
10
+ self,
11
+ llm_config: Optional[Union[PretrainedConfig, dict]] = None,
12
+ visual_tokenizer_config: Optional[Union[PretrainedConfig, dict]] = None,
13
+ multimodal_max_length=8192,
14
+ hidden_size=None,
15
+ conversation_formatter_class=None,
16
+ llm_attn_implementation=None,
17
+ disable_tie_weight=False,
18
+ **kwargs
19
+ ):
20
+ super().__init__(**kwargs)
21
+ if llm_config is not None:
22
+ assert isinstance(llm_config, (PretrainedConfig, dict)), \
23
+ f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type"
24
+ if not isinstance(llm_config, PretrainedConfig):
25
+ model_type = llm_config['model_type']
26
+ llm_config.pop('model_type')
27
+ llm_config = AutoConfig.for_model(model_type, **llm_config)
28
+ self.llm_config = llm_config
29
+ if visual_tokenizer_config is not None:
30
+ assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \
31
+ f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type"
32
+ if not isinstance(visual_tokenizer_config, PretrainedConfig):
33
+ model_type = visual_tokenizer_config['model_type']
34
+ visual_tokenizer_config.pop('model_type')
35
+ visual_tokenizer_config = AutoConfig.for_model(model_type, **visual_tokenizer_config)
36
+ self.visual_tokenizer_config = visual_tokenizer_config
37
+ self.multimodal_max_length = multimodal_max_length
38
+ self.hidden_size = hidden_size
39
+ self.conversation_formatter_class = conversation_formatter_class
40
+ self.llm_attn_implementation = llm_attn_implementation
41
+ self.disable_tie_weight = disable_tie_weight
ovis/model/conversation_formatter.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Dict
3
+
4
+ from ovis.util.constants import IMAGE_TOKEN_ID, IGNORE_ID, IMAGE_TOKEN
5
+
6
+
7
+ class ConversationFormatter(ABC):
8
+ support_tokenizer_types = None
9
+
10
+ def __init__(self, tokenizer):
11
+ tokenizer_type = type(tokenizer).__name__
12
+ assert tokenizer_type in self.support_tokenizer_types, \
13
+ f'Invalid tokenizer type, expected one from `{self.support_tokenizer_types}`, but got `{tokenizer_type}`'
14
+ self.tokenizer = tokenizer
15
+ self.image_token = IMAGE_TOKEN
16
+ self.image_token_id = IMAGE_TOKEN_ID
17
+ self.ignore_id = IGNORE_ID
18
+
19
+ def _tokenize_with_image_symbol(self, text):
20
+ text_chunks = [self.tokenizer(chunk, add_special_tokens=False).input_ids for chunk in
21
+ text.split(self.image_token)]
22
+ token_ids = []
23
+ num_chuck = len(text_chunks)
24
+ for i, chunk in enumerate(text_chunks):
25
+ token_ids.extend(chunk)
26
+ if i < num_chuck - 1:
27
+ token_ids.append(self.image_token_id)
28
+ return token_ids
29
+
30
+ @abstractmethod
31
+ def format(self, conversations: List[Dict], generation_preface=None):
32
+ pass
33
+
34
+ @abstractmethod
35
+ def format_query(self, query, generation_preface=""):
36
+ pass
37
+
38
+
39
+ class QwenConversationFormatter(ConversationFormatter):
40
+ support_tokenizer_types = ['QWenTokenizer', 'Qwen2TokenizerFast']
41
+
42
+ def __init__(self, tokenizer):
43
+ super().__init__(tokenizer)
44
+ self.from2role = {
45
+ "system": "<|im_start|>system\n",
46
+ "human": "<|im_start|>user\n",
47
+ "gpt": "<|im_start|>assistant\n",
48
+ }
49
+ self.gpt_token_num = None
50
+ self.im_end = "<|im_end|>\n"
51
+ self.default_system_prompt = "You are a helpful assistant."
52
+
53
+ def format(self, conversations: List[Dict], generation_preface=None):
54
+ if self.gpt_token_num is None:
55
+ self.gpt_token_num = len(self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids)
56
+
57
+ if conversations[0]["from"] != "system":
58
+ conversations.insert(0, {
59
+ "from": "system",
60
+ "value": self.default_system_prompt
61
+ })
62
+
63
+ if generation_preface is not None:
64
+ conversations.append({
65
+ "from": "gpt",
66
+ "value": generation_preface
67
+ })
68
+
69
+ prompt = ""
70
+ input_ids = []
71
+ labels = []
72
+ num_conversation = len(conversations)
73
+ for i, conversation in enumerate(conversations):
74
+ frm = conversation["from"]
75
+ role = self.from2role[frm]
76
+ message = conversation["value"]
77
+ text = role + message
78
+ if i < num_conversation - 1 or generation_preface is None:
79
+ text += self.im_end
80
+ prompt += text
81
+ token_ids = self._tokenize_with_image_symbol(text)
82
+ input_ids.extend(token_ids)
83
+ label_ids = [self.ignore_id] * len(token_ids)
84
+ if frm == "gpt" and generation_preface is None:
85
+ # learning `\n` following `im_end` is meaningless, so the last `\n` token is ignored in label
86
+ label_ids[self.gpt_token_num:-1] = token_ids[self.gpt_token_num:-1]
87
+ labels.extend(label_ids)
88
+
89
+ assert self._tokenize_with_image_symbol(prompt) == input_ids
90
+ assert len(input_ids) == len(labels)
91
+
92
+ return prompt, input_ids, labels
93
+
94
+ def format_query(self, query, generation_preface=""):
95
+ prompt, input_ids, _ = self.format([{
96
+ "from": "human",
97
+ "value": query
98
+ }], generation_preface=generation_preface)
99
+
100
+ return prompt, input_ids
101
+
102
+
103
+ class Llama3ConversationFormatter(ConversationFormatter):
104
+ support_tokenizer_types = ['PreTrainedTokenizerFast']
105
+
106
+ def __init__(self, tokenizer):
107
+ super().__init__(tokenizer)
108
+ self.from2role = {
109
+ "system": "<|start_header_id|>system<|end_header_id|>\n\n",
110
+ "human": "<|start_header_id|>user<|end_header_id|>\n\n",
111
+ "gpt": "<|start_header_id|>assistant<|end_header_id|>\n\n",
112
+ }
113
+ self.gpt_token_num = None
114
+ self.im_end = "<|eot_id|>"
115
+ self.default_system_prompt = "You are a helpful and honest multimodal assistant."
116
+ self.bos_token = "<|begin_of_text|>"
117
+ self.bos_token_ids = None
118
+
119
+ def format(self, conversations: List[Dict], generation_preface=None):
120
+ if self.gpt_token_num is None:
121
+ self.gpt_token_num = len(self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids)
122
+
123
+ if self.bos_token_ids is None:
124
+ self.bos_token_ids = self.tokenizer(self.bos_token, add_special_tokens=False).input_ids
125
+
126
+ if conversations[0]["from"] != "system":
127
+ conversations.insert(0, {
128
+ "from": "system",
129
+ "value": self.default_system_prompt
130
+ })
131
+
132
+ if generation_preface is not None:
133
+ conversations.append({
134
+ "from": "gpt",
135
+ "value": generation_preface
136
+ })
137
+
138
+ prompt = "" + self.bos_token
139
+ input_ids = [] + self.bos_token_ids
140
+ labels = [] + [IGNORE_ID] * len(input_ids)
141
+ num_conversation = len(conversations)
142
+ for i, conversation in enumerate(conversations):
143
+ frm = conversation["from"]
144
+ role = self.from2role[frm]
145
+ message = conversation["value"].strip()
146
+ text = role + message
147
+ if i < num_conversation - 1 or generation_preface is None:
148
+ text += self.im_end
149
+ prompt += text
150
+ token_ids = self._tokenize_with_image_symbol(text)
151
+ input_ids.extend(token_ids)
152
+ label_ids = [self.ignore_id] * len(token_ids)
153
+ if frm == "gpt":
154
+ label_ids[self.gpt_token_num:] = token_ids[self.gpt_token_num:]
155
+ labels.extend(label_ids)
156
+
157
+ assert self._tokenize_with_image_symbol(prompt) == input_ids
158
+ assert len(input_ids) == len(labels)
159
+
160
+ return prompt, input_ids, labels
161
+
162
+ def format_query(self, query, generation_preface=""):
163
+ prompt, input_ids, _ = self.format([{
164
+ "from": "human",
165
+ "value": query
166
+ }], generation_preface=generation_preface)
167
+
168
+ return prompt, input_ids
169
+
170
+
171
+ class GemmaConversationFormatter(ConversationFormatter):
172
+ support_tokenizer_types = ['GemmaTokenizer', 'GemmaTokenizerFast']
173
+
174
+ def __init__(self, tokenizer):
175
+ super().__init__(tokenizer)
176
+ # Gemma does not support system prompt
177
+ self.from2role = {
178
+ "human": "<start_of_turn>user\n",
179
+ "gpt": "<start_of_turn>model\n",
180
+ }
181
+ self.gpt_token_num = None
182
+ self.im_end = "<end_of_turn>\n"
183
+ self.bos_token = "<bos>"
184
+ self.bos_token_ids = None
185
+
186
+ def format(self, conversations: List[Dict], generation_preface=None):
187
+ if self.gpt_token_num is None:
188
+ self.gpt_token_num = len(self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids)
189
+
190
+ if self.bos_token_ids is None:
191
+ self.bos_token_ids = self.tokenizer(self.bos_token, add_special_tokens=False).input_ids
192
+
193
+ if conversations[0]["from"] == "system":
194
+ raise ValueError("Gemma does not support system prompt")
195
+
196
+ if generation_preface is not None:
197
+ conversations.append({
198
+ "from": "gpt",
199
+ "value": generation_preface
200
+ })
201
+
202
+ prompt = "" + self.bos_token
203
+ input_ids = [] + self.bos_token_ids
204
+ labels = [] + [IGNORE_ID] * len(input_ids)
205
+ num_conversation = len(conversations)
206
+ for i, conversation in enumerate(conversations):
207
+ frm = conversation["from"]
208
+ role = self.from2role[frm]
209
+ message = conversation["value"].strip()
210
+ text = role + message
211
+ if i < num_conversation - 1 or generation_preface is None:
212
+ text += self.im_end
213
+ prompt += text
214
+ token_ids = self._tokenize_with_image_symbol(text)
215
+ input_ids.extend(token_ids)
216
+ label_ids = [self.ignore_id] * len(token_ids)
217
+ if frm == "gpt":
218
+ # learning `\n` following `im_end` is meaningless, so the last `\n` token is ignored in label
219
+ label_ids[self.gpt_token_num:-1] = token_ids[self.gpt_token_num:-1]
220
+ labels.extend(label_ids)
221
+
222
+ assert self._tokenize_with_image_symbol(prompt) == input_ids
223
+ assert len(input_ids) == len(labels)
224
+
225
+ return prompt, input_ids, labels
226
+
227
+ def format_query(self, query, generation_preface=""):
228
+ prompt, input_ids, _ = self.format([{
229
+ "from": "human",
230
+ "value": query
231
+ }], generation_preface=generation_preface)
232
+
233
+ return prompt, input_ids
ovis/model/modeling_ovis.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from datetime import datetime
4
+ from importlib import import_module
5
+ from typing import List, Union, Callable, Optional, Dict
6
+
7
+ import PIL.Image
8
+ import deepspeed
9
+ import torch
10
+ from torch import Tensor
11
+ from torch.nn import init
12
+ from transformers import PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer, AutoModelForCausalLM
13
+ from transformers.cache_utils import HybridCache
14
+ from transformers.generation.utils import GenerateOutput
15
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled, deepspeed_config
16
+
17
+ from ovis.model.configuration_ovis import OvisConfig
18
+ from ovis.model.conversation_formatter import ConversationFormatter
19
+ from ovis.util.constants import IGNORE_ID, BEGIN_LINE, END_LINE, IMAGE_ATOM_ID, IMAGE_INDICATOR_IDS, \
20
+ IMAGE_TOKEN_ID
21
+ from ovis.util.utils import rank0_print
22
+
23
+
24
+ class VisualEmbedding(torch.nn.Embedding):
25
+ def forward(self, visual_tokens: Tensor) -> Tensor:
26
+ if visual_tokens.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
27
+ return super().forward(visual_tokens)
28
+ return torch.matmul(visual_tokens, self.weight)
29
+
30
+ def reset_parameters(self, mean=0., std=1.) -> None:
31
+ init.normal_(self.weight, mean=mean, std=std)
32
+ self._fill_padding_idx_with_zero()
33
+
34
+
35
+ class OvisPreTrainedModel(PreTrainedModel):
36
+ config_class = OvisConfig
37
+ base_model_prefix = "ovis"
38
+
39
+
40
+ class Ovis(OvisPreTrainedModel):
41
+
42
+ def __init__(self, config: OvisConfig, *inputs, **kwargs):
43
+ super().__init__(config, *inputs, **kwargs)
44
+ if kwargs.get('train_from_scratch'):
45
+ self.llm = kwargs['llm']
46
+ self.generation_config = self.llm.generation_config
47
+ self.config.llm_config = self.llm.config
48
+ self.config.hidden_size = self.llm.config.hidden_size # for deepspeed auto configuration
49
+ self.text_tokenizer = kwargs['text_tokenizer']
50
+ self.visual_tokenizer = kwargs['visual_tokenizer']
51
+ self.config.visual_tokenizer_config = self.visual_tokenizer.config
52
+ else:
53
+ attn_kwargs = dict()
54
+ if self.config.llm_attn_implementation:
55
+ attn_kwargs['attn_implementation'] = self.config.llm_attn_implementation
56
+ self.llm = AutoModelForCausalLM.from_config(self.config.llm_config, **attn_kwargs)
57
+ assert self.config.hidden_size == self.llm.config.hidden_size, "hidden size mismatch"
58
+ self.text_tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path)
59
+ self.visual_tokenizer = AutoModel.from_config(self.config.visual_tokenizer_config,
60
+ image_processor_name_or_path=self.config.name_or_path)
61
+
62
+ # initialize vte
63
+ if is_deepspeed_zero3_enabled():
64
+ with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
65
+ self.vte = VisualEmbedding(self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size)
66
+ else:
67
+ self.vte = VisualEmbedding(self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size,
68
+ device=self.visual_tokenizer.device, dtype=self.visual_tokenizer.dtype)
69
+
70
+ def _merge_modules(modules_list: tuple):
71
+ merged_modules = []
72
+ for modules in modules_list:
73
+ merged_modules.extend(modules if modules else [])
74
+ return merged_modules
75
+
76
+ self._no_split_modules = _merge_modules((self.llm._no_split_modules, self.visual_tokenizer._no_split_modules))
77
+ self._skip_keys_device_placement = self.llm._skip_keys_device_placement
78
+ self._keep_in_fp32_modules = _merge_modules(
79
+ (self.llm._keep_in_fp32_modules, self.visual_tokenizer._keep_in_fp32_modules))
80
+ self.is_parallelizable = all((self.llm.is_parallelizable, self.visual_tokenizer.is_parallelizable))
81
+ self.supports_gradient_checkpointing = all(
82
+ (self.llm.supports_gradient_checkpointing, self.visual_tokenizer.supports_gradient_checkpointing))
83
+ self._supports_flash_attn_2 = all(
84
+ (self.llm._supports_flash_attn_2, self.visual_tokenizer._supports_flash_attn_2))
85
+ self._supports_sdpa = all((self.llm._supports_sdpa, self.visual_tokenizer._supports_sdpa))
86
+
87
+ def get_text_tokenizer(self):
88
+ return self.text_tokenizer
89
+
90
+ def get_visual_tokenizer(self):
91
+ return self.visual_tokenizer
92
+
93
+ def tie_weights(self):
94
+ if not self.config.disable_tie_weight:
95
+ self.get_llm().tie_weights()
96
+
97
+ def re_init_vte(self, mean, std):
98
+ vte = self.get_vte()
99
+ rank0_print(BEGIN_LINE)
100
+ rank0_print(f'[{datetime.now()}] Before re-initialization of vte: ')
101
+ with deepspeed.zero.GatheredParameters([vte.weight]):
102
+ rank0_print(f'vte.weight: {vte.weight}')
103
+ with deepspeed.zero.GatheredParameters([vte.weight], modifier_rank=0):
104
+ if not is_deepspeed_zero3_enabled() or deepspeed.comm.get_rank() == 0:
105
+ vte.reset_parameters(mean, std)
106
+ rank0_print(f'[{datetime.now()}] After re-initialization of vte:')
107
+ with deepspeed.zero.GatheredParameters([vte.weight]):
108
+ rank0_print(f'vte.weight: {vte.weight}')
109
+ rank0_print(END_LINE)
110
+
111
+ def get_monitor_tensors(self):
112
+ monitor_tensors = dict(
113
+ wte=self.get_wte().weight,
114
+ lm_head=self.get_lm_head().weight,
115
+ vte=self.get_vte().weight
116
+ )
117
+ monitor_tensors.update(
118
+ {f'visual_tokenizer_{k}': v for k, v in self.get_visual_tokenizer().get_monitor_tensors().items()})
119
+ return monitor_tensors
120
+
121
+ def get_lm_head(self):
122
+ return self.get_llm().get_output_embeddings()
123
+
124
+ def get_llm(self):
125
+ return self.llm
126
+
127
+ def get_vte(self):
128
+ return self.vte
129
+
130
+ def get_wte(self):
131
+ return self.llm.get_input_embeddings()
132
+
133
+ def get_conversation_formatter(self) -> ConversationFormatter:
134
+ if getattr(self, 'conversation_formatter', None) is None:
135
+ self.conversation_formatter = getattr(import_module(".conversation_formatter", __package__),
136
+ self.config.conversation_formatter_class)(self.text_tokenizer)
137
+ return self.conversation_formatter
138
+
139
+ def forward(
140
+ self,
141
+ input_ids: torch.Tensor,
142
+ attention_mask: torch.Tensor,
143
+ labels: Optional[torch.Tensor],
144
+ pixel_values: List[Optional[torch.Tensor]],
145
+ **kwargs
146
+ ):
147
+ assert self.training, "`forward` can only be used in training. For inference, use `generate`."
148
+ _, inputs_embeds, labels, attention_mask = self.merge_multimodal(
149
+ text_input_ids=input_ids,
150
+ text_attention_masks=attention_mask,
151
+ text_labels=labels,
152
+ pixel_values=pixel_values
153
+ )
154
+ return self.llm(inputs_embeds=inputs_embeds, labels=labels, attention_mask=attention_mask, **kwargs)
155
+
156
+ def merge_multimodal(
157
+ self,
158
+ text_input_ids: torch.Tensor,
159
+ text_attention_masks: torch.Tensor,
160
+ text_labels: Optional[torch.Tensor],
161
+ pixel_values: List[Optional[torch.Tensor]]
162
+ ):
163
+ input_device = text_input_ids.device
164
+ visual_vocab_szie = self.get_visual_tokenizer().config.vocab_size
165
+ visual_indicator_embeds = self.get_vte()(
166
+ torch.tensor(
167
+ list(range(visual_vocab_szie - 5, visual_vocab_szie)),
168
+ dtype=torch.long,
169
+ device=self.get_visual_tokenizer().device
170
+ )
171
+ ).to(device=input_device)
172
+
173
+ if self.training:
174
+ # When training, to be compatible with deepspeed zero, each sample has to include pixel_value tensor.
175
+ # For text-only sample, one can simply use a full zero tensor as pixel_value, which will be ignored
176
+ # (see below in this function); so, the gradient will not be affected.
177
+ num_images = [x.shape[0] for x in pixel_values]
178
+ visual_tokens = self.visual_tokenizer(torch.cat([x for x in pixel_values], dim=0))
179
+ visual_embeds = torch.split(self.get_vte()(visual_tokens).to(dtype=self.dtype, device=input_device),
180
+ split_size_or_sections=num_images, dim=0)
181
+ visual_input_ids = torch.split(torch.argmax(visual_tokens, dim=-1).to(device=input_device),
182
+ split_size_or_sections=num_images, dim=0)
183
+ visual_labels = [torch.full(x.shape, IGNORE_ID, dtype=torch.long, device=input_device) for x in
184
+ visual_input_ids]
185
+ else:
186
+ # When inference, sample can include only text with `None` pixel_value
187
+ num_images = [x.shape[0] if x is not None else 0 for x in pixel_values]
188
+ if sum(num_images) > 0:
189
+ visual_tokens = self.visual_tokenizer(torch.cat([x for x in pixel_values if x is not None], dim=0))
190
+ visual_embeds = torch.split(self.get_vte()(visual_tokens).to(dtype=self.dtype, device=input_device),
191
+ split_size_or_sections=num_images, dim=0)
192
+ visual_input_ids = torch.split(torch.argmax(visual_tokens, dim=-1).to(device=input_device),
193
+ split_size_or_sections=num_images, dim=0)
194
+ visual_labels = [torch.full(x.shape, IGNORE_ID, dtype=torch.long, device=input_device) for x in
195
+ visual_input_ids]
196
+ else:
197
+ # just placeholders
198
+ visual_embeds = [None] * len(num_images)
199
+ visual_input_ids = [None] * len(num_images)
200
+ visual_labels = [None] * len(num_images)
201
+ # just placeholders
202
+ text_labels = torch.full(text_input_ids.shape, IGNORE_ID, dtype=torch.long, device=input_device)
203
+
204
+ input_embeds = []
205
+ attention_masks = []
206
+ labels = []
207
+ for text_input_id, text_label, text_attention_mask, visual_embed, visual_input_id, visual_label in zip(
208
+ text_input_ids, text_labels, text_attention_masks, visual_embeds, visual_input_ids, visual_labels
209
+ ):
210
+ placeholder_token_mask = torch.lt(text_input_id, 0)
211
+ text_embed = self.get_wte()(torch.masked_fill(text_input_id, placeholder_token_mask, 0))
212
+ for i, indicator_id in enumerate(IMAGE_INDICATOR_IDS):
213
+ text_embed[text_input_id == indicator_id] = visual_indicator_embeds[i]
214
+ image_atom_positions = torch.where(torch.eq(text_input_id, IMAGE_ATOM_ID))[0].tolist()
215
+ if len(image_atom_positions) > 0:
216
+ input_embed_parts = []
217
+ attention_mask_parts = []
218
+ label_parts = []
219
+ prev_image_atom_position = -1
220
+ for index, image_atom_position in enumerate(image_atom_positions):
221
+ input_embed_parts.append(
222
+ text_embed[prev_image_atom_position + 1:image_atom_position, :])
223
+ label_parts.append(
224
+ text_label[prev_image_atom_position + 1:image_atom_position])
225
+ attention_mask_parts.append(
226
+ text_attention_mask[prev_image_atom_position + 1:image_atom_position])
227
+ input_embed_parts.append(visual_embed[index])
228
+ attention_mask_parts.append(
229
+ torch.ones_like(visual_label[index], dtype=torch.bool))
230
+ label_parts.append(visual_label[index])
231
+ prev_image_atom_position = image_atom_position
232
+ if prev_image_atom_position + 1 < text_input_id.shape[0]:
233
+ input_embed_parts.append(
234
+ text_embed[prev_image_atom_position + 1:, :])
235
+ attention_mask_parts.append(
236
+ text_attention_mask[prev_image_atom_position + 1:])
237
+ label_parts.append(
238
+ text_label[prev_image_atom_position + 1:])
239
+ input_embed = torch.cat(input_embed_parts, dim=0)
240
+ attention_mask = torch.cat(attention_mask_parts, dim=0)
241
+ label = torch.cat(label_parts, dim=0)
242
+ else:
243
+ input_embed = text_embed
244
+ attention_mask = text_attention_mask
245
+ label = text_label
246
+ if self.training:
247
+ # Make visual_embed & visual_indicator_embeds involved in the backward graph,
248
+ # to be compatible with deepspeed zero and ddp.
249
+ input_embed += torch.sum(visual_embed * 0.0) + torch.sum(visual_indicator_embeds * 0.0)
250
+ input_embeds.append(input_embed)
251
+ attention_masks.append(attention_mask)
252
+ labels.append(label)
253
+
254
+ if self.training: # padding to self.config.multimodal_max_length for increased training speed
255
+ padding_size = max(0, self.config.multimodal_max_length - len(input_embeds[0]))
256
+ input_embeds[0] = torch.nn.ConstantPad2d((0, 0, 0, padding_size), 0.0)(input_embeds[0])
257
+ attention_masks[0] = torch.nn.ConstantPad1d((0, padding_size), False)(attention_masks[0])
258
+ labels[0] = torch.nn.ConstantPad1d((0, padding_size), IGNORE_ID)(labels[0])
259
+ batch_input_embeds = torch.nn.utils.rnn.pad_sequence(input_embeds, batch_first=True, padding_value=0.0)[:,
260
+ :self.config.multimodal_max_length, :]
261
+ batch_attention_mask = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=False)[
262
+ :,
263
+ :self.config.multimodal_max_length]
264
+ batch_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_ID)[:,
265
+ :self.config.multimodal_max_length]
266
+
267
+ return visual_input_ids, batch_input_embeds, batch_labels, batch_attention_mask
268
+
269
+ def preprocess_inputs(
270
+ self,
271
+ text_or_conversations: Union[List[Dict], str],
272
+ images: Optional[List[PIL.Image.Image]],
273
+ max_partition=9,
274
+ generation_preface='',
275
+ return_labels=False,
276
+ propagate_exception=True
277
+ ):
278
+ # convert text to conversations
279
+ if isinstance(text_or_conversations, str):
280
+ conversations = [{
281
+ "from": "human",
282
+ "value": text_or_conversations
283
+ }]
284
+ elif isinstance(text_or_conversations, list):
285
+ conversations = text_or_conversations
286
+ else:
287
+ raise ValueError(f'Invalid type of `text_or_conversations`, expected `List[Dict]` or `str`,'
288
+ f' but got {type(text_or_conversations)}')
289
+
290
+ # format conversations
291
+ prompt, raw_input_ids, raw_labels = self.get_conversation_formatter().format(
292
+ conversations, generation_preface=generation_preface)
293
+
294
+ # place image placeholders
295
+ input_ids = []
296
+ labels = []
297
+ pixel_values = []
298
+ invalidate_label = False
299
+ image_token_indices = [i for i, v in enumerate(raw_input_ids) if v == IMAGE_TOKEN_ID]
300
+ last_image_token_index = -1
301
+ for i in range(len(image_token_indices)):
302
+ head = 0 if i == 0 else image_token_indices[i - 1] + 1
303
+ tail = image_token_indices[i]
304
+ last_image_token_index = tail
305
+ input_ids.extend(raw_input_ids[head:tail])
306
+ labels.extend(raw_labels[head:tail])
307
+ try:
308
+ image = images[i]
309
+ raw_pixel_values, image_placeholders = self.visual_tokenizer.preprocess_image(
310
+ image, max_partition=max_partition)
311
+ except Exception as e:
312
+ if propagate_exception:
313
+ raise e
314
+ logging.exception(e)
315
+ invalidate_label = True
316
+ raw_pixel_values, image_placeholders = self.visual_tokenizer.mock_input()
317
+ input_ids.extend(image_placeholders)
318
+ labels.extend([IGNORE_ID] * len(image_placeholders))
319
+ pixel_values.append(raw_pixel_values)
320
+ input_ids.extend(raw_input_ids[last_image_token_index + 1:])
321
+ labels.extend(raw_labels[last_image_token_index + 1:])
322
+
323
+ # return tensors
324
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
325
+ labels = torch.tensor([IGNORE_ID] * len(labels) if invalidate_label else labels, dtype=torch.long)
326
+ pixel_values = torch.cat(pixel_values, dim=0) if len(pixel_values) > 0 else None
327
+
328
+ if return_labels:
329
+ return prompt, input_ids, pixel_values, labels
330
+ else:
331
+ return prompt, input_ids, pixel_values
332
+
333
+ def save_pretrained(
334
+ self,
335
+ save_directory: Union[str, os.PathLike],
336
+ is_main_process: bool = True,
337
+ state_dict: Optional[dict] = None,
338
+ save_function: Callable = torch.save,
339
+ push_to_hub: bool = False,
340
+ max_shard_size: Union[int, str] = "5GB",
341
+ safe_serialization: bool = True,
342
+ variant: Optional[str] = None,
343
+ token: Optional[Union[str, bool]] = None,
344
+ save_peft_format: bool = True,
345
+ **kwargs
346
+ ):
347
+ super().save_pretrained(save_directory,
348
+ is_main_process=is_main_process,
349
+ state_dict=state_dict,
350
+ save_function=save_function,
351
+ safe_serialization=safe_serialization)
352
+ self.get_text_tokenizer().save_pretrained(save_directory)
353
+ self.get_visual_tokenizer().get_image_processor().save_pretrained(save_directory)
354
+
355
+ # uncomment the following will additionally save a separate visual tokenizer
356
+ # visual_tokenizer_directory = os.path.join(save_directory, 'visual_tokenizer')
357
+ # self.get_visual_tokenizer().save_pretrained(visual_tokenizer_directory,
358
+ # is_main_process=is_main_process,
359
+ # state_dict=None,
360
+ # save_function=save_function,
361
+ # safe_serialization=safe_serialization)
362
+ # self.get_visual_tokenizer().get_image_processor().save_pretrained(visual_tokenizer_directory)
363
+
364
+ def _get_hybrid_cache_for_llm(self, max_batch_size: int, max_cache_len: int):
365
+ cache_cls = HybridCache
366
+ llm = self.get_llm()
367
+
368
+ need_new_cache = (
369
+ not hasattr(llm, "_cache")
370
+ or (not isinstance(llm._cache, cache_cls))
371
+ or llm._cache.max_batch_size != max_batch_size
372
+ or llm._cache.max_cache_len < max_cache_len
373
+ )
374
+
375
+ if need_new_cache:
376
+ if hasattr(llm.config, "_pre_quantization_dtype"):
377
+ cache_dtype = llm.config._pre_quantization_dtype
378
+ else:
379
+ cache_dtype = llm.dtype
380
+ llm._cache = cache_cls(
381
+ config=llm.config,
382
+ max_batch_size=max_batch_size,
383
+ max_cache_len=max_cache_len,
384
+ device=llm.device,
385
+ dtype=cache_dtype,
386
+ )
387
+ else:
388
+ llm._cache.reset()
389
+ return llm._cache
390
+
391
+ # TODO: support batch generation
392
+ def generate(
393
+ self,
394
+ inputs: Optional[torch.Tensor] = None,
395
+ **kwargs
396
+ ) -> Union[GenerateOutput, torch.LongTensor]:
397
+ assert inputs.shape[0] == 1, 'Currently, only support `batch_size=1`'
398
+ _, inputs_embeds, labels, attention_mask = self.merge_multimodal(
399
+ text_input_ids=inputs,
400
+ text_attention_masks=kwargs.pop('attention_mask'),
401
+ text_labels=None,
402
+ pixel_values=kwargs.pop('pixel_values')
403
+ )
404
+ if getattr(self.generation_config, 'cache_implementation') == 'hybrid': # mainly for Gemma2
405
+ kwargs['past_key_values'] = self._get_hybrid_cache_for_llm(
406
+ getattr(kwargs, "num_beams", 1), kwargs['max_new_tokens'] + inputs_embeds.shape[-2])
407
+ self.get_llm()._supports_cache_class = True
408
+ kwargs['cache_implementation'] = None
409
+
410
+ return self.llm.generate(inputs=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
411
+
412
+
413
+ AutoConfig.register("ovis", OvisConfig)
414
+ AutoModelForCausalLM.register(OvisConfig, Ovis)
ovis/model/visual_tokenizer/base_visual_tokenizer.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Optional
2
+
3
+ import PIL.Image
4
+ import torch
5
+ from torch.nn.functional import softmax, gumbel_softmax, pad
6
+ from transformers import PretrainedConfig, PreTrainedModel, AutoImageProcessor, AutoModel, AutoConfig
7
+ from ovis.util.constants import IMAGE_INDICATOR_IDS, IMAGE_ATOM_ID
8
+
9
+
10
+ class BaseVisualTokenizerConfig(PretrainedConfig):
11
+ def __init__(
12
+ self,
13
+ vocab_size=16384,
14
+ tokenize_function="softmax",
15
+ tau=1.0,
16
+ depths=None,
17
+ drop_cls_token=False,
18
+ backbone_config: Optional[Union[PretrainedConfig, dict]] = None,
19
+ hidden_stride: int = 1,
20
+ **kwargs
21
+ ):
22
+ super().__init__(**kwargs)
23
+ self.vocab_size = vocab_size
24
+ self.tokenize_function = tokenize_function
25
+ self.tau = tau
26
+ if isinstance(depths, str):
27
+ depths = [int(x) for x in depths.split('|')]
28
+ self.depths = depths
29
+ self.backbone_kwargs = {}
30
+ self.drop_cls_token = drop_cls_token
31
+ if backbone_config is not None:
32
+ assert isinstance(backbone_config, (PretrainedConfig, dict)), \
33
+ f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type"
34
+ if not isinstance(backbone_config, PretrainedConfig):
35
+ model_type = backbone_config['model_type']
36
+ backbone_config.pop('model_type')
37
+ backbone_config = AutoConfig.for_model(model_type, **backbone_config)
38
+ self.backbone_config = backbone_config
39
+ self.hidden_stride = hidden_stride
40
+
41
+
42
+ class BaseVisualTokenizer(PreTrainedModel):
43
+ base_model_prefix = "backbone"
44
+ main_input_name = None
45
+ _image_processor_class = None
46
+ _image_processor_kwargs = {}
47
+ _backbone_class = None
48
+ _backbone_name_or_path = None
49
+
50
+ def __init__(self, config: BaseVisualTokenizerConfig, *inputs, **kwargs):
51
+ super().__init__(config, *inputs, **kwargs)
52
+ if kwargs.get('train_from_scratch'):
53
+ self.image_processor = self._image_processor_class.from_pretrained(self._backbone_name_or_path,
54
+ **self._image_processor_kwargs)
55
+ self.backbone = self._backbone_class.from_pretrained(self._backbone_name_or_path,
56
+ **self.config.backbone_kwargs)
57
+ self.config.backbone_config = self.backbone.config
58
+ else:
59
+ self.image_processor = AutoImageProcessor.from_pretrained(kwargs['image_processor_name_or_path'])
60
+ self.backbone = AutoModel.from_config(self.config.backbone_config)
61
+ head_dim = self.config.vocab_size - len(IMAGE_INDICATOR_IDS) # reserved tokens for IMAGE_INDICATORS
62
+ self.head = torch.nn.Sequential(
63
+ torch.nn.Linear(
64
+ self.backbone.config.hidden_size * self.config.hidden_stride * self.config.hidden_stride, head_dim,
65
+ bias=False
66
+ ),
67
+ torch.nn.LayerNorm(head_dim)
68
+ )
69
+
70
+ assert all((self.image_processor.do_resize,
71
+ not getattr(self.image_processor, 'do_center_crop', False),
72
+ self.image_processor.do_rescale,
73
+ self.image_processor.do_normalize
74
+ )), f"image_processor `{self.image_processor}` is not supported currently"
75
+
76
+ def get_backbone(self):
77
+ return self.backbone
78
+
79
+ def get_monitor_tensors(self):
80
+ raise NotImplementedError
81
+
82
+ def get_image_processor(self):
83
+ return self.image_processor
84
+
85
+ def mock_input(self):
86
+ height, width = self.get_image_size()
87
+ return torch.zeros(1, 3, height, width), self.construct_image_placeholders((1, 1))
88
+
89
+ def get_head(self):
90
+ return self.head
91
+
92
+ def get_image_size(self):
93
+ raise NotImplementedError
94
+
95
+ @staticmethod
96
+ def construct_image_placeholders(grid):
97
+ image_placeholders = [IMAGE_INDICATOR_IDS[0], IMAGE_ATOM_ID, IMAGE_INDICATOR_IDS[1]]
98
+ if grid[0] * grid[1] > 1:
99
+ for r in range(grid[0]):
100
+ for c in range(grid[1]):
101
+ image_placeholders.append(IMAGE_ATOM_ID)
102
+ if c < grid[1] - 1:
103
+ image_placeholders.append(IMAGE_INDICATOR_IDS[2])
104
+ if r < grid[0] - 1:
105
+ image_placeholders.append(IMAGE_INDICATOR_IDS[3])
106
+ image_placeholders.append(IMAGE_INDICATOR_IDS[4])
107
+ return image_placeholders
108
+
109
+ def preprocess_image(self, image: PIL.Image.Image, max_partition=9, covering_threshold=0.9, convert_to_rgb=True):
110
+ def _preprocess(img: PIL.Image.Image, side):
111
+ # first resize and preprocess
112
+ w, h = img.size
113
+ if w == h:
114
+ new_width = new_height = side
115
+ elif w > h:
116
+ new_width = side
117
+ new_height = int(h / w * new_width)
118
+ else:
119
+ new_height = side
120
+ new_width = int(w / h * new_height)
121
+ new_size = dict(height=new_height, width=new_width)
122
+ pixel_values = self.image_processor.preprocess(img, size=new_size, return_tensors='pt')['pixel_values']
123
+
124
+ # then pad to square
125
+ square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device)
126
+ new_height, new_width = pixel_values.shape[2:]
127
+ if new_height == new_width:
128
+ square_values[:, :, :, :] = pixel_values
129
+ elif new_height > new_width:
130
+ from_index = (side - new_width) // 2
131
+ square_values[:, :, :, from_index:from_index + new_width] = pixel_values
132
+ else:
133
+ from_index = (side - new_height) // 2
134
+ square_values[:, :, from_index:from_index + new_height, :] = pixel_values
135
+
136
+ return square_values
137
+
138
+ def _partition(img, grid):
139
+ w, h = img.size
140
+ row_height = h // grid[0]
141
+ col_width = w // grid[1]
142
+
143
+ partition = []
144
+ for row in range(grid[0]):
145
+ for col in range(grid[1]):
146
+ left = col * col_width
147
+ upper = row * row_height
148
+ right = w if col == grid[1] - 1 else (col + 1) * col_width
149
+ lower = h if row == grid[0] - 1 else (row + 1) * row_height
150
+ partition.append((left, upper, right, lower))
151
+
152
+ return partition
153
+
154
+ def _covering_area(left, upper, right, lower, side):
155
+ w = right - left
156
+ h = lower - upper
157
+ w, h = max(w, h), min(w, h)
158
+ if w > side:
159
+ h = h / w * side
160
+ w = side
161
+ return w * h
162
+
163
+ def _get_best_grid(img, side):
164
+ img_area = img.size[0] * img.size[1]
165
+
166
+ candidate_grids = []
167
+ for i in range(1, max_partition + 1):
168
+ for j in range(1, max_partition + 1):
169
+ if i * j <= max_partition:
170
+ candidate_grids.append((i, j))
171
+
172
+ all_grids = []
173
+ good_grids = []
174
+ for grid in candidate_grids:
175
+ partition = _partition(img, grid)
176
+ covering_ratio = sum([_covering_area(*p, side) for p in partition]) / img_area
177
+ assert covering_ratio <= 1.0
178
+ all_grids.append((grid, covering_ratio))
179
+ if covering_ratio > covering_threshold:
180
+ good_grids.append((grid, covering_ratio))
181
+
182
+ if len(good_grids) > 0:
183
+ # pick the good partition with minimum #sub_images and break the tie using covering_ratio
184
+ return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][0]
185
+ else:
186
+ # pick the partition with maximum covering_ratio and break the tie using #sub_images
187
+ return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0]
188
+
189
+ if convert_to_rgb and image.mode != 'RGB':
190
+ image = image.convert('RGB')
191
+
192
+ sides = self.get_image_size()
193
+ if sides[0] != sides[1]:
194
+ raise ValueError('get_image_size() returns non-square size')
195
+ side = sides[0]
196
+ grid = _get_best_grid(image, side)
197
+ partition = _partition(image, grid)
198
+ crops = [image.crop(p) for p in partition]
199
+ if len(crops) > 1:
200
+ crops.insert(0, image)
201
+ pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0)
202
+ image_placeholders = self.construct_image_placeholders(grid)
203
+ return pixel_values, image_placeholders
204
+
205
+ def get_backbone_layer(self, index):
206
+ return self.backbone.vision_model.encoder.layers[index]
207
+
208
+ def tokenize(self, logits):
209
+ def st_argmax(y_soft, dim): # straight-through softmax
210
+ index = y_soft.max(dim, keepdim=True)[1]
211
+ y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
212
+ ret = y_hard - y_soft.detach() + y_soft
213
+ return ret
214
+
215
+ if self.config.tokenize_function == 'softmax':
216
+ tokens = softmax(logits, dim=-1)
217
+ elif self.config.tokenize_function == 'gumbel_argmax':
218
+ tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
219
+ elif self.config.tokenize_function == 'st_argmax':
220
+ tokens = st_argmax(logits, dim=-1)
221
+ else:
222
+ raise ValueError(
223
+ f'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax, but got {self.config.tokenize_function}')
224
+ return tokens
225
+
226
+ def encode(self, pixel_values):
227
+ output = self.backbone(pixel_values, output_hidden_states=True, return_dict=True)
228
+ features = output.hidden_states[-1]
229
+ if self.config.drop_cls_token:
230
+ features = features[:, 1:, :]
231
+
232
+ # merge number of `hidden_stride * hidden_stride` hidden states together to reduce token sequence length
233
+ # e.g., for hidden_stride=3, this leads to a token length reduction: 729 -> 81 for siglip
234
+ if self.config.hidden_stride > 1:
235
+ n, l, d = features.shape # this `d` maybe different from the above `d
236
+ sqrt_l = int(l ** 0.5)
237
+ assert sqrt_l ** 2 == l, "The token sequence length should be a perfect square."
238
+ features = features.reshape(n, sqrt_l, sqrt_l, d)
239
+ pl = (self.config.hidden_stride - (sqrt_l % self.config.hidden_stride)) % self.config.hidden_stride
240
+ features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0)
241
+ sqrt_l += pl
242
+ features = features.reshape(n, sqrt_l // self.config.hidden_stride, self.config.hidden_stride,
243
+ sqrt_l // self.config.hidden_stride, self.config.hidden_stride, d)
244
+ features = features.permute(0, 1, 3, 2, 4, 5) # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
245
+ features = features.flatten(3) # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
246
+ features = features.reshape(
247
+ n, -1, self.config.hidden_stride * self.config.hidden_stride * d)
248
+
249
+ return features
250
+
251
+ def forward(self, pixel_values) -> torch.Tensor: # [BatchSize, ImageShape] -> [BatchSize, #Token, VocabSize]
252
+ features = self.encode(pixel_values)
253
+ logits = self.head(features)
254
+ tokens = self.tokenize(logits)
255
+ # tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with [BatchSize, #Token, 5], after
256
+ # which, tokens' shape should become [BatchSize, #Token, VocabSize]
257
+ batch_size, token_len, _ = tokens.shape
258
+ padding_tensor = torch.zeros(size=(batch_size, token_len, len(IMAGE_INDICATOR_IDS)),
259
+ dtype=tokens.dtype,
260
+ device=tokens.device,
261
+ layout=tokens.layout,
262
+ requires_grad=False)
263
+ tokens = torch.cat((tokens, padding_tensor), dim=2)
264
+ return tokens
ovis/model/visual_tokenizer/clip_visual_tokenizer.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel
2
+ from transformers import CLIPVisionModel, CLIPImageProcessor
3
+ from .base_visual_tokenizer import BaseVisualTokenizerConfig, BaseVisualTokenizer
4
+
5
+ MODEL_TYPE = "clip_visual_tokenizer"
6
+
7
+
8
+ class ClipVisualTokenizerConfig(BaseVisualTokenizerConfig):
9
+ model_type = MODEL_TYPE
10
+
11
+ def __init__(self, **kwargs):
12
+ super().__init__(**kwargs)
13
+ if self.depths:
14
+ assert len(self.depths) == 1
15
+ self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
16
+
17
+
18
+ class ClipVisualTokenizer(BaseVisualTokenizer):
19
+ config_class = ClipVisualTokenizerConfig
20
+ supports_gradient_checkpointing = True
21
+ _no_split_modules = ["CLIPEncoderLayer"]
22
+ _image_processor_class = CLIPImageProcessor
23
+ _image_processor_kwargs = dict(do_center_crop=False)
24
+ _backbone_class = CLIPVisionModel
25
+ _backbone_name_or_path = "openai/clip-vit-large-patch14-336"
26
+
27
+ def get_monitor_tensors(self):
28
+ return dict(
29
+ backbone_bottom=self.backbone.vision_model.encoder.layers[0].self_attn.k_proj.weight,
30
+ backbone_top=self.backbone.vision_model.encoder.layers[-1].self_attn.out_proj.weight,
31
+ head=self.head[0].weight
32
+ )
33
+
34
+ def get_image_size(self):
35
+ height = self.image_processor.crop_size["height"]
36
+ width = self.image_processor.crop_size["width"]
37
+ return height, width
38
+
39
+
40
+ AutoConfig.register(MODEL_TYPE, ClipVisualTokenizerConfig)
41
+ AutoModel.register(ClipVisualTokenizerConfig, ClipVisualTokenizer)
ovis/model/visual_tokenizer/siglip_visual_tokenizer.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel
2
+ from transformers import SiglipVisionModel, SiglipImageProcessor
3
+ from .base_visual_tokenizer import BaseVisualTokenizerConfig, BaseVisualTokenizer
4
+
5
+ MODEL_TYPE = "siglip_visual_tokenizer"
6
+
7
+
8
+ class SiglipVisualTokenizerConfig(BaseVisualTokenizerConfig):
9
+ model_type = MODEL_TYPE
10
+
11
+ def __init__(self, **kwargs):
12
+ super().__init__(**kwargs)
13
+ if self.drop_cls_token:
14
+ self.drop_cls_token = False
15
+ if self.depths:
16
+ assert len(self.depths) == 1
17
+ self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
18
+
19
+
20
+ class SiglipVisualTokenizer(BaseVisualTokenizer):
21
+ config_class = SiglipVisualTokenizerConfig
22
+ supports_gradient_checkpointing = True
23
+ _no_split_modules = ["SiglipVisionTransformer"]
24
+ _image_processor_class = SiglipImageProcessor
25
+ _image_processor_kwargs = {}
26
+ _backbone_class = SiglipVisionModel
27
+ _backbone_name_or_path = "google/siglip-so400m-patch14-384"
28
+
29
+ def get_monitor_tensors(self):
30
+ return dict(
31
+ backbone_bottom=self.backbone.vision_model.encoder.layers[0].self_attn.k_proj.weight,
32
+ backbone_top=self.backbone.vision_model.encoder.layers[-1].self_attn.out_proj.weight,
33
+ head=self.head[0].weight
34
+ )
35
+
36
+ def get_image_size(self):
37
+ height = self.image_processor.size["height"]
38
+ width = self.image_processor.size["width"]
39
+ return height, width
40
+
41
+
42
+ AutoConfig.register(MODEL_TYPE, SiglipVisualTokenizerConfig)
43
+ AutoModel.register(SiglipVisualTokenizerConfig, SiglipVisualTokenizer)
ovis/serve/runner.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import field, dataclass
2
+ from typing import Optional, Union, List
3
+
4
+ import torch
5
+ from PIL import Image
6
+
7
+ from ovis.model.modeling_ovis import Ovis
8
+ from ovis.util.constants import IMAGE_TOKEN
9
+
10
+
11
+ @dataclass
12
+ class RunnerArguments:
13
+ model_path: str
14
+ max_new_tokens: int = field(default=512)
15
+ do_sample: bool = field(default=False)
16
+ top_p: Optional[float] = field(default=None)
17
+ top_k: Optional[int] = field(default=None)
18
+ temperature: Optional[float] = field(default=None)
19
+ max_partition: int = field(default=9)
20
+
21
+
22
+ class OvisRunner:
23
+ def __init__(self, args: RunnerArguments):
24
+ self.model_path = args.model_path
25
+ self.dtype = torch.bfloat16
26
+ self.device = torch.cuda.current_device()
27
+ self.dtype = torch.bfloat16
28
+ self.model = Ovis.from_pretrained(self.model_path, torch_dtype=self.dtype, multimodal_max_length=8192)
29
+ self.model = self.model.eval().to(device=self.device)
30
+ self.eos_token_id = self.model.generation_config.eos_token_id
31
+ self.text_tokenizer = self.model.get_text_tokenizer()
32
+ self.pad_token_id = self.text_tokenizer.pad_token_id
33
+ self.visual_tokenizer = self.model.get_visual_tokenizer()
34
+ self.conversation_formatter = self.model.get_conversation_formatter()
35
+ self.image_placeholder = IMAGE_TOKEN
36
+ self.max_partition = args.max_partition
37
+ self.gen_kwargs = dict(
38
+ max_new_tokens=args.max_new_tokens,
39
+ do_sample=args.do_sample,
40
+ top_p=args.top_p,
41
+ top_k=args.top_k,
42
+ temperature=args.temperature,
43
+ repetition_penalty=None,
44
+ eos_token_id=self.eos_token_id,
45
+ pad_token_id=self.pad_token_id,
46
+ use_cache=True
47
+ )
48
+
49
+ def preprocess(self, inputs: List[Union[Image.Image, str]]):
50
+ # for single image and single text inputs, ensure image ahead
51
+ if len(inputs) == 2 and isinstance(inputs[0], str) and isinstance(inputs[1], Image.Image):
52
+ inputs = reversed(inputs)
53
+
54
+ # build query
55
+ query = ''
56
+ images = []
57
+ for data in inputs:
58
+ if isinstance(data, Image.Image):
59
+ query += self.image_placeholder + '\n'
60
+ images.append(data)
61
+ elif isinstance(data, str):
62
+ query += data.replace(self.image_placeholder, '')
63
+ elif data is not None:
64
+ raise RuntimeError(f'Invalid input type, expected `PIL.Image.Image` or `str`, but got {type(data)}')
65
+
66
+ # format conversation
67
+ prompt, input_ids, pixel_values = self.model.preprocess_inputs(
68
+ query, images, max_partition=self.max_partition)
69
+ attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)
70
+ input_ids = input_ids.unsqueeze(0).to(device=self.device)
71
+ attention_mask = attention_mask.unsqueeze(0).to(device=self.device)
72
+ if pixel_values is not None:
73
+ pixel_values = [pixel_values.to(device=self.device, dtype=self.dtype)]
74
+ else:
75
+ pixel_values = [None]
76
+
77
+ return prompt, input_ids, attention_mask, pixel_values
78
+
79
+ def run(self, inputs: List[Union[Image.Image, str]]):
80
+ prompt, input_ids, attention_mask, pixel_values = self.preprocess(inputs)
81
+ output_ids = self.model.generate(
82
+ input_ids,
83
+ pixel_values=pixel_values,
84
+ attention_mask=attention_mask,
85
+ **self.gen_kwargs
86
+ )
87
+ output = self.text_tokenizer.decode(output_ids[0], skip_special_tokens=True)
88
+ input_token_len = input_ids.shape[1]
89
+ output_token_len = output_ids.shape[1]
90
+ response = dict(
91
+ prompt=prompt,
92
+ output=output,
93
+ prompt_tokens=input_token_len,
94
+ total_tokens=input_token_len + output_token_len
95
+ )
96
+ return response
97
+
98
+
99
+ if __name__ == '__main__':
100
+ runner_args = RunnerArguments(model_path='<model_path>')
101
+ runner = OvisRunner(runner_args)
102
+ image = Image.open('<image_path>')
103
+ text = '<prompt>'
104
+ response = runner.run([image, text])
105
+ print(response['output'])
ovis/serve/server.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os.path
3
+
4
+ import gradio as gr
5
+ from gradio.components import Textbox, Image
6
+
7
+ from ovis.serve.runner import RunnerArguments, OvisRunner
8
+
9
+
10
+ class Server:
11
+ def __init__(self, runner: OvisRunner):
12
+ self.runner = runner
13
+
14
+ def __call__(self, image, text):
15
+ response = self.runner.run([image, text])
16
+ output = response["output"]
17
+ return output
18
+
19
+
20
+ if __name__ == '__main__':
21
+ parser = argparse.ArgumentParser(description='Ovis Server')
22
+ parser.add_argument('--model_path', type=str, required=True)
23
+ parser.add_argument('--flagging_dir', type=str, default=os.path.expanduser('~/ovis-flagged'))
24
+ parser.add_argument('--max_partition', type=int, default=9)
25
+ parser.add_argument('--port', type=int, required=True)
26
+ args = parser.parse_args()
27
+
28
+ os.makedirs(args.flagging_dir, exist_ok=True)
29
+ runner_args = RunnerArguments(
30
+ model_path=args.model_path,
31
+ max_partition=args.max_partition
32
+ )
33
+ demo = gr.Interface(
34
+ fn=Server(OvisRunner(runner_args)),
35
+ inputs=[Image(type='pil', label='image'),
36
+ Textbox(placeholder='Enter your text here...', label='prompt')],
37
+ outputs=gr.Markdown(),
38
+ title=args.model_path.split('/')[-1],
39
+ flagging_dir=args.flagging_dir
40
+ )
41
+ demo.launch(server_port=args.port)
ovis/train/__init__.py ADDED
File without changes
ovis/train/arguments.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional
3
+
4
+ import transformers
5
+
6
+
7
+ @dataclass
8
+ class ModelArguments:
9
+ llm_name_or_path: Optional[str] = field(default=None)
10
+ visual_tokenizer_type: str = field(default=None)
11
+ visual_vocab_size: int = field(default=8192)
12
+ visual_drop_cls_token: bool = field(default=False)
13
+ visual_tokenize_function: str = field(default='softmax')
14
+ visual_tau: float = field(default=1.0)
15
+ visual_depths: Optional[str] = field(default=None)
16
+ visual_hidden_stride: int = field(default=1)
17
+ multimodal_max_length: int = field(default=2048)
18
+ conversation_formatter_class: str = field(default=None)
19
+ pad_token_id: Optional[int] = field(default=None)
20
+ llm_attn_implementation: Optional[str] = field(default=None)
21
+ disable_tie_weight: bool = field(default=False)
22
+
23
+
24
+ @dataclass
25
+ class TrainingArguments(transformers.TrainingArguments):
26
+ dataset_names: Optional[str] = field(default=None) # a|b|c
27
+ dataset_info: Optional[str] = field(default='dataset_info_v1_6')
28
+ ovis_pretrained_path: Optional[str] = field(default=None)
29
+ visual_tokenizer_pretrained_path: Optional[str] = field(default=None)
30
+ caption_template: Optional[str] = field(default=None)
31
+ stage: Optional[int] = field(default=None)
32
+ train_modules: Optional[str] = field(default=None)
33
+ cache_dir: Optional[str] = field(default=None)
34
+ optim: str = field(default="adamw_torch")
35
+ visual_max_tau: float = field(default=5.0)
36
+ visual_min_tau: float = field(default=0.05)
37
+ save_safetensors: bool = field(default=True)
38
+ monitor_step: int = field(default=100)
39
+ vte_re_init: bool = field(default=False)
40
+ text_max_length: int = field(default=1024)
41
+ max_partitions: str = field(default="9|1|1")
42
+
43
+ def __post_init__(self):
44
+ if self.gradient_checkpointing:
45
+ self.gradient_checkpointing_kwargs = {"use_reentrant": False}
46
+ if self.stage < 3:
47
+ self.save_safetensors = False
48
+ super().__post_init__()
ovis/train/callback.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import deepspeed
2
+ import torch
3
+ from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
4
+
5
+ from ovis.util.constants import END_LINE, BEGIN_LINE
6
+ from ovis.util.utils import rank0_print
7
+
8
+
9
+ class TuneTauCallback(TrainerCallback):
10
+ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
11
+ visual_tokenizer = kwargs['model'].get_visual_tokenizer()
12
+ current_step = state.global_step
13
+ max_step = state.max_steps
14
+ ratio = current_step / max_step
15
+ visual_tokenizer.config.tau = args.visual_max_tau - (args.visual_max_tau - args.visual_min_tau) * ratio
16
+
17
+
18
+ class MonitorCallback(TrainerCallback):
19
+ def _monitoring(self, model, step):
20
+ with torch.no_grad():
21
+ with deepspeed.zero.GatheredParameters(model.get_monitor_tensors().values()):
22
+ for k, v in model.get_monitor_tensors().items():
23
+ rank0_print(BEGIN_LINE)
24
+ rank0_print(f'{k} @ step {step} with sum: {v.sum().item()} and content: ')
25
+ rank0_print(v)
26
+ rank0_print(END_LINE)
27
+
28
+ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
29
+ model = kwargs['model']
30
+ step = state.global_step
31
+ if step % args.monitor_step == 0 or step == 10: # monitor at step 10 for fast check
32
+ self._monitoring(model, step)
33
+
34
+ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
35
+ model = kwargs['model']
36
+ step = state.global_step
37
+ self._monitoring(model, step)
ovis/train/dataset/__init__.py ADDED
File without changes
ovis/train/dataset/caption_dataset.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from datetime import datetime
3
+ from typing import Dict
4
+
5
+ import pandas
6
+ import torch
7
+
8
+ from ovis.train.dataset.multimodal_dataset import MultimodalDataset
9
+ from ovis.util.constants import IMAGE_TOKEN, IGNORE_ID
10
+ from ovis.util.utils import rank0_print
11
+
12
+
13
+ class CaptionDataset(MultimodalDataset):
14
+
15
+ def load(self):
16
+ rank0_print(f"[{datetime.now()}] Loading dataset {self.name} from {self.meta_file} begin")
17
+ samples = pandas.read_parquet(self.meta_file, engine='pyarrow')
18
+ rank0_print(f"[{datetime.now()}] Loading dataset {self.name} end")
19
+ return samples
20
+
21
+ def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
22
+ sample = self.samples.iloc[i]
23
+ text = sample['caption']
24
+ image_path = sample['image_path']
25
+
26
+ # read and preprocess image
27
+ pixel_values, image_placeholders = self.visual_tokenizer.mock_input()
28
+ valid_image = False
29
+ image, e = self.read_image(image_path)
30
+ if image is None:
31
+ logging.warning(
32
+ f'reading image failed with index: {i}, image path: {image_path}, and exception: {e}')
33
+ else:
34
+ try:
35
+ pixel_values, image_placeholders = self.visual_tokenizer.preprocess_image(
36
+ image, max_partition=self.max_partitions[0])
37
+ valid_image = True
38
+ except Exception as e:
39
+ logging.warning(
40
+ f'preprocessing image failed with index: {i}, image path: {image_path}, and exception: {e}')
41
+
42
+ # preprocess text
43
+ if text is None:
44
+ logging.warning(f'text is `None`, index: {i}')
45
+ text = ""
46
+ if not valid_image:
47
+ logging.warning(f'image is not valid, so set text as empty, index: {i}, image path: {image_path}')
48
+ text = ""
49
+ text = text.replace(IMAGE_TOKEN, '').strip()
50
+ head, tail = self.caption_template.split(IMAGE_TOKEN)
51
+ head_ids = self.text_tokenizer(head, add_special_tokens=False).input_ids
52
+ tail_ids = self.text_tokenizer(tail, add_special_tokens=False).input_ids
53
+ text_ids = self.text_tokenizer(text, add_special_tokens=False).input_ids
54
+ input_ids = head_ids + image_placeholders + tail_ids + text_ids
55
+ labels = [IGNORE_ID] * (len(input_ids) - len(text_ids)) + text_ids
56
+
57
+ input_ids = input_ids[:self.text_max_length]
58
+ labels = labels[:self.text_max_length]
59
+
60
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
61
+ labels = torch.tensor(labels, dtype=torch.long)
62
+
63
+ return dict(
64
+ pixel_values=pixel_values,
65
+ input_ids=input_ids,
66
+ labels=labels
67
+ )
ovis/train/dataset/conversation_dataset.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import logging
4
+ from datetime import datetime
5
+ from typing import Dict
6
+
7
+ import torch
8
+
9
+ from ovis.train.dataset.multimodal_dataset import MultimodalDataset
10
+ from ovis.util.utils import rank0_print
11
+
12
+
13
+ class ConversationDataset(MultimodalDataset):
14
+ def load(self):
15
+ rank0_print(f"[{datetime.now()}] Loading dataset {self.name} from {self.meta_file} begin")
16
+ with open(self.meta_file, 'r', encoding='utf-8') as f:
17
+ samples = json.load(f)
18
+ rank0_print(f'#samples: {len(samples)}')
19
+ rank0_print(f'sample: {samples[0]}')
20
+ rank0_print(f"[{datetime.now()}] Loading dataset {self.name} end")
21
+ return samples
22
+
23
+ def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
24
+ sample = self.samples[i]
25
+ conversations = copy.deepcopy(sample["conversations"])
26
+
27
+ images = None
28
+ max_partition = None
29
+ if 'image' in sample:
30
+ image_paths = sample['image']
31
+ if isinstance(image_paths, str):
32
+ image_paths = [image_paths]
33
+ images = []
34
+ for image_path in image_paths:
35
+ image, e = self.read_image(image_path)
36
+ if image is None:
37
+ logging.warning(
38
+ f'reading image failed with index: {i}, image path: {image_path}, and exception: {e}')
39
+ images = None
40
+ break
41
+ images.append(image)
42
+ elif 'video' in sample:
43
+ raise RuntimeError('video is to be supported')
44
+
45
+ if images:
46
+ max_partition = self.max_partitions[0] if len(images) == 1 else self.max_partitions[1]
47
+
48
+ prompt, input_ids, pixel_values, labels = self.model.preprocess_inputs(
49
+ conversations,
50
+ images,
51
+ max_partition=max_partition,
52
+ generation_preface=None,
53
+ return_labels=True,
54
+ propagate_exception=False
55
+ )
56
+
57
+ if pixel_values is None:
58
+ pixel_values, _ = self.visual_tokenizer.mock_input()
59
+
60
+ input_ids = input_ids[:self.text_max_length]
61
+ labels = labels[:self.text_max_length]
62
+
63
+ return dict(
64
+ pixel_values=pixel_values,
65
+ input_ids=input_ids,
66
+ labels=labels
67
+ )
ovis/train/dataset/multimodal_dataset.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Dict, Sequence, Union, List
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from torch.utils.data import Dataset
8
+ from transformers import PreTrainedTokenizer
9
+
10
+ from ovis.model.modeling_ovis import Ovis
11
+ from ovis.train.arguments import TrainingArguments
12
+ from ovis.util.constants import IGNORE_ID
13
+
14
+
15
+ class MultimodalDataset(Dataset):
16
+ def __init__(self, name: str, info: Dict, model: Ovis, training_args: TrainingArguments):
17
+ self.name = name
18
+ self.meta_file = info['meta_file']
19
+ self.image_dir = info['image_dir']
20
+ self.caption_template = info.get('caption_template', None)
21
+ self.text_tokenizer = model.get_text_tokenizer()
22
+ self.visual_tokenizer = model.get_visual_tokenizer()
23
+ self.image_height, self.image_width = self.visual_tokenizer.get_image_size()
24
+ self.model = model
25
+ self.text_max_length = training_args.text_max_length
26
+ self.max_partitions = [int(m.strip()) for m in training_args.max_partitions.split('|')]
27
+ self.samples = self.load()
28
+
29
+ def load(self):
30
+ raise NotImplementedError
31
+
32
+ def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
33
+ raise NotImplementedError
34
+
35
+ def __len__(self):
36
+ return len(self.samples)
37
+
38
+ def read_image(self, path):
39
+ try:
40
+ full_path = os.path.join(self.image_dir, path)
41
+ image = Image.open(full_path).convert('RGB')
42
+ return image, None
43
+ except Exception as e:
44
+ return None, e
45
+
46
+
47
+ class DataCollatorForMultimodalDataset:
48
+ def __init__(self, text_tokenizer: PreTrainedTokenizer):
49
+ self.text_tokenizer = text_tokenizer
50
+
51
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
52
+ pixel_values, input_ids, labels = tuple([instance[key] for instance in instances]
53
+ for key in ("pixel_values", "input_ids", "labels"))
54
+ input_ids = torch.nn.utils.rnn.pad_sequence(
55
+ input_ids,
56
+ batch_first=True,
57
+ padding_value=self.text_tokenizer.pad_token_id)
58
+ attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)
59
+ labels = torch.nn.utils.rnn.pad_sequence(
60
+ labels,
61
+ batch_first=True,
62
+ padding_value=IGNORE_ID)
63
+ num_valid_label = torch.not_equal(labels, IGNORE_ID).sum().item()
64
+ if num_valid_label == 0:
65
+ logging.warning(
66
+ f'[DataCollatorForMultimodalDataset] All labels in a batch are ignored, which may lead to training instability\n{input_ids=}\n{attention_mask=}\n{labels=}')
67
+ return dict(
68
+ input_ids=input_ids,
69
+ attention_mask=attention_mask,
70
+ labels=labels,
71
+ pixel_values=pixel_values
72
+ )
ovis/train/train.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pathlib
4
+
5
+ import deepspeed
6
+ import torch
7
+ import transformers
8
+ from deepspeed import get_accelerator
9
+ from torch.utils.data import ConcatDataset
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoConfig
11
+ from transformers import Trainer
12
+ from transformers.integrations.deepspeed import unset_hf_deepspeed_config, set_hf_deepspeed_config
13
+
14
+ from callback import TuneTauCallback, MonitorCallback
15
+ from ovis.model.configuration_ovis import OvisConfig
16
+ from ovis.model.modeling_ovis import Ovis
17
+ from ovis.train.arguments import ModelArguments, TrainingArguments
18
+ from ovis.train.dataset.caption_dataset import CaptionDataset
19
+ from ovis.train.dataset.conversation_dataset import ConversationDataset
20
+ from ovis.train.dataset.multimodal_dataset import DataCollatorForMultimodalDataset
21
+ from ovis.util.constants import BEGIN_LINE, END_LINE
22
+ from ovis.util.utils import smart_unit, rank0_print
23
+
24
+
25
+ def train():
26
+ # parse args
27
+ parser = transformers.HfArgumentParser(
28
+ (ModelArguments, TrainingArguments))
29
+ model_args, training_args = parser.parse_args_into_dataclasses()
30
+
31
+ # save args to checkpoint dir
32
+ with training_args.main_process_first(local=False):
33
+ if training_args.process_index == 0:
34
+ def args2dict(args):
35
+ return {k: str(v) for k, v in args.__dict__.items()}
36
+
37
+ args_log = json.dumps(dict(
38
+ model_args=args2dict(model_args),
39
+ training_args=args2dict(training_args)
40
+ ), ensure_ascii=False, indent=2)
41
+ print(args_log)
42
+ os.makedirs(training_args.output_dir, exist_ok=True)
43
+ with open(os.path.join(training_args.output_dir, 'model_training_args.json'), 'w',
44
+ encoding='utf-8') as f:
45
+ f.write(args_log + '\n')
46
+
47
+ # construct or load ovis model
48
+ if not training_args.ovis_pretrained_path: # construct model (S1)
49
+ # 1. construct ovis config
50
+ ovis_config = OvisConfig(
51
+ multimodal_max_length=model_args.multimodal_max_length,
52
+ conversation_formatter_class=model_args.conversation_formatter_class,
53
+ llm_attn_implementation=model_args.llm_attn_implementation
54
+ )
55
+ # 2. load pretrained llm and text tokenizer
56
+ attn_kwargs = dict()
57
+ if model_args.llm_attn_implementation:
58
+ attn_kwargs['attn_implementation'] = model_args.llm_attn_implementation
59
+ llm = AutoModelForCausalLM.from_pretrained(model_args.llm_name_or_path, **attn_kwargs)
60
+ text_tokenizer = AutoTokenizer.from_pretrained(model_args.llm_name_or_path)
61
+ if text_tokenizer.pad_token_id is None and model_args.pad_token_id is not None:
62
+ text_tokenizer.pad_token_id = model_args.pad_token_id
63
+ # 3. construct visual tokenizer
64
+ # deepspeed zero.Init with bfloat16 fail for visual_tokenizer, so temporarily disable zero.Init here
65
+ unset_hf_deepspeed_config()
66
+ if training_args.visual_tokenizer_pretrained_path is not None:
67
+ visual_tokenizer = AutoModel.from_pretrained(
68
+ training_args.visual_tokenizer_pretrained_path,
69
+ image_processor_name_or_path=training_args.visual_tokenizer_pretrained_path
70
+ )
71
+ else:
72
+ visual_tokenizer_config = AutoConfig.for_model(
73
+ model_type=model_args.visual_tokenizer_type + "_visual_tokenizer",
74
+ vocab_size=model_args.visual_vocab_size,
75
+ tokenize_function=model_args.visual_tokenize_function,
76
+ tau=model_args.visual_tau,
77
+ depths=model_args.visual_depths,
78
+ drop_cls_token=model_args.visual_drop_cls_token,
79
+ hidden_stride=model_args.visual_hidden_stride,
80
+ )
81
+ visual_tokenizer = AutoModel.from_config(visual_tokenizer_config, train_from_scratch=True)
82
+ visual_tokenizer = visual_tokenizer.to(
83
+ device=torch.device(get_accelerator().device_name(os.getenv("LOCAL_RANK"))))
84
+ if getattr(training_args, 'hf_deepspeed_config', None) is not None:
85
+ set_hf_deepspeed_config(training_args.hf_deepspeed_config)
86
+ # 4. construct ovis model
87
+ model = Ovis(ovis_config, llm=llm, text_tokenizer=text_tokenizer, visual_tokenizer=visual_tokenizer,
88
+ train_from_scratch=True)
89
+ else: # load pretrained ovis model
90
+ model, loading_info = Ovis.from_pretrained(training_args.ovis_pretrained_path,
91
+ multimodal_max_length=model_args.multimodal_max_length,
92
+ output_loading_info=True)
93
+ rank0_print(BEGIN_LINE)
94
+ rank0_print(f'Loading info of Ovis:\n{loading_info}')
95
+ rank0_print(END_LINE)
96
+ training_args.vte_re_init = False
97
+
98
+ model.get_llm().config.use_cache = False
99
+ model.config.use_cache = False
100
+ text_tokenizer = model.get_text_tokenizer()
101
+
102
+ rank0_print(BEGIN_LINE)
103
+ rank0_print(f'model.config:\n{model.config}')
104
+ rank0_print(END_LINE)
105
+
106
+ # maybe re-init vte
107
+ if training_args.vte_re_init:
108
+ with deepspeed.zero.GatheredParameters([model.get_wte().weight]):
109
+ mean = model.get_wte().weight.mean().item()
110
+ std = model.get_wte().weight.std().item()
111
+ rank0_print(f'Statistics of embedding table of LLM: {mean=}, {std=}')
112
+ model.re_init_vte(mean, std)
113
+
114
+ # select train modules
115
+ model.requires_grad_(False)
116
+ for module in training_args.train_modules.split('|'):
117
+ if module == 'all':
118
+ model.requires_grad_(True)
119
+ elif module == 'llm':
120
+ model.get_llm().requires_grad_(True)
121
+ elif module == 'visual_tokenizer':
122
+ model.get_visual_tokenizer().requires_grad_(True)
123
+ elif module == 'visual_tokenizer.backbone':
124
+ model.get_visual_tokenizer().get_backbone().requires_grad_(True)
125
+ elif module.startswith('visual_tokenizer.backbone.layer.'):
126
+ layer_index = int(module[len('visual_tokenizer.backbone.layer.'):])
127
+ layer = model.get_visual_tokenizer().get_backbone_layer(layer_index)
128
+ layer.requires_grad_(True)
129
+ elif module == 'visual_tokenizer.head':
130
+ model.get_visual_tokenizer().get_head().requires_grad_(True)
131
+ elif module == 'vte':
132
+ model.get_vte().requires_grad_(True)
133
+ else:
134
+ raise ValueError(f'Invalid train module name: {module}')
135
+
136
+ rank0_print(BEGIN_LINE)
137
+ rank0_print('Parameters to train:')
138
+ for name, param in model.named_parameters():
139
+ if param.requires_grad:
140
+ rank0_print(name)
141
+ rank0_print(f'LLM\'s attn implementation: {model.get_llm().config._attn_implementation}')
142
+ rank0_print(END_LINE)
143
+
144
+ # construct data module
145
+ datasets = []
146
+ dataset_info_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
147
+ f'dataset/{training_args.dataset_info}.json')
148
+ with open(dataset_info_path, 'r', encoding='utf-8') as f:
149
+ dataset_info = json.load(f)
150
+ for name in training_args.dataset_names.split('|'):
151
+ info = dataset_info[name]
152
+ data_format = info['data_format']
153
+ if data_format == 'caption':
154
+ dataset = CaptionDataset(name, info, model, training_args)
155
+ elif data_format == 'conversation':
156
+ dataset = ConversationDataset(name, info, model, training_args)
157
+ else:
158
+ raise ValueError(f'Invalid data format `{data_format}` for dataset `{name}`')
159
+ datasets.append(dataset)
160
+ data_module = dict(
161
+ train_dataset=ConcatDataset(datasets),
162
+ data_collator=DataCollatorForMultimodalDataset(text_tokenizer)
163
+ )
164
+
165
+ # train
166
+ train_callbacks = [MonitorCallback]
167
+ if model_args.visual_tokenize_function == 'gumbel_argmax':
168
+ train_callbacks.append(TuneTauCallback)
169
+ trainer = Trainer(
170
+ model=model,
171
+ args=training_args,
172
+ callbacks=train_callbacks,
173
+ **data_module
174
+ )
175
+ rank0_print(BEGIN_LINE)
176
+ rank0_print('Dataset sample tensor:')
177
+ rank0_print(data_module['train_dataset'][0])
178
+ rank0_print(END_LINE)
179
+ rank0_print(BEGIN_LINE)
180
+ rank0_print('Dataset sample input_ids decoding:')
181
+ rank0_print(text_tokenizer.decode([x for x in data_module['train_dataset'][0]['input_ids'] if x >= 0]))
182
+ rank0_print(END_LINE)
183
+ rank0_print(BEGIN_LINE)
184
+ rank0_print('Dataset sample labels decoding:')
185
+ rank0_print(text_tokenizer.decode([x for x in data_module['train_dataset'][0]['labels'] if x >= 0]))
186
+ rank0_print(END_LINE)
187
+ rank0_print(BEGIN_LINE)
188
+ rank0_print(f'#param of model: {smart_unit(model.num_parameters())}')
189
+ rank0_print(f'#param of llm: {smart_unit(model.get_llm().num_parameters())}')
190
+ rank0_print(f'#param of visual_tokenizer: {smart_unit(model.get_visual_tokenizer().num_parameters())}')
191
+ rank0_print(f'#param of vte: {smart_unit(model.get_vte().weight.numel())}')
192
+ rank0_print(END_LINE)
193
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
194
+ trainer.train(resume_from_checkpoint=True)
195
+ else:
196
+ trainer.train()
197
+ trainer.save_state()
198
+
199
+ # save model
200
+ model.get_llm().config.use_cache = True
201
+ model.config.use_cache = True
202
+ trainer.save_model()
203
+
204
+
205
+ if __name__ == '__main__':
206
+ train()
ovis/util/__init__.py ADDED
File without changes
ovis/util/constants.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Constants
2
+ IGNORE_ID = -100
3
+ IMAGE_TOKEN_ID = -200
4
+ IMAGE_TOKEN = "<image>"
5
+
6
+ IMAGE_ATOM_ID = -300
7
+ IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305]
8
+
9
+ # Log & Print
10
+ BEGIN_LINE = '========================************========================'
11
+ END_LINE = '------------------------------------------------------------'
ovis/util/utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from importlib import import_module
3
+
4
+
5
+ def rank0_print(*args):
6
+ if int(os.getenv("LOCAL_PROCESS_RANK", os.getenv("LOCAL_RANK", 0))) == 0:
7
+ print(*args)
8
+
9
+
10
+ def smart_unit(num):
11
+ if num / 1.0e9 >= 1:
12
+ return f'{num / 1.0e9:.2f}B'
13
+ else:
14
+ return f'{num / 1.0e6:.2f}M'
15
+
16
+
17
+ def import_class_from_string(full_class_string):
18
+ # Split the path to get separate module and class names
19
+ module_path, _, class_name = full_class_string.rpartition('.')
20
+
21
+ # Import the module using the module path
22
+ module = import_module(module_path)
23
+
24
+ # Get the class from the imported module
25
+ cls = getattr(module, class_name)
26
+ return cls