Batch inference

#7
by d-rau - opened

Could you please give an example how to do batch inference?

Hello @d-rau ,

import torch
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import requests

# device = torch.device("cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_num_threads(4)

processor = AutoProcessor.from_pretrained(
    'allenai/Molmo-7B-D-0924',
    trust_remote_code=True,
    torch_dtype=torch.float32
)
model = AutoModelForCausalLM.from_pretrained(
    'allenai/Molmo-7B-D-0924',
    trust_remote_code=True,
    torch_dtype=torch.float32,
    low_cpu_mem_usage=True,
    device_map=None
)
urls = [
    "https://picsum.photos/id/237/536/354",
    "https://picsum.photos/id/238/536/354",
    "https://picsum.photos/id/239/536/354"
]
prompts = [
    "What breed is this dog?",
    "Describe the colors in this image.",
    "Is this an indoor or outdoor scene?"
]
def process_single_image(image, prompt):
    inputs = processor.process(
        images=[image],
        text=prompt,
        return_tensors="pt"
    )
    if 'input_ids' in inputs:
        if len(inputs['input_ids'].shape) == 1:
            inputs['input_ids'] = inputs['input_ids'].unsqueeze(0)
    for k, v in inputs.items():
        if isinstance(v, torch.Tensor):
            print(f"{k}: {v.shape}")
    return inputs

all_processed_inputs = []
for url, prompt in zip(urls, prompts):
    image = Image.open(requests.get(url, stream=True).raw).convert('RGB')
    inputs = process_single_image(image, prompt)
    all_processed_inputs.append(inputs)
batch_inputs = {}

input_lengths = [inputs['input_ids'].size(-1) for inputs in all_processed_inputs]
max_length = max(input_lengths)

padded_input_ids = []
for inputs in all_processed_inputs:
    curr_len = inputs['input_ids'].size(-1)
    if curr_len < max_length:
        padding = torch.zeros((1, max_length - curr_len), dtype=inputs['input_ids'].dtype)
        padded = torch.cat([inputs['input_ids'], padding], dim=1)
    else:
        padded = inputs['input_ids']
    padded_input_ids.append(padded)

batch_inputs['input_ids'] = torch.cat(padded_input_ids, dim=0)
batch_inputs['images'] = torch.stack([inputs['images'] for inputs in all_processed_inputs], dim=0)
batch_inputs['image_input_idx'] = torch.stack([inputs['image_input_idx'] for inputs in all_processed_inputs], dim=0)
if 'image_masks' in all_processed_inputs[0]:
    batch_inputs['image_masks'] = torch.stack([inputs['image_masks'] for inputs in all_processed_inputs], dim=0)

with torch.inference_mode():
    outputs = model.generate_from_batch(
        batch_inputs,
        GenerationConfig(
            max_new_tokens=200,
            stop_strings="<|endoftext|>",
            pad_token_id=processor.tokenizer.pad_token_id,
            num_beams=1
        ),
        tokenizer=processor.tokenizer
    )
start_idx = batch_inputs['input_ids'].size(1)
generated_tokens = outputs[:, start_idx:]
generated_texts = processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
for url, prompt, text in zip(urls, prompts, generated_texts):
    print("\nPrompt:", prompt)
    print("\nGenerated text:", text)

I found that the snippet above had an issue with the outputs. I think the image tensors are not being handled properly.

Here's code that worked for me:

import numpy as np
import requests
import torch
from PIL import Image, ImageOps
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from typing import List, Dict

processor = AutoProcessor.from_pretrained(
    "allenai/Molmo-7B-D-0924",
    trust_remote_code=True,
    torch_dtype=torch.float32,
    device_map="auto",
)
model = AutoModelForCausalLM.from_pretrained(
    "allenai/Molmo-7B-D-0924",
    trust_remote_code=True,
    torch_dtype=torch.float32,
    device_map="auto",
)
urls = [
    "https://picsum.photos/id/237/536/354",
    "https://picsum.photos/id/238/536/354",
    "https://picsum.photos/id/239/536/354",
]
prompts = [
    "What breed is this dog?",
    "Describe the colors in this image.",
    "Is this an indoor or outdoor scene?",
]

