Allow for multiple images per batch sample.

#17
Files changed (1) hide show
  1. processing_phi3_v.py +24 -9
processing_phi3_v.py CHANGED
@@ -20,14 +20,19 @@ import re
20
  from typing import List, Optional, Union
21
 
22
  import torch
23
-
24
  import transformers
25
  from transformers.feature_extraction_utils import BatchFeature
26
  from transformers.image_utils import ImageInput
27
  from transformers.processing_utils import ProcessorMixin
28
- from transformers.tokenization_utils_base import PaddingStrategy, TextInput, TruncationStrategy
 
 
 
 
29
  from transformers.utils import TensorType
30
- from .image_processing_phi3_v import Phi3VImageProcessor
 
 
31
  transformers.Phi3VImageProcessor = Phi3VImageProcessor
32
 
33
  class Phi3VProcessor(ProcessorMixin):
@@ -144,13 +149,25 @@ class Phi3VProcessor(ProcessorMixin):
144
  return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
145
 
146
  def _convert_images_texts_to_inputs(self, images, texts, padding=False, truncation=None, max_length=None, return_tensors=None):
147
-
148
  if not len(images):
149
  model_inputs = self.tokenizer(texts, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length)
150
  return BatchFeature(data={**model_inputs})
151
 
152
  pattern = r"<\|image_\d+\|>"
153
- prompt_chunks = [self.tokenizer(chunk).input_ids for chunk in re.split(pattern, texts)]
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  if 'num_img_tokens' in images:
156
  num_img_tokens = images['num_img_tokens']
@@ -161,10 +178,8 @@ class Phi3VProcessor(ProcessorMixin):
161
 
162
  images, image_sizes = images['pixel_values'], images['image_sizes']
163
 
164
- # image_tags needs to start from 1 to n
165
- image_tags = re.findall(pattern, texts)
166
- # image_ids = [int(s.split("|")[1].split("_")[-1]) * -1 for s in image_tags]
167
- # image_ids_pad = [[iid]*num_img_tokens[i] for i, iid in enumerate(image_ids)]
168
  image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
169
  unique_image_ids = sorted(list(set(image_ids)))
170
  # image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5]
 
20
  from typing import List, Optional, Union
21
 
22
  import torch
 
23
  import transformers
24
  from transformers.feature_extraction_utils import BatchFeature
25
  from transformers.image_utils import ImageInput
26
  from transformers.processing_utils import ProcessorMixin
27
+ from transformers.tokenization_utils_base import (
28
+ PaddingStrategy,
29
+ TextInput,
30
+ TruncationStrategy,
31
+ )
32
  from transformers.utils import TensorType
33
+
34
+ from .image_processing_phi3_v import Phi3VImageProcessor
35
+
36
  transformers.Phi3VImageProcessor = Phi3VImageProcessor
37
 
38
  class Phi3VProcessor(ProcessorMixin):
 
149
  return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
150
 
151
  def _convert_images_texts_to_inputs(self, images, texts, padding=False, truncation=None, max_length=None, return_tensors=None):
 
152
  if not len(images):
153
  model_inputs = self.tokenizer(texts, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length)
154
  return BatchFeature(data={**model_inputs})
155
 
156
  pattern = r"<\|image_\d+\|>"
157
+
158
+ # Don't over list-comprehend this, it's already hard to read.
159
+ prompt_chunks = []
160
+ image_tags = []
161
+ for text in texts:
162
+ chunks = re.split(pattern, text)
163
+ chunk_image_tags = re.findall(pattern, text)
164
+ for chunk in chunks:
165
+ tokenized_chunk = self.tokenizer(chunk).input_ids
166
+ prompt_chunks.append(tokenized_chunk)
167
+ for tag in chunk_image_tags:
168
+ image_tags.append(tag)
169
+
170
+
171
 
172
  if 'num_img_tokens' in images:
173
  num_img_tokens = images['num_img_tokens']
 
178
 
179
  images, image_sizes = images['pixel_values'], images['image_sizes']
180
 
181
+
182
+ # image_tags needs to start from 1 to num_images
 
 
183
  image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
184
  unique_image_ids = sorted(list(set(image_ids)))
185
  # image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5]