from typing import Dict, List, Any from transformers import AutoTokenizer, AutoModelForCausalLM import torch import logging # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path: str = ""): logger.info(f"Initializing EndpointHandler with model path: {path}") try: self.tokenizer = AutoTokenizer.from_pretrained(path) logger.info("Tokenizer loaded successfully") self.model = AutoModelForCausalLM.from_pretrained( path, device_map="auto" ) logger.info(f"Model loaded successfully. Device map: {self.model.device}") self.model.eval() logger.info("Model set to evaluation mode") # Default generation parameters self.default_params = { "max_new_tokens": 100, "temperature": 0.0, "top_p": 0.9, "top_k": 50, "repetition_penalty": 1.1, "do_sample": True } logger.info(f"Default generation parameters: {self.default_params}") except Exception as e: logger.error(f"Error during initialization: {str(e)}") raise def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]: """Handle chat completion requests. Args: data: Dictionary containing: - messages: List of message dictionaries with 'role' and 'content' - generation_params: Optional dictionary of generation parameters Returns: List containing the generated response message """ try: logger.info("Processing new request") logger.info(f"Input data: {data}") messages = data.get("inputs", []) if not messages: logger.warning("No input messages provided") return [{"role": "assistant", "content": "No input messages provided"}] # Get generation parameters, use defaults for missing values gen_params = {**self.default_params, **data.get("generation_params", {})} logger.info(f"Generation parameters: {gen_params}") # Apply the chat template logger.info("Applying chat template") prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) logger.info(f"Generated prompt: {prompt}") # Tokenize the prompt logger.info("Tokenizing input") inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) logger.info(f"Input shape: {inputs.input_ids.shape}") # Generate response logger.info("Generating response") with torch.no_grad(): output_tokens = self.model.generate( **inputs, **gen_params ) logger.debug(f"Output shape: {output_tokens.shape}") # Decode the response logger.debug("Decoding response") output_text = self.tokenizer.batch_decode(output_tokens)[0] # Extract the assistant's response by removing the input prompt response = output_text[len(prompt):].strip() logger.info(f"Generated response length: {len(response)}") logger.debug(f"Generated response: {response}") return [{"role": "assistant", "content": response}] except Exception as e: logger.error(f"Error during generation: {str(e)}", exc_info=True) raise