Spaces:
Runtime error
Runtime error
import torch | |
from llava.conversation import conv_llava_llama_2 | |
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN | |
from llava.mm_utils import tokenizer_image_token | |
def prepare_text_prompt(user_prompt): | |
qs = DEFAULT_IMAGE_TOKEN + '\n'+ user_prompt | |
conv = conv_llava_llama_2.copy() | |
conv.append_message(conv.roles[0], qs) | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
return prompt | |
# support batch implementation | |
class Prompt: | |
# tokenization | |
# turn to embeddings | |
# padding? wait until targets have been appended | |
# prepare labels? need to wait for targets | |
def __init__(self, model, tokenizer, text_prompts=None, device='cuda:0',max_new_tokens=300, max_length=2000): | |
self.model = model | |
self.tokenizer = tokenizer | |
self.device = device | |
self.text_prompts = text_prompts | |
self.img_prompts = [[]] | |
self.context_length = [] | |
self.input_ids = [] | |
self.do_tokenization(self.text_prompts) | |
self.max_new_tokens = max_new_tokens | |
self.max_length = max_length | |
self.text_embs = self.generate_text_embedding(self.text_prompts) | |
self.img_embs = [[]] | |
self.update_context_embs() | |
def do_tokenization(self, text_prompts): | |
if text_prompts is None: | |
self.input_ids = [] | |
self.context_length = [] | |
return | |
if type(text_prompts) is list: | |
text_prompts = text_prompts[0] | |
input_ids = tokenizer_image_token(text_prompts, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() | |
self.input_ids = [input_ids] | |
self.context_length = [input_ids.shape[1]] | |
def update_context_embs(self): | |
if len(self.text_embs) == len(self.img_embs): | |
self.context_embs = self.generate_context_embedding( | |
self.text_embs, self.img_embs | |
) | |
else: | |
self.context_embs = [] | |
def update_text_prompt(self, text_prompts): | |
self.text_prompts = text_prompts | |
self.text_embs = self.generate_text_embedding(self.text_prompts) | |
self.update_context_embs() | |
def generate_text_embedding(self, text_prompts): | |
if text_prompts is None: | |
return [] | |
text_embs = [] | |
for item in text_prompts: # for each prompt within a batch | |
prompt_segs = item.split('<image>') # each <ImageHere> corresponds to one image | |
seg_tokens = [ | |
self.tokenizer( | |
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids | |
# only add bos to the first seg | |
for i, seg in enumerate(prompt_segs) | |
] | |
embs = [self.model.model.embed_tokens(seg_t) for seg_t in seg_tokens] # text to embeddings | |
text_embs.append(embs) | |
return text_embs | |
def generate_context_embedding(self, batch_text_embs, batch_img_embs): | |
#assert len(text_embs) == len(img_embs) + 1, "Unmatched numbers of image placeholders and images." | |
assert len(batch_text_embs) == len(batch_img_embs), "Unmathced batch size of text and image prompts" | |
batch_size = len(batch_text_embs) | |
batch_context_embs = [] | |
for i in range(batch_size): | |
mixed_embs = torch.cat(batch_text_embs[i], dim=1) | |
current_max_len = mixed_embs.shape[1] + self.max_new_tokens | |
if current_max_len - self.max_length > 0: | |
print('Warning: The number of tokens in current conversation exceeds the max length. ' | |
'The model will not see the contexts outside the range.') | |
begin_idx = max(0, current_max_len - self.max_length) | |
mixed_embs = mixed_embs[:, begin_idx:] | |
batch_context_embs.append(mixed_embs) | |
return batch_context_embs | |