from transformers import OPTConfig, OPTModel, OPTForCausalLM, StoppingCriteria, TextStreamer from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from typing import List, Optional, Tuple, Union import requests from PIL import Image from io import BytesIO import json import re import torch import numpy as np import torch.nn as nn from torch.nn import CrossEntropyLoss import torch.nn.functional as F from .sam_vision_b import build_SAM_vit_b from torchvision import transforms from torchvision.transforms.functional import InterpolationMode import dataclasses DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = '' DEFAULT_IM_START_TOKEN = '' DEFAULT_IM_END_TOKEN = '' from enum import auto, Enum class SeparatorStyle(Enum): """Different separator style.""" SINGLE = auto() TWO = auto() MPT = auto() @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" system: str roles: List[str] messages: List[List[str]] offset: int sep_style: SeparatorStyle = SeparatorStyle.SINGLE sep: str = "<|im_end|>" sep2: str = None version: str = "Unknown" skip_next: bool = False def get_prompt(self): if self.sep_style == SeparatorStyle.SINGLE: ret = self.system + self.sep + '\n' for role, message in self.messages: if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + self.sep else: ret += role + ":" return ret elif self.sep_style == SeparatorStyle.TWO: seps = [self.sep, self.sep2] ret = self.system + seps[0] for i, (role, message) in enumerate(self.messages): if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + seps[i % 2] else: ret += role + ":" return ret if self.sep_style == SeparatorStyle.MPT: if self.system: ret = self.system + self.sep else: ret = '' for role, message in self.messages: if message: if type(message) is tuple: message, _, _ = message ret += role + message + self.sep else: ret += role return ret else: raise ValueError(f"Invalid style: {self.sep_style}") def append_message(self, role, message): self.messages.append([role, message]) def copy(self): return Conversation( system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2) class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords] self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1] self.tokenizer = tokenizer self.start_len = None self.input_ids = input_ids def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: if self.start_len is None: self.start_len = self.input_ids.shape[1] else: for keyword_id in self.keyword_ids: if output_ids[0, -1] == keyword_id: return True outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False conv_vicuna_v1_1 = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) class OneChartImageEvalProcessor: def __init__(self, image_size=1024): mean = (0., 0., 0.) std = (1., 1., 1.) self.normalize = transforms.Normalize(mean, std) self.transform = transforms.Compose( [ transforms.Resize( (image_size, image_size), interpolation=InterpolationMode.BICUBIC ), transforms.ToTensor(), self.normalize, ] ) def __call__(self, item): return self.transform(item) class OneChartConfig(OPTConfig): model_type = "OneChart" class OneChartModel(OPTModel): config_class = OneChartConfig def __init__(self, config: OPTConfig): super(OneChartModel, self).__init__(config) self.vision_tower = build_SAM_vit_b() self.mm_projector = nn.Linear(1024, 768) def embed_tokens(self, x): return self.get_input_embeddings()(x) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) vision_tower_high = getattr(self, 'vision_tower', None) if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None: use_im_start_end = getattr(self.config, "use_im_start_end", -1) vision_select_layer = getattr(self.config, "vision_select_layer", -1) im_patch_token = getattr(self.config, "im_patch_token", -1) im_start_token = getattr(self.config, "im_start_token", -1) im_end_token = getattr(self.config, "im_end_token", -1) freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False) image_features = [] for image in images: P, C, H, W = image.shape if P == 1: with torch.set_grad_enabled(False): cnn_feature = vision_tower_high(image) cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024 image_feature = self.mm_projector(cnn_feature) image_features.append(image_feature) else: raise NotImplementedError("Batch inference needs to be implemented.") use_im_start_end = True new_input_embeds = [] for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features): if use_im_start_end: if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum(): raise ValueError("The number of image start tokens and image end tokens should be the same.") image_start_tokens = torch.where(cur_input_ids == im_start_token)[0] for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features): per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device) num_patches = per_cur_image_features.shape[0] if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token: raise ValueError("The image end token should follow the image start token.") cur_input_embeds = torch.cat( ( cur_input_embeds[:image_start_token_pos+1], per_cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:] ), dim=0 ) new_input_embeds.append(cur_input_embeds) else: raise NotImplementedError inputs_embeds = torch.stack(new_input_embeds, dim=0) return super(OneChartModel, self).forward( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) class OneChartOPTForCausalLM(OPTForCausalLM): config_class = OneChartConfig def __init__(self, config): super(OneChartOPTForCausalLM, self).__init__(config) self.model = OneChartModel(config) self.vocab_size = config.vocab_size self.num_decoder = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size // 2), nn.ReLU(), nn.Linear(config.hidden_size // 2, config.hidden_size // 2), nn.ReLU(), nn.Linear(config.hidden_size // 2, 256), ) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.pred_locs = [] # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, loc_labels=None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, images=images, return_dict=return_dict ) hidden_states = outputs[0] if (loc_labels is not None) and len(loc_labels) > 0: det_patch_token = torch.where(input_ids == self.config.number_token)[1][0] pred_locs = self.num_decoder(hidden_states[:, det_patch_token, :]) # shape: [batch_size, 256] # inference时输出num_head预测的值 if not self.training: try: det_patch_token = torch.where(input_ids == self.config.number_token)[1][0] pred_locs = self.num_decoder(hidden_states[:, det_patch_token, :]) # shape: [batch_size, 256] self.pred_locs = pred_locs[0][:100].cpu().tolist() except Exception as e: pass logits = self.lm_head(hidden_states) logits = logits.float() # logits loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs ): token_type_ids = kwargs.get("token_type_ids", None) if past_key_values: input_ids = input_ids[:, -1].unsqueeze(-1) if token_type_ids is not None: token_type_ids = token_type_ids[:, -1].unsqueeze(-1) attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -1].unsqueeze(-1) else: position_ids = None if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "position_ids": position_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "images": kwargs.get("images", None), } ) return model_inputs def load_image(self, image_file): if image_file.startswith('http') or image_file.startswith('https'): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_file).convert('RGB') return image def disable_torch_init(self): """ Disable the redundant torch default initialization to accelerate model creation. """ setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def chat(self, tokenizer, image_file, reliable_check=True, print_prompt=False): dtype=torch.bfloat16 device="cuda" def list_json_value(json_dict): rst_str = [] sort_flag = True try: for key, value in json_dict.items(): if isinstance(value, dict): decimal_out = list_json_value(value) rst_str = rst_str + decimal_out sort_flag = False elif isinstance(value, list): return [] else: if isinstance(value, float) or isinstance(value, int): rst_str.append(value) else: # num_value = value.replace("%", "").replace("$", "").replace(" ", "").replace(",", "") value = re.sub(r'\(\d+\)|\[\d+\]', '', value) num_value = re.sub(r'[^\d.-]', '', str(value)) if num_value not in ["-", "*", "none", "None", ""]: rst_str.append(float(num_value)) except Exception as e: print(f"Error: {e}") # print(json_dict) return [] # if len(rst_str) > 0: # rst_str = rst_str + [float(-1)] return rst_str def norm_(rst_list): if len(rst_list) < 2: return rst_list min_vals = min(rst_list) max_vals = max(rst_list) rst_list = np.array(rst_list) normalized_tensor = (rst_list - min_vals) / (max_vals - min_vals + 1e-9) return list(normalized_tensor) self.disable_torch_init() image_processor_high = OneChartImageEvalProcessor(image_size=1024) use_im_start_end = True image_token_len = 256 image = self.load_image(image_file) image_tensor_1 = image_processor_high(image).to(dtype=dtype, device=device) query = 'Convert the key information of the chart to a python dict:' qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + query + '\n' conv = conv_vicuna_v1_1.copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() if print_prompt: print(prompt) inputs = tokenizer([prompt]) input_ids = torch.as_tensor(inputs.input_ids).to(device=device) stop_str = '' keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) with torch.autocast(device, dtype=dtype): output_ids = self.generate( input_ids, images=[image_tensor_1.unsqueeze(0).half()], do_sample=False, num_beams = 1, # no_repeat_ngram_size = 20, # streamer=streamer, max_new_tokens=4096, stopping_criteria=[stopping_criteria] ) outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True) outputs = outputs.replace("", "") outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] response_str = outputs if reliable_check: pred_nums = self.pred_locs try: outputs_json = json.loads(outputs) list_v = list_json_value(outputs_json['values']) list_v = [round(x,4) for x in norm_(list_v)] gt_nums = torch.tensor(list_v).reshape(1,-1) response_str = response_str + "\n: " + str(pred_nums[:len(list_v)]) pred_nums_ = torch.tensor(pred_nums[:len(list_v)]).reshape(1,-1) reliable_distence = F.l1_loss(pred_nums_, gt_nums) response_str = response_str + "\nreliable_distence: " + str(reliable_distence) if reliable_distence < 0.1: response_str = response_str + "\nAfter OneChart checking, this prediction is reliable." else: response_str = response_str + "\nThis prediction may be has error! " except Exception as e: response_str = response_str + "\nThis prediction may be has error! " response_str = response_str + "\n" + str(e) return response_str