File size: 5,492 Bytes
8f0b71b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from typing import Any, Dict, List
import os
import logging
import re
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from collections import deque

logging.basicConfig(level=logging.DEBUG)

dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16


class EndpointHandler:
    def __init__(self, path=""):
        self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            return_dict=True,
            device_map="auto",
            load_in_8bit=True,
            torch_dtype=dtype,
            trust_remote_code=True,
        )
        generation_config = self.model.generation_config
        generation_config.max_new_tokens = 140
        generation_config.temperature = 0.7
        generation_config.top_p = 0.7
        generation_config.num_return_sequences = 1
        generation_config.pad_token_id = self.tokenizer.eos_token_id
        generation_config.eos_token_id = self.tokenizer.eos_token_id
        generation_config.early_stopping = True
        self.generate_config = generation_config
        
        self.pipeline = transformers.pipeline(
            "text-generation", model=self.model, tokenizer=self.tokenizer
        )

    def _ensure_token_limit(self, tokens: List[int]) -> List[int]:
        MAX_TOKEN_COUNT = 1024
        if len(tokens) > MAX_TOKEN_COUNT:
            # Keep only the last 2048 tokens
            return tokens[-MAX_TOKEN_COUNT:]
        return tokens 

    def _extract_response(self, text: str) -> str:
        # Check for the start of the bot's or assistant's response
        bot_start = text.find("<bot> response:") + len("<bot> response:")
        assistant_start = text.find("<assistant> response:") + len("<assistant> response:")
        response_start = max(bot_start, assistant_start)
        
        # If neither bot nor assistant start marker is found, set to the beginning of the text
        if response_start == -1 or (assistant_start == len("<assistant> response:") and bot_start == len("<bot> response:")):
            response_start = 0
    
        # Extract everything after the bot's or assistant's start marker until the next "User:" content
        user_response_start = text.find("User:", response_start)
        if user_response_start != -1:
            end_point = user_response_start
        else:
            end_point = len(text)
        
        # Return only the bot's or assistant's response, removing "User:" content
        bot_response = text[response_start:end_point].strip()
        return bot_response

    def _truncate_conversation(self, conversation: str, max_tokens: int = 512) -> str:
        # Split the conversation into exchanges
        exchanges = re.split(r'(?=User:|Assistant:)', conversation)
        while len(exchanges) > 0:
            tokenized_conv = self.tokenizer.encode(' '.join(exchanges), truncation=False)
            if len(tokenized_conv) <= max_tokens:
                return ' '.join(exchanges)
            exchanges.pop(0)  # Remove the oldest exchange
        return ""  # If all exchanges are removed, return an empty string.


    def generate_response(self, user_prompt, additional_context=None):
        if additional_context:
            truncated_conversation = self._truncate_conversation(additional_context)
        else:
            truncated_conversation = ""
        permanent_context = ("<context>: You are a life coaching bot with the goal of providing guidance, improving understanding, reducing suffering and improving life. Gain as much understanding of the user before providing guidance with detailed actionable steps.")
        structured_prompt = f"{permanent_context}\n{truncated_conversation}\n<user>: {user_prompt}"
        
        structured_prompt += "<bot> response:"
        
        input_ids = self.tokenizer.encode(structured_prompt, return_tensors="pt")
        stop_token_ids = [self.tokenizer.encode(token)[0] for token in ['<bot>', 'User ']]  # assuming these tokens are single tokens in your tokenizer
        
        max_length = 1024
        outputs = input_ids
    
        while len(outputs[0]) < max_length:
            # Generate next token
            next_token_logits = self.model(outputs).logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
            
            # Check if the token is in stop_tokens list
            if any(token.item() in stop_token_ids for token in next_token):
                break
            
            # Append the next_token to the outputs
            outputs = torch.cat([outputs, next_token], dim=-1)
    
        response_text = self._extract_response(self.tokenizer.decode(outputs[0])).strip()
    
        return response_text
    
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        try:
            debug_info = "No debug info available."
            user_prompt = data.get("inputs", data)
            
            self.prev_user_message = user_prompt
            response_text = self.generate_response(user_prompt)
            
            return [{"generated_text": response_text, "debug_info": debug_info}]
        except Exception as e:
            logging.error(f"An error occurred in __call__ method: {e}")
            return [{"generated_text": str(e), "debug_info": debug_info}]