File size: 1,833 Bytes
8f0b71b
7fb554f
8f0b71b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be62e65
8f0b71b
7fb554f
be62e65
7fb554f
8f0b71b
 
 
 
 
 
 
 
 
7fb554f
 
8f0b71b
be62e65
 
 
8f0b71b
be62e65
8f0b71b
be62e65
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from typing import Any, Dict, List

import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16

class EndpointHandler:
    def __init__(self, path=""):
        self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            return_dict=True,
            device_map="auto",
            load_in_8bit=True,
            torch_dtype=dtype,
            trust_remote_code=True,
        )
        
        generation_config = self.model.generation_config
        generation_config.max_new_tokens = 200
        generation_config.temperature = 0.4
        generation_config.top_p = 0.8
        generation_config.num_return_sequences = 1
        generation_config.pad_token_id = self.tokenizer.eos_token_id
        generation_config.eos_token_id = self.tokenizer.eos_token_id
        self.generate_config = generation_config
        
        self.pipeline = transformers.pipeline(
            "text-generation", model=self.model, tokenizer=self.tokenizer
        )

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        user_prompt = data.pop("inputs", data)
        
        # Add the permanent context to the user's prompt
        permanent_context = "<context>: You are a life coaching bot with the goal of improving understanding, reducing suffering and improving life. Learn about the user in order to provide guidance without making assumptions or adding information not provided by the user."
        combined_prompt = f"{permanent_context}\n<human>: {user_prompt}"
        
        result = self.pipeline(combined_prompt, generation_config=self.generate_config)
        
        return result