Spaces:
Runtime error
Runtime error
import torch | |
from transformers import StoppingCriteria, StoppingCriteriaList | |
from llava.conversation import conv_llava_llama_2, SeparatorStyle | |
class KeywordsStoppingCriteria(StoppingCriteria): | |
def __init__(self, keywords, tokenizer, input_ids): | |
self.keywords = keywords | |
self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords] | |
self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1] | |
self.tokenizer = tokenizer | |
self.start_len = None | |
self.input_ids = input_ids | |
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
if self.start_len is None: | |
self.start_len = self.input_ids.shape[1] | |
else: | |
for keyword_id in self.keyword_ids: | |
if output_ids[0, -1] == keyword_id: | |
return True | |
outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] | |
for keyword in self.keywords: | |
if keyword in outputs: | |
return True | |
return False | |
class Generator: | |
def __init__(self, model, tokenizer, max_new_tokens=1024, temperature=0.7, device='cuda:0'): | |
self.model = model | |
self.device = device | |
self.tokenizer = tokenizer | |
self.max_new_tokens = max_new_tokens | |
self.temperature = temperature | |
self.stop_str = conv_llava_llama_2.sep if conv_llava_llama_2.sep_style != SeparatorStyle.TWO else conv_llava_llama_2.sep2 | |
self.keywords = [self.stop_str] | |
def generate(self, prompt, image): | |
input_ids = prompt.input_ids[0] | |
stopping_criteria = KeywordsStoppingCriteria(self.keywords, self.tokenizer, input_ids) | |
with torch.inference_mode(): | |
output_ids = self.model.generate( | |
input_ids, | |
images=image.half(), | |
do_sample=True, | |
temperature=0.2, | |
min_new_tokens=128, | |
max_new_tokens=1024, | |
use_cache=True, | |
stopping_criteria=[stopping_criteria] | |
) | |
input_token_len = input_ids.shape[1] | |
print(input_token_len) | |
print(output_ids.shape) | |
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() | |
if n_diff_input_output > 0: | |
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') | |
outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] | |
outputs = outputs.strip() | |
if outputs.endswith(self.stop_str): | |
outputs = outputs[:-len(self.stop_str)] | |
outputs = outputs.strip() | |
return outputs | |