LucasAguetai commited on
Commit
00837ec
1 Parent(s): 37e2495

fix squeeze issue + load only once pipeline

Browse files
Files changed (2) hide show
  1. app.py +8 -7
  2. modeles.py +31 -13
app.py CHANGED
@@ -4,9 +4,10 @@ from fastapi import FastAPI, UploadFile
4
  from typing import Union
5
  import json
6
  import csv
7
- from modeles import bert, squeezebert, deberta
8
  from uploadFile import file_to_text
9
  from typing import List
 
10
 
11
 
12
 
@@ -21,10 +22,10 @@ app.add_middleware(
21
  )
22
 
23
 
24
- @app.on_event("startup")
25
- async def startup_event():
26
- print("start")
27
 
 
 
 
28
 
29
  @app.get("/")
30
  async def root():
@@ -60,7 +61,7 @@ async def create_upload_file(texte: str, model: str):
60
  @app.post("/squeezebert/")
61
  async def qasqueezebert(context: str, question: str):
62
  try:
63
- squeezebert_answer = squeezebert(context, question)
64
  if squeezebert_answer:
65
  return squeezebert_answer
66
  else:
@@ -75,7 +76,7 @@ async def qasqueezebert(context: str, question: str):
75
  @app.post("/bert/")
76
  async def qabert(context: str, question: str):
77
  try:
78
- bert_answer = bert(context, question)
79
  if bert_answer:
80
  return bert_answer
81
  else:
@@ -90,7 +91,7 @@ async def qabert(context: str, question: str):
90
  @app.post("/deberta-v2/")
91
  async def qadeberta(context: str, question: str):
92
  try:
93
- deberta_answer = deberta(context, question)
94
  if deberta_answer:
95
  return deberta_answer
96
  else:
 
4
  from typing import Union
5
  import json
6
  import csv
7
+ from modeles import bert, squeezebert, deberta, loadSqueeze
8
  from uploadFile import file_to_text
9
  from typing import List
10
+ from transformers import pipeline
11
 
12
 
13
 
 
22
  )
23
 
24
 
 
 
 
25
 
26
+ pipBert = pipeline('question-answering', model="ALOQAS/bert-large-uncased-finetuned-squad-v2", tokenizer="ALOQAS/bert-large-uncased-finetuned-squad-v2")
27
+ pipDeberta = pipeline('question-answering', model="ALOQAS/deberta-large-finetuned-squad-v2", tokenizer="ALOQAS/deberta-large-finetuned-squad-v2")
28
+ tokenizer, model = loadSqueeze()
29
 
30
  @app.get("/")
31
  async def root():
 
61
  @app.post("/squeezebert/")
62
  async def qasqueezebert(context: str, question: str):
63
  try:
64
+ squeezebert_answer = squeezebert(context, question, model, tokenizer)
65
  if squeezebert_answer:
66
  return squeezebert_answer
67
  else:
 
76
  @app.post("/bert/")
77
  async def qabert(context: str, question: str):
78
  try:
79
+ bert_answer = bert(context, question, pipBert)
80
  if bert_answer:
81
  return bert_answer
82
  else:
 
91
  @app.post("/deberta-v2/")
92
  async def qadeberta(context: str, question: str):
93
  try:
94
+ deberta_answer = deberta(context, question, pipDeberta)
95
  if deberta_answer:
96
  return deberta_answer
97
  else:
modeles.py CHANGED
@@ -1,19 +1,37 @@
1
- from transformers import pipeline
 
2
 
3
- def squeezebert(context, question):
 
 
 
 
 
4
  # Define the specific model and tokenizer for SqueezeBERT
5
- model_name = "ALOQAS/squeezebert-uncased-finetuned-squad-v2"
6
- pip = pipeline('question-answering', model=model_name, tokenizer=model_name)
7
- return pip(context=context, question=question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- def bert(context, question):
10
- # Define the specific model and tokenizer for BERT
11
- model_name = "ALOQAS/bert-large-uncased-finetuned-squad-v2"
12
- pip = pipeline('question-answering', model=model_name, tokenizer=model_name)
13
  return pip(context=context, question=question)
14
 
15
- def deberta(context, question):
16
- # Define the specific model and tokenizer for DeBERTa
17
- model_name = "ALOQAS/deberta-large-finetuned-squad-v2"
18
- pip = pipeline('question-answering', model=model_name, tokenizer=model_name)
19
  return pip(context=context, question=question)
 
1
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
2
+ import torch
3
 
4
+ def loadSqueeze():
5
+ tokenizer = AutoTokenizer.from_pretrained("ALOQAS/squeezebert-uncased-finetuned-squad-v2")
6
+ model = AutoModelForQuestionAnswering.from_pretrained("ALOQAS/squeezebert-uncased-finetuned-squad-v2")
7
+ return tokenizer, model
8
+
9
+ def squeezebert(context, question, model, tokenizer):
10
  # Define the specific model and tokenizer for SqueezeBERT
11
+ # Tokenize the input question-context pair
12
+ inputs = tokenizer.encode_plus(question, context, max_length=512, truncation=True, padding=True, return_tensors='pt')
13
+
14
+ # Send inputs to the same device as your model
15
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
16
+
17
+ with torch.no_grad():
18
+ # Forward pass, get model outputs
19
+ outputs = model(**inputs)
20
+
21
+ # Extract the start and end positions of the answer in the tokens
22
+ answer_start_scores, answer_end_scores = outputs.start_logits, outputs.end_logits
23
+ answer_start_index = torch.argmax(answer_start_scores) # Most likely start of answer
24
+ answer_end_index = torch.argmax(answer_end_scores) + 1 # Most likely end of answer; +1 for inclusive slicing
25
+
26
+ # Convert token indices to the actual answer text
27
+ answer_tokens = inputs['input_ids'][0, answer_start_index:answer_end_index]
28
+ answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
29
+ return {"answer": answer, "start": answer_start_index.item(), "end": answer_end_index.item()}
30
+
31
+
32
 
33
+ def bert(context, question, pip):
 
 
 
34
  return pip(context=context, question=question)
35
 
36
+ def deberta(context, question, pip):
 
 
 
37
  return pip(context=context, question=question)