|
from transformers import Pipeline, T5ForConditionalGeneration, AutoTokenizer
|
|
|
|
class PersianTextFormalizerPipeline(Pipeline):
|
|
|
|
def _sanitize_parameters(self, **kwargs):
|
|
preprocess_kwargs = {}
|
|
if "second_text" in kwargs:
|
|
preprocess_kwargs["second_text"] = kwargs["second_text"]
|
|
return preprocess_kwargs, {}, {}
|
|
|
|
def preprocess(self, text, second_text=None):
|
|
inputs = self.tokenizer.encode("informal: " + text, return_tensors='pt', max_length=128, truncation=True, padding='max_length')
|
|
return inputs.to(self.device)
|
|
|
|
def _forward(self, model_inputs):
|
|
return self.model.generate(model_inputs, max_length=128, num_beams=4, temperature=0.7)
|
|
|
|
def postprocess(self, model_outputs):
|
|
return self.tokenizer.decode(model_outputs[0], skip_special_tokens=True) |