zaidmehdi commited on
Commit
9a23b5c
1 Parent(s): ab54966

add function to extract hidden state

Browse files
Files changed (2) hide show
  1. src/main.py +7 -3
  2. src/utils.py +9 -0
src/main.py CHANGED
@@ -1,14 +1,18 @@
1
  import pickle
2
 
3
  from flask import Flask, request, jsonify
4
- from transformers import AutoTokenizer
 
 
5
 
6
  app = Flask(__name__)
7
 
8
  with open("../models/logistic_regression.pkl", "rb") as f:
9
  model = pickle.load(f)
10
 
11
- tokenizer = AutoTokenizer.from_pretrained("moussaKam/AraBART")
 
 
12
 
13
 
14
  @app.route("/classify", methods=["POST"])
@@ -19,7 +23,7 @@ def classify_arabic_dialect():
19
  if not text:
20
  return jsonify({"error": "No text has been received"}), 400
21
 
22
- text_embeddings = tokenizer(text, padding=True)
23
  predicted_class = model.predict(text_embeddings)
24
 
25
  return jsonify({"class": predicted_class}), 200
 
1
  import pickle
2
 
3
  from flask import Flask, request, jsonify
4
+ from transformers import AutoModel, AutoTokenizer
5
+
6
+ from utils import extract_hidden_state
7
 
8
  app = Flask(__name__)
9
 
10
  with open("../models/logistic_regression.pkl", "rb") as f:
11
  model = pickle.load(f)
12
 
13
+ model_name = "moussaKam/AraBART"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ language_model = AutoModel.from_pretrained(model_name)
16
 
17
 
18
  @app.route("/classify", methods=["POST"])
 
23
  if not text:
24
  return jsonify({"error": "No text has been received"}), 400
25
 
26
+ text_embeddings = extract_hidden_state(text, tokenizer, language_model)
27
  predicted_class = model.predict(text_embeddings)
28
 
29
  return jsonify({"class": predicted_class}), 200
src/utils.py CHANGED
@@ -2,6 +2,15 @@ import matplotlib.pyplot as plt
2
  import seaborn as sns
3
  from sklearn.metrics import accuracy_score, f1_score
4
  from sklearn.metrics import confusion_matrix
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  def get_metrics(y_true, y_preds):
 
2
  import seaborn as sns
3
  from sklearn.metrics import accuracy_score, f1_score
4
  from sklearn.metrics import confusion_matrix
5
+ import torch
6
+
7
+
8
+ def extract_hidden_state(input_text, tokenizer, language_model):
9
+ tokens = tokenizer(input_text, padding=True)
10
+ with torch.no_grad():
11
+ outputs = language_model(tokens)
12
+
13
+ return outputs.last_hidden_state
14
 
15
 
16
  def get_metrics(y_true, y_preds):