from typing import Any, Dict, List import torch import transformers from transformers import AutoModelForCausalLM, AutoTokenizer dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 class EndpointHandler: def __init__(self, path=""): self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) self.model = AutoModelForCausalLM.from_pretrained( path, return_dict=True, device_map="auto", load_in_8bit=True, torch_dtype=dtype, trust_remote_code=True, ) generation_config = self.model.generation_config generation_config.max_new_tokens = 200 generation_config.temperature = 0.8 generation_config.top_p = 0.8 generation_config.num_return_sequences = 1 generation_config.pad_token_id = self.tokenizer.eos_token_id generation_config.eos_token_id = self.tokenizer.eos_token_id generation_config.early_stopping = True self.generate_config = generation_config self.pipeline = transformers.pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer ) def _ensure_token_limit(self, text): """Ensure text is within the model's token limit.""" tokens = self.tokenizer.tokenize(text) if len(tokens) > 2048: # Remove tokens from the beginning until the text fits tokens = tokens[-2048:] return self.tokenizer.decode(tokens) return text def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: user_prompt = data.pop("inputs", data) # Permanent context permanent_context = ": You are a life coaching bot..." structured_prompt = f"{permanent_context}\ response:" result = self.pipeline(structured_prompt, generation_config=self.generate_config) # Ensure _extract_response is defined and works as intended response_text = self._extract_response(result[0]['generated_text']) # Trimming response response_text = response_text.rsplit("[END", 1)[0].strip() return {"response": response_text}