homen_testing_merged6 / handler.py
fletch1300's picture
Create handler.py
8f0b71b
raw
history blame
No virus
5.49 kB
from typing import Any, Dict, List
import os
import logging
import re
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from collections import deque
logging.basicConfig(level=logging.DEBUG)
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
path,
return_dict=True,
device_map="auto",
load_in_8bit=True,
torch_dtype=dtype,
trust_remote_code=True,
)
generation_config = self.model.generation_config
generation_config.max_new_tokens = 140
generation_config.temperature = 0.7
generation_config.top_p = 0.7
generation_config.num_return_sequences = 1
generation_config.pad_token_id = self.tokenizer.eos_token_id
generation_config.eos_token_id = self.tokenizer.eos_token_id
generation_config.early_stopping = True
self.generate_config = generation_config
self.pipeline = transformers.pipeline(
"text-generation", model=self.model, tokenizer=self.tokenizer
)
def _ensure_token_limit(self, tokens: List[int]) -> List[int]:
MAX_TOKEN_COUNT = 1024
if len(tokens) > MAX_TOKEN_COUNT:
# Keep only the last 2048 tokens
return tokens[-MAX_TOKEN_COUNT:]
return tokens
def _extract_response(self, text: str) -> str:
# Check for the start of the bot's or assistant's response
bot_start = text.find("<bot> response:") + len("<bot> response:")
assistant_start = text.find("<assistant> response:") + len("<assistant> response:")
response_start = max(bot_start, assistant_start)
# If neither bot nor assistant start marker is found, set to the beginning of the text
if response_start == -1 or (assistant_start == len("<assistant> response:") and bot_start == len("<bot> response:")):
response_start = 0
# Extract everything after the bot's or assistant's start marker until the next "User:" content
user_response_start = text.find("User:", response_start)
if user_response_start != -1:
end_point = user_response_start
else:
end_point = len(text)
# Return only the bot's or assistant's response, removing "User:" content
bot_response = text[response_start:end_point].strip()
return bot_response
def _truncate_conversation(self, conversation: str, max_tokens: int = 512) -> str:
# Split the conversation into exchanges
exchanges = re.split(r'(?=User:|Assistant:)', conversation)
while len(exchanges) > 0:
tokenized_conv = self.tokenizer.encode(' '.join(exchanges), truncation=False)
if len(tokenized_conv) <= max_tokens:
return ' '.join(exchanges)
exchanges.pop(0) # Remove the oldest exchange
return "" # If all exchanges are removed, return an empty string.
def generate_response(self, user_prompt, additional_context=None):
if additional_context:
truncated_conversation = self._truncate_conversation(additional_context)
else:
truncated_conversation = ""
permanent_context = ("<context>: You are a life coaching bot with the goal of providing guidance, improving understanding, reducing suffering and improving life. Gain as much understanding of the user before providing guidance with detailed actionable steps.")
structured_prompt = f"{permanent_context}\n{truncated_conversation}\n<user>: {user_prompt}"
structured_prompt += "<bot> response:"
input_ids = self.tokenizer.encode(structured_prompt, return_tensors="pt")
stop_token_ids = [self.tokenizer.encode(token)[0] for token in ['<bot>', 'User ']] # assuming these tokens are single tokens in your tokenizer
max_length = 1024
outputs = input_ids
while len(outputs[0]) < max_length:
# Generate next token
next_token_logits = self.model(outputs).logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# Check if the token is in stop_tokens list
if any(token.item() in stop_token_ids for token in next_token):
break
# Append the next_token to the outputs
outputs = torch.cat([outputs, next_token], dim=-1)
response_text = self._extract_response(self.tokenizer.decode(outputs[0])).strip()
return response_text
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
try:
debug_info = "No debug info available."
user_prompt = data.get("inputs", data)
self.prev_user_message = user_prompt
response_text = self.generate_response(user_prompt)
return [{"generated_text": response_text, "debug_info": debug_info}]
except Exception as e:
logging.error(f"An error occurred in __call__ method: {e}")
return [{"generated_text": str(e), "debug_info": debug_info}]