vikram-fresche commited on
Commit
f72bb15
·
verified ·
1 Parent(s): 9146a36

handler_v2 (#2)

Browse files

- added custom handler v2 (44a5e555878070616b760618de1c3f5f4a18241a)
- added custom handler v2 (5a6c04a4c27038c562422fb8a54023504d49859f)

Files changed (1) hide show
  1. handler.py +99 -13
handler.py CHANGED
@@ -1,17 +1,103 @@
1
  from typing import Dict, List, Any
2
- from transformers import pipeline
 
 
3
 
4
- class EndpointHandler():
5
- def __init__(self, path=""):
6
- self.model = pipeline("text-generation", model=path)
 
 
 
7
 
8
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
9
- """
10
- data args:
11
- inputs (:obj: `str` | `PIL.Image` | `np.array`)
12
- kwargs
13
- Return:
14
- A :obj:`list` | `dict`: will be serialized and returned
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  """
16
- inputs = data.pop("inputs", data)
17
- return self.model(inputs, **data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Dict, List, Any
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+ import logging
5
 
6
+ # Configure logging
7
+ logging.basicConfig(
8
+ level=logging.INFO,
9
+ format='%(asctime)s - %(levelname)s - %(message)s'
10
+ )
11
+ logger = logging.getLogger(__name__)
12
 
13
+ class EndpointHandler:
14
+ def __init__(self, path: str = ""):
15
+ logger.info(f"Initializing EndpointHandler with model path: {path}")
16
+ try:
17
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
18
+ logger.info("Tokenizer loaded successfully")
19
+
20
+ self.model = AutoModelForCausalLM.from_pretrained(
21
+ path,
22
+ device_map="auto"
23
+ )
24
+ logger.info(f"Model loaded successfully. Device map: {self.model.device}")
25
+
26
+ self.model.eval()
27
+ logger.info("Model set to evaluation mode")
28
+
29
+ # Default generation parameters
30
+ self.default_params = {
31
+ "max_new_tokens": 100,
32
+ "temperature": 0.0,
33
+ "top_p": 0.9,
34
+ "top_k": 50,
35
+ "repetition_penalty": 1.1,
36
+ "do_sample": True
37
+ }
38
+ logger.info(f"Default generation parameters: {self.default_params}")
39
+ except Exception as e:
40
+ logger.error(f"Error during initialization: {str(e)}")
41
+ raise
42
+
43
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
44
+ """Handle chat completion requests.
45
+
46
+ Args:
47
+ data: Dictionary containing:
48
+ - messages: List of message dictionaries with 'role' and 'content'
49
+ - generation_params: Optional dictionary of generation parameters
50
+
51
+ Returns:
52
+ List containing the generated response message
53
  """
54
+ try:
55
+ logger.info("Processing new request")
56
+ logger.debug(f"Input data: {data}")
57
+
58
+ messages = data.get("messages", [])
59
+ if not messages:
60
+ logger.warning("No input messages provided")
61
+ return [{"role": "assistant", "content": "No input messages provided"}]
62
+
63
+ # Get generation parameters, use defaults for missing values
64
+ gen_params = {**self.default_params, **data.get("generation_params", {})}
65
+ logger.info(f"Generation parameters: {gen_params}")
66
+
67
+ # Apply the chat template
68
+ logger.debug("Applying chat template")
69
+ prompt = self.tokenizer.apply_chat_template(
70
+ messages,
71
+ tokenize=False,
72
+ add_generation_prompt=True
73
+ )
74
+ logger.debug(f"Generated prompt: {prompt}")
75
+
76
+ # Tokenize the prompt
77
+ logger.debug("Tokenizing input")
78
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
79
+ logger.debug(f"Input shape: {inputs.input_ids.shape}")
80
+
81
+ # Generate response
82
+ logger.info("Generating response")
83
+ with torch.no_grad():
84
+ output_tokens = self.model.generate(
85
+ **inputs,
86
+ **gen_params
87
+ )
88
+ logger.debug(f"Output shape: {output_tokens.shape}")
89
+
90
+ # Decode the response
91
+ logger.debug("Decoding response")
92
+ output_text = self.tokenizer.batch_decode(output_tokens)[0]
93
+
94
+ # Extract the assistant's response by removing the input prompt
95
+ response = output_text[len(prompt):].strip()
96
+ logger.info(f"Generated response length: {len(response)}")
97
+ logger.debug(f"Generated response: {response}")
98
+
99
+ return [{"role": "assistant", "content": response}]
100
+
101
+ except Exception as e:
102
+ logger.error(f"Error during generation: {str(e)}", exc_info=True)
103
+ raise