kppkkp commited on
Commit
73bcc49
1 Parent(s): 3daadf9

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.json +9 -6
  2. modeling_OneChart.py +496 -0
  3. sam_vision_b.py +468 -0
config.json CHANGED
@@ -1,11 +1,15 @@
1
  {
2
- "_name_or_path": "/data/hypertext/xpkong/detvary/checkpoints/318-train_chartqax10-min_max_normalize_l1/checkpoint-51500/",
3
  "_remove_final_layer_norm": false,
4
  "activation_dropout": 0.0,
5
  "activation_function": "relu",
6
  "architectures": [
7
- "MMGPTOPTForCausalLM"
8
  ],
 
 
 
 
9
  "attention_dropout": 0.0,
10
  "bos_token_id": 2,
11
  "do_layer_norm_before": true,
@@ -18,15 +22,15 @@
18
  "im_end_token": 50267,
19
  "im_patch_token": 50265,
20
  "im_start_token": 50266,
 
21
  "image_token_len": 256,
22
  "init_std": 0.02,
23
  "layer_norm_elementwise_affine": true,
24
  "layerdrop": 0.0,
25
  "max_position_embeddings": 4096,
26
- "model_type": "mmgpt",
27
  "num_attention_heads": 12,
28
  "num_hidden_layers": 12,
29
- "number_token": 50268,
30
  "pad_token_id": 1,
31
  "prefix": "</s>",
32
  "torch_dtype": "bfloat16",
@@ -34,7 +38,6 @@
34
  "use_cache": true,
35
  "use_im_start_end": true,
36
  "vision_select_layer": -2,
37
- "vision_tower": "/mnt/host0/vit-large-patch14",
38
  "vocab_size": 50269,
39
  "word_embed_proj_dim": 768
40
- }
 
