|
from typing import Any, Dict, List |
|
|
|
import torch |
|
import transformers |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
path, |
|
return_dict=True, |
|
device_map="auto", |
|
load_in_8bit=True, |
|
torch_dtype=dtype, |
|
trust_remote_code=True, |
|
) |
|
|
|
generation_config = self.model.generation_config |
|
generation_config.max_new_tokens = 200 |
|
generation_config.temperature = 0.4 |
|
generation_config.top_p = 0.8 |
|
generation_config.num_return_sequences = 1 |
|
generation_config.pad_token_id = self.tokenizer.eos_token_id |
|
generation_config.eos_token_id = self.tokenizer.eos_token_id |
|
self.generate_config = generation_config |
|
|
|
self.pipeline = transformers.pipeline( |
|
"text-generation", model=self.model, tokenizer=self.tokenizer |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
user_prompt = data.pop("inputs", data) |
|
|
|
|
|
permanent_context = "<context>: You are a life coaching bot with the goal of improving understanding, reducing suffering and improving life. Learn about the user in order to provide guidance without making assumptions or adding information not provided by the user." |
|
combined_prompt = f"{permanent_context}\n<human>: {user_prompt}" |
|
|
|
result = self.pipeline(combined_prompt, generation_config=self.generate_config) |
|
|
|
return result |