Spaces:
Runtime error
Runtime error
import torch | |
from transformers import StoppingCriteria, StoppingCriteriaList | |
class StoppingCriteriaSub(StoppingCriteria): | |
def __init__(self, stops=[], encounters=1): | |
super().__init__() | |
self.stops = stops | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): | |
for stop in self.stops: | |
if torch.all((stop == input_ids[0][-len(stop):])).item(): | |
return True | |
return False | |
class Generator: | |
def __init__(self, model, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, | |
repetition_penalty=1.0, length_penalty=1, temperature=1.0, device='cuda:0', remove_invalid_values=True): | |
self.model = model | |
self.device = device | |
self.max_new_tokens = max_new_tokens | |
self.num_beams = num_beams | |
self.min_length = min_length | |
self.top_p = top_p | |
self.repetition_penalty = repetition_penalty | |
self.length_penalty = length_penalty | |
self.temperature = temperature | |
self.remove_invalid_values = remove_invalid_values | |
stop_words_ids = [torch.tensor([835]).to(self.device), | |
torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways. | |
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) | |
def generate(self, prompt): | |
outputs = self.model.llama_model.generate( | |
inputs_embeds=prompt.context_embs[0], | |
max_new_tokens=self.max_new_tokens, | |
stopping_criteria=self.stopping_criteria, | |
num_beams=self.num_beams, | |
do_sample=True, | |
min_length=self.min_length, | |
top_p=self.top_p, | |
repetition_penalty=self.repetition_penalty, | |
length_penalty=self.length_penalty, | |
temperature=self.temperature, | |
remove_invalid_values=self.remove_invalid_values | |
) | |
output_token = outputs[0] | |
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it | |
output_token = output_token[1:] | |
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it | |
output_token = output_token[1:] | |
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False) | |
output_text = output_text.split('###')[0] # remove the stop sign '###' | |
output_text = output_text.split('Assistant:')[-1].strip() | |
return output_text, output_token.cpu().numpy() |