Allow for multiple images per batch sample.
Browse files- processing_phi3_v.py +24 -9
processing_phi3_v.py
CHANGED
@@ -20,14 +20,19 @@ import re
|
|
20 |
from typing import List, Optional, Union
|
21 |
|
22 |
import torch
|
23 |
-
|
24 |
import transformers
|
25 |
from transformers.feature_extraction_utils import BatchFeature
|
26 |
from transformers.image_utils import ImageInput
|
27 |
from transformers.processing_utils import ProcessorMixin
|
28 |
-
from transformers.tokenization_utils_base import
|
|
|
|
|
|
|
|
|
29 |
from transformers.utils import TensorType
|
30 |
-
|
|
|
|
|
31 |
transformers.Phi3VImageProcessor = Phi3VImageProcessor
|
32 |
|
33 |
class Phi3VProcessor(ProcessorMixin):
|
@@ -144,13 +149,25 @@ class Phi3VProcessor(ProcessorMixin):
|
|
144 |
return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
|
145 |
|
146 |
def _convert_images_texts_to_inputs(self, images, texts, padding=False, truncation=None, max_length=None, return_tensors=None):
|
147 |
-
|
148 |
if not len(images):
|
149 |
model_inputs = self.tokenizer(texts, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length)
|
150 |
return BatchFeature(data={**model_inputs})
|
151 |
|
152 |
pattern = r"<\|image_\d+\|>"
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
if 'num_img_tokens' in images:
|
156 |
num_img_tokens = images['num_img_tokens']
|
@@ -161,10 +178,8 @@ class Phi3VProcessor(ProcessorMixin):
|
|
161 |
|
162 |
images, image_sizes = images['pixel_values'], images['image_sizes']
|
163 |
|
164 |
-
|
165 |
-
|
166 |
-
# image_ids = [int(s.split("|")[1].split("_")[-1]) * -1 for s in image_tags]
|
167 |
-
# image_ids_pad = [[iid]*num_img_tokens[i] for i, iid in enumerate(image_ids)]
|
168 |
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
|
169 |
unique_image_ids = sorted(list(set(image_ids)))
|
170 |
# image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5]
|
|
|
20 |
from typing import List, Optional, Union
|
21 |
|
22 |
import torch
|
|
|
23 |
import transformers
|
24 |
from transformers.feature_extraction_utils import BatchFeature
|
25 |
from transformers.image_utils import ImageInput
|
26 |
from transformers.processing_utils import ProcessorMixin
|
27 |
+
from transformers.tokenization_utils_base import (
|
28 |
+
PaddingStrategy,
|
29 |
+
TextInput,
|
30 |
+
TruncationStrategy,
|
31 |
+
)
|
32 |
from transformers.utils import TensorType
|
33 |
+
|
34 |
+
from .image_processing_phi3_v import Phi3VImageProcessor
|
35 |
+
|
36 |
transformers.Phi3VImageProcessor = Phi3VImageProcessor
|
37 |
|
38 |
class Phi3VProcessor(ProcessorMixin):
|
|
|
149 |
return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
|
150 |
|
151 |
def _convert_images_texts_to_inputs(self, images, texts, padding=False, truncation=None, max_length=None, return_tensors=None):
|
|
|
152 |
if not len(images):
|
153 |
model_inputs = self.tokenizer(texts, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length)
|
154 |
return BatchFeature(data={**model_inputs})
|
155 |
|
156 |
pattern = r"<\|image_\d+\|>"
|
157 |
+
|
158 |
+
# Don't over list-comprehend this, it's already hard to read.
|
159 |
+
prompt_chunks = []
|
160 |
+
image_tags = []
|
161 |
+
for text in texts:
|
162 |
+
chunks = re.split(pattern, text)
|
163 |
+
chunk_image_tags = re.findall(pattern, text)
|
164 |
+
for chunk in chunks:
|
165 |
+
tokenized_chunk = self.tokenizer(chunk).input_ids
|
166 |
+
prompt_chunks.append(tokenized_chunk)
|
167 |
+
for tag in chunk_image_tags:
|
168 |
+
image_tags.append(tag)
|
169 |
+
|
170 |
+
|
171 |
|
172 |
if 'num_img_tokens' in images:
|
173 |
num_img_tokens = images['num_img_tokens']
|
|
|
178 |
|
179 |
images, image_sizes = images['pixel_values'], images['image_sizes']
|
180 |
|
181 |
+
|
182 |
+
# image_tags needs to start from 1 to num_images
|
|
|
|
|
183 |
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
|
184 |
unique_image_ids = sorted(list(set(image_ids)))
|
185 |
# image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5]
|