fletch1300 commited on
Commit
7fb554f
1 Parent(s): 3b9144c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +35 -85
handler.py CHANGED
@@ -1,21 +1,15 @@
1
  from typing import Any, Dict, List
2
- import os
3
- import logging
4
- import re
5
  import torch
6
  import transformers
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
- from collections import deque
9
 
10
- logging.basicConfig(level=logging.DEBUG)
11
 
12
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
13
 
14
-
15
  class EndpointHandler:
16
  def __init__(self, path=""):
17
  self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
18
- self.tokenizer.pad_token = self.tokenizer.eos_token
19
  self.model = AutoModelForCausalLM.from_pretrained(
20
  path,
21
  return_dict=True,
@@ -24,10 +18,11 @@ class EndpointHandler:
24
  torch_dtype=dtype,
25
  trust_remote_code=True,
26
  )
 
27
  generation_config = self.model.generation_config
28
- generation_config.max_new_tokens = 140
29
- generation_config.temperature = 0.7
30
- generation_config.top_p = 0.7
31
  generation_config.num_return_sequences = 1
32
  generation_config.pad_token_id = self.tokenizer.eos_token_id
33
  generation_config.eos_token_id = self.tokenizer.eos_token_id
@@ -38,86 +33,41 @@ class EndpointHandler:
38
  "text-generation", model=self.model, tokenizer=self.tokenizer
39
  )
40
 
41
- def _ensure_token_limit(self, tokens: List[int]) -> List[int]:
42
- MAX_TOKEN_COUNT = 1024
43
- if len(tokens) > MAX_TOKEN_COUNT:
44
- # Keep only the last 2048 tokens
45
- return tokens[-MAX_TOKEN_COUNT:]
46
- return tokens
47
 
48
- def _extract_response(self, text: str) -> str:
49
- # Check for the start of the bot's or assistant's response
50
- bot_start = text.find("<bot> response:") + len("<bot> response:")
51
- assistant_start = text.find("<assistant> response:") + len("<assistant> response:")
52
- response_start = max(bot_start, assistant_start)
53
-
54
- # If neither bot nor assistant start marker is found, set to the beginning of the text
55
- if response_start == -1 or (assistant_start == len("<assistant> response:") and bot_start == len("<bot> response:")):
56
- response_start = 0
57
 
58
- # Extract everything after the bot's or assistant's start marker until the next "User:" content
59
- user_response_start = text.find("User:", response_start)
60
- if user_response_start != -1:
61
- end_point = user_response_start
62
- else:
63
- end_point = len(text)
64
 
65
- # Return only the bot's or assistant's response, removing "User:" content
66
- bot_response = text[response_start:end_point].strip()
67
- return bot_response
68
-
69
- def _truncate_conversation(self, conversation: str, max_tokens: int = 512) -> str:
70
- # Split the conversation into exchanges
71
- exchanges = re.split(r'(?=User:|Assistant:)', conversation)
72
- while len(exchanges) > 0:
73
- tokenized_conv = self.tokenizer.encode(' '.join(exchanges), truncation=False)
74
- if len(tokenized_conv) <= max_tokens:
75
- return ' '.join(exchanges)
76
- exchanges.pop(0) # Remove the oldest exchange
77
- return "" # If all exchanges are removed, return an empty string.
78
-
79
 
80
- def generate_response(self, user_prompt, additional_context=None):
81
- if additional_context:
82
- truncated_conversation = self._truncate_conversation(additional_context)
83
- else:
84
- truncated_conversation = ""
85
- permanent_context = ("<context>: You are a life coaching bot with the goal of providing guidance, improving understanding, reducing suffering and improving life. Gain as much understanding of the user before providing guidance with detailed actionable steps.")
86
- structured_prompt = f"{permanent_context}\n{truncated_conversation}\n<user>: {user_prompt}"
87
 
88
- structured_prompt += "<bot> response:"
89
 
