File size: 3,994 Bytes
9146a36
f72bb15
 
 
9146a36
f72bb15
 
 
 
 
 
9146a36
f72bb15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9146a36
f72bb15
 
0e056a8
f72bb15
1e856c7
f72bb15
 
 
 
 
 
 
 
 
0e056a8
f72bb15
 
 
 
 
0e056a8
f72bb15
 
0e056a8
f72bb15
0e056a8
f72bb15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

class EndpointHandler:
    def __init__(self, path: str = ""):
        logger.info(f"Initializing EndpointHandler with model path: {path}")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(path)
            logger.info("Tokenizer loaded successfully")
            
            self.model = AutoModelForCausalLM.from_pretrained(
                path,
                device_map="auto"
            )
            logger.info(f"Model loaded successfully. Device map: {self.model.device}")
            
            self.model.eval()
            logger.info("Model set to evaluation mode")
            
            # Default generation parameters
            self.default_params = {
                "max_new_tokens": 100,
                "temperature": 0.0,
                "top_p": 0.9,
                "top_k": 50,
                "repetition_penalty": 1.1,
                "do_sample": True
            }
            logger.info(f"Default generation parameters: {self.default_params}")
        except Exception as e:
            logger.error(f"Error during initialization: {str(e)}")
            raise

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
        """Handle chat completion requests.
        
        Args:
            data: Dictionary containing:
                - messages: List of message dictionaries with 'role' and 'content'
                - generation_params: Optional dictionary of generation parameters
                
        Returns:
            List containing the generated response message
        """
        try:
            logger.info("Processing new request")
            logger.info(f"Input data: {data}")
            
            messages = data.get("inputs", [])
            if not messages:
                logger.warning("No input messages provided")
                return [{"role": "assistant", "content": "No input messages provided"}]
            
            # Get generation parameters, use defaults for missing values
            gen_params = {**self.default_params, **data.get("generation_params", {})}
            logger.info(f"Generation parameters: {gen_params}")
            
            # Apply the chat template
            logger.info("Applying chat template")
            prompt = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            logger.info(f"Generated prompt: {prompt}")
            
            # Tokenize the prompt
            logger.info("Tokenizing input")
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
            logger.info(f"Input shape: {inputs.input_ids.shape}")
            
            # Generate response
            logger.info("Generating response")
            with torch.no_grad():
                output_tokens = self.model.generate(
                    **inputs,
                    **gen_params
                )
            logger.debug(f"Output shape: {output_tokens.shape}")
            
            # Decode the response
            logger.debug("Decoding response")
            output_text = self.tokenizer.batch_decode(output_tokens)[0]
            
            # Extract the assistant's response by removing the input prompt
            response = output_text[len(prompt):].strip()
            logger.info(f"Generated response length: {len(response)}")
            logger.debug(f"Generated response: {response}")
            
            return [{"role": "assistant", "content": response}]
            
        except Exception as e:
            logger.error(f"Error during generation: {str(e)}", exc_info=True)
            raise