vikram-fresche commited on
Commit
44a5e55
·
1 Parent(s): 9146a36

added custom handler v2

Browse files
Files changed (1) hide show
  1. handler.py +61 -13
handler.py CHANGED
@@ -1,17 +1,65 @@
1
  from typing import Dict, List, Any
2
- from transformers import pipeline
 
3
 
4
- class EndpointHandler():
5
- def __init__(self, path=""):
6
- self.model = pipeline("text-generation", model=path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
 
 
 
 
 
 
 
 
9
  """
10
- data args:
11
- inputs (:obj: `str` | `PIL.Image` | `np.array`)
12
- kwargs
13
- Return:
14
- A :obj:`list` | `dict`: will be serialized and returned
15
- """
16
- inputs = data.pop("inputs", data)
17
- return self.model(inputs, **data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Dict, List, Any
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
 
5
+ class EndpointHandler:
6
+ def __init__(self, path: str = ""):
7
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
8
+ self.model = AutoModelForCausalLM.from_pretrained(
9
+ path,
10
+ device_map="auto"
11
+ )
12
+ self.model.eval()
13
+
14
+ # Default generation parameters
15
+ self.default_params = {
16
+ "max_new_tokens": 100,
17
+ "temperature": 0.0,
18
+ "top_p": 0.9,
19
+ "top_k": 50,
20
+ "repetition_penalty": 1.1,
21
+ "do_sample": True
22
+ }
23
 
24
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
25
+ """Handle chat completion requests.
26
+
27
+ Args:
28
+ data: Dictionary containing:
29
+ - messages: List of message dictionaries with 'role' and 'content'
30
+ - generation_params: Optional dictionary of generation parameters
31
+
32
+ Returns:
33
+ List containing the generated response message
34
  """
35
+ messages = data.get("messages", [])
36
+ if not messages:
37
+ return [{"role": "assistant", "content": "No input messages provided"}]
38
+
39
+ # Get generation parameters, use defaults for missing values
40
+ gen_params = {**self.default_params, **data.get("generation_params", {})}
41
+
42
+ # Apply the chat template
43
+ prompt = self.tokenizer.apply_chat_template(
44
+ messages,
45
+ tokenize=False,
46
+ add_generation_prompt=True
47
+ )
48
+
49
+ # Tokenize the prompt
50
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
51
+
52
+ # Generate response
53
+ with torch.no_grad():
54
+ output_tokens = self.model.generate(
55
+ **inputs,
56
+ **gen_params
57
+ )
58
+
59
+ # Decode the response
60
+ output_text = self.tokenizer.batch_decode(output_tokens)[0]
61
+
62
+ # Extract the assistant's response by removing the input prompt
63
+ response = output_text[len(prompt):].strip()
64
+
65
+ return [{"role": "assistant", "content": response}]