Spaces:
Sleeping
Sleeping
feat: cpu support using cuda condition branch
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
|
3 |
import streamlit as st
|
4 |
from pymilvus import MilvusClient
|
|
|
5 |
|
6 |
from model import encode_dpr_question, get_dpr_encoder
|
7 |
from model import summarize_text, get_summarizer
|
@@ -56,6 +57,7 @@ st.markdown(styl, unsafe_allow_html=True)
|
|
56 |
question = st.text_area("Text to summarize", INITIAL, height=400)
|
57 |
|
58 |
|
|
|
59 |
def main(question: str):
|
60 |
if question in st.session_state:
|
61 |
print("Cache hit!")
|
|
|
2 |
|
3 |
import streamlit as st
|
4 |
from pymilvus import MilvusClient
|
5 |
+
import torch
|
6 |
|
7 |
from model import encode_dpr_question, get_dpr_encoder
|
8 |
from model import summarize_text, get_summarizer
|
|
|
57 |
question = st.text_area("Text to summarize", INITIAL, height=400)
|
58 |
|
59 |
|
60 |
+
@torch.inference_mode()
|
61 |
def main(question: str):
|
62 |
if question in st.session_state:
|
63 |
print("Cache hit!")
|
model.py
CHANGED
@@ -7,15 +7,21 @@ from transformers import QuestionAnsweringPipeline
|
|
7 |
from transformers import AutoTokenizer, PegasusXForConditionalGeneration, PegasusTokenizerFast
|
8 |
import torch
|
9 |
|
|
|
10 |
max_answer_len = 8
|
11 |
logging.set_verbosity_error()
|
12 |
|
13 |
|
|
|
14 |
def summarize_text(tokenizer: PegasusTokenizerFast, model: PegasusXForConditionalGeneration,
|
15 |
input_texts: List[str]):
|
16 |
inputs = tokenizer(input_texts, padding=True,
|
17 |
-
return_tensors='pt', truncation=True)
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
summary_ids = model.generate(inputs["input_ids"])
|
20 |
summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True,
|
21 |
clean_up_tokenization_spaces=False, batch_size=len(input_texts))
|
@@ -24,14 +30,13 @@ def summarize_text(tokenizer: PegasusTokenizerFast, model: PegasusXForConditiona
|
|
24 |
|
25 |
def get_summarizer(model_id="seonglae/resrer") -> Tuple[PegasusTokenizerFast, PegasusXForConditionalGeneration]:
|
26 |
tokenizer = PegasusTokenizerFast.from_pretrained(model_id)
|
27 |
-
model = PegasusXForConditionalGeneration.from_pretrained(model_id)
|
|
|
|
|
28 |
model = torch.compile(model)
|
29 |
return tokenizer, model
|
30 |
|
31 |
|
32 |
-
# OpenAI reader
|
33 |
-
|
34 |
-
|
35 |
class AnswerInfo(TypedDict):
|
36 |
score: float
|
37 |
start: int
|
@@ -42,10 +47,16 @@ class AnswerInfo(TypedDict):
|
|
42 |
@torch.inference_mode()
|
43 |
def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering,
|
44 |
questions: List[str], ctxs: List[str]) -> List[AnswerInfo]:
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
pipeline = QuestionAnsweringPipeline(
|
47 |
-
model=model, tokenizer=tokenizer, device='
|
48 |
-
answer_infos
|
49 |
question=questions, context=ctxs)
|
50 |
for answer_info in answer_infos:
|
51 |
answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer'])
|
@@ -54,10 +65,13 @@ def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering,
|
|
54 |
|
55 |
def get_reader(model_id="mrm8488/longformer-base-4096-finetuned-squadv2"):
|
56 |
tokenizer = DPRReaderTokenizer.from_pretrained(model_id)
|
57 |
-
model = DPRReader.from_pretrained(model_id)
|
|
|
|
|
58 |
return tokenizer, model
|
59 |
|
60 |
|
|
|
61 |
def encode_dpr_question(tokenizer: DPRQuestionEncoderTokenizer, model: DPRQuestionEncoder, questions: List[str]) -> torch.FloatTensor:
|
62 |
"""Encode a question using DPR question encoder.
|
63 |
https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder
|
@@ -67,9 +81,13 @@ def encode_dpr_question(tokenizer: DPRQuestionEncoderTokenizer, model: DPRQuesti
|
|
67 |
model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
|
68 |
"""
|
69 |
batch_dict = tokenizer(questions, return_tensors="pt",
|
70 |
-
padding=True, truncation=True
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
73 |
return embeddings
|
74 |
|
75 |
|
@@ -82,5 +100,7 @@ def get_dpr_encoder(model_id="facebook/dpr-question_encoder-single-nq-base") ->
|
|
82 |
model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
|
83 |
"""
|
84 |
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(model_id)
|
85 |
-
model = DPRQuestionEncoder.from_pretrained(model_id)
|
|
|
|
|
86 |
return tokenizer, model
|
|
|
7 |
from transformers import AutoTokenizer, PegasusXForConditionalGeneration, PegasusTokenizerFast
|
8 |
import torch
|
9 |
|
10 |
+
cuda = torch.cuda.is_available()
|
11 |
max_answer_len = 8
|
12 |
logging.set_verbosity_error()
|
13 |
|
14 |
|
15 |
+
@torch.inference_mode()
|
16 |
def summarize_text(tokenizer: PegasusTokenizerFast, model: PegasusXForConditionalGeneration,
|
17 |
input_texts: List[str]):
|
18 |
inputs = tokenizer(input_texts, padding=True,
|
19 |
+
return_tensors='pt', truncation=True)
|
20 |
+
if cuda:
|
21 |
+
inputs = inputs.to(0)
|
22 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
23 |
+
summary_ids = model.generate(inputs["input_ids"])
|
24 |
+
else:
|
25 |
summary_ids = model.generate(inputs["input_ids"])
|
26 |
summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True,
|
27 |
clean_up_tokenization_spaces=False, batch_size=len(input_texts))
|
|
|
30 |
|
31 |
def get_summarizer(model_id="seonglae/resrer") -> Tuple[PegasusTokenizerFast, PegasusXForConditionalGeneration]:
|
32 |
tokenizer = PegasusTokenizerFast.from_pretrained(model_id)
|
33 |
+
model = PegasusXForConditionalGeneration.from_pretrained(model_id)
|
34 |
+
if cuda:
|
35 |
+
model = model.to(0)
|
36 |
model = torch.compile(model)
|
37 |
return tokenizer, model
|
38 |
|
39 |
|
|
|
|
|
|
|
40 |
class AnswerInfo(TypedDict):
|
41 |
score: float
|
42 |
start: int
|
|
|
47 |
@torch.inference_mode()
|
48 |
def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering,
|
49 |
questions: List[str], ctxs: List[str]) -> List[AnswerInfo]:
|
50 |
+
if cuda:
|
51 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
52 |
+
pipeline = QuestionAnsweringPipeline(
|
53 |
+
model=model, tokenizer=tokenizer, device='cuda', max_answer_len=max_answer_len)
|
54 |
+
answer_infos: List[AnswerInfo] = pipeline(
|
55 |
+
question=questions, context=ctxs)
|
56 |
+
else:
|
57 |
pipeline = QuestionAnsweringPipeline(
|
58 |
+
model=model, tokenizer=tokenizer, device='cpu', max_answer_len=max_answer_len)
|
59 |
+
answer_infos = pipeline(
|
60 |
question=questions, context=ctxs)
|
61 |
for answer_info in answer_infos:
|
62 |
answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer'])
|
|
|
65 |
|
66 |
def get_reader(model_id="mrm8488/longformer-base-4096-finetuned-squadv2"):
|
67 |
tokenizer = DPRReaderTokenizer.from_pretrained(model_id)
|
68 |
+
model = DPRReader.from_pretrained(model_id)
|
69 |
+
if cuda:
|
70 |
+
model = model.to(0)
|
71 |
return tokenizer, model
|
72 |
|
73 |
|
74 |
+
@torch.inference_mode()
|
75 |
def encode_dpr_question(tokenizer: DPRQuestionEncoderTokenizer, model: DPRQuestionEncoder, questions: List[str]) -> torch.FloatTensor:
|
76 |
"""Encode a question using DPR question encoder.
|
77 |
https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder
|
|
|
81 |
model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
|
82 |
"""
|
83 |
batch_dict = tokenizer(questions, return_tensors="pt",
|
84 |
+
padding=True, truncation=True)
|
85 |
+
if cuda:
|
86 |
+
batch_dict = batch_dict.to(0)
|
87 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
88 |
+
embeddings: torch.FloatTensor = model(**batch_dict).pooler_output
|
89 |
+
else:
|
90 |
+
embeddings = model(**batch_dict).pooler_output
|
91 |
return embeddings
|
92 |
|
93 |
|
|
|
100 |
model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
|
101 |
"""
|
102 |
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(model_id)
|
103 |
+
model = DPRQuestionEncoder.from_pretrained(model_id)
|
104 |
+
if cuda:
|
105 |
+
model = model.to(0)
|
106 |
return tokenizer, model
|