NuExtract-1.5 / handler.py
Alexandre-Numind's picture
Create handler.py
c04a9f0 verified
raw
history blame
1.44 kB
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
MAX_INPUT_SIZE = 10_000
MAX_NEW_TOKENS = 4_000
def clean_json_text(text):
"""
Cleans JSON text by removing leading/trailing whitespace and escaping special characters.
"""
text = text.strip()
text = text.replace("\#", "#").replace("\&", "&")
return text
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.model = AutoModelForCausalLM.from_pretrained(path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto")
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> str:
data = data.pop("inputs")
template = data.pop("template")
text = data.pop("text")
input_llm = f"<|input|>\n### Template:\n{template}\n### Text:\n{text}\n\n<|output|>" + "{"
input_ids = self.tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=MAX_INPUT_SIZE).to("cuda")
output = self.tokenizer.decode(self.model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS)[0], skip_special_tokens=True)
return clean_json_text(output.split("<|output|>")[1])