vikram-fresche commited on
Commit
5a6c04a
·
1 Parent(s): 44a5e55

added custom handler v2

Browse files
Files changed (1) hide show
  1. handler.py +83 -45
handler.py CHANGED
@@ -1,25 +1,44 @@
1
  from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
 
 
 
 
 
 
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path: str = ""):
7
- self.tokenizer = AutoTokenizer.from_pretrained(path)
8
- self.model = AutoModelForCausalLM.from_pretrained(
9
- path,
10
- device_map="auto"
11
- )
12
- self.model.eval()
13
-
14
- # Default generation parameters
15
- self.default_params = {
16
- "max_new_tokens": 100,
17
- "temperature": 0.0,
18
- "top_p": 0.9,
19
- "top_k": 50,
20
- "repetition_penalty": 1.1,
21
- "do_sample": True
22
- }
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
25
  """Handle chat completion requests.
@@ -32,34 +51,53 @@ class EndpointHandler:
32
  Returns:
33
  List containing the generated response message
34
  """
35
- messages = data.get("messages", [])
36
- if not messages:
37
- return [{"role": "assistant", "content": "No input messages provided"}]
38
 
39
- # Get generation parameters, use defaults for missing values
40
- gen_params = {**self.default_params, **data.get("generation_params", {})}
41
-
42
- # Apply the chat template
43
- prompt = self.tokenizer.apply_chat_template(
44
- messages,
45
- tokenize=False,
46
- add_generation_prompt=True
47
- )
48
-
49
- # Tokenize the prompt
50
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
51
-
52
- # Generate response
53
- with torch.no_grad():
54
- output_tokens = self.model.generate(
55
- **inputs,
56
- **gen_params
57
  )
58
-
59
- # Decode the response
60
- output_text = self.tokenizer.batch_decode(output_tokens)[0]
61
-
62
- # Extract the assistant's response by removing the input prompt
63
- response = output_text[len(prompt):].strip()
64
-
65
- return [{"role": "assistant", "content": response}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
+ import logging
5
+
6
+ # Configure logging
7
+ logging.basicConfig(
8
+ level=logging.INFO,
9
+ format='%(asctime)s - %(levelname)s - %(message)s'
10
+ )
11
+ logger = logging.getLogger(__name__)
12
 
13
  class EndpointHandler:
14
  def __init__(self, path: str = ""):
15
+ logger.info(f"Initializing EndpointHandler with model path: {path}")
16
+ try:
17
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
18
+ logger.info("Tokenizer loaded successfully")
19
+
20
+ self.model = AutoModelForCausalLM.from_pretrained(
21
+ path,
22
+ device_map="auto"
23
+ )
24
+ logger.info(f"Model loaded successfully. Device map: {self.model.device}")
25
+
26
+ self.model.eval()
27
+ logger.info("Model set to evaluation mode")
28
+
29
+ # Default generation parameters
30
+ self.default_params = {
31
+ "max_new_tokens": 100,
32
+ "temperature": 0.0,
33
+ "top_p": 0.9,
34
+ "top_k": 50,
35
+ "repetition_penalty": 1.1,
36
+ "do_sample": True
37
+ }
38
+ logger.info(f"Default generation parameters: {self.default_params}")
39
+ except Exception as e:
40
+ logger.error(f"Error during initialization: {str(e)}")
41
+ raise
42
 
43
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
44
  """Handle chat completion requests.
 
51
  Returns:
52
  List containing the generated response message
53
  """
54
+ try:
55
+ logger.info("Processing new request")
56
+ logger.debug(f"Input data: {data}")
57
 
58
+ messages = data.get("messages", [])
59
+ if not messages:
60
+ logger.warning("No input messages provided")
61
+ return [{"role": "assistant", "content": "No input messages provided"}]
62
+
63
+ # Get generation parameters, use defaults for missing values
64
+ gen_params = {**self.default_params, **data.get("generation_params", {})}
65
+ logger.info(f"Generation parameters: {gen_params}")
66
+
67
+ # Apply the chat template
68
+ logger.debug("Applying chat template")
69
+ prompt = self.tokenizer.apply_chat_template(
70
+ messages,
71
+ tokenize=False,
72
+ add_generation_prompt=True
 
 
 
73
  )
74
+ logger.debug(f"Generated prompt: {prompt}")
75
+
76
+ # Tokenize the prompt
77
+ logger.debug("Tokenizing input")
78
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
79
+ logger.debug(f"Input shape: {inputs.input_ids.shape}")
80
+
81
+ # Generate response
82
+ logger.info("Generating response")
83
+ with torch.no_grad():
84
+ output_tokens = self.model.generate(
85
+ **inputs,
86
+ **gen_params
87
+ )
88
+ logger.debug(f"Output shape: {output_tokens.shape}")
89
+
90
+ # Decode the response
91
+ logger.debug("Decoding response")
92
+ output_text = self.tokenizer.batch_decode(output_tokens)[0]
93
+
94
+ # Extract the assistant's response by removing the input prompt
95
+ response = output_text[len(prompt):].strip()
96
+ logger.info(f"Generated response length: {len(response)}")
97
+ logger.debug(f"Generated response: {response}")
98
+
99
+ return [{"role": "assistant", "content": response}]
100
+
101
+ except Exception as e:
102
+ logger.error(f"Error during generation: {str(e)}", exc_info=True)
103
+ raise