File size: 1,833 Bytes
8f0b71b 7fb554f 8f0b71b be62e65 8f0b71b 7fb554f be62e65 7fb554f 8f0b71b 7fb554f 8f0b71b be62e65 8f0b71b be62e65 8f0b71b be62e65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
from typing import Any, Dict, List
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
path,
return_dict=True,
device_map="auto",
load_in_8bit=True,
torch_dtype=dtype,
trust_remote_code=True,
)
generation_config = self.model.generation_config
generation_config.max_new_tokens = 200
generation_config.temperature = 0.4
generation_config.top_p = 0.8
generation_config.num_return_sequences = 1
generation_config.pad_token_id = self.tokenizer.eos_token_id
generation_config.eos_token_id = self.tokenizer.eos_token_id
self.generate_config = generation_config
self.pipeline = transformers.pipeline(
"text-generation", model=self.model, tokenizer=self.tokenizer
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
user_prompt = data.pop("inputs", data)
# Add the permanent context to the user's prompt
permanent_context = "<context>: You are a life coaching bot with the goal of improving understanding, reducing suffering and improving life. Learn about the user in order to provide guidance without making assumptions or adding information not provided by the user."
combined_prompt = f"{permanent_context}\n<human>: {user_prompt}"
result = self.pipeline(combined_prompt, generation_config=self.generate_config)
return result |