princepride commited on
Commit
8fc9247
1 Parent(s): 047f82c

Upload image_processing_minicpmv.py

Browse files
Files changed (1) hide show
  1. image_processing_minicpmv.py +418 -0
image_processing_minicpmv.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Dict, Any, List
2
+
3
+ import torch
4
+ import math
5
+ import PIL.Image
6
+ import PIL.ImageSequence
7
+ import numpy as np
8
+ import PIL
9
+ from PIL import Image
10
+
11
+ from transformers.utils import TensorType, requires_backends, is_torch_dtype, is_torch_device
12
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
13
+ from transformers import AutoImageProcessor
14
+ from transformers.image_transforms import to_channel_dimension_format
15
+ from transformers.image_utils import (
16
+ ImageInput,
17
+ make_list_of_images,
18
+ valid_images,
19
+ is_torch_tensor,
20
+ is_batched,
21
+ to_numpy_array,
22
+ infer_channel_dimension_format,
23
+ ChannelDimension
24
+ )
25
+
26
+
27
+ def recursive_converter(converter, value):
28
+ if isinstance(value, list):
29
+ new_value = []
30
+ for v in value:
31
+ new_value += [recursive_converter(converter, v)]
32
+ return new_value
33
+ else:
34
+ return converter(value)
35
+
36
+
37
+ class MiniCPMVBatchFeature(BatchFeature):
38
+ r"""
39
+ Extend from BatchFeature for supporting various image size
40
+ """
41
+ def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
42
+ super().__init__(data)
43
+ self.convert_to_tensors(tensor_type=tensor_type)
44
+
45
+ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
46
+ if tensor_type is None:
47
+ return self
48
+
49
+ is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
50
+
51
+ def converter(value):
52
+ try:
53
+ if not is_tensor(value):
54
+ tensor = as_tensor(value)
55
+ return tensor
56
+ except: # noqa E722
57
+ if key == "overflowing_values":
58
+ raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
59
+ raise ValueError(
60
+ "Unable to create tensor, you should probably activate padding "
61
+ "with 'padding=True' to have batched tensors with the same length."
62
+ )
63
+
64
+
65
+ for key, value in self.items():
66
+ self[key] = recursive_converter(converter, value)
67
+ return self
68
+
69
+ def to(self, *args, **kwargs) -> "MiniCPMVBatchFeature":
70
+ requires_backends(self, ["torch"])
71
+ import torch
72
+
73
+ def cast_tensor(v):
74
+ # check if v is a floating point
75
+ if torch.is_floating_point(v):
76
+ # cast and send to device
77
+ return v.to(*args, **kwargs)
78
+ elif device is not None:
79
+ return v.to(device=device)
80
+ else:
81
+ return v
82
+
83
+ new_data = {}
84
+ device = kwargs.get("device")
85
+ # Check if the args are a device or a dtype
86
+ if device is None and len(args) > 0:
87
+ # device should be always the first argument
88
+ arg = args[0]
89
+ if is_torch_dtype(arg):
90
+ # The first argument is a dtype
91
+ pass
92
+ elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
93
+ device = arg
94
+ else:
95
+ # it's something else
96
+ raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
97
+ # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
98
+ for k, v in self.items():
99
+ new_data[k] = recursive_converter(cast_tensor, v)
100
+ self.data = new_data
101
+ return self
102
+
103
+
104
+ class MiniCPMVImageProcessor(BaseImageProcessor):
105
+ model_input_names = ["pixel_values"]
106
+
107
+ def __init__(
108
+ self,
109
+ max_slice_nums=9,
110
+ scale_resolution=448,
111
+ patch_size=14,
112
+ **kwargs):
113
+ super().__init__(**kwargs)
114
+ self.max_slice_nums = max_slice_nums
115
+ self.scale_resolution = scale_resolution
116
+ self.patch_size = patch_size
117
+ self.use_image_id = kwargs.pop("use_image_id", False)
118
+ self.image_feature_size = kwargs.pop("image_feature_size", 64)
119
+ self.im_start_token = kwargs.pop("im_start", "<image>")
120
+ self.im_end_token = kwargs.pop("im_end", "</image>")
121
+ self.slice_start_token = kwargs.pop("slice_start", "<slice>")
122
+ self.slice_end_token = kwargs.pop("slice_end", "</slice>")
123
+ self.unk_token = kwargs.pop("unk", "<unk>")
124
+ self.im_id_start = kwargs.pop("im_id_start", "<image_id>")
125
+ self.im_id_end = kwargs.pop("im_id_end", "</image_id>")
126
+ self.slice_mode = kwargs.pop("slice_mode", True)
127
+ self.mean = np.array(kwargs.pop("norm_mean", [0.5, 0.5, 0.5]))
128
+ self.std = np.array(kwargs.pop("norm_std", [0.5, 0.5, 0.5]))
129
+ self.version = kwargs.pop("version", 2.0)
130
+
131
+ def ensure_divide(self, length, patch_size):
132
+ return max(round(length / patch_size) * patch_size, patch_size)
133
+
134
+ def find_best_resize(self,
135
+ original_size,
136
+ scale_resolution,
137
+ patch_size,
138
+ allow_upscale=False):
139
+ width, height = original_size
140
+ if (width * height >
141
+ scale_resolution * scale_resolution) or allow_upscale:
142
+ r = width / height
143
+ height = int(scale_resolution / math.sqrt(r))
144
+ width = int(height * r)
145
+ best_width = self.ensure_divide(width, patch_size)
146
+ best_height = self.ensure_divide(height, patch_size)
147
+ return (best_width, best_height)
148
+
149
+ def get_refine_size(self,
150
+ original_size,
151
+ grid,
152
+ scale_resolution,
153
+ patch_size,
154
+ allow_upscale=False):
155
+ width, height = original_size
156
+ grid_x, grid_y = grid
157
+
158
+ refine_width = self.ensure_divide(width, grid_x)
159
+ refine_height = self.ensure_divide(height, grid_y)
160
+
161
+ grid_width = refine_width / grid_x
162
+ grid_height = refine_height / grid_y
163
+
164
+ best_grid_size = self.find_best_resize((grid_width, grid_height),
165
+ scale_resolution,
166
+ patch_size,
167
+ allow_upscale=allow_upscale)
168
+ refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
169
+ return refine_size
170
+
171
+ def split_to_patches(self, image, grid):
172
+ patches = []
173
+ width, height = image.size
174
+ grid_x = int(width / grid[0])
175
+ grid_y = int(height / grid[1])
176
+ for i in range(0, height, grid_y):
177
+ images = []
178
+ for j in range(0, width, grid_x):
179
+ box = (j, i, j + grid_x, i + grid_y)
180
+ patch = image.crop(box)
181
+ images.append(patch)
182
+ patches.append(images)
183
+ return patches
184
+
185
+ def slice_image(
186
+ self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False
187
+ ):
188
+ original_size = image.size
189
+ source_image = None
190
+ best_grid = self.get_sliced_grid(original_size, max_slice_nums, never_split)
191
+ patches = []
192
+
193
+ if best_grid is None:
194
+ # dont need to slice, upsample
195
+ best_size = self.find_best_resize(
196
+ original_size, scale_resolution, patch_size, allow_upscale=True
197
+ )
198
+ source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC)
199
+ else:
200
+ # source image, down-sampling and ensure divided by patch_size
201
+ best_resize = self.find_best_resize(original_size, scale_resolution, patch_size)
202
+ source_image = image.copy().resize(best_resize, resample=Image.Resampling.BICUBIC)
203
+ refine_size = self.get_refine_size(
204
+ original_size, best_grid, scale_resolution, patch_size, allow_upscale=True
205
+ )
206
+ refine_image = image.resize(refine_size, resample=Image.Resampling.BICUBIC)
207
+ patches = self.split_to_patches(refine_image, best_grid)
208
+
209
+ return source_image, patches, best_grid
210
+
211
+ def get_grid_placeholder(self, grid):
212
+ if grid is None:
213
+ return ""
214
+ slice_image_placeholder = (
215
+ self.slice_start_token
216
+ + self.unk_token * self.image_feature_size
217
+ + self.slice_end_token
218
+ )
219
+
220
+ cols = grid[0]
221
+ rows = grid[1]
222
+ slices = []
223
+ for i in range(rows):
224
+ lines = []
225
+ for j in range(cols):
226
+ lines.append(slice_image_placeholder)
227
+ slices.append("".join(lines))
228
+
229
+ slice_placeholder = "\n".join(slices)
230
+ return slice_placeholder
231
+
232
+ def get_image_id_placeholder(self, idx=0):
233
+ return f"{self.im_id_start}{idx}{self.im_id_end}"
234
+
235
+ def get_sliced_images(self, image, max_slice_nums=None):
236
+ slice_images = []
237
+
238
+ if not self.slice_mode:
239
+ return [image]
240
+
241
+ max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
242
+ assert max_slice_nums > 0
243
+ source_image, patches, sliced_grid = self.slice_image(
244
+ image,
245
+ max_slice_nums, # default: 9
246
+ self.scale_resolution, # default: 448
247
+ self.patch_size # default: 14
248
+ )
249
+
250
+ slice_images.append(source_image)
251
+ if len(patches) > 0:
252
+ for i in range(len(patches)):
253
+ for j in range(len(patches[0])):
254
+ slice_images.append(patches[i][j])
255
+ return slice_images
256
+
257
+ def get_sliced_grid(self, image_size, max_slice_nums, nerver_split=False):
258
+ original_width, original_height = image_size
259
+ log_ratio = math.log(original_width / original_height)
260
+ ratio = original_width * original_height / (self.scale_resolution * self.scale_resolution)
261
+ multiple = min(math.ceil(ratio), max_slice_nums)
262
+ if multiple <= 1 or nerver_split:
263
+ return None
264
+ candidate_split_grids_nums = []
265
+ for i in [multiple - 1, multiple, multiple + 1]:
266
+ if i == 1 or i > max_slice_nums:
267
+ continue
268
+ candidate_split_grids_nums.append(i)
269
+
270
+ candidate_grids = []
271
+ for split_grids_nums in candidate_split_grids_nums:
272
+ m = 1
273
+ while m <= split_grids_nums:
274
+ if split_grids_nums % m == 0:
275
+ candidate_grids.append([m, split_grids_nums // m])
276
+ m += 1
277
+
278
+ best_grid = [1, 1]
279
+ min_error = float("inf")
280
+ for grid in candidate_grids:
281
+ error = abs(log_ratio - math.log(grid[0] / grid[1]))
282
+ if error < min_error:
283
+ best_grid = grid
284
+ min_error = error
285
+
286
+ return best_grid
287
+
288
+ def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=None, use_image_id=None):
289
+ max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
290
+ assert max_slice_nums > 0
291
+ grid = self.get_sliced_grid(image_size=image_size, max_slice_nums=max_slice_nums)
292
+
293
+ image_placeholder = (
294
+ self.im_start_token
295
+ + self.unk_token * self.image_feature_size
296
+ + self.im_end_token
297
+ )
298
+ use_image_id = self.use_image_id if use_image_id is None else bool(use_image_id)
299
+ if use_image_id:
300
+ final_placeholder = self.get_image_id_placeholder(image_idx) + image_placeholder
301
+ else:
302
+ final_placeholder = image_placeholder
303
+
304
+ if self.slice_mode:
305
+ final_placeholder = final_placeholder + self.get_grid_placeholder(grid=grid)
306
+ return final_placeholder
307
+
308
+ def to_pil_image(self, image, rescale=None) -> PIL.Image.Image:
309
+ """
310
+ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
311
+ needed.
312
+
313
+ Args:
314
+ image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
315
+ The image to convert to the PIL Image format.
316
+ rescale (`bool`, *optional*):
317
+ Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
318
+ default to `True` if the image type is a floating type, `False` otherwise.
319
+ """
320
+ if isinstance(image, PIL.Image.Image):
321
+ return image
322
+ if is_torch_tensor(image):
323
+ image = image.numpy()
324
+
325
+ if isinstance(image, np.ndarray):
326
+ if rescale is None:
327
+ # rescale default to the array being of floating type.
328
+ rescale = isinstance(image.flat[0], np.floating)
329
+ # If the channel as been moved to first dim, we put it back at the end.
330
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
331
+ image = image.transpose(1, 2, 0)
332
+ if rescale:
333
+ image = image * 255
334
+ image = image.astype(np.uint8)
335
+ return PIL.Image.fromarray(image)
336
+ return image
337
+
338
+ def reshape_by_patch(self, image):
339
+ """
340
+ :param image: shape [3, H, W]
341
+ :param patch_size:
342
+ :return: [3, patch_size, HW/patch_size]
343
+ """
344
+ image = torch.from_numpy(image)
345
+ patch_size = self.patch_size
346
+ patches = torch.nn.functional.unfold(
347
+ image,
348
+ (patch_size, patch_size),
349
+ stride=(patch_size, patch_size)
350
+ )
351
+
352
+ patches = patches.reshape(image.size(0), patch_size, patch_size, -1)
353
+ patches = patches.permute(0, 1, 3, 2).reshape(image.size(0), patch_size, -1)
354
+ return patches.numpy()
355
+
356
+ def preprocess(
357
+ self,
358
+ images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]],
359
+ do_pad: Optional[bool] = True, # TODO: add pad for MiniCPM-Llama3-V-2_5
360
+ max_slice_nums: int = None,
361
+ return_tensors: Optional[Union[str, TensorType]] = None,
362
+ **kwargs
363
+ ) -> MiniCPMVBatchFeature:
364
+ if isinstance(images, Image.Image):
365
+ images_list = [[images]]
366
+ elif isinstance(images[0], Image.Image):
367
+ images_list = [images]
368
+ else:
369
+ images_list = images
370
+
371
+ new_images_list = []
372
+ image_sizes_list = []
373
+ tgt_sizes_list = []
374
+
375
+ for _images in images_list:
376
+ if _images is None or len(_images) == 0:
377
+ new_images_list.append([])
378
+ image_sizes_list.append([])
379
+ tgt_sizes_list.append([])
380
+ continue
381
+ if not valid_images(_images):
382
+ raise ValueError(
383
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
384
+ "torch.Tensor, tf.Tensor or jax.ndarray."
385
+ )
386
+
387
+ _images = [self.to_pil_image(image).convert("RGB") for image in _images]
388
+ input_data_format = infer_channel_dimension_format(np.array(_images[0]))
389
+
390
+ new_images = []
391
+ image_sizes = [image.size for image in _images]
392
+ tgt_sizes = []
393
+ for image in _images:
394
+ image_patches = self.get_sliced_images(image, max_slice_nums)
395
+ image_patches = [to_numpy_array(image).astype(np.float32) / 255 for image in image_patches]
396
+ image_patches = [
397
+ self.normalize(image=image, mean=self.mean, std=self.std, input_data_format=input_data_format)
398
+ for image in image_patches
399
+ ]
400
+ image_patches = [
401
+ to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
402
+ for image in image_patches
403
+ ]
404
+ for slice_image in image_patches:
405
+ new_images.append(self.reshape_by_patch(slice_image))
406
+ tgt_sizes.append(np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size)))
407
+
408
+ if tgt_sizes:
409
+ tgt_sizes = np.vstack(tgt_sizes)
410
+
411
+ new_images_list.append(new_images)
412
+ image_sizes_list.append(image_sizes)
413
+ tgt_sizes_list.append(tgt_sizes)
414
+ return MiniCPMVBatchFeature(
415
+ data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list}, tensor_type=return_tensors
416
+ )
417
+
418
+ AutoImageProcessor.register("MiniCPMVImageProcessor", MiniCPMVImageProcessor)