vikram-fresche's picture
added custom handler v5
ff143df
raw
history blame
3.99 kB
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path: str = ""):
logger.info(f"Initializing EndpointHandler with model path: {path}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(path)
logger.info("Tokenizer loaded successfully")
self.model = AutoModelForCausalLM.from_pretrained(
path,
device_map="auto"
)
logger.info(f"Model loaded successfully. Device map: {self.model.device}")
self.model.eval()
logger.info("Model set to evaluation mode")
# Default generation parameters
self.default_params = {
"max_new_tokens": 100,
"temperature": 0.0,
"top_p": 0.9,
"top_k": 50,
"repetition_penalty": 1.1,
"do_sample": True
}
logger.info(f"Default generation parameters: {self.default_params}")
except Exception as e:
logger.error(f"Error during initialization: {str(e)}")
raise
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
"""Handle chat completion requests.
Args:
data: Dictionary containing:
- messages: List of message dictionaries with 'role' and 'content'
- generation_params: Optional dictionary of generation parameters
Returns:
List containing the generated response message
"""
try:
logger.info("Processing new request")
logger.info(f"Input data: {data}")
messages = data.get("inputs", [])
if not messages:
logger.warning("No input messages provided")
return [{"role": "assistant", "content": "No input messages provided"}]
# Get generation parameters, use defaults for missing values
gen_params = {**self.default_params, **data.get("generation_params", {})}
logger.info(f"Generation parameters: {gen_params}")
# Apply the chat template
logger.info("Applying chat template")
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
logger.info(f"Generated prompt: {prompt}")
# Tokenize the prompt
logger.info("Tokenizing input")
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
logger.info(f"Input shape: {inputs.input_ids.shape}")
# Generate response
logger.info("Generating response")
with torch.no_grad():
output_tokens = self.model.generate(
**inputs,
**gen_params
)
logger.debug(f"Output shape: {output_tokens.shape}")
# Decode the response
logger.debug("Decoding response")
output_text = self.tokenizer.batch_decode(output_tokens)[0]
# Extract the assistant's response by removing the input prompt
response = output_text[len(prompt):].strip()
logger.info(f"Generated response length: {len(response)}")
logger.debug(f"Generated response: {response}")
return [{"role": "assistant", "content": response}]
except Exception as e:
logger.error(f"Error during generation: {str(e)}", exc_info=True)
raise