images_list = []
for url in urls:
    response = requests.get(url)
    image = Image.open(requests.get(url, stream=True).raw)
    images_list.append([image])

texts = ["User: " + prompt + " Assistant:" for prompt in prompts]


def process_batch(
    processor: AutoProcessor,
    texts: List[str],
    images_list: List[List[Image.Image]]
) -> Dict[str, torch.Tensor]:
    """
    Process in batch.
    
    Args:
        processor: The original processor.
        texts: List of text inputs
        images_list: List of lists containing PIL images.
        
    Returns:
        Dict with padded input_ids, images, image_input_idx, image_masks.
    """
    batch_size = len(texts)
    tokens_list = []
    for text in texts:
        tokens = processor.tokenizer.encode(" " + text, add_special_tokens=False)
        tokens_list.append(tokens)
    images_arrays_list = []
    image_idxs_list = []
    for images in images_list:
        if images:
            image_arrays = []
            for image in images:
                if isinstance(image, Image.Image):
                    image = image.convert("RGB")
                    image = ImageOps.exif_transpose(image)
                    image_arrays.append(np.array(image))
                else:
                    assert len(image.shape) == 3 and image.shape[-1] == 3
                    image_arrays.append(image.astype(np.uint8))
            images_arrays_list.append(image_arrays)
            image_idx = [-1] * len(image_arrays)
            image_idxs_list.append(image_idx)
        else:
            images_arrays_list.append(None)
            image_idxs_list.append(None)
    images_kwargs = {
        "max_crops": 12,
        "overlap_margins": [4, 4],
        "base_image_input_size": [336, 336],
        "image_token_length_w": 12,
        "image_token_length_h": 12,
        "image_patch_size": 14,
        "image_padding_mask": True,
    }
    outputs_list = []
    for i in range(batch_size):
        tokens = tokens_list[i]
        images = images_arrays_list[i]
        image_idx = image_idxs_list[i]
        out = processor.image_processor.multimodal_preprocess(
            images=images,
            image_idx=image_idx,
            tokens=np.asarray(tokens).astype(np.int32),
            sequence_length=1536,
            image_patch_token_id=processor.special_token_ids["<im_patch>"],
            image_col_token_id=processor.special_token_ids["<im_col>"],
            image_start_token_id=processor.special_token_ids["<im_start>"],
            image_end_token_id=processor.special_token_ids["<im_end>"],
            **images_kwargs,
        )
        outputs_list.append(out)

    batch_outputs = {}
    for key in outputs_list[0].keys():
        tensors = [torch.from_numpy(out[key]) for out in outputs_list]
        batch_outputs[key] = torch.nn.utils.rnn.pad_sequence(
            tensors, batch_first=True, padding_value=-1
        )
    bos = processor.tokenizer.bos_token_id or processor.tokenizer.eos_token_id
    batch_outputs["input_ids"] = torch.nn.functional.pad(
        batch_outputs["input_ids"], (1, 0), value=bos
    )
    if "image_input_idx" in batch_outputs:
        image_input_idx = batch_outputs["image_input_idx"]
        batch_outputs["image_input_idx"] = torch.where(
            image_input_idx < 0, image_input_idx, image_input_idx + 1
        )
    return batch_outputs


inputs = process_batch(processor, texts, images_list)

inputs = {k: v.to(model.device) for k, v in inputs.items()}

output = model.generate_from_batch(
    inputs,
    GenerationConfig(
        max_new_tokens=200,
        stop_sequences=["<|endoftext|>"],
        eos_token_id=processor.tokenizer.eos_token_id,
        pad_token_id=processor.tokenizer.pad_token_id,
    ),
    tokenizer=processor.tokenizer,
)

generated_texts = processor.tokenizer.batch_decode(
    output[:, inputs["input_ids"].size(1) :], skip_special_tokens=True
)
for prompt, text in zip(prompts, generated_texts):
    print(f"\nPrompt: {prompt}")
    print(f"Response: {text}")

Hopefully this is helpful :) Honestly, I think it should be handled in the code though!

Sign up or log in to comment