Spaces:
Runtime error
Runtime error
youssef1214
commited on
Commit
•
d3aaa0b
1
Parent(s):
a0c40d9
Update main.py
Browse files
main.py
CHANGED
@@ -30,11 +30,22 @@ import threading
|
|
30 |
import firebase_admin
|
31 |
from firebase_admin import credentials
|
32 |
from firebase_admin import firestore
|
|
|
|
|
|
|
|
|
33 |
# Firebase ininlaziton
|
34 |
cred = credentials.Certificate(
|
35 |
"text-to-emotions-firebase-adminsdk-8isbn-dffbdf01e8.json")
|
36 |
firebase_admin.initialize_app(cred)
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
# Model inilization
|
40 |
isristemmer = ISRIStemmer()
|
@@ -142,7 +153,16 @@ def original_values(num):
|
|
142 |
return 'sympathy'
|
143 |
elif num == 6:
|
144 |
return 'fear'
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
def modelpredict(data):
|
148 |
data = txt_preprocess(data)
|
@@ -158,7 +178,13 @@ app = FastAPI()
|
|
158 |
@app.get("/")
|
159 |
def index():
|
160 |
return "Hello World"
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
@app.post("/predict")
|
163 |
async def read_root(request: Request):
|
164 |
json_data = await request.json()
|
|
|
30 |
import firebase_admin
|
31 |
from firebase_admin import credentials
|
32 |
from firebase_admin import firestore
|
33 |
+
from transformers import BertTokenizer, AutoModelForSeq2SeqLM, pipeline
|
34 |
+
from arabert.preprocess import ArabertPreprocessor
|
35 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
36 |
+
import re
|
37 |
# Firebase ininlaziton
|
38 |
cred = credentials.Certificate(
|
39 |
"text-to-emotions-firebase-adminsdk-8isbn-dffbdf01e8.json")
|
40 |
firebase_admin.initialize_app(cred)
|
41 |
|
42 |
+
# Model summury
|
43 |
+
model_name="abdalrahmanshahrour/auto-arabic-summarization"
|
44 |
+
preprocessor = ArabertPreprocessor(model_name="")
|
45 |
+
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
47 |
+
modelsummary = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
48 |
+
pipeline1 = pipeline("text2text-generation",model=modelsummary,tokenizer=tokenizer)
|
49 |
|
50 |
# Model inilization
|
51 |
isristemmer = ISRIStemmer()
|
|
|
153 |
return 'sympathy'
|
154 |
elif num == 6:
|
155 |
return 'fear'
|
156 |
+
def modelsummary(data):
|
157 |
+
result = pipeline1(text,
|
158 |
+
pad_token_id= tokenizer.eos_token_id,
|
159 |
+
num_beams=4,
|
160 |
+
repetition_penalty=3.0,
|
161 |
+
max_length=600,
|
162 |
+
length_penalty=1.0,
|
163 |
+
no_repeat_ngram_size = 3)[0]['generated_text']
|
164 |
+
result = remove_punctuations(result)
|
165 |
+
return { 'summary':result}
|
166 |
|
167 |
def modelpredict(data):
|
168 |
data = txt_preprocess(data)
|
|
|
178 |
@app.get("/")
|
179 |
def index():
|
180 |
return "Hello World"
|
181 |
+
@app.post("/summary")
|
182 |
+
async def read_root(request:Request):
|
183 |
+
json_data = await request.json()
|
184 |
+
if 'text'in json_data:
|
185 |
+
return modelsummary(json_data['text'])
|
186 |
+
else:
|
187 |
+
raise HTTPException(status_code=400, detail="Missing text value")
|
188 |
@app.post("/predict")
|
189 |
async def read_root(request: Request):
|
190 |
json_data = await request.json()
|