|
import torch |
|
from dataclasses import dataclass |
|
from transformers.utils import ModelOutput |
|
from typing import Optional |
|
from .modeling_minicpmv import MiniCPMV |
|
from .modeling_minicpm import MiniCPMForCausalLM |
|
from .resampler import Resampler |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
|
|
def transform_image_mp(img_list, transform, device, max_workers=None): |
|
pixel_values = [] |
|
|
|
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
for img_batch in img_list: |
|
img_inps = list(executor.map(transform, img_batch)) |
|
for i in range(len(img_inps)): |
|
img_inps[i] = img_inps[i].to(device) |
|
pixel_values.append(img_inps if img_inps else []) |
|
|
|
return pixel_values |
|
|
|
|
|
@dataclass |
|
class BaseModelOutputWithAttentionMask(ModelOutput): |
|
last_hidden_state: torch.FloatTensor = None |
|
attention_mask: Optional[torch.Tensor] = None |
|
|
|
class VisRAG_Ret(MiniCPMV): |
|
def fused_tokenize( |
|
self, |
|
data_list=None, |
|
img_list=None, |
|
tokenizer=None, |
|
max_inp_length: Optional[int] = None, |
|
vision_hidden_states=None, |
|
return_vision_hidden_states=False, |
|
**kwargs): |
|
|
|
assert data_list is not None |
|
bs = len(data_list) |
|
if img_list == None: |
|
img_list = [[] for i in range(bs)] |
|
assert bs == len(img_list) |
|
|
|
model_inputs = self._process_list(tokenizer, data_list, max_inp_length, padding_side="right") |
|
|
|
if vision_hidden_states is None: |
|
pixel_values = transform_image_mp(img_list, self.transform, self.device, max_workers=8) |
|
model_inputs["pixel_values"] = pixel_values |
|
else: |
|
model_inputs["vision_hidden_states"] = vision_hidden_states |
|
|
|
return model_inputs |
|
|
|
def prepare_context(self, inputs, tokenizer): |
|
text_, image_ = inputs |
|
if not isinstance(text_, str): |
|
raise NotImplementedError(f"chatml format expected, expect outmost type to be str but got {type(text_)}") |
|
|
|
|
|
content = text_ |
|
|
|
|
|
if image_: |
|
if self.config.slice_mode: |
|
images, final_placeholder = self.get_slice_image_placeholder( |
|
image_, tokenizer |
|
) |
|
content = final_placeholder + "\n" + content |
|
else: |
|
images = [image_] |
|
content = ( |
|
tokenizer.im_start |
|
+ tokenizer.unk_token * self.config.query_num |
|
+ tokenizer.im_end |
|
+ "\n" |
|
+ content |
|
) |
|
else: |
|
images = [] |
|
|
|
return content, images |
|
|
|
def forward( |
|
self, |
|
text, |
|
image, |
|
tokenizer, |
|
vision_hidden_states=None, |
|
max_inp_length=2048, |
|
**kwargs): |
|
|
|
processed_image = [] |
|
processed_text = [] |
|
|
|
with ThreadPoolExecutor(max_workers=8) as executor: |
|
contexts = list(executor.map(lambda inputs: self.prepare_context(inputs, tokenizer), zip(text, image))) |
|
|
|
for context in contexts: |
|
content_, image_ = context |
|
processed_text.append(content_) |
|
processed_image.append(image_) |
|
|
|
model_inputs = self.fused_tokenize( |
|
data_list=processed_text, |
|
img_list=processed_image, |
|
tokenizer=tokenizer, |
|
max_inp_length=max_inp_length |
|
) |
|
|
|
|
|
model_inputs["inputs_embeds"], vision_hidden_states = self.get_vllm_embedding(model_inputs) |
|
vlm_outputs = self.llm.model( |
|
input_ids=None, |
|
position_ids=None, |
|
inputs_embeds=model_inputs["inputs_embeds"], |
|
attention_mask=model_inputs["attention_mask"], |
|
return_dict=True |
|
) |
|
|
|
return BaseModelOutputWithAttentionMask( |
|
last_hidden_state=vlm_outputs.last_hidden_state, |
|
attention_mask=model_inputs.attention_mask |
|
) |
|
|
|
|