LucasAguetai commited on
Commit
c37d535
1 Parent(s): f14f444

add deberta and squeezebert to the fastapi app and the modeles file

Browse files
Files changed (2) hide show
  1. app.py +30 -2
  2. modeles.py +39 -4
app.py CHANGED
@@ -7,7 +7,7 @@ from fastapi import FastAPI, UploadFile, File
7
  from typing import Union
8
  import json
9
  import csv
10
- from modeles import bert
11
 
12
 
13
  app = FastAPI()
@@ -48,11 +48,25 @@ async def create_upload_file(texte: str, model: str):
48
 
49
  return {"model": model, "texte": texte}
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # # Modèle Pydantic pour les requêtes BERT
52
  # class BERTRequest(BaseModel):
53
  # context: str
54
  # question: str
55
-
56
  @app.post("/bert/")
57
  async def qabert(context: str, question: str):
58
  try:
@@ -64,6 +78,20 @@ async def qabert(context: str, question: str):
64
  except Exception as e:
65
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  def extract_data(file: UploadFile) -> Union[str, dict, list]:
69
  if file.filename.endswith(".txt"):
 
7
  from typing import Union
8
  import json
9
  import csv
10
+ from modeles import bert, squeezebert, deberta
11
 
12
 
13
  app = FastAPI()
 
48
 
49
  return {"model": model, "texte": texte}
50
 
51
+ # # Modèle Pydantic pour les requêtes SqueezeBERT
52
+ # class SqueezeBERTRequest(BaseModel):
53
+ # context: str
54
+ # question: str
55
+ @app.post("/squeezebert/")
56
+ async def qasqueezebert(context: str, question: str):
57
+ try:
58
+ squeezebert_answer = squeezebert(context, question)
59
+ if squeezebert_answer:
60
+ return squeezebert_answer
61
+ else:
62
+ raise HTTPException(status_code=404, detail="No answer found")
63
+ except Exception as e:
64
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
65
+
66
  # # Modèle Pydantic pour les requêtes BERT
67
  # class BERTRequest(BaseModel):
68
  # context: str
69
  # question: str
 
70
  @app.post("/bert/")
71
  async def qabert(context: str, question: str):
72
  try:
 
78
  except Exception as e:
79
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
80
 
81
+ # # Modèle Pydantic pour les requêtes DeBERTa
82
+ # class DeBERTaRequest(BaseModel):
83
+ # context: str
84
+ # question: str
85
+ @app.post("/deberta-v2/")
86
+ async def qadeberta(context: str, question: str):
87
+ try:
88
+ deberta_answer = deberta(context, question)
89
+ if deberta_answer:
90
+ return deberta_answer
91
+ else:
92
+ raise HTTPException(status_code=404, detail="No answer found")
93
+ except Exception as e:
94
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
95
 
96
  def extract_data(file: UploadFile) -> Union[str, dict, list]:
97
  if file.filename.endswith(".txt"):
modeles.py CHANGED
@@ -1,6 +1,41 @@
1
- # transformers obtenu via pip install automatique spécifié dans requirements.txt
2
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def bert(context, question):
5
- question_answerer = pipeline("question-answering", "alexandre-huynh/bert-base-uncased-finetuned-squad-single-epoch", framework="tf")
6
- return(question_answerer(context=context, question=question))
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
2
+ import torch
3
+
4
+ def load_and_answer(question, context, model_name):
5
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
6
+ model = AutoModelForQuestionAnswering.from_pretrained(model_name)
7
+
8
+ # Tokenize the input question-context pair
9
+ inputs = tokenizer.encode_plus(question, context, max_length=512)
10
+
11
+ # Send inputs to the same device as your model
12
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
13
+
14
+ with torch.no_grad():
15
+ # Forward pass, get model outputs
16
+ outputs = model(**inputs)
17
+
18
+ # Extract the start and end positions of the answer in the tokens
19
+ answer_start_scores, answer_end_scores = outputs.start_logits, outputs.end_logits
20
+ answer_start_index = torch.argmax(answer_start_scores) # Most likely start of answer
21
+ answer_end_index = torch.argmax(answer_end_scores) + 1 # Most likely end of answer; +1 for inclusive slicing
22
+
23
+ # Convert token indices to the actual answer text
24
+ answer_tokens = inputs['input_ids'][0, answer_start_index:answer_end_index]
25
+ answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
26
+ return {"answer": answer, "start": answer_start_index.item(), "end": answer_end_index.item()}
27
+
28
+ def squeezebert(context, question):
29
+ # Define the specific model and tokenizer for SqueezeBERT
30
+ model_name = "ALOQAS/squeezebert-uncased-finetuned-squad-v2"
31
+ return load_and_answer(question, context, model_name)
32
 
33
  def bert(context, question):
34
+ # Define the specific model and tokenizer for BERT
35
+ model_name = "ALOQAS/bert-large-uncased-finetuned-squad-v2"
36
+ return load_and_answer(question, context, model_name)
37
+
38
+ def deberta(context, question):
39
+ # Define the specific model and tokenizer for DeBERTa
40
+ model_name = "ALOQAS/deberta-large-finetuned-squad-v2"
41
+ return load_and_answer(question, context, model_name)