|
from typing import Dict, Any |
|
import logging |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch.cuda |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
LOGGER = logging.getLogger(__name__) |
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.model = AutoModelForCausalLM.from_pretrained("Ozgur98/pushed_model_mosaic_small", trust_remote_code=True).to(device='cuda:0', dtype=torch.bfloat16) |
|
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") |
|
|
|
|
|
def __call__(self, data): |
|
""" |
|
Args: |
|
data (Dict): The payload with the text prompt and generation parameters. |
|
""" |
|
LOGGER.info(data) |
|
|
|
LOGGER.info(f"Start generation.") |
|
tokenized_example = self.tokenizer(data, return_tensors='pt') |
|
outputs = self.model.generate(tokenized_example['input_ids'].to('cuda:0'), max_new_tokens=100, do_sample=True, top_k=10, top_p = 0.95) |
|
|
|
|
|
answer = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
prompt = answer[0].rstrip() |
|
return prompt |