90
- input_ids = self.tokenizer.encode(structured_prompt, return_tensors="pt")
91
- stop_token_ids = [self.tokenizer.encode(token)[0] for token in ['<bot>', 'User ']] # assuming these tokens are single tokens in your tokenizer
92
 
93
- max_length = 1024
94
- outputs = input_ids
95
-
96
- while len(outputs[0]) < max_length:
97
- # Generate next token
98
- next_token_logits = self.model(outputs).logits[:, -1, :]
99
- next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
100
-
101
- # Check if the token is in stop_tokens list
102
- if any(token.item() in stop_token_ids for token in next_token):
103
- break
104
-
105
- # Append the next_token to the outputs
106
- outputs = torch.cat([outputs, next_token], dim=-1)
107
-
108
- response_text = self._extract_response(self.tokenizer.decode(outputs[0])).strip()
109
-
110
- return response_text
111
 
112
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
113
- try:
114
- debug_info = "No debug info available."
115
- user_prompt = data.get("inputs", data)
116
-
117
- self.prev_user_message = user_prompt
118
- response_text = self.generate_response(user_prompt)
119
-
120
- return [{"generated_text": response_text, "debug_info": debug_info}]
121
- except Exception as e:
122
- logging.error(f"An error occurred in __call__ method: {e}")
123
- return [{"generated_text": str(e), "debug_info": debug_info}]
 
1
  from typing import Any, Dict, List
2
+
 
 
3
  import torch
4
  import transformers
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
6
 
 
7
 
8
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
9
 
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
  self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
 
13
  self.model = AutoModelForCausalLM.from_pretrained(
14
  path,
15
  return_dict=True,
 
18
  torch_dtype=dtype,
19
  trust_remote_code=True,
20
  )
21
+
22
  generation_config = self.model.generation_config
23
+ generation_config.max_new_tokens = 200
24
+ generation_config.temperature = 0.8
25
+ generation_config.top_p = 0.8
26
  generation_config.num_return_sequences = 1
27
  generation_config.pad_token_id = self.tokenizer.eos_token_id
28
  generation_config.eos_token_id = self.tokenizer.eos_token_id
 
33
  "text-generation", model=self.model, tokenizer=self.tokenizer
34
  )
35
 
 
 
 
 
 
 
36
 
37
+ def _ensure_token_limit(self, text):
38
+ """Ensure text is within the model's token limit."""
39
+ tokens = self.tokenizer.tokenize(text)
40
+ if len(tokens) > 2048:
41
+ # Remove tokens from the beginning until the text fits
42
+ tokens = tokens[-2048:]
43
+ return self.tokenizer.decode(tokens)
44
+ return text
 
45
 
46
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
47
+ user_prompt = data.pop("inputs", data)
 
 
 
 
48
 
49
+ # Add the user's message to the conversation history
50
+ self.conversation_history += f"<user>: {user_prompt}\n"
51
+
52
+ # Ensure the conversation history is within token limit
53
+ self.conversation_history = self._ensure_token_limit(self.conversation_history)
 
 
 
 
 
 
 
 
 
54
 
55
+ # Add the permanent context, user's prompt, and conversation history
56
+ permanent_context = "<context>: You are a life coaching bot with the goal of providing guidance, improving understanding, reducing suffering and improving life. Gain as much understanding of the user before providing guidance."
57
+ structured_prompt = f"{permanent_context}\n{self.conversation_history}<bot> response:"
 
 
 
 
58
 
59
+ result = self.pipeline(structured_prompt, generation_config=self.generate_config)
60
 
61
+ # Extract only the bot's response without the structuring text
62
+ response_text = self._extract_response(result[0]['generated_text'])
63
 
64
+ # Remove the last "<bot>" from the response_text
65
+ response_text = response_text.rsplit("[END", 1)[0].strip()
66
+
67
+ # Add the bot's response to the conversation history
68
+ self.conversation_history += f"<bot>: {response_text}\n"
69
+ self.conversation_history = self._ensure_token_limit(self.conversation_history)
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ return [{"generated_text": response_text}]
72
+
73
+ return {"response": response_text}