Batch inference
#7
by
d-rau
- opened
Could you please give an example how to do batch inference?
Hello @d-rau ,
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import requests
# device = torch.device("cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_num_threads(4)
processor = AutoProcessor.from_pretrained(
'allenai/Molmo-7B-D-0924',
trust_remote_code=True,
torch_dtype=torch.float32
)
model = AutoModelForCausalLM.from_pretrained(
'allenai/Molmo-7B-D-0924',
trust_remote_code=True,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
device_map=None
)
urls = [
"https://picsum.photos/id/237/536/354",
"https://picsum.photos/id/238/536/354",
"https://picsum.photos/id/239/536/354"
]
prompts = [
"What breed is this dog?",
"Describe the colors in this image.",
"Is this an indoor or outdoor scene?"
]
def process_single_image(image, prompt):
inputs = processor.process(
images=[image],
text=prompt,
return_tensors="pt"
)
if 'input_ids' in inputs:
if len(inputs['input_ids'].shape) == 1:
inputs['input_ids'] = inputs['input_ids'].unsqueeze(0)
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
print(f"{k}: {v.shape}")
return inputs
all_processed_inputs = []
for url, prompt in zip(urls, prompts):
image = Image.open(requests.get(url, stream=True).raw).convert('RGB')
inputs = process_single_image(image, prompt)
all_processed_inputs.append(inputs)
batch_inputs = {}
input_lengths = [inputs['input_ids'].size(-1) for inputs in all_processed_inputs]
max_length = max(input_lengths)
padded_input_ids = []
for inputs in all_processed_inputs:
curr_len = inputs['input_ids'].size(-1)
if curr_len < max_length:
padding = torch.zeros((1, max_length - curr_len), dtype=inputs['input_ids'].dtype)
padded = torch.cat([inputs['input_ids'], padding], dim=1)
else:
padded = inputs['input_ids']
padded_input_ids.append(padded)
batch_inputs['input_ids'] = torch.cat(padded_input_ids, dim=0)
batch_inputs['images'] = torch.stack([inputs['images'] for inputs in all_processed_inputs], dim=0)
batch_inputs['image_input_idx'] = torch.stack([inputs['image_input_idx'] for inputs in all_processed_inputs], dim=0)
if 'image_masks' in all_processed_inputs[0]:
batch_inputs['image_masks'] = torch.stack([inputs['image_masks'] for inputs in all_processed_inputs], dim=0)
with torch.inference_mode():
outputs = model.generate_from_batch(
batch_inputs,
GenerationConfig(
max_new_tokens=200,
stop_strings="<|endoftext|>",
pad_token_id=processor.tokenizer.pad_token_id,
num_beams=1
),
tokenizer=processor.tokenizer
)
start_idx = batch_inputs['input_ids'].size(1)
generated_tokens = outputs[:, start_idx:]
generated_texts = processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
for url, prompt, text in zip(urls, prompts, generated_texts):
print("\nPrompt:", prompt)
print("\nGenerated text:", text)
I found that the snippet above had an issue with the outputs. I think the image tensors are not being handled properly.
Here's code that worked for me:
import numpy as np
import requests
import torch
from PIL import Image, ImageOps
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from typing import List, Dict
processor = AutoProcessor.from_pretrained(
"allenai/Molmo-7B-D-0924",
trust_remote_code=True,
torch_dtype=torch.float32,
device_map="auto",
)
model = AutoModelForCausalLM.from_pretrained(
"allenai/Molmo-7B-D-0924",
trust_remote_code=True,
torch_dtype=torch.float32,
device_map="auto",
)
urls = [
"https://picsum.photos/id/237/536/354",
"https://picsum.photos/id/238/536/354",
"https://picsum.photos/id/239/536/354",
]
prompts = [
"What breed is this dog?",
"Describe the colors in this image.",
"Is this an indoor or outdoor scene?",
]
images_list = []
for url in urls:
response = requests.get(url)
image = Image.open(requests.get(url, stream=True).raw)
images_list.append([image])
texts = ["User: " + prompt + " Assistant:" for prompt in prompts]
def process_batch(
processor: AutoProcessor,
texts: List[str],
images_list: List[List[Image.Image]]
) -> Dict[str, torch.Tensor]:
"""
Process in batch.
Args:
processor: The original processor.
texts: List of text inputs
images_list: List of lists containing PIL images.
Returns:
Dict with padded input_ids, images, image_input_idx, image_masks.
"""
batch_size = len(texts)
tokens_list = []
for text in texts:
tokens = processor.tokenizer.encode(" " + text, add_special_tokens=False)
tokens_list.append(tokens)
images_arrays_list = []
image_idxs_list = []
for images in images_list:
if images:
image_arrays = []
for image in images:
if isinstance(image, Image.Image):
image = image.convert("RGB")
image = ImageOps.exif_transpose(image)
image_arrays.append(np.array(image))
else:
assert len(image.shape) == 3 and image.shape[-1] == 3
image_arrays.append(image.astype(np.uint8))
images_arrays_list.append(image_arrays)
image_idx = [-1] * len(image_arrays)
image_idxs_list.append(image_idx)
else:
images_arrays_list.append(None)
image_idxs_list.append(None)
images_kwargs = {
"max_crops": 12,
"overlap_margins": [4, 4],
"base_image_input_size": [336, 336],
"image_token_length_w": 12,
"image_token_length_h": 12,
"image_patch_size": 14,
"image_padding_mask": True,
}
outputs_list = []
for i in range(batch_size):
tokens = tokens_list[i]
images = images_arrays_list[i]
image_idx = image_idxs_list[i]
out = processor.image_processor.multimodal_preprocess(
images=images,
image_idx=image_idx,
tokens=np.asarray(tokens).astype(np.int32),
sequence_length=1536,
image_patch_token_id=processor.special_token_ids["<im_patch>"],
image_col_token_id=processor.special_token_ids["<im_col>"],
image_start_token_id=processor.special_token_ids["<im_start>"],
image_end_token_id=processor.special_token_ids["<im_end>"],
**images_kwargs,
)
outputs_list.append(out)
batch_outputs = {}
for key in outputs_list[0].keys():
tensors = [torch.from_numpy(out[key]) for out in outputs_list]
batch_outputs[key] = torch.nn.utils.rnn.pad_sequence(
tensors, batch_first=True, padding_value=-1
)
bos = processor.tokenizer.bos_token_id or processor.tokenizer.eos_token_id
batch_outputs["input_ids"] = torch.nn.functional.pad(
batch_outputs["input_ids"], (1, 0), value=bos
)
if "image_input_idx" in batch_outputs:
image_input_idx = batch_outputs["image_input_idx"]
batch_outputs["image_input_idx"] = torch.where(
image_input_idx < 0, image_input_idx, image_input_idx + 1
)
return batch_outputs
inputs = process_batch(processor, texts, images_list)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
output = model.generate_from_batch(
inputs,
GenerationConfig(
max_new_tokens=200,
stop_sequences=["<|endoftext|>"],
eos_token_id=processor.tokenizer.eos_token_id,
pad_token_id=processor.tokenizer.pad_token_id,
),
tokenizer=processor.tokenizer,
)
generated_texts = processor.tokenizer.batch_decode(
output[:, inputs["input_ids"].size(1) :], skip_special_tokens=True
)
for prompt, text in zip(prompts, generated_texts):
print(f"\nPrompt: {prompt}")
print(f"Response: {text}")
Hopefully this is helpful :) Honestly, I think it should be handled in the code though!