from typing import Any, Dict, List import os import logging import re import torch import transformers from transformers import AutoModelForCausalLM, AutoTokenizer from collections import deque logging.basicConfig(level=logging.DEBUG) dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 class EndpointHandler: def __init__(self, path=""): self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( path, return_dict=True, device_map="auto", load_in_8bit=True, torch_dtype=dtype, trust_remote_code=True, ) generation_config = self.model.generation_config generation_config.max_new_tokens = 140 generation_config.temperature = 0.7 generation_config.top_p = 0.7 generation_config.num_return_sequences = 1 generation_config.pad_token_id = self.tokenizer.eos_token_id generation_config.eos_token_id = self.tokenizer.eos_token_id generation_config.early_stopping = True self.generate_config = generation_config self.pipeline = transformers.pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer ) def _ensure_token_limit(self, tokens: List[int]) -> List[int]: MAX_TOKEN_COUNT = 1024 if len(tokens) > MAX_TOKEN_COUNT: # Keep only the last 2048 tokens return tokens[-MAX_TOKEN_COUNT:] return tokens def _extract_response(self, text: str) -> str: # Check for the start of the bot's or assistant's response bot_start = text.find(" response:") + len(" response:") assistant_start = text.find(" response:") + len(" response:") response_start = max(bot_start, assistant_start) # If neither bot nor assistant start marker is found, set to the beginning of the text if response_start == -1 or (assistant_start == len(" response:") and bot_start == len(" response:")): response_start = 0 # Extract everything after the bot's or assistant's start marker until the next "User:" content user_response_start = text.find("User:", response_start) if user_response_start != -1: end_point = user_response_start else: end_point = len(text) # Return only the bot's or assistant's response, removing "User:" content bot_response = text[response_start:end_point].strip() return bot_response def _truncate_conversation(self, conversation: str, max_tokens: int = 512) -> str: # Split the conversation into exchanges exchanges = re.split(r'(?=User:|Assistant:)', conversation) while len(exchanges) > 0: tokenized_conv = self.tokenizer.encode(' '.join(exchanges), truncation=False) if len(tokenized_conv) <= max_tokens: return ' '.join(exchanges) exchanges.pop(0) # Remove the oldest exchange return "" # If all exchanges are removed, return an empty string. def generate_response(self, user_prompt, additional_context=None): if additional_context: truncated_conversation = self._truncate_conversation(additional_context) else: truncated_conversation = "" permanent_context = (": You are a life coaching bot with the goal of providing guidance, improving understanding, reducing suffering and improving life. Gain as much understanding of the user before providing guidance with detailed actionable steps.") structured_prompt = f"{permanent_context}\n{truncated_conversation}\n: {user_prompt}" structured_prompt += " response:" input_ids = self.tokenizer.encode(structured_prompt, return_tensors="pt") stop_token_ids = [self.tokenizer.encode(token)[0] for token in ['', 'User ']] # assuming these tokens are single tokens in your tokenizer max_length = 1024 outputs = input_ids while len(outputs[0]) < max_length: # Generate next token next_token_logits = self.model(outputs).logits[:, -1, :] next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) # Check if the token is in stop_tokens list if any(token.item() in stop_token_ids for token in next_token): break # Append the next_token to the outputs outputs = torch.cat([outputs, next_token], dim=-1) response_text = self._extract_response(self.tokenizer.decode(outputs[0])).strip() return response_text def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: try: debug_info = "No debug info available." user_prompt = data.get("inputs", data) self.prev_user_message = user_prompt response_text = self.generate_response(user_prompt) return [{"generated_text": response_text, "debug_info": debug_info}] except Exception as e: logging.error(f"An error occurred in __call__ method: {e}") return [{"generated_text": str(e), "debug_info": debug_info}]