llavaguard / llava_utils /generator.py
Ahren09's picture
Upload 227 files
5ca4e86 verified
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
from llava.conversation import conv_llava_llama_2, SeparatorStyle
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
self.tokenizer = tokenizer
self.start_len = None
self.input_ids = input_ids
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if self.start_len is None:
self.start_len = self.input_ids.shape[1]
else:
for keyword_id in self.keyword_ids:
if output_ids[0, -1] == keyword_id:
return True
outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
class Generator:
def __init__(self, model, tokenizer, max_new_tokens=1024, temperature=0.7, device='cuda:0'):
self.model = model
self.device = device
self.tokenizer = tokenizer
self.max_new_tokens = max_new_tokens
self.temperature = temperature
self.stop_str = conv_llava_llama_2.sep if conv_llava_llama_2.sep_style != SeparatorStyle.TWO else conv_llava_llama_2.sep2
self.keywords = [self.stop_str]
def generate(self, prompt, image):
input_ids = prompt.input_ids[0]
stopping_criteria = KeywordsStoppingCriteria(self.keywords, self.tokenizer, input_ids)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=image.half(),
do_sample=True,
temperature=0.6,
top_p=0.9,
min_new_tokens=128,
max_new_tokens=1024,
use_cache=True,
stopping_criteria=[stopping_criteria]
)
input_token_len = input_ids.shape[1]
print(input_token_len)
print(output_ids.shape)
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(self.stop_str):
outputs = outputs[:-len(self.stop_str)]
outputs = outputs.strip()
return outputs