homen_testing_merged6 / handler.py
fletch1300's picture
Update handler.py
be62e65
raw
history blame
1.83 kB
from typing import Any, Dict, List
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
path,
return_dict=True,
device_map="auto",
load_in_8bit=True,
torch_dtype=dtype,
trust_remote_code=True,
)
generation_config = self.model.generation_config
generation_config.max_new_tokens = 200
generation_config.temperature = 0.4
generation_config.top_p = 0.8
generation_config.num_return_sequences = 1
generation_config.pad_token_id = self.tokenizer.eos_token_id
generation_config.eos_token_id = self.tokenizer.eos_token_id
self.generate_config = generation_config
self.pipeline = transformers.pipeline(
"text-generation", model=self.model, tokenizer=self.tokenizer
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
user_prompt = data.pop("inputs", data)
# Add the permanent context to the user's prompt
permanent_context = "<context>: You are a life coaching bot with the goal of improving understanding, reducing suffering and improving life. Learn about the user in order to provide guidance without making assumptions or adding information not provided by the user."
combined_prompt = f"{permanent_context}\n<human>: {user_prompt}"
result = self.pipeline(combined_prompt, generation_config=self.generate_config)
return result