Spaces:
Runtime error
Runtime error
import math | |
from copy import deepcopy | |
from io import BytesIO | |
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union | |
import numpy as np | |
from transformers.image_utils import get_image_size, to_numpy_array | |
from typing_extensions import override | |
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER | |
from ..extras.packages import is_pillow_available, is_pyav_available | |
if is_pillow_available(): | |
from PIL import Image | |
from PIL.Image import Image as ImageObject | |
if is_pyav_available(): | |
import av | |
if TYPE_CHECKING: | |
import torch | |
from av.stream import Stream | |
from transformers import PreTrainedTokenizer, ProcessorMixin | |
from transformers.image_processing_utils import BaseImageProcessor | |
class EncodedImage(TypedDict): | |
path: Optional[str] | |
bytes: Optional[bytes] | |
ImageInput = Union[str, EncodedImage, ImageObject] | |
VideoInput = str | |
def _get_paligemma_token_type_ids( | |
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin" | |
) -> List[List[int]]: | |
r""" | |
Gets paligemma token type ids for computing loss. | |
Returns: | |
batch_token_type_ids: shape (batch_size, sequence_length) | |
""" | |
batch_token_type_ids = [] | |
for imglen, seqlen in zip(imglens, seqlens): | |
image_seqlen = imglen * getattr(processor, "image_seqlen") | |
batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen)) | |
return batch_token_type_ids | |
class BasePlugin: | |
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None: | |
self.image_token = image_token | |
self.video_token = video_token | |
def _validate_input( | |
self, | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
) -> None: | |
r""" | |
Validates if this model accepts the input modalities. | |
""" | |
if len(images) != 0 and self.image_token is None: | |
raise ValueError("This model does not support image input.") | |
if len(videos) != 0 and self.video_token is None: | |
raise ValueError("This model does not support video input.") | |
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": | |
r""" | |
Pre-processes a single image. | |
""" | |
image_resolution: int = kwargs.get("image_resolution") | |
if max(image.width, image.height) > image_resolution: | |
resize_factor = image_resolution / max(image.width, image.height) | |
width, height = int(image.width * resize_factor), int(image.height * resize_factor) | |
image = image.resize((width, height), resample=Image.NEAREST) | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
return image | |
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int: | |
r""" | |
Computes video sample frames according to fps. | |
""" | |
video_fps: float = kwargs.get("video_fps") | |
video_maxlen: int = kwargs.get("video_maxlen") | |
total_frames = video_stream.frames | |
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps | |
sample_frames = min(total_frames, video_maxlen, sample_frames) | |
return math.floor(sample_frames) | |
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]: | |
r""" | |
Regularizes images to avoid error. Including reading and pre-processing. | |
""" | |
results = [] | |
for image in images: | |
if isinstance(image, str): | |
image = Image.open(image) | |
elif isinstance(image, dict): | |
if image["bytes"] is not None: | |
image = Image.open(BytesIO(image["bytes"])) | |
else: | |
image = Image.open(image["path"]) | |
if not isinstance(image, ImageObject): | |
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image))) | |
results.append(self._preprocess_image(image, **kwargs)) | |
return results | |
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]: | |
r""" | |
Regularizes videos to avoid error. Including reading, resizing and converting. | |
""" | |
results = [] | |
for video in videos: | |
container = av.open(video, "r") | |
video_stream = next(stream for stream in container.streams if stream.type == "video") | |
total_frames = video_stream.frames | |
sample_frames = self._get_video_sample_frames(video_stream, **kwargs) | |
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) | |
frames: List["ImageObject"] = [] | |
container.seek(0) | |
for frame_idx, frame in enumerate(container.decode(video_stream)): | |
if frame_idx in sample_indices: | |
frames.append(frame.to_image()) | |
frames = self._regularize_images(frames, **kwargs) | |
results.append(frames) | |
return results | |
def _get_mm_inputs( | |
self, | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
processor: "ProcessorMixin", | |
) -> Dict[str, "torch.Tensor"]: | |
r""" | |
Processes visual inputs. | |
Returns: (llava and paligemma) | |
pixel_values: tensor with shape (B, C, H, W) | |
Returns: (qwen2-vl) | |
pixel_values: tensor with shape (num_patches, patch_dim) | |
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height | |
It holds num_patches == torch.prod(image_grid_thw) | |
""" | |
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") | |
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor) | |
input_dict = {"images": None} # default key | |
if len(images) != 0: | |
images = self._regularize_images( | |
images, | |
image_resolution=getattr(processor, "image_resolution", 512), | |
) | |
input_dict["images"] = images | |
if len(videos) != 0: | |
videos = self._regularize_videos( | |
videos, | |
image_resolution=getattr(processor, "video_resolution", 128), | |
video_fps=getattr(processor, "video_fps", 1.0), | |
video_maxlen=getattr(processor, "video_maxlen", 64), | |
) | |
input_dict["videos"] = videos | |
mm_inputs = {} | |
if image_processor != video_processor: | |
if input_dict.get("images") is not None: | |
mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt")) | |
if input_dict.get("videos") is not None: | |
mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt")) | |
elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl) | |
mm_inputs.update(image_processor(**input_dict, return_tensors="pt")) | |
return mm_inputs | |
def process_messages( | |
self, | |
messages: Sequence[Dict[str, str]], | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
processor: Optional["ProcessorMixin"], | |
) -> List[Dict[str, str]]: | |
r""" | |
Pre-processes input messages before tokenization for VLMs. | |
""" | |
self._validate_input(images, videos) | |
return messages | |
def process_token_ids( | |
self, | |
input_ids: List[int], | |
labels: Optional[List[int]], | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
tokenizer: "PreTrainedTokenizer", | |
processor: Optional["ProcessorMixin"], | |
) -> Tuple[List[int], Optional[List[int]]]: | |
r""" | |
Pre-processes token ids after tokenization for VLMs. | |
""" | |
self._validate_input(images, videos) | |
return input_ids, labels | |
def get_mm_inputs( | |
self, | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
imglens: Sequence[int], | |
vidlens: Sequence[int], | |
seqlens: Sequence[int], | |
processor: Optional["ProcessorMixin"], | |
) -> Dict[str, Union[List[int], "torch.Tensor"]]: | |
r""" | |
Builds batched multimodal inputs for VLMs. | |
""" | |
self._validate_input(images, videos) | |
return {} | |
class LlavaPlugin(BasePlugin): | |
def process_messages( | |
self, | |
messages: Sequence[Dict[str, str]], | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
processor: Optional["ProcessorMixin"], | |
) -> List[Dict[str, str]]: | |
self._validate_input(images, videos) | |
num_image_tokens = 0 | |
image_seqlen = getattr(processor, "image_seqlen") | |
messages = deepcopy(messages) | |
for message in messages: | |
content = message["content"] | |
while IMAGE_PLACEHOLDER in content: | |
num_image_tokens += 1 | |
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) | |
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen) | |
if len(images) != num_image_tokens: | |
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) | |
return messages | |
def get_mm_inputs( | |
self, | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
imglens: Sequence[int], | |
vidlens: Sequence[int], | |
seqlens: Sequence[int], | |
processor: Optional["ProcessorMixin"], | |
) -> Dict[str, Union[List[int], "torch.Tensor"]]: | |
self._validate_input(images, videos) | |
return self._get_mm_inputs(images, videos, processor) | |
class LlavaNextPlugin(BasePlugin): | |
def process_messages( | |
self, | |
messages: Sequence[Dict[str, str]], | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
processor: Optional["ProcessorMixin"], | |
) -> List[Dict[str, str]]: | |
self._validate_input(images, videos) | |
num_image_tokens = 0 | |
messages = deepcopy(messages) | |
mm_inputs = self._get_mm_inputs(images, videos, processor) | |
if "image_sizes" in mm_inputs: | |
image_sizes = iter(mm_inputs["image_sizes"]) | |
if "pixel_values" in mm_inputs: | |
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) | |
for message in messages: | |
content = message["content"] | |
while self.image_token in content: | |
image_size = next(image_sizes) | |
orig_height, orig_width = image_size | |
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) | |
if processor.vision_feature_select_strategy == "default": | |
image_seqlen -= 1 | |
num_image_tokens += 1 | |
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1) | |
message["content"] = content.replace("{{image}}", self.image_token) | |
if len(images) != num_image_tokens: | |
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) | |
return messages | |
def get_mm_inputs( | |
self, | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
imglens: Sequence[int], | |
vidlens: Sequence[int], | |
seqlens: Sequence[int], | |
processor: Optional["ProcessorMixin"], | |
) -> Dict[str, Union[List[int], "torch.Tensor"]]: | |
self._validate_input(images, videos) | |
res = self._get_mm_inputs(images, videos, processor) | |
return res | |
class LlavaNextVideoPlugin(BasePlugin): | |
def process_messages( | |
self, | |
messages: Sequence[Dict[str, str]], | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
processor: Optional["ProcessorMixin"], | |
) -> List[Dict[str, str]]: | |
self._validate_input(images, videos) | |
num_image_tokens = 0 | |
num_video_tokens = 0 | |
messages = deepcopy(messages) | |
mm_inputs = self._get_mm_inputs(images, videos, processor) | |
if "pixel_values" in mm_inputs: | |
image_sizes = iter(mm_inputs["image_sizes"]) | |
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) | |
for message in messages: | |
content = message["content"] | |
while self.image_token in content: | |
image_size = next(image_sizes) | |
orig_height, orig_width = image_size | |
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) | |
if processor.vision_feature_select_strategy == "default": | |
image_seqlen -= 1 | |
num_image_tokens += 1 | |
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1) | |
message["content"] = content.replace("{{image}}", self.image_token) | |
if "pixel_values_videos" in mm_inputs: | |
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) | |
height, width = get_image_size(pixel_values_video[0]) | |
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim | |
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) | |
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer | |
for message in messages: | |
content = message["content"] | |
while self.video_token in content: | |
num_video_tokens += 1 | |
content = content.replace(self.video_token, "{{video}}", 1) | |
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen) | |
if len(images) != num_image_tokens: | |
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) | |
if len(videos) != num_video_tokens: | |
raise ValueError("The number of videos does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) | |
return messages | |
def get_mm_inputs( | |
self, | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
imglens: Sequence[int], | |
vidlens: Sequence[int], | |
seqlens: Sequence[int], | |
processor: Optional["ProcessorMixin"], | |
) -> Dict[str, Union[List[int], "torch.Tensor"]]: | |
self._validate_input(images, videos) | |
return self._get_mm_inputs(images, videos, processor) | |
class PaliGemmaPlugin(BasePlugin): | |
def process_messages( | |
self, | |
messages: Sequence[Dict[str, str]], | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
processor: Optional["ProcessorMixin"], | |
) -> List[Dict[str, str]]: | |
self._validate_input(images, videos) | |
num_image_tokens = 0 | |
messages = deepcopy(messages) | |
for message in messages: | |
content = message["content"] | |
while IMAGE_PLACEHOLDER in content: | |
num_image_tokens += 1 | |
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) | |
message["content"] = content.replace("{{image}}", "") | |
if len(images) != num_image_tokens: | |
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) | |
return messages | |
def process_token_ids( | |
self, | |
input_ids: List[int], | |
labels: Optional[List[int]], | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
tokenizer: "PreTrainedTokenizer", | |
processor: Optional["ProcessorMixin"], | |
) -> Tuple[List[int], Optional[List[int]]]: | |
self._validate_input(images, videos) | |
num_images = len(images) | |
image_seqlen = num_images * getattr(processor, "image_seqlen") | |
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) | |
input_ids = [image_token_id] * image_seqlen + input_ids | |
if labels is not None: | |
labels = [IGNORE_INDEX] * image_seqlen + labels | |
return input_ids, labels | |
def get_mm_inputs( | |
self, | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
imglens: Sequence[int], | |
vidlens: Sequence[int], | |
seqlens: Sequence[int], | |
processor: Optional["ProcessorMixin"], | |
) -> Dict[str, Union[List[int], "torch.Tensor"]]: | |
self._validate_input(images, videos) | |
mm_inputs = self._get_mm_inputs(images, videos, processor) | |
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor) | |
return mm_inputs | |
class Qwen2vlPlugin(BasePlugin): | |
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": | |
image = super()._preprocess_image(image, **kwargs) | |
if min(image.width, image.height) < 28: | |
width, height = max(image.width, 28), max(image.height, 28) | |
image = image.resize((width, height), resample=Image.NEAREST) | |
if image.width / image.height > 200: | |
width, height = image.height * 180, image.height | |
image = image.resize((width, height), resample=Image.NEAREST) | |
if image.height / image.width > 200: | |
width, height = image.width, image.width * 180 | |
image = image.resize((width, height), resample=Image.NEAREST) | |
return image | |
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int: | |
sample_frames = super()._get_video_sample_frames(video_stream, **kwargs) | |
sample_frames = sample_frames // 2 * 2 | |
return sample_frames | |
def process_messages( | |
self, | |
messages: Sequence[Dict[str, str]], | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
processor: Optional["ProcessorMixin"], | |
) -> List[Dict[str, str]]: | |
self._validate_input(images, videos) | |
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") | |
merge_length: int = getattr(image_processor, "merge_size") ** 2 | |
mm_inputs = self._get_mm_inputs(images, videos, processor) | |
image_grid_thw = mm_inputs.get("image_grid_thw", []) | |
video_grid_thw = mm_inputs.get("video_grid_thw", []) | |
num_image_tokens, num_video_tokens = 0, 0 | |
messages = deepcopy(messages) | |
for message in messages: | |
content = message["content"] | |
while IMAGE_PLACEHOLDER in content: | |
if num_image_tokens >= len(image_grid_thw): | |
raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER)) | |
content = content.replace( | |
IMAGE_PLACEHOLDER, | |
"<|vision_start|>{}<|vision_end|>".format( | |
self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length) | |
), | |
1, | |
) | |
num_image_tokens += 1 | |
while VIDEO_PLACEHOLDER in content: | |
if num_video_tokens >= len(video_grid_thw): | |
raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER)) | |
content = content.replace( | |
VIDEO_PLACEHOLDER, | |
"<|vision_start|>{}<|vision_end|>".format( | |
self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length) | |
), | |
1, | |
) | |
num_video_tokens += 1 | |
message["content"] = content | |
if len(images) != num_image_tokens: | |
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) | |
if len(videos) != num_video_tokens: | |
raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER)) | |
return messages | |
def get_mm_inputs( | |
self, | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
imglens: Sequence[int], | |
vidlens: Sequence[int], | |
seqlens: Sequence[int], | |
processor: Optional["ProcessorMixin"], | |
) -> Dict[str, Union[List[int], "torch.Tensor"]]: | |
self._validate_input(images, videos) | |
return self._get_mm_inputs(images, videos, processor) | |
class VideoLlavaPlugin(BasePlugin): | |
def process_messages( | |
self, | |
messages: Sequence[Dict[str, str]], | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
processor: Optional["ProcessorMixin"], | |
) -> List[Dict[str, str]]: | |
self._validate_input(images, videos) | |
num_image_tokens = 0 | |
num_video_tokens = 0 | |
messages = deepcopy(messages) | |
mm_inputs = self._get_mm_inputs(images, videos, processor) | |
num_frames = 0 | |
exist_images = "pixel_values_images" in mm_inputs | |
exist_videos = "pixel_values_videos" in mm_inputs | |
if exist_videos or exist_images: | |
if exist_images: | |
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0])) | |
num_frames = 1 | |
if exist_videos: | |
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) | |
height, width = get_image_size(pixel_values_video[0]) | |
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim | |
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1 | |
video_seqlen = image_seqlen * num_frames | |
if processor.vision_feature_select_strategy == "default": | |
image_seqlen -= 1 | |
for message in messages: | |
content = message["content"] | |
while self.image_token in content: | |
num_image_tokens += 1 | |
content = content.replace(self.image_token, "{{image}}", 1) | |
while self.video_token in content: | |
num_video_tokens += 1 | |
content = content.replace(self.video_token, "{{video}}", 1) | |
content = content.replace("{{image}}", self.image_token * image_seqlen) | |
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen) | |
if len(images) != num_image_tokens: | |
raise ValueError("The number of images does not match the number of {} tokens".format(self.image_token)) | |
if len(videos) != num_video_tokens: | |
raise ValueError("The number of videos does not match the number of {} tokens".format(self.video_token)) | |
return messages | |
def get_mm_inputs( | |
self, | |
images: Sequence["ImageInput"], | |
videos: Sequence["VideoInput"], | |
imglens: Sequence[int], | |
vidlens: Sequence[int], | |
seqlens: Sequence[int], | |
processor: Optional["ProcessorMixin"], | |
) -> Dict[str, Union[List[int], "torch.Tensor"]]: | |
self._validate_input(images, videos) | |
return self._get_mm_inputs(images, videos, processor) | |
PLUGINS = { | |
"base": BasePlugin, | |
"llava": LlavaPlugin, | |
"llava_next": LlavaNextPlugin, | |
"llava_next_video": LlavaNextVideoPlugin, | |
"paligemma": PaliGemmaPlugin, | |
"qwen2_vl": Qwen2vlPlugin, | |
"video_llava": VideoLlavaPlugin, | |
} | |
def get_mm_plugin( | |
name: str, | |
image_token: Optional[str] = None, | |
video_token: Optional[str] = None, | |
) -> "BasePlugin": | |
plugin_class = PLUGINS.get(name, None) | |
if plugin_class is None: | |
raise ValueError("Multimodal plugin `{}` not found.".format(name)) | |
return plugin_class(image_token, video_token) | |