|
from typing import List |
|
from queue import Queue |
|
|
|
import torch |
|
from PIL import Image |
|
from copy import deepcopy |
|
import requests, os |
|
|
|
IMAGE_TOKEN_INDEX=-200 |
|
blacklist = ['<image>', '<s>', '</s>'] |
|
max_num_images = 3 |
|
|
|
def input_moderation(texts: list[list[str]]): |
|
|
|
for text_pair in texts: |
|
|
|
for b in blacklist: |
|
text_pair[0] = text_pair[0].replace(b, '') |
|
if text_pair[1] is not None: |
|
text_pair[1] = text_pair[1].replace(b, '') |
|
|
|
return texts |
|
|
|
def insert_image_placeholder(t, num_images, placeholder='<image>', sep='\n'): |
|
for _ in range(num_images): |
|
t = f"{placeholder}{sep}" + t |
|
return t |
|
|
|
def get_conv(texts): |
|
ret = [] |
|
|
|
for conv in texts: |
|
ret.append({'from': 'human', 'value': conv[0]}) |
|
ret.append({'from': 'gpt', 'value': conv[1]}) |
|
|
|
return ret |
|
|
|
|
|
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): |
|
prompt_chunks = [tokenizer(chunk, add_special_tokens=False).input_ids for chunk in prompt.split('<image>')] |
|
|
|
def insert_separator(X, sep): |
|
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] |
|
|
|
input_ids = [] |
|
offset = 0 |
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: |
|
offset = 1 |
|
input_ids.append(prompt_chunks[0][0]) |
|
|
|
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): |
|
input_ids.extend(x[offset:]) |
|
|
|
if return_tensors is not None: |
|
if return_tensors == 'pt': |
|
return torch.tensor(input_ids, dtype=torch.long) |
|
raise ValueError(f'Unsupported tensor type: {return_tensors}') |
|
return input_ids |
|
|
|
def preprocess(tokenizer, data: list, return_tensors='pt'): |
|
''' |
|
[ |
|
{ |
|
'from': 'human', |
|
'value': xxx, |
|
}, |
|
{ |
|
'from': 'gpt', |
|
'value': xxx |
|
} |
|
] |
|
''' |
|
|
|
if not isinstance(data, list): |
|
raise ValueError('must be a list') |
|
|
|
|
|
return preprocess_allava(tokenizer, data, return_tensors=return_tensors) |
|
|
|
|
|
|
|
|
|
def preprocess_allava(tokenizer, convs: list, return_tensors) -> list: |
|
input_ids = torch.tensor([1]).long() |
|
|
|
for ind, conv in enumerate(convs): |
|
|
|
if ind % 2 == 0: |
|
h = conv['value'].strip() |
|
h = f"<|user|>\n{h}<|end|>\n" |
|
cur_input_ids = tokenizer_image_token(prompt=h, tokenizer=tokenizer, return_tensors=return_tensors) |
|
|
|
|
|
|
|
|
|
if input_ids is None: |
|
input_ids = cur_input_ids |
|
else: |
|
input_ids = torch.cat([input_ids, cur_input_ids]) |
|
|
|
else: |
|
g = conv['value'] |
|
if g is not None: |
|
g = f"<|assistant|>\n{g}<|end|>\n" |
|
cur_input_ids = tokenizer(g, add_special_tokens= False, truncation=True, return_tensors='pt').input_ids[0] |
|
input_ids = torch.cat([input_ids, cur_input_ids]) |
|
else: |
|
g = f'<|assistant|>\n' |
|
|
|
return input_ids |
|
|
|
|
|
|
|
def get_image_tensors(processor, images, device): |
|
list_image_tensors = [] |
|
crop_size = processor.crop_size |
|
for fp in images: |
|
if fp is None: |
|
list_image_tensors.append(torch.zeros(3, crop_size['height'], crop_size['width']).to(device)) |
|
continue |
|
elif isinstance(fp, str): |
|
image = Image.open(fp).convert('RGB') |
|
elif isinstance(fp, Image.Image): |
|
image = fp |
|
else: |
|
raise TypeError(f'Unsupported type {type(fp)}') |
|
|
|
|
|
if True: |
|
|
|
def expand2square(pil_img, background_color): |
|
width, height = pil_img.size |
|
if pil_img.mode == 'L': |
|
pil_img = pil_img.convert('RGB') |
|
|
|
if width == height: |
|
return pil_img |
|
elif width > height: |
|
result = Image.new(pil_img.mode, (width, width), background_color) |
|
result.paste(pil_img, (0, (width - height) // 2)) |
|
return result |
|
else: |
|
result = Image.new(pil_img.mode, (height, height), background_color) |
|
result.paste(pil_img, ((height - width) // 2, 0)) |
|
return result |
|
|
|
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) |
|
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] |
|
else: |
|
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] |
|
list_image_tensors.append(image.to(device)) |
|
|
|
return list_image_tensors |
|
|
|
|
|
|
|
|
|
def build_allava_input(tokenizer, processor, texts, images, history=None, return_history=False, device='cuda'): |
|
''' |
|
texts: [[]] |
|
''' |
|
|
|
|
|
|
|
|
|
if isinstance(texts, str): |
|
texts = [[texts, None]] |
|
else: |
|
assert isinstance(texts, list) and isinstance(texts[0], list) , 'texts must be a list of list' |
|
|
|
if history is not None: |
|
texts = history + texts |
|
|
|
texts = input_moderation(texts) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(images, str) or isinstance(images, Image.Image): |
|
images = [images] |
|
|
|
valid_images = [] |
|
if images is None: |
|
images = [None] |
|
|
|
for img in images: |
|
try: |
|
if os.path.exists(img): |
|
img = Image.open(img).convert('RGB') |
|
else: |
|
img = Image.open(requests.get(img, stream=True).raw) |
|
|
|
valid_images.append(img) |
|
except: |
|
continue |
|
|
|
images = valid_images |
|
|
|
if images == []: |
|
images = [None] |
|
|
|
|
|
assert len(images) < max_num_images, f'Currently at most {max_num_images} images are supported' |
|
|
|
|
|
|
|
|
|
|
|
history = deepcopy(texts) |
|
|
|
|
|
image_place_holder_inserted = insert_image_placeholder(texts[0][0], len(images) if None not in images else 0) |
|
texts[0][0] = image_place_holder_inserted |
|
|
|
|
|
conv = get_conv(texts) |
|
|
|
|
|
input_ids = preprocess(tokenizer, conv, return_tensors='pt').unsqueeze(0).to(device) |
|
|
|
list_image_tensors = get_image_tensors(processor, images, device) |
|
image_tensors = torch.stack(list_image_tensors) |
|
|
|
try: |
|
dtype = torch.bfloat16 |
|
|
|
torch.tensor(1, dtype=dtype).cuda() |
|
except: |
|
|
|
dtype = torch.float16 |
|
|
|
if return_history: |
|
return input_ids, image_tensors, history |
|
|
|
return input_ids, image_tensors, None |
|
|
|
|
|
|
|
class TextIterStreamer: |
|
def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False): |
|
self.tokenizer = tokenizer |
|
self.skip_prompt = skip_prompt |
|
self.skip_special_tokens = skip_special_tokens |
|
self.tokens = [] |
|
self.text_queue = Queue() |
|
self.next_tokens_are_prompt = True |
|
|
|
def put(self, value): |
|
if self.skip_prompt and self.next_tokens_are_prompt: |
|
self.next_tokens_are_prompt = False |
|
else: |
|
if len(value.shape) > 1: |
|
value = value[0] |
|
self.tokens.extend(value.tolist()) |
|
self.text_queue.put( |
|
self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens)) |
|
|
|
def end(self): |
|
self.text_queue.put(None) |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
value = self.text_queue.get() |
|
if value is None: |
|
raise StopIteration() |
|
else: |
|
return value |
|
|