Mohavere_PersianTextFormalizer-inference-pipeline / PersianTextFormalizerPipeline.py
PardisSzah's picture
commit files to HF hub
c96e0a1
raw
history blame
849 Bytes
from transformers import Pipeline, T5ForConditionalGeneration, AutoTokenizer
class PersianTextFormalizerPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "second_text" in kwargs:
preprocess_kwargs["second_text"] = kwargs["second_text"]
return preprocess_kwargs, {}, {}
def preprocess(self, text, second_text=None):
inputs = self.tokenizer.encode("informal: " + text, return_tensors='pt', max_length=128, truncation=True, padding='max_length')
return inputs.to(self.device)
def _forward(self, model_inputs):
return self.model.generate(model_inputs, max_length=128, num_beams=4, temperature=0.7)
def postprocess(self, model_outputs):
return self.tokenizer.decode(model_outputs[0], skip_special_tokens=True)