File size: 4,606 Bytes
c7a0be3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
from dataclasses import dataclass
from transformers.utils import ModelOutput
from typing import Optional
from .modeling_minicpmv import MiniCPMV
from .modeling_minicpm import MiniCPMForCausalLM
from .resampler import Resampler
from concurrent.futures import ThreadPoolExecutor
def transform_image_mp(img_list, transform, device, max_workers=None):
pixel_values = []
# 使用ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for img_batch in img_list:
img_inps = list(executor.map(transform, img_batch))
for i in range(len(img_inps)):
img_inps[i] = img_inps[i].to(device)
pixel_values.append(img_inps if img_inps else [])
return pixel_values
@dataclass
class BaseModelOutputWithAttentionMask(ModelOutput):
last_hidden_state: torch.FloatTensor = None
attention_mask: Optional[torch.Tensor] = None
class VisRAG_Ret(MiniCPMV): # -> MiniCPMV -> Ultimately a CausalLM
def fused_tokenize(
self,
data_list=None, # List[str]
img_list=None, # List[List[PIL.Image]]
tokenizer=None,
max_inp_length: Optional[int] = None,
vision_hidden_states=None, # default None
return_vision_hidden_states=False,
**kwargs):
assert data_list is not None
bs = len(data_list)
if img_list == None:
img_list = [[] for i in range(bs)]
assert bs == len(img_list)
model_inputs = self._process_list(tokenizer, data_list, max_inp_length, padding_side="right")
if vision_hidden_states is None:
pixel_values = transform_image_mp(img_list, self.transform, self.device, max_workers=8)
model_inputs["pixel_values"] = pixel_values
else:
model_inputs["vision_hidden_states"] = vision_hidden_states
return model_inputs
def prepare_context(self, inputs, tokenizer):
text_, image_ = inputs
if not isinstance(text_, str):
raise NotImplementedError(f"chatml format expected, expect outmost type to be str but got {type(text_)}")
# 1.add text
content = text_
# 2. add image
if image_:
if self.config.slice_mode:
images, final_placeholder = self.get_slice_image_placeholder(
image_, tokenizer
) # crop one image into multiple sub images -> List[Image]
content = final_placeholder + "\n" + content
else:
images = [image_] # only keep one image without cropping -> List[Image]
content = (
tokenizer.im_start
+ tokenizer.unk_token * self.config.query_num
+ tokenizer.im_end
+ "\n"
+ content
)
else:
images = []
return content, images
def forward(
self,
text, # List[str] B*str
image, # List[ PIL.Image ] B*PIL.Image, one image for each data
tokenizer,
vision_hidden_states=None,
max_inp_length=2048,
**kwargs):
processed_image = []
processed_text = []
with ThreadPoolExecutor(max_workers=8) as executor:
contexts = list(executor.map(lambda inputs: self.prepare_context(inputs, tokenizer), zip(text, image)))
for context in contexts:
content_, image_ = context
processed_text.append(content_)
processed_image.append(image_)
model_inputs = self.fused_tokenize(
data_list=processed_text, # List[str]
img_list=processed_image, # List[List[PIL.Image]]
tokenizer=tokenizer,
max_inp_length=max_inp_length
)
# this is vision encoder forward.
model_inputs["inputs_embeds"], vision_hidden_states = self.get_vllm_embedding(model_inputs)
vlm_outputs = self.llm.model(
input_ids=None, # because image and text have been merged into model_inputs["inputs_embeds"] here, we don't give input_ids
position_ids=None,
inputs_embeds=model_inputs["inputs_embeds"],
attention_mask=model_inputs["attention_mask"],
return_dict=True
)
return BaseModelOutputWithAttentionMask(
last_hidden_state=vlm_outputs.last_hidden_state,
attention_mask=model_inputs.attention_mask
)
|