llavaguard / llava_utils /prompt_wrapper.py
Ahren09's picture
Upload 227 files
5ca4e86 verified
import torch
from llava.conversation import conv_llava_llama_2
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import tokenizer_image_token
def prepare_text_prompt(user_prompt):
qs = DEFAULT_IMAGE_TOKEN + '\n'+ user_prompt
conv = conv_llava_llama_2.copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
return prompt
# support batch implementation
class Prompt:
# tokenization
# turn to embeddings
# padding? wait until targets have been appended
# prepare labels? need to wait for targets
def __init__(self, model, tokenizer, text_prompts=None, device='cuda:0',max_new_tokens=300, max_length=2000):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.text_prompts = text_prompts
self.img_prompts = [[]]
self.context_length = []
self.input_ids = []
self.do_tokenization(self.text_prompts)
self.max_new_tokens = max_new_tokens
self.max_length = max_length
self.text_embs = self.generate_text_embedding(self.text_prompts)
self.img_embs = [[]]
self.update_context_embs()
def do_tokenization(self, text_prompts):
if text_prompts is None:
self.input_ids = []
self.context_length = []
return
if type(text_prompts) is list:
text_prompts = text_prompts[0]
input_ids = tokenizer_image_token(text_prompts, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
self.input_ids = [input_ids]
self.context_length = [input_ids.shape[1]]
def update_context_embs(self):
if len(self.text_embs) == len(self.img_embs):
self.context_embs = self.generate_context_embedding(
self.text_embs, self.img_embs
)
else:
self.context_embs = []
def update_text_prompt(self, text_prompts):
self.text_prompts = text_prompts
self.text_embs = self.generate_text_embedding(self.text_prompts)
self.update_context_embs()
def generate_text_embedding(self, text_prompts):
if text_prompts is None:
return []
text_embs = []
for item in text_prompts: # for each prompt within a batch
prompt_segs = item.split('<image>') # each <ImageHere> corresponds to one image
seg_tokens = [
self.tokenizer(
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
# only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
embs = [self.model.model.embed_tokens(seg_t) for seg_t in seg_tokens] # text to embeddings
text_embs.append(embs)
return text_embs
def generate_context_embedding(self, batch_text_embs, batch_img_embs):
#assert len(text_embs) == len(img_embs) + 1, "Unmatched numbers of image placeholders and images."
assert len(batch_text_embs) == len(batch_img_embs), "Unmathced batch size of text and image prompts"
batch_size = len(batch_text_embs)
batch_context_embs = []
for i in range(batch_size):
mixed_embs = torch.cat(batch_text_embs[i], dim=1)
current_max_len = mixed_embs.shape[1] + self.max_new_tokens
if current_max_len - self.max_length > 0:
print('Warning: The number of tokens in current conversation exceeds the max length. '
'The model will not see the contexts outside the range.')
begin_idx = max(0, current_max_len - self.max_length)
mixed_embs = mixed_embs[:, begin_idx:]
batch_context_embs.append(mixed_embs)
return batch_context_embs