|
from typing import Dict, List, Any |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
import logging |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
class EndpointHandler: |
|
def __init__(self, path: str = ""): |
|
logger.info(f"Initializing EndpointHandler with model path: {path}") |
|
try: |
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
logger.info("Tokenizer loaded successfully") |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
path, |
|
device_map="auto" |
|
) |
|
logger.info(f"Model loaded successfully. Device map: {self.model.device}") |
|
|
|
self.model.eval() |
|
logger.info("Model set to evaluation mode") |
|
|
|
|
|
self.default_params = { |
|
"max_new_tokens": 100, |
|
"temperature": 0.0, |
|
"top_p": 0.9, |
|
"top_k": 50, |
|
"repetition_penalty": 1.1, |
|
"do_sample": True |
|
} |
|
logger.info(f"Default generation parameters: {self.default_params}") |
|
except Exception as e: |
|
logger.error(f"Error during initialization: {str(e)}") |
|
raise |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]: |
|
"""Handle chat completion requests. |
|
|
|
Args: |
|
data: Dictionary containing: |
|
- messages: List of message dictionaries with 'role' and 'content' |
|
- generation_params: Optional dictionary of generation parameters |
|
|
|
Returns: |
|
List containing the generated response message |
|
""" |
|
try: |
|
logger.info("Processing new request") |
|
logger.info(f"Input data: {data}") |
|
|
|
messages = data.get("messages", []) |
|
if not messages: |
|
logger.warning("No input messages provided") |
|
return [{"role": "assistant", "content": "No input messages provided"}] |
|
|
|
|
|
gen_params = {**self.default_params, **data.get("generation_params", {})} |
|
logger.info(f"Generation parameters: {gen_params}") |
|
|
|
|
|
logger.info("Applying chat template") |
|
prompt = self.tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
logger.info(f"Generated prompt: {prompt}") |
|
|
|
|
|
logger.info("Tokenizing input") |
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
|
logger.info(f"Input shape: {inputs.input_ids.shape}") |
|
|
|
|
|
logger.info("Generating response") |
|
with torch.no_grad(): |
|
output_tokens = self.model.generate( |
|
**inputs, |
|
**gen_params |
|
) |
|
logger.debug(f"Output shape: {output_tokens.shape}") |
|
|
|
|
|
logger.debug("Decoding response") |
|
output_text = self.tokenizer.batch_decode(output_tokens)[0] |
|
|
|
|
|
response = output_text[len(prompt):].strip() |
|
logger.info(f"Generated response length: {len(response)}") |
|
logger.debug(f"Generated response: {response}") |
|
|
|
return [{"role": "assistant", "content": response}] |
|
|
|
except Exception as e: |
|
logger.error(f"Error during generation: {str(e)}", exc_info=True) |
|
raise |