PardisSzah commited on
Commit
c96e0a1
1 Parent(s): 360b18c

commit files to HF hub

Browse files
Files changed (2) hide show
  1. PersianTextFormalizerPipeline.py +19 -0
  2. config.json +9 -0
PersianTextFormalizerPipeline.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Pipeline, T5ForConditionalGeneration, AutoTokenizer
2
+
3
+ class PersianTextFormalizerPipeline(Pipeline):
4
+
5
+ def _sanitize_parameters(self, **kwargs):
6
+ preprocess_kwargs = {}
7
+ if "second_text" in kwargs:
8
+ preprocess_kwargs["second_text"] = kwargs["second_text"]
9
+ return preprocess_kwargs, {}, {}
10
+
11
+ def preprocess(self, text, second_text=None):
12
+ inputs = self.tokenizer.encode("informal: " + text, return_tensors='pt', max_length=128, truncation=True, padding='max_length')
13
+ return inputs.to(self.device)
14
+
15
+ def _forward(self, model_inputs):
16
+ return self.model.generate(model_inputs, max_length=128, num_beams=4, temperature=0.7)
17
+
18
+ def postprocess(self, model_outputs):
19
+ return self.tokenizer.decode(model_outputs[0], skip_special_tokens=True)
config.json CHANGED
@@ -4,6 +4,15 @@
4
  "T5ForConditionalGeneration"
5
  ],
6
  "classifier_dropout": 0.0,
 
 
 
 
 
 
 
 
 
7
  "d_ff": 2048,
8
  "d_kv": 64,
9
  "d_model": 768,
 
4
  "T5ForConditionalGeneration"
5
  ],
6
  "classifier_dropout": 0.0,
7
+ "custom_pipelines": {
8
+ "text2text-PersianTextFormalizer_M": {
9
+ "impl": "PersianTextFormalizerPipeline.PersianTextFormalizerPipeline",
10
+ "pt": [
11
+ "T5ForConditionalGeneration"
12
+ ],
13
+ "tf": []
14
+ }
15
+ },
16
  "d_ff": 2048,
17
  "d_kv": 64,
18
  "d_model": 768,