File size: 1,440 Bytes
c04a9f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os


MAX_INPUT_SIZE = 10_000
MAX_NEW_TOKENS = 4_000

def clean_json_text(text):
    """
    Cleans JSON text by removing leading/trailing whitespace and escaping special characters.
    """
    text = text.strip()
    text = text.replace("\#", "#").replace("\&", "&")
    return text

class EndpointHandler:
    def __init__(self, path=""):
        # load model and processor from path
        self.model =  AutoModelForCausalLM.from_pretrained(path, 
                                             trust_remote_code=True, 
                                             torch_dtype=torch.bfloat16,
                                             device_map="auto")
        self.model.eval()
        self.tokenizer = AutoTokenizer.from_pretrained(path)

    def __call__(self, data: Dict[str, Any]) -> str:
        data = data.pop("inputs")
        template = data.pop("template")
        text = data.pop("text")
        input_llm =  f"<|input|>\n### Template:\n{template}\n### Text:\n{text}\n\n<|output|>" + "{"

        input_ids = self.tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=MAX_INPUT_SIZE).to("cuda")
        output = self.tokenizer.decode(self.model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS)[0], skip_special_tokens=True)

        return clean_json_text(output.split("<|output|>")[1])