|
from typing import Optional, Union, Dict, Any, List |
|
|
|
import torch |
|
import math |
|
import PIL.Image |
|
import PIL.ImageSequence |
|
import numpy as np |
|
import PIL |
|
from PIL import Image |
|
|
|
from transformers.utils import TensorType, requires_backends, is_torch_dtype, is_torch_device |
|
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
|
from transformers import AutoImageProcessor |
|
from transformers.image_transforms import to_channel_dimension_format |
|
from transformers.image_utils import ( |
|
ImageInput, |
|
make_list_of_images, |
|
valid_images, |
|
is_torch_tensor, |
|
is_batched, |
|
to_numpy_array, |
|
infer_channel_dimension_format, |
|
ChannelDimension |
|
) |
|
|
|
|
|
def recursive_converter(converter, value): |
|
if isinstance(value, list): |
|
new_value = [] |
|
for v in value: |
|
new_value += [recursive_converter(converter, v)] |
|
return new_value |
|
else: |
|
return converter(value) |
|
|
|
|
|
class MiniCPMVBatchFeature(BatchFeature): |
|
r""" |
|
Extend from BatchFeature for supporting various image size |
|
""" |
|
def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None): |
|
super().__init__(data) |
|
self.convert_to_tensors(tensor_type=tensor_type) |
|
|
|
def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None): |
|
if tensor_type is None: |
|
return self |
|
|
|
is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type) |
|
|
|
def converter(value): |
|
try: |
|
if not is_tensor(value): |
|
tensor = as_tensor(value) |
|
return tensor |
|
except: |
|
if key == "overflowing_values": |
|
raise ValueError("Unable to create tensor returning overflowing values of different lengths. ") |
|
raise ValueError( |
|
"Unable to create tensor, you should probably activate padding " |
|
"with 'padding=True' to have batched tensors with the same length." |
|
) |
|
|
|
|
|
for key, value in self.items(): |
|
self[key] = recursive_converter(converter, value) |
|
return self |
|
|
|
def to(self, *args, **kwargs) -> "MiniCPMVBatchFeature": |
|
requires_backends(self, ["torch"]) |
|
import torch |
|
|
|
def cast_tensor(v): |
|
|
|
if torch.is_floating_point(v): |
|
|
|
return v.to(*args, **kwargs) |
|
elif device is not None: |
|
return v.to(device=device) |
|
else: |
|
return v |
|
|
|
new_data = {} |
|
device = kwargs.get("device") |
|
|
|
if device is None and len(args) > 0: |
|
|
|
arg = args[0] |
|
if is_torch_dtype(arg): |
|
|
|
pass |
|
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): |
|
device = arg |
|
else: |
|
|
|
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") |
|
|
|
for k, v in self.items(): |
|
new_data[k] = recursive_converter(cast_tensor, v) |
|
self.data = new_data |
|
return self |
|
|
|
|
|
class MiniCPMVImageProcessor(BaseImageProcessor): |
|
model_input_names = ["pixel_values"] |
|
|
|
def __init__( |
|
self, |
|
max_slice_nums=9, |
|
scale_resolution=448, |
|
patch_size=14, |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
self.max_slice_nums = max_slice_nums |
|
self.scale_resolution = scale_resolution |
|
self.patch_size = patch_size |
|
self.use_image_id = kwargs.pop("use_image_id", False) |
|
self.image_feature_size = kwargs.pop("image_feature_size", 64) |
|
self.im_start_token = kwargs.pop("im_start", "<image>") |
|
self.im_end_token = kwargs.pop("im_end", "</image>") |
|
self.slice_start_token = kwargs.pop("slice_start", "<slice>") |
|
self.slice_end_token = kwargs.pop("slice_end", "</slice>") |
|
self.unk_token = kwargs.pop("unk", "<unk>") |
|
self.im_id_start = kwargs.pop("im_id_start", "<image_id>") |
|
self.im_id_end = kwargs.pop("im_id_end", "</image_id>") |
|
self.slice_mode = kwargs.pop("slice_mode", True) |
|
self.mean = np.array(kwargs.pop("norm_mean", [0.5, 0.5, 0.5])) |
|
self.std = np.array(kwargs.pop("norm_std", [0.5, 0.5, 0.5])) |
|
self.version = kwargs.pop("version", 2.0) |
|
|
|
def ensure_divide(self, length, patch_size): |
|
return max(round(length / patch_size) * patch_size, patch_size) |
|
|
|
def find_best_resize(self, |
|
original_size, |
|
scale_resolution, |
|
patch_size, |
|
allow_upscale=False): |
|
width, height = original_size |
|
if (width * height > |
|
scale_resolution * scale_resolution) or allow_upscale: |
|
r = width / height |
|
height = int(scale_resolution / math.sqrt(r)) |
|
width = int(height * r) |
|
best_width = self.ensure_divide(width, patch_size) |
|
best_height = self.ensure_divide(height, patch_size) |
|
return (best_width, best_height) |
|
|
|
def get_refine_size(self, |
|
original_size, |
|
grid, |
|
scale_resolution, |
|
patch_size, |
|
allow_upscale=False): |
|
width, height = original_size |
|
grid_x, grid_y = grid |
|
|
|
refine_width = self.ensure_divide(width, grid_x) |
|
refine_height = self.ensure_divide(height, grid_y) |
|
|
|
grid_width = refine_width / grid_x |
|
grid_height = refine_height / grid_y |
|
|
|
best_grid_size = self.find_best_resize((grid_width, grid_height), |
|
scale_resolution, |
|
patch_size, |
|
allow_upscale=allow_upscale) |
|
refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y) |
|
return refine_size |
|
|
|
def split_to_patches(self, image, grid): |
|
patches = [] |
|
width, height = image.size |
|
grid_x = int(width / grid[0]) |
|
grid_y = int(height / grid[1]) |
|
for i in range(0, height, grid_y): |
|
images = [] |
|
for j in range(0, width, grid_x): |
|
box = (j, i, j + grid_x, i + grid_y) |
|
patch = image.crop(box) |
|
images.append(patch) |
|
patches.append(images) |
|
return patches |
|
|
|
def slice_image( |
|
self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False |
|
): |
|
original_size = image.size |
|
source_image = None |
|
best_grid = self.get_sliced_grid(original_size, max_slice_nums, never_split) |
|
patches = [] |
|
|
|
if best_grid is None: |
|
|
|
best_size = self.find_best_resize( |
|
original_size, scale_resolution, patch_size, allow_upscale=True |
|
) |
|
source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC) |
|
else: |
|
|
|
best_resize = self.find_best_resize(original_size, scale_resolution, patch_size) |
|
source_image = image.copy().resize(best_resize, resample=Image.Resampling.BICUBIC) |
|
refine_size = self.get_refine_size( |
|
original_size, best_grid, scale_resolution, patch_size, allow_upscale=True |
|
) |
|
refine_image = image.resize(refine_size, resample=Image.Resampling.BICUBIC) |
|
patches = self.split_to_patches(refine_image, best_grid) |
|
|
|
return source_image, patches, best_grid |
|
|
|
def get_grid_placeholder(self, grid): |
|
if grid is None: |
|
return "" |
|
slice_image_placeholder = ( |
|
self.slice_start_token |
|
+ self.unk_token * self.image_feature_size |
|
+ self.slice_end_token |
|
) |
|
|
|
cols = grid[0] |
|
rows = grid[1] |
|
slices = [] |
|
for i in range(rows): |
|
lines = [] |
|
for j in range(cols): |
|
lines.append(slice_image_placeholder) |
|
slices.append("".join(lines)) |
|
|
|
slice_placeholder = "\n".join(slices) |
|
return slice_placeholder |
|
|
|
def get_image_id_placeholder(self, idx=0): |
|
return f"{self.im_id_start}{idx}{self.im_id_end}" |
|
|
|
def get_sliced_images(self, image, max_slice_nums=None): |
|
slice_images = [] |
|
|
|
if not self.slice_mode: |
|
return [image] |
|
|
|
max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums) |
|
assert max_slice_nums > 0 |
|
source_image, patches, sliced_grid = self.slice_image( |
|
image, |
|
max_slice_nums, |
|
self.scale_resolution, |
|
self.patch_size |
|
) |
|
|
|
slice_images.append(source_image) |
|
if len(patches) > 0: |
|
for i in range(len(patches)): |
|
for j in range(len(patches[0])): |
|
slice_images.append(patches[i][j]) |
|
return slice_images |
|
|
|
def get_sliced_grid(self, image_size, max_slice_nums, nerver_split=False): |
|
original_width, original_height = image_size |
|
log_ratio = math.log(original_width / original_height) |
|
ratio = original_width * original_height / (self.scale_resolution * self.scale_resolution) |
|
multiple = min(math.ceil(ratio), max_slice_nums) |
|
if multiple <= 1 or nerver_split: |
|
return None |
|
candidate_split_grids_nums = [] |
|
for i in [multiple - 1, multiple, multiple + 1]: |
|
if i == 1 or i > max_slice_nums: |
|
continue |
|
candidate_split_grids_nums.append(i) |
|
|
|
candidate_grids = [] |
|
for split_grids_nums in candidate_split_grids_nums: |
|
m = 1 |
|
while m <= split_grids_nums: |
|
if split_grids_nums % m == 0: |
|
candidate_grids.append([m, split_grids_nums // m]) |
|
m += 1 |
|
|
|
best_grid = [1, 1] |
|
min_error = float("inf") |
|
for grid in candidate_grids: |
|
error = abs(log_ratio - math.log(grid[0] / grid[1])) |
|
if error < min_error: |
|
best_grid = grid |
|
min_error = error |
|
|
|
return best_grid |
|
|
|
def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=None, use_image_id=None): |
|
max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums) |
|
assert max_slice_nums > 0 |
|
grid = self.get_sliced_grid(image_size=image_size, max_slice_nums=max_slice_nums) |
|
|
|
image_placeholder = ( |
|
self.im_start_token |
|
+ self.unk_token * self.image_feature_size |
|
+ self.im_end_token |
|
) |
|
use_image_id = self.use_image_id if use_image_id is None else bool(use_image_id) |
|
if use_image_id: |
|
final_placeholder = self.get_image_id_placeholder(image_idx) + image_placeholder |
|
else: |
|
final_placeholder = image_placeholder |
|
|
|
if self.slice_mode: |
|
final_placeholder = final_placeholder + self.get_grid_placeholder(grid=grid) |
|
return final_placeholder |
|
|
|
def to_pil_image(self, image, rescale=None) -> PIL.Image.Image: |
|
""" |
|
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if |
|
needed. |
|
|
|
Args: |
|
image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`): |
|
The image to convert to the PIL Image format. |
|
rescale (`bool`, *optional*): |
|
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will |
|
default to `True` if the image type is a floating type, `False` otherwise. |
|
""" |
|
if isinstance(image, PIL.Image.Image): |
|
return image |
|
if is_torch_tensor(image): |
|
image = image.numpy() |
|
|
|
if isinstance(image, np.ndarray): |
|
if rescale is None: |
|
|
|
rescale = isinstance(image.flat[0], np.floating) |
|
|
|
if image.ndim == 3 and image.shape[0] in [1, 3]: |
|
image = image.transpose(1, 2, 0) |
|
if rescale: |
|
image = image * 255 |
|
image = image.astype(np.uint8) |
|
return PIL.Image.fromarray(image) |
|
return image |
|
|
|
def reshape_by_patch(self, image): |
|
""" |
|
:param image: shape [3, H, W] |
|
:param patch_size: |
|
:return: [3, patch_size, HW/patch_size] |
|
""" |
|
image = torch.from_numpy(image) |
|
patch_size = self.patch_size |
|
patches = torch.nn.functional.unfold( |
|
image, |
|
(patch_size, patch_size), |
|
stride=(patch_size, patch_size) |
|
) |
|
|
|
patches = patches.reshape(image.size(0), patch_size, patch_size, -1) |
|
patches = patches.permute(0, 1, 3, 2).reshape(image.size(0), patch_size, -1) |
|
return patches.numpy() |
|
|
|
def preprocess( |
|
self, |
|
images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]], |
|
do_pad: Optional[bool] = True, |
|
max_slice_nums: int = None, |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
**kwargs |
|
) -> MiniCPMVBatchFeature: |
|
if isinstance(images, Image.Image): |
|
images_list = [[images]] |
|
elif isinstance(images[0], Image.Image): |
|
images_list = [images] |
|
else: |
|
images_list = images |
|
|
|
new_images_list = [] |
|
image_sizes_list = [] |
|
tgt_sizes_list = [] |
|
|
|
for _images in images_list: |
|
if _images is None or len(_images) == 0: |
|
new_images_list.append([]) |
|
image_sizes_list.append([]) |
|
tgt_sizes_list.append([]) |
|
continue |
|
if not valid_images(_images): |
|
raise ValueError( |
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " |
|
"torch.Tensor, tf.Tensor or jax.ndarray." |
|
) |
|
|
|
_images = [self.to_pil_image(image).convert("RGB") for image in _images] |
|
input_data_format = infer_channel_dimension_format(np.array(_images[0])) |
|
|
|
new_images = [] |
|
image_sizes = [image.size for image in _images] |
|
tgt_sizes = [] |
|
for image in _images: |
|
image_patches = self.get_sliced_images(image, max_slice_nums) |
|
image_patches = [to_numpy_array(image).astype(np.float32) / 255 for image in image_patches] |
|
image_patches = [ |
|
self.normalize(image=image, mean=self.mean, std=self.std, input_data_format=input_data_format) |
|
for image in image_patches |
|
] |
|
image_patches = [ |
|
to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format) |
|
for image in image_patches |
|
] |
|
for slice_image in image_patches: |
|
new_images.append(self.reshape_by_patch(slice_image)) |
|
tgt_sizes.append(np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size))) |
|
|
|
if tgt_sizes: |
|
tgt_sizes = np.vstack(tgt_sizes) |
|
|
|
new_images_list.append(new_images) |
|
image_sizes_list.append(image_sizes) |
|
tgt_sizes_list.append(tgt_sizes) |
|
return MiniCPMVBatchFeature( |
|
data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list}, tensor_type=return_tensors |
|
) |
|
|
|
AutoImageProcessor.register("MiniCPMVImageProcessor", MiniCPMVImageProcessor) |
|
|