1024m commited on
Commit
d98df3d
1 Parent(s): 5d343f8

Upload image_preprocessing_molmo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. image_preprocessing_molmo.py +548 -0
image_preprocessing_molmo.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image processor class for Molmo"""
2
+ from typing import List, Optional, Union, Mapping
3
+
4
+ import numpy as np
5
+ import einops
6
+ import torch
7
+ import torchvision.transforms
8
+ from torchvision.transforms import InterpolationMode
9
+ from torchvision.transforms.functional import convert_image_dtype
10
+
11
+ from transformers.image_utils import (
12
+ OPENAI_CLIP_MEAN,
13
+ OPENAI_CLIP_STD,
14
+ ImageInput,
15
+ is_valid_image,
16
+ )
17
+ from transformers.processing_utils import ImagesKwargs
18
+ from transformers.image_processing_utils import BaseImageProcessor
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ def pad_to_bounding_box(
26
+ image, offset_height, offset_width, target_height,
27
+ target_width, value=0
28
+ ):
29
+ height, width = image.shape[:2]
30
+ after_padding_width = target_width - offset_width - width
31
+ after_padding_height = target_height - offset_height - height
32
+ return np.pad(image, [
33
+ [offset_height, after_padding_height],
34
+ [offset_width, after_padding_width],
35
+ [0, 0]
36
+ ], constant_values=value)
37
+
38
+
39
+ def normalize_image(image, offset, scale):
40
+ image -= np.array(offset, dtype=np.float32)[None, None, :]
41
+ image /= np.array(scale, dtype=np.float32)[None, None, :]
42
+ return image
43
+
44
+
45
+ def resize_and_pad(
46
+ image,
47
+ desired_output_size,
48
+ resize_method="torch-bilinear",
49
+ pad_value=0,
50
+ normalize=True,
51
+ image_mean=OPENAI_CLIP_MEAN,
52
+ image_std=OPENAI_CLIP_STD,
53
+ ):
54
+ desired_height, desired_width = desired_output_size
55
+ height, width = image.shape[:2]
56
+
57
+ # Cast into float32 since the training code did this in float32 and it (very rarely) effects
58
+ # the results after rounding.
59
+ image_scale_y = np.array(desired_height, np.float32) / np.array(height, np.float32)
60
+ image_scale_x = np.array(desired_width, np.float32) / np.array(width, np.float32)
61
+ image_scale = min(image_scale_x, image_scale_y)
62
+ scaled_height = int(np.array(height, np.float32) * image_scale)
63
+ scaled_width = int(np.array(width, np.float32) * image_scale)
64
+
65
+ if resize_method == "tensorflow":
66
+ # This how the original training code did resizing, it can produce slightly different
67
+ # results then using torch resize so we keep it just in case
68
+ import tensorflow as tf
69
+ image = tf.image.convert_image_dtype(tf.constant(image), dtype=tf.float32)
70
+ image = tf.image.resize(
71
+ image,
72
+ [scaled_height, scaled_width],
73
+ method=tf.image.ResizeMethod.BILINEAR,
74
+ antialias=True,
75
+ )
76
+ image = tf.clip_by_value(image, 0.0, 1.0)
77
+ image = image.numpy()
78
+ elif resize_method == "torch-bilinear":
79
+ image = torch.permute(torch.from_numpy(image), [2, 0, 1])
80
+ image = convert_image_dtype(image) # resize in float32 to match the training code
81
+ image = torchvision.transforms.Resize(
82
+ [scaled_height, scaled_width], InterpolationMode.BILINEAR, antialias=True
83
+ )(image)
84
+ image = torch.clip(image, 0.0, 1.0)
85
+ image = torch.permute(image, [1, 2, 0]).numpy()
86
+ else:
87
+ raise NotImplementedError(resize_method)
88
+
89
+ top_pad = (desired_height - scaled_height) // 2
90
+ left_pad = (desired_width - scaled_width) // 2
91
+ padding = [
92
+ [top_pad, desired_height - scaled_height - top_pad],
93
+ [left_pad, desired_width - scaled_width - left_pad],
94
+ [0, 0]
95
+ ]
96
+ image_mask = np.pad(np.ones_like(image[:, :, 0], dtype=bool), padding[:2])
97
+ image = np.pad(image, padding, constant_values=pad_value)
98
+ if normalize:
99
+ image = normalize_image(image, offset=image_mean, scale=image_std)
100
+ return image, image_mask
101
+
102
+
103
+ def select_tiling(h, w, patch_size, max_num_patches):
104
+ """Decide how best to divide in image of size [w, h] in up to max_num_patches of size patch_size"""
105
+ original_size = np.stack([h, w]) # [1, 2]
106
+ original_res = h * w
107
+ tilings = []
108
+ for i in range(1, max_num_patches+1):
109
+ for j in range(1, max_num_patches+1):
110
+ if i*j <= max_num_patches:
111
+ tilings.append((i, j))
112
+ # sort so argmin and argmax favour smaller tilings in the event of a tie
113
+ tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
114
+ candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
115
+ candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
116
+
117
+ # How much we would need to scale the image to fit exactly in each tiling
118
+ original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
119
+ required_scale_d = candidate_resolutions.astype(np.float32) / original_size
120
+ required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
121
+ if np.all(required_scale < 1):
122
+ # We are forced to downscale, so try to minimize the amount of downscaling
123
+ ix = np.argmax(required_scale)
124
+ else:
125
+ # Pick the resolution that required the least upscaling so that it most closely fits the image
126
+ required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
127
+ ix = np.argmin(required_scale)
128
+ return candidate_tilings[ix]
129
+
130
+
131
+ class MolmoImagesKwargs(ImagesKwargs, total=False):
132
+ max_crops: Optional[int]
133
+ overlap_margins: Optional[List[int]]
134
+ base_image_input_size: Optional[List[int]]
135
+ image_token_length_w: Optional[int]
136
+ image_token_length_h: Optional[int]
137
+ image_patch_size: Optional[int]
138
+ image_padding_mask: Optional[bool]
139
+
140
+
141
+ class MolmoImageProcessor(BaseImageProcessor):
142
+ """Preprocess images and multi-model inputs"""
143
+
144
+ def __init__(
145
+ self,
146
+ max_crops: int = 12,
147
+ overlap_margins: List[int] = (4, 4),
148
+ base_image_input_size: List[int] = (336, 336),
149
+ image_token_length_w: int = 12,
150
+ image_token_length_h: int = 12,
151
+ image_patch_size: int = 14,
152
+ image_padding_mask: bool = True,
153
+ do_normalize: bool = True,
154
+ image_mean: Optional[Union[float, List[float]]] = None,
155
+ image_std: Optional[Union[float, List[float]]] = None,
156
+ **kwargs,
157
+ ):
158
+ super().__init__(**kwargs)
159
+ self.max_crops = max_crops
160
+ self.overlap_margins = overlap_margins
161
+ self.base_image_input_size = base_image_input_size
162
+ self.image_token_length_w = image_token_length_w
163
+ self.image_token_length_h = image_token_length_h
164
+ self.image_patch_size = image_patch_size
165
+ self.image_padding_mask = image_padding_mask
166
+ self.do_normalize = do_normalize
167
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
168
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
169
+
170
+ def image_to_patches_and_tokens(
171
+ self,
172
+ image: ImageInput,
173
+ image_patch_token_id: int,
174
+ image_col_token_id: int,
175
+ image_start_token_id: int,
176
+ image_end_token_id: int,
177
+ max_crops: Optional[int] = None,
178
+ overlap_margins: Optional[List[int]] = None,
179
+ base_image_input_size: Optional[Union[int, List[int]]] = None,
180
+ image_token_length_w: Optional[int] = None,
181
+ image_token_length_h: Optional[int] = None,
182
+ image_patch_size: Optional[int] = None,
183
+ ):
184
+ if isinstance(base_image_input_size, int):
185
+ base_image_input_size = (base_image_input_size, base_image_input_size)
186
+
187
+ base_image_input_d = image_patch_size
188
+ tokens_per_image = image_token_length_w * image_token_length_h
189
+ image_base_patch_w = base_image_input_size[1] // base_image_input_d
190
+ image_base_patch_h = base_image_input_size[0] // base_image_input_d
191
+
192
+ original_image_h, original_image_w = image.shape[:2]
193
+ crop_size = base_image_input_size[0]
194
+
195
+ # Discard this many patches from the (left/top, right/bottom) of crops
196
+ left_margin, right_margin = overlap_margins
197
+ # left_margin, right_margin = 2, 2
198
+ assert left_margin % 2 == 0 # Required for compatibility with 2x2 pooling
199
+ total_margin_pixels = base_image_input_d*(right_margin + left_margin) # pixels removed per dim
200
+ crop_patches = base_image_input_size[0] // base_image_input_d # patches per crop dim
201
+ crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
202
+ crop_window_size = crop_window_patches * base_image_input_d
203
+ tiling = select_tiling(
204
+ original_image_h - total_margin_pixels,
205
+ original_image_w - total_margin_pixels,
206
+ crop_window_size,
207
+ max_crops
208
+ )
209
+ src, img_mask = resize_and_pad(
210
+ image,
211
+ [tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels]
212
+ )
213
+
214
+ # Now we have to split the image into crops, while keeping track of how each patch in the
215
+ # each crop should be ordered in the global image, this require a lot of tricky booking
216
+ n_crops = tiling[0] * tiling[1]
217
+ patches_arr = []
218
+ mask_arr = []
219
+ patch_ordering_arr = []
220
+
221
+ # We assume 2x2 pooling, but can allow padding the right/bottom with extra
222
+ # patches if the number of patches per side is not even
223
+ assert (crop_patches+1)//2 == image_token_length_h
224
+ assert (crop_patches+1)//2 == image_token_length_w
225
+ on = 0
226
+ on_patch = 0
227
+ for i in range(tiling[0]):
228
+ y0 = i*crop_window_size
229
+ if i == 0:
230
+ crop_y0 = 0
231
+ else:
232
+ crop_y0 = left_margin // 2
233
+
234
+ crop_h = image_base_patch_h - (right_margin + left_margin)
235
+ if i == 0:
236
+ crop_h += left_margin
237
+ if i == (tiling[0]-1):
238
+ crop_h += right_margin
239
+ for j in range(tiling[1]):
240
+ x0 = j*crop_window_size
241
+ if j == 0:
242
+ crop_x0 = 0
243
+ else:
244
+ crop_x0 = left_margin // 2
245
+
246
+ crop_w = image_base_patch_w - (right_margin + left_margin)
247
+ if j == 0:
248
+ crop_w += left_margin
249
+ if j == (tiling[1]-1):
250
+ crop_w += right_margin
251
+
252
+ pooled_w = (crop_w + 1) // 2
253
+ pooled_h = (crop_h + 1) // 2
254
+ patch_ordering_arr.append(
255
+ pad_to_bounding_box(
256
+ np.reshape(np.arange(on, on+pooled_h*pooled_w, dtype=np.int32), (pooled_h, pooled_w, 1)),
257
+ crop_y0, crop_x0, image_token_length_h, image_token_length_w, value=-1
258
+ )[:, :, 0]
259
+ )
260
+ patches_arr.append(src[y0:y0+crop_size, x0:x0+crop_size])
261
+ mask_arr.append(img_mask[y0:y0+crop_size, x0:x0+crop_size])
262
+
263
+ on += pooled_h*pooled_w
264
+ on_patch += 1
265
+ patches = np.stack(patches_arr)
266
+ patch_ordering = np.stack(patch_ordering_arr)
267
+ img_mask = np.stack(mask_arr)
268
+
269
+ # Switch to [n_crops, n_patches, pixels_per_patch] format
270
+ image_layout_impatch_w, image_layout_impatch_h = tiling[0], tiling[1]
271
+ patches = einops.rearrange(
272
+ patches, 'p (h dh) (w dw) c -> p (h w) (dh dw c)',
273
+ dh=base_image_input_d,
274
+ dw=base_image_input_d,
275
+ h=image_base_patch_h,
276
+ w=image_base_patch_w
277
+ )
278
+ img_mask = einops.rearrange(
279
+ img_mask, 'p (h dh) (w dw) -> p (h w) (dh dw)',
280
+ dh=base_image_input_d,
281
+ dw=base_image_input_d,
282
+ h=image_base_patch_h,
283
+ w=image_base_patch_w
284
+ )
285
+
286
+ img_mask = img_mask.astype(np.float32).mean(axis=-1)
287
+ patch_ordering = np.reshape(patch_ordering, [-1])
288
+ valid = patch_ordering >= 0
289
+
290
+ # Transpose order, to get left-to-right order instead of crop-by-crop order
291
+ patch_ordering_rh = np.reshape(
292
+ patch_ordering,
293
+ [tiling[0], tiling[1], image_token_length_h, image_token_length_w]
294
+ )
295
+ patch_ordering_rh = np.transpose(patch_ordering_rh, [0, 2, 1, 3])
296
+ patch_ordering_rh = np.reshape(patch_ordering_rh, [-1])
297
+
298
+ # The transpose will screw up which patches are masked, project the
299
+ # new order into sparse structure of `patch_ordering` to fix this
300
+ patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0]
301
+
302
+ # Now build the output tokens
303
+ h = tiling[0] * crop_window_patches + (right_margin+left_margin)
304
+ w = tiling[1] * crop_window_patches + (right_margin+left_margin)
305
+ per_row = np.full(
306
+ ((w+1)//2,),
307
+ image_patch_token_id,
308
+ )
309
+ per_row = np.concatenate([per_row, [image_col_token_id]], 0)
310
+
311
+ joint = np.tile(per_row, [(h+1)//2])
312
+ joint = [
313
+ [image_start_token_id],
314
+ joint,
315
+ [image_end_token_id]
316
+ ]
317
+
318
+ # Finally do the same for the global image
319
+ resized, _ = resize_and_pad(image, base_image_input_size)
320
+ resized = einops.rearrange(
321
+ resized, '(h dh) (w dw) c -> (h w) (dh dw c)',
322
+ dh=base_image_input_d,
323
+ dw=base_image_input_d,
324
+ h=image_base_patch_h,
325
+ w=image_base_patch_w
326
+ )
327
+ patches = np.concatenate([np.expand_dims(resized, 0), patches], 0)
328
+
329
+ # Global image goes first, so the order of patches in previous crops gets increased
330
+ patch_ordering = np.where(
331
+ patch_ordering >= 0,
332
+ patch_ordering + tokens_per_image,
333
+ -1
334
+ )
335
+ patch_ordering = np.concatenate([np.arange(0, tokens_per_image), patch_ordering], 0)
336
+ per_row = np.full(
337
+ (image_token_length_w,),
338
+ image_patch_token_id,
339
+ )
340
+ per_row = np.concatenate([per_row, [image_col_token_id]], 0)
341
+ extra_tokens = np.tile(per_row, [image_token_length_h])
342
+ joint = [
343
+ [image_start_token_id],
344
+ extra_tokens,
345
+ [image_end_token_id],
346
+ ] + joint
347
+
348
+ joint = np.concatenate(joint, 0)
349
+ img_mask = np.pad(img_mask, [[0, 1], [0, 0]], constant_values=-1)
350
+ return patches, joint, patch_ordering, img_mask
351
+
352
+ def build_image_input_idx(
353
+ self,
354
+ image_tokens: np.ndarray,
355
+ patch_order: np.ndarray,
356
+ image_patch_token_id: int,
357
+ no_image: Optional[bool] = None,
358
+ image_token_length_w: Optional[int] = None,
359
+ image_token_length_h: Optional[int] = None,
360
+ ):
361
+ """Converts `patch_order` into a mapping of token_id -> patch_id"""
362
+
363
+ tokens_per_image = image_token_length_w * image_token_length_h
364
+ if no_image is not None and no_image:
365
+ return np.zeros((0, tokens_per_image), np.int32)
366
+
367
+ # Indices to insert the patches
368
+ image_input_idx = image_tokens == image_patch_token_id
369
+ image_input_idx = np.nonzero(image_input_idx)[0].astype(np.int32)
370
+
371
+ if patch_order is not None:
372
+ n_tokens = image_input_idx.shape[0]
373
+ patch_order = np.reshape(patch_order, [-1])
374
+ n_patches = patch_order.shape[0]
375
+
376
+ valid = patch_order >= 0
377
+ n_valid_patches = valid.sum()
378
+ assert len(image_input_idx) == n_valid_patches
379
+
380
+ sorted_patch_ixs = np.zeros([n_tokens], np.int32)
381
+ sorted_patch_ixs[patch_order[valid]] = np.arange(n_valid_patches, dtype=np.int32)
382
+
383
+ # Project the inverted mapping into same sparse structure
384
+ sorted_patch_ixs_ex = np.full(np.shape(patch_order), -1)
385
+ sorted_patch_ixs_ex[valid] = sorted_patch_ixs
386
+
387
+ # Do the gather and then re-masked outputs that were masked in `sorted_patch_ixs`
388
+ valid = (sorted_patch_ixs_ex >= 0).astype(np.int32)
389
+ image_input_idx = image_input_idx[sorted_patch_ixs_ex*valid]
390
+ image_input_idx = image_input_idx*valid - 100*(1 - valid)
391
+ image_input_idx = np.reshape(image_input_idx, [-1, tokens_per_image])
392
+ return image_input_idx
393
+
394
+ def preprocess(
395
+ self,
396
+ image: np.ndarray,
397
+ image_patch_token_id: int,
398
+ image_col_token_id: int,
399
+ image_start_token_id: int,
400
+ image_end_token_id: int,
401
+ max_crops: Optional[int] = None,
402
+ overlap_margins: Optional[List[int]] = None,
403
+ base_image_input_size: Optional[Union[int, List[int]]] = None,
404
+ image_token_length_w: Optional[int] = None,
405
+ image_token_length_h: Optional[int] = None,
406
+ image_patch_size: Optional[int] = None,
407
+ **kwargs,
408
+ ):
409
+ """Preprocesses an image
410
+
411
+ Returns:
412
+ crops: (n_crops, n_patches, patch_dim) individual crops, `n_crops` might
413
+ change between images but the other dimension are fixed
414
+ tokens: (n_tokens,) int32 tokens, pad tokens indicate where to insert the
415
+ patch features, might include other special tokens as well
416
+ image_idx: (n_crops, n_patches) index in `tokens` to put the patch features from the
417
+ crops after pooling, negative values indicates patches features to exclude
418
+ padding_mask: (n_crops, n_patches) what percent of each crop is padding, can be None
419
+ if the image mask is not being used.
420
+ """
421
+
422
+ max_crops = max_crops or self.max_crops
423
+ overlap_margins = overlap_margins or self.overlap_margins
424
+ base_image_input_size = base_image_input_size or self.base_image_input_size
425
+ image_token_length_w = image_token_length_w or self.image_token_length_w
426
+ image_token_length_h = image_token_length_h or self.image_token_length_h
427
+ image_patch_size = image_patch_size or self.image_patch_size
428
+
429
+ crops, image_tokens, patch_ordering, img_mask = self.image_to_patches_and_tokens(
430
+ image,
431
+ image_patch_token_id,
432
+ image_col_token_id,
433
+ image_start_token_id,
434
+ image_end_token_id,
435
+ max_crops,
436
+ overlap_margins,
437
+ base_image_input_size,
438
+ image_token_length_w,
439
+ image_token_length_h,
440
+ image_patch_size,
441
+ )
442
+ patch_idx = self.build_image_input_idx(
443
+ image_tokens,
444
+ patch_ordering,
445
+ image_patch_token_id,
446
+ image_token_length_w=image_token_length_w,
447
+ image_token_length_h=image_token_length_h,
448
+ )
449
+ return crops, image_tokens, patch_idx, img_mask
450
+
451
+ def multimodal_preprocess(
452
+ self,
453
+ images: np.ndarray,
454
+ tokens: List[int],
455
+ image_idx: np.ndarray,
456
+ sequence_length: int,
457
+ image_patch_token_id: int,
458
+ image_col_token_id: int,
459
+ image_start_token_id: int,
460
+ image_end_token_id: int,
461
+ **kwargs,
462
+ ):
463
+ """Merge images and text tokens into multi-modal features for the model
464
+
465
+ :param images: images to use as input
466
+ :param tokens: input text tokens
467
+ :param image_idx: where to insert the images into `tokens`
468
+ :params image_patch_token_id: id to use of tokens that will contain image features
469
+ :params image_col_token_id: token id for image column special tokens
470
+ :params image_start_token_id: token id for image start special tokens
471
+ :params image_end_token_id: token id for image end special tokens
472
+ :params kwargs: override preprocessor default args
473
+ """
474
+ max_total_crops = kwargs.get("max_crops") or self.max_crops
475
+ image_token_length_w = kwargs.get("image_token_length_w") or self.image_token_length_w
476
+ image_token_length_h = kwargs.get("image_token_length_h") or self.image_token_length_h
477
+ image_patch_size = kwargs.get("image_patch_size") or self.image_patch_size
478
+ base_image_input_size = kwargs.get("base_image_input_size") or self.base_image_input_size
479
+ image_num_patch = (
480
+ base_image_input_size[0] // image_patch_size,
481
+ base_image_input_size[1] // image_patch_size,
482
+ )
483
+ image_padding_mask = kwargs.get("image_padding_mask") or self.image_padding_mask
484
+
485
+ tokens_per_image = image_token_length_w * image_token_length_h
486
+ n_pixels = image_patch_size * image_patch_size * 3
487
+ n_patches = image_num_patch[0] * image_num_patch[1]
488
+
489
+ if images is None:
490
+ return {
491
+ "input_ids": tokens,
492
+ "images": None,
493
+ "image_input_idx": None
494
+ }
495
+ else:
496
+ n = len(images)
497
+ all_crops = []
498
+ all_image_idx = []
499
+ out_tokens = []
500
+ all_crop_masks = []
501
+
502
+ for ix in range(n):
503
+ token_ix = image_idx[ix]
504
+ crops, image_tokens, patch_idx, img_mask = self.preprocess(
505
+ images[ix],
506
+ image_patch_token_id,
507
+ image_col_token_id,
508
+ image_start_token_id,
509
+ image_end_token_id,
510
+ **kwargs,
511
+ )
512
+
513
+ if token_ix == -1: # -1 is an image inserted at the very start
514
+ start = 0
515
+ token_ix = 0
516
+ end = 0
517
+ else:
518
+ start = 0 if ix == 0 else image_idx[ix-1] + 1
519
+ end = token_ix + 1
520
+
521
+ all_image_idx.append(patch_idx + token_ix)
522
+ all_crops.append(crops)
523
+ out_tokens.append(tokens[start:token_ix])
524
+ out_tokens.append(image_tokens)
525
+ if ix == (n - 1):
526
+ out_tokens.append(tokens[end:])
527
+ if image_padding_mask:
528
+ all_crop_masks.append(img_mask)
529
+
530
+ input_ids = np.concatenate(out_tokens, 0)
531
+ images = np.concatenate(all_crops, 0)
532
+ image_input_idx = np.concatenate(all_image_idx, 0)
533
+ if image_padding_mask:
534
+ image_masks = np.concatenate(all_crop_masks, 0)
535
+ else:
536
+ image_masks = None
537
+
538
+ out = {
539
+ "input_ids": input_ids,
540
+ "images": images,
541
+ "image_input_idx": image_input_idx
542
+ }
543
+ if image_masks is not None:
544
+ out["image_masks"] = image_masks
545
+ return out
546
+
547
+
548
+ MolmoImageProcessor.register_for_auto_class()