1
  {
2
+ "_name_or_path": "kppkkp/OneChart",
3
  "_remove_final_layer_norm": false,
4
  "activation_dropout": 0.0,
5
  "activation_function": "relu",
6
  "architectures": [
7
+ "OneChartOPTForCausalLM"
8
  ],
9
+ "auto_map": {
10
+ "AutoConfig": "modeling_OneChart.OneChartConfig",
11
+ "AutoModel": "modeling_OneChart.OneChartOPTForCausalLM"
12
+ },
13
  "attention_dropout": 0.0,
14
  "bos_token_id": 2,
15
  "do_layer_norm_before": true,
 
22
  "im_end_token": 50267,
23
  "im_patch_token": 50265,
24
  "im_start_token": 50266,
25
+ "number_token": 50268,
26
  "image_token_len": 256,
27
  "init_std": 0.02,
28
  "layer_norm_elementwise_affine": true,
29
  "layerdrop": 0.0,
30
  "max_position_embeddings": 4096,
31
+ "model_type": "OneChart",
32
  "num_attention_heads": 12,
33
  "num_hidden_layers": 12,
 
34
  "pad_token_id": 1,
35
  "prefix": "</s>",
36
  "torch_dtype": "bfloat16",
 
38
  "use_cache": true,
39
  "use_im_start_end": true,
40
  "vision_select_layer": -2,
 
41
  "vocab_size": 50269,
42
  "word_embed_proj_dim": 768
43
+ }
modeling_OneChart.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import OPTConfig, OPTModel, OPTForCausalLM, StoppingCriteria, TextStreamer
2
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
3
+ from typing import List, Optional, Tuple, Union
4
+ import requests
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import json
8
+ import re
9
+ import torch
10
+ import numpy as np
11
+ import torch.nn as nn
12
+ from torch.nn import CrossEntropyLoss
13
+ import torch.nn.functional as F
14
+ from .sam_vision_b import build_SAM_vit_b
15
+ from torchvision import transforms
16
+ from torchvision.transforms.functional import InterpolationMode
17
+ import dataclasses
18
+
19
+ DEFAULT_IMAGE_TOKEN = "<image>"
20
+ DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
21
+ DEFAULT_IM_START_TOKEN = '<img>'
22
+ DEFAULT_IM_END_TOKEN = '</img>'
23
+
24
+ from enum import auto, Enum
25
+ class SeparatorStyle(Enum):
26
+ """Different separator style."""
27
+ SINGLE = auto()
28
+ TWO = auto()
29
+ MPT = auto()
30
+
31
+
32
+ @dataclasses.dataclass
33
+ class Conversation:
34
+ """A class that keeps all conversation history."""
35
+ system: str
36
+ roles: List[str]
37
+ messages: List[List[str]]
38
+ offset: int
39
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
40
+ sep: str = "<|im_end|>"
41
+ sep2: str = None
42
+ version: str = "Unknown"
43
+
44
+ skip_next: bool = False
45
+
46
+ def get_prompt(self):
47
+ if self.sep_style == SeparatorStyle.SINGLE:
48
+ ret = self.system + self.sep + '\n'
49
+ for role, message in self.messages:
50
+ if message:
51
+ if type(message) is tuple:
52
+ message, _, _ = message
53
+ ret += role + ": " + message + self.sep
54
+ else:
55
+ ret += role + ":"
56
+ return ret
57
+ elif self.sep_style == SeparatorStyle.TWO:
58
+ seps = [self.sep, self.sep2]
59
+ ret = self.system + seps[0]
60
+ for i, (role, message) in enumerate(self.messages):
61
+ if message:
62
+ if type(message) is tuple:
63
+ message, _, _ = message
64
+ ret += role + ": " + message + seps[i % 2]
65
+ else:
66
+ ret += role + ":"
67
+ return ret
68
+ if self.sep_style == SeparatorStyle.MPT:
69
+ if self.system:
70
+ ret = self.system + self.sep
71
+ else:
72
+ ret = ''
73
+ for role, message in self.messages:
74
+ if message:
75
+ if type(message) is tuple:
76
+ message, _, _ = message
77
+ ret += role + message + self.sep
78
+ else:
79
+ ret += role
80
+ return ret
81
+ else:
82
+ raise ValueError(f"Invalid style: {self.sep_style}")
83
+
84
+
85
+ def append_message(self, role, message):
86
+ self.messages.append([role, message])
87
+
88
+ def copy(self):
89
+ return Conversation(
90
+ system=self.system,
91
+ roles=self.roles,
92
+ messages=[[x, y] for x, y in self.messages],
93
+ offset=self.offset,
94
+ sep_style=self.sep_style,
95
+ sep=self.sep,
96
+ sep2=self.sep2)
97
+
98
+
99
+ class KeywordsStoppingCriteria(StoppingCriteria):
100
+ def __init__(self, keywords, tokenizer, input_ids):
101
+ self.keywords = keywords
102
+ self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
103
+ self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
104
+ self.tokenizer = tokenizer
105
+ self.start_len = None
106
+ self.input_ids = input_ids
107
+
108
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
109
+ if self.start_len is None:
110
+ self.start_len = self.input_ids.shape[1]
111
+ else:
112
+ for keyword_id in self.keyword_ids:
113
+ if output_ids[0, -1] == keyword_id:
114
+ return True
115
+ outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
116
+ for keyword in self.keywords:
117
+ if keyword in outputs:
118
+ return True
119
+ return False
120
+
121
+ conv_vicuna_v1_1 = Conversation(
122
+ system="A chat between a curious user and an artificial intelligence assistant. "
123
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
124
+ roles=("USER", "ASSISTANT"),
125
+ version="v1",
126
+ messages=(),
127
+ offset=0,
128
+ sep_style=SeparatorStyle.TWO,
129
+ sep=" ",
130
+ sep2="</s>",
131
+ )
132
+
133
+ class OneChartImageEvalProcessor:
134
+ def __init__(self, image_size=1024):
135
+ mean = (0., 0., 0.)
136
+ std = (1., 1., 1.)
137
+ self.normalize = transforms.Normalize(mean, std)
138
+ self.transform = transforms.Compose(
139
+ [
140
+ transforms.Resize(
141
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
142
+ ),
143
+ transforms.ToTensor(),
144
+ self.normalize,
145
+ ]
146
+ )
147
+ def __call__(self, item):
148
+ return self.transform(item)
149
+
150
+
151
+ class OneChartConfig(OPTConfig):
152
+ model_type = "OneChart"
153
+
154
+ class OneChartModel(OPTModel):
155
+ config_class = OneChartConfig
156
+
157
+ def __init__(self, config: OPTConfig):
158
+ super(OneChartModel, self).__init__(config)
159
+ self.vision_tower = build_SAM_vit_b()
160
+ self.mm_projector = nn.Linear(1024, 768)
161
+
162
+ def embed_tokens(self, x):
163
+ return self.get_input_embeddings()(x)
164
+
165
+ def forward(
166
+ self,
167
+ input_ids: torch.LongTensor = None,
168
+ attention_mask: Optional[torch.Tensor] = None,
169
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
170
+ inputs_embeds: Optional[torch.FloatTensor] = None,
171
+ use_cache: Optional[bool] = None,
172
+ output_attentions: Optional[bool] = None,
173
+ output_hidden_states: Optional[bool] = None,
174
+ images: Optional[torch.FloatTensor] = None,
175
+ return_dict: Optional[bool] = None,
176
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
177
+
178
+ if inputs_embeds is None:
179
+ inputs_embeds = self.embed_tokens(input_ids)
180
+
181
+ vision_tower_high = getattr(self, 'vision_tower', None)
182
+ if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
183
+ use_im_start_end = getattr(self.config, "use_im_start_end", -1)
184
+ vision_select_layer = getattr(self.config, "vision_select_layer", -1)
185
+ im_patch_token = getattr(self.config, "im_patch_token", -1)
186
+ im_start_token = getattr(self.config, "im_start_token", -1)
187
+ im_end_token = getattr(self.config, "im_end_token", -1)
188
+ freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)
189
+
190
+ image_features = []
191
+ for image in images:
192
+ P, C, H, W = image.shape
193
+ if P == 1:
194
+ with torch.set_grad_enabled(False):
195
+ cnn_feature = vision_tower_high(image)
196
+ cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
197
+ image_feature = self.mm_projector(cnn_feature)
198
+ image_features.append(image_feature)
199
+ else:
200
+ raise NotImplementedError("Batch inference needs to be implemented.")
201
+
202
+
203
+ use_im_start_end = True
204
+ new_input_embeds = []
205
+ for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
206
+ if use_im_start_end:
207
+ if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
208
+ raise ValueError("The number of image start tokens and image end tokens should be the same.")
209
+
210
+ image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
211
+ for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
212
+ per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
213
+ num_patches = per_cur_image_features.shape[0]
214
+
215
+ if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
216
+ raise ValueError("The image end token should follow the image start token.")
217
+
218
+ cur_input_embeds = torch.cat(
219
+ (
220
+ cur_input_embeds[:image_start_token_pos+1],
221
+ per_cur_image_features,
222
+ cur_input_embeds[image_start_token_pos + num_patches + 1:]
223
+ ),
224
+ dim=0
225
+ )
226
+
227
+ new_input_embeds.append(cur_input_embeds)
228
+ else:
229
+ raise NotImplementedError
230
+
231
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
232
+
233
+ return super(OneChartModel, self).forward(
234
+ input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
235
+ inputs_embeds=inputs_embeds, use_cache=use_cache,
236
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
237
+ return_dict=return_dict
238
+ )
239
+
240
+
241
+ class OneChartOPTForCausalLM(OPTForCausalLM):
242
+ config_class = OneChartConfig
243
+ def __init__(self, config):
244
+ super(OneChartOPTForCausalLM, self).__init__(config)
245
+ self.model = OneChartModel(config)
246
+ self.vocab_size = config.vocab_size
247
+ self.num_decoder = nn.Sequential(
248
+ nn.Linear(config.hidden_size, config.hidden_size // 2),
249
+ nn.ReLU(),
250
+ nn.Linear(config.hidden_size // 2, config.hidden_size // 2),
251
+ nn.ReLU(),
252
+ nn.Linear(config.hidden_size // 2, 256),
253
+ )
254
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
255
+ self.pred_locs = []
256
+ # Initialize weights and apply final processing
257
+ self.post_init()
258
+
259
+ def get_model(self):
260
+ return self.model
261
+
262
+ def forward(
263
+ self,
264
+ input_ids: Optional[torch.LongTensor] = None,
265
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
266
+ attention_mask: Optional[torch.FloatTensor] = None,
267
+ token_type_ids: Optional[torch.LongTensor] = None,
268
+ position_ids: Optional[torch.LongTensor] = None,
269
+ head_mask: Optional[torch.FloatTensor] = None,
270
+ inputs_embeds: Optional[torch.FloatTensor] = None,
271
+ encoder_hidden_states: Optional[torch.Tensor] = None,
272
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
273
+ labels: Optional[torch.LongTensor] = None,
274
+ use_cache: Optional[bool] = None,
275
+ output_attentions: Optional[bool] = None,
276
+ output_hidden_states: Optional[bool] = None,
277
+ images: Optional[torch.FloatTensor] = None,
278
+ return_dict: Optional[bool] = None,
279
+ loc_labels=None,
280
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
281
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
282
+ output_hidden_states = (
283
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
284
+ )
285
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
286
+
287
+ outputs = self.model(
288
+ input_ids=input_ids,
289
+ past_key_values=past_key_values,
290
+ attention_mask=attention_mask,
291
+ inputs_embeds=inputs_embeds,
292
+ use_cache=use_cache,
293
+ output_attentions=output_attentions,
294
+ output_hidden_states=output_hidden_states,
295
+ images=images,
296
+ return_dict=return_dict
297
+ )
298
+
299
+ hidden_states = outputs[0]
300
+ if (loc_labels is not None) and len(loc_labels) > 0:
301
+ det_patch_token = torch.where(input_ids == self.config.number_token)[1][0]
302
+ pred_locs = self.num_decoder(hidden_states[:, det_patch_token, :]) # shape: [batch_size, 256]
303
+
304
+ # inference时输出num_head预测的值
305
+ if not self.training:
306
+ try:
307
+ det_patch_token = torch.where(input_ids == self.config.number_token)[1][0]
308
+ pred_locs = self.num_decoder(hidden_states[:, det_patch_token, :]) # shape: [batch_size, 256]
309
+ self.pred_locs = pred_locs[0][:100].cpu().tolist()
310
+ except Exception as e:
311
+ pass
312
+
313
+ logits = self.lm_head(hidden_states)
314
+ logits = logits.float()
315
+
316
+ # logits
317
+ loss = None
318
+ if labels is not None:
319
+ # Shift so that tokens < n predict n
320
+ shift_logits = logits[..., :-1, :].contiguous()
321
+ shift_labels = labels[..., 1:].contiguous()
322
+ # Flatten the tokens
323
+ loss_fct = CrossEntropyLoss()
324
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
325
+ shift_labels = shift_labels.view(-1)
326
+ # Enable model parallelism
327
+ shift_labels = shift_labels.to(shift_logits.device)
328
+ loss = loss_fct(shift_logits, shift_labels)
329
+
330
+ if not return_dict:
331
+ output = (logits,) + outputs[1:]
332
+ return (loss,) + output if loss is not None else output
333
+
334
+ return CausalLMOutputWithPast(
335
+ loss=loss,
336
+ logits=logits,
337
+ past_key_values=outputs.past_key_values,
338
+ hidden_states=outputs.hidden_states,
339
+ attentions=outputs.attentions,
340
+ )
341
+
342
+ def prepare_inputs_for_generation(
343
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
344
+ ):
345
+ token_type_ids = kwargs.get("token_type_ids", None)
346
+ if past_key_values:
347
+ input_ids = input_ids[:, -1].unsqueeze(-1)
348
+ if token_type_ids is not None:
349
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
350
+
351
+ attention_mask = kwargs.get("attention_mask", None)
352
+ position_ids = kwargs.get("position_ids", None)
353
+
354
+ if attention_mask is not None and position_ids is None:
355
+ position_ids = attention_mask.long().cumsum(-1) - 1
356
+ position_ids.masked_fill_(attention_mask == 0, 1)
357
+ if past_key_values:
358
+ position_ids = position_ids[:, -1].unsqueeze(-1)
359
+ else:
360
+ position_ids = None
361
+
362
+ if inputs_embeds is not None and past_key_values is None:
363
+ model_inputs = {"inputs_embeds": inputs_embeds}
364
+ else:
365
+ model_inputs = {"input_ids": input_ids}
366
+
367
+ model_inputs.update(
368
+ {
369
+ "past_key_values": past_key_values,
370
+ "use_cache": kwargs.get("use_cache"),
371
+ "position_ids": position_ids,
372
+ "attention_mask": attention_mask,
373
+ "token_type_ids": token_type_ids,
374
+ "images": kwargs.get("images", None),
375
+ }
376
+ )
377
+ return model_inputs
378
+
379
+
380
+ def load_image(self, image_file):
381
+ if image_file.startswith('http') or image_file.startswith('https'):
382
+ response = requests.get(image_file)
383
+ image = Image.open(BytesIO(response.content)).convert('RGB')
384
+ else:
385
+ image = Image.open(image_file).convert('RGB')
386
+ return image
387
+
388
+ def disable_torch_init(self):
389
+ """
390
+ Disable the redundant torch default initialization to accelerate model creation.
391
+ """
392
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
393
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
394
+
395
+ def chat(self, tokenizer, image_file, reliable_check=True, print_prompt=False):
396
+ dtype=torch.bfloat16
397
+ device="cuda"
398
+ def list_json_value(json_dict):
399
+ rst_str = []
400
+ sort_flag = True
401
+ try:
402
+ for key, value in json_dict.items():
403
+ if isinstance(value, dict):
404
+ decimal_out = list_json_value(value)
405
+ rst_str = rst_str + decimal_out
406
+ sort_flag = False
407
+ elif isinstance(value, list):
408
+ return []
409
+ else:
410
+ if isinstance(value, float) or isinstance(value, int):
411
+ rst_str.append(value)
412
+ else:
413
+ # num_value = value.replace("%", "").replace("$", "").replace(" ", "").replace(",", "")
414
+ value = re.sub(r'\(\d+\)|\[\d+\]', '', value)
415
+ num_value = re.sub(r'[^\d.-]', '', str(value))
416
+ if num_value not in ["-", "*", "none", "None", ""]:
417
+ rst_str.append(float(num_value))
418
+ except Exception as e:
419
+ print(f"Error: {e}")
420
+ # print(json_dict)
421
+ return []
422
+ # if len(rst_str) > 0:
423
+ # rst_str = rst_str + [float(-1)]
424
+ return rst_str
425
+
426
+ def norm_(rst_list):
427
+ if len(rst_list) < 2:
428
+ return rst_list
429
+ min_vals = min(rst_list)
430
+ max_vals = max(rst_list)
431
+ rst_list = np.array(rst_list)
432
+ normalized_tensor = (rst_list - min_vals) / (max_vals - min_vals + 1e-9)
433
+ return list(normalized_tensor)
434
+
435
+ self.disable_torch_init()
436
+ image_processor_high = OneChartImageEvalProcessor(image_size=1024)
437
+ use_im_start_end = True
438
+ image_token_len = 256
439
+ image = self.load_image(image_file)
440
+ image_tensor_1 = image_processor_high(image).to(dtype=dtype, device=device)
441
+
442
+ query = 'Convert the key information of the chart to a python dict:'
443
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + query + '\n'
444
+ conv = conv_vicuna_v1_1.copy()
445
+ conv.append_message(conv.roles[0], qs)
446
+ conv.append_message(conv.roles[1], None)
447
+ prompt = conv.get_prompt()
448
+
449
+ if print_prompt:
450
+ print(prompt)
451
+
452
+ inputs = tokenizer([prompt])
453
+ input_ids = torch.as_tensor(inputs.input_ids).to(device=device)
454
+ stop_str = '</s>'
455
+ keywords = [stop_str]
456
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
457
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
458
+
459
+ with torch.autocast(device, dtype=dtype):
460
+ output_ids = self.generate(
461
+ input_ids,
462
+ images=[image_tensor_1.unsqueeze(0).half()],
463
+ do_sample=False,
464
+ num_beams = 1,
465
+ # no_repeat_ngram_size = 20,
466
+ # streamer=streamer,
467
+ max_new_tokens=4096,
468
+ stopping_criteria=[stopping_criteria]
469
+ )
470
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
471
+ outputs = outputs.replace("<Number>", "")
472
+ outputs = outputs.strip()
473
+ if outputs.endswith(stop_str):
474
+ outputs = outputs[:-len(stop_str)]
475
+ response_str = outputs
476
+
477
+ if reliable_check:
478
+ pred_nums = self.pred_locs
479
+ try:
480
+ outputs_json = json.loads(outputs)
481
+ list_v = list_json_value(outputs_json['values'])
482
+ list_v = [round(x,4) for x in norm_(list_v)]
483
+ gt_nums = torch.tensor(list_v).reshape(1,-1)
484
+ response_str = response_str + "\n<Chart>: " + str(pred_nums[:len(list_v)])
485
+ pred_nums_ = torch.tensor(pred_nums[:len(list_v)]).reshape(1,-1)
486
+ reliable_distence = F.l1_loss(pred_nums_, gt_nums)
487
+ response_str = response_str + "\nreliable_distence: " + str(reliable_distence)
488
+ if reliable_distence < 0.1:
489
+ response_str = response_str + "\nAfter OneChart checking, this prediction is reliable."
490
+ else:
491
+ response_str = response_str + "\nThis prediction may be has error! "
492
+ except Exception as e:
493
+ response_str = response_str + "\nThis prediction may be has error! "
494
+ response_str = response_str + "\n" + str(e)
495
+
496
+ return response_str
sam_vision_b.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import Optional, Tuple, Type
4
+ from functools import partial
5
+ import torch.nn as nn
6
+ from typing import Type
7
+
8
+
9
+
10
+ class MLPBlock(nn.Module):
11
+ def __init__(
12
+ self,
13
+ embedding_dim: int,
14
+ mlp_dim: int,
15
+ act: Type[nn.Module] = nn.GELU,
16
+ ) -> None:
17
+ super().__init__()
18
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
19
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
20
+ self.act = act()
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ return self.lin2(self.act(self.lin1(x)))
24
+
25
+
26
+
27
+ class LayerNorm2d(nn.Module):
28
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
29
+ super().__init__()
30
+ self.weight = nn.Parameter(torch.ones(num_channels))
31
+ self.bias = nn.Parameter(torch.zeros(num_channels))
32
+ self.eps = eps
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ u = x.mean(1, keepdim=True)
36
+ s = (x - u).pow(2).mean(1, keepdim=True)
37
+ x = (x - u) / torch.sqrt(s + self.eps)
38
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
39
+ return x
40
+
41
+
42
+
43
+ class ImageEncoderViT(nn.Module):
44
+ def __init__(
45
+ self,
46
+ img_size: int = 1024,
47
+ patch_size: int = 16,
48
+ in_chans: int = 3,
49
+ embed_dim: int = 768,
50
+ depth: int = 12,
51
+ num_heads: int = 12,
52
+ mlp_ratio: float = 4.0,
53
+ out_chans: int = 256,
54
+ qkv_bias: bool = True,
55
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
56
+ act_layer: Type[nn.Module] = nn.GELU,
57
+ use_abs_pos: bool = True,
58
+ use_rel_pos: bool = False,
59
+ rel_pos_zero_init: bool = True,
60
+ window_size: int = 0,
61
+ global_attn_indexes: Tuple[int, ...] = (),
62
+ ) -> None:
63
+ """
64
+ Args:
65
+ img_size (int): Input image size.
66
+ patch_size (int): Patch size.
67
+ in_chans (int): Number of input image channels.
68
+ embed_dim (int): Patch embedding dimension.
69
+ depth (int): Depth of ViT.
70
+ num_heads (int): Number of attention heads in each ViT block.
71
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
72
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
73
+ norm_layer (nn.Module): Normalization layer.
74
+ act_layer (nn.Module): Activation layer.
75
+ use_abs_pos (bool): If True, use absolute positional embeddings.
76
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
77
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
78
+ window_size (int): Window size for window attention blocks.
79
+ global_attn_indexes (list): Indexes for blocks using global attention.
80
+ """
81
+ super().__init__()
82
+ self.img_size = img_size
83
+
84
+ self.patch_embed = PatchEmbed(
85
+ kernel_size=(patch_size, patch_size),
86
+ stride=(patch_size, patch_size),
87
+ in_chans=in_chans,
88
+ embed_dim=embed_dim,
89
+ )
90
+
91
+ self.pos_embed: Optional[nn.Parameter] = None
92
+ if use_abs_pos:
93
+ # Initialize absolute positional embedding with pretrain image size.
94
+ self.pos_embed = nn.Parameter(
95
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
96
+ )
97
+
98
+ self.blocks = nn.ModuleList()
99
+ for i in range(depth):
100
+ block = Block(
101
+ dim=embed_dim,
102
+ num_heads=num_heads,
103
+ mlp_ratio=mlp_ratio,
104
+ qkv_bias=qkv_bias,
105
+ norm_layer=norm_layer,
106
+ act_layer=act_layer,
107
+ use_rel_pos=use_rel_pos,
108
+ rel_pos_zero_init=rel_pos_zero_init,
109
+ window_size=window_size if i not in global_attn_indexes else 0,
110
+ input_size=(img_size // patch_size, img_size // patch_size),
111
+ )
112
+ self.blocks.append(block)
113
+
114
+ self.neck = nn.Sequential(
115
+ nn.Conv2d(
116
+ embed_dim,
117
+ out_chans,
118
+ kernel_size=1,
119
+ bias=False,
120
+ ),
121
+ LayerNorm2d(out_chans),
122
+ nn.Conv2d(
123
+ out_chans,
124
+ out_chans,
125
+ kernel_size=3,
126
+ padding=1,
127
+ bias=False,
128
+ ),
129
+ LayerNorm2d(out_chans),
130
+ )
131
+
132
+
133
+ self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
134
+ self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
135
+
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
137
+ x = self.patch_embed(x)
138
+ if self.pos_embed is not None:
139
+ x = x + self.pos_embed
140
+
141
+ for blk in self.blocks:
142
+ x = blk(x)
143
+
144
+ x = self.neck(x.permute(0, 3, 1, 2))
145
+ x = self.net_2(x)
146
+ x = self.net_3(x)
147
+
148
+
149
+ return x
150
+
151
+
152
+ class Block(nn.Module):
153
+ """Transformer blocks with support of window attention and residual propagation blocks"""
154
+
155
+ def __init__(
156
+ self,
157
+ dim: int,
158
+ num_heads: int,
159
+ mlp_ratio: float = 4.0,
160
+ qkv_bias: bool = True,
161
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
162
+ act_layer: Type[nn.Module] = nn.GELU,
163
+ use_rel_pos: bool = False,
164
+ rel_pos_zero_init: bool = True,
165
+ window_size: int = 0,
166
+ input_size: Optional[Tuple[int, int]] = None,
167
+ ) -> None:
168
+ """
169
+ Args:
170
+ dim (int): Number of input channels.
171
+ num_heads (int): Number of attention heads in each ViT block.
172
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
173
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
174
+ norm_layer (nn.Module): Normalization layer.
175
+ act_layer (nn.Module): Activation layer.
176
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
177
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
178
+ window_size (int): Window size for window attention blocks. If it equals 0, then
179
+ use global attention.
180
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
181
+ positional parameter size.
182
+ """
183
+ super().__init__()
184
+ self.norm1 = norm_layer(dim)
185
+ self.attn = Attention(
186
+ dim,
187
+ num_heads=num_heads,
188
+ qkv_bias=qkv_bias,
189
+ use_rel_pos=use_rel_pos,
190
+ rel_pos_zero_init=rel_pos_zero_init,
191
+ input_size=input_size if window_size == 0 else (window_size, window_size),
192
+ )
193
+
194
+ self.norm2 = norm_layer(dim)
195
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
196
+
197
+ self.window_size = window_size
198
+
199
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
200
+ shortcut = x
201
+ x = self.norm1(x)
202
+ # Window partition
203
+ if self.window_size > 0:
204
+ H, W = x.shape[1], x.shape[2]
205
+ x, pad_hw = window_partition(x, self.window_size)
206
+
207
+ x = self.attn(x)
208
+ # Reverse window partition
209
+ if self.window_size > 0:
210
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
211
+
212
+ x = shortcut + x
213
+ x = x + self.mlp(self.norm2(x))
214
+
215
+ return x
216
+
217
+
218
+ class Attention(nn.Module):
219
+ """Multi-head Attention block with relative position embeddings."""
220
+
221
+ def __init__(
222
+ self,
223
+ dim: int,
224
+ num_heads: int = 8,
225
+ qkv_bias: bool = True,
226
+ use_rel_pos: bool = False,
227
+ rel_pos_zero_init: bool = True,
228
+ input_size: Optional[Tuple[int, int]] = None,
229
+ ) -> None:
230
+ """
231
+ Args:
232
+ dim (int): Number of input channels.
233
+ num_heads (int): Number of attention heads.
234
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
235
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
236
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
237
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
238
+ positional parameter size.
239
+ """
240
+ super().__init__()
241
+ self.num_heads = num_heads
242
+ head_dim = dim // num_heads
243
+ self.scale = head_dim**-0.5
244
+
245
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
246
+ self.proj = nn.Linear(dim, dim)
247
+
248
+ self.use_rel_pos = use_rel_pos
249
+ if self.use_rel_pos:
250
+ assert (
251
+ input_size is not None
252
+ ), "Input size must be provided if using relative positional encoding."
253
+ # initialize relative positional embeddings
254
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
255
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
256
+
257
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
258
+ B, H, W, _ = x.shape
259
+ # qkv with shape (3, B, nHead, H * W, C)
260
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
261
+ # q, k, v with shape (B * nHead, H * W, C)
262
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
263
+
264
+ attn = (q * self.scale) @ k.transpose(-2, -1)
265
+
266
+ if self.use_rel_pos:
267
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
268
+
269
+ attn = attn.softmax(dim=-1)
270
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
271
+ x = self.proj(x)
272
+
273
+ return x
274
+
275
+
276
+ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
277
+ """
278
+ Partition into non-overlapping windows with padding if needed.
279
+ Args:
280
+ x (tensor): input tokens with [B, H, W, C].
281
+ window_size (int): window size.
282
+
283
+ Returns:
284
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
285
+ (Hp, Wp): padded height and width before partition
286
+ """
287
+ B, H, W, C = x.shape
288
+
289
+ pad_h = (window_size - H % window_size) % window_size
290
+ pad_w = (window_size - W % window_size) % window_size
291
+ if pad_h > 0 or pad_w > 0:
292
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
293
+ Hp, Wp = H + pad_h, W + pad_w
294
+
295
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
296
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
297
+ return windows, (Hp, Wp)
298
+
299
+
300
+ def window_unpartition(
301
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
302
+ ) -> torch.Tensor:
303
+ """
304
+ Window unpartition into original sequences and removing padding.
305
+ Args:
306
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
307
+ window_size (int): window size.
308
+ pad_hw (Tuple): padded height and width (Hp, Wp).
309
+ hw (Tuple): original height and width (H, W) before padding.
310
+
311
+ Returns:
312
+ x: unpartitioned sequences with [B, H, W, C].
313
+ """
314
+ Hp, Wp = pad_hw
315
+ H, W = hw
316
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
317
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
318
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
319
+
320
+ if Hp > H or Wp > W:
321
+ x = x[:, :H, :W, :].contiguous()
322
+ return x
323
+
324
+
325
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
326
+ """
327
+ Get relative positional embeddings according to the relative positions of
328
+ query and key sizes.
329
+ Args:
330
+ q_size (int): size of query q.
331
+ k_size (int): size of key k.
332
+ rel_pos (Tensor): relative position embeddings (L, C).
333
+
334
+ Returns:
335
+ Extracted positional embeddings according to relative positions.
336
+ """
337
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
338
+ # Interpolate rel pos if needed.
339
+ if rel_pos.shape[0] != max_rel_dist:
340
+ # Interpolate rel pos.
341
+ rel_pos_resized = F.interpolate(
342
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
343
+ size=max_rel_dist,
344
+ mode="linear",
345
+ )
346
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
347
+ else:
348
+ rel_pos_resized = rel_pos
349
+
350
+ # Scale the coords with short length if shapes for q and k are different.
351
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
352
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
353
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
354
+
355
+ return rel_pos_resized[relative_coords.long()]
356
+
357
+
358
+ def add_decomposed_rel_pos(
359
+ attn: torch.Tensor,
360
+ q: torch.Tensor,
361
+ rel_pos_h: torch.Tensor,
362
+ rel_pos_w: torch.Tensor,
363
+ q_size: Tuple[int, int],
364
+ k_size: Tuple[int, int],
365
+ ) -> torch.Tensor:
366
+ """
367
+ Args:
368
+ attn (Tensor): attention map.
369
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
370
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
371
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
372
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
373
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
374
+
375
+ Returns:
376
+ attn (Tensor): attention map with added relative positional embeddings.
377
+ """
378
+ q_h, q_w = q_size
379
+ k_h, k_w = k_size
380
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
381
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
382
+
383
+ B, _, dim = q.shape
384
+ r_q = q.reshape(B, q_h, q_w, dim)
385
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
386
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
387
+
388
+ attn = (
389
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
390
+ ).view(B, q_h * q_w, k_h * k_w)
391
+
392
+ return attn
393
+
394
+
395
+ class PatchEmbed(nn.Module):
396
+ """
397
+ Image to Patch Embedding.
398
+ """
399
+
400
+ def __init__(
401
+ self,
402
+ kernel_size: Tuple[int, int] = (16, 16),
403
+ stride: Tuple[int, int] = (16, 16),
404
+ padding: Tuple[int, int] = (0, 0),
405
+ in_chans: int = 3,
406
+ embed_dim: int = 768,
407
+ ) -> None:
408
+ """
409
+ Args:
410
+ kernel_size (Tuple): kernel size of the projection layer.
411
+ stride (Tuple): stride of the projection layer.
412
+ padding (Tuple): padding size of the projection layer.
413
+ in_chans (int): Number of input image channels.
414
+ embed_dim (int): Patch embedding dimension.
415
+ """
416
+ super().__init__()
417
+
418
+ self.proj = nn.Conv2d(
419
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
420
+ )
421
+
422
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
423
+ x = self.proj(x)
424
+ # B C H W -> B H W C
425
+ x = x.permute(0, 2, 3, 1)
426
+ return x
427
+
428
+
429
+
430
+ def build_SAM_vit_b(checkpoint=None):
431
+ return _build_SAM_vision(
432
+ encoder_embed_dim=768,
433
+ encoder_depth=12,
434
+ encoder_num_heads=12,
435
+ encoder_global_attn_indexes=[2, 5, 8, 11],
436
+ checkpoint=checkpoint,
437
+ )
438
+
439
+
440
+ def _build_SAM_vision(
441
+ encoder_embed_dim,
442
+ encoder_depth,
443
+ encoder_num_heads,
444
+ encoder_global_attn_indexes,
445
+ checkpoint=None,
446
+ ):
447
+ prompt_embed_dim = 256
448
+ image_size = 1024
449
+ vit_patch_size = 16
450
+ image_embedding_size = image_size // vit_patch_size
451
+ image_encoder=ImageEncoderViT(
452
+ depth=encoder_depth,
453
+ embed_dim=encoder_embed_dim,
454
+ img_size=image_size,
455
+ mlp_ratio=4,
456
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
457
+ num_heads=encoder_num_heads,
458
+ patch_size=vit_patch_size,
459
+ qkv_bias=True,
460
+ use_rel_pos=True,
461
+ global_attn_indexes=encoder_global_attn_indexes,
462
+ window_size=14,
463
+ out_chans=prompt_embed_dim,
464
+ )
465
+
466
+
467
+ return image_encoder
468
+