import torch from transformers import StoppingCriteria, StoppingCriteriaList class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops=[], encounters=1): super().__init__() self.stops = stops def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): for stop in self.stops: if torch.all((stop == input_ids[0][-len(stop):])).item(): return True return False class Generator: def __init__(self, model, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, repetition_penalty=1.0, length_penalty=1, temperature=1.0, device='cuda:0', remove_invalid_values=True): self.model = model self.device = device self.max_new_tokens = max_new_tokens self.num_beams = num_beams self.min_length = min_length self.top_p = top_p self.repetition_penalty = repetition_penalty self.length_penalty = length_penalty self.temperature = temperature self.remove_invalid_values = remove_invalid_values stop_words_ids = [torch.tensor([835]).to(self.device), torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways. self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) def generate(self, prompt): outputs = self.model.llama_model.generate( inputs_embeds=prompt.context_embs[0], max_new_tokens=self.max_new_tokens, stopping_criteria=self.stopping_criteria, num_beams=self.num_beams, do_sample=True, min_length=self.min_length, top_p=self.top_p, repetition_penalty=self.repetition_penalty, length_penalty=self.length_penalty, temperature=self.temperature, remove_invalid_values=self.remove_invalid_values ) output_token = outputs[0] if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it output_token = output_token[1:] if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it output_token = output_token[1:] output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False) output_text = output_text.split('###')[0] # remove the stop sign '###' output_text = output_text.split('Assistant:')[-1].strip() return output_text, output_token.cpu().numpy()