fletch1300 commited on
Commit
8f0b71b
1 Parent(s): b058d61

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +123 -0
handler.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+ import os
3
+ import logging
4
+ import re
5
+ import torch
6
+ import transformers
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from collections import deque
9
+
10
+ logging.basicConfig(level=logging.DEBUG)
11
+
12
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
13
+
14
+
15
+ class EndpointHandler:
16
+ def __init__(self, path=""):
17
+ self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
18
+ self.tokenizer.pad_token = self.tokenizer.eos_token
19
+ self.model = AutoModelForCausalLM.from_pretrained(
20
+ path,
21
+ return_dict=True,
22
+ device_map="auto",
23
+ load_in_8bit=True,
24
+ torch_dtype=dtype,
25
+ trust_remote_code=True,
26
+ )
27
+ generation_config = self.model.generation_config
28
+ generation_config.max_new_tokens = 140
29
+ generation_config.temperature = 0.7
30
+ generation_config.top_p = 0.7
31
+ generation_config.num_return_sequences = 1
32
+ generation_config.pad_token_id = self.tokenizer.eos_token_id
33
+ generation_config.eos_token_id = self.tokenizer.eos_token_id
34
+ generation_config.early_stopping = True
35
+ self.generate_config = generation_config
36
+
37
+ self.pipeline = transformers.pipeline(
38
+ "text-generation", model=self.model, tokenizer=self.tokenizer
39
+ )
40
+
41
+ def _ensure_token_limit(self, tokens: List[int]) -> List[int]:
42
+ MAX_TOKEN_COUNT = 1024
43
+ if len(tokens) > MAX_TOKEN_COUNT:
44
+ # Keep only the last 2048 tokens
45
+ return tokens[-MAX_TOKEN_COUNT:]
46
+ return tokens
47
+
48
+ def _extract_response(self, text: str) -> str:
49
+ # Check for the start of the bot's or assistant's response
50
+ bot_start = text.find("<bot> response:") + len("<bot> response:")
51
+ assistant_start = text.find("<assistant> response:") + len("<assistant> response:")
52
+ response_start = max(bot_start, assistant_start)
53
+
54
+ # If neither bot nor assistant start marker is found, set to the beginning of the text
55
+ if response_start == -1 or (assistant_start == len("<assistant> response:") and bot_start == len("<bot> response:")):
56
+ response_start = 0
57
+
58
+ # Extract everything after the bot's or assistant's start marker until the next "User:" content
59
+ user_response_start = text.find("User:", response_start)
60
+ if user_response_start != -1:
61
+ end_point = user_response_start
62
+ else:
63
+ end_point = len(text)
64
+
65
+ # Return only the bot's or assistant's response, removing "User:" content
66
+ bot_response = text[response_start:end_point].strip()
67
+ return bot_response
68
+
69
+ def _truncate_conversation(self, conversation: str, max_tokens: int = 512) -> str:
70
+ # Split the conversation into exchanges
71
+ exchanges = re.split(r'(?=User:|Assistant:)', conversation)
72
+ while len(exchanges) > 0:
73
+ tokenized_conv = self.tokenizer.encode(' '.join(exchanges), truncation=False)
74
+ if len(tokenized_conv) <= max_tokens:
75
+ return ' '.join(exchanges)
76
+ exchanges.pop(0) # Remove the oldest exchange
77
+ return "" # If all exchanges are removed, return an empty string.
78
+
79
+
80
+ def generate_response(self, user_prompt, additional_context=None):
81
+ if additional_context:
82
+ truncated_conversation = self._truncate_conversation(additional_context)
83
+ else:
84
+ truncated_conversation = ""
85
+ permanent_context = ("<context>: You are a life coaching bot with the goal of providing guidance, improving understanding, reducing suffering and improving life. Gain as much understanding of the user before providing guidance with detailed actionable steps.")
86
+ structured_prompt = f"{permanent_context}\n{truncated_conversation}\n<user>: {user_prompt}"
87
+
88
+ structured_prompt += "<bot> response:"
89
+
90
+ input_ids = self.tokenizer.encode(structured_prompt, return_tensors="pt")
91
+ stop_token_ids = [self.tokenizer.encode(token)[0] for token in ['<bot>', 'User ']] # assuming these tokens are single tokens in your tokenizer
92
+
93
+ max_length = 1024
94
+ outputs = input_ids
95
+
96
+ while len(outputs[0]) < max_length:
97
+ # Generate next token
98
+ next_token_logits = self.model(outputs).logits[:, -1, :]
99
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
100
+
101
+ # Check if the token is in stop_tokens list
102
+ if any(token.item() in stop_token_ids for token in next_token):
103
+ break
104
+
105
+ # Append the next_token to the outputs
106
+ outputs = torch.cat([outputs, next_token], dim=-1)
107
+
108
+ response_text = self._extract_response(self.tokenizer.decode(outputs[0])).strip()
109
+
110
+ return response_text
111
+
112
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
113
+ try:
114
+ debug_info = "No debug info available."
115
+ user_prompt = data.get("inputs", data)
116
+
117
+ self.prev_user_message = user_prompt
118
+ response_text = self.generate_response(user_prompt)
119
+
120
+ return [{"generated_text": response_text, "debug_info": debug_info}]
121
+ except Exception as e:
122
+ logging.error(f"An error occurred in __call__ method: {e}")
123
+ return [{"generated_text": str(e), "debug_info": debug_info}]