zaidmehdi commited on
Commit
ab54966
1 Parent(s): 99757c1

creating flask api

Browse files
Files changed (1) hide show
  1. src/main.py +22 -9
src/main.py CHANGED
@@ -1,21 +1,34 @@
1
  import pickle
2
 
 
3
  from transformers import AutoTokenizer
4
 
 
5
 
6
- def classify_arabic_dialect(text:str, model, tokenizer) -> str:
7
- text_embeddings = tokenizer(text, padding=True)
8
- predicted_class = model.predict(text_embeddings)
9
 
10
- return predicted_class
11
 
12
 
13
- def main():
14
- with open("../models/logistic_regression.pkl", "rb") as f:
15
- model = pickle.load(f)
 
 
 
 
 
 
 
16
 
17
- tokenizer = AutoTokenizer.from_pretrained("moussaKam/AraBART")
18
- return
 
 
 
 
 
19
 
20
 
21
  if __name__ == "__main__":
 
1
  import pickle
2
 
3
+ from flask import Flask, request, jsonify
4
  from transformers import AutoTokenizer
5
 
6
+ app = Flask(__name__)
7
 
8
+ with open("../models/logistic_regression.pkl", "rb") as f:
9
+ model = pickle.load(f)
 
10
 
11
+ tokenizer = AutoTokenizer.from_pretrained("moussaKam/AraBART")
12
 
13
 
14
+ @app.route("/classify", methods=["POST"])
15
+ def classify_arabic_dialect():
16
+ try:
17
+ data = request.json
18
+ text = data.get("text")
19
+ if not text:
20
+ return jsonify({"error": "No text has been received"}), 400
21
+
22
+ text_embeddings = tokenizer(text, padding=True)
23
+ predicted_class = model.predict(text_embeddings)
24
 
25
+ return jsonify({"class": predicted_class}), 200
26
+ except Exception as e:
27
+ return jsonify({"error": str(e)}), 500
28
+
29
+
30
+ def main():
31
+ app.run(debug=True)
32
 
33
 
34
  if __name__ == "__main__":