zaidmehdi commited on
Commit
99757c1
1 Parent(s): 1784a22

defining classify_arabic_dialect() function

Browse files
Files changed (1) hide show
  1. src/main.py +11 -1
src/main.py CHANGED
@@ -1,10 +1,20 @@
1
  import pickle
2
 
 
 
 
 
 
 
 
 
 
3
 
4
  def main():
5
  with open("../models/logistic_regression.pkl", "rb") as f:
6
  model = pickle.load(f)
7
-
 
8
  return
9
 
10
 
 
1
  import pickle
2
 
3
+ from transformers import AutoTokenizer
4
+
5
+
6
+ def classify_arabic_dialect(text:str, model, tokenizer) -> str:
7
+ text_embeddings = tokenizer(text, padding=True)
8
+ predicted_class = model.predict(text_embeddings)
9
+
10
+ return predicted_class
11
+
12
 
13
  def main():
14
  with open("../models/logistic_regression.pkl", "rb") as f:
15
  model = pickle.load(f)
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained("moussaKam/AraBART")
18
  return
19
 
20